Reputation: 27
In Michael Nielsen's tutorial on neural networks he has the following code:
def update_mini_batch(self, mini_batch, eta):
"""The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
is the learning rate."""
nabla_b = [np.zeros(b.shape) for b in self.biases]
nabla_w = [np.zeros(w.shape) for w in self.weights]
for x, y in mini_batch:
delta_nabla_b, delta_nabla_w = self.backprop(x, y)
nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
nabla_w = [nw+dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
self.weights = [w-(eta/len(mini_batch))*nw
for w, nw in zip(self.weights, nabla_w)]
self.biases = [b-(eta/len(mini_batch))*nb
for b, nb in zip(self.biases, nabla_b)]
I understand what tuples and lists are and I understand what the zip function is doing but I don't understand how the variables nb, dnb, nw, and dnw are updated on these 2 lines of code:
nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
nabla_w = [nw+dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
Can anyone help explain the magic going on in these 2 lines?
Upvotes: 1
Views: 107
Reputation: 60319
These 2 lines are typical examples of Python list comprehensions.
In essence, for your first list:
nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
this means:
zip(nabla_b, delta_nabla_b)
; name them nb
and dnb
nb+dnb
)nabla_b
nabla_b
, until all pairs from zip(nabla_b, delta_nabla_b)
have been exhaustedAs a simple example, the list comprehension below:
squares = [x**2 for x in range(10)]
print(squares)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
is equivalent with the following for
loop:
squares = []
for x in range(10):
squares.append(x**2)
print(squares)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
See here for more examples and a quick introduction.
Upvotes: 1
Reputation: 1289
The zip
function sticks the two lists together element by element, so that if you gave it:
a = [1, 2, 3, 4]
b = ["a", "b", "c", "d"]
zip(a, b)
would return:
[(1, "a"), (2, "b"), ...]
(each element being a tuple
)
You can unpack elements of list
s that are tuple
s (or list
s) using the a comma between each variable in the element tuple
:
for elem_a, elem_b in zip(a, b):
print(elem_a, elem_b)
This would print:
1 a
2 b
3 c
4 d
So in your case, it's adding the two lists nabla_b
and delta_nabla_b
elementwise, so you get one list with each element being the sum of the corresponding elements in the zipped lists.
It might look a bit strange because the for
loop is all on one line, but this is called a "list comprehension". Simple list comprehensions read like English.
Upvotes: 1