Tauquir
Tauquir

Reputation: 6913

How to to optimize this python problem?

Here is the problem. And my solution is:

import sys
LIST_ITEM = []
NUMBER_OF_TEST_CASES = int(raw_input())
def non_decreasing(my_list):
    if len(my_list) < 2:
        return my_list
    my_list.sort()
    return my_list

if __name__ == "__main__":
    if NUMBER_OF_TEST_CASES >= 1000001 or NUMBER_OF_TEST_CASES <= 0:
        sys.exit()

    for val in range(1, NUMBER_OF_TEST_CASES+1):
        x = int(raw_input())

        if x >= 1000001 or x<0:
            sys.exit()
        else:
            LIST_ITEM.append(x)
    values =  non_decreasing(LIST_ITEM)
    for i in values:
        print i

But it tell me Time Limit Exceeded. Here is my solution link

Upvotes: 1

Views: 541

Answers (3)

Shawn Chin
Shawn Chin

Reputation: 86924

Since the input data value and size is limited to 10^6, you can simple initialise an array of 10^6 values and keep track of what values have appeared. Sorting and duplicate detection "comes for free".

This method will have an initial cost (initialising the array) but it would be worth it for large input sizes.

Example: import array from sys import stdin from itertools import repeat

low  = 1000001
high = -1
data = array.array('i', repeat(-1, 1000000))
count = int(stdin.readline().strip())

while count:
    v = int(stdin.readline().strip())
    count -= 1
    data[v] = v
    low  = min(v, low)
    high = max(v, high)

for v in xrange(low, high+1):
    if data[v] > 0:
        print v

Note that I used array since the size and type is known in advance so we can bypass the overheads that come with using list.

If there is a limit on memory usage, one could use a bit array instead with will reduce the size of data but will incur some additional overheads (and complexity) when setting and iterating the values.

Upvotes: 1

chmullig
chmullig

Reputation: 13416

EDIT: Well shit, apparently you aren't supposed to have dupes? That changes things, and your current code doesn't work right anyway. Leaving my old answer for now...

The simple answer is profile it! I generated data with:

import random
out = open('foobar.txt', 'w')

total = random.randint(100000, 1e6)
out.write('%s\n' % total)

for x in xrange(total):
    out.write('%s\n' % random.randint(0, 1e6))

I then tested with the command: time python -m cProfile -o foo.profile foo.py < foobar.txt > fooout.txt && gprof2dot -f pstats foo.profile | dot -Tpng -o foo_profile.png. That generates this nifty graphic using the gprof2dot tool and reports the time it took to run (1.9s on my system with 266k input rows). sort -n foobar.txt > foo_sorted.txt is my gold standard, at ~0.41s.

So you can see that 44.81% of your time is spent in your basic code itself, 38.82% is spent in raw_input, and 14% is spent in sort.

profile output

So next we start to optimize.

First is to put your code into a method. Just add a def main() around all your code, and at the end if __name__ == '__main__': main(). For me that shaved runtime down by about 5%, to 1.8s, and moved raw_input to the highest percent of our load.

Let's see if we can shave that down. Perhaps replace raw_input with a direct use of sys.stdin? I assume that raw_input is designed for interactive use, and probably not profiled super well since it's (probably) not intended for heavy use. By substituting raw_input with something like sys.stdin.readline() we should use a more efficient code path. For me that brings the runtime down from 1.8s to 0.952s. A savings of half! Here's the code now, and the profile output.

import sys 

def non_decreasing(my_list):
    if len(my_list) < 2:
        return my_list
    my_list.sort()
    return my_list

def main():
    LIST_ITEM = []
    NUMBER_OF_TEST_CASES = int(sys.stdin.readline().strip())

    if NUMBER_OF_TEST_CASES >= 1000001 or NUMBER_OF_TEST_CASES <= 0:
        sys.exit()

    for x in sys.stdin:
        x = int(x.strip())

        if x >= 1000001 or x<0:
            sys.exit()
        else:
            LIST_ITEM.append(x)
    values =  non_decreasing(LIST_ITEM)
    for i in values:
        print i

if __name__ == '__main__':
    main()

Revised So that's a good start. We're now at less than half our original runtime. Let's look at what's slow now. The main function, sort, strip(), and append. Perhaps we can optimize something in main? Well, I notice we're printing out the lines one by one. Could we switch that out with a single sys.stdout.write() and see if that helps? I tried sys.stdout.writelines([str(x) for x in values]) and it actually seemed slower, so I guess print is super efficient. Let's stick with that.

What else can we reduce? Maybe if x >= 1000001 or x<0:statement? Is it entirely necessary? Looks like we can get rid of a few hundredths of a second, easily, by removing it.

What else? Perhaps the whole non_decreasing thing is unnecessary, and we can just use LIST_ITEM.sort()? I imagine your check and extra function call isn't actually speeding anything up. Yup, that speeds it up a little bit!

Ideally at this point we'd do something like not strip the newlines from the input, sort as strings and then write it out. Unfortunately that doesn't get the desired sorting :( So let's try some alternatives

  1. for x in sys.stdin: values.append(x[:-1])
  2. x.rstrip()
  3. x.rstrip('\n')
  4. values = sys.stdin.split('\n')
  5. values = sys.stdin.read().splitlines()
  6. values = sys.stdin.readlines()

In my testing a variant on #1 is fastest and maintains the correctness, at ~.783s. Here's my final code:

import sys

def main():
    NUMBER_OF_TEST_CASES = int(sys.stdin.readline().strip())
    if NUMBER_OF_TEST_CASES >= 1000001 or NUMBER_OF_TEST_CASES <= 0:
        sys.exit()

    values = [int(x) for x in sys.stdin.readlines()]
    values.sort()
    for i in values:
        print i

if __name__ == '__main__':
    main()

And the final gprof2dot profile info... enter image description here

Upvotes: 3

Andrew Jaffe
Andrew Jaffe

Reputation: 27097

Don't sort!

Just create a length-1e6 array of zeros (in numpy, perhaps), read through the list setting a[i]=1 whenever you encounter the number i, and then print out all the nonzero entries.

I think this works:

import numpy as np
import sys

nn = 1000000
a = np.zeros(nn+1, dtype=np.int)

for l in sys.stdin:
    a[np.int(l)]=1

for i in xrange(nn+1):
    if a[i]: print i

I expect there's a way to speed up the i/o.

Upvotes: 2

Related Questions