user12314098
user12314098

Reputation:

Replace all indices in tensor within a range with 1s

def generate_mask(data : list, max_seq_len : int):
    """
    Generates a mask for data where each element is expected to be max_seq_len length after padding
    Args:
    data : The data being forwarded through LSTM after being converted to a tensor
    max_seq_len : The length of the names after being padded
    """
    batch_sz = len(data)
    ret = torch.zeros(1,batch_sz, max_seq_len, dtype=torch.bool)
    for i in range(batch_sz):
        name = data[i]

        for letter_idx in range(len(name)):
            ret[0][i][letter_idx] = 1

    return ret

I have this code for generating a mask and I really hate how I'm doing it. Essentially as you can see I'm just going through every name and turning each index from 0 to name length to 1, I'd prefer a more elegant way to do this.

Upvotes: 1

Views: 47

Answers (1)

Berriel
Berriel

Reputation: 13631

Well, you can simplify to something like this:

# [...]
for i in range(batch_sz):
    ret[0, i, :len(data[i])] = 1

Upvotes: 2

Related Questions