首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow 2+ Keras的知识蒸馏损失

Tensorflow 2+ Keras的知识蒸馏损失
EN

Stack Overflow用户
提问于 2019-12-02 19:22:43
回答 1查看 668关注 0票数 5

我正在尝试实现一个非常简单的keras模型,它使用来自另一个模型的知识蒸馏1。粗略地说,我需要将原始的损失L(y_true, y_pred)替换为L(y_true, y_pred)+L(y_teacher_pred, y_pred),其中y_teacher_pred是另一个模型的预测。

我试着去做

代码语言:javascript
复制
def create_student_model_with_distillation(teacher_model):

  inp = tf.keras.layers.Input(shape=(21,))

  model = tf.keras.models.Sequential()
  model.add(inp)

  model.add(...) 
  model.add(tf.keras.layers.Dense(units=1))

  teacher_pred = teacher_model(inp)

  def my_loss(y_true,y_pred):
      loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
      loss += tf.keras.losses.mean_squared_error(teacher_pred, y_pred)
      return loss

  model.compile(loss=my_loss, optimizer='adam')

  return model

然而,当我尝试在我的模型上调用fit时,我得到了

代码语言:javascript
复制
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.

我该如何解决这个问题?

参考文献

1

EN

回答 1

Stack Overflow用户

发布于 2021-04-21 20:46:41

实际上,这篇博文回答了你的问题:keras blog

但简而言之,你应该使用新的TF2接口,并在tf.GradientTape()块之前调用教师的predict

代码语言:javascript
复制
def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59137907

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档