Reputation: 420
I am trying to extract embeddings from a hidden layer of LSTM. I have a huge dataset with multiple sentences and therefore those will generate multiple numpy vectors. I want to store all those vectors efficiently into a single file. This is what I have so far
with open(src_vectors_save_file, "wb") as s_writer, open(tgt_vectors_save_file, "wb") as t_writer:
for batch in data_iter:
encoder_hidden_layer, decoder_hidden_layer = self.extract_lstm_hidden_states_for_batch(
batch, data.src_vocabs, attn_debug
)
encoder_hidden_layer = encoder_hidden_layer.detach().numpy()
decoder_hidden_layer = decoder_hidden_layer.detach().numpy()
enc_hidden_bytes = pickle.dumps(encoder_hidden_layer)
dec_hidden_bytes = pickle.dumps(decoder_hidden_layer)
s_writer.write(enc_hidden_bytes)
s_writer.write("\n")
t_writer.write(dec_hidden_bytes)
t_writer.write("\n")
Essentially I am using pickle
to get the bytes
from the np.array
and writing that in binary file. I tried to naively separate each byte encoded array with ASCII newline which obviously throws an error. I was planning to use .readlines()
function or read each byte-encoded array per line using a for
loop in the next program. However, that won't be possible now.
I am out of any ideas can someone suggest an alternative? How can I efficiently store all the arrays in a compressed fashion in one file and how can I read them back from that file?
Upvotes: 0
Views: 348
Reputation: 9203
There is a problem with using \n
s are separators because the dump from pickle (enc_hidden_bytes
) could have \n
in it because the data is not ASCII encoded.
There are two solutions. You can escape the \n
appearing in the data and then use \n
as terminators. But this adds complexity even while reading.
The other solution is to put into the file the size of the data before starting the actual data. This is like some sort of a header and is a very common practice while sending data over a connection.
You can write the following two functions -
import struct
def write_bytes(handle, data):
total_bytes = len(data)
handle.write(struct.pack(">Q", total_bytes))
handle.write(data)
def read_bytes(handle):
size_bytes = handle.read(8)
if len(size_bytes) == 0:
return None
total_bytes = struct.unpack(">Q", size_bytes)[0]
return handle.read(total_bytes)
Now you can replace
s_writer.write(enc_hidden_bytes)
s_writer.write("\n")
with
write_bytes(s_writer, enc_hidden_bytes)
and same for the other variables.
While reading back from the file in a loop you can use the read_bytes
function in a similar way.
Upvotes: 1