如何解决用Numba和CUDA求和的数组
我刚刚开始学习如何使用Numba和CUDA进行编程,因此此代码可能是非常错误的,但是我不明白为什么它不起作用。我正在尝试对N个不同的数组求和,其内容取决于另一个数组。显示代码可能比以下解释更好:
import numba as nb
from numba import cuda
import numpy as np
from math import exp,ceil
t0s = np.array([2.5,6.7,8.1,9.6,10.5])
threadsperblock = 32
blockspergrid = ceil(t0s.shape[0] / threadsperblock)
time = np.linspace(0,10,2000)
waveform = np.zeros_like(time)
total_waveform = np.zeros_like(waveform)
@cuda.jit(device=True)
def current(waveform,time,t0):
for i in range(waveform.shape[0]):
if time[i] > t0:
waveform[i] = 0
else:
waveform[i] = exp(time[i]-t0)
@cuda.jit
def total(time,waveform,total_waveform,t0s):
i = cuda.grid(1)
if i < t0s.shape[0]:
current(waveform,t0s[i])
for j in range(total_waveform.shape[0]):
total_waveform[j] += waveform[j]
total[blockspergrid,threadsperblock](time,t0s)
不幸的是,total_waveform
仅包含第一个波形(就像在t0s
的第一个元素之后停止一样),我真的不明白为什么。救命! :)
解决方法
基于已发布的代码和此注释:
我的正确结果将是一个包含5条上升的指数曲线的数组,每条曲线都以
t0s[i]
结尾
假设您的意思是,您似乎可以极大地简化代码并获得所需的结果
我的正确结果是一个数组,其中包含 5条上升的指数曲线的总和,每条曲线的终点为
t0s[i]
。
当t0较大时,每条曲线在小t处接近零,而对于所有t0> 0,每条曲线在[0,t0)上始终不为零。如果我没有误解您的意图和代码,您可以: / p>
- 将
current
更改为标量函数 - 一起消除
waveform
,这是不需要存储的中间结果 - 更改并行化策略,以便每个线程仅计算输出中的一个时间点(即,从原始代码中反转循环的顺序)。如果这样做,则不会出现内存争用或同步问题。
如果您做这三件事,您将得到如下信息:
$ cat wavegoodbye.py
import numba as nb
from numba import cuda
import numpy as np
from math import exp,ceil
t0s = np.array([2.5,6.7,8.1,9.6,10.5])
time = np.linspace(0,10,2000)
total_waveform = np.zeros_like(time)
threadsperblock = 32
blockspergrid = ceil(total_waveform.shape[0] / threadsperblock)
@cuda.jit(device=True)
def current(time,t0):
if time > t0:
waveform = 0
else:
waveform = exp(time-t0)
return waveform
@cuda.jit
def total(time,total_waveform,t0s):
i = cuda.grid(1)
if i < total_waveform.shape[0]:
for j in range(t0s.shape[0]):
total_waveform[i] += current(time[i],t0s[j])
total[blockspergrid,threadsperblock](time,t0s)
这样做:
$ ipython
Python 3.7.4 (default,Aug 13 2019,20:35:49)
Type 'copyright','credits' or 'license' for more information
IPython 7.11.1 -- An enhanced Interactive Python. Type '?' for help.
In [1]: %run wavegoodbye.py
In [2]: import pylab as pl
In [3]: pl.plot(time,total_waveform)
我想这就是你的想法。
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。