Reputation: 4426
I am trying to understand why the following plots look so different
plt.subplot(projection='3d')
plt.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
plt.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
plt.show()
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
ax.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
plt.show()
basically, I don't understand why I need to add a subplot if I just want one plot anyway. So intuitively I would use the first plot, but they don't give the same result?
Upvotes: 4
Views: 2736
Reputation: 339220
The difference is not between plt.subplot
and fig.add_subplot
.
It is rather that in the first case you use pyplot's scatter function plt.scatter
and in the second case you use the scatter
of the axes, ax.scatter
.
plt.scatter
is a 2D function. It will interprete its third argument as the size of the scatter points and draw a scatter in two dimensions. (You see that the z axis is not scaled at all.)
When using ax.scatter
, ax
is a 3D axes (matplotlib.axes._subplots.Axes3DSubplot
). Its scatter
method is different from the 2D case, as it expects 3 arguments x,y,z
as input.
Now you can use both, plt.subplot
and fig.add_subplot
, for a 3D plot, but you cannot use plt.scatter
on any of them.
Instead you need to use ax.scatter
in both cases, making sure that the matplotlib.axes._subplots.Axes3DSubplot
's scatter
method is called.
One option is to use plt.gca()
to get the current axes (which is the 3D axes):
plt.subplot(projection='3d')
plt.gca().scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
plt.gca().scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
plt.show()
You can also get the axes from the call to `plt.subplot()
ax = plt.subplot(projection='3d')
ax.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
ax.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
plt.show()
Of course you may use the way you already found working,
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
ax.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
plt.show()
Or you may use plt.subplots
(mind the s
) to get a figure and axes handle at the same time,
fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
ax.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
ax.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
plt.show()
The result will be the same in all cases.
Upvotes: 4