selfAlex
selfAlex

Reputation: 23

Optimizing the algorithm for brute force numbers in the problem

How many pairs (i, j) exist such that 1 <= i <= j <= n, j - i <= a? 'n' and 'a' input numbers.

The problem is my algorithm is too slow when increasing 'n' or 'a'.
I cannot think of a correct algorithm.
Execution time must be less than 10 seconds.

Tests:


n, a = input().split()

i = 1
j = 1

answer = 0

while True:
    if n >= j:
        if a >= (j-i):
            answer += 1

            j += 1

        else:
            i += 1
            j = i

            if j > n:
                break

    else:
        i += 1
        j = i

        if j > n:
            break

print(answer)

Upvotes: 1

Views: 128

Answers (3)

Yogesh
Yogesh

Reputation: 801

One can derive a direct formula to solve this problem.

ans = ((a+1)*a)/2 + (a+1) + (a+1)*(n-a-1)

Thus the time complexity is O(1). This is the fastest way to solve this problem.

Derivation:

The first a number can have pairs of (a+1)C2 + (a+1). Every additional number has 'a+1' options to pair with. So, therefore, there are n-a-1 number remaining and have (a+1) options, (a+1)*(n-a-1)

Therefore the final answer is (a+1)C2 + (a+1) + (a+1)*(n-a-1) implies ((a+1)*a)/2 + (a+1) + (a+1)*(n-a-1).

Upvotes: 3

John Coleman
John Coleman

Reputation: 51998

You are using a quadratic algorithm but you should be able to get it to linear (or perhaps even better).

The main problem is to determine how many pairs, given i and j. It is natural to split that off into a function.

A key point is that, for i fixed, the j which go with that i are in the range i to min(n,i+a). This is since j-i <= a is equivalent to j <= i+a.

There are min(n,i+a) - i + 1 valid j for a given i. This leads to:

def count_pairs(n,a):
    return sum(min(n,i+a) - i + 1 for i in range(1,n+1))

count_pairs(898982,40000) evaluates to 35160158982 in about a second. If that is still to slow, do some more mathematical analysis.

Upvotes: 3

Red
Red

Reputation: 27567

Here is an improvement:

n, a = map(int, input().split())

i = 1
j = 1

answer = 0

while True:
    if n >= j <= a + i:
        answer += 1
        j += 1
        continue
    i = j = i + 1
    if j > n:
        break
        
print(answer)

Upvotes: 1

Related Questions