wolfsatthedoor
wolfsatthedoor

Reputation: 7313

How to get mean of nonzero elements by row, varying which columns are used by a condition

Suppose I have the following data table:

  tempmat=matrix(c(1,1,0,4,1,0,0,4,0,1,0,4, 0,0,1,4, 0,0,0,5),5,4,byrow=T)
  tempmat=rbind(rep(0,4),tempmat)
  tempmat=data.table(tempmat)
  names(tempmat)=paste0('prod1vint',1:4)

Which looks like:

       prod1vint1 prod1vint2 prod1vint3 prod1vint4
1:          0          0          0          0
2:          1          1          0          4
3:          1          0          0          4
4:          0          1          0          4
5:          0          0          1          4
6:          0          0          0          5

I want to define a new column, TN, that takes the mean row-wise in the following fashion.

  1. For each row, find the first nonzero element going left to right.
  2. Then, find the mean of all nonzero elements to the RIGHT of that.

The output should be:

   prod1vint1 prod1vint2 prod1vint3 prod1vint4   TN
1:          0          0          0          0   NA
2:          1          1          0          4   2.5
3:          1          0          0          4   4
4:          0          1          0          4   4
5:          0          0          1          4   4 
6:          0          0          0          5   NA

The NA's arise because in 1: there are no nonzero elements, and in 6: there are no nonzero elements to the right of the first nonzero element.

Upvotes: 4

Views: 80

Answers (3)

Frank
Frank

Reputation: 66819

You can iterate over columns, only operating when non-zero and after the first non-zero col in that row:

DT[, `:=`(n = 0L, s = 0, v = NA_real_)]
for (k in sprintf("prod1vint%s", 1:4)) 
  DT[get(k) != 0, `:=`(s = s + (n > 0)*get(k), n = n + 1L)]
DT[n > 1L, v := s/(n - 1)][]

   prod1vint1 prod1vint2 prod1vint3 prod1vint4 n s   v
1:          0          0          0          0 0 0  NA
2:          1          1          0          4 3 5 2.5
3:          1          0          0          4 2 4 4.0
4:          0          1          0          4 2 4 4.0
5:          0          0          1          4 2 4 4.0
6:          0          0          0          5 1 0  NA

Because this is vectorized, doesn't coerce to matrix and operates selectively, I expect that it is pretty efficient. The get part is awkward. but could be avoided like...

DT[, `:=`(n = 0L, s = 0, v = NA_real_)]
for (k in sprintf("prod1vint%s", 1:4)){ 
  expr = substitute(DT[k != 0, `:=`(s = s + (n > 0)*k, n = n + 1L)], list(k = as.name(k)))
  eval(expr)
}
DT[n > 1L, v := s/(n - 1)][]

Upvotes: 0

akrun
akrun

Reputation: 887691

Here is one option with melt

library(data.table)
library(dplyr)
TN <- melt(tempmat[, rid := seq_len(.N)], id.var = 'rid')[, 
    {i1 <- cumsum(value) > 0
    mean(na_if(value[i1][-1], 0), na.rm = TRUE)}, rid]$V1
tempmat[, TN := TN][]

Or using tidyverse

library(tidyverse)
tempmat %>% 
   mutate(TN = pmap(., ~ c(...) %>% 
           keep(., cumsum(.) > 0) %>%
           tail(-1) %>% 
           na_if(0) %>%
           mean(na.rm = TRUE)))

Or another option is to transpose the dataset and then do the colwise operation

t(tempmat) %>%
    as.data.frame %>% 
    summarise_all(list(~ mean(na_if(.[cumsum(.) > 0], 0)[-1],
          na.rm = TRUE))) %>%
    unlist %>%
    mutate(tempmat, TN = .)

Or using a vectorized approach

library(matrixStats)
m1 <- rowCumsums(as.matrix(tempmat)) > 0
m1[cbind(seq_len(nrow(m1)), max.col(m1, 'first'))] <- FALSE
rowMeans(na_if(tempmat * NA^!m1, 0), na.rm = TRUE)

Or using apply

apply(tempmat, 1, FUN = function(x) 
      mean(na_if(x[cumsum(x) > 0], 0)[-1], na.rm = TRUE))

Upvotes: 2

Ronak Shah
Ronak Shah

Reputation: 389175

Using apply row-wise we can first find out indices in the row which are not 0. Then calculate the mean for non-zero values if there is atleast one non-zero value and the non-zero value is not present in the last column else return NA.

tempmat$TN <- apply(tempmat, 1, function(x) {
           inds <- x != 0
           if (any(inds) & which.max(inds) != length(x)) 
             mean(Filter(function(f) f > 0, x[(which.max(inds) + 1) : length(x)]))
           else  
              NA
            })

tempmat
#   prod1vint1 prod1vint2 prod1vint3 prod1vint4  TN
#1:          0          0          0          0  NA
#2:          1          1          0          4 2.5
#3:          1          0          0          4 4.0
#4:          0          1          0          4 4.0
#5:          0          0          1          4 4.0
#6:          0          0          0          5  NA

Upvotes: 2

Related Questions