首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在MNIST上使用tf.contrib.model_pruning?

如何在MNIST上使用tf.contrib.model_pruning?
EN

Stack Overflow用户
提问于 2018-08-28 18:45:02
回答 2查看 2.2K关注 0票数 6

我很难使用Tensorflow的剪枝库,也没有找到很多有用的例子,所以我正在寻找帮助来修剪一个在MNIST数据集上训练过的简单模型。如果有人可以帮助修复我的尝试,或者提供一个如何在MNIST上使用这个库的例子,我将非常感激。

我的代码的前半部分是相当标准的,除了我的模型有两个隐藏层,300个单元宽,使用layers.masked_fully_connected进行剪枝。

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data

# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])

# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = tf.contrib.layers.fully_connected(layer2, 10, tf.nn.relu)

# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

# Training op
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)

# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

然后,我尝试定义必要的剪枝操作,但是我得到了一个错误。

代码语言:javascript
复制
############ Pruning Operations ##############
# Create global step variable
global_step = tf.contrib.framework.get_or_create_global_step()

# Create a pruning object using the pruning specification
pruning_hparams = pruning.get_pruning_hparams()
p = pruning.Pruning(pruning_hparams, global_step=global_step)

# Mask Update op
mask_update_op = p.conditional_mask_update_op()

# Set up the specification for model pruning
prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)

这一行出错:

prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)

_device="/job:localhost/replica:0/task:0/device:GPU:0"]] (回溯参见上文):您必须为占位符张量' Placeholder_1‘提供一个值,其中包含dtype InvalidArgumentError InvalidArgumentError ?,10[Node: Placeholder_1= Placeholderdtype=DT_FLOAT,shape=?,10,Placeholder_1 [Node: global_step/_57 = _Recv_start_time=0,client_terminated=false,recv_device="/job:localhost/replica:0/task:0/device:CPU:0",]send_device="/job:localhost/replica:0/task:0/device:GPU:0",send_device_incarnation=1,tensor_name="edge_71_global_step",tensor_type=DT_INT64,_device="/job:localhost/replica:0/task:0/device:CPU:0"]

我猜想它需要一种不同类型的操作来代替train_op,但是我还没有发现任何有效的调整。

同样,如果你有一个不同的工作例子,修剪一个在MNIST上训练的模型,我会考虑这个答案。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-09-19 21:03:20

最简单的剪枝库示例,我可以得到工作,我想我应该张贴在这里,以防它帮助其他noobie谁有困难的文档。

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data

epochs = 250
batch_size = 55000 # Entire training set

# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)

# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])

# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = layers.masked_fully_connected(layer2, 10)

# Create global step variable (needed for pruning)
global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)

# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

# Training op, the global step is critical here, make sure it matches the one used in pruning later
# running this operation increments the global_step
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)

# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Get, Print, and Edit Pruning Hyperparameters
pruning_hparams = pruning.get_pruning_hparams()
print("Pruning Hyperparameters:", pruning_hparams)

# Change hyperparameters to meet our needs
pruning_hparams.begin_pruning_step = 0
pruning_hparams.end_pruning_step = 250
pruning_hparams.pruning_frequency = 1
pruning_hparams.sparsity_function_end_step = 250
pruning_hparams.target_sparsity = .9

# Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9)
prune_op = p.conditional_mask_update_op()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    # Train the model before pruning (optional)
    for epoch in range(epochs):
        for batch in range(batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})

        # Calculate Test Accuracy every 10 epochs
        if epoch % 10 == 0:
            acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
            print("Un-pruned model step %d test accuracy %g" % (epoch, acc_print))

    acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
    print("Pre-Pruning accuracy:", acc_print)
    print("Sparsity of layers (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))

    # Reset the global step counter and begin pruning
    sess.run(reset_global_step_op)
    for epoch in range(epochs):
        for batch in range(batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Prune and retrain
            sess.run(prune_op)
            sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})

        # Calculate Test Accuracy every 10 epochs
        if epoch % 10 == 0:
            acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
            print("Pruned model step %d test accuracy %g" % (epoch, acc_print))
            print("Weight sparsities:", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))

    # Print final accuracy
    acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
    print("Final accuracy:", acc_print)
    print("Final sparsity by layer (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
票数 3
EN

Stack Overflow用户

发布于 2018-09-25 18:55:33

罗曼尼基辛要求的代码,可以保存模型,这是一个轻微的扩展,我原来的答案。

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.contrib.model_pruning.python import pruning
from tensorflow.contrib.model_pruning.python.layers import layers
from tensorflow.examples.tutorials.mnist import input_data

epochs = 250
batch_size = 55000 # Entire training set
model_path_unpruned = "Model_Saves/Unpruned.ckpt"
model_path_pruned = "Model_Saves/Pruned.ckpt"

# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)

# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784])
label = tf.placeholder(tf.float32, [None, 10])

# Define the model
layer1 = layers.masked_fully_connected(image, 300)
layer2 = layers.masked_fully_connected(layer1, 300)
logits = layers.masked_fully_connected(layer2, 10)

# Create global step variable (needed for pruning)
global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)

# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

# Training op, the global step is critical here, make sure it matches the one used in pruning later
# running this operation increments the global_step
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step)

# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Get, Print, and Edit Pruning Hyperparameters
pruning_hparams = pruning.get_pruning_hparams()
print("Pruning Hyperparameters:", pruning_hparams)

# Change hyperparameters to meet our needs
pruning_hparams.begin_pruning_step = 0
pruning_hparams.end_pruning_step = 250
pruning_hparams.pruning_frequency = 1
pruning_hparams.sparsity_function_end_step = 250
pruning_hparams.target_sparsity = .9

# Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam
p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9)
prune_op = p.conditional_mask_update_op()

# Create a saver for writing training checkpoints.
saver = tf.train.Saver()

with tf.Session() as sess:

    # Uncomment the following if you don't have a trained model yet
    sess.run(tf.initialize_all_variables())

    # Train the model before pruning (optional)
    for epoch in range(epochs):
        for batch in range(batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})

        # Calculate Test Accuracy every 10 epochs
        if epoch % 10 == 0:
            acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
            print("Un-pruned model step %d test accuracy %g" % (epoch, acc_print))

    acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
    print("Pre-Pruning accuracy:", acc_print)
    print("Sparsity of layers (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))

    # Saves the model before pruning
    saver.save(sess, model_path_unpruned)

    # Resets the session and restores the saved model
    sess.run(tf.initialize_all_variables())
    saver.restore(sess, model_path_unpruned)

    # Reset the global step counter and begin pruning
    sess.run(reset_global_step_op)
    for epoch in range(epochs):
        for batch in range(batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # Prune and retrain
            sess.run(prune_op)
            sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})

        # Calculate Test Accuracy every 10 epochs
        if epoch % 10 == 0:
            acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
            print("Pruned model step %d test accuracy %g" % (epoch, acc_print))
            print("Weight sparsities:", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))

    # Saves the model after pruning
    saver.save(sess, model_path_pruned)

    # Print final accuracy
    acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
    print("Final accuracy:", acc_print)
    print("Final sparsity by layer (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52064450

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档