Reputation: 407
I was wondering whether it is possible to optimise the following using Numpy
or mathematical trickery.
def f1(g, b, dt, t1, t2):
p = np.copy(g)
for i in range(dt):
p += t1*np.tanh(np.dot(p, b)) + t2*p
return p
where g
is a vector of length n
, b
is an n
xn
matrix, dt
is the number of iterations, and t1
and t2
are scalars.
I have quickly ran out of ideas on how to optimise this further, because p
is used within the loop, in all three terms of the equation: when added to itself; in the dot product; and in a scalar multiplication.
But maybe there is a different way to represent this function or there are other tricks to improve its efficiency. If possible, I would prefer not to use Cython
etc., but I'd be willing to use it if the speed improvements are significant. Thanks in advance, and apologies if the question is out of scope somehow.
The answers provided so far are more focused on what the values of the input/output could be to avoid unnecessary operations. I have now updated the MWE with proper initialisation values for the variables (I didn't expect the optimisation ideas to come from that side -- apologies). g
will be in the range [-1, 1]
and b
will be in the range [-infinity, infinity]
. Approximating the output is not an option because the returned vectors are later given to an evaluation function -- approximation may return the same vector for fairly similar input, so it is not an option.
import numpy as np
import timeit
iterations = 10000
setup = """
import numpy as np
n = 100
g = np.random.uniform(-1, 1, (n,)) # Updated.
b = np.random.uniform(-1, 1, (n,n)) # Updated.
dt = 10
t1 = 1
t2 = 1/2
def f1(g, b, dt, t1, t2):
p = np.copy(g)
for i in range(dt):
p += t1*np.tanh(np.dot(p, b)) + t2*p
return p
"""
functions = [
"""
p = f1(g, b, dt, t1, t2)
"""
]
if __name__ == '__main__':
for function in functions:
print(function)
print('Time = {}'.format(timeit.timeit(function, setup=setup,
number=iterations)))
Upvotes: 11
Views: 475
Reputation: 88158
I hesitate to give this as an answer, as I think it may be an artifact of the input data you gave us. Nevertheless, note that tanh(x) ~ 1
for for x>>1
. Your input data, at all times I've run it has x = np.dot(p,b) >> 1
, hence we can replace the f1
with f2
.
def f1(g, b, dt, t1, t2):
p = np.copy(g)
for i in range(dt):
p += t1*np.tanh(np.dot(p, b)) + t2*p
return p
def f2(g, b, dt, t1, t2):
p = np.copy(g)
for i in range(dt):
p += t1 + t2*p
return p
print np.allclose(f1(g,b,dt,t1,t2), f2(g,b,dt,t1,t2))
Which indeed shows the two functions are numerically equivalent. Note that f2 is a non-homogeneous linear recurrence relation, and can be solved in one step if you choose to do so.
Upvotes: 4
Reputation: 54340
To get the code running much faster without cython
or jit
will be very hard, some mathematical trickery may be more the easier approach. It appears to me that if we define a k(g, b) = f1(g, b, n+1, t1, t2)/f1(g, b, n, t1, t2)
for n
in positive N, the k
function should have a limit of t1+t2
(don't have a solid proof yet, just a gut feeling; it may be a special case for E(g)=0 & E(p)=0 also.). For t1=1
and t2=0.5
, k()
appears to approach the limit fairly quickly, for N>100
, it is almost a constant of 1.5
.
So I think a numerical approximation approach should be the easiest one.
In [81]:
t2=0.5
data=[f1(g, b, i+2, t1, t2)/f1(g, b, i+1, t1, t2) for i in range(1000)]
In [82]:
plt.figure(figsize=(10,5))
plt.plot(data[0], '.-', label='1')
plt.plot(data[4], '.-', label='5')
plt.plot(data[9], '.-', label='10')
plt.plot(data[49], '.-', label='50')
plt.plot(data[99], '.-', label='100')
plt.plot(data[999], '.-', label='1000')
plt.xlim(xmax=120)
plt.legend()
plt.savefig('limit.png')
In [83]:
data[999]
Out[83]:
array([ 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5, 1.5,
1.5])
Upvotes: 4