Georgii Firsov
Georgii Firsov

Reputation: 316

Issue passing concatenated inputs to LSTM in keras

I have several neural networks. Their outputs are concatenated and then passed to LSTM.

Here is a simplified code snippet:

import keras.backend as K

from keras.layers import Input, Dense, LSTM, concatenate
from keras.models import Model

# 1st NN
input_l1 = Input(shape=(1, ))
out_l1 = Dense(1)(input_l1)

# 2nd NN
input_l2 = Input(shape=(1, ))
out_l2 = Dense(1)(input_l2)

# concatenated layer
concat_vec = concatenate([out_l1, out_l2])

# expand dimensions to (None, 2, 1)
expanded_concat = K.expand_dims(concat_vec, axis=2)

lstm_out = LSTM(10)(expanded_concat)

model = keras.Model(inputs=[input_l1, input_l2], outputs=lstm_out)

Unfortunately I get an error on the last line:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-53-a16fe60c0fc3> in <module>
      2 lstm_out = LSTM(10)(expanded_concat)
      3 
----> 4 model = keras.Model(inputs=[input_l1, input_l2], outputs=lstm_out)

/usr/local/lib/python3.9/site-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in __init__(self, *args, **kwargs)
     91                 'inputs' in kwargs and 'outputs' in kwargs):
     92             # Graph network
---> 93             self._init_graph_network(*args, **kwargs)
     94         else:
     95             # Subclassed network

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in _init_graph_network(self, inputs, outputs, name)
    228 
    229         # Keep track of the network's nodes and layers.
--> 230         nodes, nodes_by_depth, layers, layers_by_depth = _map_graph_network(
    231             self.inputs, self.outputs)
    232         self._network_nodes = nodes

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in _map_graph_network(inputs, outputs)
   1361     for x in outputs:
   1362         layer, node_index, tensor_index = x._keras_history
-> 1363         build_map(x, finished_nodes, nodes_in_progress,
   1364                   layer=layer,
   1365                   node_index=node_index,

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)
   1350             node_index = node.node_indices[i]
   1351             tensor_index = node.tensor_indices[i]
-> 1352             build_map(x, finished_nodes, nodes_in_progress, layer,
   1353                       node_index, tensor_index)
   1354 

/usr/local/lib/python3.9/site-packages/keras/engine/network.py in build_map(tensor, finished_nodes, nodes_in_progress, layer, node_index, tensor_index)
   1323             ValueError: if a cycle is detected.
   1324         """
-> 1325         node = layer._inbound_nodes[node_index]
   1326 
   1327         # Prevent cycles.

AttributeError: 'NoneType' object has no attribute '_inbound_nodes'

Is there a way to fix it? If it is important I use PlaidML backend as the only option for macOS with discrete GPU support.

Upvotes: 0

Views: 327

Answers (1)

user11530462
user11530462

Reputation:

To achieve the goal here you can use Reshape layer, that convert input into the target shape.

Keras is integrated with Tensorflow. Here is the working code in Tensorflow version.

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LSTM, concatenate
from tensorflow.keras.models import Model

# 1st NN
input_l1 = Input(shape=(1, ))
out_l1 = Dense(1)(input_l1)

# 2nd NN
input_l2 = Input(shape=(1, ))
out_l2 = Dense(1)(input_l2)

# concatenated layer
concat_vec = concatenate([out_l1, out_l2])

# expand dimensions to (None, 2, 1)
expanded_concat = tf.keras.layers.Reshape((2, 1))(concat_vec)
#expanded_concat = K.expand_dims(concat_vec, axis=2)

lstm_out = LSTM(10)(expanded_concat)

model = Model(inputs=[input_l1, input_l2], outputs=lstm_out)
model.summary()

Output:

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
dense (Dense)                   (None, 1)            2           input_1[0][0]                    
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 1)            2           input_2[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 2)            0           dense[0][0]                      
                                                                 dense_1[0][0]                    
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 2, 1)         0           concatenate[0][0]                
__________________________________________________________________________________________________
lstm (LSTM)                     (None, 10)           480         reshape_1[0][0]                  
==================================================================================================
Total params: 484
Trainable params: 484
Non-trainable params: 0
__________________________________________________________________________________________________

Upvotes: 1

Related Questions