我正在尝试在Rust中实现弗洛伊德-沃尔算法的一个相当快的版本。该算法在有向加权图中寻找所有顶点之间的最短路径。
算法的主要部分可以写成这样:
// dist[i][j] contains edge length between vertices [i] and [j]
// after the end of the execution it contains shortest path between [i] and [j]
fn floyd_warshall(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
for k in 0..n {
dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
}
}
}
}这个实现很短,很容易理解,但是它的工作速度比类似的c++实现慢1.5倍。
正如我所理解的,问题是,在每个向量访问上,Rust检查索引是否在向量的边界内,并且增加了一些开销。
我用get_unchecked*函数重写了这个函数:
fn floyd_warshall_unsafe(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
for k in 0..n {
unsafe {
*dist[j].get_unchecked_mut(k) = min(
*dist[j].get_unchecked(k),
dist[j].get_unchecked(i) + dist[i].get_unchecked(k),
)
}
}
}
}
}它真的开始运行1.5倍的速度(测试的全部代码)。
我没想到边界检查会增加这么多开销:
是否有可能在没有不安全的情况下以惯用的方式重写这段代码,这样它的工作速度和不安全版本一样快吗?例如,通过在代码中添加一些断言,是否有可能向编译器“证明”没有超出绑定的访问权限?
发布于 2021-11-21 22:48:01
经过一些实验,基于安德鲁的回答和有关问题的评论中提出的想法,我找到了解决方案,其中:
&mut [Vec<i32>]作为参数)代码如下:
fn floyd_warshall_fast(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
for j in 0..n {
if i == j {
continue;
}
let (dist_j, dist_i) = if j < i {
let (lo, hi) = dist.split_at_mut(i);
(&mut lo[j][..n], &mut hi[0][..n])
} else {
let (lo, hi) = dist.split_at_mut(j);
(&mut hi[0][..n], &mut lo[i][..n])
};
let dist_ji = dist_j[i];
for k in 0..n {
dist_j[k] = min(dist_j[k], dist_ji + dist_i[k]);
}
}
}
}里面有几个想法:
dist_ji一次,因为它在大多数内部循环中不发生变化,编译器不需要考虑它。dist[i]和dist[j]实际上是两个不同的向量。这是由这个丑陋的split_at_mut和i == j特例来完成的(我真的很想知道一个更简单的解决方案)。在此之后,我们可以将dist[i]和dist[j]完全分开处理,例如,编译器可以将这个循环向量化,因为它知道数据不重叠。dist[i]和dist[j]至少都有n元素。这是由[..n]在计算dist[i]和dist[j]时完成的(例如,我们使用&mut lo[j][..n]而不是&mut lo[j])。之后,编译器就会明白,内部循环从不使用越界值,并删除检查。有趣的是,只有当三种优化都被使用时,它才能大大加快速度。如果我们只使用其中的任何两个,编译器就无法优化它。
发布于 2021-11-21 10:46:27
乍一看,人们会希望这样做就足够了:
fn floyd_warshall(dist: &mut [Vec<i32>]) {
let n = dist.len();
for i in 0..n {
assert!(i < dist.len());
for j in 0..n {
assert!(j < dist.len());
assert!(i < dist[j].len());
let v2 = dist[j][i];
for k in 0..n {
assert!(k < dist[i].len());
assert!(k < dist[j].len());
dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
}
}
}
}添加断言是一个众所周知的技巧,可以使Rust优化器相信变量确实在范围内。然而,它在这里行不通。我们需要做的是在某种程度上使Rust编译器更清楚地认识到,这些循环是在边界内的,而不需要使用深奥的代码。
为此,我移到了David建议的2D数组中:
fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
for i in 0..N {
for j in 0..N {
for k in 0..N {
dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
}
}
}
dist
}这使用常量泛型( Rust的一个相对较新的特性)来指定堆上给定的2d数组的大小。就其本身而言,这一变化在我的机器上做得很好(比usafe快100毫秒,落后于不安全的20毫秒)。此外,如果您将v2计算移出k-循环,则如下所示:
fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
for i in 0..N {
for j in 0..N {
let v2 = dist[j][i];
for k in 0..N {
dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
}
}
}
dist
}改进很大(在我的机器上从大约300毫秒提高到100毫秒)。同样的优化也适用于floyd_warshall_unsafe,使其在我的机器上平均达到100 my。当检查程序集(在#[inline(never)]上使用floyd_warshall)时,两者似乎都没有发生边界检查,而且两者在某种程度上看起来都是向量化的。虽然,我不是阅读集会的专家。
因为这是一个很热的循环(有最多三个边界检查),所以我并不惊讶性能会受到如此大的影响。不幸的是,在这种情况下,索引的使用非常复杂,从而阻止了assert技巧给您带来了一个简单的修复。在其他已知的情况下,需要进行断言检查以提高性能,但编译器无法充分使用这些信息。这里有一个这样的例子。
这是操场和我的变化。
https://stackoverflow.com/questions/70050040
复制相似问题