首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >小镜头学习/连体网络-3通道输入图像

小镜头学习/连体网络-3通道输入图像
EN

Stack Overflow用户
提问于 2020-12-09 22:34:42
回答 1查看 93关注 0票数 0

我正在尝试在一个有不同的几个类和40个训练样本的已准备好的数据集上学习几个镜头(40个镜头学习)。为了加载我的数据,我使用了以下代码:

代码语言:javascript
复制
def list_files(startpath):
    X = []
    images = []
    full_path = []
    for root, dirs, files in os.walk(startpath):
        level = root.replace(startpath, '').count(os.sep)
        indent = ' ' * 4 * (level)
        print('{}{}/'.format(indent, os.path.basename(root)))
        subindent = ' ' * 4 * (level + 1)
        l_alpha.append(dirs)
        for f in files:
        #    print('{}{}'.format(subindent, f))
        #    l_char.append(f)
            full_path.append(str(dirs)+'\\'+str(root)+'\\'+str(f))
            for pixel in f:
                #img_data = cv2.imread(str(root)+'\\'+str(f))
            #    # store loaded image
            #    loaded_images.append([img_data])
            ##X.append(np.stack(loaded_images))
            #X = np.concatenate( loaded_images]]
        #full_path = [full_path[i][2:] for i in range(len(full_path))]
            #print(full_path[i][3:])
                images = [np.array(Image.open(v[3:])) for v in full_path]
        images = [images]
    return root, dirs, files, full_path, images

这完美地解决了我的shape (3267, 100, 100, 3)输出图像的形状问题。我的问题出在下一部分代码中,每次都会收集一批数据:

代码语言:javascript
复制
def get_batch(batch_size,s="train"):
    n_examples= 40
    """Create batch of n pairs, half same class, half different class"""
    if s == 'train':
        X = Xtrain
        X= X.reshape(-1,100,100,3)
        #X= X.reshape(-1,20,105,105)
        categories = train_classes
    else:
        X = Xval
        X= X.reshape(-1,100,100,3)
        categories = val_classes
    #n_classes, n_examples, w, h, chan = X.shape
    tot_examples, w, h, chan = X.shape
    
    n_classes = 51#tot_examples / len(full_path) *100
    
    # randomly sample several classes to use in the batch
    #categories = rng.choice(n_classes,size=(batch_size,),replace=False)
    categories = rng.choice(int(n_classes),size=(batch_size,),replace=True)
    
    # initialize 2 empty arrays for the input image batch
    #pairs=[np.zeros((batch_size, h, w,1)) for i in range(2)]
    pairs=[np.zeros((batch_size, h, w, chan)) for i in range(2)]
    
    # initialize vector for the targets
    targets=np.zeros((batch_size,))
    
    # make one half of it '1's, so 2nd half of batch has same class
    targets[batch_size//2:] = 1
    for i in range(batch_size):
        category = categories[i]
        print(category)
        idx_1 = rng.randint(0, n_examples)#
        pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)
        idx_2 = rng.randint(0, n_examples)
        
        # pick images of same class for 1st half, different for 2nd
        if i >= batch_size // 2:
            category_2 = category  
        else: 
            # add a random number to the category modulo n classes to ensure 2nd image has a different category
            category_2 = (category + rng.randint(1,n_classes)) % n_classes
        
        pairs[1][i,:,:,:] = X[category_2,idx_2].reshape(w, h,1)
    
    return pairs, targets

我得到的回溯是:

代码语言:javascript
复制
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-129-a37505d52b27> in <module>
----> 1 (inputs,targets) = get_batch(batch_size)

<ipython-input-128-e55ce72910a7> in get_batch(batch_size, s)
     33         print(category)
     34         idx_1 = rng.randint(0, n_examples)#
---> 35         pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)
     36         idx_2 = rng.randint(0, n_examples)
     37 

ValueError: cannot reshape array of size 300 into shape (100,100,3)

我理解错误,但我看不出问题所在。因为示例的数量对应于指定的40个。有人能告诉我代码的问题出在哪里吗?我做了一些评论,这可能有助于更好地理解代码。提前谢谢你。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-12-09 23:45:49

由于我可以跟踪X的形状,您之前已经将其重塑为(-1, 100, 100, 3)

在出现错误的行pairs[0][i,:,:,:] = X[category, idx_1].reshape(w, h, chan)中,对X进行第一和第二个维度的切片。因此,X[category, idx_1]的形状将是(100, 3),使其大小为300。因此,不可能将其重塑为(100, 100, 3)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65218857

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档