Reputation:
def generate_mask(data : list, max_seq_len : int):
"""
Generates a mask for data where each element is expected to be max_seq_len length after padding
Args:
data : The data being forwarded through LSTM after being converted to a tensor
max_seq_len : The length of the names after being padded
"""
batch_sz = len(data)
ret = torch.zeros(1,batch_sz, max_seq_len, dtype=torch.bool)
for i in range(batch_sz):
name = data[i]
for letter_idx in range(len(name)):
ret[0][i][letter_idx] = 1
return ret
I have this code for generating a mask and I really hate how I'm doing it. Essentially as you can see I'm just going through every name and turning each index from 0 to name length to 1, I'd prefer a more elegant way to do this.
Upvotes: 1
Views: 47
Reputation: 13631
Well, you can simplify to something like this:
# [...]
for i in range(batch_sz):
ret[0, i, :len(data[i])] = 1
Upvotes: 2