carl
carl

Reputation: 4426

Understanding the difference between subplot and add_subplot (scatter) plots in matplotlib

I am trying to understand why the following plots look so different

plt.subplot(projection='3d')
plt.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
plt.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
plt.show()

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
ax.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
plt.show()

basically, I don't understand why I need to add a subplot if I just want one plot anyway. So intuitively I would use the first plot, but they don't give the same result? enter image description here enter image description here

Upvotes: 4

Views: 2736

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339220

The difference is not between plt.subplot and fig.add_subplot.
It is rather that in the first case you use pyplot's scatter function plt.scatter and in the second case you use the scatter of the axes, ax.scatter.

plt.scatter is a 2D function. It will interprete its third argument as the size of the scatter points and draw a scatter in two dimensions. (You see that the z axis is not scaled at all.)

When using ax.scatter, ax is a 3D axes (matplotlib.axes._subplots.Axes3DSubplot). Its scatter method is different from the 2D case, as it expects 3 arguments x,y,z as input.

Now you can use both, plt.subplot and fig.add_subplot, for a 3D plot, but you cannot use plt.scatter on any of them. Instead you need to use ax.scatter in both cases, making sure that the matplotlib.axes._subplots.Axes3DSubplot's scatter method is called.

  • One option is to use plt.gca() to get the current axes (which is the 3D axes):

    plt.subplot(projection='3d')
    plt.gca().scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
    plt.gca().scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
    plt.show()
    
  • You can also get the axes from the call to `plt.subplot()

    ax = plt.subplot(projection='3d')
    ax.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
    ax.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
    plt.show()
    
  • Of course you may use the way you already found working,

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
    ax.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
    plt.show()
    
  • Or you may use plt.subplots (mind the s) to get a figure and axes handle at the same time,

    fig, ax = plt.subplots(subplot_kw=dict(projection="3d"))
    ax.scatter(position1[:,0], position1[:,1], position1[:,2], marker='.')
    ax.scatter(position2[:,0], position2[:,1], position2[:,2], marker='.')
    plt.show()
    

The result will be the same in all cases.

Upvotes: 4

Related Questions