首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow预训练模型输入误差

Tensorflow预训练模型输入误差
EN

Stack Overflow用户
提问于 2018-08-08 04:23:19
回答 1查看 75关注 0票数 0

我正在研究tensorflow中的对象检测模型。我有一个文件model.py

代码语言:javascript
复制
from PIL import Image   
import cv2   
import numpy as np   
import tensorflow as tf   
from .squeezenet import SqueezeNet

save_path = "sqnet/squeezenet.ckpt"
sess = tf.Session()
model = SqueezeNet(save_path=save_path, sess=sess)

class Finder(object):
    def __init__(self, image_path):
        self.image_path = image_path

    def predict(self):
        image = process(self.image_path)
        ans = sess.run(model.classifier, feed_dict={model.image: 
                       image})
        return ans


def process(path):
    image = Image.open(path)
    # image.show()
    image = np.array(image)
    image = cv2.resize(image, dsize=(224, 224), 
                       interpolation=cv2.INTER_CUBIC)
    image = image.reshape((1, 224, 224, 3))
    #print(image.shape)
    #img = Image.fromarray(image, 'RGB')
    return image


image_path = "/home/jatin/ai.jpeg"

object_detector = Finder(image_path)

ans = object_detector.predict()

print(np.argmax(ans))

sess.close()

我在model.py文件旁边有一个名为squuezenet.cpkt的文件夹,其中有squuezenet.cpkt文件。但是运行这会导致错误:

InvalidArgumentError (参见上面的回溯):不成功的TensorSliceReader构造函数:未能在sqnet/PRECZenet.ckpt: Not :sqnet上获得匹配文件;没有这样的文件或目录。

有什么问题吗?

EN

回答 1

Stack Overflow用户

发布于 2018-08-08 07:17:31

对我来说似乎是一个简单的IO错误。你试过使用绝对路径吗?

代码语言:javascript
复制
save_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'sqnet', 'squeezenet.ckpt')
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/51738581

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档