BizarroJr
BizarroJr

Reputation: 13

Trouble retrieving nodes by attributes in NetworkX

I am developing a code using NetworkX in which I have a Multipartite graph similar to the following:

enter image description here

Each node has the following attributes:

In the photo above, each node label is depicted as 'label.layer' (in this case we can see a subset of a bigger graph, hence the layer number starts at 451 and not at 0).

What I would like to get are separate dictionaries from this graph, and these dictionaries should only contain the nodes which belong to the same trajectory, that is, all the nodes which are neighbours between each other. So far I followed these posts and my solution is:

Select network nodes with a given attribute value

Select nodes and edges form networkx graph with attributes

for i in range(trajectory):
        sel_nodes = dict((node, attribute['trajectory']) for node, attribute in G.nodes().items() if attribute['trajectory'] == i)
        print(sel_nodes)

This should return a dictionary for each 'row' of nodes, however the output are the following dicts:

{'0.451': 0, '0.452': 0, '0.453': 0, '0.454': 0, '0.455': 0, '0.456': 0, '0.457': 0, '0.458': 0, '0.459': 0, '0.460': 0}
{'1.451': 1, '1.452': 1, '1.453': 1, '1.454': 1, '1.455': 1, '1.456': 1, '1.457': 1, '1.458': 1, '1.459': 1, '1.460': 1}
{'2.451': 2, '3.452': 2, '3.453': 2, '3.454': 2, '3.455': 2, '3.456': 2, '3.457': 2, '3.458': 2, '3.459': 2, '4.460': 2}
{'3.451': 3, '2.452': 3, '2.453': 3, '2.454': 3, '2.455': 3, '2.456': 3, '2.457': 3, '2.458': 3, '2.459': 3, '3.460': 3}
{'4.451': 4, '4.452': 4, '4.453': 4, '4.454': 4, '4.455': 4, '4.456': 4, '4.457': 4, '4.458': 4, '4.459': 4, '5.460': 4}
{'5.451': 5, '5.452': 5, '5.453': 5, '5.454': 5, '5.455': 5, '5.456': 5, '5.457': 5, '5.458': 5, '5.459': 5, '6.460': 5}
{'6.451': 6, '6.452': 6, '6.453': 6, '6.454': 6, '6.455': 6, '6.456': 6, '6.457': 6, '6.458': 6, '6.459': 6, '7.460': 6}
{'7.451': 7, '7.452': 7, '7.453': 7, '7.454': 7, '7.455': 7, '7.456': 7, '7.457': 7, '7.458': 7, '7.459': 7, '8.460': 7}
{'8.451': 8, '8.452': 8, '8.453': 8, '8.454': 8, '8.455': 8, '8.456': 8, '8.457': 8, '8.458': 8, '8.459': 8, '9.460': 8}
{}
{}

The last two empty dicts should be containing the lower line of nodes, and the lonely node at the last column of the graph, respectively, however it is not the case and I am only able to retrieve the nodes which are somehow connected to the first column.

Is there some way to fix this behaviour?

EDIT: To narrow down the problem a little bit, I believe the problem lies in the dictionary comprehension I used, since I have checked if the attribute trajectory has a value assigned to it by doing:

print(G.nodes['9.455']['trajectory']) 

And the output gives me trajectory 9, which is coherent with which I expect the trajectory to be.

Upvotes: 0

Views: 956

Answers (1)

Andrew Eckart
Andrew Eckart

Reputation: 1726

I feel like a list of dicts is the wrong data structure here. It seems like what you're really looking for a data structure that tells you, given a trajectory i, which nodes belong to this trajectory? A dict of sets or list of sets seems like a better fit.

Here is how you could construct such a list on a small example graph where the top trajectory is an isolated node and the bottom two cross over each other:

>>> import networkx as nx
>>> G = nx.Graph()
>>> G.add_nodes_from([(0.3, {"trajectory": 0}), (1.1, {"trajectory": 1}), (2.2, {"trajectory": 1}), (2.3, {"trajectory": 1}), (2.1, {"trajectory": 2}), (1.2, {"trajectory": 2}), (1.3, {"trajectory": 2})])
>>> from collections import defaultdict
>>> d = defaultdict(set)
>>> for node, attrs in G.nodes().items():
...     d[attrs["trajectory"]].add(node)
... 
>>> d
defaultdict(<class 'set'>, {0: {0.3}, 1: {1.1, 2.2, 2.3}, 2: {1.2, 2.1, 1.3}})

If you really want the list of dicts as described above, you can easily construct it from here:

>>> for trajectory, nodes in d.items():
...     print({node: trajectory for node in nodes})
... 
{0.3: 0}
{1.1: 1, 2.2: 1, 2.3: 1}
{1.2: 2, 2.1: 2, 1.3: 2}

If you want an ordered list for each trajectory instead of a set, you can sort on the part of the name which comes after the decimal point:

>>> d = {k: sorted(nodes, key=lambda x: str(x)[2:]) for k, nodes in d.items()}
>>> d
{0: [0.3], 1: [1.1, 2.2, 2.3], 2: [2.1, 1.2, 1.3]}

Upvotes: 0

Related Questions