首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >对Bellman-Ford算法使用带有numpy的矢量化

对Bellman-Ford算法使用带有numpy的矢量化
EN

Stack Overflow用户
提问于 2013-01-16 07:48:05
回答 1查看 791关注 0票数 2

我一直在尝试编写用于在图中寻找最短路径的Bellman Ford algoritm,虽然我已经有了一个有效的解决方案,但它运行得并不是很快,我相信如果我使用numpy而不是我目前的方法,它会更快。

这是我使用for循环的解决方案:

代码语言:javascript
复制
import os                    
file = open(os.path.dirname(os.path.realpath(__file__)) + "/g_small.txt")

vertices, edges = map(lambda x: int(x), file.readline().replace("\n", "").split(" "))

adjacency_list = [[] for k in xrange(vertices)]
for line in file.readlines():
    tail, head, weight = line.split(" ")
    adjacency_list[int(head)-1].append({"from" : int(tail), "weight" : int(weight)})

n = vertices

shortest_paths = []
s=2

cache = [[0 for k in xrange(vertices)] for j in xrange(vertices)]
cache[0][s] = 0

for v in range(0, vertices):
    if v != s:
    cache[0][v] = float("inf")

# this can be done with numpy I think?
for i in range(1, vertices):
    for v in range(0, vertices):
        adjacent_nodes = adjacency_list[v]

        least_adjacent_cost = float("inf")
        for node in adjacent_nodes:
            adjacent_cost = cache[i-1][node["from"]-1] + node["weight"]
            if adjacent_cost < least_adjacent_cost:
                least_adjacent_cost = adjacent_cost

        cache[i][v] = min(cache[i-1][v], least_adjacent_cost)

shortest_paths.append([s, cache[vertices-1]])

for path in shortest_paths:
    print(str(path[1]))

shortest_path = min(reduce(lambda x, y: x + y, map(lambda x: x[1], shortest_paths)))  
print("Shortest Path: " + str(shortest_path))  

输入文件如下所示的-> https://github.com/mneedham/algorithms2/blob/master/shortestpath/g_small.txt

除了大约一半的嵌套循环之外,它几乎是无趣的。我尝试使用numpy对其进行矢量化,但我不确定如何做,因为矩阵/2D数组在每次迭代时都会发生变化。

如果任何人有任何关于我需要做什么的想法,或者甚至是一些可以阅读的东西,那将是非常棒的。

==================

我写了一个更新的版本来考虑Jaime的评论:

代码语言:javascript
复制
s=0

def initialise_cache(vertices, s):
    cache = [0 for k in xrange(vertices)]
    cache[s] = 0

    for v in range(0, vertices):
        if v != s:
            cache[v] = float("inf")
    return cache    

cache = initialise_cache(vertices, s)

for i in range(1, vertices):
    previous_cache = deepcopy(cache)
    cache = initialise_cache(vertices, s)
    for v in range(0, vertices):
        adjacent_nodes = adjacency_list[v]

    least_adjacent_cost = float("inf")
    for node in adjacent_nodes:
        adjacent_cost = previous_cache[node["from"]-1] + node["weight"]
        if adjacent_cost < least_adjacent_cost:
            least_adjacent_cost = adjacent_cost

    cache[v] = min(previous_cache[v], least_adjacent_cost)

================

和另一个新版本,这次使用了矢量化:

代码语言:javascript
复制
def initialise_cache(vertices, s):
    cache = empty(vertices)
    cache[:] = float("inf")
    cache[s] = 0
    return cache    

adjacency_matrix = zeros((vertices, vertices))
adjacency_matrix[:] = float("inf")
for line in file.readlines():
    tail, head, weight = line.split(" ")
    adjacency_matrix[int(head)-1][int(tail)-1] = int(weight)    

n = vertices
shortest_paths = []
s=2

cache = initialise_cache(vertices, s)
for i in range(1, vertices):
    previous_cache = cache
    combined = (previous_cache.T + adjacency_matrix).min(axis=1)
    cache = minimum(previous_cache, combined)

shortest_paths.append([s, cache])
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2013-01-18 03:53:33

在遵循了Jaime的建议后,我最终得到了以下矢量化代码:

代码语言:javascript
复制
def initialise_cache(vertices, s):
    cache = empty(vertices)
    cache[:] = float("inf")
    cache[s] = 0
    return cache    

adjacency_matrix = zeros((vertices, vertices))
adjacency_matrix[:] = float("inf")
for line in file.readlines():
    tail, head, weight = line.split(" ")
    adjacency_matrix[int(head)-1][int(tail)-1] = int(weight)    

n = vertices
shortest_paths = []
s=2

cache = initialise_cache(vertices, s)
for i in range(1, vertices):
    previous_cache = cache
    combined = (previous_cache.T + adjacency_matrix).min(axis=1)
    cache = minimum(previous_cache, combined)

shortest_paths.append([s, cache])
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/14349084

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档