首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow -使用交织或parallel_interleave时出错

TensorFlow -使用交织或parallel_interleave时出错
EN

Stack Overflow用户
提问于 2019-02-21 18:23:30
回答 2查看 409关注 0票数 0

我使用V1.12API的tf.data.Datasets (像这个Q&A )来读取目录中每个文件预先保存的几个.h5文件。我第一次制造了发电机:

代码语言:javascript
复制
class generator_yield:
    def __init__(self, file):
        self.file = file

    def __call__(self):
        with h5py.File(self.file, 'r') as f:
            yield f['X'][:], f['y'][:]

然后列出文件名,并在Dataset中通过它们

代码语言:javascript
复制
def _fnamesmaker(dir, mode='h5'):
    fnames = []
    for dirpath, _, filenames in os.walk(dir):
        for fname in filenames:
            if fname.endswith(mode):
                fnames.append(os.path.abspath(os.path.join(dirpath, fname)))
    return fnames

fnames = _fnamesmaker('./')
len_fnames = len(fnames)
fnames = tf.data.Dataset.from_tensor_slices(fnames)

应用Dataset的交织方法:

代码语言:javascript
复制
# handle multiple files
ds = fnames.interleave(lambda filename: tf.data.Dataset.from_generator(
    generator_yield(filename), output_types=(tf.float32, tf.float32),
    output_shapes=(tf.TensorShape([100, 100, 1]), tf.TensorShape([100, 100, 1]))), cycle_length=len_fnames)
ds = ds.batch(5).shuffle(5).prefetch(5)

# init iterator
it = ds.make_initializable_iterator()
init_op = it.initializer
X_it, y_it = it.get_next()

型号:

代码语言:javascript
复制
# model
with tf.name_scope("Conv1"):
    W = tf.get_variable("W", shape=[3, 3, 1, 1],
                         initializer=tf.contrib.layers.xavier_initializer())
    b = tf.get_variable("b", shape=[1], initializer=tf.contrib.layers.xavier_initializer())
    layer1 = tf.nn.conv2d(X_it, W, strides=[1, 1, 1, 1], padding='SAME') + b
    logits = tf.nn.relu(layer1)


    loss = tf.reduce_mean(tf.losses.mean_squared_error(labels=y_it, predictions=logits))
    train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss)

开始会议:

代码语言:javascript
复制
with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), init_op])
    while True:
        try:
            data = sess.run(train_op)
            print(data.shape)
        except tf.errors.OutOfRangeError:
            print('done.')
            break

该错误看起来如下:

TypeError:预期的str、字节或os.PathLike对象,而不是生成器的init方法的张量。显然,当一个应用交错时,它是一个张量传递到发电机

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-04-04 09:32:26

根据这个post,我的情况不会从使用parralel_interleave的性能中获益。

...have :将源数据集的每个元素转换为多个元素到目标数据集的转换.

它与数据(狗、猫.)保存在单独目录中的典型分类问题更相关。这里有一个分割问题,这意味着标签包含输入图像的相同维数。所有数据都存储在一个目录中,每个.h5文件都包含一个图像及其标签(掩码)。

这里,一个简单的mapnum_parallel_callssufficient

票数 0
EN

Stack Overflow用户

发布于 2019-02-21 19:19:47

不能通过sess.run直接运行dataset对象。你必须定义一个迭代器,得到下一个元素。试着做这样的事情:

代码语言:javascript
复制
next_elem = files.make_one_shot_iterator.get_next()
data = sess.run(next_elem)

你应该能得到你的张量。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/54813820

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档