Attention注意力机制——ECANet以及加入到1DCNN网络方法

原文:https://arxiv.org/abs/1910.03151
代码:https://github.com/BangguWu/ECANet
论文题目:ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks

目录

引言

一、ECANet结构

 二、ECANet代码

三、将ECANet作为一个模块加入到CNN中 

1、要加入的CNN网络

 2、加入eca_block的语句

3、加入eca_block后的网络结构的代码(例如在第二层卷积层之后加入)


引言

ECANet是对SENet模块的改进,提出了一种不降维的局部跨信道交互策略(ECA模块)和自适应选择一维卷积核大小的方法,从而实现了性能上的提优。 

在给定输入特征的情况下,SE块首先对每个通道单独使用全局平均池化,然后使用两个具有非线性的完全连接(FC)层,然后使用一个Sigmoid函数来生成通道权值。两个FC层的设计是为了捕捉非线性的跨通道交互,其中包括降维来控制模型的复杂性。虽然该策略在后续的通道注意模块中得到了广泛的应用,但作者的实验研究表明,降维对通道注意预测带来了副作用,捕获所有通道之间的依赖是低效的,也是不必要的。

因此,提出了一种针对深度CNN的高效通道注意(ECA)模块,该模块避免了降维,有效捕获了跨通道交互的信息。如下图

一、ECANet结构

和SENet模块相比,ECANet在全局平均池化之后去除了全连接层,改用1*1卷积

在没有降维的通道全局平均池化之后,ECANet使用一维卷积来实现跨通道信息交互,而卷积核的大小k通过函数来自适应。

给定通道维度 C,卷积核大小 k 可以自适应地确定为:

odd为取奇数,\gamma = 2,b=1
 

 二、ECANet代码

 我用来处理一维信号,所以网络里的池化为1D

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model, layers
import math


def eca_block(inputs, b=1, gama=2):
    # 输入特征图的通道数
    in_channel = inputs.shape[-1]

    # 根据公式计算自适应卷积核大小
    kernel_size = int(abs((math.log(in_channel, 2) + b) / gama))

    # 如果卷积核大小是偶数,就使用它
    if kernel_size % 2:
        kernel_size = kernel_size

    # 如果卷积核大小是奇数就变成偶数
    else:
        kernel_size = kernel_size + 1

    # [h,w,c]==>[None,c] 全局平均池化
    x = layers.GlobalAveragePooling1D()(inputs)

    # [None,c]==>[c,1]
    x = layers.Reshape(target_shape=(in_channel, 1))(x)

    # [c,1]==>[c,1]
    x = layers.Conv1D(filters=1, kernel_size=kernel_size, padding='same', use_bias=False)(x)

    # sigmoid激活
    x = tf.nn.sigmoid(x)

    # [c,1]==>[1,1,c]
    x = layers.Reshape((1, 1, in_channel))(x)

    # 结果和输入相乘
    outputs = layers.multiply([inputs, x])

    return outputs

三、将ECANet作为一个模块加入到CNN中 

1、要加入的CNN网络

from keras.layers import Conv1D, Dense, Dropout, BatchNormalization, MaxPooling1D, Activation, Flatten,Input
from keras.models import Model
import preprocess
from keras.callbacks import TensorBoard
import matplotlib.pyplot as plt
import numpy as np
from keras.regularizers import l2

# 数据路径
path = xxx
# 数据经过preprocess预处理
x_train, y_train, x_valid, y_valid, x_test, y_test = preprocess.prepro(d_path=path,length=length,
                                                                  number=number,
                                                                  normal=normal,
                                                                  rate=rate,
                                                                  enc=True, enc_step=28)

x_train, x_valid, x_test = x_train[:,:,np.newaxis], x_valid[:,:,np.newaxis], x_test[:,:,np.newaxis]

batch_size = 128
epochs = 20
num_classes = 10
length = 2048
BatchNorm = True # 是否批量归一化
number = 1000 # 每类样本的数量
normal = True # 是否标准化
rate = [0.7,0.2,0.1] # 测试集验证集划分比例

input_shape =x_train.shape[1:]

# 定义输入层,确定输入维度
input = Input(shape = input_shape)

