Reputation: 570
I wrote a function that takes a pandas data frame and two of its columns. Inside the function, I want to group the elements of the first column by the elements of the second column. The goal of the function is to generate a bar chart using matplotlib that plots the grouped counts. I do not know how to refer to the column arguments so they can be recognized by the group-by call inside the function.
I tried using df['col'] and 'col' but none of these worked. When I use df['col'], I get this error:
AttributeError: 'DataFrameGroupBy' object has no attribute 'x'
When I use 'col', I get this error:
AttributeError: 'DataFrameGroupBy' object has no attribute 'x'
Here is an example implementation, first without the function, to generate the expected result, and then with the function.
import pandas as pd
# generate dataframe
df = pd.DataFrame()
df['col_A'] = [1, 4, 3, 2, 2, 1, 1, 4, 3, 2]
df['col_B'] = ['a', 'a', 'b', 'b', 'b', 'c', 'c', 'c', 'c', 'c']
# plot counts
import matplotlib.pyplot as plt
counts = df.groupby('col_B').col_A.count()
counts = counts.sort_values(ascending=False)
fig = plt.figure(figsize=(10,8))
counts.plot.barh(ylim=0).invert_yaxis()
# plot count with function
def count_barplot(data, x, y):
counts = data.groupby(y).x.count()
counts = counts.sort_values(ascending=False)
fig = plt.figure(figsize=(10,8))
counts.plot.barh(ylim=0).invert_yaxis()
# function call
count_barplot(df, df['col_A'], df['col_B'])
How do I specify the data frame column arguments inside the function and in the function call, so that the group-by function can recognize them?
Upvotes: 1
Views: 3034
Reputation: 4011
The problem is that your function call is providing a a dataframe and two series as its arguments, while what you want to pass is a dataframe and column names. Note that you also want to use the []
syntax for referring to the column in your groupby
, and you can simplify your count method using the built-in value_counts()
method.
Thus, using your syntax:
# plot count with function
def count_barplot(data, x, y):
counts = data.groupby(y)[x].count()
counts = counts.sort_values(ascending=False)
fig = plt.figure(figsize=(10,8))
counts.plot.barh(ylim=0).invert_yaxis()
count_barplot(df, 'col_A', 'col_B')
or more simply:
# plot count with function
def count_barplot(data, y):
counts = df[y].value_counts()
fig = plt.figure(figsize=(10,8))
counts.plot.barh(ylim=0).invert_yaxis()
# function call
count_barplot(df, 'col_B')
or even
def count_barplot(data, x, y):
fig = plt.figure(figsize=(10,8))
df[y].value_counts(ascending=True).plot.barh(ylim=0)
Upvotes: 2
Reputation: 3910
This way it works for me:
def count_barplot(data, x, y):
counts = data.groupby(y)[x].count()
counts = counts.sort_values(ascending=False)
fig = plt.figure(figsize=(10,8))
counts.plot.barh(ylim=0).invert_yaxis()
# function call
count_barplot(df, 'col_A', 'col_B')
Upvotes: 2