saraherceg
saraherceg

Reputation: 341

Plot all predictions of the model multi-label classification

I would like to plot my input and output of the model I am trying to train:

input data shape:

processed_data.shape
(100, 64, 256, 2)

How it looks like:

processed_data
array([[[[ 1.93965047e+04,  8.49532852e-01],
         [ 1.93965047e+04,  8.49463479e-01],

Output data shape:

output.shape
(100, 6)

Output is basically probablities of each label

output = model.predict(processed_data)

output
array([[0.53827614, 0.64929205, 0.48180097, 0.50065327, 0.43016508,
        0.50453395]

I would like to plot somehow for each instance in processed data the predicted probabilities of classes (since this is the multi-label classification problem) but I am struggling to do so. So how I can plot the processed data but not sure how to plot the probabilities for each instance of input. I would like to be able to label all the 6 possible classes on each of the output. I am a bit lost... Any suggestions?

So far I only plot input: shape = output.shape[0]

for i in range(it):
    fig,axs = plt.subplots(5,2,figsize=(10,10))

    if isinstance(data,list): 
        inp = data[i]
        outp = output[i]
    else: 
        inp = data
        outp = output

    for j in range(5):
        r = randint(0,shape)
        axs[j,0].imshow(inp[r,...,0]); 
        axs[j,0].title.set_text('Input {}'.format(r))

Upvotes: 0

Views: 1169

Answers (1)

l_l_l_l_l_l_l_l
l_l_l_l_l_l_l_l

Reputation: 538

I edited my response now that I understand the question better. This code will plot images along with output.

import matplotlib.image as mpimg
import numpy as np
import matplotlib.pyplot as plt

img_paths = ['../python/imgs/Image001.png',
             '../python/imgs/Image002.png',
             '../python/imgs/Image003.png',
             '../python/imgs/Image004.png',
             '../python/imgs/Image005.png']

input  = np.array([mpimg.imread(path) for path in img_paths])
output = np.random.rand(5, 6)

print(input.shape, output.shape)

fig, axs = plt.subplots(2, 5, figsize=(8, 4), sharey = 'row')

for i, sample in enumerate(range(5)):
    o = output[sample]

    axs[0,i].set_title(f'Sample {sample + 1}')
    axs[0,i].imshow(input[i,:])
    axs[0,i].axis('off')

    axs[1,i].bar(range(6), o)
    axs[1,i].set_xticks(range(6))
    axs[1,i].set_xticklabels([f'{i+1}' for i in range(6)])

plt.show()

Output:

(5, 1510, 2560, 4) (5, 6)

output plot

The important part is the plt.subplots call, where you can create a grid of plots however you like (if you want to actually plot all 100 images, you'll probably prefer a vertical layout).

Upvotes: 2

Related Questions