首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >MATLAB:使用bsxfun加速离散化函数

MATLAB:使用bsxfun加速离散化函数
EN

Stack Overflow用户
提问于 2016-08-16 22:22:58
回答 2查看 398关注 0票数 0

对于当前的项目,我必须将准连续值离散到由一些预定义的分箱分辨率定义的分箱中。为此,我编写了一个函数,我希望它是高效的,因为它能够使用bsxfun处理标量输入和向量输入。然而,经过一些分析之后,我发现我的更大项目的几乎所有处理时间都是在这个函数中产生的,并且在这个函数中,主要是bsxfun部分需要时间,其次是min-query。长话短说,我正在寻找关于如何在MATLAB中更快地解决这个任务的建议。附注:我通常会传递50k个元素的向量。

代码如下:

代码语言:javascript
复制
function sampleNo = value2sample(value,bins)

%Make sure both vectors have orientations fitting bsxfun
value = value(:);
bins = bins(:)';

%Recover bin resolution (avoids passing another parameter)
delta = median(diff(bins));

%Calculate distance matrix between all combinations
dist = abs(bsxfun(@minus,value,bins));

%What we really want to know is the minimum distance per row
[minval,ind] = min(dist,[],2);

%Make sure we don't accidentally further process NaNs as 1st bin
ind(isnan(minval))=NaN;

sampleNo = ind;
sampleNo(minval>delta) = NaN;

end
EN

回答 2

Stack Overflow用户

发布于 2016-08-16 22:39:58

你的函数慢的原因是因为你在计算valuesbins的每个元素之间的距离,并将它们全部存储在一个数组中-如果有N个值和M个big,那么你将需要NM个元素来存储所有的距离,这可能是一个非常大的数字(例如,如果每个输入有50,000个元素,那么你需要输出数组中的25亿个元素)。

此外,由于您的存储箱是排序的(您没有说明这一点,但看起来您在代码中假定了这一点),您不需要计算从每个值到每个存储箱的距离。你可以变得更聪明

代码语言:javascript
复制
function ind = value2sample(value, bins)

    % Find median bin distance
    delta = median(diff(bins));

    % Bucket into 'nearest' bin by using midpoints
    bins = bins(:);
    mids = [-Inf; 0.5 * (bins(1:end-1) + bins(2:end))];

    [~, ind] = histc(value, mids);

    % Ensure that NaN values and points that aren't near any bin are returned as NaN
    ind(isnan(value)) = NaN;
    ind(abs(value - bins(ind)) > delta) = NaN;

end

在我的测试中,使用values = randn(10000, 1)bins = -50:50运行原始函数大约需要4.5毫秒,运行上面的代码需要485微秒,所以您可以获得大约10倍的加速(并且随着输入大小的增加,速度会更快)。

票数 1
EN

Stack Overflow用户

发布于 2016-08-17 03:10:28

感谢@Chris Taylor,我能够非常高效地解决这个问题。代码现在的运行速度几乎是以前的400倍。我必须对他的版本所做的唯一更改反映在下面的代码中。主要问题是discretize取代了histc (不再鼓励使用它)。

代码语言:javascript
复制
function ind = value2sample(value, bins)

% Make sure the vectors are standing
value = value(:);
bins = bins(:);

% Bucket into 'nearest' bin by using midpoints
mids = [eps; 0.5 * (bins(1:end-1) + bins(2:end))];

ind = discretize(value, mids);

唯一的问题是,在这个实现中,您的bin必须是非负的。除此之外,这段代码完全符合我的要求,包括当valueNaN或超出bins范围时,ind的大小与value相同,并包含NaNs

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/38977482

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档