Reputation: 15
I want to plot a 3-D surface with the axis in the middle of the figure.
I use the following code to plot the figure:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-1, 1, 10)
y = np.linspace(-1, 1, 10)
X, Y = np.meshgrid(x, y)
Z = np.array([[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291]])
fig = plt.figure()
ax = plt.axes(projection='3d')
surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0,
cmap='viridis', edgecolor='none', antialiased=False)
ax.set_xlim(-1.01, 1.01)
fig.colorbar(surf, shrink=0.5, aspect=5)
# Plot the axis in the middle of the figure
xline=((min(X[:,0]),max(X[:,0])),(0,0),(0,0))
ax.plot(xline[0],xline[1],xline[2],'grey')
yline=((0,0),(min(Y[:,1]),max(Y[:,1])),(0,0))
ax.plot(yline[0],yline[1],yline[2],'grey')
zline=((0,0),(0,0),(min(Z[:,2]),max(Z[:,2])))
ax.plot(zline[0],zline[1],zline[2],'grey')
ax.view_init(30,220) # Camera angel
ax.set_title('surface');
By using the above code, I obtain the figure like this
What I really want is to plot a 3-D axis origin figure like the following:
How to eliminate the margin and put the axis in the middle of the graph?
Upvotes: 0
Views: 3329
Reputation: 34018
Here is a solution using plotly.
The code is represented below, but here I want to give the most important remarks
get_arrow()
.tickvals
and range
parameters for the layout.scene.xaxis
, layout.scene.yaxis
and layout.scene.zaxis
. To have only two values, so that drawing a grid will show the box like this. If you would like to show the normal grid also, this should be done with vectors, too (like the arrows).showscale=True
for the go.Surface
.import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
def get_data():
x = np.linspace(-1, 1, 10)
y = np.linspace(-1, 1, 10)
X, Y = np.meshgrid(x, y)
Z = np.array([[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291],
[2.58677481, 3.22528864, 3.65334814, 3.86669336, 3.86399048,
3.64525411, 3.21186215, 2.56819809, 1.72989472, 0.78569291]])
return X, Y, Z
# One-color arrows & arrowheads
colorscale = [
[0, "rgb(84,48,5)"],
[1, "rgb(84,48,5)"],
]
X, Y, Z = get_data()
data = {
key: {
"min": v.min(),
"max": v.max(),
"mid": (v.max() + v.min()) / 2,
"range": v.max() - v.min(),
}
for (key, v) in dict(x=X, y=Y, z=Z).items()
}
def get_arrow(axisname="x"):
# Create arrow body
body = go.Scatter3d(
marker=dict(size=1, color=colorscale[0][1]),
line=dict(color=colorscale[0][1], width=3),
showlegend=False, # hide the legend
)
head = go.Cone(
sizeref=0.1,
autocolorscale=None,
colorscale=colorscale,
showscale=False, # disable additional colorscale for arrowheads
hovertext=axisname,
)
for ax, direction in zip(("x", "y", "z"), ("u", "v", "w")):
if ax == axisname:
body[ax] = data[ax]["min"], data[ax]["max"]
head[ax] = [data[ax]["max"]]
head[direction] = [1]
else:
body[ax] = data[ax]["mid"], data[ax]["mid"]
head[ax] = [data[ax]["mid"]]
head[direction] = [0]
return [body, head]
def add_axis_arrows(fig):
for ax in ("x", "y", "z"):
for item in get_arrow(ax):
fig.add_trace(item)
def get_annotation_for_ax(ax):
d = dict(showarrow=False, text=ax, xanchor="left", font=dict(color="#1f1f1f"))
for ax_ in ("x", "y", "z"):
if ax_ == ax:
d[ax_] = data[ax]["max"] - data[ax]["range"] * 0.05
else:
d[ax_] = data[ax_]["mid"]
if ax in {"x", "y"}:
d["xshift"] = 15
return d
def get_axis_names():
return [get_annotation_for_ax(ax) for ax in ("x", "y", "z")]
def get_scene_axis(axisname="x"):
return dict(
title="", # remove axis label (x,y,z)
showbackground=False,
visible=True,
showticklabels=False, # hide numeric values of axes
showgrid=True, # Show box around plot
gridcolor="grey", # Box color
tickvals=[data[axisname]["min"], data[axisname]["max"]], # Set box limits
range=[
data[axisname]["min"],
data[axisname]["max"],
], # Prevent extra lines around box
)
fig = go.Figure(
data=[
go.Surface(
z=Z,
x=X,
y=Y,
opacity=0.6,
showscale=False, # Set to True to show colorscale
),
],
layout=dict(
title="surface",
autosize=True,
width=700,
height=500,
margin=dict(l=20, r=20, b=25, t=25),
scene=dict(
xaxis=get_scene_axis("x"),
yaxis=get_scene_axis("y"),
zaxis=get_scene_axis("z"),
annotations=get_axis_names(),
),
),
)
add_axis_arrows(fig)
fig.show()
Upvotes: 2
Reputation: 4045
I don't know of a proper way but only a workaround.
There is a detailed and very good description for the 2D case here with the options of
ax.axhline(y=0, color='k')
ax.spines['left'].set_position('zero')
seaborn.despine(ax=ax, offset=0)
However, I am afraid, they don't work this easy in 3D.
I am only aware of this workaround, where the (outer) axes are turned off (ax.set_axis_off()
) and arrows are drawn (ax.quiver()
) from the origin.
So you add
x, y, z = np.array([[-1,0,0],[0,-1,0],[0,0,0]])
u, v, w = np.array([[2,0,0],[0,2,0],[0,0,5]])
ax.quiver(x,y,z,u,v,w,arrow_length_ratio=0.1, color="black")
ax.set_axis_off()
plt.show()
to your code and you'll get this picture:
Upvotes: 0