xuhdev
xuhdev

Reputation: 9380

How to test the correctness of a Keras custom layer?

After creating a Keras custom layer with training weight, how can one test the correctness of the code? It does not seem to be described in Keras' manual.

For example, to test the expected behavior of a function, one can write a unit test. How can we do this for a Keras custom layer?

Upvotes: 5

Views: 2629

Answers (2)

Markus Weber
Markus Weber

Reputation: 1107

layer_test in keras utils. https://github.com/keras-team/keras/blob/master/keras/utils/test_utils.py

They provide following code, which tests the shape, the actual result, serializing and training:

def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
               input_data=None, expected_output=None,
               expected_output_dtype=None, fixed_batch_size=False):
    """Test routine for a layer with a single input tensor
    and single output tensor.
    """
    # generate input data
    if input_data is None:
        assert input_shape
        if not input_dtype:
            input_dtype = K.floatx()
        input_data_shape = list(input_shape)
        for i, e in enumerate(input_data_shape):
            if e is None:
                input_data_shape[i] = np.random.randint(1, 4)
        input_data = (10 * np.random.random(input_data_shape))
        input_data = input_data.astype(input_dtype)
    else:
        if input_shape is None:
            input_shape = input_data.shape
        if input_dtype is None:
            input_dtype = input_data.dtype
    if expected_output_dtype is None:
        expected_output_dtype = input_dtype

    # instantiation
    layer = layer_cls(**kwargs)

    # test get_weights , set_weights at layer level
    weights = layer.get_weights()
    layer.set_weights(weights)

    expected_output_shape = layer.compute_output_shape(input_shape)

    # test in functional API
    if fixed_batch_size:
        x = Input(batch_shape=input_shape, dtype=input_dtype)
    else:
        x = Input(shape=input_shape[1:], dtype=input_dtype)
    y = layer(x)
    assert K.dtype(y) == expected_output_dtype

    # check with the functional API
    model = Model(x, y)

    actual_output = model.predict(input_data)
    actual_output_shape = actual_output.shape
    for expected_dim, actual_dim in zip(expected_output_shape,
                                        actual_output_shape):
        if expected_dim is not None:
            assert expected_dim == actual_dim

    if expected_output is not None:
        assert_allclose(actual_output, expected_output, rtol=1e-3)

    # test serialization, weight setting at model level
    model_config = model.get_config()
    recovered_model = model.__class__.from_config(model_config)
    if model.weights:
        weights = model.get_weights()
        recovered_model.set_weights(weights)
        _output = recovered_model.predict(input_data)
        assert_allclose(_output, actual_output, rtol=1e-3)

    # test training mode (e.g. useful when the layer has a
    # different behavior at training and testing time).
    if has_arg(layer.call, 'training'):
        model.compile('rmsprop', 'mse')
        model.train_on_batch(input_data, actual_output)

    # test instantiation from layer config
    layer_config = layer.get_config()
    layer_config['batch_input_shape'] = input_shape
    layer = layer.__class__.from_config(layer_config)

    # for further checks in the caller function
    return actual_output

Upvotes: 2

Suba Selvandran
Suba Selvandran

Reputation: 314

You can still do something like unit test by getting the output of the custom layer for the given input and verifying it against the manually calculated output,

Let's say your custom layer Custom takes (None, 3, 200) as input shape and returns (None, 3)

from keras.layers import Input
from keras.models import Model

inp = Input(shape=(3, 200))
out = Custom()(inp)
model = Model(inp, out)

output = model.predict(your_input)

You can verify the layer output output with your expected output for a known input your_input.

Upvotes: 5

Related Questions