首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow剪枝模型与原始基线模型大小相同

Tensorflow剪枝模型与原始基线模型大小相同
EN

Stack Overflow用户
提问于 2021-03-16 12:08:34
回答 1查看 141关注 0票数 0

我有一个基线TF功能模型,我想修剪。我尝试遵循文档中的代码,但压缩剪枝模型的大小与压缩基线模型的大小相同。

(https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide#export_model_with_size_compression)

我不相信我的代码有什么问题,那么为什么会发生这种情况呢?

代码语言:javascript
复制
def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)


def test():
    model = keras.models.load_model('models/cifar10/baselines/convnet_small')
    model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(model)

    model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

    print("Size of gzipped baseline model: %.2f bytes" % (get_gzipped_model_size(model)))
    print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
    print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))


if __name__ == "__main__":
    test()

输出:

代码语言:javascript
复制
Size of gzipped baseline model: 604286.00 bytes

Size of gzipped pruned model without stripping: 610750.00 bytes

Size of gzipped pruned model with stripping: 604287.00 bytes

编辑:

我也尝试使用与文档中相同的模型,并且剪枝模型仍然与基线大小相同:

代码语言:javascript
复制
input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)


def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights


setup_model()
pretrained_weights = setup_pretrained_weights()

输出:

代码语言:javascript
复制
Size of gzipped baseline model: 2910.00 bytes
Size of gzipped pruned model without stripping: 3333.00 bytes
Size of gzipped pruned model with stripping: 2910.00 bytes
EN

回答 1

Stack Overflow用户

发布于 2021-06-10 14:15:20

在我看来,你似乎错过了实际进行修剪的步骤。如果我们查看test()函数,就可以将模型设置为剪枝,但实际上从未对其进行修剪。看看下面的编辑。

代码语言:javascript
复制
import tensorflow_model_optimization as tfmot

def test():
    model = keras.models.load_model('models/cifar10/baselines/convnet_small')
    pruning_schedule = tfmot.sparsity.keras.ConstantSparsity(
                         target_sparsity=0.95, 
                         begin_step=0, 
                         end_step=-1, 
                         frequency=100
                        )

    callbacks = [tfmot.sparsity.keras.UpdatePruningStep()]
    model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, pruning_schedule=pruning_schedule)
    model_for_pruning.compile(optimizer="adam", loss="some-loss")
    model_for_pruning.fit(X, y, epochs=2, callbacks=callbacks)
    model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)


    print("Size of gzipped baseline model: %.2f bytes" %(get_gzipped_model_size(model)))
    print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
    print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))

你可以看看我刚才问的问题中的代码。我有一个稍微不同的问题,但在那里张贴的代码有效(至少在某些情况下)。

Why is my pruned model larger than my base model when using Tensorflow's Model Optimization library to prune weights

如果您感兴趣,还可以查看tensorflow.sparsity.keras APIs以查看其他一些选项。

https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66654938

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档