Reputation: 618
I need to program a nearest neighbor algorithm in stata from scratch because my dataset does not allow me to use any of the available solutions (as far as I am concerned).
To be pecise. I have a dataset that is of similar structure to that of the following (original has around 14k observations)
input id value treatment match
1 0.14 0 .
2 0.32 0 .
3 0.465 1 2
4 0.878 1 2
5 0.912 1 2
6 0.001 1 1
end
I want to generate a variable called match (already included in the example above). For each observation with treatment == 1 the variable match should store the id of another observation from within treatment == 0 whose value is closest to value of the considered observation (treatment == 1).
I am new to stata programming, so I am not yet familiar with the syntax. My first shot is the following however it does not produce any changes to the match variable. I am sure this is a novice question but I am hoping for some advice on how to make the code running.
EDIT: I have changed the code slightly and now it seems to work. Do you see any problems that may arise if I run it on a bigger dataset?
set more off
clear all
input id pscore treatment
1 0.14 0
2 0.32 0
3 0.465 1
4 0.878 1
5 0.912 1
6 0.001 1
end
gen match = .
forval i = 1/`= _N' {
if treatment[`i'] == 1 {
local dist 1
forvalues j = 1/`= _N' {
if (treatment[`j'] == 0) {
local current_dist (pscore[`i'] - pscore[`j'])^2
if `dist' > `current_dist' {
local dist `current_dist' // update smallest distance
replace match = id[`j'] in `i' // write match
}
}
}
}
}
Upvotes: 2
Views: 6236
Reputation: 11102
Consider some simulated data: 1,000 observations, 200 of them untreated (treat == 0
) and the rest treated (treat == 1
). Then the code included below will be much more efficient than the originally posted. (Ties, like in your code, are not explicitly handled.)
clear
set more off
*----- example data -----
set obs 1000
set seed 32956
gen id = _n
gen pscore = runiform()
gen treat = cond(_n <= 200, 0, 1)
*----- new method -----
timer clear
timer on 1
// get id of last non-treated and first treated
// (data is sorted by treat and ids are consecutive)
bysort treat (id): gen firsttreat = id[1]
local firstt = first[_N]
local lastnt = `firstt' - 1
// start loop
gen match = .
gen dif = .
quietly forvalues i = `firstt'/`=_N' {
// compute distances
replace dif = (pscore[`i'] - pscore)^2
summarize dif in 1/`lastnt', meanonly
// identify id of minimum-distance observation
replace match = . in 1/`lastnt'
replace match = id in 1/`lastnt' if dif == r(min)
summarize match in 1/`lastnt', meanonly
// save the minimum-distance id
replace match = r(max) in `i'
}
// clean variable and drop
replace match = . in 1/`lastnt'
drop dif firsttreat
timer off 1
tempfile first
save `first'
*----- your method -----
drop match
timer on 2
gen match = .
quietly forval i = 1/`= _N' {
if treat[`i'] == 1 {
local dist 1
forvalues j = 1/`= _N' {
if (treat[`j'] == 0) {
local current_dist (pscore[`i'] - pscore[`j'])^2
if `dist' > `current_dist' {
local dist `current_dist' // update smallest distance
replace match = id[`j'] in `i' // write match
}
}
}
}
}
timer off 2
tempfile second
save `second'
// check for equality of results
cf _all using `first'
// check times
timer list
The results in seconds to finish execution:
. timer list
1: 0.19 / 1 = 0.1930
2: 10.79 / 1 = 10.7900
The difference is huge, specially considering this data set has only 1,000 observations.
An interesting thing to notice is that as the number of non-treated cases increases relative to the number of treated, then the original method improves, but never reaches the levels of efficiency of the new method. As an example, invert the number of cases, so there is now 800 untreated and 200 treated (change data setup to gen treat = cond(_n <= 800, 0, 1)
). The result is
. timer list
1: 0.07 / 1 = 0.0720
2: 4.45 / 1 = 4.4470
You can see that the new method also improves and is still much faster. In fact, the relative difference is still the same.
Another way to do this is using joinby
or cross
. The problem is they temporarily expand (a lot) the size of your data base. In many cases, they are not feasible due to the hard limit Stata has on the number of possible observations (see help limits
). You can find an example of joinby
here: https://stackoverflow.com/a/19784222/2077064.
If there's a large number of treated relative to untreated, your code suffers
because you go through the whole first loop many more times (due to the first if
).
Furthermore, going through
that whole loop once, implies going through another loop that
has itself two if
conditions, _N
more times.
The opposite case in which there are few treated observations means that you go through the whole
first loop only in a small number of occasions, speeding up your code substantially.
The reason my code can maintain its efficiency is due to the use of in
. This always
offers speed gains over if
. Stata will go directly to those observations with no
logical checking needed. Your problem provides an opportunity for that replacement
and it's wise to seize it.
If my code used if
where in
is in place, the results would be different.
Your code would be faster for the
case in which there's a large number of untreated relative to treated, and again, that
is because in your code there would not be the need to go through the complete loop,
requiring very little work;
the first loop is short-circuited with the first if
. For the opposite case,
my code would still dominate.
The key is to "separate" treated from untreated and work on each group using in
.
Upvotes: 4