Reputation: 27
I am trying to learn backward propagation in pytorch, where I saw this code:
def backward(ctx, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
i, = ctx.saved_tensors
grad_output = 2*i
return grad_output
I cannot understand what i,
is here.
Upvotes: 2
Views: 215
Reputation: 451
Python can unpack values, e.g if we have a tuple
t = (2,3,'wat')
we could assign its values to variables like this
coolnumber, othernumber, word = t
which would lead to word
being set to 'wat'
.
If our tuple has length one, we need to distinguish between assigning the whole tuple to a variable
a = (1,)
# a is now (1,)
and unpacking that one value
a, = (1,)
# a is now 1
Upvotes: 5
Reputation: 198324
i,
is a 1-tuple whose sole element is i
, just like i, j
is a 2-tuple (obviously, can't just write i
because that's not a tuple, but just i
). When on the right side of an assignment, such syntax is used to construct a tuple; when on the left side, the assignment will deconstruct a tuple (or another sequence) into components. See this example:
# construction
scalar = 5 # => 5
two_tuple = 6, 7 # => (6, 7)
one_tuple = 8, # => (8)
# deconstruction
five = scalar
six, seven = two_tuple
eight, = one_tuple
So, if ctx.saved_tensor
is a 1-tuple, your code will assign the content of ctx.saved_tensor
to i
.
(Note that you will often read that the tuple syntax includes parentheses — (6, 7)
instead of 6, 7
. This is half-correct; tuple syntax does not include the parentheses, but they are necessary in many contexts because the comma has a very low priority, or because you want to delimit the commas inside a tuple and those outside of it, like in a function parameter list, or within another sequence. That said, the extra parentheses do not hurt.)
Upvotes: 1