Kira
Kira

Reputation: 115

Speed up Python cKDTree

I currently have a function that I created that connects the blue dots with its (at maximum) 3 nearest neighbors within a pixel range of 55. The vertices_xy_list is an extremely large list or points (nested list) of about 5000-10000 pairs.

Example of vertices_xy_list:

[[3673.3333333333335, 2483.3333333333335],
 [3718.6666666666665, 2489.0],
 [3797.6666666666665, 2463.0],
 [3750.3333333333335, 2456.6666666666665],...]

I currently have written this calculate_draw_vertice_lines() function that uses a CKDTree inside of a While loop to find all points within 55 pixels and then connect them each with a green line.

It can be seen that this would become exponentially slower as the list gets longer. Is there any method to speed up this function significantly? Such as vectorizing operations?

def calculate_draw_vertice_lines():

    global vertices_xy_list
    global cell_wall_lengths
    global list_of_lines_references

    index = 0

    while True:

        if (len(vertices_xy_list) == 1):

            break

        point_tree = spatial.cKDTree(vertices_xy_list)

        index_of_closest_points = point_tree.query_ball_point(vertices_xy_list[index], 55)

        index_of_closest_points.remove(index)

        for stuff in index_of_closest_points:

            list_of_lines_references.append(plt.plot([vertices_xy_list[index][0],vertices_xy_list[stuff][0]] , [vertices_xy_list[index][1],vertices_xy_list[stuff][1]], color = 'green'))

            wall_length = math.sqrt( (vertices_xy_list[index][0] - vertices_xy_list[stuff][0])**2 + (vertices_xy_list[index][1] - vertices_xy_list[stuff][1])**2 )

            cell_wall_lengths.append(wall_length)

        del vertices_xy_list[index]

    fig.canvas.draw()

enter image description here

Upvotes: 4

Views: 1388

Answers (1)

hilberts_drinking_problem
hilberts_drinking_problem

Reputation: 11602

If I understand the logic of selecting the green lines correctly, there is no need to create a KDTree at each iteration. For each pair (p1, p2) of blue points, the line should be drawn if and only if the following hold:

  1. p1 is one of 3 closest neighbors of p2.
  2. p2 is one of 3 closest neighbors of p1.
  3. dist(p1, p2) < 55.

You can create the KDTree once and create a list of green lines efficiently. Here is part of the implementation that returns a list of pairs of indices for points between which the green lines need to be drawn. The runtime is about 0.5 seconds on my machine for 10,000 points.

import numpy as np
from scipy import spatial


data = np.random.randint(0, 1000, size=(10_000, 2))

def get_green_lines(data):
    tree = spatial.cKDTree(data)
    # each key in g points to indices of 3 nearest blue points
    g = {i: set(tree.query(data[i,:], 4)[-1][1:]) for i in range(data.shape[0])}

    green_lines = list()
    for node, candidates in g.items():
        for node2 in candidates:
            if node2 < node:
                # avoid double-counting
                continue

            if node in g[node2] and spatial.distance.euclidean(data[node,:], data[node2,:]) < 55:
                green_lines.append((node, node2))

    return green_lines

You can proceed to plot green lines as follows:

green_lines = get_green_lines(data)
fig, ax = plt.subplots()
ax.scatter(data[:, 0], data[:, 1], s=1)
from matplotlib import collections as mc
lines = [[data[i], data[j]] for i, j in green_lines]
line_collection = mc.LineCollection(lines, color='green')
ax.add_collection(line_collection)

Example output:

enter image description here

Upvotes: 2

Related Questions