我刚刚开始学习如何使用Numba和CUDA编程,所以这个代码可能是非常错误的,但是我不明白为什么它不起作用。我正在尝试和N个不同的数组,它们的内容依赖于另一个数组。显示代码可能比下面的解释更好:
import numba as nb
from numba import cuda
import numpy as np
from math import exp, ceil
t0s = np.array([2.5,6.7,8.1,9.6,10.5])
threadsperblock = 32
blockspergrid = ceil(t0s.shape[0] / threadsperblock)
time = np.linspace(0,10,2000)
waveform = np.zeros_like(time)
total_waveform = np.zeros_like(waveform)
@cuda.jit(device=True)
def current(waveform, time, t0):
for i in range(waveform.shape[0]):
if time[i] > t0:
waveform[i] = 0
else:
waveform[i] = exp(time[i]-t0)
@cuda.jit
def total(time, waveform, total_waveform, t0s):
i = cuda.grid(1)
if i < t0s.shape[0]:
current(waveform, time, t0s[i])
for j in range(total_waveform.shape[0]):
total_waveform[j] += waveform[j]
total[blockspergrid, threadsperblock](time, waveform, total_waveform, t0s)不幸的是,total_waveform只包含第一个波形(就像它在t0s的第一个元素之后停止一样),我真的不明白为什么。救命!:)
发布于 2020-08-11 09:11:17
根据张贴的代码和本评论:
我正确的结果是一个包含5条上升指数曲线的数组,每条曲线的结尾都是
t0s[i]
看起来你可以大大简化代码,并得到期望的结果,前提是你的意思是
,我正确的结果是一个数组,包含--5上升指数曲线之和,每条曲线以
t0s[i]结尾。
当t0较大时,当每条曲线在小t处接近于零时,对于所有t0 > 0,在[0,t0)上,每条曲线总是非零的。如果我没有误解你的意图和代码,你能:
将function
current改为标量current waveform,这是一个中间结果,不需要存储如果你做了这三件事,你会得到这样的东西:
$ cat wavegoodbye.py
import numba as nb
from numba import cuda
import numpy as np
from math import exp, ceil
t0s = np.array([2.5,6.7,8.1,9.6,10.5])
time = np.linspace(0,10,2000)
total_waveform = np.zeros_like(time)
threadsperblock = 32
blockspergrid = ceil(total_waveform.shape[0] / threadsperblock)
@cuda.jit(device=True)
def current(time, t0):
if time > t0:
waveform = 0
else:
waveform = exp(time-t0)
return waveform
@cuda.jit
def total(time, total_waveform, t0s):
i = cuda.grid(1)
if i < total_waveform.shape[0]:
for j in range(t0s.shape[0]):
total_waveform[i] += current(time[i], t0s[j])
total[blockspergrid, threadsperblock](time, total_waveform, t0s)它是这样的:
$ ipython
Python 3.7.4 (default, Aug 13 2019, 20:35:49)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.11.1 -- An enhanced Interactive Python. Type '?' for help.
In [1]: %run wavegoodbye.py
In [2]: import pylab as pl
In [3]: pl.plot(time, total_waveform)

我想你就是这么想的。
https://stackoverflow.com/questions/63350559
复制相似问题