imbr
imbr

Reputation: 7642

Pytorch torch.cholesky ignoring exception

For some matrices on my batch I'm having an exception due the matrix being singular.

L = th.cholesky(Xt.bmm(X))

cholesky_cpu: For batch 51100: U(22,22) is zero, singular U

Since they are few for my use case I would like to ignore the exception and further deal with them. I will set the resulting calculation as nan is it possible somehow?

Actually if I catch the exception and use continue still it doesn’t finish the calculation of the rest of the batch.

The same happens in C++ with Pytorch libtorch.

Upvotes: 1

Views: 2105

Answers (3)

imbr
imbr

Reputation: 7642

It's not possible to catch the exception according to Pytorch Discuss forum.

The solution, unfortunately, was to implement my own simple batched cholesky (th.cholesky(..., upper=False)) and then deal with Nan values using th.isnan.

import torch as th

# nograd cholesky
def cholesky(A):
    L = th.zeros_like(A)

    for i in range(A.shape[-1]):
        for j in range(i+1):
            s = 0.0
            for k in range(j):
                s = s + L[...,i,k] * L[...,j,k]

            L[...,i,j] = th.sqrt(A[...,i,i] - s) if (i == j) else \
                      (1.0 / L[...,j,j] * (A[...,i,j] - s))
    return L

Upvotes: 1

Karl
Karl

Reputation: 5373

I don't know how this compares speed-wise to the other solutions posted, but it may be faster.

First use torch.det to determine if there are any singular matrices in your batch. Then mask out those matrices.

output = Xt.bmm(X)
dets = torch.det(output)

# if output is of shape (bs, x, y), dets will be of shape (bs)
bad_idxs = dets==0 #might want an allclose here

output[bad_idxs] = 1. # fill singular matrices with 1s

L = torch.cholesky(output)

After you probably need to deal with the singular matrices you filled in with 1s, but you have their index values so it's easy to grab them or exclude them.

Upvotes: 1

jodag
jodag

Reputation: 22244

When performing cholesky decomposition PyTorch relies on LAPACK for CPU tensors and MAGMA for CUDA tensors. In the PyTorch code used to call LAPACK the batch is just iterated over, invoking LAPACK's zpotrs_ function on each matrix separately. In the PyTorch code used to call MAGMA the entire batch is processed using MAGMA's magma_dpotrs_batched which is probably faster than iterating over each matrix separately.

AFAIK there's no way to indicate to MAGMA or LAPACK to not raise exceptions (though to be fair, I'm not an expert on these packages). Since MAGMA may be exploiting batches in some way we may not want to just default to an iterative approach, since we are potentially losing performance by not performing the batched cholesky.

One potential solution is to first try and perform batched cholesky decomposition, if it fails then we could perform cholesky decomposition on each element in the batch, setting the entries that fail to NaN.

def cholesky_no_except(x, upper=False, force_iterative=False):
    success = False
    if not force_iterative:
        try:
            results = torch.cholesky(x, upper=upper)
            success = True
        except RuntimeError:
            pass

    if not success:
        # fall back to operating on each element separately
        results_list = []
        x_batched = x.reshape(-1, x.shape[-2], x.shape[-1])
        for batch_idx in range(x_batched.shape[0]):
            try:
                result = torch.cholesky(x_batched[batch_idx, :, :], upper=upper)
            except RuntimeError:
                # may want to only accept certain RuntimeErrors add a check here if that's the case
                # on failure create a "nan" matrix
                result = float('nan') + torch.empty(x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype)
            results_list.append(result)
        results = torch.cat(results_list, dim=0).reshape(*x.shape)

    return results

If you expect exceptions to be common during cholesky decomposition you may want use force_iterative=True to skip the initial call which tries to use the batched version, since in that case this function would likely just be wasting time with the first attempt.

Upvotes: 1

Related Questions