Ashutosh Baheti
Ashutosh Baheti

Reputation: 420

How to store multiple float arrays efficiently in a file using python

I am trying to extract embeddings from a hidden layer of LSTM. I have a huge dataset with multiple sentences and therefore those will generate multiple numpy vectors. I want to store all those vectors efficiently into a single file. This is what I have so far

with open(src_vectors_save_file, "wb") as s_writer, open(tgt_vectors_save_file, "wb") as t_writer:
    for batch in data_iter:
        encoder_hidden_layer, decoder_hidden_layer = self.extract_lstm_hidden_states_for_batch(
            batch, data.src_vocabs, attn_debug
        )
        encoder_hidden_layer = encoder_hidden_layer.detach().numpy()
        decoder_hidden_layer = decoder_hidden_layer.detach().numpy()

        enc_hidden_bytes = pickle.dumps(encoder_hidden_layer)
        dec_hidden_bytes = pickle.dumps(decoder_hidden_layer)

        s_writer.write(enc_hidden_bytes)
        s_writer.write("\n")
        t_writer.write(dec_hidden_bytes)
        t_writer.write("\n")

Essentially I am using pickle to get the bytes from the np.array and writing that in binary file. I tried to naively separate each byte encoded array with ASCII newline which obviously throws an error. I was planning to use .readlines() function or read each byte-encoded array per line using a for loop in the next program. However, that won't be possible now.

I am out of any ideas can someone suggest an alternative? How can I efficiently store all the arrays in a compressed fashion in one file and how can I read them back from that file?

Upvotes: 0

Views: 348

Answers (1)

Ajay Brahmakshatriya
Ajay Brahmakshatriya

Reputation: 9203

There is a problem with using \ns are separators because the dump from pickle (enc_hidden_bytes) could have \n in it because the data is not ASCII encoded.

There are two solutions. You can escape the \n appearing in the data and then use \n as terminators. But this adds complexity even while reading.

The other solution is to put into the file the size of the data before starting the actual data. This is like some sort of a header and is a very common practice while sending data over a connection.

You can write the following two functions -

import struct

def write_bytes(handle, data):
        total_bytes = len(data)
        handle.write(struct.pack(">Q", total_bytes))
        handle.write(data)

def read_bytes(handle):
        size_bytes = handle.read(8)
        if len(size_bytes) == 0:
            return None
        total_bytes = struct.unpack(">Q", size_bytes)[0]
        return handle.read(total_bytes)

Now you can replace

s_writer.write(enc_hidden_bytes)
s_writer.write("\n")

with

write_bytes(s_writer, enc_hidden_bytes)

and same for the other variables.

While reading back from the file in a loop you can use the read_bytes function in a similar way.

Upvotes: 1

Related Questions