首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >两个torch.distribution.Distribution对象的KL散度

两个torch.distribution.Distribution对象的KL散度
EN

Stack Overflow用户
提问于 2022-06-23 07:36:53
回答 2查看 367关注 0票数 3

我试图确定如何计算两个torch.distribution.Distribution对象的KL散度。到目前为止,我还没有找到一个函数来完成这个任务。以下是我尝试过的:

代码语言:javascript
复制
import torch as t
from torch import distributions as tdist
import torch.nn.functional as F

def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
    """Compute the KL divergence between two distributions."""
    return F.kl_div(x, y)  

a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)

print(kl_divergence(a, b))  # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-06-23 17:13:54

torch.nn.functional.kl_div正在计算KL散度损失.两个分布之间的KL-散度可以用torch.distributions.kl.kl_divergence计算.

票数 2
EN

Stack Overflow用户

发布于 2022-06-23 13:26:40

tdist.Normal(...)将返回一个正态分布对象,您必须从发行版中获取一个样本.

代码语言:javascript
复制
x = a.sample()
y = b.sample()
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72726304

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档