首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >最优运输的Sinkhorn算法

最优运输的Sinkhorn算法
EN

Stack Overflow用户
提问于 2019-12-24 16:39:52
回答 1查看 391关注 0票数 1

我正在尝试编写Sinkhorn算法,尤其是当熵正则化的强度收敛到0时,我是否可以计算出两个度量之间的最优传输。

例如,让我们将$0;1$上的统一度量$U$传输到$1;2$上的统一度量$V$中。二次海岸的最优测度是$(x,x-1){#} U$。

让我们离散化$0;1$,度量$U$,$1;2$和度量$V$。使用辛克霍恩,我应该得到一个度量,使得支持度在行$y = x-1$的图形中。但它没有,所以我正在努力找出问题所在。我将向你展示我的代码和我的结果,也许有人可以帮助我。

代码语言:javascript
复制
import numpy as np
import math
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.colors as colors

#Parameters
N = 10                        #Step of the discritization of [0,1]
stop = 10**-3
Niter = 10**3


def Sinkhorn(C, mu, nu, lamb):
    # lam : strength of the entropic regularization

    #Initialization
    a1 = np.zeros(N)
    b1 = np.zeros(N)
    a2 = np.ones(N)
    b2 = np.ones(N)
    Iter = 0

    GammaB = np.exp(-lamb*C)

    #Sinkhorn
    while (np.linalg.norm(a2) > stop and  np.linalg.norm(b2) > stop and np.linalg.norm(a2) < 1/stop and np.linalg.norm(b2) < 1/stop and Iter < Niter and np.linalg.norm(a1-a2) + np.linalg.norm(b1-b2) > stop ):
        a1 = a2
        b1 = b2
        a2 = mu/(np.dot(GammaB,b1))
        b2 = nu/(np.dot(GammaB.T,a2))
        Iter +=1

    # Compute gamma_star
    Gamma = np.zeros((N,N))
    for i in range(N):
        for j in range(N):
            Gamma[i][j] = a2[i]*b2[j]*GammaB[i][j]
    Gamma /= Gamma.sum()

    return Gamma

## Test between uniform([0;1]) over uniform([1;2])

S = np.linspace(0,1,N, False)  #discritization of [0,1]
T = np.linspace(1,2,N,False)  #discritization of [1,2]

# Discretization of uniform([0;1])
U01 = np.ones(N)
Mass = np.sum(U01)
U01 = U01/Mass

# Discretization uniform([1;2])
U12 = np.ones(N)
Mass = np.sum(U12)
U12 = U12/Mass

# Cost function
X,Y = np.meshgrid(S,T)
C = (X-Y)**2               #Matrix of c[i,j]=(xi-yj)²

def plot_Sinkhorn_U01_U12():

    #plot optimal measure and convergence
    fig = plt.figure()
    for i in range(4):
        ax = fig.add_subplot(2, 2, i+1, projection='3d')
        Gamma_star = Sinkhorn(C, U01, U12, 1/10**i)
        ax.scatter(X, Y, Gamma_star, cmap='viridis', linewidth=0.5)
        plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(1/10**i))

    plt.show()

    plt.figure()
    for i in range(4):
        plt.subplot(2,2,i+1)
        Gamma_star = Sinkhorn(C, U01, U12, 1/10**i)
        plt.imshow(Gamma_star,interpolation='none')
        plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(1/10**i))

    plt.show()

    return

# The transport between U01 ans U12 is x -> x-1 so the support of gamma^* is contained in the graph of the function x -> (x,x+1) which is the line y = x+1

plot_Sinkhorn_U01_U12()

我得到了什么。

如前所述,这是我考虑1/lamb时的代码输出。

这要好得多,但仍然不正确。这是Gamma_star(125)

代码语言:javascript
复制
Gamma_star(125) : 
[[0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.1  0.09 0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.01 0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.01]]

我们可以看到,行$y = x-1$中没有包含对度量gamma_star的支持

谢谢并致以问候。

EN

回答 1

Stack Overflow用户

发布于 2019-12-29 16:19:39

这不是最终的答案,但我们越来越接近了。

正如我所建议的,我减轻了我的while条件。例如,只有一个条件

代码语言:javascript
复制
while (Iter < Niter):

这是我得到的:

下面是我为gamma_star(125)得到的矩阵:

代码语言:javascript
复制
[[0.08 0.02 0.   0.   0.   0.   0.   0.   0.   0.  ]
 [0.02 0.06 0.02 0.   0.   0.   0.   0.   0.   0.  ]
 [0.   0.02 0.06 0.02 0.   0.   0.   0.   0.   0.  ]
 [0.   0.   0.02 0.06 0.02 0.   0.   0.   0.   0.  ]
 [0.   0.   0.   0.02 0.06 0.02 0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.02 0.06 0.02 0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.02 0.06 0.02 0.   0.  ]
 [0.   0.   0.   0.   0.   0.   0.02 0.06 0.02 0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.02 0.06 0.02]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.02 0.08]]

它与我的预期更接近:对于$j \ne i-1$,$\text{Gamma_star}(i,j) = 0$

新的代码是:

代码语言:javascript
复制
import numpy as np
import math
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.colors as colors

#Parameters
N = 10                        #Step of the discritization of [0,1]
Niter = 10**5

def Sinkhorn(C, mu, nu, lamb):
    # lam : strength of the entropic regularization

    #Initialization
    a1 = np.zeros(N)
    b1 = np.zeros(N)
    a2 = np.ones(N)
    b2 = np.ones(N)
    Iter = 0

    GammaB = np.exp(-lamb*C)

    #Sinkhorn
    while (Iter < Niter):
        a1 = a2
        b1 = b2
        a2 = mu/(np.dot(GammaB,b1))
        b2 = nu/(np.dot(GammaB.T,a2))
        Iter +=1

    # Compute gamma_star
    Gamma = np.zeros((N,N))
    for i in range(N):
        for j in range(N):
            Gamma[i][j] = a2[i]*b2[j]*GammaB[i][j]
    Gamma /= Gamma.sum()

    return Gamma

    ## Test between uniform([0;1]) over uniform([1;2])

    S = np.linspace(0,1,N, False)  #discritization of [0,1]
    T = np.linspace(1,2,N,False)  #discritization of [1,2]

    # Discretization of uniform([0;1])
    U01 = np.ones(N)
    Mass = np.sum(U01)
    U01 = U01/Mass

    # Discretization uniform([1;2])
    U12 = np.ones(N)
    Mass = np.sum(U12)
    U12 = U12/Mass

    # Cost function
    X,Y = np.meshgrid(S,T)
    C = (X-Y)**2               #Matrix of c[i,j]=(xi-yj)²

    def plot_Sinkhorn_U01_U12():

        #plot optimal measure and convergence
        fig = plt.figure()
        for i in range(4):
            ax = fig.add_subplot(2, 2, i+1, projection='3d')
            Gamma_star = Sinkhorn(C, U01, U12, 5**i)
            ax.scatter(X, Y, Gamma_star, cmap='viridis', linewidth=0.5)
            plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(5**i))
        plt.show()

        plt.figure()
        for i in range(4):
            plt.subplot(2,2,i+1)
            Gamma_star = Sinkhorn(C, U01, U12, 5**i)
            plt.imshow(Gamma_star,interpolation='none')
            plt.title("Gamma_bar({}) between uniform([0,1]) and uniform([1,2])".format(5**i))

        plt.show()

        return

    # The transport between U01 ans U12 is x -> x-1 so the support of gamma^* is contained in the graph of the function x -> (x,x-1) which is the line y = x-1

    plot_Sinkhorn_U01_U12()
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59466082

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档