Tom Hale
Tom Hale

Reputation: 46795

Check if PyTorch tensors are equal within epsilon

How do I check if two PyTorch tensors are semantically equal?

Given floating point errors, I want to know if the the elements differ only by a small epsilon value.

Upvotes: 17

Views: 11163

Answers (1)

Tom Hale
Tom Hale

Reputation: 46795

At the time of writing, this is a undocumented function in the latest stable release (0.4.1), but the documentation is in the master (unstable) branch.

torch.allclose() will return a boolean indicating whether all element-wise differences are equal allowing for a margin of error.

Additionally, there's the undocumented isclose():

>>> torch.isclose(torch.Tensor([1]), torch.Tensor([1.00000001]))
tensor([1], dtype=torch.uint8)

Upvotes: 16

Related Questions