Reputation: 1235
When I was looking at https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L744, the function is declared to take 4 template parameters, but only 2 parameters are passed to the template when calling that function. Where does the cell_params and io_type come from in this case?
template<template<typename,typename> class LayerT,
template<typename,typename> class BidirLayerT,
typename cell_params,
typename io_type>
std::tuple<io_type, Tensor, Tensor> _lstm_impl(
const io_type& input,
const std::vector<cell_params>& params,
const Tensor& hx,
const Tensor& cx,
int64_t num_layers,
double dropout_p,
bool train,
bool bidirectional) {
...
}
auto results = _lstm_impl<FullLayer, FullBidirectionalLayer>(input, params, hx[0], hx[1], num_layers, dropout_p, train, bidirectional)
Upvotes: 2
Views: 112
Reputation: 37523
Last two parameters are obviously deduced from function arguments. io_type
from input
and cell_params
from params
Upvotes: 1