Toon Tran
Toon Tran

Reputation: 378

Where is torch.cholesky and how torch refers to its methods?

I'm doing some research into Cholesky decomposition, which requires some insights into how torch.cholesky works. After a while of grep-ing and searching through ATen, I got stuck at TensorMethods.h, which interestingly has this following code:

inline Tensor Tensor::cholesky(bool upper) const {
#ifdef USE_STATIC_DISPATCH
    return TypeDefault::cholesky(const_cast<Tensor&>(*this), upper);
#else
    static c10::OperatorHandle op = c10::Dispatcher::singleton().findSchema({"aten::cholesky", ""}).value();
    return c10::Dispatcher::singleton().callUnboxed<Tensor, const Tensor &, bool>(
        op, impl::dispatchTypeId(at::detail::multi_dispatch_tensor_type_set(*this)), const_cast<Tensor&>(*this), upper);
#endif
}

This raised the question of how torch locates its methods. Thank you!

Upvotes: 1

Views: 458

Answers (1)

jodag
jodag

Reputation: 22234

Take a look at aten/src/ATen/native/README.md which describes how functions are registered to the API.

ATen "native" functions are the modern mechanism for adding operators and functions to ATen (they are "native" in contrast to legacy functions, which are bound via TH/THC cwrap metadata). Native functions are declared in native_functions.yaml and have implementations defined in one of the cpp files in this directory.

If we take a look at aten/src/ATen/native/native_functions.yaml and search for cholesky we find

- func: cholesky(Tensor self, bool upper=False) -> Tensor
  use_c10_dispatcher: full
  variants: method, function

To find the entry-point you basically just have to search the .cpp files in the aten/src/ATen/native directory and locate the function named cholesky. Currently it can be found at BatchLinearAlgebra.cpp:550

Tensor cholesky(const Tensor &self, bool upper) {
  if (self.size(-1) == 0) {
    return at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
  }
  squareCheckInputs(self);

  auto raw_cholesky_output = at::_cholesky_helper(self, upper);
  if (upper) {
    return raw_cholesky_output.triu_();
  } else {
    return raw_cholesky_output.tril_();
  }
}

From this point it's just a matter of following the C++ code to understand what's going on.

Upvotes: 3

Related Questions