首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >MNIST数据处理- PyTorch

MNIST数据处理- PyTorch
EN

Stack Overflow用户
提问于 2022-01-31 06:05:30
回答 1查看 476关注 0票数 0

我试图为MNIST数据集编写一个变分自动编码器,数据预处理如下:

代码语言:javascript
复制
# Create transformations to be applied to dataset-
transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
        (0.1307,), (0.3081,)
        # (0.5,), (0.5,)
        )
    ]
)


# Create training and validation datasets-
train_dataset = torchvision.datasets.MNIST(
    # root = 'data', train = True,
    root = path_to_data, train = True,
    download = True, transform = transforms
    )

val_dataset = torchvision.datasets.MNIST(
    # root = 'data', train = False,
    root = path_to_data, train = False,
    download = True, transform = transforms
    )

# Sanity check-
len(train_dataset), len(val_dataset)
# (60000, 10000)

# Create training and validation data loaders-
train_dataloader = torch.utils.data.DataLoader(
    dataset = train_dataset, batch_size = 32,
    shuffle = True,
    # num_workers = 2
    )

val_dataloader = torch.utils.data.DataLoader(
    dataset = val_dataset, batch_size = 32,
    shuffle = True,
    # num_workers = 2
    )

# Get a mini-batch of train data loaders-
imgs, labels = next(iter(train_dataloader))

imgs.shape, labels.shape
# (torch.Size([32, 1, 28, 28]), torch.Size([32]))

# Minimum & maximum pixel values-
imgs.min(), imgs.max()
# (tensor(-0.4242), tensor(2.8215))

# Compute min and max for train dataloader-
min_mnist, max_mnist = 0.0, 0.0

for img, _ in train_dataloader:
    if img.min() < min_mnist:
        min_mnist = img.min()
    if img.max() > max_mnist:
        max_mnist = img.max()

print(f"MNIST - train: min pixel value = {min_mnist:.4f} & max pixel value = {max_mnist:.4f}")
# MNIST - train: min pixel value = -0.4242 & max pixel value = 2.8215

min_mnist, max_mnist = 0.0, 0.0

for img, _ in val_dataloader:
    if img.min() < min_mnist:
        min_mnist = img.min()
    if img.max() > max_mnist:
        max_mnist = img.max()

print(f"MNIST - validation: min pixel value = {min_mnist:.4f} & max pixel value = {max_mnist:.4f}")
# MNIST - validation: min pixel value = -0.4242 & max pixel value = 2.8215

使用'ToTensor()‘和’正常化()‘转换,输出图像像素在-0.4242,2.8215之间。VAE中解码器的输出层使用sigmoid或tanh激活函数。Sigmoid在范围0,1中输出值,而tanh在范围- 1,1中输出值。

这可能是一个问题,因为输入在-0.4242,2.8215范围内,而输出可以在0,1或-1,1范围内,这取决于所使用的激活- sigmoid或tanh。

所使用的重建损失为MSE。BCE也可以使用,但它是针对Bernoulli分布和连续数据像素值提出的。

一个简单的解决方法是只使用“ToTensor()”转换,将输入缩放到0,1,然后在VAE中对输出解码器层使用sigmoid激活函数。但是,有什么更好的方法来进行数据预处理,使用图像进行规范化,并对每个通道进行“正常化()”转换,以便输入和输出/重构在相同的范围内?

EN

回答 1

Stack Overflow用户

发布于 2022-02-01 15:18:03

最简单的方法是删除最后一层中的sigmoid或tanh激活函数,然后使用一个线性层作为输出。在这种情况下,网络可以输出任何值,而不限于0、1或-1,1。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70921841

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档