首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow - TextSum模型:如何创建自己的训练数据

Tensorflow - TextSum模型:如何创建自己的训练数据
EN

Stack Overflow用户
提问于 2016-08-27 09:05:51
回答 2查看 1.6K关注 0票数 1

我正在尝试为TextSum模型创建自己的训练数据。据我所知,我需要将我的文章和摘要放到一个二进制文件中(在TFRecords中)。但是,我不能从原始文本文件创建我自己的训练数据。我对格式的理解不是很清楚,所以我尝试使用以下代码创建一个非常简单的二进制文件:

代码语言:javascript
复制
files = os.listdir(path)
writer = tf.python_io.TFRecordWriter("test_data")
for i, file in enumerate(files):
    content = open(os.path.join(path, file), "r").read()
    example = tf.train.Example(
        features = tf.train.Features(
            feature = {
                'content': tf.train.Feature(bytes_list=tf.train.BytesList(value=[content]))
            }
        )
    )

    serialized = example.SerializeToString()
    writer.write(serialized)

我尝试使用以下代码来读出这个test_data文件的值

代码语言:javascript
复制
reader = open("test_data", 'rb')
len_bytes = reader.read(8)
str_len = struct.unpack('q', len_bytes)[0]
example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
example_pb2.Example.FromString(example_str)

但我总是得到以下错误:

代码语言:javascript
复制
  File "dailymail_corpus_to_tfrecords.py", line 34, in check_file
    example_pb2.Example.FromString(example_str)
  File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 770, in FromString
    message.MergeFromString(s)
  File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1091, in MergeFromString
    if self._InternalParse(serialized, 0, length) != length:
  File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1117, in InternalParse
    new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
  File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/decoder.py", line 850, in SkipField
    return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
  File "/home/s1510032/anaconda2/lib/python2.7/site-packages/google/protobuf/internal/decoder.py", line 791, in _SkipLengthDelimited
    raise _DecodeError('Truncated message.')
google.protobuf.message.DecodeError: Truncated message.

我不知道哪里出了问题。如果你有任何解决这个问题的建议,请告诉我。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2016-08-27 12:30:10

对于那些有同样问题的人。我不得不看一下TensorFlow的源代码,看看他们是如何用TFRecordWriter写出数据的。我意识到他们实际上写了8个字节用于长度,4个字节用于CRC校验,这意味着前12个字节用于报头。因为在TextSum代码中,示例二进制文件似乎只有8字节头,这就是为什么他们使用reader.read(8)来获取数据的长度,并将其余的读取为特征。

我的工作解决方案是:

代码语言:javascript
复制
reader = open("test_data", 'rb')
len_bytes = reader.read(8)
reader.read(4) #ignore next 4 bytes
str_len = struct.unpack('q', len_bytes)[0]
example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
example_pb2.Example.FromString(example_str)
票数 3
EN

Stack Overflow用户

发布于 2016-09-28 08:38:59

我希望你的textsum目录中有data_convert_example.py。如果没有,你可以在这篇文章中找到:https://github.com/tensorflow/models/pull/379/files

使用python文件将给定的二进制玩具数据(文件名: data目录中的数据)转换为文本格式。python data_convert_example.py --command binary_to_text --in_file ../data/data --out_file ../data/result_text

您可以看到在result_text格式中应该给出的实际文本格式。

以这种格式准备数据,并使用相同的python脚本从text_to_binary进行转换,并将结果用于训练/测试/评估。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/39176529

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档