Stella
Stella

Reputation: 69

optimize python code with basic libraries

I'm trying to do a non equi self join with basic python on a table that has 1.7 millions rows and 4 variables. the data look like this:

product     position_min     position_max      count_pos
A.16        167804              167870              20
A.18        167804              167838              15
A.15        167896              167768              18
A.20        238359              238361              33
A.35        167835              167837              8

here the code i used:

import csv
from collections import defaultdict
import sys
import os

list_csv=[]
l=[]
with open(r'product.csv', 'r') as file1:
    my_reader1 = csv.reader(file1, delimiter=';')
    for row in my_reader1:
        list_csv.append(row)
with open(r'product.csv', 'r') as file2:
    my_reader2 = csv.reader(file2, delimiter=';') 
    with open('product_p.csv', "w") as csvfile_write:
        ecriture = csv.writer(csvfile_write, delimiter=';',
                                quotechar='"', quoting=csv.QUOTE_ALL)
        for row in my_reader2:
            res = defaultdict(list)
            for k in range(len(list_csv)):
                comp= list_csv[k]
                try:
                    if int(row[1]) >= int(comp[1]) and int(row[2]) <= int(comp[2]) and row[0] != comp[0]:
                        res[row[0]].append([comp[0],comp[3]]) 
                except:
                    pass
            


            if bool(res):    
                for key, value in res.items():
                    sublists = defaultdict(list)
                    for sublist in value:
                        l=[]
                        sublists[sublist[0]].append(int(sublist[1]))
                    l.append(str(key) + ";"+ str(min(sublists.keys(), key=(lambda k: sublists[k]))))
                        ecriture.writerow(l)

I should get this in the "product_p.csv" file:

'A.18'; 'A.16'
'A.15'; 'A.18'
'A.35'; 'A.18' 

What the code does is to read the same file twice, the first time completely, and convert it into a list, and the 2nd time line by line and that is to find for each product (1st variable) all the products to which it belongs by the condition on position_min and position_max and after that choose only one by keeping the product that has the minimum of count_pos .

I tried it on a sample of the original data, it works, but with 1.7 millions rows, it runs for hours without giving any results. Is there a way to dos that withour or with less loops ? could anyone help on optimizing this with basic python libraries ?

Thank you in advance

Upvotes: 4

Views: 400

Answers (3)

Peter
Peter

Reputation: 81

Using sqlite3 in-memory database the search can be moved to B-tree indexes that is more optimal than suggested ways. The following approach works 30 times faster that the original one. For generated 2M rows file it takes 44 hours to calculate result for each item (~1200 hours for original approach).

import csv
import sqlite3
import sys
import time

with sqlite3.connect(':memory:') as con:
    cursor = con.cursor()
    cursor.execute('CREATE TABLE products (id integer PRIMARY KEY, product text, position_min int, position_max int, count_pos int)')
    cursor.execute('CREATE INDEX idx_products_main ON products(position_max, position_min, count_pos)')

    with open('product.csv', 'r') as products_file:
        reader = csv.reader(products_file, delimiter=';')
        # Omit parsing first row in file
        next(reader)

        for row in reader:
            row_id = row[0][len('A.'):] if row[0].startswith('A.') else row[0];
            cursor.execute('INSERT INTO products VALUES (?, ?, ?, ?, ?)', [row_id] + row)

    con.commit()

    with open('product_p.csv', 'wb') as write_file:
        with open('product.csv', 'r') as products_file:
            reader = csv.reader(products_file, delimiter=';')
            # Omit parsing first row in file
            next(reader)

            for row in reader:
                row_product_id, row_position_min, row_position_max, count_pos = row
                result_row = cursor.execute(
                    'SELECT product, count_pos FROM products WHERE position_min <= ? AND position_max >= ? ORDER BY count_pos, id LIMIT 1',
                    (row_position_min, row_position_max)
                ).fetchone()

                if (result_row and result_row[0] == row_product_id):
                    result_row = cursor.execute(
                        'SELECT product, count_pos FROM products WHERE product != ? AND position_min <= ? AND position_max >= ? ORDER BY count_pos, id LIMIT 1',
                        (row_product_id, row_position_min, row_position_max)
                    ).fetchone()

                if (result_row):
                    write_file.write(f'{row_product_id};{result_row[0]};{result_row[1]}\n'.encode())

