首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TensorFlow -矩阵乘法转换为浮点型的矩阵需要很长时间,为什么?

TensorFlow -矩阵乘法转换为浮点型的矩阵需要很长时间,为什么?
EN

Stack Overflow用户
提问于 2020-08-27 02:12:48
回答 1查看 154关注 0票数 0

TensorFlow2.x中的以下矩阵乘法需要很长时间才能执行

代码语言:javascript
复制
    a = tf.random.uniform(shape=(9180, 3049))
    b = tf.random.uniform(shape=(3049, 1913))
    a = tf.cast(a ,tf.float16)
    b = tf.cast(b ,tf.float16)
    tf.matmul(a,b)

但是如果我简单地使用下面的方法,它是很快的

代码语言:javascript
复制
    a = tf.random.uniform(shape=(9180, 3049))
    b = tf.random.uniform(shape=(3049, 1913))
    tf.matmul(a,b)

为甚麽呢?出于某种目的,我需要将张量转换为浮点数。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-08-27 02:56:49

实际上,在这两种情况下,您都在尝试对浮点值进行矩阵乘法运算。在第一种情况下,您使用float16,在第二种情况下,您使用float32。

代码语言:javascript
复制
import tensorflow as tf
import time
a = tf.random.uniform(shape=(9180, 3049), seed = 10)
b = tf.random.uniform(shape=(3049, 1913), seed = 10)

第一次运行

代码语言:javascript
复制
x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

输出:

代码语言:javascript
复制
184.76319313049316
0.0

在重启我的内核后第二次运行。

代码语言:javascript
复制
x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

输出:

代码语言:javascript
复制
183.03942680358887
1.0335445404052734

现在,如果我再次运行相同的代码,而不重新启动内核,即使在更改了a和b的值之后。

代码语言:javascript
复制
x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

输出:

代码语言:javascript
复制
0.0
0.0

所以从本质上讲,这不是TensorFlow的问题。Tensorflow以图的形式执行。当您第一次运行它时,它会使用前面提到的数据结构初始化图形,并对其进行优化以进行进一步计算。看看this中的最后一条评论。

因此,第二次执行操作的速度会更快

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63603468

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档