A.E
A.E

Reputation: 1013

How to efficiently loop through date columns in pandas

I have a large dataset where index for the columns have date format. To explain my problem, i am building a similar dataset as below:

import pandas as pd

Cities = ['San Francisco', 'Los Angeles', 'New York', 'Huston', 'Chicago']
Jan = [10, 20, 15, 10, 35]
Feb = [12, 23, 17, 15, 41]
Mar = [15, 29, 21, 21, 53]
Apr = [27, 48, 56, 49, 73]

data = pd.DataFrame({'City': Cities, '01/01/20': Jan, '02/01/20': Feb, '03/01/20': Mar, '04/01/20': Apr})

print (data)

            City  01/01/20  02/01/20  03/01/20  04/01/20
0  San Francisco        10        12        15        27
1    Los Angeles        20        23        29        48
2       New York        15        17        21        56
3         Huston        10        15        21        49
4        Chicago        35        41        53        73

I want to plot the data for each city as a function of time. Here is my attempt:

import matplotlib.pyplot as plt 

cols = data.columns 

dates = data.loc[:, cols[1:]].columns

San_Francisco = []
Los_Angeles = []
New_York = []
Huston = []
Chicago = []

for i in dates:
    San_Francisco.append(data[data['City'] == 'San Francisco'][i].sum())
    Los_Angeles.append(data[data['City'] == 'Los Angeles'][i].sum())
    New_York.append(data[data['City'] == 'New York'][i].sum())
    Huston.append(data[data['City'] == 'Huston'][i].sum())
    Chicago.append(data[data['City'] == 'Chicago'][i].sum())
    
plt.plot(dates, San_Francisco, label='San Francisco')
plt.plot(dates, Los_Angeles, label='Los Angeles')
plt.plot(dates, New_York, label='New York')
plt.plot(dates, Huston, label='Huston')
plt.plot(dates, Chicago, label='Chicago')
plt.legend()

The results is what I want, however, for large dataset, my approach is not efficient. How can I speed it up? Also for the plotting section, I have a large rows of cities and manually hardcoding the names is tedious; is there a better way?

Thanks

Upvotes: 2

Views: 243

Answers (1)

jezrael
jezrael

Reputation: 862671

If possible some values of City are duplicated first aggregate by GroupBy.sum, then transpose by DataFrame.T and last ploting by DataFrame.plot:

data.groupby('City').sum().T.plot()

graph

If column City has always unique values is possible use DataFrame.set_index:

data.set_index("City").T.plot()

EDIT:

df = data.groupby('City').sum().T
    
N = 10
df.groupby(np.arange(len(df.columns)) // N, axis=1).plot()

Upvotes: 5

Related Questions