首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么这个PyTorch AutoEncoder的参数是这样硬编码的?

为什么这个PyTorch AutoEncoder的参数是这样硬编码的?
EN

Stack Overflow用户
提问于 2022-01-10 18:50:50
回答 1查看 231关注 0票数 0

嗨,我想了解下面的PyTorch AutoEncoder代码是如何工作的。下面的代码使用MNIST数据集,该数据集为28X28。我的问题是如何选择nn.Linear(128,3)参数?

我有一个数据集,它是512X512,我想修改这个AutoEncoder所支持的代码。

代码语言:javascript
复制
class LitAutoEncoder(pl.LightningModule):

def __init__(self):
    super().__init__()
    self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
    self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

def forward(self, x):
    # in lightning, forward defines the prediction/inference actions
    embedding = self.encoder(x)
    return embedding

def training_step(self, batch, batch_idx):
    # training_step defined the train loop. It is independent of forward
    x, y = batch
    x = x.view(x.size(0), -1)
    z = self.encoder(x)
    x_hat = self.decoder(z)
    loss = F.mse_loss(x_hat, x)
    return loss

def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    return optimizer
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-01-11 02:23:32

我假设输入的图像数据是这样的:x.shape == [bs, 1, h, w],其中bs是批处理大小。然后,x首先被看作是[bs, h*w],即[bs, 28*28]。这意味着图像中的所有像素都被压平成一维矢量。

然后在编码器中:

  • nn.Linear(28*28, 128)接受尺寸[bs, 28*28]的扁平输入,并输出大小[bs, 28*28]的中间结果。

然后在解码器中:

  • nn.Linear(3, 128)[bs, 3] -> [bs, 128]
  • nn.Linear(128, 28*28)[bs, 128] -> [bs, 28*28]

最后的输出将与输入相匹配。

如果您想要对512x512图像使用确切的架构,只需将代码中的28*28的每次出现都更改为512*512即可。然而,由于这些原因,这是一个相当不可行的选择:

  • 用于MNIST图像,nn.Linear(28*28, 128)包含28x28x128+128=100480参数,而对于您的图像,nn.Linear(512*512, 128)包含512x512x128+128=33554560参数。大小太大,可能导致overfitting
  • The中间数据[bs, 3]只使用3个浮点数来编码512x512图像。我不认为你能用这种压缩

恢复任何东西

为了你的目的,我建议你查一下复杂的建筑

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70657510

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档