首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >多任务“嵌套”神经网络的实现

多任务“嵌套”神经网络的实现
EN

Stack Overflow用户
提问于 2021-11-11 21:49:18
回答 1查看 32关注 0票数 0

我正在尝试实现一篇论文使用的多任务神经网络,但我非常不确定我应该如何编写多任务网络,因为作者没有提供该部分的代码。

网络架构看起来像(paper):

为了简单起见,网络架构可以概括为(对于demo,我将它们对单个嵌入的更复杂的操作改为串联):

作者正在总结单个任务和成对任务的损失,并使用总损失来优化每批中三个网络(编码器,MLP-1,MLP-2)的参数,但我对如何将不同类型的数据组合在单个批次中以馈送到共享初始编码器的两个不同网络中感到困惑。我试图搜索其他具有类似结构的网络,但没有找到任何来源。如果您有任何想法,将不胜感激!

EN

回答 1

Stack Overflow用户

发布于 2021-11-11 22:39:45

这实际上是一种常见的模式。它可以通过如下代码来解决。

代码语言:javascript
复制
class Network(nn.Module):
   def __init__(self, ...):
      self.encoder = DrugTargetInteractiongNetwork()
      self.mlp1 = ClassificationMLP()
      self.mlp2 = PairwiseMLP()

   def forward(self, data_a, data_b):
      a_encoded = self.encoder(data_a)
      b_encoded = self.encoder(data_b)

      a_classified = self.mlp1(a_encoded)
      b_classified = self.mlp1(b_encoded)

      # let me assume data_a and data_b are of shape
      # [batch_size, n_molecules, n_features].
      # and that those n_molecules are not necessarily
      # equal.
      # This can be generalized to more dimensions.
      a_broadcast, b_broadcast = torch.broadcast_tensors(
         a_encoded[:, None, :, :],
         b_encoded[:, :, None, :],
      )

      # this will work if your mlp2 accepts an arbitrary number of
      # learding dimensions and just broadcasts over them. That's true
      # for example if it uses just Linear and pointwise
      # operations, but may fail if it makes some specific assumptions
      # about the number of dimensions of the inputs
      pairwise_classified = self.mlp2(a_broadcast, b_broadcast)

      # if that is a problem, you have to reshape it such that it
      # works. Most torch models accept at least a leading batch dimension
      # for vectorization, so we can "fold" the pairwise dimension
      # into the batch dimension, presenting it as
      # [batch*n_mol_1*n_mol_2, n_features]
      # to mlp2 and then recover it back
      B, N1, N_feat = a_broadcast.shape
      _B, N2, _N_feat = b_broadcast.shape
      a_batched = a_broadcast.reshape(B*N1*N2, N_feat)
      b_batched = b_broadcast.reshape(B*N1*N2, N_feat)
      # above, -1 would suffice instead of B*N1*N2, just being explicit
      batch_output = self.mlp2(a_batched, b_batched)

      # this should be exactly the same as `pairwise_classified`
      alternative_classified = batch_output.reshape(B, N1, N2, -1)

      return a_classified, b_classified, pairwise_classified
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69935341

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档