Reputation: 4872
I have a dataframe df
as below (all columns numerical, but the last two should be categorical)
Close Direction prediction
Date
2018-03-31 40.889999 -1 1
2017-12-31 34.459999 1 1
2017-09-30 40.529999 -1 -1
2017-06-30 38.200001 1 -1
2017-03-31 43.160000 1 -1
2016-12-31 46.369999 1 -1
2016-09-30 63.180000 1 -1
2016-06-30 64.300003 1 1
2016-03-31 66.500000 1 1
2015-12-31 85.250000 -1 -1
2015-09-30 63.020000 1 1
2015-06-30 87.139999 -1 -1
2015-03-31 83.169998 -1 1
i want to plot these three columns with a shared x axis .
plot1 - line plot (x = date index of dataframe , y = df[close])
plot2 - scatter plot (x = date index of dataframe , y = df[Direction])
plot3 - scatter plot (x = date index of dataframe , y = df[prediction])
all three plots should be one over the other sharing x axis,
I have tried the below core but is not getting desired output..
fig, (ax1, ax2,ax3) = plt.subplots(3, 1,figsize=(10,7), sharex=True)
ax1.plot(x= df.index, y=df['Close'])
ax2.scatter(x= df.index, y=df['Direction'].astype('category'),color='blue')
ax3.scatter(x= df.index, y=df['prediction'].astype('category'),color='red')
xtick_dates = pd.date_range(df.index[0], df.index[-1], freq='3M')
plt.xticks(dates_rng, [dtz.strftime('%Y-%m') for dtz in xtick_dates], rotation=90)
plt.show()
can anyone find a way to solve this?
Upvotes: 0
Views: 1429
Reputation: 4872
figured out
def plot_predictions(details):
global path
df = pd.read_csv(path+"Quarterly_prediction.csv",parse_dates=['Date'],index_col=0)
fig, (ax1, ax2,ax3) = plt.subplots(3, 1,figsize=(10,7), sharex=True)
fig.subplots_adjust(bottom=0.2)
ax1.plot( df.index, df['Close'])
ax1.xaxis.grid(True,alpha=0.3)
ax1.set_ylabel('Quarterly Closing')
ax2.scatter(x= df.index, y=df['Direction'],color='blue')
ax2.xaxis.grid(True,alpha=0.3)
ax2.set_ylabel('Actual Direction')
ax3.scatter(x= df.index, y=df['prediction'],color='red')
ax3.xaxis.grid(True,alpha=0.3)
ax3.set_ylabel('Predicted Direction ')
fig.suptitle(details, fontsize=10)
xtick_dates = pd.date_range(start = df.index[0], end = df.index[-1], freq='3M')
plt.xticks(xtick_dates, [dtz.strftime('%Y-%m') for dtz in xtick_dates], rotation=88)
fig.text(0.5, 0.04, 'Quarter Closing Dates', ha='center')
plt.show()
Upvotes: 1