Reputation: 33
I'm trying to build a program that deals with baseball statistics. I Ask the user to input a team and then the code runs through a panda I have created searching for "teamID" that matches the user input.
I've tried grouping by "teamID" but and indexing before the for loop.
def AttendancePlot(teams,team_pick):
fig, ax = plt.subplots()
group_by_teamID = teams.groupby(by=['teamID'])
print group_by_teamID
for i in group_by_teamID.index:
if i == team_pick:
ax.scatter(teams['yearID'][i], teams['attendance'][i], color="#4DDB94", s=200)
ax.annotate(i, (teams['yearID'][i], teams['attendance'][i]),
bbox=dict(boxstyle="round", color="#4DDB94"),
xytext=(-30, 30), textcoords='offset points',
arrowprops=dict(arrowstyle="->", connectionstyle="angle,angleA=0,angleB=90,rad=10"))
How I'm creating the Panda
teams = pd.read_csv('Teams.csv')
salaries = pd.read_csv('Salaries.csv')
names = pd.read_csv('Names.csv')
teams = teams[teams['yearID'] >= 1985]
teams = teams[['yearID', 'teamID', 'Rank', 'R', 'RA', 'G', 'W', 'H', 'BB', 'HBP', 'AB', 'SF', 'HR', '2B', '3B', 'attendance']]
teams = teams.set_index(['yearID', 'teamID'])
salaries_by_yearID_teamID = salaries.groupby(['yearID', 'teamID']) ['salary'].sum()
teams = teams.join(salaries_by_yearID_teamID)
print teams.head(15)
Outputted Panda
Rank R RA G ... 2B 3B attendance salary
yearID teamID ...
1985 ATL 5 632 781 162 ... 213 28 1350137.0 14807000.0
BAL 4 818 764 161 ... 234 22 2132387.0 11560712.0
BOS 5 800 720 163 ... 292 31 1786633.0 10897560.0
CAL 2 732 703 162 ... 215 31 2567427.0 14427894.0
I would like a scatter plot showing yearly attendance of a certain inputted team. I am getting a blank graph with no errors.
Upvotes: 3
Views: 62
Reputation: 7361
No need to use groupby()
here, groupby()
is typically used when you want to apply some math on a selection of rows. What you need is a proper selection of the data.
This function will plot year (x axis) vs attendance (y axis) of the given team team_pick
, assuming the dataframe structure you described (dataframe is teams
):
def AttendancePlot(teams, team_pick):
teamdata = teams.loc[teams.index.get_level_values('teamID') == team_pick]
plt.scatter(teamdata.index.levels[0], teamdata['attendance'])
plt.show()
I leave annotation to you.
The key is this line: teamdata = teams.loc[teams.index.get_level_values('teamID') == team_pick]
.
teams.index.get_level_values('teamID') == team_pick
performs a selection on the multiline index, allowing you to select all rows where the team is team_pick
.
teamdata
is hence a dataframe containing all the rows for the given team.
This is called pandas indexing. See also the pandas advanced indexing.
Upvotes: 2