Shishir Pandey
Shishir Pandey

Reputation: 832

How to use a user defined metric for nearest neighbors in scikit-learn?

I am using scikit-learn 0.18.dev0. I know exactly the same question has been asked before here. I tried the answer presented there, I am getting the following error

>>> def mydist(x, y):
...     return np.sum((x-y)**2)
...
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3,   2]])

>>> nbrs = NearestNeighbors(n_neighbors=4, algorithm='ball_tree',
...            metric='pyfunc', func=mydist)

Error message _init_params() got an unexpected keyword argument 'func'

It looks like this option has been removed. How can I use a user defined matrix in sklearn.neighbors?

Upvotes: 4

Views: 2208

Answers (1)

eickenberg
eickenberg

Reputation: 14377

The proper keyword is metric:

import numpy as np
from sklearn.neighbors import NearestNeighbors

def mydist(x, y):
    return np.sum((x-y)**2)

nn = NearestNeighbors(n_neighbors=4, algorithm='ball_tree', metric=myfunc)

X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3,   2]])
nn.fit(X)

This is also mentioned in the docstring in the development version: https://github.com/scikit-learn/scikit-learn/blob/86b1ba72771718acbd1e07fbdc5caaf65ae65440/sklearn/neighbors/unsupervised.py#L48

Upvotes: 6

Related Questions