首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >再培训pytorch模型(增强学习)

再培训pytorch模型(增强学习)
EN

Stack Overflow用户
提问于 2022-06-12 14:49:37
回答 1查看 171关注 0票数 2

我有以下代码:

代码语言:javascript
复制
import torch
from facenet_pytorch import InceptionResnetV1, MTCNN
from torch.utils.data import DataLoader
from torchvision import datasets
import numpy as np
import pandas as pd
import os


workers = 0 if os.name == 'nt' else 4
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

def collate_fn(x):
    return x[0]

dataset = datasets.ImageFolder('data/images/')
dataset.idx_to_class = {i:c for c, i in dataset.class_to_idx.items()}
loader = DataLoader(dataset, collate_fn=collate_fn, num_workers=workers)
#print(dataset.idx_to_class)

aligned = []
names = []
i = 0
for x, y in loader:
    x_aligned, prob = mtcnn(x, return_prob=True)
    if x_aligned is not None:
        print('Face detected with probability: {:8f}'.format(prob))
        aligned.append(x_aligned)
        names.append(dataset.idx_to_class[y])
        i += 1
#print(i)

for name, param in mtcnn.named_parameters(): #Freezing everything but last layer
    #print(name)
    if name != "onet.dense6_3.bias":
        param.require_grad = False
    else:
        param.require_grad = True

现在我想重新训练这个模型来预测三个类(现在它只预测一张脸的概率)。假设我在data/images/中有三个文件夹,faces1faces2faces3。如何用这三个文件夹重新训练这个模型?我想要一个张量,像[prob1, prob2, prob3],每个类都有一个图像的概率。谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-13 14:09:48

MTCN:这个类加载预先训练过的P-、R-和O-网,并返回裁剪的图像,这些图像仅包括原始输入图像。

我假设您正在尝试使用InceptionResnetV1对您的数据集进行分类。为了重新训练Inception模型,您只需用所需的类数加载模型,然后对其进行培训。

代码语言:javascript
复制
resnet = InceptionResnetV1(
    classify=True,
    pretrained='vggface2',
    num_classes=3
)

完整的细化示例如下所示:https://github.com/timesler/facenet-pytorch/blob/master/examples/finetune.ipynb

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72593257

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档