首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >CNN半监督学习中的问题

CNN半监督学习中的问题
EN

Stack Overflow用户
提问于 2021-07-26 08:04:22
回答 1查看 164关注 0票数 0

我进行了半监督学习,在数据集中标注未标注的图像。CNN模型以无标号图像为输入,经过softmax计算,生成一个概率指数。如果值超过某个数字(例如0.65),我将标记图像并将其添加到火车组中。获取说服力数据集的代码:

代码语言:javascript
复制
def get_pseudo_labels(trainset, dataset, model, threshold=0.65):
# This functions generates pseudo-labels of a dataset using given model.
# It returns an instance of DatasetFolder containing images whose prediction confidences exceed a given threshold.
device = "cuda" if torch.cuda.is_available() else "cpu"

# Construct a data loader.
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
# The dataset is unlabelled image

# Make sure the model is in eval mode.
model.eval()
# Define softmax function.
softmax = nn.Softmax(dim=-1)

# Iterate over the dataset by batches.
for batch in tqdm(data_loader):

    img, labels = batch
    # Forward the data
    # Using torch.no_grad() accelerates the forward process.
    with torch.no_grad():
        logits = model(img.to(device))

    # Obtain the probability distributions by applying softmax on logits.
    probs = softmax(logits)
    # calculate probs

    for j in range(0, batch_size):
        for i in range(0, 11):
            if probs[j][i].item() > threshold:
                batch[1][j] = torch.Tensor([i]) # Label the imgae
                temp = batch[0][j] + batch[1][j] # contact two tensor
                trainset = ConcatDataset([trainset, temp]) # add this labelled image into trainset

model.train()
return trainset

编者提醒我:

if probsj.item() >阈值: IndexError:索引2对于尺寸为2的维度0是超出界限的

不过,我可以正常打印问题。

代码语言:javascript
复制
        for j in range(0, batch_size):
        for i in range(0, 11):
            print('batch:', j)
            print('The value of label', i)
            print(probs[j][i])
            if probs[j][i].item() > threshold:
                batch[1][j] = torch.Tensor([i])
                temp = batch[0][j] + batch[1][j]
                trainset = ConcatDataset([trainset, temp])

输出:

代码语言:javascript
复制
...
batch: 63
The value of label 9
tensor(0.0859, device='cuda:0')
batch: 63
The value of label 10
tensor(0.0977, device='cuda:0')

我不知道IndexError是什么意思..。

国际管理小组的主要工作是:

代码语言:javascript
复制
tensor([...(img)],[...(label)])
EN

回答 1

Stack Overflow用户

发布于 2022-06-28 02:50:45

确保dataset % batch_size = 0。

对于batch_size (例如4),您应该在数据集中有很多示例(8或12或16等等)。

这里16 %4=0

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68526388

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档