Reputation: 552
I am trying to plot multiple graphs via subplot. The code "works" but it always gives me an index error that I can't for the life of me figure out.
As a side question, I was wondering if anyone knew how to keep each separate plot the same size. For example, if I added more rows or columns then each plot gets smaller. Thanks.
count = 0
n_rows = 2
n_columns = 2
f, axarr = plt.subplots(n_rows, n_columns)
plt.figure(figsize=(20,20))
for column in range(n_cols):
for row in range(n_rows):
axarr[row, column].imshow(generate_pattern('block3_conv1', count, size=150))
count += 1
Error
IndexError Traceback (most recent call last)
<ipython-input-37-7f7ae19e07e9> in <module>()
7 for column in range(n_cols):
8 for row in range(n_rows):
----> 9 axarr[row, column].imshow(generate_pattern('block3_conv1', count, size=150))
10
11 count += 1
IndexError: index 2 is out of bounds for axis 1 with size 2
Code for functions used
def generate_pattern(layer_name, filter_index, size=150):
# Build a loss function that maximizes the activation
# of the nth filter of the layer considered.
layer_output = model.get_layer(layer_name).output
loss = K.mean(layer_output[:, :, :, filter_index])
# Compute the gradient of the input picture wrt this loss
grads = K.gradients(loss, model.input)[0]
# Normalization trick: we normalize the gradient
grads /= (K.sqrt(K.mean(K.square(grads))) + 1e-5)
# This function returns the loss and grads given the input picture
iterate = K.function([model.input], [loss, grads])
# We start from a gray image with some noise
input_img_data = np.random.random((1, size, size, 3)) * 20 + 128.
# Run gradient ascent for 40 steps
step = 1.
for i in range(40):
loss_value, grads_value = iterate([input_img_data])
input_img_data += grads_value * step
img = input_img_data[0]
return deprocess_image(img)
def deprocess_image(x):
x -= x.mean()
x /= (x.std() + 1e-5)
x *= 0.1
x += 0.5
x = np.clip(x,0,1)
x *= 255
x = np.clip(x,0,255).astype('uint8')
return x
Upvotes: 1
Views: 395
Reputation: 3309
This error is the result of trying to index the array produced by plt.subplots()
with a value that is out of the range of the index. One way to show this is by replacing the variables from the loop with just numbers. In this case you will see that axarr[1,2]
will produce the following error:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-39-31f90736bd1d> in <module>()
2 #plt.figure(figsize=(20,20))
3
----> 4 a[0,2]
IndexError: index 2 is out of bounds for axis 1 with size 2
We know that the error did not occur in the generate_pattern
function as the error message would have indicated as much.
Upvotes: 1