James Oliver
James Oliver

Reputation: 597

Generate a multi scatter subplot - unexpected results returned

I am new to python and I am trying to create a subplot of multiple months. What happens is that the block of 3 x3 is generated but is empty, then each chart follows underneath one another which doesn't allow for easy viewing.

Here is my code which I've lifted from a similar question.

def scat_months2(df,prod):
    """Print scatter for all months as sub plots of any given product"""
    uniq=sorted(set(train2.YM))[0:9]
    fig, axes = plt.subplots(3, 3, figsize=(6, 4), sharex=True, sharey=True)
    for period in uniq:
        df[(df["YM"]==period) & (df["item_id"]==prod)].plot(x='shop_id',
            y='item_price',
            kind='scatter',
            label=period,alpha=0.2)

    fig.tight_layout()

I have tried to generate some random data for so you can help me, but this also hasn't worked out as I'd hoped (again python newbie)... it generates a different error. I hope this still allows you can easily fix my example and then see the same result I see..

If you tell me how I could generate my random data properly too, that would really help my learning. I apologise for not being able to come with a functional example.

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

T=pd.Series([201301,201301,201301,201301,201301,201301,201301,201301,201301])
Shop=pd.Series([1,1,1,2,2,2,3,3,3])
Price=pd.Series(np.random.randint(10, size=(9)))
ds1=pd.DataFrame(dict(T = T, Shop=Shop,Price = Price))
T2=pd.Series([201302,201302,201302,201302,201302,201302,201302,201302,201302])
ds2=pd.DataFrame(dict(T = T2, Shop=Shop,Price = Price))
T3=pd.Series([201303,201303,201303,201303,201303,201303,201303,201303,201303])
ds3=pd.DataFrame(dict(T = T3, Shop=Shop,Price = Price))
ds=pd.concat([ds1,ds2,ds3], axis=0)
ds.index=range(27)

def scat_months2(df):
    """Print scatter for all months as sub plots of any given product"""
    uniq=sorted(set(df.T))
    fig, axes = plt.subplots(3, 1, figsize=(6, 4), sharex=True, sharey=True)
    for period in uniq:
        df[df["T"]==period].plot(x='Shop',
            y='Price',
            kind='scatter')

    fig.tight_layout()

Upvotes: 0

Views: 54

Answers (1)

Stef
Stef

Reputation: 30579

You need to assign the ax parameter of the plot function:

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

T=pd.Series([201301,201301,201301,201301,201301,201301,201301,201301,201301])
Shop=pd.Series([1,1,1,2,2,2,3,3,3])
Price=pd.Series(np.random.randint(10, size=(9)))
ds1=pd.DataFrame(dict(T = T, Shop=Shop,Price = Price))
T2=pd.Series([201302,201302,201302,201302,201302,201302,201302,201302,201302])
ds2=pd.DataFrame(dict(T = T2, Shop=Shop,Price = Price))
T3=pd.Series([201303,201303,201303,201303,201303,201303,201303,201303,201303])
ds3=pd.DataFrame(dict(T = T3, Shop=Shop,Price = Price))
ds=pd.concat([ds1,ds2,ds3], axis=0)
ds.index=range(27)

def scat_months2(df):
    """Print scatter for all months as sub plots of any given product"""
    uniq=sorted(set(df['T']))
    fig, axes = plt.subplots(len(uniq), 1, figsize=(6, 4), sharex=True, sharey=True)
    for i, period in enumerate(uniq):
        df[df["T"]==period].plot(x='Shop',
            y='Price',
            kind='scatter',
            ax=axes[i])

    fig.tight_layout()
    plt.show()

scat_months2(ds)

enter image description here

(There was a small error in your example that I corrected to make it work: df.T returns the transposed dataframe, if your column name is T then you need to explicitly write df['T'] instead of df.T)


PS: In order to conveniently create sample data you can use numpy's repeat and tile functions:

months = [201301, 201302, 201303]
shops = [1,2,3]
n = 3
df = pd.DataFrame({'Month': np.repeat(months, n*len(shops)), 'Shop': np.tile(shops, n*len(months)), 'Price': np.random.randint(10, size=n*len(shops)*len(months))})

Upvotes: 1

Related Questions