henry147
henry147

Reputation: 21

Construct a rotation matrix in Pytorch

I want to construct a rotation matrix, which have unknown Eular angles. I want to build some regression solution to find the value of Eular angles. My code is here.

roll = yaw = pitch = torch.randn(1,requires_grad=True)
RX = torch.tensor([
                [1, 0, 0],
                [0, cos(roll), -sin(roll)],
                [0, sin(roll), cos(roll)]
            ],requires_grad=True)
RY = torch.tensor([
                [cos(pitch), 0, sin(pitch)],
                [0, 1, 0],
                [-sin(pitch), 0, cos(pitch)]
            ],requires_grad=True)
RZ = torch.tensor([
                [cos(yaw), -sin(yaw), 0],
                [sin(yaw), cos(yaw), 0],
                [0, 0, 1]
            ],requires_grad=True)
R = torch.mm(RZ, RY).requires_grad_()
R = torch.mm(R, RX).requires_grad_()
R = R.mean().requires_grad_()
R.backward()

Matrix cannot differentiate the Euler angles. There isn't any gradient value of matrix. Can anyone fix my problems? Thanks! debug result

Upvotes: 2

Views: 6576

Answers (1)

haojie yuan
haojie yuan

Reputation: 33

torch.tensor is viewed as an operation and that is not able to do backpropgation.

A dirty way to fix your code:

roll = torch.randn(1,requires_grad=True)
yaw = torch.randn(1,requires_grad=True)
pitch = torch.randn(1,requires_grad=True)

tensor_0 = torch.zeros(1)
tensor_1 = torch.ones(1)

RX = torch.stack([
                torch.stack([tensor_1, tensor_0, tensor_0]),
                torch.stack([tensor_0, cos(roll), -sin(roll)]),
                torch.stack([tensor_0, sin(roll), cos(roll)])]).reshape(3,3)

RY = torch.stack([
                torch.stack([cos(pitch), tensor_0, sin(pitch)]),
                torch.stack([tensor_0, tensor_1, tensor_0]),
                torch.stack([-sin(pitch), tensor_0, cos(pitch)])]).reshape(3,3)

RZ = torch.stack([
                torch.stack([cos(yaw), -sin(yaw), tensor_0]),
                torch.stack([sin(yaw), cos(yaw), tensor_0]),
                torch.stack([tensor_0, tensor_0, tensor_1])]).reshape(3,3)

R = torch.mm(RZ, RY)
R = torch.mm(R, RX)
R_mean = R.mean()

R_mean.backward()

Upvotes: 3

Related Questions