Reputation: 169
This is how my df looks like:
hr slope value
8 s_1 6
10 s_1 2
8 s_2 4
10 s_2 8
I would like to make a 3D bar plot with 'hr' in the x-axis, 'value' in the y-axis, and 'slopes' in the z-axis.
xpos = df['hr']
ypos = df['value']
xpos, ypos = np.meshgrid(xpos+0.25, ypos+0.25)
xpos = xpos.flatten()
ypos = ypos.flatten()
zpos=np.zeros(df.shape).flatten()
dx=0.5 * np.ones_like(zpos)
dy=0.5 * np.ones_like(zpos)
dz=df.values.ravel()
ax.bar3d(xpos,ypos,zpos,dx,dy,dz,color='b', alpha=0.5)
plt.show()
I get the following error messages:
ValueError: shape mismatch: objects cannot be broadcast to a single shape
Any help is very welcome, thank you in advance
Upvotes: 4
Views: 6090
Reputation: 30032
The documentation of bar3d()
can be found at https://matplotlib.org/mpl_toolkits/mplot3d/api.html#mpl_toolkits.mplot3d.axes3d.Axes3D.bar3d. Here is an explanation of it. Official demo can be found at https://matplotlib.org/3.1.1/gallery/mplot3d/3d_bars.html.
import matplotlib.pyplot as plt
xpos = [1, 2, 3] # x coordinates of each bar
ypos = [0, 0, 0] # y coordinates of each bar
zpos = [0, 0, 0] # z coordinates of each bar
dx = [0.5, 0.5, 0.5] # Width of each bar
dy = [0.5, 0.5, 0.5] # Depth of each bar
dz = [5, 4, 7] # Height of each bar
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.bar3d(xpos,ypos,zpos,dx,dy,dz, color='b', alpha=0.5)
plt.show()
The problem why you got this error is that the length of xpos, ypos, zpos, dx, dy, dz
is not the same. Besides, the element of dz
contains string.
Here is how I reproduce your example
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
df = pd.read_csv('1.csv')
xpos = df['hr']
ypos = df['value']
xpos, ypos = np.meshgrid(xpos+0.25, ypos+0.25)
xpos = xpos.flatten()
ypos = ypos.flatten()
zpos = np.zeros(df.shape).flatten()
dx = 0.5 * np.ones_like(zpos)
dy = 0.5 * np.ones_like(zpos)
dz = df[['hr', 'value']].values.ravel()
print(xpos)
print(ypos)
print(zpos)
print(dx)
print(dy)
print(dz) # [8 's_1' 6 10 's_1' 2 8 's_2' 4 10 's_2' 8]
print(len(xpos)) # 16
print(len(ypos)) # 16
print(len(zpos)) # 12
print(len(dx)) # 12
print(len(dy)) # 12
print(len(dz)) # 12
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.bar3d(xpos,ypos,zpos,dx,dy,dz,color='b', alpha=0.5)
plt.show()
The content of 1.csv
is
hr,slope,value
8,s_1,6
10,s_1,2
8,s_2,4
10,s_2,8
Upvotes: 2