首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用tf.data.Dataset.from_generator()向生成器函数发送参数?

如何使用tf.data.Dataset.from_generator()向生成器函数发送参数?
EN

Stack Overflow用户
提问于 2018-09-21 11:57:03
回答 2查看 5.1K关注 0票数 14

我想使用tf.data.Dataset函数创建一些from_generator()。我想向生成器函数(raw_data_gen)发送一个参数。其思想是生成器函数将根据发送的参数产生不同的数据。通过这种方式,我希望raw_data_gen能够提供培训、验证或测试数据。

代码语言:javascript
复制
training_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([1]))

validation_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([2]))

test_dataset = tf.data.Dataset.from_generator(raw_data_gen, (tf.float32, tf.uint8), ([None, 1], [None]), args=([3]))

当我试图以这种方式调用from_generator()时,收到的错误消息是:

代码语言:javascript
复制
TypeError: from_generator() got an unexpected keyword argument 'args'

这是raw_data_gen函数,不过我不确定您是否需要这样做,因为我的直觉是,问题在于调用from_generator()

代码语言:javascript
复制
def raw_data_gen(train_val_or_test):

    if train_val_or_test == 1:        
        #For every filename collected in the list
        for filename, lab in training_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 2:
        #For every filename collected in the list
        for filename, lab in validation_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    elif train_val_or_test == 3:
        #For every filename collected in the list
        for filename, lab in test_filepath_label_dict.items():
            raw_data, samplerate = soundfile.read(filename)
            try: #assume the audio is stereo, ready to be sliced
                raw_data = raw_data[:,0] #raw_data is a np.array, just take first channel with slice
            except IndexError:
                pass #this must be mono audio
            yield raw_data, lab

    else:
        print("generator function called with an argument not in [1, 2, 3]")
        raise ValueError()
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-09-21 12:36:25

您需要基于raw_data_gen定义一个不带任何参数的新函数。您可以使用lambda关键字来执行此操作。

代码语言:javascript
复制
training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
...

现在,我们将一个函数传递给from_generator,它不带任何参数,但它只是充当raw_data_gen,参数设置为1。您可以对验证集和测试集使用相同的方案,分别传递2和3。

票数 16
EN

Stack Overflow用户

发布于 2021-01-30 15:32:01

关于Tensorflow 2.4:

代码语言:javascript
复制
training_dataset = tf.data.Dataset.from_generator(
     raw_data_gen, 
     args=(1), 
     output_types=(tf.float32, tf.uint8), 
     output_shapes=([None, 1], [None]))
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52443273

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档