Reputation: 11
I'm trying to create a graph neural network, for edge prediction and got this error. Would really appreciate it if someone could help me out.
from sklearn.metrics import roc_auc_score
model = GraphSAGE(train_g.ndata['congestion_onehot'].shape[1],16)
# You can replace DotPredictor with MLPPredictor.
#pred = MLPPredictor(16)
pred = DotPredictor()
def compute_loss(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score])
labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
return F.binary_cross_entropy_with_logits(scores, labels)
def compute_auc(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score]).numpy()
labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
return roc_auc_score(labels, scores)
The error was:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-56-d9c7e915d747> in <module>()
1 from sklearn.metrics import roc_auc_score
----> 2 model = GraphSAGE(train_g.ndata['congestion_onehot'].shape[1],16)
3 # You can replace DotPredictor with MLPPredictor.
4 #pred = MLPPredictor(16)
5 pred = DotPredictor()
IndexError: tuple index out of range
If it helps
train_g
Graph(num_nodes=4333, num_edges=60222,
ndata_schemes={'congestion_onehot': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'weight': Scheme(shape=(), dtype=torch.float64)})
Upvotes: 1
Views: 241