首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >MNIST数据集上的无条件生成对抗串行网络

MNIST数据集上的无条件生成对抗串行网络
EN

Stack Overflow用户
提问于 2019-12-23 16:53:21
回答 1查看 68关注 0票数 2

我正在使用tfgan库和tfgan估计器在MNIST数据集上训练无条件GAN。一切都很好,图像正在生成,see。生成器和鉴别器模型函数的助手函数是使用tf.layers编写的。但是,当我只更改助手函数并使用tf.keras编写它们时,相同的代码不起作用,也没有图像生成,see。有人能帮我解决这个问题吗?这两个脚本之间唯一的区别是帮助器函数从使用tf.layers改为使用tf.keras。使用tf.layers的助手函数:

代码语言:javascript
复制
def _dense(inputs, units, l2_weight):
  return tf.layers.dense(
      inputs, units, None,
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))

def _batch_norm(inputs, is_training):
  return tf.layers.batch_normalization(
      inputs, momentum=0.999, epsilon=0.001, training=is_training)

def _deconv2d(inputs, filters, kernel_size, stride, l2_weight):
  return tf.layers.conv2d_transpose(
      inputs, filters, [kernel_size, kernel_size], strides=[stride, stride], 
      activation=tf.nn.relu, padding='same',
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))

def _conv2d(inputs, filters, kernel_size, stride, l2_weight):
  return tf.layers.conv2d(
      inputs, filters, [kernel_size, kernel_size], strides=[stride, stride], 
      activation=None, padding='same',
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight)) 

使用tf.keras的助手函数:

代码语言:javascript
复制
def _dense(inputs, units, l2_weight):
  return Dense(units,
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))(inputs)

def _batch_norm(inputs, is_training):
  return BatchNormalization(momentum=0.999, epsilon=0.001)(inputs, training = is_training)


def _deconv2d(inputs, filters, kernel_size, stride, l2_weight):
  return Conv2DTranspose(filters=filters, kernel_size=[kernel_size, kernel_size], strides=[stride, stride],
                                      activation=keras.activations.relu, padding='same',
                                      kernel_initializer=keras.initializers.glorot_uniform,
                                      kernel_regularizer=keras.regularizers.l2(l=l2_weight),
                                      bias_regularizer=keras.regularizers.l2(l=l2_weight))(inputs)

def _conv2d(inputs, filters, kernel_size, stride, l2_weight):
  return Conv2D(filters=filters, kernel_size=[kernel_size, kernel_size], strides=[stride, stride], padding='same',
      kernel_initializer=tf.keras.initializers.glorot_uniform,
      kernel_regularizer=tf.keras.regularizers.l2(l=l2_weight),
      bias_regularizer=tf.keras.regularizers.l2(l=l2_weight))(inputs)
EN

回答 1

Stack Overflow用户

发布于 2020-01-03 22:44:04

不幸的是,tfgan目前依赖于variable_scopes才能正常工作,而Keras层并不尊重variable_scopes。我们有一个重新设计的总体计划,将支持Keras,但不幸的是,目前我们没有任何东西可以展示它或ETA。欢迎贡献代码!

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59452422

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档