namespace-Pt
namespace-Pt

Reputation: 1924

Output of nn.Linear is different for the same input

In torch==1.7.1+cu101, I have two tensors

import torch
a = torch.rand(1,5,10)
b = torch.rand(100,1,10)

and a feed-forward network

import torch.nn as nn
l = nn.Linear(10,10)

I force one row of them to be equal:

a[0,0] = b[80][0].clone()

Then I feed both tensors to l:

r1 = l(a)
r2 = l(b)

Apparently, since a[0,0] is equal to b[80,0], r1[0,0] must be equal to r2[80,0]. But it turns out to be like:

(r1[0,0] == r2[80,0]).all()
>>> False

I've fixed the randomness by:

seed = 42

random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

Does anyone know why (r1[0,0] == r2[80,0]).all() is False?

Upvotes: 3

Views: 1027

Answers (1)

rkechols
rkechols

Reputation: 597

If you were to print r1[0, 0] and r2[80, 0], you'd see that they are extremely similar. Even identical within the number of printed digits.

However, if you print r1[0, 0] - r2[80, 0] you'll see that the resulting entries are not perfectly 0.0 (though they are close to it), meaning that r1[0, 0] and r2[80, 0] are close but not perfectly identical.

Now, if we were to do take those individual vectors out first, and then pass them through the linear layer like this:

r1_small = l(a[0, 0])
r2_small = l(b[80, 0])
print((r1_small == r2_small).all())  # tensor(True)

we get that they are perfectly identical, even despite being floats.

So, that means that some difference is introduced by the identical vectors being smaller parts of bigger tensors when they are passed through the linear layer.

It's also very much worth noting that the same difference does not arise when the first n-1 dimensions are all powers of 2:

a2 = torch.randn(8, 8, 10)
b2 = torch.randn(4, 16, 10)
a2[0, 0] = b2[1, 0].clone()
r1_2 = l(a2)
r2_2 = l(b2)
print((r1_2[0, 0] == r2_2[1, 0]).all())  # tensor(True)

So, while I don't know the details, I suspect it has to do with byte alignment.

In general, testing for perfect equality between float values that should be mathematically equal will not always give the expected results. So how do we handle these differences? You can use torch.isclose or torch.allclose to check for imperfect equality.

Upvotes: 2

Related Questions