首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow模型获得损失0

TensorFlow模型获得损失0
EN

Stack Overflow用户
提问于 2017-05-04 07:36:23
回答 1查看 5K关注 0票数 5
代码语言:javascript
复制
import tensorflow as tf
import numpy as np
def weight(shape):
return tf.Variable(tf.truncated_normal(shape, stddev=0.1))
def bias(shape):
return tf.Variable(tf.constant(0.1, shape=shape))
def output(input,w,b):
return tf.matmul(input,w)+b
x_columns = 33
y_columns = 1
layer1_num = 7
layer2_num = 7
epoch_num = 10
train_num = 1000
batch_size = 100
display_size = 1
x = tf.placeholder(tf.float32,[None,x_columns])
y = tf.placeholder(tf.float32,[None,y_columns])

layer1 = 
tf.nn.relu(output(x,weight([x_columns,layer1_num]),bias([layer1_num])))
layer2=tf.nn.relu
(output(layer1,weight([layer1_num,layer2_num]),bias([layer2_num])))
prediction = output(layer2,weight([layer2_num,y_columns]),bias([y_columns]))

loss=tf.reduce_mean
(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
train_step = tf.train.AdamOptimizer().minimize(loss)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
for epoch in range(epoch_num):
   avg_loss = 0.
   for i in range(train_num):
      index = np.random.choice(len(x_train),batch_size)
      x_train_batch = x_train[index]
      y_train_batch = y_train[index]
      _,c = sess.run([train_step,loss],feed_dict=
{x:x_train_batch,y:y_train_batch})
      avg_loss += c/train_num
   if epoch % display_size == 0:
      print("Epoch:{0},Loss:{1}".format(epoch+1,avg_loss))
print("Training Finished")

我的模型得到时代:2,损失:0.0时代:3,损失:0.0时代:4,损失:0.0时代:5,损失:0.0时代:6,损失:0.0时代:7,损失:0.0时代:8,损失:0.0时代:9,损失:0.0时代:10,损失:0.0训练结束

我该如何处理这个问题?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-05-04 12:19:45

softmax_cross_entropy_with_logits期望标签是一种热的形式,即带有形状[batch_size, num_classes]的标签.这里有y_columns = 1,这意味着只有一个类,它必须是预测的类和‘基本真理’(从网络的角度来看),所以不管权重是什么,输出总是正确的。因此,loss=0

我想您确实有不同的类,y_train包含标签的ID。那么predictions应该是[batch_size, num_classes]形状的,而不是softmax_cross_entropy_with_logits,您应该使用tf.nn.sparse_softmax_cross_entropy_with_logits

票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43776661

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档