首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >加载tensorflow图像并创建补丁

加载tensorflow图像并创建补丁
EN

Stack Overflow用户
提问于 2020-10-12 22:18:54
回答 2查看 2.4K关注 0票数 2

我使用目录将一个非常大的RGB图像数据集从磁盘加载到一个数据集中。例如,

代码语言:javascript
复制
dataset = tf.keras.preprocessing.image_dataset_from_directory(
    <directory>,
    label_mode=None,
    seed=1,
    subset='training',
    validation_split=0.1)

例如,该数据集有100000幅图像分组成32大小的批次,生成了一个带有规范tf.data.Dataset(batch=32, width=256, height=256, channels=3)

我想从图像中提取补丁,创建一个图像空间维数为64x64的新tf.data.Dataset

因此,我想要创建一个新的数据集,该数据集的400000个补丁仍在32的批处理中,其中包含一个带有规范tf.data.Dataset(batch=32, width=64, height=64, channels=3)

我已经看过窗户方法和补丁函数,但是从文档中还不清楚如何使用它们来创建一个新的数据集,我需要开始有关补丁的培训。window似乎面向一维张量,而extract_patches似乎适用于数组而不是数据集。

对于如何做到这一点,有什么建议吗?

更新:

只是为了澄清我的需要。我试图避免手动创建磁盘上的修补程序。第一,这将是无法维持的磁盘智慧。二,补丁大小不是固定的。实验将在几个贴片尺寸上进行。因此,我不希望在磁盘上手动执行修补程序创建,也不希望手动加载内存中的映像并执行修补程序。我更愿意让tensorflow作为管道工作流的一部分来处理补丁创建,以减少磁盘和内存的使用。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2020-10-12 23:41:42

你要找的是tf.image.extract_patches。下面是一个例子:

代码语言:javascript
复制
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

data = tfds.load('mnist', split='test', as_supervised=True)

get_patches = lambda x, y: (tf.reshape(
    tf.image.extract_patches(
        images=tf.expand_dims(x, 0),
        sizes=[1, 14, 14, 1],
        strides=[1, 14, 14, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'), (4, 14, 14, 1)), y)

data = data.map(get_patches)

fig = plt.figure()
plt.subplots_adjust(wspace=.1, hspace=.2)
images, labels = next(iter(data))
for index, image in enumerate(images):
    ax = plt.subplot(2, 2, index + 1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image)
plt.show()

票数 1
EN

Stack Overflow用户

发布于 2020-10-12 23:37:59

我相信您可以使用python类生成器。如果需要,可以将此生成器传递给model.fit函数。实际上,我曾经用它来做标签预处理。

我编写了以下数据集生成器,该生成器从数据集加载批处理,根据tile_shape参数将批处理中的图像分割为多个图像。如果有足够的图像,将返回下一批图像。

在本例中,我使用了一个简单的dataset from_tensor_slices进行简化。当然,你可以用你的代替。

代码语言:javascript
复制
import tensorflow as tf

class TileDatasetGenerator:
    
    def __init__(self, dataset, batch_size, tile_shape):
        self.dataset_iterator = iter(dataset)
        self.batch_size = batch_size
        self.tile_shape = tile_shape
        self.image_queue = None
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self._has_queued_enough_for_batch():
            return self._dequeue_batch()
        
        batch = next(self.dataset_iterator)
        self._split_images(batch)    
        return self.__next__()
            
    def _has_queued_enough_for_batch(self):
        return self.image_queue is not None and tf.shape(self.image_queue)[0] >= self.batch_size
    
    def _dequeue_batch(self):
        batch, remainder = tf.split(self.image_queue, [self.batch_size, -1], axis=0)
        self.image_queue = remainder
        return batch
        
    def _split_images(self, batch):
        batch_shape = tf.shape(batch)
        batch_splitted = tf.reshape(batch, shape=[-1, self.tile_shape[0], self.tile_shape[1], batch_shape[-1]])
        if self.image_queue is None:
            self.image_queue = batch_splitted
        else:
            self.image_queue = tf.concat([self.image_queue, batch_splitted], axis=0)
            


dataset = tf.data.Dataset.from_tensor_slices(tf.ones(shape=[128, 64, 64, 3]))
dataset.batch(32)
generator = TileDatasetGenerator(dataset, batch_size = 16, tile_shape = [32,32])

for batch in generator:
    tf.print(tf.shape(batch))

编辑:如果需要,可以将生成器转换为tf.data.Dataset,但需要向返回迭代器的生成器添加__call__函数(在本例中为self)。

代码语言:javascript
复制
new_dataset = tf.data.Dataset.from_generator(generator, output_types=(tf.int64))
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/64326029

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档