KMoore
KMoore

Reputation: 181

Python: How to return a list of connected nodes

I have a set of nodes (stored as nodes = np.array([[x1,y1],[x2,y2],etc]])) and elements connected the nodes. Each element is an array of the two node indices.

So [0,3] is an element connecting nodes 0 and 3.

How would I return a list of all node indices connected to a given node index?

For example, if:

element = np.array( [ [3,2] , [1,4] , [1,3] ])
print findConnectedNodes(3) 

should print [2,1]

Upvotes: 1

Views: 1903

Answers (4)

Divakar
Divakar

Reputation: 221644

Here's a vectorized approach -

def findConnectedNodes(a, node_id = 3):
    b = a[(a == node_id).any(1)]
    return b[b!=node_id]

Sample runs -

In [40]: element
Out[40]: 
array([[3, 2],
       [1, 4],
       [1, 3]])

In [41]: findConnectedNodes(element, node_id=3)
Out[41]: array([2, 1])

In [42]: findConnectedNodes(element, node_id=1)
Out[42]: array([4, 3])

In [43]: findConnectedNodes(element, node_id=2)
Out[43]: array([3])

For performance, we can use slicing to get b for such a 2 column input array of pairwise nodes. Also, we can employ more of boolean masking here. Thus, we would have two more approaches, like so -

def findConnectedNodes_v2(a, node_id = 3):
    b = a[(a[:,0] == node_id) | (a[:,1] == node_id)]
    return b[b!=node_id]

def findConnectedNodes_v3(a, node_id = 3):
    mask2D = a == node_id
    mask1D = mask2D[:,0] | mask2D[:,1]
    return a[mask1D][~mask2D[mask1D]]

Upvotes: 3

toine
toine

Reputation: 2026

if you are looking for the function that takes:

data = np.array([[3,2], [1,4], [1,3]])

and returns:

[2, 1]

you can have:

res = []
matches = d[(d == 3).any(axis = 1)]
for pair in matches:
    for node in pair:
        if (node != 3):
            res.append(node)

or simply:

matches = data[(data == 3).any(axis = 1)]
[node for pair in matches for node in pair if node != 3]

see this for double list comprehension.

Upvotes: -1

Nikolas Rieble
Nikolas Rieble

Reputation: 2611

Using numpy, you can iterate through all elements and return those elements that contain the target node. Then you create a set of those numbers and delete the target node from that set such as

def findconnected_nodes(a, element):   
    return np.setdiff1d(np.unique([i for i in element if a in i]), a)

EDIT

Comparing both valid solutions for performance:

element = np.random.randint(1,10,(10000, 2))
a = 2
%timeit findConnectedNodes(a, element)

10 loops, best of 3: 51.9 ms per loop

%timeit findconnected_nodes(a, element)

100 loops, best of 3: 16.6 ms per loop

Upvotes: 2

Ma0
Ma0

Reputation: 15204

Here is a working solution:

import numpy as np

elements = np.array([[3, 2], [1, 4], [1, 3]])

def findConnectedNodes(lookup):
    connected = [x for y in elements for x in y if lookup in y]
    return list(set(connected) - {lookup,})

print(findConnectedNodes(3))  # [1, 2]

it works by getting all the lists that have the node we are looking for and flattening them. Then, it converts the flattened list to a set to remove the duplicates and finally it removes the node lookUp node.

Upvotes: 0

Related Questions