Dhyey Shah
Dhyey Shah

Reputation: 25

Optimizing algorithm for finding palindromic primes in Python

I've been trying to solve a LeetCode problem which takes an input number (less than 10^8) and returns the next palindromic prime. Also, the answer is guaranteed to exist and is less than 2 * 10^8. My approach seems to work fine for most numbers, but the runtime increases significantly and LeetCode tells me I've exceeded the time limit when a specific number is entered (like 9989900). Is it because the gap between palindromic primes is large in that range? This is the code I've written.

    import time

    start = time.time()


    def is_prime(num: int) -> bool:
        if num < 2:
            return False
        elif num == 2 or num == 3:
            return True
        if num % 6 != 1 and num % 6 != 5:
            return False
        else:
            for i in range(3, int(num ** 0.5) + 1, 2):
                if num % i == 0:
                    return False
            else:
                return True


    def is_palindrome(num: int) -> bool:
        return str(num) == str(num)[::-1]


    class Solution:
        def primePalindrome(self, N: int):
            if N == 1:
                return 2
            elif 8 <= N < 11:
                return 11

            elif is_prime(N) and is_palindrome(N):
                return N

            # To skip even numbers, take 2 cases, i.e., when N is even and when N is odd
            elif N % 2 == 0:
                for i in range(N + 1, 2 * 10 ** 8, 2):
                    if len(str(i)) % 2 == 0:  # Because even length palindromes are divisible by 11
                        i += 2
                    elif is_palindrome(i):
                        if is_prime(i):
                            return i
                        else:
                            continue

            else:
                for i in range(N, 2 * 10 ** 8, 2):
                    if len(str(i)) % 2 == 0:
                        i += 2
                    elif is_palindrome(i):
                        if is_prime(i):
                            return i
                        else:
                            continue


    obj = Solution()
    print(obj.primePalindrome(9989900))  # 100030001
    print(time.time() - start)  # 9 seconds

Is my solution slow because of too many loops and conditional statements? How do I reduce the runtime? Solving this without using any external libraries/packages would be preferable. Thank you.

Upvotes: 0

Views: 725

Answers (1)

Alain T.
Alain T.

Reputation: 42133

Given that checking primes/palindromes sequentially isn't fast enough, I thought of this "number assembly" approach:

Given that prime numbers can only end with digits 1, 3, 7 or 9. The palindrome numbers also can only begin with these digits. So, if we generate palindrom digits in between the first and last we will get a lot fewer numbers to chck for "primality".

For example: 1xxxxxx1, 3xxxxxx3, 7xxxxxx7 and 9xxxxxx9

These middle parts must also be palindromes so we only have half the digits to consider: 1xxxyyy1 where yyy is a mirror of xxx. For odd sized middle we will have xxzyy where zyy is a mirror of xxz.

Combining this with a sequential generation of the first/last digits and digits in the middle, we can get the next number after N. By generating the most significant digits sequentially (i.e. the xxx part) we are certain that the composed numbers will be generated in an increasing sequence.

def isPrime(n):
    return n==2 if n<3 or n%2==0 else all(n%d for d in range(3,int(n**0.5)+2,2))

def nextPalPrime(N):
    digits = list(map(int,str(N)))
    while True:
        if digits[0] not in (1,3,7,9):              # advance first/last digits
            digits[0]  = [1,1,3,3,7,7,7,7,9,9][digits[0]]  
            digits[1:] = [0]*(len(digits)-1)
        digits[-1] = digits[0]
        midSize  = (len(digits)-1)//2
        midStart = int("".join(map(str,digits[1:1+midSize] or [0])))
        for middle in range(midStart,10**midSize):            # generate middle digits
            if midSize:
                midDigits = list(map(int,f"{middle:0{midSize}}")) # left half of middle
                digits[1:1+midSize]   = midDigits                 # set left half
                digits[-midSize-1:-1] = midDigits[::-1]           # set mirrored right half
            number = int("".join(map(str,digits)))
            if number>N and isPrime(number):                  # check for prime
                return number
        digits[0] += 1                                        # next first digit
        if digits[0] > 9: digits = [1]+[0]*len(digits)        # need more digits 

output:

pp = 1000
for _ in range(20):
    pp = nextPalPrime(pp)
    print(pp)

10301
10501
10601
11311
11411
12421
12721
12821
13331
13831
13931
14341
14741
15451
15551
16061
16361
16561
16661
17471

Performance:

from time import time
start=time()
print(nextPalPrime(9989900),time()-start)

100030001 0.023847103118896484

No even number of digits

Initially I was surprised that the solutions never produced a prime number with an even number of digits. but analyzing the composition of palindrome numbers I realized that those would always be multiples of 11 (so not prime):

abba     = a*1001   + b*110   
         = a*11*91  + b*11*10
         = 11*(a*91 + b*10)

abccba   = a*100001   + b*10010   + c*1100  
         = a*11*9091  + b*11*910  + c*11*100 
         = 11*(a*9091 + b*910     + c*100)

abcddcba = a*10000001   + b*1000010   + c*100100  + d*110000
         = a*11*909091  + b*11*90910  + c*11*9100 + d*11*10000
         = 11*(a*909091 + b*90910     + c*9100    + d*10000)

abcdeedcba = a*1000000001   + b*100000010  + c*10000100  + d*10010000  + e*11000000
           = a*11*90909091  + b*11*9090910 + c*11*909100 + d*11*910000 + e*11*1000000
           = 11*(a*90909091 + b*9090910    + c*909100    + d*910000    + e*1000000)

Using this observation and a more numerical approach, we get a nice performance boost:

def nextPalPrime(N):
    for width in range(len(str(N)),10):
        if width%2==0: continue
        size = width//2
        factors = [(100**(size-n)+1)*10**n for n in range(size)]+[10**size]
        for firstDigit in (1,3,7,9):
            if (firstDigit+1)*factors[0]<N: continue
            for middle in range(10**size):
                digits = [firstDigit]+[*map(int,f"{middle:0{size}}")]
                number = sum(f*d for f,d in zip(factors,digits))
                if number>N and isPrime(number):
                    return number

from time import time
start=time()
print(nextPalPrime(9989900),time()-start)

100030001 0.004210948944091797

Upvotes: 2

Related Questions