Saranraj K
Saranraj K

Reputation: 430

TensorFlow - Matrix multiplication of matrices cast to float type takes very long time , why?

The following matrix multiplication in tensorflow 2.x takes a very long time to execute

    a = tf.random.uniform(shape=(9180, 3049))
    b = tf.random.uniform(shape=(3049, 1913))
    a = tf.cast(a ,tf.float16)
    b = tf.cast(b ,tf.float16)
    tf.matmul(a,b)

but if I simply use the below method, it's fast

    a = tf.random.uniform(shape=(9180, 3049))
    b = tf.random.uniform(shape=(3049, 1913))
    tf.matmul(a,b)

Why is it so? and I need to convert the tensor to float for some purpose.

Upvotes: 0

Views: 700

Answers (1)

Aniket Bote
Aniket Bote

Reputation: 3564

Actually, in both of your cases, you are attempting Matrix multiplication of floating values. In the first case you are using float16 and in second case you are using float32.

import tensorflow as tf
import time
a = tf.random.uniform(shape=(9180, 3049), seed = 10)
b = tf.random.uniform(shape=(3049, 1913), seed = 10)

1st run

x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

Output:

184.76319313049316
0.0

2nd run after restarting my kernel.

x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

Output:

183.03942680358887
1.0335445404052734

Now if I run the same code again without restarting the kernel again even after changing the values of a and b.

x1 = tf.cast(a ,tf.float16)
y1 = tf.cast(b ,tf.float16)
s = time.time()
r1 = tf.matmul(x1,y1)
e = time.time()
print((e-s)*1000)

x2 = a
y2 = b
s = time.time()
r2 = tf.matmul(x2,y2)
e = time.time()
print((e-s)*1000)

Output:

0.0
0.0

So essentially it is not a problem of TensorFlow. Tensorflow is executed as a graph. When you run it for the first time it initializes the graph with the mentioned data structure and optimizes it for further calculation. Take a look at the final comment in this.

Therefore your second execution for an operation will be faster

Upvotes: 1

Related Questions