tamasgal
tamasgal

Reputation: 26309

Retrieve indices for rows of a PyTables table matching a condition using `Table.where()`

I need the indices (as numpy array) of the rows matching a given condition in a table (with billions of rows) and this is the line I currently use in my code, which works, but is quite ugly:

indices = np.array([row.nrow for row in the_table.where("foo == 42")])

It also takes half a minute, and I'm sure that the list creation is one of the reasons why.

I could not find an elegant solution yet and I'm still struggling with the pytables docs, so does anybody know any magical way to do this more beautifully and maybe also a bit faster? Maybe there is special query keyword I am missing, since I have the feeling that pytables should be able to return the matched rows indices as numpy array.

Upvotes: 2

Views: 1014

Answers (2)

And0k
And0k

Reputation: 74

tables.Table.get_where_list() gives indices of the rows matching a given condition

Upvotes: 2

HYRY
HYRY

Reputation: 97331

I read the source of pytables, where() is implemented in Cython, but it seems not fast enough. Here is a complex method that can speedup:

Create some data first:

from tables import *
import numpy as np

class Particle(IsDescription):
    name      = StringCol(16)   # 16-character String
    idnumber  = Int64Col()      # Signed 64-bit integer
    ADCcount  = UInt16Col()     # Unsigned short integer
    TDCcount  = UInt8Col()      # unsigned byte
    grid_i    = Int32Col()      # 32-bit integer
    grid_j    = Int32Col()      # 32-bit integer
    pressure  = Float32Col()    # float  (single-precision)
    energy    = Float64Col()    # double (double-precision)
h5file = open_file("tutorial1.h5", mode = "w", title = "Test file")
group = h5file.create_group("/", 'detector', 'Detector information')
table = h5file.create_table(group, 'readout', Particle, "Readout example")
particle = table.row
for i in range(1001000):
    particle['name']  = 'Particle: %6d' % (i)
    particle['TDCcount'] = i % 256
    particle['ADCcount'] = (i * 256) % (1 << 16)
    particle['grid_i'] = i
    particle['grid_j'] = 10 - i
    particle['pressure'] = float(i*i)
    particle['energy'] = float(particle['pressure'] ** 4)
    particle['idnumber'] = i * (2 ** 34)
    # Insert a new particle record
    particle.append()

table.flush()
h5file.close()

Read the column in chunks and append the indices into a list and concatenate the list to array finally. You can change the chunk size according to your memory size:

h5file = open_file("tutorial1.h5")

table = h5file.get_node("/detector/readout")

size = 10000
col = "energy"
buf = np.zeros(batch, dtype=table.coldtypes[col])
res = []
for start in range(0, table.nrows, size):
    length = min(size, table.nrows - start)
    data = table.read(start, start + batch, field=col, out=buf[:length])
    tmp = np.where(data > 10000)[0]
    tmp += start
    res.append(tmp)
res = np.concatenate(res)

Upvotes: 0

Related Questions