首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >keras argmax没有渐变。如何定义argmax的梯度?

keras argmax没有渐变。如何定义argmax的梯度?
EN

Stack Overflow用户
提问于 2018-09-17 04:46:25
回答 1查看 2.3K关注 0票数 1

我在gamma层中使用Keras.Backend.armax()。模型编译得很好,但在fit()过程中抛出错误。

代码语言:javascript
复制
ValueError: An operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.

我的模型:

代码语言:javascript
复制
latent_dim = 512
encoder_inputs = Input(shape=(train_data.shape[1],))
encoder_dense = Dense(vocabulary, activation='softmax')
encoder_outputs = Embedding(vocabulary, latent_dim)(encoder_inputs)
encoder_outputs = LSTM(latent_dim, return_sequences=True)(encoder_outputs)
encoder_outputs = Dropout(0.5)(encoder_outputs)
encoder_outputs = encoder_dense(encoder_outputs)
encoder_outputs = Lambda(K.argmax, arguments={'axis':-1})(encoder_outputs)
encoder_outputs = Lambda(K.cast, arguments={'dtype':'float32'})(encoder_outputs)

encoder_dense1 = Dense(train_label.shape[1], activation='softmax')
decoder_embedding = Embedding(vocabulary, latent_dim)
decoder_lstm1 = LSTM(latent_dim, return_sequences=True)
decoder_lstm2 = LSTM(latent_dim, return_sequences=True)
decoder_dense2 = Dense(vocabulary, activation='softmax')

decoder_outputs = encoder_dense1(encoder_outputs)
decoder_outputs = decoder_embedding(decoder_outputs)
decoder_outputs = decoder_lstm1(decoder_outputs)
decoder_outputs = decoder_lstm2(decoder_outputs)
decoder_outputs = Dropout(0.5)(decoder_outputs)
decoder_outputs = decoder_dense2(decoder_outputs)
model = Model(encoder_inputs, decoder_outputs)
model.summary()

便于可视化的模型摘要:

代码语言:javascript
复制
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_7 (InputLayer)         (None, 32)                0         
_________________________________________________________________
embedding_13 (Embedding)     (None, 32, 512)           2018816   
_________________________________________________________________
lstm_19 (LSTM)               (None, 32, 512)           2099200   
_________________________________________________________________
dropout_10 (Dropout)         (None, 32, 512)           0         
_________________________________________________________________
dense_19 (Dense)             (None, 32, 3943)          2022759   
_________________________________________________________________
lambda_5 (Lambda)            (None, 32)                0         
_________________________________________________________________
lambda_6 (Lambda)            (None, 32)                0         
_________________________________________________________________
dense_20 (Dense)             (None, 501)               16533     
_________________________________________________________________
embedding_14 (Embedding)     (None, 501, 512)          2018816   
_________________________________________________________________
lstm_20 (LSTM)               (None, 501, 512)          2099200   
_________________________________________________________________
lstm_21 (LSTM)               (None, 501, 512)          2099200   
_________________________________________________________________
dropout_11 (Dropout)         (None, 501, 512)          0         
_________________________________________________________________
dense_21 (Dense)             (None, 501, 3943)         2022759   
=================================================================
Total params: 14,397,283
Trainable params: 14,397,283
Non-trainable params: 0
_________________________________________________________________

我在谷歌上搜索了解决方案,但几乎所有的都是关于一个错误的模型。有些人建议不要使用导致问题的函数。但是,正如您所看到的,我无法在没有K.argmax的情况下创建此模型(如果您知道任何其他方法,请告诉我)。我如何解决这个问题,从而训练我的模型?

EN

回答 1

Stack Overflow用户

发布于 2018-09-17 04:58:58

由于显而易见的原因,Argmax函数没有梯度;这是如何定义的?为了让您的模型工作,您需要将层设置为不可训练的。根据this question (或the documentation),您需要将trainable = False传递给您的层。至于层权重(如果适用),您可能希望将其设置为单位矩阵。

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52358241

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档