首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow cifar10教程失败

tensorflow cifar10教程失败
EN

Stack Overflow用户
提问于 2016-12-28 12:04:38
回答 1查看 2K关注 0票数 4

我已经从教程CIFAR10的链接下载了这里代码,并试图运行本教程。我用命令运行它

代码语言:javascript
复制
python cifar10_train.py

它启动ok并按预期下载数据文件。当它试图打开输入文件时,它会失败,跟踪如下:

代码语言:javascript
复制
Traceback (most recent call last):
  File "cifar10_train.py", line 120, in <module>
    tf.app.run()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 43, in run
    sys.exit(main(sys.argv[:1] + flags_passthrough))
  File "cifar10_train.py", line 116, in main
    train()
  File "cifar10_train.py", line 63, in train
    images, labels = cifar10.distorted_inputs()
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10.py", line 157, in distorted_inputs
    batch_size=FLAGS.batch_size)
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10_input.py", line 161, in distorted_inputs
    read_input = read_cifar10(filename_queue)
  File "/notebooks/Python Scripts/tensorflowModels/tutorials/image/cifar10/cifar10_input.py", line 87, in read_cifar10
    tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
TypeError: strided_slice() takes at least 4 arguments (3 given)

当然,当我研究代码时,cifar10_input.py中有一个对strided_slice()的调用,其中只有3个参数:

代码语言:javascript
复制
tf.strided_slice(record_bytes, [0], [label_bytes])

而tensorflow文档确实规定必须至少有4个论点。

出什么问题了?我下载了最新的tensorflow (0.12),并运行cifar代码的主分支。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-12-29 10:13:35

github进行了一些讨论之后,我进行了以下修改,这些更改似乎使其工作起来:

在cifar10_input.py中

代码语言:javascript
复制
-  result.label = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
+  result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)



-  depth_major = tf.reshape( tf.strided_slice(record_bytes, [label_bytes], [label_bytes + image_bytes]),      [result.depth, result.height, result.width])
+  depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), [result.depth, result.height, result.width])

然后,在cifar10_input.py和cifar10.py中,我都必须搜索“废弃的”,无论在哪里找到它,根据我在api指南中看到的内容(希望是正确的),用一个有效的函数替换它。这方面的例子如下:

代码语言:javascript
复制
-  tf.contrib.deprecated.image_summary('images', images)
+  tf.summary.image('images', images)

代码语言:javascript
复制
 - tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
 - tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity',
 + tf.summary.histogram(tensor_name + '/activations', x)
 + tf.summary.scalar(tensor_name + '/sparsity',

它现在似乎正快活地前进着。我将看看它是否完成了OK,以及上面所做的更改是否提供了所需的诊断输出。

我还是希望听到接近密码的人给我一个明确的答案。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/41361766

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档