george123
george123

Reputation: 37

How to provide specific training, validation and test sets in StellarGraph PaddedGraphGenerator -

I am trying to train a graph convolutional neural network using the StellarGraph library. I would like to run this example https://stellargraph.readthedocs.io/en/stable/demos/graph-classification/gcn-supervised-graph-classification.html but without the N-Fold Crossvalidation by providing my own training, validation and test sets. This is the code I am using (taken from this post)

generator = PaddedGraphGenerator(graphs=graphs)

train_gen = generator.flow([x for x in range(0, len(graphs_train))],
                           targets=graphs_train_labels,
                           batch_size=35)

test_gen = generator.flow([x for x in range(len(graphs_train),len(graphs_train) + len(graphs_test))],
                          targets=graphs_test_labels,
                          batch_size=35)

# Stopping criterium
es = EarlyStopping(monitor="val_loss",
                   min_delta=0,
                   patience=20,
                   restore_best_weights=True)

# Model definition
gc_model = GCNSupervisedGraphClassification(layer_sizes=[64, 64],
                                            activations=["relu", "relu"],
                                            generator=generator,
                                            dropout=0.5)

x_inp, x_out = gc_model.in_out_tensors()
predictions = Dense(units=32, activation="relu")(x_out)
predictions = Dense(units=16, activation="relu")(predictions)
predictions = Dense(units=1, activation="sigmoid")(predictions)

# Creating Keras model and preparing it for training
model = Model(inputs=x_inp, outputs=predictions)
model.compile(optimizer=Adam(0.001), loss=binary_crossentropy, metrics=["acc"])

# GNN Training
history = model.fit(train_gen, epochs=10, validation_data=test_gen, verbose=1)
model.fit(x=graphs_train,
          y=graphs_train_labels,
          epochs=10,
          verbose=1,
          callbacks=[es])


# Calculate performance on the validation data
test_metrics = model.evaluate(valid_gen, verbose=1)
valid_acc = test_metrics[model.metrics_names.index("acc")]

print(f"Test Accuracy model = {valid_acc}")

But at the end I am getting this error

ValueError: Failed to find data adapter that can handle input: (<class 'list'> containing values of types {"<class 'stellargraph.core.graph.StellarGraph'>"}), <class 'numpy.ndarray'>

What am I missing here? Is it because of the way I have created the graphs? In my case the graphs is a list which contains the stellar graphs

Upvotes: 0

Views: 149

Answers (1)

george123
george123

Reputation: 37

Problem solved. I was calling

model.fit(x=graphs_train,
          y=graphs_train_labels,
          epochs=10,
          verbose=1,
          callbacks=[es])

after the line

history = model.fit(train_gen, epochs=10, validation_data=test_gen, verbose=1)



Upvotes: 1

Related Questions