eastafri
eastafri

Reputation: 2226

How to change shapes/colours for 3d points with matplot3d

Python and MatPlot3D newbie. I have a plot with which I would like display co-ordinates using different shapes and colours depending on some attributes. The data looks like this.

col1 col2   col3 col4 col5
276  147    -6   K  dia
274  145    -8   A  cir
270  141    -12  B  dia
267  138    -15  K  cir
266  137    -16  K  cir
261  132    -21  B  bu
251  122    -31  C  cir

Now I would like to change the shapes based on col4 and color of the shapes based on col5. I have this code for now that reads the data points from a file and only plots the points.

from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.pyplot as plt
from matplotlib.mlab import griddata
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')

data = np.genfromtxt('distances.txt')

x = data[:,0]
y = data[:,1]
z = data[:,3]

ax.scatter(x, y, z,c='red',marker='^')

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')

plt.show()

How can I easily specify the shape and colour depending on value of col4 and col5?

Upvotes: 0

Views: 949

Answers (1)

tacaswell
tacaswell

Reputation: 87376

The first thing you need to do is import your data in a way that doesn't turn those columns into 'nan', you then need to translate the column values into values that mpl can understand.

from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import matplotlib.pyplot as plt
from matplotlib.mlab import griddata
import numpy as np
import csv

color_map = {'A':'r', 'B':'b', 'K':'k', 'C':'c'}
shape_map = {'dia':'^', 'cir':'o', 'bu':'.'}

with open('/tmp/dist.txt','r') as in_file:
    reader = csv.DictReader(in_file, delimiter=' ', skipinitialspace=True)
    data = []
    for r in reader:
        data.append([float(r['col1']),
                     float(r['col2']),
                     float(r['col3']),
                     color_map[r['col4']],
                     shape_map[r['col5']]])

To get colors is easy, scatter will take an iterable of colors for per-marker coloring:

X, Y, Z, col, shape = zip(*data)


fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')


ax.scatter(X, Y, Z, c=col)

Getting the shape is a bit trickier, as scatter only takes one marker for all the points, so if you want to use multiple scatter calls:

import collections

by_shape = collections.defaultdict(list)
for d in data:
    by_shape[d[4]].append(d[:4])

for key, val in by_shape.items():
    X, Y, Z, col = zip(*val)
    ax.scatter(X, Y, Z, c=col, marker=key)

Upvotes: 1

Related Questions