Reputation: 150
I am trying to train a StellarGraph GCN Model and i get a dimension error on when fitting. Below is my code.
graph = StellarDiGraph(nodes=domains, edges=edges)
generator = FullBatchNodeGenerator(graph, method="gcn")
from sklearn import preprocessing
target_encoding = preprocessing.LabelBinarizer()
train_targets = target_encoding.fit_transform(train_labels)
val_targets = target_encoding.transform(val_labels)
test_targets = target_encoding.transform(test_labels)
# Creating train and val gen so they can be fitted into our StellarGraph GCN
train_gen = generator.flow(domains.loc[x_train].index, train_targets)
val_gen = generator.flow(domains.loc[x_val].index, val_targets)
# Hyperparameters to tune
layer_sizes = [128,128]
activations = ['relu', 'relu']
# Define GCN model
gcn = GCN(layer_sizes=layer_sizes, activations=activations, generator=generator, dropout=0.2)
x_inp, x_out = gcn.in_out_tensors()
predictions = layers.Dense(units=9, activation='softmax')(x_out)
model = Model(inputs=x_inp, outputs=predictions)
model.compile(optimizer=optimizers.Adam(learning_rate=1e-3),
loss=losses.categorical_crossentropy,
metrics=['acc'])
history = model.fit(
train_gen,
epochs=50,
validation_data=val_gen,
verbose=2,
shuffle=False, # this should be False, since shuffling data means shuffling the whole graph
callbacks=[early_stopping],
)
And i get the following error when fitting the model:
ValueError: features: expected batch dimension = 1 when using sparse adjacency matrix in GraphConvolution, found features batch dimension None
Call arguments received by layer 'graph_convolution' (type GraphConvolution):
• inputs=['tf.Tensor(shape=(None, None, None), dtype=float32)', 'SparseTensor(indices=Tensor("model/squeezed_sparse_conversion/Squeeze:0", shape=(None, None), dtype=int64), values=Tensor("model/squeezed_sparse_conversion/Squeeze_1:0", shape=(None,), dtype=float32), dense_shape=Tensor("model/squeezed_sparse_conversion/SparseTensor/dense_shape:0", shape=(2,), dtype=int64))']
I also tried to run the demo from their site using the cora dataset and got the exact same error. I can't find anything online to help.
Upvotes: 0
Views: 54