我试图修剪一个基本模型,它由VGG网络上的几个层组成。它还包含一个名为instance_normalization的用户定义层。为了成功地剪枝,我定义了该层的get_prunable_weights函数如下:
### defined for model pruning
def get_prunable_weights(self):
return self.weights我使用了以下函数,使用一个名为model的基本模型获得了要修剪的模型结构。
def define_prune_model(self, model, img_shape, epochs, batch_size, validation_split=0.1):
num_images = img_shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.5,
final_sparsity=0.80,
begin_step=0,
end_step=end_step)
}
model_for_pruning = prune_low_magnitude(model, **pruning_params)
model_for_pruning.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_for_pruning.summary()
return model_for_pruning然后,我编写了以下函数来执行这个剪枝模型的培训:
def train_prune_model(self, model_for_pruning, train_images, train_labels,
epochs, batch_size, validation_split=0.1):
callbacks = [
tfmot.sparsity.keras.UpdatePruningStep(),
tfmot.sparsity.keras.PruningSummaries(log_dir='./models/pruned'),
]
model_for_pruning.fit(train_images, train_labels,
batch_size=batch_size, epochs=epochs, validation_split=validation_split,
callbacks=callbacks)
return model_for_pruning然而,在训练时,我发现所有的训练和验证损失都是nan,最终的模型预测输出完全为零。然而,传递给define_prune_model的基本模型已经成功地进行了正确的训练和预测。
我怎么才能解决这个问题?提前谢谢你。
发布于 2021-08-30 12:43:12
如果没有更多的信息,就很难确定这个问题。特别是,您能给出更多关于自定义instance_normalization层的详细信息(最好是代码)吗?
假设代码很好:既然您提到模型不需要修剪就可以正确地进行训练,那么这些剪枝参数是否太苛刻了呢?毕竟,这些选项从第一步开始就将权重的50%设置为零。
以下是我要尝试的:
initial_sparsity)为实验对象。begin_step参数的剪枝时间表)。有些人甚至更喜欢在没有修剪的情况下训练一次模型。然后再用prune_low_magnitude()重新训练。frequency参数)。https://stackoverflow.com/questions/68821353
复制相似问题