首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >distribute.MirroredStrategy和tf.Estimator温馨入门

distribute.MirroredStrategy和tf.Estimator温馨入门
EN

Stack Overflow用户
提问于 2018-06-08 18:00:38
回答 1查看 872关注 0票数 0

我正在尝试使用MirroredStartegy和tf.Estimator运行多gpus训练。第一次尝试是在估计器model_fn中使用tf.train.init_from_chekpoint,如下所示

代码语言:javascript
复制
def model_fn(features, labels, mode, params):

    .....

   tf.train.init_from_checkpoint(params['resnet_checkpoint'], {'/': 'resnet50/'})

   ....

这将抛出以下错误

代码语言:javascript
复制
.../tensorflow/contrib/distribute/python/values.py", line 285, in _get_update_device
    "Use DistributionStrategy.update() to modify a MirroredVariable.")

下一次尝试是使用tf.estimator.WarmStartSetting

代码语言:javascript
复制
ws = tf.estimator.WarmStartSettings(
        ckpt_to_initialize_from=params['resnet_checkpoint'],
        vars_to_warm_start='resnet50.*',
        var_name_to_prev_var_name=var_name_to_prev_var_name
    )

session_config = tf.ConfigProto(allow_soft_placement=True)

if FLAGS.num_gpus == 0:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:CPU:0')
elif FLAGS.num_gpus == 1:
        distribution = tf.contrib.distribute.OneDeviceStrategy('device:GPU:0')
else:
        distribution = tf.contrib.distribute.MirroredStrategy(
            num_gpus=FLAGS.num_gpus
        )
run_config = tf.estimator.RunConfig(train_distribute=distribution,
                                        session_config=session_config)

estimator = tf.estimator.Estimator(
        model_fn=model_function,
        params=params,
        config=run_config,
        model_dir=FLAGS.model_dir,
        warm_start_from=ws
    )

同样,这会抛出一个错误

代码语言:javascript
复制
TypeError: var MUST be one of the following: a Variable, list of Variable or PartitionedVariable, but is <class 'tensorflow.contrib.distribute.python.values.MirroredVariable'>

有什么想法可以解决这两种方法中的一种吗?

EN

回答 1

Stack Overflow用户

发布于 2018-06-13 06:43:31

不幸的是,MirroredStrategy还不支持使用您尝试过的两种机制从检查点进行恢复。我已经提交了一个github问题来追踪这个https://github.com/tensorflow/tensorflow/issues/19958。请关注此问题以获取进展。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50758110

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档