Josiah Glyshaw
Josiah Glyshaw

Reputation: 11

How to efficiently plot many lines in a 3D matplotlib graph

I have some edges between nodes that I need to graph in a 3D matplotlib graph. One example of an edge would be from coordinates (0,0,0) to coordinates (5,2,1).

I can't graph all the edges with one plot call because then it would connect nodes together that don't have an edge. I could call a separate plot function for each edge but when I am dealing with thousands of edges, the performance for moving the graph around goes way down.

Is there a way I can have thousands of seperate lines without the performance dropping?

Here is some code I wrote for an example of what I want to do, with dummy data. The data is a list of (nodeA, nodeB) where nodeA is coordinates for first vertex of edge and nodeB is second vertex of edge.

import matplotlib.pyplot as plt
from random import *

def generateDummyData(size):
    data = []
    for i in range(size):
        nodeA = (randrange(-10,10),randrange(-10,10),randrange(-10,10))
        nodeB = (randrange(-10,10),randrange(-10,10),randrange(-10,10))
        data.append([nodeA,nodeB])
    return data


#graph each edge in the 3D graph
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
data = generateDummyData(4000)
for edge in data:
    nodeA = edge[0]
    nodeB = edge[1]
    
    x = [nodeA[0],nodeB[0]]
    y = [nodeA[1],nodeB[1]]
    z = [nodeA[2],nodeB[2]]
    ax.plot(x,y,z)

plt.show()

Upvotes: 0

Views: 168

Answers (1)

chrslg
chrslg

Reputation: 13336

It seems you are looking for LineCollections

import matplotlib.pyplot as plt
from random import *
import numpy as np
from mpl_toolkits.mplot3d.art3d import Line3DCollection

def generateDummyData(size):
    data = []
    for i in range(size):
        nodeA = (randrange(-10,10),randrange(-10,10),randrange(-10,10))
        nodeB = (randrange(-10,10),randrange(-10,10),randrange(-10,10))
        data.append([nodeA,nodeB])
    return data

#graph each edge in the 3D graph
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
data = np.array(generateDummyData(4000))
color=np.random.random((len(data),3))

lc = Line3DCollection(data, color=color)
ax.add_collection3d(lc)
ax.set_xlim(-10, 10)
ax.set_ylim(-10, 10)
ax.set_zlim(-10, 10)

plt.show()

Upvotes: 1

Related Questions