首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使嵌入相似的火炬损失函数

使嵌入相似的火炬损失函数
EN

Stack Overflow用户
提问于 2020-12-31 13:59:40
回答 1查看 2.1K关注 0票数 0

我正在研究一个嵌入模型,其中有一个BERT模型,它接收文本输入并输出多维向量。该模型的目的是为相似的文本找到相似的嵌入(高余弦相似度),对于不相似的文本寻找不同的嵌入(低余弦相似度)。

在小型批处理模式下进行训练时,BERT模型给出N*D维输出,其中N是批处理大小,D是BERT模型的输出维数。

此外,我还有一个维数N*N的目标矩阵,如果sentence[i]sentence[j]在意义上是相似的,则在[i, j]第th位置包含sentence[i],如果没有,则包含-1

我想做的是通过找到BERT输出中所有嵌入的余弦相似性,并将其与目标矩阵进行比较,找到整个批处理的丢失/错误。

我所做的只是把张量乘以它的转置,然后取元素乙状体。

代码语言:javascript
复制
scores = torch.matmul(document_embedding, torch.transpose(document_embedding, 0, 1))
scores = torch.sigmoid(scores)

loss = self.bceloss(scores, targets)

但这似乎行不通。

还有别的办法吗?

我想做的是类似于本论文中描述的方法。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-31 15:36:44

要计算两个向量之间的余弦相似性,您可以使用nn.CosineSimilarity。但是,我认为这不允许你从一组n向量中得到对相似度。幸运的是,您可以通过一些张量操作实现它。

让我们称x为document_embedding of shape (n, d),其中d是嵌入大小。我们要n=3d=5。所以x是由[x1, x2, x3].T组成的。

代码语言:javascript
复制
>>> x = torch.rand(n, d)
tensor([[0.8620, 0.9322, 0.4220, 0.0280, 0.3789],
        [0.2747, 0.4047, 0.6418, 0.7147, 0.3409],
        [0.6573, 0.3432, 0.5663, 0.2512, 0.0582]])

余弦相似性是一个归一化的点积。x@x.T 矩阵乘法将为您提供成对的点产品:包含:||x1||²<x1/x2><x1/x3><x2/x1>||x2||²等。

代码语言:javascript
复制
>>> sim = x@x.T
tensor([[1.9343, 1.0340, 1.1545],
        [1.0340, 1.2782, 0.8822],
        [1.1545, 0.8822, 0.9370]])

将所有范数的向量标准化:||x1||||x2||||x3||

代码语言:javascript
复制
>>> norm = x.norm(dim=1)
tensor([1.3908, 1.1306, 0.9680])

构造了包含归一化因子:||x1||²||x1||.||x2||||x1||.||x3||||x2||.||x1||||x2||²等的矩阵。

代码语言:javascript
复制
>>> factor = norm*norm.unsqueeze(1)
tensor([[1.9343, 1.5724, 1.3462],
        [1.5724, 1.2782, 1.0944],
        [1.3462, 1.0944, 0.9370]])

然后正常化:

代码语言:javascript
复制
>>> sim /= factor
tensor([[1.0000, 0.6576, 0.8576],
        [0.6576, 1.0000, 0.8062],
        [0.8576, 0.8062, 1.0000]])

或者,是避免创建规范矩阵的一种更快捷的方法,它是在乘法之前进行规范化:

代码语言:javascript
复制
>>> x /= x.norm(dim=1, keepdim=True)
>>> sim = x@x.T
tensor([[1.0000, 0.6576, 0.8576],
        [0.6576, 1.0000, 0.8062],
        [0.8576, 0.8062, 1.0000]])

对于损失函数,我将直接在预测的相似矩阵和目标矩阵之间应用nn.CrossEntropyLoss,而不是使用sigmoid + BCE。注:nn.CrossEntropyLoss包括nn.LogSoftmax

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65521840

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档