ph7see
ph7see

Reputation: 67

Find the mean from the permutations in R

There are balls with values 1 to 3 in the bag. I will draw all three ball without replacement at random. For the first ball, I need to pay the value of the ball multiplied by 1. For the second ball, I need to pay the value of the ball multiplied by 2. For the third ball, I need to pay the value of the ball multiplied by 3. For example, if you drew 1,2,3, then your total payment is (1*1)+(2*2)+(3*3) = 14. I want to find the mean of all possible total payments.

So I had this code:

library(gtools)

N<-1:3
perms3 <- data.frame(permutations(n = 3, r = 3, v = N))
perms3$total_payment <- perms3$X1 *1+ perms3$X2*2 + perms3$X3*3 
mean(perms3$total_payment)

I would like to make the general function that I can apply to any number of N. For example, there are balls with values 1 to 5 or 1 to 10 and so on. I might use above code with a little manipulation to calculate the mean of total payment like this:

N<-1:5
perms5 <- data.frame(permutations(n = 5, r = 5, v = N))
perms5$total_payment <- perms5$X1 *1+ perms5$X2*2 +perms5$X3*3 +perms5$X4*4 +perms5$X5*5
mean(perms5$total_payment) 

But, I don't want to do that each time. Can you help me solve this problem?

Upvotes: 3

Views: 271

Answers (5)

Joseph Wood
Joseph Wood

Reputation: 7597

This can be distilled to a constant time solution using a little math. In short, we are simply finding the Expected Value.

TL;DR

sum(1:n) * (n + 1) / 2

Which is equal to:

(n * (n + 1) / 2) * (n + 1) / 2   -->>   n * (n + 1)^2 / 4

constantTimeMean <- function(n) n * (n + 1)^2 / 4

constantTimeMean(5)
[1] 45

Explanation

Let (x1, x2, ... xn) be a permutation of the numbers 1 through n. Multiply each xi by i and sum like so:

x_1 * 1 + x_2 * 2 ... + x_n * n

Since we are taking all permutations, each index i has equal probability of being multiplied by each number 1 through n. We also note that if we remove the coefficients, the sum of each permutation is constant (i.e. sum(1:n)). Thus, all we need to do is calculate the average value of 1 through n and multiply by the sum of 1 through n.

The closed form expression of the sum of 1 through n is given by:

 (n * (n + 1) / 2)

Together with the average value, we obtain:

n * (n + 1)^2 / 4

This is nice, because generating all permutations gets out of hand really fast. For example, what if we set N = 15 or even N = 4321? That's facrorial(15) = 1.307674e+12 permutations... generating is already out of the question (factorial(4321) returns Inf... Using the gmp package, we see that it really has over 13000 decimal digits: gmp::log10.bigz(gmp::factorialZ(4321)) ~= 13834.99). However, with the formula above, it's no problem:

system.time(print(constantTimeMean(15)))
[1] 960
user  system elapsed 
   0       0       0


system.time(print(constantTimeMean(4321)))
[1] 20178728641
user  system elapsed 
   0       0       0 

Upvotes: 7

user12728748
user12728748

Reputation: 8506

If you care about speed, you could try the Rfast implementation:

# fastest previous proposition, for reference  
func <- function(N) {
    Ns <- seq_len(N)
    mean(gtools::permutations(n = N, r = N, v = Ns) %*% matrix(seq_len(N)))
}

# implementation using Rfast
func_u <- function(n){
    sn <- seq_len(n)
    mean(tcrossprod(Rfast::permutation(sn), t(sn)))
}

microbenchmark::microbenchmark(
    f_3 = func(3),
    u3 = func_u(3),
    f_7 = func(7),
    u7 = func_u(7)
)
#> Unit: microseconds
#>  expr       min          lq        mean      median          uq        max
#>   f_3   168.345    187.7160    661.2309    217.8845    244.7845  44466.821
#>    u3    35.434     45.3930    127.6996     52.6240     90.3450   6398.212
#>   f_7 47170.752 111422.4390 112419.3058 113008.3590 114360.2590 126243.638
#>    u7   234.751    271.7305    882.8380    298.1155    336.3765  41195.978
#>  neval cld
#>    100  a 
#>    100  a 
#>    100   b
#>    100  a

Created on 2020-04-09 by the reprex package (v0.3.0)

Upvotes: 2

r2evans
r2evans

Reputation: 160437

An alternative to RonakShah's function.

func <- function(N) {
  Ns <- seq_len(N)
  mean(gtools::permutations(n = N, r = N, v = Ns) %*% matrix(Ns))
}
func(3)
# [1] 12
func(5)
# [1] 45

This method has an advantage that it's taking care of the matrix multiplication you're using. The speed improvements may tend to even out for larger samples. We can also add R.Schifini's suggestion (in get_mean_b below) to use apply, though in general rowSums is faster than more-generic apply uses:

microbenchmark::microbenchmark(
  ronak_3  = get_mean(3),
  ronak_3b = get_mean_b(3),
  akrun_3  = akrun(3),
  r2_3     = func(3),
  ronak_5  = get_mean(5),
  ronak_5b = get_mean_b(5),
  akrun_5  = akrun(5),
  r2_5     = func(5),
  ronak_7  = get_mean(7),
  ronak_7b = get_mean_b(7),
  akrun_7  = akrun(7),
  r2_7     = func(7)
)
# Unit: microseconds
#      expr       min         lq       mean     median         uq        max neval
#   ronak_3   438.001   577.5010   684.8250   639.3510   752.7010   1769.601   100
#  ronak_3b   241.901   310.0005   386.5211   352.0010   423.1515   1202.001   100
#   akrun_3   202.601   274.4510   484.4809   297.0005   365.2010  13570.301   100
#      r2_3    87.601   110.4510   132.0599   125.3505   150.9010    218.000   100
#   ronak_5  1338.101  1689.3010  2085.9439  1774.6510  1949.9510  25789.601   100
#  ronak_5b  1208.101  1545.5000  1813.0931  1643.9015  1831.6510   5187.100   100
#   akrun_5  1004.301  1291.5010  1459.4920  1376.2010  1526.7010   3422.901   100
#      r2_5   924.601  1097.8510  1334.1570  1161.7510  1308.2010   5304.501   100
#   ronak_7 35273.101 46720.0505 59103.9000 54075.6015 64263.3005 118192.401   100
#  ronak_7b 43330.700 56615.3005 70568.5350 62788.4515 74308.0505 213410.001   100
#   akrun_7 34402.701 44957.6015 57026.5051 52982.6010 62273.2010 131092.001   100
#      r2_7 35018.401 43930.4510 58400.5710 51515.6510 61678.9510 167691.602   100

Upvotes: 2

akrun
akrun

Reputation: 887078

We can use crossprod

get_mean <- function(n) {
    perms <- data.frame(permutations(n = n, r = n, v = seq_len(n)))
     mean(crossprod(t(perms), seq_len(n)))

 }
get_mean(3)
#[1] 12
get_mean(5)
#[1] 45

Upvotes: 1

Ronak Shah
Ronak Shah

Reputation: 388962

You could write a function to calculate that.

library(gtools)

get_mean <- function(n) {
   perms <- data.frame(permutations(n = n, r = n, v = seq_len(n)))
   mean(rowSums(perms * as.list(seq_len(n))))
}

get_mean(3)
#[1] 12

get_mean(5)
#[1] 45

Upvotes: 2

Related Questions