在 as_strided 之前进行整形以进行优化

如何解决在 as_strided 之前进行整形以进行优化

def forward(x,f,s):
    B,H,W,C = x.shape # e.g. 64,16,3
    Fh,Fw,C,_ = f.shape # e.g. 4,4,3,3 
    # C is redeclared to emphasise that the dimension is the same
    
    Sh,Sw = s # e.g. 2,2

    strided_shape = B,1 + (H - Fh) // Sh,1 + (W - Fw) // Sw,Fh,C

    x = as_strided(x,strided_shape,strides=(
        x.strides[0],Sh * x.strides[1],Sw * x.strides[2],x.strides[1],x.strides[2],x.strides[3]),)

    # print(x.flags,f.flags)

    # The reshaping changes the einsum from 'wxyijk,ijkd' to 'wxyz,zd->wxyd'
    f = f.reshape(-1,f.shape[-1])
    x = x.reshape(*x.shape[:3],-1) # Bottleneck!
    
    return np.einsum('wxyz,zd->wxyd',x,optimize='optimal')

(相反,变体没有重塑使用return np.einsum('wxyijk,ijkd->wxyd',f)

作为参考,以下是重塑前 xf 的标志:

x.flags:

C_CONTIGUOUS : False
F_CONTIGUOUS : False
OWNDATA : False
WRITEABLE : True
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False


f.flags:

C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
WRITEBACKIFCOPY : False
UPDATEIFCOPY : False

有趣的是,例程中的主要瓶颈不是einsum,而是x 的重塑(扁平化)。我知道 f 不会遇到这样的问题,因为它的内存是 C 连续的,所以重塑相当于在不更改数据的情况下进行快速内部修改 - 但由于 x 不是 C 连续的(并且不拥有它的数据,就此而言),reshape 的成本要高得多,因为它涉及经常更改数据/获取非缓存对齐的数据。这反过来又是 as_strided 上执行的 x 函数的结果 - 步幅的修改必须以扰乱自然顺序的方式进行。 (仅供参考,as_strided 非常快,无论传递给它什么步幅都应该很快)

有没有办法在不产生瓶颈的情况下达到相同的结果?也许通过在使用 x 之前重塑 as_strided


另请注意,对于几乎 100% 的应用程序: B: [1-64],W: [1-60],C: [1-8] Fh,Fw: [1-12]

我还在这里包含了一些图表,用于随着张量维度 B(批量大小)以及我设备上的 H,W(图像大小)的变化而变化的时间(如你可以看到,涉及到 reshape 的那一个已经可以与 Tensorflow 竞争了):

Variation with batch size

Variation with image size


编辑:一个有趣的发现 - 重塑算法在 CPU 上以 5 倍的系数击败非重塑算法,但是当我使用 GPU(即使用 CuPy 而不是 NumPy)时,两种算法同样快(大约是 TensorFlow 的前向传递速度的两倍)

解决方法

由于您提到的原因(在非连续数组上复制),跨步数组的重新整形有点昂贵,但没有您想象的那么昂贵。 np.einsum 实际上可能是您的应用程序中的瓶颈,具体取决于张量大小。如 Convolutional layer in Python using Numpy 中所述,np.tensordot 可以很好地替代 np.einsum

举个简单的例子:

x = np.arange(64*221*221*3).reshape((64,221,3))
f = np.arange(4*4*3*5).reshape((4,4,3,5))
s = (2,2)

B,H,W,C = x.shape # e.g. 64,16,3
Fh,Fw,C,_ = f.shape # e.g. 4,3 
Sh,Sw = s # e.g. 2,2
strided_shape = B,1 + (H - Fh) // Sh,1 + (W - Fw) // Sw,Fh,C
print(strided_shape)
# (64,109,3)

初始化变量后,我们可以测试代码部分的时序

%timeit x_strided = as_strided(x,strided_shape,strides=(x.strides[0],Sh * x.strides[1],Sw * x.strides[2],x.strides[1],x.strides[2],x.strides[3]),)
>>> 7.11 µs ± 118 ns per loop (mean ± std. dev. of 7 runs,100000 loops each)

%timeit f_reshaped = f.reshape(-1,f.shape[-1])
>>> 450 ns ± 7.43 ns per loop (mean ± std. dev. of 7 runs,1000000 loops each)

%timeit x_reshaped = x_strided.reshape(*x_strided.shape[:3],-1) # Bottleneck!
>>> 94.6 ms ± 896 µs per loop (mean ± std. dev. of 7 runs,10 loops each)

# einsum without reshape
%timeit np.einsum('wxy...,...d->wxyd',x_strided,f,optimize='optimal')
>>> 809 ms ± 1.1 ms per loop (mean ± std. dev. of 7 runs,1 loop each)

# einsum with reshape
%%timeit
f_reshaped = f.reshape(-1,f.shape[-1])
x_reshaped = x_strided.reshape(*x_strided.shape[:3],-1) # Bottleneck!
k = np.einsum('wxyz,zd->wxyd',x_reshaped,f_reshaped,optimize='optimal')
>>> 549 ms ± 3.05 ms per loop (mean ± std. dev. of 7 runs,1 loop each)

# tensordot without reshape
%timeit k = np.tensordot(x_strided,axes=3)
>>> 271 ms ± 4.89 ms per loop (mean ± std. dev. of 7 runs,1 loop each)

# tensordot with reshape
%%timeit
f_reshaped = f.reshape(-1,-1) # Bottleneck!
k = np.tensordot(x_reshaped,axes=(3,0))
>>> 266 ms ± 3.15 ms per loop (mean ± std. dev. of 7 runs,1 loop each)

我在您的代码中使用张量大小得到了类似的结果(即 64、16、16、3 和 4、4、3、3)。

如您所见,调整大小操作存在开销,但由于连续数据,它使矩阵操作更快。请注意,结果会因 CPU 速度、CPU 架构/代等而异。

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 <select id="xxx"> SELECT di.id, di.name, di.work_type, di.updated... <where> <if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 <property name="dynamic.classpath" value="tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams['font.sans-serif'] = ['SimHei'] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -> systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping("/hires") public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate<String
使用vite构建项目报错 C:\Users\ychen\work>npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-