Captain Proton
Captain Proton

Reputation: 21

Reconstituting consumed sequence mask in Keras LSTM autoencoder

I'm trying to get my head around sequence-to-sequence autoencoding (LSTM) for inputs of varying number of time steps in Keras. I'm particularly interested in using such a network to encode variable-length sequences as a fixed vector (not using return_sequences=True for the LSTM encoder).

As an example, just with the masking layer...

import numpy as np
import tensorflow as tf

test_data = np.array([[[1,2,1],[3,2,1],[0,0,0],[0,0,0]],
                      [[2,1,1],[3,1,2],[1,2,1],[3,3,3]],
                      [[3,2,3],[1,3,2],[1,1,1],[0,0,0]],
                      [[1,1,2],[0,0,0],[0,0,0,],[0,0,0]],
                      [[2,1,3],[4,2,0],[1,1,2],[1,3,2]]],dtype=np.float32)
input_layer = tf.keras.layers.Input((4,3))
masking_layer = tf.keras.layers.Masking(mask_value=0)(input_layer)
# masks correctly to this point

model = tf.keras.models.Model(inputs=input_layer, outputs=masking_layer)

print(model(test_data)._keras_mask)

gives

tf.Tensor(
[[ True  True False False]
 [ True  True  True  True]
 [ True  True  True False]
 [ True False False False]
 [ True  True  True  True]], shape=(5, 4), dtype=bool)

as expected. Adding the LSTM encoding layer consumes this mask:

encode_layer = tf.keras.layers.LSTM(64)(masking_layer)
model = tf.keras.models.Model(inputs=input_layer, outputs= encode_layer)
print(model(test_data)._keras_mask)

yields nothing (again, this is expected). What I'd like to do is to (somehow) reconstitute the mask generated by the masking_layer and apply it to the RepeatVector I'm using to feed the LSTM decoder:

repeated = tf.keras.layers.RepeatVector(5)(encode_layer)
# it's here I'd like to reincarnate the mask, as it should propagate to the decoder

decode_layer = tf.keras.layers.LSTM(5,return_sequences=True)(repeated)

model = tf.keras.models.Model(inputs=input_layer, outputs=decode_layer)

I've tried implementing a Lambda layer to try and read the mask from the mask layer, but I think I'm seriously misconceiving this step

get_mask = lambda x: masking_layer._keras_mask
Lambda_layer = tf.keras.layers.Lambda(lambda x: x, mask=get_mask)(repeated)

This just generates a TypeError: () takes 1 positional argument but 2 were given.

I'd appreciate any insight you might have about resurrecting a mask like this. I'm aware that I could rethink the network entirely, but I'd like to retain the single-vector encoded representation and avoid return_sequences=True on the encoding LSTM.

Thanking you in advance.

Upvotes: 1

Views: 385

Answers (1)

Captain Proton
Captain Proton

Reputation: 21

Here's what seems to be working for me, although I'd love to hear other solutions; likely more elegant than this...

I created a custom layer, Reapply_Masking, adapted from Keras's default masking layer. It takes two input layers: Input[0] is the layer to which you're applying the mask (in my example, the RepeatVector layer); Input[1] is the layer preceding the original mask (in my example input_layer).

It re-computes the mask as the Masking layer does, and zeros out the masked timesteps of Input[0] for good measure (as does Keras's masking layer).

class Reapply_Masking(Layer):
def __init__(self, mask_value, **kwargs):
    super(Reapply_Masking, self).__init__(**kwargs)
    self.supports_masking = True
    self.mask_value = mask_value

def compute_output_shape(self, input_shape_list):
    return input_shape_list[0].shape

def compute_mask(self, input_list, mask=None):
    return K.any(math_ops.not_equal(input_list[1], self.mask_value), axis=-1)

def call(self, input_list):
    to_output = input_list[0]

    boolean_mask = array_ops.squeeze(K.any(
        math_ops.not_equal(input_list[1], self.mask_value), axis=-1, keepdims=True),axis=-1)       

    dim =(input_list[0].shape[-1])
    killer = math_ops.cast(tf.keras.backend.repeat_elements(tf.expand_dims(boolean_mask,axis=2), dim, axis=2),
                           to_output.dtype)

    outputs = to_output * killer
    outputs._keras_mask = boolean_mask

    return outputs

def get_config(self):
    config = {'mask_value': self.mask_value}
    base_config = super(Reapply_Masking, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

I seem to be able to save/restore the model, but only using .h5 format and custom_objects.

Upvotes: 1

Related Questions