我试图在Tensorflow图中使用条件随机场损失。
我正在执行一个序列标记任务:
我有一个元素序列作为输入[A, B, C, D]。每个元素可以属于三个不同的类中的一个。类以一种热编码的方式表示:属于0类的元素由向量‘1,0,0’表示。
我的输入标签(y)有大小(batch_size x sequence_length x num_classes)。
我的网络产生相同形状的日志。
假设我所有的序列都有4。
这是我的密码:
import tensorflow as tf
sequence_length = 4
num_classes = 3
input_y = tf.placeholder(tf.int32, shape=[None, sequence_length, num_classes])
logits = tf.placeholder(tf.float32, shape=[None, None, num_classes])
dense_y = tf.argmax(input_y, -1, output_type=tf.int32)
log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(logits, dense_y, sequence_length)我得到以下错误:
文件"",第1行,文件"/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py",第182行,crf_log_likelihood transition_params),文件"/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py",第109行,文件"/usr/local/lib/python2.7/dist-packages/tensorflow/python/layers/utils.py",第206号,在smart_cond pred,true_fn=true_fn,false_fn=false_fn,name=name中)文件false_fn=_multi_seq_fn第59行,在文件"/usr/local/lib/python2.7/dist-packages/tensorflow/python/util/deprecation.py",( smart_cond name=name)第432行中,在new_func返回函数(*args,**kwargs)文件"/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py",行2063中,在cond orig_res_t中,文件"/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/control_flow_ops.py",行res_t = context_t.BuildCondBranch(true_fn) 1913,在BuildCondBranch original_result = fn() File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/crf/python/ops/crf.py",第95行中,在_single_seq_fn array_ops.concat(example_inds,tag_indices )中文件"/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py",第2975行,gather_nd "GatherNd",params=params,indices=indices,name=name)文件"/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py",行787,(在"/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py","/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py",( _apply_op_helper op_def=op_def)文件行3392中,在create_op op_def=op_def中)文件op_def=op_def行1734,在"/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py",( init control_input_ops)文件第1570行中,在_create_c_op raise (str(E))ValueError中,for 1必须是<= params.rank,但对于输入形状:?、3的“cond/GatherNd”(op:'GatherNd'),则看到索引形状:?、5和参数形状:?、3,?、5
发布于 2018-08-24 04:16:42
该错误是由于序列长度变量的维数错误造成的。它必须是一个向量,而不是标量。
import tensorflow as tf
num_classes = 3
input_x = tf.placeholder(tf.int32, shape=[None, None], name="input_x")
input_y = tf.placeholder(tf.int32, shape=[None, sequence_length, num_classes])
sequence_length = tf.reduce_sum(tf.sign(input_x), 1)
# After some network operation you will come up with logits
logits = tf.placeholder(tf.float32, shape=[None, None, num_classes])
dense_y = tf.argmax(input_y, -1, output_type=tf.int32)
log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(logits, dense_y, sequence_lengthhttps://stackoverflow.com/questions/51105062
复制相似问题