Reputation: 75
I have been using the DirichletMultinomial R package to build some clustering of a dataset. Now, using the model I built, I would like to predict these groups on another dataset. Before doing that, I just used predict on my original dataset and I was surprised by the result.
If I do it on the Twins dataset included in the DirichletMultinomial package, I have the following results, which seem almost identical.
> fl <- system.file(package="DirichletMultinomial", "extdata","Twins.csv")
> count <- t(as.matrix(read.csv(fl, row.names=1)))
> fit <- mclapply(1:7, dmn, count=count, verbose=TRUE)
> best <- fit[[4]]
> head(mixture(best))
[,1] [,2] [,3] [,4]
TS1.2 9.999914e-01 8.533980e-06 3.306707e-08 2.117430e-11
TS10.2 3.775494e-08 9.996731e-01 2.847409e-10 3.268341e-04
TS100.2 7.215444e-09 7.954710e-13 1.174677e-01 8.825323e-01
TS100 5.973009e-01 7.881327e-02 3.238844e-01 1.413423e-06
TS101.2 1.110644e-13 2.606683e-19 1.472201e-06 9.999985e-01
TS103.2 3.736814e-04 9.996260e-01 4.361933e-10 3.245371e-07
> head(predict(best,count))
[,1] [,2] [,3] [,4]
TS1.2 9.999914e-01 8.533549e-06 3.306516e-08 2.116933e-11
TS10.2 3.774837e-08 9.996732e-01 2.847388e-10 3.268028e-04
TS100.2 7.214920e-09 7.953494e-13 1.174489e-01 8.825511e-01
TS100 5.972792e-01 7.881798e-02 3.239015e-01 1.413737e-06
TS101.2 1.110848e-13 2.606804e-19 1.472259e-06 9.999985e-01
TS103.2 3.736395e-04 9.996260e-01 4.361901e-10 3.244978e-07
But if I do it on my dataset, I have some discrepancies (especially for the 2 last rows), and I do not understand where it comes from. Maybe I misunderstood the way the method is working, but I would have expected to find approximately the same results with mixture and predict. How am I mistaken here?
> fit <- mclapply(1:7, dmn, count=data_mat, verbose=TRUE)
> best <- fit[[5]]
> head(mixture(best))
[,1] [,2] [,3] [,4] [,5]
S1 1.478612e-05 9.982146e-01 3.268764e-04 0.001440704 3.017367e-06
S2 1.434083e-04 1.830344e-05 1.108533e-05 0.999827203 8.179015e-11
S3 4.216772e-13 7.048083e-09 1.467059e-01 0.853294032 6.232012e-08
S4 1.099027e-07 8.253738e-01 1.486238e-01 0.025929117 7.317699e-05
S5 3.990831e-09 4.338591e-02 5.659807e-01 0.347391684 4.324167e-02
S6 6.712974e-01 1.970030e-04 2.243382e-08 0.328505564 6.026012e-11
> head(predict(best,data_mat))
[,1] [,2] [,3] [,4] [,5]
S1 1.694860e-05 9.982972e-01 2.966473e-04 0.00138649 2.690955e-06
S2 1.551721e-04 1.844655e-05 1.091220e-05 0.99981547 7.825475e-11
S3 4.472069e-13 6.991368e-09 1.426672e-01 0.85733272 6.133705e-08
S4 1.439618e-07 8.320968e-01 1.405944e-01 0.02723980 6.891784e-05
S5 1.478314e-08 2.397801e-02 1.150241e-01 0.84848208 1.251575e-02
S6 8.047456e-01 2.614642e-04 7.240597e-09 0.19499295 1.107943e-10
Upvotes: 0
Views: 58