Daniel
Daniel

Reputation: 173

How to speed up iteration?

I got this code, and I want to iterate over a csv file with ~100000 columns.

This script do run very slowly to iterate over that number of columns.

Do any of you have a possible solution to speed up my code?

import pandas as pd 
import matplotlib.pyplot as plt
import math    

a=pd.read_csv('test.csv', sep = ';', skiprows=[1,2],usecols = [1,2,3,4],dtype=float, decimal=',')
v1 = (math.sqrt(2)/math.sqrt(3))
v2 = (1/math.sqrt(6))
v3 = (1/math.sqrt(6))
v4 = (1/math.sqrt(2))
v5 = (1/math.sqrt(2))

fig=plt.figure()
ax=fig.add_axes([0,0,1,1])

for i in range(len(a)):
    id = ((v1*a.iloc[i,1])-(v2*a.iloc[i,2])-(v3*a.iloc[i,3]))
    iq = (v4*a.iloc[i,2])-(v5*a.iloc[i,3])
    ax.scatter(id,iq)
    print(i)

plt.show()

This is a look into my csv data:

time;A;B;C;D
(s);(mV);(mV);(mV);(mV)

0,00000000;5,43279200;-19,49701000;5,09095300;1,83738200
0,00010000;6,84287600;-17,72677000;12,59309000;2,53937200
0,00020000;4,02270800;-20,08302000;-1,94725900;2,77743900
0,00030000;4,37675500;-17,84275000;9,07703500;2,30741000
0,00040000;5,66475400;-18,90490000;12,70907000;-6,98937800
0,00050000;2,61872800;-18,43487000;4,73690600;-2,28299400
0,00060000;4,26077400;-17,01868000;12,59309000;-5,81125600
0,00070000;4,02270800;-17,61079000;17,98926000;-6,51935100
0,00080000;2,02661500;-17,01868000;10,48102000;-5,93334100
0,00090000;3,08265200;-15,72458000;17,28116000;-7,45940700
0,00100000;-0,20144060;-18,08082000;3,68086900;-6,63533100
0,00110000;3,43669900;-7,34953000;17,75119000;-6,28738900
0,00120000;1,67867200;-17,25674000;16,81724000;-4,40117200
0,00130000;-0,67146870;-15,84056000;13,41716000;-6,28738900
0,00140000;-0,43340250;-14,90050000;15,05921000;-0,51886220
0,00150000;-3,13759000;-16,78672000;3,44890700;-0,04883409
0,00160000;6,25686700;-12,06812000;17,51923000;-3,57709700

Upvotes: 1

Views: 135

Answers (2)

Pierre D
Pierre D

Reputation: 26281

Edit

The OP is having trouble reading the CSV file, presumably because of the 2-row header and the slightly unusual separator (and the decimal comma).

Here is a way to read such a file:

a = pd.read_csv(io.StringIO(txt), sep=';', decimal=',', header=[0,1])

>>> a
     time         A         B          C         D
      (s)      (mV)      (mV)       (mV)      (mV)
0  0.0000  5.432792 -19.49701   5.090953  1.837382
1  0.0001  6.842876 -17.72677  12.593090  2.539372
2  0.0002  4.022708 -20.08302  -1.947259  2.777439
...

To drop the units from the columns (facilitate indexing):

a.columns = a.columns.droplevel(1)

Then, in the expression below, replace a[1] by a['A'], etc. to get:

df = pd.DataFrame([
    v1 * a['A'] - v2 * a['B'] - v3 * a['C'],
    v4 * a['B'] - v5 * a['C'],
])

Original answer

While the time spent computing your x and y values, if you do it right, is negligible (10 ms per million rows), the time plotting all the points is not.

For 100K points, it's still tolerable, but if you have millions, consider using hist2d(), or Seaborn's jointplot() instead. Bonus, the latter also plots the marginal distributions:

a = pd.DataFrame(np.random.normal(size=(1_000_000, 4)))
v1 = np.sqrt(2/3)
v2 = 1 / np.sqrt(6)
v3 = v2
v4 = 1 / np.sqrt(2)
v5 = v4

df = pd.DataFrame(dict(
    id_=v1 * a[1] - v2 * a[2] - v3 * a[3],
    iq=v4 * a[2] - v5 * a[3]),
)
# timeit of this gives 9.89 ms ± 9.13 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
df.plot.scatter('id_', 'iq')
plt.show()
# timeit (makes multiple plots) and indicates:
# 2.91 s ± 11.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Instead:

%%timeit
sns.jointplot(
    data=df, x='id_', y='iq', kind='hex',
    marginal_kws=dict(bins=50),
    joint_kws=dict(bins=50),
)
plt.show()
# indicates:
# 852 ms ± 1.73 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Upvotes: 2

Corralien
Corralien

Reputation: 120559

Don't use loop, scatter can take array-like as argument:

Sample:

# a = pd.DataFrame(np.random.random((100000, 4)), columns=['A', 'B', 'C', 'D'])
>>> a
              A         B         C         D
0      0.797683  0.819883  0.190643  0.838554
1      0.494024  0.757094  0.863671  0.492803
2      0.935607  0.272122  0.834900  0.707307
3      0.635601  0.329287  0.703526  0.984984
4      0.117422  0.583254  0.399773  0.182749
...         ...       ...       ...       ...
99995  0.182855  0.960854  0.531180  0.242445
99996  0.632885  0.607970  0.043772  0.080374
99997  0.570511  0.214377  0.063418  0.810628
99998  0.401211  0.713925  0.573271  0.500783
99999  0.028511  0.470635  0.315194  0.019288

[100000 rows x 4 columns]

Code:

x = (v1*a[1]) - (v2*a[2]) - (v3*a[3])
y = (v4*a[2])-(v5*a[3])

fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
ax.scatter(x, y)
plt.show()

Upvotes: 1

Related Questions