TensorFlow2.x中的以下矩阵乘法需要很长时间才能执行
a = tf.random.uniform(shape=(9180, 3049))
b = tf.random.uniform(shape=(3049, 1913))
a = tf.cast(a ,tf.float16)
b = tf.cast(b ,tf.float16)
tf.matmul(a,b)但是如果我简单地使用下面的方法,它是很快的
a = tf.random.uniform(shape=(9180, 3049))
b = tf.random.uniform(shape=(3049, 1913))
tf.matmul(a,b)为甚麽呢?出于某种目的,我需要将张量转换为浮点数。
发布于 2020-08-27 02:56:49
实际上,在这两种情况下,您都在尝试对浮点值进行矩阵乘法运算。在第一种情况下,您使用float16,在第二种情况下,您使用float32。
import tensorflow as tf
import time
a = tf.random.uniform(shape=(9180, 3049), seed = 10)
b = tf.random.uniform(shape=(3049, 1913), seed = 10)第一次运行
x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)
x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)输出:
184.76319313049316
0.0在重启我的内核后第二次运行。
x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)
x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)输出:
183.03942680358887
1.0335445404052734现在,如果我再次运行相同的代码,而不重新启动内核,即使在更改了a和b的值之后。
x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)
x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)输出:
0.0
0.0所以从本质上讲,这不是TensorFlow的问题。Tensorflow以图的形式执行。当您第一次运行它时,它会使用前面提到的数据结构初始化图形,并对其进行优化以进行进一步计算。看看this中的最后一条评论。
因此,第二次执行操作的速度会更快
https://stackoverflow.com/questions/63603468
复制相似问题