Reputation: 678
I have time-series data that collected weekly basis, where I want to see the correlation of its two columns. to do so, I could able to find a correlation between two columns and want to see how rolling correlation moves each year. my current approach works fine but I need to normalize the two columns before doing rolling correlation and making a line plot. In my current attempt, I don't know how to show 3-year, 5 year rolling correlation. Can anyone suggest a possible idea of doing this in matplotlib
?
current attempt:
Here is my current attempt:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
dataPath="https://gist.github.com/jerry-shad/503a7f6915b8e66fe4a0afbc52be7bfa#file-sample_data-csv"
def ts_corr_plot(dataPath, roll_window=4):
df = pd.read_csv(dataPath)
df['Date'] = pd.to_datetime(df['Date'])
df['week'] = pd.DatetimeIndex(df['date']).week
df['year'] = pd.DatetimeIndex(df['date']).year
df['week'] = df['date'].dt.strftime('%W').astype('uint8')
def find_corr(x):
df = df.loc[x.index]
return df[:, 1].corr(df[:, 2])
df['corr'] = df['week'].rolling(roll_window).apply(find_corr)
fig, ax = plt.subplots(figsize=(7, 4), dpi=144)
sns.lineplot(x='week', y='corr', hue='year', data=df,alpha=.8)
plt.show()
plt.close
update:
I want to see rolling correlation in different time window such as:
plt_1 = ts_corr_plot(dataPath, roll_window=4)
plt_2 = ts_corr_plot(dataPath, roll_window=12)
plt_3 = ts_corr_plot(dataPath, roll_window=24)
I need to add 3-years, 5-years rolling correlation to the plots but I couldn't find a better way of doing this. Can anyone point me out how to make a rolling correlation line plot for time series data? How can I improve the current attempt? any idea?
desired plot
this is my expected plot that I want to obtain:
Upvotes: 1
Views: 823
Reputation: 35135
Customizing the legend in esaborn is painstaking, so I created the code in matplotlib.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
dataPath="https://gist.githubusercontent.com/jerry-shad/503a7f6915b8e66fe4a0afbc52be7bfa/raw/414a2fc2988fcf0b8e6911d77cccfbeb4b9e9664/sample_data.csv"
df = pd.read_csv(dataPath)
df['Date'] = pd.to_datetime(df['Date'])
df['week'] = df['Date'].dt.isocalendar().week
df['year'] = df['Date'].dt.year
df['week'] = df['Date'].dt.strftime('%W').astype('uint8')
def find_corr(x):
dfc = df.loc[x.index]
tmp = dfc.iloc[:, [1,2]].corr()
tmp = tmp.iloc[0,1]
return tmp
roll_window=4
df['corr'] = df['week'].rolling(roll_window).apply(find_corr)
df3 = df.copy() # three year
df3['corr3'] = df3['year'].rolling(156).apply(find_corr) # 3 year = 52 week x 3 year = 156
fig, ax = plt.subplots(figsize=(12, 4), dpi=144)
cmap = plt.get_cmap("tab10")
for i,y in enumerate(df['year'].unique()):
tmp = df[df['year'] == y]
ax.plot(tmp['week'], tmp['corr'], color=cmap(i), label=y)
for i,y in enumerate(df['year'].unique()):
tmp = df3[df3['year'] == y]
if tmp['corr3'].notnull().all():
ax.plot(tmp['week'], tmp['corr3'], color=cmap(i), lw=3, linestyle='--', label=str(y)+' 3 year avg')
ax.grid(axis='both')
ax.legend(loc='upper left', bbox_to_anchor=(1.0, 1.0), borderaxespad=1)
plt.show()
# plt.close
Upvotes: 1