Manoj Selvam
Manoj Selvam

Reputation: 77

How to integrate ALS in my spark pipeline to implement Non-negative matrix factorization?

I'm using spark mllib to train naive-bayes classifier model where i create a pipeline to index my string features, then normalize and apply PCA for dimensionality reduction after which i train my naive bayes model. When i run the pipeline i get negative values in the PCA components vector.On googling i found out that i have to apply NMF(Non negative matrix factorization) to obtain positive vectors and i found ALS will implement NMF with method .setnonnegative(true), but i dont know how to integrate the ALS into my pipeline after PCA. Any help appreciated. Thanks.

here is the code

import org.apache.spark.SparkConf;
import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.Normalizer;
import org.apache.spark.ml.feature.PCA;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;

public class NBTrainPCA {
    public static void main(String args[]){
        try{
            SparkConf conf = new SparkConf().setAppName("NBTrain");
            SparkContext scc = new SparkContext(conf);
            scc.setLogLevel("ERROR");
            JavaSparkContext sc = new JavaSparkContext(scc);
            SQLContext sqlc = new SQLContext(scc);
            DataFrame traindata = sqlc.read().format("parquet").load(args[0]).filter("user_email!='NA' and user_email!='00' and user_email!='0ed709b5bec77b6bff96ea5b5e334a8e5' and user_email is not null  and ip is not null  and region_code is not null and city is not null and browser_name is not null and os_name is not null");
            traindata.registerTempTable("master");
            //DataFrame data = sqlc.sql("select user_email,user_device,ip,country_code,region_code,city,zip_code,time_zone,browser_name,browser_manf,os_name,os_manf from master where user_email!='NA' and user_email is not null and user_device is not null and ip is not null and country_code is not null and region_code is not null and city is not null and browser_name is not null and browser_manf is not null and zip_code is not null and time_zone is not null and os_name is not null and os_manf is not null");
            StringIndexerModel emailIndexer = new StringIndexer()
              .setInputCol("user_email")
              .setOutputCol("email_index")
              .setHandleInvalid("skip")
              .fit(traindata);
            StringIndexer udevIndexer = new StringIndexer()
              .setInputCol("user_device")
              .setOutputCol("udev_index")
              .setHandleInvalid("skip");
            StringIndexer ipIndexer = new StringIndexer()
              .setInputCol("ip")
              .setOutputCol("ip_index")
              .setHandleInvalid("skip");
            StringIndexer ccodeIndexer = new StringIndexer()
              .setInputCol("country_code")
              .setOutputCol("ccode_index")
              .setHandleInvalid("skip");
            StringIndexer rcodeIndexer = new StringIndexer()
              .setInputCol("region_code")
              .setOutputCol("rcode_index")
              .setHandleInvalid("skip");
            StringIndexer cyIndexer = new StringIndexer()
              .setInputCol("city")
              .setOutputCol("cy_index")
              .setHandleInvalid("skip");
            StringIndexer zpIndexer = new StringIndexer()
              .setInputCol("zip_code")
              .setOutputCol("zp_index")
              .setHandleInvalid("skip");
            StringIndexer tzIndexer = new StringIndexer()
              .setInputCol("time_zone")
              .setOutputCol("tz_index")
              .setHandleInvalid("skip");
            StringIndexer bnIndexer = new StringIndexer()
              .setInputCol("browser_name")
              .setOutputCol("bn_index")
              .setHandleInvalid("skip");
            StringIndexer bmIndexer = new StringIndexer()
              .setInputCol("browser_manf")
              .setOutputCol("bm_index")
              .setHandleInvalid("skip");
            StringIndexer bvIndexer = new StringIndexer()
              .setInputCol("browser_version")
              .setOutputCol("bv_index")
              .setHandleInvalid("skip");
            StringIndexer onIndexer = new StringIndexer()
              .setInputCol("os_name")
              .setOutputCol("on_index")
              .setHandleInvalid("skip");
            StringIndexer omIndexer = new StringIndexer()
              .setInputCol("os_manf")
              .setOutputCol("om_index")
              .setHandleInvalid("skip");
            VectorAssembler assembler = new VectorAssembler()
              .setInputCols(new String[]{ "udev_index","ip_index","ccode_index","rcode_index","cy_index","zp_index","tz_index","bn_index","bm_index","bv_index","on_index","om_index"})
              .setOutputCol("ffeatures");
            Normalizer normalizer = new Normalizer()
              .setInputCol("ffeatures")
              .setOutputCol("sfeatures")
              .setP(1.0);
            PCA pca = new PCA()
                .setInputCol("sfeatures")
                .setOutputCol("pcafeatures")
                .setK(5);
            NaiveBayes nbcl = new NaiveBayes()
            .setFeaturesCol("pcafeatures")
            .setLabelCol("email_index")
            .setSmoothing(1.0);
            IndexToString is = new IndexToString()
            .setInputCol("prediction")
            .setOutputCol("op")
            .setLabels(emailIndexer.labels());
            Pipeline pipeline = new Pipeline()
              .setStages(new PipelineStage[] {emailIndexer,udevIndexer,ipIndexer,ccodeIndexer,rcodeIndexer,cyIndexer,zpIndexer,tzIndexer,bnIndexer,bmIndexer,bvIndexer,onIndexer,omIndexer,assembler,normalizer,pca,nbcl,is});
            PipelineModel model = pipeline.fit(traindata);
            //DataFrame chidata = model.transform(data);
            //chidata.write().format("com.databricks.spark.csv").save(args[1]);
            model.write().overwrite().save(args[1]);
            sc.close();
            }
            catch(Exception e){

            }
    }
}

Upvotes: 0

Views: 935

Answers (1)

Dr VComas
Dr VComas

Reputation: 735

I would suggest you to read a bit about PCA so you can get a better feeling of what it is doing. Here some links:

https://stats.stackexchange.com/questions/26352/interpreting-positive-and-negative-signs-of-the-elements-of-pca-eigenvectors

https://stats.stackexchange.com/questions/2691/making-sense-of-principal-component-analysis-eigenvectors-eigenvalues

On the ALS integration to your pipeline seems like you just want to plug one thing after the other. Better to understand what each of them is doing and used for: ALS and PCA are quite different things. ALS is doing matrix factorization using AlS for error minimization, is not finding any principal component to apply a transformation to the data, or dimensionality reduction.

BTW: I do not see any problems getting negative values in the PCA components vector. You can check this in the links above. You are applying a linear transformation to the data. So the new vectors are now a result of the transformation. I hope it helps.

Upvotes: 0

Related Questions