Reputation: 181
I have a set of nodes (stored as nodes = np.array([[x1,y1],[x2,y2],etc]])
) and elements connected the nodes. Each element is an array of the two node indices.
So [0,3] is an element connecting nodes 0 and 3.
How would I return a list of all node indices connected to a given node index?
For example, if:
element = np.array( [ [3,2] , [1,4] , [1,3] ])
print findConnectedNodes(3)
should print [2,1]
Upvotes: 1
Views: 1903
Reputation: 221644
Here's a vectorized approach -
def findConnectedNodes(a, node_id = 3):
b = a[(a == node_id).any(1)]
return b[b!=node_id]
Sample runs -
In [40]: element
Out[40]:
array([[3, 2],
[1, 4],
[1, 3]])
In [41]: findConnectedNodes(element, node_id=3)
Out[41]: array([2, 1])
In [42]: findConnectedNodes(element, node_id=1)
Out[42]: array([4, 3])
In [43]: findConnectedNodes(element, node_id=2)
Out[43]: array([3])
For performance, we can use slicing to get b
for such a 2 column input array of pairwise nodes. Also, we can employ more of boolean masking here. Thus, we would have two more approaches, like so -
def findConnectedNodes_v2(a, node_id = 3):
b = a[(a[:,0] == node_id) | (a[:,1] == node_id)]
return b[b!=node_id]
def findConnectedNodes_v3(a, node_id = 3):
mask2D = a == node_id
mask1D = mask2D[:,0] | mask2D[:,1]
return a[mask1D][~mask2D[mask1D]]
Upvotes: 3
Reputation: 2026
if you are looking for the function that takes:
data = np.array([[3,2], [1,4], [1,3]])
and returns:
[2, 1]
you can have:
res = []
matches = d[(d == 3).any(axis = 1)]
for pair in matches:
for node in pair:
if (node != 3):
res.append(node)
or simply:
matches = data[(data == 3).any(axis = 1)]
[node for pair in matches for node in pair if node != 3]
see this for double list comprehension.
Upvotes: -1
Reputation: 2611
Using numpy, you can iterate through all elements and return those elements that contain the target node. Then you create a set of those numbers and delete the target node from that set such as
def findconnected_nodes(a, element):
return np.setdiff1d(np.unique([i for i in element if a in i]), a)
EDIT
Comparing both valid solutions for performance:
element = np.random.randint(1,10,(10000, 2))
a = 2
%timeit findConnectedNodes(a, element)
10 loops, best of 3: 51.9 ms per loop
%timeit findconnected_nodes(a, element)
100 loops, best of 3: 16.6 ms per loop
Upvotes: 2
Reputation: 15204
Here is a working solution:
import numpy as np
elements = np.array([[3, 2], [1, 4], [1, 3]])
def findConnectedNodes(lookup):
connected = [x for y in elements for x in y if lookup in y]
return list(set(connected) - {lookup,})
print(findConnectedNodes(3)) # [1, 2]
it works by getting all the lists that have the node we are looking for and flattening them. Then, it converts the flattened list to a set to remove the duplicates and finally it removes the node lookUp node.
Upvotes: 0