首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >python实现中的Strassen算法错误

python实现中的Strassen算法错误
EN

Stack Overflow用户
提问于 2018-09-02 13:20:20
回答 1查看 2.3K关注 0票数 0

通过Strassen算法和Python 3中朴素嵌套的循环实现,我得到了不同的矩阵乘法输出。

代码:

代码语言:javascript
复制
def new_matrix(r, c):
    """Create a new matrix filled with zeros."""
    matrix = [[0 for row in range(r)] for col in range(c)]
    return matrix


def direct_multiply(x, y):
    if len(x[0]) != len(y):
        return "Multiplication is not possible!"
    else:
        p_matrix = new_matrix(len(x), len(y[0]))
        for i in range(len(x)):
            for j in range(len(y[0])):
                for k in range(len(y)):
                    p_matrix[i][j] += x[i][k] * y[k][i]
    return p_matrix


def split(matrix):
    """Split matrix into quarters."""
    a = b = c = d = matrix

    while len(a) > len(matrix)/2:
        a = a[:len(a)//2]
        b = b[:len(b)//2]
        c = c[len(c)//2:]
        d = d[len(d)//2:]

    while len(a[0]) > len(matrix[0])//2:
        for i in range(len(a[0])//2):
            a[i] = a[i][:len(a[i])//2]
            b[i] = b[i][len(b[i])//2:]
            c[i] = c[i][:len(c[i])//2]
            d[i] = d[i][len(d[i])//2:]

    return a, b, c, d


def add_matrix(a, b):
    if type(a) == int:
        d = a + b
    else:
        d = []
        for i in range(len(a)):
            c = []
            for j in range(len(a[0])):
                c.append(a[i][j] + b[i][j])
            d.append(c)
    return d


def subtract_matrix(a, b):
    if type(a) == int:
        d = a - b
    else:
        d = []
        for i in range(len(a)):
            c = []
            for j in range(len(a[0])):
                c.append(a[i][j] - b[i][j])
            d.append(c)
    return d


def strassen(x, y, n):
    # base case: 1x1 matrix
    if n == 1:
        z = [[0]]
        z[0][0] = x[0][0] * y[0][0]
        return z
    else:
        # split matrices into quarters
        a, b, c, d = split(x)
        e, f, g, h = split(y)

        # p1 = a*(f-h)
        p1 = strassen(a, subtract_matrix(f, h), n/2)

        # p2 = (a+b)*h
        p2 = strassen(add_matrix(a, b), h, n/2)

        # p3 = (c+d)*e
        p3 = strassen(add_matrix(c, d), e, n/2)

        # p4 = d*(g-e)
        p4 = strassen(d, subtract_matrix(g, e), n/2)

        # p5 = (a+d)*(e+h)
        p5 = strassen(add_matrix(a, d), add_matrix(e, h), n/2)

        # p6 = (b-d)*(g+h)
        p6 = strassen(subtract_matrix(b, d), add_matrix(g, h), n/2)

        # p7 = (a-c)*(e+f)
        p7 = strassen(subtract_matrix(a, c), add_matrix(e, f), n/2)

        z11 = add_matrix(subtract_matrix(add_matrix(p5, p4), p2), p6)

        z12 = add_matrix(p1, p2)

        z21 = add_matrix(p3, p4)

        z22 = add_matrix(subtract_matrix(subtract_matrix(p5, p3), p7), p1)

        z = new_matrix(len(z11)*2, len(z11)*2)
        for i in range(len(z11)):
            for j in range(len(z11)):
                z[i][j] = z11[i][j]
                z[i][j+len(z11)] = z12[i][j]
                z[i+len(z11)][j] = z21[i][j]
                z[i+len(z11)][j+len(z11)] = z22[i][j]

        return z


a = [[11,11,11,11],[22,22,22,22],[33,33,33,33],[44,44,44,44]]
b = [[101,181,119,113],[22,22,22,22],[33,33,33,33],[44,44,44,44]]

print(f"a = {a}")
print(f"b = {b}")

print(f"Using Strassen's algorithm:\na*b = {strassen(a, b, 4)}")

print(f"Using naive algorithm:\na*b = {direct_multiply(a, b)}")

输出:

$ python3 strassen.py A= [11、11、11、11、22、22、22、33、33、33、33、44、44、44 ] B= [101、181、119、113、22、22、22、33、33、33、44、44、44、44] 使用Strassen算法: a*b = [2200、3080、2398、2332、4400、6160、4796、4664、6600、9240、7194、6996、8800、12320、9592、9328] 使用朴素算法: a*b = [2200、2200、2200、2200、6160、6160、6160、6160、6160、7194、7194、7194、7194、9328、9328、9328、9328、9328]

有人能帮忙吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-09-02 13:25:23

它应该是

代码语言:javascript
复制
p_matrix[i][j] += x[i][k] * y[k][j]

direct_multiply函数中。这样就可以将一行中的kth元素与列上的kth元素相乘,并将其累加起来。就像你做矩阵乘法一样。

输出

代码语言:javascript
复制
a = [[11, 11, 11, 11], [22, 22, 22, 22], [33, 33, 33, 33], [44, 44, 44, 44]]
b = [[101, 181, 119, 113], [22, 22, 22, 22], [33, 33, 33, 33], [44, 44, 44, 44]]
Using Strassen's algorithm:
a*b = [[2200, 3080, 2398, 2332], [4400, 6160, 4796, 4664], [6600, 9240, 7194, 6996], [8800, 12320, 9592, 9328]]
Using naive algorithm:
a*b = [[2200, 3080, 2398, 2332], [4400, 6160, 4796, 4664], [6600, 9240, 7194, 6996], [8800, 12320, 9592, 9328]]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52137431

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档