Roman Voyt
Roman Voyt

Reputation: 113

How can I improve this solution to make it faster using numpy?

The problem statement:

An unnamed tourist got lost in New York. All he has is a map of M metro stations, which shows the coordinates of the stations and his own coordinates, which he saw on the nearby pointer. The tourist is not sure that each of the stations is open, therefore, just in case, he is looking for the nearest N stations. The tourist moves through New York City like every New Yorker (Distance of city quarters). Help the tourist to find these stations.

Sample input

5 2
А 1 2
B 4.5 1.2
C 100500 100500
D 100501 100501
E 100502 100502
1 1

Sample output

A B

My code:

import scipy.spatial.distance as d
import math

#finds N nearest metro stations in relation to the tourist
def find_shortest_N(distance_list, name_list, number_of_stations):
    result = []
    for num in range(0, number_of_stations):
        min_val_index = distance_list.index(min(distance_list))
        result.append(name_list[min_val_index])
        distance_list.pop(min_val_index)
        name_list.pop(min_val_index)
    return result

#returns a list with distances between touri and stations
def calculate_nearest(list_of_coords, tourist_coords):
    distances = []
    for metro_coords in list_of_coords:
        distances.append(math.fabs(d.cityblock(metro_coords, tourist_coords)))
    return distances


station_coords = []
station_names = []

input_stations = input("Input a number of stations: ").split()
input_stations = list(map(int, input_stations))

#all station coordinates and their names
station_M = input_stations[0]

#number of stations a tourist wants to visit
stations_wanted_N = input_stations[1]

#distribute the station names in station_names list 
#and the coordinates in station_coords list
for data in range(0, station_M):
    str_input = input()
    list_input = str_input.split()
    station_names.append(list_input[0])
    list_input.pop(0)
    list_input = list(map(float, list_input))
    station_coords.append(list_input)

tourist_coordinates = input("Enter tourist position: ").split()
tourist_coordinates = list(map(float, tourist_coordinates))

distance_values = calculate_nearest(station_coords, tourist_coordinates)

result = find_shortest_N(distance_values, station_names, stations_wanted_N)

for name in result:
    print(name, end=" ")

Upvotes: 0

Views: 90

Answers (2)

Daniel F
Daniel F

Reputation: 14409

Use scipy.spatial.KDTree

from scipy.spatial import KDTree
subway_tree = KDTree(stations_coords)
dist, idx = subway_tree.query(tourist_coords, nbr_wanted, p = 1)
nearest_stations = station_names[idx]

Upvotes: 1

xdze2
xdze2

Reputation: 4151

You could also, for example, directly use the cdist function:

import numpy as np
from scipy.spatial.distance import cdist

sample_input = '''
5 2
А 1 2
B 4.5 1.2
C 100500 100500
D 100501 100501
E 100502 100502
1 1
'''

# Parsing the input data:
sample_data = [line.split()
                for line in sample_input.strip().split('\n')]

tourist_coords = np.array(sample_data.pop(),  dtype=float) # takes the last line
nbr_stations, nbr_wanted = [int(n) for n in sample_data.pop(0)] # takes the first line

stations_coords = np.array([line[1:] for line in sample_data], dtype=float)
stations_names = [line[0] for line in sample_data]

# Computing the distances:
tourist_coords = tourist_coords.reshape(1, 2)  # have to be a 2D array
distance = cdist(stations_coords, tourist_coords, metric='cityblock')

# Sorting the distances:
sorted_distance = sorted(zip(stations_names, distance), key=lambda x:x[1])

# Result:
result = [name for name, dist in sorted_distance[:nbr_wanted]]
print(result)

Upvotes: 1

Related Questions