首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow模型剪枝为训练和验证损失提供“nan”

Tensorflow模型剪枝为训练和验证损失提供“nan”
EN

Stack Overflow用户
提问于 2021-08-17 17:00:36
回答 1查看 145关注 0票数 0

我试图修剪一个基本模型,它由VGG网络上的几个层组成。它还包含一个名为instance_normalization的用户定义层。为了成功地剪枝,我定义了该层的get_prunable_weights函数如下:

代码语言:javascript
复制
### defined for model pruning
    def get_prunable_weights(self):
        return self.weights

我使用了以下函数,使用一个名为model的基本模型获得了要修剪的模型结构。

代码语言:javascript
复制
def define_prune_model(self, model, img_shape, epochs, batch_size, validation_split=0.1):
        num_images = img_shape[0] * (1 - validation_split)
        end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

        # Define model for pruning.
        pruning_params = {
            'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
                                                                    final_sparsity=0.80,
                                                                    begin_step=0,
                                                                    end_step=end_step)
        }

        model_for_pruning = prune_low_magnitude(model, **pruning_params)

        model_for_pruning.compile(optimizer='adam',
                    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                    metrics=['accuracy'])

        model_for_pruning.summary()

        return model_for_pruning

然后,我编写了以下函数来执行这个剪枝模型的培训:

代码语言:javascript
复制
def train_prune_model(self, model_for_pruning, train_images, train_labels,
                     epochs, batch_size, validation_split=0.1):
    callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    tfmot.sparsity.keras.PruningSummaries(log_dir='./models/pruned'),
    ]
    model_for_pruning.fit(train_images, train_labels,
                batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                callbacks=callbacks)
    return model_for_pruning

然而,在训练时,我发现所有的训练和验证损失都是nan,最终的模型预测输出完全为零。然而,传递给define_prune_model的基本模型已经成功地进行了正确的训练和预测。

我怎么才能解决这个问题?提前谢谢你。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-08-30 12:43:12

如果没有更多的信息,就很难确定这个问题。特别是,您能给出更多关于自定义instance_normalization层的详细信息(最好是代码)吗?

假设代码很好:既然您提到模型不需要修剪就可以正确地进行训练,那么这些剪枝参数是否太苛刻了呢?毕竟,这些选项从第一步开始就将权重的50%设置为零。

以下是我要尝试的:

  • 以较低的稀疏度(特别是initial_sparsity)为实验对象。
  • 在稍后的培训期间开始应用剪枝(begin_step参数的剪枝时间表)。有些人甚至更喜欢在没有修剪的情况下训练一次模型。然后再用prune_low_magnitude()重新训练。
  • 只在某些步骤进行剪枝,给模型在剪枝之间恢复的时间(frequency参数)。
  • 最后,当它仍然失败时,通常的治疗方法是:降低学习率,使用正则化或梯度剪裁。
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68821353

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档