Reputation: 6338
Suppose I have a panda dataframe of shape = (2,500,000, M) and a scipy csr sparse matrix of shape (2,500,000, N).
Each row of the dataframe and sparse matrix describes one entity. They are already ordered such that row 1 of the dataframe is describing an entity that is also found in row 1 of the sparse matrix. So now the dataframe has a fast mechanism to do filtering (catalogue.where(catalogue.some_column != ''
), but how do I find the respective rows in the sparse matrix given the filtered dataframe?
Assume the dataframe is called a catalogue
, and the sparse matrix is called a collection
def collection_filter_row(catalogue_filtered, catalogue_index_full, collection):
return scipy.sparse.vstack(ThreadPool(100).map(
functools.partial(collection_get_row,
catalogue_index=tuple(catalogue_index_full),
collection=collection),
tuple(catalogue_filtered.index.values)))
def collection_get_row(document_id, catalogue_index, collection):
return collection.getrow(catalogue_index.index(document_id))
collection_partial = partial(
collection_filter_row,
catalogue_index_full=catalogue.index.values,
collection=pickle.load(open('collection-tfidf', 'rb')))
criteria = catalogue['criteria'].where(catalogue.criteria != '')
collection_state = collection_partial(criteria)
but even with any sort of multiprocessing (gevent, threadpool), it is still slow to pick the respective rows, am I doing anything wrong (or rather, is there a faster way of doing this)?
Upvotes: 1
Views: 992
Reputation: 6338
Somehow found a faster way to solve this problem. Start by creating a dictionary of catalogue
index => collection
index.
index_dict = dict(zip(
catalogue.index.values.tolist(),
range(collection.shape[0])))
Then my collection_filter_row
becomes
def collection_filter_row(catalogue_filtered, index_dict, collection):
return collection[[index_dict[document_id]
for document_id
in catalogue_filtered.index.values.tolist()]]
In order to return a subset of collection, instead of using catalogue.where()
I really should be using catalogue.loc[catalogue.some_column != '']
, so the proper call to collection_filter_row
is then
collection_sub = collection_filter_row(
catalogue.loc[catalogue.some_column != ''],
index_dict,
collection)
much much faster than the original method shown in question
Upvotes: 1