首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >锈蚀中快速惯用的Floyd-Warshall算法

锈蚀中快速惯用的Floyd-Warshall算法
EN

Stack Overflow用户
提问于 2021-11-20 21:35:16
回答 2查看 671关注 0票数 16

我正在尝试在Rust中实现弗洛伊德-沃尔算法的一个相当快的版本。该算法在有向加权图中寻找所有顶点之间的最短路径。

算法的主要部分可以写成这样:

代码语言:javascript
复制
// dist[i][j] contains edge length between vertices [i] and [j]
// after the end of the execution it contains shortest path between [i] and [j]
fn floyd_warshall(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            for k in 0..n {
                dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
            }
        }
    }
}

这个实现很短,很容易理解,但是它的工作速度比类似的c++实现慢1.5倍。

正如我所理解的,问题是,在每个向量访问上,Rust检查索引是否在向量的边界内,并且增加了一些开销。

我用get_unchecked*函数重写了这个函数:

代码语言:javascript
复制
fn floyd_warshall_unsafe(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            for k in 0..n {
                unsafe {
                    *dist[j].get_unchecked_mut(k) = min(
                        *dist[j].get_unchecked(k),
                        dist[j].get_unchecked(i) + dist[i].get_unchecked(k),
                    )
                }
            }
        }
    }
}

它真的开始运行1.5倍的速度(测试的全部代码)。

我没想到边界检查会增加这么多开销:

是否有可能在没有不安全的情况下以惯用的方式重写这段代码,这样它的工作速度和不安全版本一样快吗?例如,通过在代码中添加一些断言,是否有可能向编译器“证明”没有超出绑定的访问权限?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-11-21 22:48:01

经过一些实验,基于安德鲁的回答有关问题的评论中提出的想法,我找到了解决方案,其中:

  • 仍然使用相同的接口(例如,&mut [Vec<i32>]作为参数)
  • 不使用不安全
  • 比不安全版本快3-4倍
  • 很丑的:

代码如下:

代码语言:javascript
复制
fn floyd_warshall_fast(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        for j in 0..n {
            if i == j {
                continue;
            }
            let (dist_j, dist_i) = if j < i {
                let (lo, hi) = dist.split_at_mut(i);
                (&mut lo[j][..n], &mut hi[0][..n])
            } else {
                let (lo, hi) = dist.split_at_mut(j);
                (&mut hi[0][..n], &mut lo[i][..n])
            };
            let dist_ji = dist_j[i];
            for k in 0..n {
                dist_j[k] = min(dist_j[k], dist_ji + dist_i[k]);
            }
        }
    }
}

里面有几个想法:

  • 我们计算dist_ji一次,因为它在大多数内部循环中不发生变化,编译器不需要考虑它。
  • 我们“证明”dist[i]dist[j]实际上是两个不同的向量。这是由这个丑陋的split_at_muti == j特例来完成的(我真的很想知道一个更简单的解决方案)。在此之后,我们可以将dist[i]dist[j]完全分开处理,例如,编译器可以将这个循环向量化,因为它知道数据不重叠。
  • 最后一个窍门是向编译器“证明”dist[i]dist[j]至少都有n元素。这是由[..n]在计算dist[i]dist[j]时完成的(例如,我们使用&mut lo[j][..n]而不是&mut lo[j])。之后,编译器就会明白,内部循环从不使用越界值,并删除检查。

有趣的是,只有当三种优化都被使用时,它才能大大加快速度。如果我们只使用其中的任何两个,编译器就无法优化它。

票数 5
EN

Stack Overflow用户

发布于 2021-11-21 10:46:27

乍一看,人们会希望这样做就足够了:

代码语言:javascript
复制
fn floyd_warshall(dist: &mut [Vec<i32>]) {
    let n = dist.len();
    for i in 0..n {
        assert!(i < dist.len());
        for j in 0..n {
            assert!(j < dist.len());
            assert!(i < dist[j].len());
            let v2 = dist[j][i];
            for k in 0..n {
                assert!(k < dist[i].len());
                assert!(k < dist[j].len());
                dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
            }
        }
    }
}

添加断言是一个众所周知的技巧,可以使Rust优化器相信变量确实在范围内。然而,它在这里行不通。我们需要做的是在某种程度上使Rust编译器更清楚地认识到,这些循环是在边界内的,而不需要使用深奥的代码。

为此,我移到了David建议的2D数组中:

代码语言:javascript
复制
fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
    for i in 0..N {
        for j in 0..N {
            for k in 0..N {
                dist[j][k] = min(dist[j][k], dist[j][i] + dist[i][k]);
            }
        }
    }
    dist
}

这使用常量泛型( Rust的一个相对较新的特性)来指定堆上给定的2d数组的大小。就其本身而言,这一变化在我的机器上做得很好(比usafe快100毫秒,落后于不安全的20毫秒)。此外,如果您将v2计算移出k-循环,则如下所示:

代码语言:javascript
复制
fn floyd_warshall<const N:usize>(mut dist: Box<[[i32; N]; N]>) -> Box<[[i32; N]; N]> {
    for i in 0..N {
        for j in 0..N {
            let v2 = dist[j][i];
            for k in 0..N {
                dist[j][k] = min(dist[j][k], v2 + dist[i][k]);
            }
        }
    }
    dist
}

改进很大(在我的机器上从大约300毫秒提高到100毫秒)。同样的优化也适用于floyd_warshall_unsafe,使其在我的机器上平均达到100 my。当检查程序集(在#[inline(never)]上使用floyd_warshall)时,两者似乎都没有发生边界检查,而且两者在某种程度上看起来都是向量化的。虽然,我不是阅读集会的专家。

因为这是一个很热的循环(有最多三个边界检查),所以我并不惊讶性能会受到如此大的影响。不幸的是,在这种情况下,索引的使用非常复杂,从而阻止了assert技巧给您带来了一个简单的修复。在其他已知的情况下,需要进行断言检查以提高性能,但编译器无法充分使用这些信息。这里有一个这样的例子

这是操场和我的变化。

票数 5
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70050040

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档