Reputation: 7433
The task is to find the sum of the equation given n
and a
. So for the equation 1a + 2a^2 + 3a^3 + ... + na^n
, we can find the n-th element in the sequence with the following formula (from observation):
n-th element = a^n * (n-(n-2)/n-(n-1)) * (n-(n-3)/n-(n-2)) * ... * (n/(n-1))
I think that it's impossible to simplify the sum of n elements by modifying the above formula to a sum formula. Even if it is possible, I assume that it will involve the use of exponent n
, which will introduce a n-time loop; thus causing the solution to not be O(log n). The best solution I can get is simply find the ratio of each element, which is a(n+1)/n
and apply that to the n-1
element to find the n-th
element.
I think that I may be missing something. Could someone provide me with solution(s)?
Upvotes: 1
Views: 1022
Reputation: 44957
I assume a, n
are nonnegative integers. The explicit formula for a > 1
is
a * (n * a^{n + 1} - (n + 1) * a^n + 1) / (a - 1)^2
It can be evaluated efficiently in O(log(n))
using
square and multiply for a^n
.
To derive the formula, you could use the following ingredients:
a = 1
.Now you can simply calculate:
sum_{i = 1}^n i * a^i // [0] ugly sum
= a * sum_{i = 1}^n i * a^{i-1} // [1] linearity
= a * d/da (sum_{i = 1}^n a^i) // [2] antiderivative
= a * d/da (sum_{i = 0}^n a^i - 1) // [3] + 1 - 1
= a * d/da ((a^{n + 1} - 1) / (a - 1) - 1) // [4] geom. series
= a * ((n + 1)*a^n / (a - 1) - (a^{n+1} - 1)/(a - 1)^2) // [5] derivative
= a * (n * a^{n + 1} - (n + 1)a^n + 1) / (a - 1)^2 // [6] explicit formula
This is just a simple arithmetic expression with a^n
, which can be evaluated in O(log(n))
time using square-and-multiply.
This doesn't work for a = 0
or a = 1
, so you have to treat those cases specially: for a = 0
you just return 0
immediately, for a = 1
, you return n * (n + 1) / 2
.
Scala snippet to test the formula:
def fast(a: Int, n: Int): Int = {
def pow(a: Int, n: Int): Int =
if (n == 0) 1
else if (n == 1) a
else {
val r = pow(a, n / 2)
if (n % 2 == 0) r * r else r * r * a
}
if (a == 0) 0
else if (a == 1) n * (n + 1) / 2
else {
val aPowN = pow(a, n)
val d = a - 1
a * (n * aPowN * a - (n + 1) * aPowN + 1) / (d * d)
}
}
Slower, but simpler version, for comparison:
def slow(a: Int, n: Int): Int = {
def slowPow(a: Int, n: Int): Int = if (n == 0) 1 else slowPow(a, n - 1) * a
(1 to n).map(i => i * slowPow(a, i)).sum
}
Comparison:
for (a <- 0 to 5; n <- 0 to 5) {
println(s"${slow(a, n)} <-> ${fast(a, n)}")
}
Output:
0 <-> 0
0 <-> 0
0 <-> 0
0 <-> 0
0 <-> 0
0 <-> 0
0 <-> 0
1 <-> 1
3 <-> 3
6 <-> 6
10 <-> 10
15 <-> 15
0 <-> 0
2 <-> 2
10 <-> 10
34 <-> 34
98 <-> 98
258 <-> 258
0 <-> 0
3 <-> 3
21 <-> 21
102 <-> 102
426 <-> 426
1641 <-> 1641
0 <-> 0
4 <-> 4
36 <-> 36
228 <-> 228
1252 <-> 1252
6372 <-> 6372
0 <-> 0
5 <-> 5
55 <-> 55
430 <-> 430
2930 <-> 2930
18555 <-> 18555
So, yes, the O(log(n)) formula gives the same numbers as the O(n^2) formula.
Upvotes: 4
Reputation: 59263
You can solve this problem, and lots of problems like it, with matrix exponentiation.
Let's start with this sequence:
A[n] = a + a^2 + a^3 ... + a^n
That sequence can be generated with a simple formula:
A[i] = a*(A[i-1] + 1)
Now if we consider your sequence:
B[n] = a + 2a^2 + 3a^3 ... + na^n
We can generate that with a formula that makes use of the previous one:
B[i] = (B[i-1] + A[i-1] + 1) * a
If we make a sequence of vectors containing all the components we need:
V[n] = (B[n], A[n], 1)
Then we can construct a matrix M
so that:
V[i] = M*V[i-1]
And so:
V[n] = (M^(n-1))V[1]
Since the size of the matrix is fixed at 3x3, you can use exponentiation by squaring on the matrix itself to calculate M^(n-1)
in O(log n) time, and the final multiplication takes constant time.
Here's an implementation in python with numpy (so I don't have to include matrix multiply code):
import numpy as np
def getSum(a,n):
# A[n] = a + a^2 + a^3...a^n
# B[n] = a + 2a^2 + 3a^3 +. .. na^n
# V[n] = [B[n],A[n],1]
M = np.matrix([
[a, a, a], # B[i] = B[i-1]*a + A[i-1]*a + a
[0, a, a], # A[i] = A[i-1]*a + a
[0, 0, 1]
])
# calculate MsupN = M^(n-1)
n-=1
MsupN=np.matrix([[1,0,0],[0,1,0],[0,0,1]]);
while(n>0):
if n%2 > 0:
MsupN *= M
n-=1
M*=M
n=n/2
# calculate V[n] = MsupN*V
Vn = MsupN*np.matrix([a,a,1]).T;
return Vn.item(0,0);
Upvotes: 9
Reputation: 28157
a^n
can be indeed computed in O(log n)
.
The method is called Exponentiation by squaring and the main idea is that if you know a^n
you also know a^(2*n)
which is just a^n * a^n
.
So if you want to compute a^n
(if n is even) you can just compute a^(n/2)
and multiply the result with itself: a^n = a^(n/2) * a^(n/2)
. So instead of having a loop up to n
, now you only have a loop up to n/2
. But n/2
is just another number, and can be computed the same way, thus doing only half the operations. Halving the number of operations each time leads to the logarithmic complexity.
As mentioned by @Sopel in the comment, the series can be written as a simple equation/function:
a * (n * a^(n+1) - (n+1) * a^n + 1)
f(a,n) = ------------------------------------
(a- 1) ^ 2
So to find the answer you only have to compute the above formula, using the fast exponentiation described above to do it in O(logN)
complexity.
Upvotes: 2