Whisht
Whisht

Reputation: 753

How to judge a torch.tensor dtype is int or not?

I want to check a Tensor convert from an image is normalized or not, i.e. the dtype is int or float. Is there convenient way to achieve this goal? I do not want an enumerated condition check like a.dtype == torch.int or a.dtype == torch.int32 or a.dtype ==torch.uint8 .... . Or is there another way to check an image tensor normalized or not?

Upvotes: 5

Views: 4370

Answers (2)

LemmeTestThat
LemmeTestThat

Reputation: 658

As mentioned by other answers and comments, torch.is_floating_point will cover most situations.

However, do note that torch.is_floating_point will return False for complex dtypes, which are not integers. It seems that, as of now, PyTorch is missing a torch.is_integer check.

As a workaround to cover more cases than just int VS float, one may consider using the following condition:

not torch.is_floating_point(my_tensor) and not torch.is_complex(my_tensor)

Upvotes: 2

Shai
Shai

Reputation: 114796

As pointed out by jasonharper, you can use torch.is_floating_point to check your tensor:

if torch.is_floating_point(my_tensor):
  # the tensor is "normalized"...
else:
  # the tensor is an integer and needs to be normalized...

Upvotes: 5

Related Questions