# 卷积层1
x = Conv1D(filters=16, kernel_size=64, strides=16, padding='same',kernel_regularizer=l2(1e-4),input_shape=input_shape)(input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 卷积层2
x = Conv1D(filters=32, kernel_size=3, strides=1, padding='same',kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 卷积层3
x = Conv1D(filters=64, kernel_size=3, strides=1, padding='same',kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 卷积层4
x = Conv1D(filters=64, kernel_size=3, strides=1, padding='same',kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 卷积层5
x = Conv1D(filters=64, kernel_size=3, strides=1, padding='same',kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 从卷积到全连接需要展平
x = Flatten()(x)

# 添加全连接层
x = Dense(units=100, activation='relu', kernel_regularizer=l2(1e-4))(x)

# 增加输出层
output = Dense(units=num_classes, activation='softmax', kernel_regularizer=l2(1e-4))(x)
model =Model(inputs = input,outputs = output)
model.compile(optimizer='Adam', loss='categorical_crossentropy',
              metrics=['accuracy'])

 2、加入eca_block的语句

# eca_block
eca = eca_block(x)
x = layers.add([x,eca])

3、加入eca_block后的网络结构的代码(例如在第二层卷积层之后加入)
 

# 定义输入层,确定输入维度
input = Input(shape = input_shape)
# 卷积层1
x = Conv1D(filters=16, kernel_size=64, strides=16, padding='same',kernel_regularizer=l2(1e-4),input_shape=input_shape)(input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 卷积层2
x = Conv1D(filters=32, kernel_size=3, strides=1, padding='same',kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# eca_block
eca = eca_block(x)
x = layers.add([x,eca])

# 卷积层3
x = Conv1D(filters=64, kernel_size=3, strides=1, padding='same',kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 卷积层4
x = Conv1D(filters=64, kernel_size=3, strides=1, padding='same',kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 卷积层5
x = Conv1D(filters=64, kernel_size=3, strides=1, padding='same',kernel_regularizer=l2(1e-4))(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = MaxPooling1D(pool_size=2)(x)

# 从卷积到全连接需要展平
x = Flatten()(x)

# 添加全连接层
x = Dense(units=100, activation='relu', kernel_regularizer=l2(1e-4))(x)
# 增加输出层
output = Dense(units=num_classes, activation='softmax', kernel_regularizer=l2(1e-4))(x)
model =Model(inputs = input,outputs = output)
model.compile(optimizer='Adam', loss='categorical_crossentropy',
              metrics=['accuracy'])

 网络结构由keras中的Model方法构建,用来处理一维信号

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


学习编程是顺着互联网的发展潮流,是一件好事。新手如何学习编程?其实不难,不过在学习编程之前你得先了解你的目的是什么?这个很重要,因为目的决定你的发展方向、决定你的发展速度。
IT行业是什么工作做什么?IT行业的工作有:产品策划类、页面设计类、前端与移动、开发与测试、营销推广类、数据运营类、运营维护类、游戏相关类等,根据不同的分类下面有细分了不同的岗位。
女生学Java好就业吗?女生适合学Java编程吗?目前有不少女生学习Java开发,但要结合自身的情况,先了解自己适不适合去学习Java,不要盲目的选择不适合自己的Java培训班进行学习。只要肯下功夫钻研,多看、多想、多练
Can’t connect to local MySQL server through socket \'/var/lib/mysql/mysql.sock问题 1.进入mysql路径
oracle基本命令 一、登录操作 1.管理员登录 # 管理员登录 sqlplus / as sysdba 2.普通用户登录
一、背景 因为项目中需要通北京网络,所以需要连vpn,但是服务器有时候会断掉,所以写个shell脚本每五分钟去判断是否连接,于是就有下面的shell脚本。
BETWEEN 操作符选取介于两个值之间的数据范围内的值。这些值可以是数值、文本或者日期。
假如你已经使用过苹果开发者中心上架app,你肯定知道在苹果开发者中心的web界面,无法直接提交ipa文件,而是需要使用第三方工具,将ipa文件上传到构建版本,开...
下面的 SQL 语句指定了两个别名,一个是 name 列的别名,一个是 country 列的别名。**提示:**如果列名称包含空格,要求使用双引号或方括号:
在使用H5混合开发的app打包后,需要将ipa文件上传到appstore进行发布,就需要去苹果开发者中心进行发布。​
+----+--------------+---------------------------+-------+---------+
数组的声明并不是声明一个个单独的变量,比如 number0、number1、...、number99,而是声明一个数组变量,比如 numbers,然后使用 nu...
第一步:到appuploader官网下载辅助工具和iCloud驱动,使用前面创建的AppID登录。
如需删除表中的列,请使用下面的语法(请注意,某些数据库系统不允许这种在数据库表中删除列的方式):
前不久在制作win11pe,制作了一版,1.26GB,太大了,不满意,想再裁剪下,发现这次dism mount正常,commit或discard巨慢,以前都很快...
赛门铁克各个版本概览:https://knowledge.broadcom.com/external/article?legacyId=tech163829
实测Python 3.6.6用pip 21.3.1,再高就报错了,Python 3.10.7用pip 22.3.1是可以的
Broadcom Corporation (博通公司,股票代号AVGO)是全球领先的有线和无线通信半导体公司。其产品实现向家庭、 办公室和移动环境以及在这些环境...
发现个问题,server2016上安装了c4d这些版本,低版本的正常显示窗格,但红色圈出的高版本c4d打开后不显示窗格,
TAT:https://cloud.tencent.com/document/product/1340