Reputation:
I have the following dataframe:
StockId Date Value
1 2015-01-02 -0.070012
2 2015-01-02 -0.022447
4 2015-01-02 -0.011474
6 2015-01-02 0.003796
13 2015-01-02 -0.032061
...
355 2018-09-14 -0.035717
356 2018-09-14 -0.007899
357 2018-09-14 0.065217
358 2018-09-14 0.063536
359 2018-09-14 -0.023433
I'm looking to find the correlation between stocks over time in order to find the five stocks that are most correlated with stock 1. Is there a quick way to do this using pandas? Or does this require creating arrays and then calculating the correlations one by one? There are 359 stocks in the data frame.
Upvotes: 0
Views: 296
Reputation: 4564
Assuming your dataframe is in a long format where each stock is valued once per day, you can use the pivot function to reshape into a wide format. Specify Date
to be the index of the new dataframe and StockID
to be the columns. If you have data that is sampled more than daily, you can specify the aggfunc
argument to be min/max/avg or whatever else you deem appropriate for your application. If you have data that is sampled less than daily, you can still run the code, but be aware that the correlation will be based on some null values.
Note: I'm only saying daily because that's what your table seems to imply.
From there you can use df.corr()
to view the correlation matrix.
df = df.pivot(index='Date', columns='StockID')
df.columns = df.columns.droplevel() # Convert multi-index to single index
print(df)
# StockID a b c
# Date
# 1/10/2020 0.956625 0.175345 0.999375
# 1/11/2020 0.458859 0.714604 0.995440
# 1/12/2020 0.603331 0.881022 0.215262
# 1/13/2020 0.584198 0.303796 0.332117
matrix = df.corr()
print(matrix)
# StockID a b c
# StockID
# a 1.000000 -0.680290 0.305365
# b -0.680290 1.000000 -0.336229
# c 0.305365 -0.336229 1.000000
From there, you could iterate through each row, sort the row by values, and then you'll have a dict sorted by the strongest correlation.
for stock, corr in matrix.to_dict().items():
corr = {
k: v for k, v
in sorted(corr.items(), key=lambda item: -item[1])
if k != stock
}
print(stock, corr)
# a {'c': 0.30536503121224934, 'b': -0.6802897760166712}
# b {'c': -0.3362290204607999, 'a': -0.6802897760166712}
# c {'a': 0.30536503121224934, 'b': -0.3362290204607999}
Or, if you want a more visual comparison,
plt.matshow(matrix)
plt.colorbar()
plt.show()
Upvotes: 1