首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >模型不适合Tensorflow数据样条具有未知TensorShape

模型不适合Tensorflow数据样条具有未知TensorShape
EN

Stack Overflow用户
提问于 2022-04-28 10:31:43
回答 2查看 88关注 0票数 0

我有一个用于视频数据的数据加载器管道。虽然我指定了管道的输出,但在调用model.fit时仍然会出现以下错误。"ValueError: as_list()不是在未知的TensorShape上定义的“。我搜索了这个错误,大多数人说是因为tf.numpy_function返回了一个未知的形状(到Tensorflow管道)。在该函数之后指定形状应该可以解决问题。然而,事实并非如此。

代码语言:javascript
复制
AUTOTUNE = tf.data.experimental.AUTOTUNE

#get list of numpy files in directory
train_ds = tf.data.Dataset.list_files("dir") 

#load numpy files (video with shape 40,160,160,3), get corresponding label and output both
#video and label
def get_label(file_path):
    label = tf.strings.split(file_path, os.path.sep)
    return label [-2]

def process_image(file_path):
    label = get_label(file_path)
    video= np.load(file_path, allow_pickle=True)
    video= tf.convert_to_tensor(video/255, dtype=tf.float32) 
    return video, np.float32(label)

train_ds = train_ds.map(lambda item: tf.numpy_function(
          process_image, [item], (tf.float32, tf.float32)),num_parallel_calls = AUTOTUNE ) 

#Convert video to tf object
def set_shape(video, label):
  video = tf.reshape(video, (40,160,160,3))
  #video = tf.ensure_shape(video, (40,160,160,3)) #also does not work
  #video = tf.convert_to_tensor(video, dtype=tf.float32) #also does not work
  return video, label

train_ds = train_ds.map(set_shape)

#batching
train_ds = train_ds.batch(batch_size =5)

#optimazation
train_ds = train_ds.prefetch(AUTOTUNE)


train_ds.take(1) 

尽管代码的其余部分看起来很好(当我手动输入数据时确实可以工作),但如果不是,我将粘贴它。

代码语言:javascript
复制
def create_LRCN_model():
    '''
    This function will construct the required LRCN model.
    Returns:
        model: It is the required constructed LRCN model.
    '''
 
    # We will use a Sequential model for model construction.
    model = Sequential()
    
    # Define the Model Architecture.

    ########################################################################################
    
    model.add(TimeDistributed(Conv2D(128, (3, 3), padding='same',activation = 'relu'),
                              input_shape = (40, 160, 160, 3)))
    
    model.add(TimeDistributed(MaxPooling2D((4, 4)))) 
    model.add(TimeDistributed(Dropout(0.25)))
    
    model.add(TimeDistributed(Conv2D(256, (3, 3), padding='same',activation = 'relu')))
    model.add(TimeDistributed(MaxPooling2D((4, 4))))
    model.add(TimeDistributed(Dropout(0.25)))
    
    model.add(TimeDistributed(Conv2D(128, (3, 3), padding='same',activation = 'relu')))
    model.add(TimeDistributed(MaxPooling2D((2, 2))))
    model.add(TimeDistributed(Dropout(0.25)))
    
    model.add(TimeDistributed(Conv2D(64, (3, 3), padding='same',activation = 'relu')))
    model.add(TimeDistributed(MaxPooling2D((2, 2))))
    #model.add(TimeDistributed(Dropout(0.25)))
                                      
    model.add(TimeDistributed(Flatten()))
                                      
    model.add(LSTM(32))
                                      
    model.add(Dense(1, activation = 'sigmoid'))
 
    ########################################################################################
    # Display the models summary.
    model.summary()
    
    # Return the constructed LRCN model.
    return model

LRCN_model = create_LRCN_model()
early_stopping_callback = EarlyStopping(monitor = 'val_loss', patience = 15, mode = 'min', restore_best_weights = True)
LRCN_model.compile(loss='binary_crossentropy', optimizer = 'Adam', metrics = ["accuracy"])
LRCN_model_training_history = LRCN_model.fit(train_ds, validation_data= val_ds, epochs = 70,   callbacks = [early_stopping_callback])
EN

回答 2

Stack Overflow用户

发布于 2022-04-28 12:36:05

好吧我找到了另一个解决方案。我不太清楚它为什么工作,只需调用下面的函数就可以了。

代码语言:javascript
复制
def set_shape(video, label):

  video.set_shape((40,160,160, 3))
  label.set_shape([])

  return video, label
票数 1
EN

Stack Overflow用户

发布于 2022-04-28 11:56:12

明白了!您只需将模型编译中的“准确性”更改为"binary_accuracy“即可。它为我工作与您的代码和一些虚拟视频和标签输入数据。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72042131

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档