janoliver
janoliver

Reputation: 7824

Make pyplot faster than gnuplot

I recently decided to give matplotlib.pyplot a try, while having used gnuplot for scientific data plotting for years. I started out with simply reading a data file and plot two columns, like gnuplot would do with plot 'datafile' u 1:2. The requirements for my comfort are:

Now, the following code is my solution for the problem. However, compared to gnuplot, it really is not as fast. This is a bit odd, since I read that one big advantage of py(plot/thon) over gnuplot is it's speed.

import numpy as np
import matplotlib.pyplot as plt
import sys

datafile = sys.argv[1]
data = []
for line in open(datafile,'r'):
    if line and line[0] != '#':
        cols = filter(lambda x: x!='',line.split(' '))
        for index,col in enumerate(cols):
            if len(data) <= index:
                data.append([])
            data[index].append(float(col))

plt.plot(data[0],data[1])
plt.show()

What would I do to make the data reading faster? I had a quick look at the csv module, but it didn't seem to be very flexible with comments in files and one still needs to iterate over all lines in the file.

Upvotes: 1

Views: 3402

Answers (2)

unutbu
unutbu

Reputation: 880269

Since you have matplotlib installed, you must also have numpy installed. numpy.genfromtxt meets all your requirements and should be much faster than parsing the file yourself in a Python loop:

import numpy as np
import matplotlib.pyplot as plt

import textwrap
fname='/tmp/tmp.dat'
with open(fname,'w') as f:
    f.write(textwrap.dedent('''\
        id col1 col2 col3
        2010 1 2 3 4
        # Foo

        2011 5 6 7 8
        # Bar        
        # Baz
        2012 8 7 6 5
        '''))

data = np.genfromtxt(fname, 
                     comments='#',    # skip comment lines
                     dtype = None,    # guess dtype of each column
                     names=True)      # use first line as column names
print(data)
plt.plot(data['id'],data['col2'])
plt.show()

Upvotes: 6

agf
agf

Reputation: 176910

You really need to profile your code to find out what the bottleneck is.

Here are some micro-optimizations:

import numpy as np
import matplotlib.pyplot as plt
import sys

datafile = sys.argv[1]
data = []
# use with to auto-close the file
for line in open(datafile,'r'):
    # line will never be False because it will always have at least a newline
    # maybe you mean line.rstrip()?
    # you can also try line.startswith('#') instead of line[0] != '#'
    if line and line[0] != '#':
        # not sure of the point of this
        # just line.split() will allow any number of spaces
        # if you do need it, use a list comprehension
        # cols = [col for col in line.split(' ') if col]
        # filter on a user-defined function is slow
        cols = filter(lambda x: x!='',line.split(' '))

        for index,col in enumerate(cols):
            # just made data a collections.defaultdict
            # initialized as data = defaultdict(list)
            # and you can skip this 'if' statement entirely
            if len(data) <= index:
                data.append([])
            data[index].append(float(col))

plt.plot(data[0],data[1])
plt.show()

You may be able to do something like:

with open(datafile) as f:
    lines = (line.split() for line in f 
                 if line.rstrip() and not line.startswith('#'))
    data = zip(*[float(col) for col in line for line in lines])

Which will give you a list of tuples instead of an int-keyed dict of lists, but otherwise appears identical. It can be done as a one-liner but I split it up to make it a little easier to read.

Upvotes: 2

Related Questions