首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >我正在用TensorFlow创建CNN函数,但我得到了一个与形状相关的错误

我正在用TensorFlow创建CNN函数,但我得到了一个与形状相关的错误
EN

Stack Overflow用户
提问于 2019-08-15 15:43:13
回答 2查看 56关注 0票数 0

我用Tensorflow尝试了卷积神经网络。

但是,形状会导致错误。

第一个是main函数的一部分。

代码语言:javascript
复制
while True:
        with mss.mss() as sct:
                Game_Scr = np.array(sct.grab(Game_Scr_pos))[:,:,:3]

                cv2.imshow('Game_Src', Game_Scr)
                cv2.waitKey(0)

                Game_Scr = cv2.resize(Game_Scr, dsize=(960, 540), interpolation=cv2.INTER_AREA)
                print(Game_Scr.shape)

                print(Convolution(Game_Scr))

第二个是我调用的函数。

代码语言:javascript
复制
def Convolution(img):
        kernel = tf.Variable(tf.truncated_normal(shape=[4], stddev=0.1))
        sess = tf.Session()
        with tf.Session() as sess:
                img = img.astype('float32')
                Bias1 = tf.Variable(tf.truncated_normal(shape=[4],stddev=0.1))
                conv2d = tf.nn.conv2d(img, kernel, strides=[1, 1, 1, 1], padding='SAME')# + Bias1
                conv2d = sess.run(conv2d)
        return conv2d

ValueError:形状的等级必须为4,但对于输入形状为540,960,3,4的'Conv2D‘(op:'Conv2D'),形状的等级为3。

我试着改变形状很多次,但我得到相同的错误。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-08-15 16:46:51

根据官方文档here,输入张量的形状应该是[batch, in_height, in_width, in_channels],而滤波器/核张量的形状应该是[filter_height, filter_width, in_channels, out_channels]

尝试将您的Convolution函数更改为如下所示:

代码语言:javascript
复制
def Convolution(img):
        kernel = tf.Variable(tf.truncated_normal(shape=[200, 200, 3, 3], stddev=0.1))
        sess = tf.Session()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            img = img.astype('float32')
            conv2d = tf.nn.conv2d(np.expand_dims(img, 0), kernel, strides=[1, 1, 1, 1], padding='SAME')# + Bias1
            conv2d = sess.run(conv2d)
        return conv2d
票数 0
EN

Stack Overflow用户

发布于 2019-08-15 16:46:25

尝试替换

img = img.astype('float32')

使用

img = tf.expand_dims(img.astype('float32'), 0)

tf.nn.conv2d输入的维数应为4,(batch_size,image_hight,image_with,image_channels)。如果缺少batch_size,tf.expand_dims只需添加该维度( batch_size为1,因为您只有一张图像)。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/57506352

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档