Reza Akraminejad
Reza Akraminejad

Reputation: 1499

Stellargraph GraphSAGE sample google Colab notebook model.predict error

I started working on running sample of stellargraph python module to run GraphSAGE algorithm sample from this link:

https://stellargraph.readthedocs.io/en/stable/demos/node-classification/graphsage-node-classification.html

Although I can run algorithm until this line [20]:

all_nodes = node_subjects.index
all_mapper = generator.flow(all_nodes)
all_predictions = model.predict(all_mapper)

but when it wants to call predict method I receive this error:

ValueError: in user code:

ValueError: Layer "model_1" expects 3 input(s), but it received 1 input tensors. Inputs received: [<tf.Tensor 'IteratorGetNext:0' shape=(None, None, None) dtype=float32>]

In spite of running code from validated site but it receives errors of layers of neural network incompatibily. I do not know what to do.

I also added the model.summary() for better results:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_2 (InputLayer)        [(None, 10, 1433)]           0         []                            
                                                                                                  
 input_3 (InputLayer)        [(None, 50, 1433)]           0         []                            
                                                                                                  
 input_1 (InputLayer)        [(None, 1, 1433)]            0         []                            
                                                                                                  
 reshape (Reshape)           (None, 1, 10, 1433)          0         ['input_2[0][0]']             
                                                                                                  
 reshape_1 (Reshape)         (None, 10, 5, 1433)          0         ['input_3[0][0]']             
                                                                                                  
 dropout_1 (Dropout)         (None, 1, 1433)              0         ['input_1[0][0]']             
                                                                                                  
 dropout (Dropout)           (None, 1, 10, 1433)          0         ['reshape[0][0]']             
                                                                                                  
 dropout_3 (Dropout)         (None, 10, 1433)             0         ['input_2[0][0]']             
                                                                                                  
 dropout_2 (Dropout)         (None, 10, 5, 1433)          0         ['reshape_1[0][0]']           
                                                                                                  
 mean_aggregator (MeanAggre  multiple                     45888     ['dropout_1[0][0]',           
 gator)                                                              'dropout[0][0]',             
                                                                     'dropout_3[0][0]',           
                                                                     'dropout_2[0][0]']           
                                                                                                  
 reshape_2 (Reshape)         (None, 1, 10, 32)            0         ['mean_aggregator[1][0]']     
                                                                                                  
 dropout_5 (Dropout)         (None, 1, 32)                0         ['mean_aggregator[0][0]']     
                                                                                                  
 dropout_4 (Dropout)         (None, 1, 10, 32)            0         ['reshape_2[0][0]']           
                                                                                                  
 mean_aggregator_1 (MeanAgg  (None, 1, 32)                1056      ['dropout_5[0][0]',           
 regator)                                                            'dropout_4[0][0]']           
                                                                                                  
 reshape_3 (Reshape)         (None, 32)                   0         ['mean_aggregator_1[0][0]']   
                                                                                                  
 lambda (Lambda)             (None, 32)                   0         ['reshape_3[0][0]']           
                                                                                                  
 dense (Dense)               (None, 7)                    231       ['lambda[0][0]']              
                                                                                                  
==================================================================================================
Total params: 47175 (184.28 KB)
Trainable params: 47175 (184.28 KB)
Non-trainable params: 0 (0.00 Byte)
__________________________________________________________________________________________________

Upvotes: 0

Views: 159

Answers (1)

Reza Akraminejad
Reza Akraminejad

Reputation: 1499

I finally find out that calling predict method with an extra parameter will work:

all_mapper = generator.flow(all_nodes,G.nodes())

Upvotes: 0

Related Questions