Reputation: 7642
For some matrices on my batch I'm having an exception due the matrix being singular.
L = th.cholesky(Xt.bmm(X))
cholesky_cpu: For batch 51100: U(22,22) is zero, singular U
Since they are few for my use case I would like to ignore the exception and further deal with them. I will set the resulting calculation as nan is it possible somehow?
Actually if I catch
the exception and use continue
still it doesn’t finish the calculation of the rest of the batch.
The same happens in C++ with Pytorch libtorch.
Upvotes: 1
Views: 2105
Reputation: 7642
It's not possible to catch the exception according to Pytorch Discuss
forum.
The solution, unfortunately, was to implement my own simple batched cholesky (th.cholesky(..., upper=False)
) and then deal with Nan values using th.isnan
.
import torch as th
# nograd cholesky
def cholesky(A):
L = th.zeros_like(A)
for i in range(A.shape[-1]):
for j in range(i+1):
s = 0.0
for k in range(j):
s = s + L[...,i,k] * L[...,j,k]
L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else \
(1.0 / L[...,j,j] * (A[...,i,j] - s))
return L
Upvotes: 1
Reputation: 5373
I don't know how this compares speed-wise to the other solutions posted, but it may be faster.
First use torch.det
to determine if there are any singular matrices in your batch. Then mask out those matrices.
output = Xt.bmm(X)
dets = torch.det(output)
# if output is of shape (bs, x, y), dets will be of shape (bs)
bad_idxs = dets==0 #might want an allclose here
output[bad_idxs] = 1. # fill singular matrices with 1s
L = torch.cholesky(output)
After you probably need to deal with the singular matrices you filled in with 1s, but you have their index values so it's easy to grab them or exclude them.
Upvotes: 1
Reputation: 22244
When performing cholesky decomposition PyTorch relies on LAPACK for CPU tensors and MAGMA for CUDA tensors. In the PyTorch code used to call LAPACK the batch is just iterated over, invoking LAPACK's zpotrs_
function on each matrix separately. In the PyTorch code used to call MAGMA the entire batch is processed using MAGMA's magma_dpotrs_batched
which is probably faster than iterating over each matrix separately.
AFAIK there's no way to indicate to MAGMA or LAPACK to not raise exceptions (though to be fair, I'm not an expert on these packages). Since MAGMA may be exploiting batches in some way we may not want to just default to an iterative approach, since we are potentially losing performance by not performing the batched cholesky.
One potential solution is to first try and perform batched cholesky decomposition, if it fails then we could perform cholesky decomposition on each element in the batch, setting the entries that fail to NaN.
def cholesky_no_except(x, upper=False, force_iterative=False):
success = False
if not force_iterative:
try:
results = torch.cholesky(x, upper=upper)
success = True
except RuntimeError:
pass
if not success:
# fall back to operating on each element separately
results_list = []
x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
for batch_idx in range(x_batched.shape[0]):
try:
result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
except RuntimeError:
# may want to only accept certain RuntimeErrors add a check here if that's the case
# on failure create a "nan" matrix
result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
results_list.append(result)
results = torch.cat(results_list, dim=0).reshape(*x.shape)
return results
If you expect exceptions to be common during cholesky decomposition you may want use force_iterative=True
to skip the initial call which tries to use the batched version, since in that case this function would likely just be wasting time with the first attempt.
Upvotes: 1