Tobias Hermann
Tobias Hermann

Reputation: 10926

How to rename the layers of a Keras model without corrupting the structure?

For some library functionality I'm trying to rename the layers (including the input layers) of a given model.

The following minimal example shows the error I run into with my current approach (using TensorFlow 2.3):

from tensorflow.keras.models import load_model

model = load_model("model.h5")
for layer in model.layers:
    layer._name = layer.name + "_renamed"

model.to_json()
ValueError: The target structure is of type `<class 'tensorflow.python.framework.ops.Tensor'>`
  Tensor("input_1:0", shape=(None, 4), dtype=float32)
However the input structure is a sequence (<class 'list'>) of length 0.

The model.h5 file might have been created like this, for example:

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

inputs = Input(shape=(4,))
x = Dense(5, activation='relu', name='a')(inputs)
x = Dense(3, activation='softmax', name='b')(x)
model = Model(inputs=inputs, outputs=x)
model.compile(loss='categorical_crossentropy', optimizer='nadam')
model.save("model.h5")

Any idea on how to fix this?

Upvotes: 4

Views: 2959

Answers (1)

OverLordGoldDragon
OverLordGoldDragon

Reputation: 19776

Problem: Keras serializes the network by traversing layer._inbound_nodes and comparing against model._network_nodes; when setting layer._name, latter persists original names.


Solution: rename _network_nodes accordingly. Working function at bottom, with example below:

from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Model

ipt = Input((16,))
out = Dense(16)(ipt)
model = Model(ipt, out)
model.compile('sgd', 'mse')

rename(model, model.layers[1], 'new_name')
model.save('model.h5')
loaded = load_model('model.h5')

Note: layer.name is a @property without a .setter, meaning it's not meant to be set (as evident). Further, layer.__setattr__ is overridden, and performs steps in addition to setting an attribute - likely necessary, but can't be sure exactly what other effects it may have. I've included an alternative which bypasses these. Treat this as a temporary solution at best; I suggest opening an Issue on Github, as API-side changes are due.


Function:

Not foolproof - _get_node_suffix's naming logic needs work (e.g. dense_1 can confound with dense_11).

def rename(model, layer, new_name):
    def _get_node_suffix(name):
        for old_name in old_nodes:
            if old_name.startswith(name):
                return old_name[len(name):]

    old_name = layer.name
    old_nodes = list(model._network_nodes)
    new_nodes = []

    for l in model.layers:
        if l.name == old_name:
            l._name = new_name
            # vars(l).__setitem__('_name', new)  # bypasses .__setattr__
            new_nodes.append(new_name + _get_node_suffix(old_name))
        else:
            new_nodes.append(l.name + _get_node_suffix(l.name))
    model._network_nodes = set(new_nodes)

Upvotes: 6

Related Questions