Jenny I
Jenny I

Reputation: 101

How can I plot pytorch tensor?

I would like to plot pytorch gpu tensor:

input= torch.randn(100).to(device)
output = torch.where(input>=0, input, -input)

input = input.('cpu').detach().numpy().copy()
output = output.('cpu').detach().numpy().copy()

plt.plot(input,out)

However I try to convert those tensors into cpu, numpy, it does not work. How can I plot the tensors ?

Upvotes: 4

Views: 26798

Answers (1)

anarchy
anarchy

Reputation: 5184

Does this work?

plt.plot(input.cpu().numpy(),output.cpu().numpy())

Alternatively you can try,

plt.plot(input.to('cpu').numpy(),output.to('cpu').numpy())

Upvotes: 1

Related Questions