我试图确定如何计算两个torch.distribution.Distribution对象的KL散度。到目前为止,我还没有找到一个函数来完成这个任务。以下是我尝试过的:
import torch as t
from torch import distributions as tdist
import torch.nn.functional as F
def kl_divergence(x: t.distributions.Distribution, y: t.distributions.Distribution):
"""Compute the KL divergence between two distributions."""
return F.kl_div(x, y)
a = tdist.Normal(0, 1)
b = tdist.Normal(1, 1)
print(kl_divergence(a, b)) # TypeError: kl_div(): argument 'input' (position 1) must be Tensor, not Normal发布于 2022-06-23 17:13:54
torch.nn.functional.kl_div正在计算KL散度损失.两个分布之间的KL-散度可以用torch.distributions.kl.kl_divergence计算.
发布于 2022-06-23 13:26:40
tdist.Normal(...)将返回一个正态分布对象,您必须从发行版中获取一个样本.
x = a.sample()
y = b.sample()https://stackoverflow.com/questions/72726304
复制相似问题