在使用SRGAN的代码时,我想用tf.image.resize_bicubic替换UpSampling2D。我为这个函数使用了keras lambda层,如下所示
def bicubic_lambda(input_shape, scale=2):
shp = input_shape
def outshape(input_shape):
dims = [input_shape[0],input_shape[1] * scale,input_shape[2] * scale,input_shape[3]]
output_shape = tuple(dims)
return output_shape
def bic(x):
return image.resize_images(x, [shp[1]*scale,shp[2]*scale], method=image.ResizeMethod.BICUBIC)
return Lambda(bic, output_shape=outshape, name='bicubic_lambda')它训练得很好,没有任何错误,但问题是我不能保存模型。generator.save(model_save_dir + 'gen_model%d.h5' % e)导致代码在UpSampling2d中正常工作时出现错误。我的tf版本是1.14.0,keras是2.3.1。请帮助:)
发布于 2021-11-01 01:37:05
这是一个简单的解决方案。
model = Lambda(lambda image: tf.image.resize_images(image, (image.shape[1]*2, image.shape[2]*2), method = tf.image.ResizeMethod.BICUBIC))(model)https://stackoverflow.com/questions/69785637
复制相似问题