Rachel Zhang
Rachel Zhang

Reputation: 564

Find the nearest neighbor using caret

I'm fitting a k-nearest neighbor model using R's caret package.

library(caret)

set.seed(0)
y = rnorm(20, 100, 15)
predictors = matrix(rnorm(80, 10, 5), ncol=4)
data = data.frame(cbind(y, predictors))
colnames(data)=c('Price', 'Distance', 'Cost', 'Tax', 'Transport')

I left one observation as the test data and fit the model using the training data.

id = sample(nrow(data)-1)
train = data[id, ]
test = data[-id,]

knn.model = train(Price~., method='knn', train)
predict(knn.model, test)

When I display knn.model, it tells me it uses k=9. I would love to know which 9 observations are actually the "nearest" to the test observation. Besides manually calculating the distances, is there an easier way to display the nearest neighbors?

Thanks!

Upvotes: 1

Views: 778

Answers (1)

Carles
Carles

Reputation: 2829

When you are using knn you are creating clusters with points that are near based on independent variables. Normally, this is done using train(Price~., method='knn', train), such that the model chooses the best prediction based on some criteria (taking into account also the dependent variable as well). Given the fact I have not checked whether the R object stores the predicted price for each of the trained values, I just used the model trained to predicte the expected price given the model (where the expected price is located in the space).

At the end, the dependent variable is just a representation of all the other variables in a common space, where the price associated is assumed to be similar since you cluster based on proximity. As a summary of steps, you need to calculate the following:

  1. Get the distance for each of the training data points. This is done through predicting over them.
  2. Calculate the distance between the trained data and your observation of interest (in absolut value, since you do not care about the sign but just about the absolut distances).
  3. Take the indexes of the N smaller ones(e.g.N= 9). you can get the observations and related to this lower distances.

    TestPred<-predict(knn.model, newdata = test)
    TrainPred<-predict(knn.model, train)
    
    Nearest9neighbors<-order(abs(TestPred-TrainPred))[1:9]
    
    train[Nearest9neighbors,]
         Price    Distance      Cost       Tax Transport
    15  95.51177 13.633754  9.725613 13.320678 12.981295
    7   86.07149 15.428847  2.181090  2.874508 14.984934
    19 106.53525 16.191521 -1.119501  5.439658 11.145098
    2   95.10650 11.886978 12.803730  9.944773 16.270416
    4  119.08644 14.020948  5.839784  9.420873  8.902422
    9   99.91349  3.577003 14.160236 11.242063 16.280094
    18  86.62118  7.852434  9.136882  9.411232 17.279942
    11 111.45390  8.821467 11.330687 10.095782 16.496562
    17 103.78335 14.960802 13.091216 10.718857  8.589131
    

Upvotes: 2

Related Questions