首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使Google平台在培训过程中检测`tf.summary.scalar`调用?

如何使Google平台在培训过程中检测`tf.summary.scalar`调用?
EN

Stack Overflow用户
提问于 2020-04-28 12:20:22
回答 2查看 452关注 0票数 7

(注:我也问过这个问题,这里)

问题

我一直试图让Google的人工智能平台显示出Keras模型的准确性,该模型是在人工智能平台上进行培训的。我用hptuning_config.yaml配置了超参数调优,它可以工作。然而,我不能让AI平台在培训期间接收tf.summary.scalar呼叫。

文档

我一直在跟踪下列文件页:

1. 超参数整定综述

2. 使用超参数整定

根据1

AI平台培训如何获得度量您可能会注意到,在本文档中没有说明如何将您的超参数度量传递给AI平台培训服务。这是因为服务监视由您的培训应用程序生成的TensorFlow摘要事件并检索度量。“

根据2,生成此类Tensorflow摘要事件的一种方法是创建回调类,如下所示:

代码语言:javascript
复制
class MyMetricCallback(tf.keras.callbacks.Callback):

    def on_epoch_end(self, epoch, logs=None):
        tf.summary.scalar('metric1', logs['RootMeanSquaredError'], epoch)

我的代码

所以在我的代码中我包括:

代码语言:javascript
复制
# hptuning_config.yaml

trainingInput:
  hyperparameters:
    goal: MAXIMIZE
    maxTrials: 4
    maxParallelTrials: 2
    hyperparameterMetricTag: val_accuracy
    params:
    - parameterName: learning_rate
      type: DOUBLE
      minValue: 0.001
      maxValue: 0.01
      scaleType: UNIT_LOG_SCALE
代码语言:javascript
复制
# model.py

class MetricCallback(tf.keras.callbacks.Callback):

    def on_epoch_end(self, epoch, logs):
        tf.summary.scalar('val_accuracy', logs['val_accuracy'], epoch)

我甚至试过

代码语言:javascript
复制
# model.py

class MetricCallback(tf.keras.callbacks.Callback):
    def __init__(self, logdir):
        self.writer = tf.summary.create_file_writer(logdir)

    def on_epoch_end(self, epoch, logs):
        with writer.as_default():
            tf.summary.scalar('val_accuracy', logs['val_accuracy'], epoch)

它成功地将“val_accuracy”指标保存到Google存储区(我也可以在TensorBoard中看到这一点)。但这并没有被人工智能平台所接受,尽管这是在1中提出的。

部分解决办法:

使用云ML高调包,我创建了以下类:

代码语言:javascript
复制
# model.py

class MetricCallback(tf.keras.callbacks.Callback):
    def __init__(self):
        self.hpt = hypertune.HyperTune()

    def on_epoch_end(self, epoch, logs):
        self.hpt.report_hyperparameter_tuning_metric(
            hyperparameter_metric_tag='val_accuracy',
            metric_value=logs['val_accuracy'],
            global_step=epoch
        )

起作用了!但我不明白它是如何做到的,因为它看起来只是写到人工智能平台worker at /tmp/hypertune/*上的一个文件。谷歌云文档中没有任何东西可以解释人工智能平台是如何发现的.

为了让tf.summary.scalar事件显示出来,我是不是遗漏了什么?

EN

回答 2

Stack Overflow用户

发布于 2020-07-20 17:41:25

我有同样的问题,我不能让人工智能平台拿起tf.summary.scalar。在过去的两个月里,我尝试用GCP支持和AI平台工程团队来调试它。即使我们使用的是几乎相同的代码,他们也无法重现这个问题。我们甚至做了一个编码会话,但仍然有不同的结果。

GCP平台工程团队的建议:“不要使用tf.summary.scalar"主要原因是通过使用其他方法:

  • 它对每个人都很好
  • 你可以控制和观察发生的事情(不是黑匣子)。

他们将更新文件,以反映这一新建议。

设置:

  • Tensoflow 2.2.0
  • TensorBoard 2.2.2
  • keras模型是在tf.distribute.MirroredStrategy()范围内创建的
  • 用于TensorBoard的keras回调

通过以下设置,可以观察到“问题”:

  • 当使用TensorBoard时,update_freq=' epoch‘和只使用1个epoch

它似乎与其他设置一起工作。无论如何,我将遵循GCP的建议,并使用自定义解决方案来避免问题。

票数 1
EN

Stack Overflow用户

发布于 2020-04-28 21:34:14

我们在TF 2.1和TF Keras和AI平台上测试了这一点,并成功地完成了以下工作:

代码语言:javascript
复制
class CustomCallback(tf.keras.callbacks.TensorBoard):
    """Callback to write out a custom metric used by CAIP for HP Tuning."""

    def on_epoch_end(self, epoch, logs=None):  # pylint: disable=no-self-use
        """Write tf.summary.scalar on epoch end."""
        tf.summary.scalar('epoch_accuracy', logs['accuracy'], epoch)

# Setup TensorBoard callback.
custom_cb = CustomCallback(os.path.join(args.job_dir, 'metric_tb'),
                               histogram_freq=1)

# Train model
keras_model.fit(
        training_dataset,
        steps_per_epoch=int(num_train_examples / args.batch_size),
        epochs=args.num_epochs,
        validation_data=validation_dataset,
        validation_steps=1,
        verbose=1,
        callbacks=[custom_cb])
代码语言:javascript
复制
trainingInput:
  hyperparameters:
    goal: MAXIMIZE
    maxTrials: 4
    maxParallelTrials: 2
    hyperparameterMetricTag: epoch_accuracy
    params:
    - parameterName: batch-size
      type: INTEGER
      minValue: 8
      maxValue: 256
      scaleType: UNIT_LINEAR_SCALE
    - parameterName: learning-rate
      type: DOUBLE
      minValue: 0.01
      maxValue: 0.1
      scaleType: UNIT_LOG_SCALE

似乎与您的代码相同,只是我没有访问权限,您是如何传递回调的。我记得在没有直接指定回调时看到了一些问题。

代码这里

票数 -2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61480051

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档