首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >双塔网做深时不学

双塔网做深时不学
EN

Data Science用户
提问于 2021-04-06 17:21:22
回答 1查看 403关注 0票数 2

我一直在尝试训练一个相对简单的双塔网作为推荐.我正在使用PyTorch,实现如下-基本上是为用户和项目嵌入层,可选的前馈网络为两个塔,点产品之间的用户和项目表示,和乙状结肠。

代码语言:javascript
复制
class SimpleTwoTower(nn.Module):
    
    def __init__(self, n_items, n_users, ln):
        super(SimpleTwoTower, self).__init__()
        
        self.ln = ln
        self.item_emb = nn.Embedding(num_embeddings=n_items, embedding_dim=self.ln[0])
        self.user_emb = nn.Embedding(num_embeddings=n_users, embedding_dim=self.ln[0])
       
        
        self.item_layers = [] #nn.ModuleList()
        self.user_layers = [] #nn.ModuleList()
        
        for i, n in enumerate(ln[0:-1]):
            m = int(ln[i+1])
            self.item_layers.append(nn.Linear(n, m, bias=True))
            self.item_layers.append(nn.ReLU())
            
            self.user_layers.append(nn.Linear(n, m, bias=True))
            self.user_layers.append(nn.ReLU())
            
            
        self.item_layers = nn.Sequential(*self.item_layers)
        self.user_layers = nn.Sequential(*self.user_layers)
        
        self.dot = torch.matmul
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, items, users):
        
        item_emb = self.item_emb(items)
        user_emb = self.user_emb(users)
        
        item_emb = self.item_layers(item_emb)
        user_emb = self.user_layers(user_emb)

        dp = self.dot(user_emb, item_emb.t())
        return self.sigmoid(dp)

我正在与二进制交叉熵损失和亚当优化器。当我只使用嵌入时,我看到了从一个时代到另一个时代的改进(损失在减少,评估指标在增加)。然而,一旦我添加了一个前馈层,网络在第一个时代就只学习一点点,然后就停滞了。我尝试用ReLU编写一个线性层,以检查问题是否与我创建层列表的方式有关,但这并没有改变任何事情。

其他人也有类似的问题吗?

编辑:这里我已经在PyTorch论坛上发布了这个问题,我有一些答复。

EN

回答 1

Data Science用户

回答已采纳

发布于 2021-04-09 08:27:25

我现在有了一个工作网络。结果表明,在大约3000次更新后,渐变都是零。我尝试了两种方法来修复这个问题--在前馈网络中的每个激活函数之后使用批处理规范化,并将激活函数从ReLU更改为Leaky。这两种方法都起作用了,最后我使用了漏式ReLU而没有标准化。

关于PyTorch讨论论坛这里中的完整线程

票数 2
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/92655

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档