首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >带有tf.contrib.losses.metric_learning.triplet_semihard_loss断言误差的keras模型

带有tf.contrib.losses.metric_learning.triplet_semihard_loss断言误差的keras模型
EN

Stack Overflow用户
提问于 2019-01-01 13:53:44
回答 2查看 2.1K关注 0票数 5

我将python 3与anaconda结合使用,并尝试使用带有Keras模型的tf.contrib丢失函数。

代码如下

代码语言:javascript
复制
from keras.layers import Dense, Flatten
from keras.optimizers import Adam
from keras.models import Sequential
from tensorflow.contrib.losses import metric_learning
model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dense(50,  activation="relu"))
model.compile(loss=metric_learning.triplet_semihard_loss, optimizer=Adam())

我得到以下错误:

文件"/home/user/.local/lib/python3.6/site-packages/keras/engine/training_utils.py",行404,在加权score_array = fn(y_true,y_pred)文件"/home/user/anaconda3/envs/siamese/lib/python3.6/site-packages/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py",第179行中,在triplet_semihard_loss assert lshape.shape == 1 AssertionError中

当我使用同样的网络和keras损失函数时,它工作得很好,我尝试将tf损失函数封装成这样的函数

代码语言:javascript
复制
def func(y_true, y_pred): 
    import tensorflow as tf
    return tf.contrib.losses.metric_learning.triplet_semihard_loss(y_true, y_pred) 

仍然会犯同样的错误

我在这里做错什么了?

更新:当更改func以返回以下内容时

代码语言:javascript
复制
return K.categorical_crossentropy(y_true, y_pred)

一切都很好!但我不能让它和特殊的tf损失函数一起工作.

当我进入tf.contrib.losses.metric_learning.triplet_semihard_loss并删除这一行代码:assert lshape.shape == 1时,它运行得很好

谢谢

EN

回答 2

Stack Overflow用户

发布于 2019-02-12 07:48:20

问题是,您将错误的输入传递给损失函数。

根据损失文件串,你需要通过labelsembeddings

所以你的代码必须是:

代码语言:javascript
复制
def func(y, embeddings): 
    return tf.contrib.losses.metric_learning.triplet_semihard_loss(labels=y, embeddings=embeddings) 

还有两个关于嵌入网络的注意事项:

  1. 最后的致密层必须没有激活。
  2. 不要忘记标准化输出向量model.add(Lambda(lambda x: K.l2_normalize(x, axis=1)))
票数 1
EN

Stack Overflow用户

发布于 2019-01-01 18:37:15

似乎您的问题来自于丢失函数中的不正确输入。实际上,三重态损失需要参数:

代码语言:javascript
复制
Args:
labels: 1-D tf.int32 `Tensor` with shape [batch_size] of
  multiclass integer labels.
embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should
  be l2 normalized.

你确定y_true的形状是正确的吗?你能给我们提供更多关于你使用的张量的更多细节吗?

票数 -1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53996020

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档