Reputation: 507
Original data:
> dt = data.table(v1 = c(3,1,1,5,6,12,13,11,10,0,2,1,3))
> dt
v1
1: 3
2: 1
3: 1
4: 5
5: 6
6: 12
7: 13
8: 11
9: 10
10: 0
11: 2
12: 1
13: 3
I would like to put v1
into 3 groups based on value as follows:
> dt %>% mutate(group = case_when(v1 <5 ~ 1,
+ v1 >=5 & v1 <10 ~ 2,
+ v1 >= 10 ~3))
v1 group
1 3 1
2 1 1
3 1 1
4 5 2
5 6 2
6 12 3
7 13 3
8 11 3
9 10 3
10 0 1
11 2 1
12 1 1
13 3 1
But I would also like to add a rule where if the total number of rows in a group is under 3, it takes the mean of those rows, and compares it to the rows (of v1) immediately before and after that group, and whichever value is closest to the mean absorbs that group.
In the example above, group 2 only has 2 rows, so I take their mean (5.5) and compare to the value above (1) and below (12). Since the smaller value is closer to the mean, those rows become group 1, making the desired output look as follows:
v1 group
1 3 1
2 1 1
3 1 1
4 5 1
5 6 1
6 12 3
7 13 3
8 11 3
9 10 3
10 0 1
11 2 1
12 1 1
13 3 1
I've made a few attempts to no avail and would really appreciate a dplyr
or data.table
solution.
Upvotes: 0
Views: 518
Reputation: 25223
Building on Frank's cut
and rleid(ct)
:
#from Frank's answer
dt[,
c("ct", "g") := {
ct <- cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE)
.(ct, rleid(ct))
}
]
#calculate mean
dt[, c("N", "m") := .(.N, m=mean(v1)), by=.(ct, g)]
#store last/first value from prev/next for rolling join later
ct_dt <- dt[, c(.(ct=ct, g=g), shift(.(v1, g), c(1L, -1L)))][,
.(near_v1=c(V3[1L], V4[.N]), new_ct=c(V5[1L], V6[.N])), .(ct, g)]
#update join for those with less than 3 rows
dt[N<3L, ct := ct_dt[.SD, on=.(ct, g, near_v1=m), roll="nearest", new_ct]]
#delete unwanted columns
dt[, c("g","N","m") := NULL]
output:
v1 ct
1: 3 1
2: 1 1
3: 1 1
4: 5 1
5: 6 1
6: 12 3
7: 13 3
8: 11 3
9: 10 3
10: 0 1
11: 2 1
12: 1 1
13: 3 1
Upvotes: 1
Reputation: 66819
First, compute the original grouping and aggregate:
gDT = dt[, .(.N, m = mean(v1)), by=.(
ct = ct <- cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE),
g = rleid(ct)
)]
ct g N m
1: 1 1 3 1.666667
2: 2 2 2 5.500000
3: 3 3 4 11.500000
4: 1 4 4 1.500000
Flag groups to change and compare m
with the nearest unchanging groups above and below:
gDT[, flag := N < 3]
gDT[, res := ct]
gDT[flag == TRUE, res := {
ffDT = gDT[flag == FALSE]
# nearest eligible rows going up and down -- possibly NA if at top or bottom
w_dn = ffDT[.(g = .SD$g - 1L), on=.(g), roll=TRUE, which=TRUE]
w_up = ffDT[.(g = .SD$g + 1L), on=.(g), roll=-Inf, which=TRUE]
# diffs of the mean against eligible rows up and down
diffs = lapply(list(dn = w_dn, up = w_up), function(w) abs(ffDT$m[w] - m))
# if/else for whichever is nearer, ties broken in favor of up
replace(ffDT$ct[w_dn], diffs$up < diffs$dn, ffDT$ct[w_up])
}]
ct g N m flag res
1: 1 1 3 1.666667 FALSE 1
2: 2 2 2 5.500000 TRUE 1
3: 3 3 4 11.500000 FALSE 3
4: 1 4 4 1.500000 FALSE 1
Creating a separate table like this makes it easy to check your work (look at flagged groups, check N
and ct
, compare m
with nearest unflagged neighbors, etc).
To add back to the original table, one way is:
dt[, res := gDT$res[ rleid(cut(v1, c(-Inf, 5, 10, Inf), right=FALSE, labels=FALSE)) ] ]
v1 ct res
1: 3 1 1
2: 1 1 1
3: 1 1 1
4: 5 2 1
5: 6 2 1
6: 12 3 3
7: 13 3 3
8: 11 3 3
9: 10 3 3
10: 0 1 1
11: 2 1 1
12: 1 1 1
13: 3 1 1
Details: The steps above are a lot more complicated than those in @RonakShah's answer since I assume that "group" here applies to contiguous rows:
But I would also like to add a rule where if the total number of rows in a group is under 3, it takes the mean of those rows, and compares it to the rows (of v1) immediately before and after that group, and whichever value is closest to the mean absorbs that group.
Otherwise, the criterion is not well defined -- if there is a group of size 2 but the two rows are not contiguous, there is no "immediately before and after that group" to compare against.
Upvotes: 1
Reputation: 389265
One option using dplyr
could be to create a new column which would keep an account of row_number
and compare the v1
value of one row above and below of those groups which have less than 3 rows and assign the new groups based on it. Here change
is the final output.
library(dplyr)
dt1 <- dt %>%
mutate(group = case_when(v1 < 5 ~ 1,
v1 >=5 & v1 <10 ~ 2,
v1 >= 10 ~3),
row = row_number())
dt1 %>%
group_by(group) %>%
mutate(change = if (n() < 3) {
c(dt1$group[first(row) - 1L], dt1$group[last(row) + 1L])[
which.min(c(abs(mean(v1) - dt1$v1[first(row) - 1L]),
abs(mean(v1) - dt1$v1[last(row) + 1L])))]
} else group)
# v1 group row change
# <dbl> <dbl> <int> <dbl>
# 1 3 1 1 1
# 2 1 1 2 1
# 3 1 1 3 1
# 4 5 2 4 1
# 5 6 2 5 1
# 6 12 3 6 3
# 7 13 3 7 3
# 8 11 3 8 3
# 9 10 3 9 3
#10 0 1 10 1
#11 2 1 11 1
#12 1 1 12 1
#13 3 1 13 1
Upvotes: 1