首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用batch_normalization时无法实例化Keras模型

使用batch_normalization时无法实例化Keras模型
EN

Stack Overflow用户
提问于 2019-06-09 00:47:45
回答 1查看 41关注 0票数 0

我不确定我做错了什么,但我正在按照书中的代码创建GAN模型,在实例化期间,Python shell只是冻结了。代码实际上是书中的一些代码的子集,但是书中的代码也无法创建模型。

但是,如果我注释掉batch_norm,我就可以实例化一个模型。

这里:

https://github.com/PacktPublishing/Advanced-Deep-Learning-with-Keras/blob/master/chapter4-gan/dcgan-mnist-4.2.1.py

文档:https://keras.io/layers/normalization/

代码语言:javascript
复制
from keras.layers import Activation, Dense, Input
from keras.layers import Conv2D, Flatten
from keras.layers import Reshape, Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import BatchNormalization
from keras.optimizers import RMSprop
from keras.models import Model
from keras.datasets import mnist
from keras.models import load_model
import keras

import numpy as np
import math
import matplotlib.pyplot as plt
import os
import argparse




def generator_model(inputs, image_size, verbose = True):
    """Generator Model

    args
    =======
    inputs = input layer
    image_size = size of image dimension (299? 480? 28?etc)

    """

    #resized dependent on how many Conv2d Transpore

    print("build generator model")

    image_resize = image_size // 4 
    kernel_size = 5
    layer_filters = [128, 64] #first two convs
    final_layer_filters = [32, 1] # last two conbs

    x= inputs
    x = Dense(image_resize * image_resize * layer_filters[0])(x)
    x = Reshape((image_resize, image_resize, layer_filters[0]))(x)
    print(x)

    for filter_ in layer_filters:
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filter_,
                            kernel_size=kernel_size,
                            strides=2,
                            padding='same')(x)


    print("built first part")
    for filter_ in final_layer_filters:
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = Conv2DTranspose(filters=filter_,
                            kernel_size=kernel_size,
                            strides=1,
                            padding='same')(x)

    x = Activation('sigmoid')(x)
    print("finised building")
    generator = Model(inputs, x, name='generator')
    if verbose:
        print(generator.summary())
    return generator






print(keras.__version__) #2.24
z_size = 100
img_size = 28
gen_input =  Input(shape= (z_size,), name='gen_input')
generator = generator_model(gen_input, img_size)

Shell输出以下内容,当它仍在运行时,它没有完成脚本的运行,它只是处于停顿状态:

代码语言:javascript
复制
2.2.4
build generator model
Tensor("reshape_1/Reshape:0", shape=(?, 7, 7, 128), dtype=float32)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2019-06-09 01:15:39

我在谷歌colab里试过你的代码。将生成以下内容。我认为这不是代码的问题。您可以检查其他问题,例如设置。

代码语言:javascript
复制
    Using TensorFlow backend.
    2.2.4
    build generator model
    WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
    Instructions for updating:
    Colocations handled automatically by placer.
    Tensor("reshape_1/Reshape:0", shape=(?, 7, 7, 128), dtype=float32)
    built first part
    finised building
    _________________________________________________________________
    Layer (type)                 Output Shape              Param #
    =================================================================
    gen_input (InputLayer)       (None, 100)               0
    _________________________________________________________________
    dense_1 (Dense)              (None, 6272)              633472
    _________________________________________________________________
    reshape_1 (Reshape)          (None, 7, 7, 128)         0
    _________________________________________________________________
    batch_normalization_1 (Batch (None, 7, 7, 128)         512
    _________________________________________________________________
    activation_1 (Activation)    (None, 7, 7, 128)         0
    _________________________________________________________________
    conv2d_transpose_1 (Conv2DTr (None, 14, 14, 128)       409728
    _________________________________________________________________
    batch_normalization_2 (Batch (None, 14, 14, 128)       512
    _________________________________________________________________
    activation_2 (Activation)    (None, 14, 14, 128)       0
    _________________________________________________________________
    conv2d_transpose_2 (Conv2DTr (None, 28, 28, 64)        204864
    _________________________________________________________________
    batch_normalization_3 (Batch (None, 28, 28, 64)        256
    _________________________________________________________________
    activation_3 (Activation)    (None, 28, 28, 64)        0
    _________________________________________________________________
    conv2d_transpose_3 (Conv2DTr (None, 28, 28, 32)        51232
    _________________________________________________________________
    batch_normalization_4 (Batch (None, 28, 28, 32)        128
    _________________________________________________________________
    activation_4 (Activation)    (None, 28, 28, 32)        0
    _________________________________________________________________
    conv2d_transpose_4 (Conv2DTr (None, 28, 28, 1)         801
    _________________________________________________________________
    activation_5 (Activation)    (None, 28, 28, 1)         0
    =================================================================
    Total params: 1,301,505
    Trainable params: 1,300,801
    Non-trainable params: 704
    _________________________________________________________________
            None
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56508349

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档