Reputation: 13
In my research I frequently make use of a matlab script to animate the evolution of functions over time, purely for visualization purposes. I am currently working on transferring my code to python, and am having trouble achieving the same animation speeds that matlab provides.
I have tried implementing the same script in python using matplotlib.pyplot, and using nearly identical code structure. I am aware of the matplotlib.animation.Animation framework as well, however I was unable to get the desirable performance. There are some obvious workaround (i.e., reducing the number of points in my vector, iterating with larger step size, etc.), however I am really interested in whether or not python can match the performance of matlab in this specific application.
Here is the script in matlab:
line = plot(0,0,'k','linewidth',2);
x=0:0.05:4*pi;
y=sin(x);
axis([min(x) max(x) min(y) max(y)])
for k=1:length(x)
set(line,'XData',x(1:k),'YData',y(1:k))
pause(0.0001)
end
and here is my implementation in python:
import matplotlib.pyplot as plt
import numpy as np
import time
x = np.arange(0, 4*np.pi, 0.05)
y = np.sin(x)
fig = plt.figure()
plt.ion()
ax = fig.add_subplot(111)
line, = ax.plot(0, 0, 'k', linewidth=2)
ax.set_xlim([np.min(x), np.max(x)])
ax.set_ylim([np.min(y), np.max(y)])
plt.show()
for k in range(len(x)):
line.set_data(x[:k], y[:k])
fig.canvas.draw()
time.sleep(0.0001)
Ideally, the python animation would be just as fast as matlab's using the same sort of code architecture and parameters. My intuition is that something with plt.ion() or fig.canvas.draw() is slowing down the python script. Any help here would be appreciated, and of course, showing that python can match matlab's performance here would just build on the case of why no one should be using matlab anymore these days!
Upvotes: 0
Views: 201
Reputation: 339430
Never use time.sleep
in an interactive plot.
import matplotlib.pyplot as plt
import numpy as np
x = np.arange(0, 4*np.pi, 0.05)
y = np.sin(x)
fig = plt.figure()
plt.ion()
ax = fig.add_subplot(111)
line, = ax.plot(0, 0, 'k', linewidth=2)
ax.set_xlim([np.min(x), np.max(x)])
ax.set_ylim([np.min(y), np.max(y)])
plt.draw()
for k in range(len(x)):
line.set_data(x[:k], y[:k])
plt.pause(0.0001)
Upvotes: 2