Further optimisation can be done using threading if needed and the result process can be optimised to take 4-5 hours using 10 threads for instance.

Upvotes: 0

gimix
gimix

Reputation: 3823

I think a different approach is needed here, because comparing each product with each other will always give a time complexity of O(n^2).

I sorted the product list by ascending position_min (and descending position_max, just in case) and reversed the check from the answer above: instead of seeing if comp "contains" ref I did the opposite. This way it is possible to check each product only against those with a higher position_min, and also to stop the search as soon as a comp is found whose position_min is higher than position_max of ref.

To test this solution I generated a random list of 100 products and run both one function copied from the answer above, and one function based on my suggestion. The latter executes about 1000 comparisons instead of 10000, and according to timeit it is about 4x faster despite the overhead due to the initial sort.

Code follows:

##reference function
def f1(basedata):
    outd={}
    for ref in basedata:
        for comp in basedata:
            if ref == comp:
                continue
            elif ref[1] >= comp[1] and ref[2] <= comp[2]:
                if not outd.get(ref[0], False) or comp[3] < outd[ref[0]][1]:
                    outd[ref[0]] = (comp[0], comp[3])
    return outd

##optimized(?) function
def f2(basedata):
    outd={}
    sorteddata = sorted(basedata, key=lambda x:(x[1],-x[2]))
    runs = 0
    for i,ref in enumerate(sorteddata):
        toohigh=False
        j=i
        while j < len(sorteddata)-1 and not toohigh:
            j+=1
            runs+=1
            comp=sorteddata[j]
            if comp[1] > ref[2]:
                toohigh=True
            elif comp[2] <= ref[2]:
                if not outd.get(comp[0], False) or ref[3] < outd[comp[0]][1]:
                    outd[comp[0]] = (ref[0], ref[3])
    print(runs)
    return outd

Upvotes: 3

MatBBastos
MatBBastos

Reputation: 401

I removed some libraries that were not used and tried to simplify the behavior of the code as much as I could.

The most important objects in the code are the list input_data, that stores data from the input csv file and the dict out_dict, that stores the output of the comparisons.

Simply put, what the code does is:

  1. Reads product.csv (without headers) into a list input_data
  2. Iterates through input_data comparing each row to each other row
    • If the reference product range is within the comparing product range, we check a new condition: is there something in out_dict for the reference product?
      • If yes, we replace it with the new comparing product if it has a lower count_pos
      • If not, we add the comparing product regardless
  3. Writes the information in out_dict to the output file product_p.csv, but only for products that had valid comparing products

And here it is:

import csv

input_data = []
with open('product.csv', 'r') as csv_in:
    reader = csv.reader(csv_in, delimiter=';')
    next(reader)
    for row in reader:
        input_data.append(row)


out_dict = {}
for ref in input_data:
    for comp in input_data:
        if ref == comp:
            continue
        elif int(ref[1]) >= int(comp[1]) and int(ref[2]) <= int(comp[2]):
            if not out_dict.get(ref[0], False) or int(comp[3]) < out_dict[ref[0]][1]:
                out_dict[ref[0]] = (comp[0], int(comp[3]))
                # print(f"In '{ref[0]}': placed '{comp[0]}'")


with open('product_p.csv', "w") as csv_out:
    ecriture = csv.writer(csv_out, delimiter=';', quotechar='"', quoting=csv.QUOTE_ALL)
    for key, value in out_dict.items():
        if value[0]:
            ecriture.writerow([key, value[0]])

Also, I commented out a print line that can show you - using a sample file with only a few rows - what the script is doing.


Note: I believe your expected output is wrong. Either that or I'm missing something from the explanation. If that's the case, do tell me. The code presented takes this into account.

From the sample input:

product;position_min;position_max;count_pos
A.16;167804;167870;20
A.18;167804;167838;15
A.15;167896;167768;18
A.20;238359;238361;33
A.35;167835;167837;8

The expected output would be:

"A.18";"A.16"
"A.15";"A.35"
"A.35";"A.18"

Since, for "A.15", "A.35" satisfies the same conditions as "A.16" and "A.18" and has the smaller count_pos.

Upvotes: 1

Related Questions