mkproject
mkproject

Reputation: 21

Prediction with Apache Spark ML

I'm new to Apache Spark ML. I would like to get a prediction for a balance by age and country. As input I have a CSV file in the following format:

RowNumber,Age,Country,Balance

The model is built and can also be trained against the test data. Everything works so far.

My problem now is when I want to make a prediction for a new customer record

Dataset<Row> newCustomer = spark.createDataFrame(Collections.singletonList(
            new Customer(28, ‘Germany’)), Customer.class);
Dataset<Row> newCustomerPrediction = model.transform(newCustomer);

I get the following error message :

java.lang.IllegalArgumentException: CountryIndex does not exist. Available: age, country.

How can I get a prediction for the new dataset?

public static void main(String[] args) {

    SparkSession spark = SparkSession
        .builder()
        .master("local[*]") 
        .appName("JavaGeneralizedLinearRegressionExample")
        .getOrCreate();

    Dataset<Row> data = spark.read()
            .option("header", "true")
            .option("inferSchema", "true")
            .option("delimiter", ",")  // oder "," je nach Dateiformat
            .csv("/data/testdaten_v4.csv");

    StringIndexer countryIndexer = new StringIndexer()
            .setInputCol("Country")
            .setOutputCol("CountryIndex")
            .setHandleInvalid("skip");
    OneHotEncoder countryEncoder = new OneHotEncoder()
            .setInputCol("CountryIndex")
            .setOutputCol("CountryVec");

    VectorAssembler assembler = new VectorAssembler()
            .setInputCols(new String[]{"Age", "CountryVec"}) // andere Features hinzufügen falls nötig
            .setOutputCol("features");

    StandardScaler scaler = new StandardScaler()
            .setInputCol("features")
            .setOutputCol("scaledFeatures");

    LinearRegression lr = new LinearRegression()
            .setLabelCol("Balance")
            .setFeaturesCol("scaledFeatures")
            .setMaxIter(100)
            .setRegParam(0.3)
            .setElasticNetParam(0.8);

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{countryIndexer, countryEncoder, assembler, scaler, lr});

     PipelineModel model = pipeline.fit(data);
     
    Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2}, 42);
    Dataset<Row> trainData = splits[0];
    Dataset<Row> testData = splits[1];

    Dataset<Row> predictions = model.transform(testData);
    predictions.select("Age", "Country", "Balance", "prediction").show();
    
    RegressionEvaluator evaluator = new RegressionEvaluator()
            .setLabelCol("Balance")
            .setPredictionCol("prediction")
            .setMetricName("rmse");
    double rmse = evaluator.evaluate(predictions);
    
    
    Dataset<Row> newCustomer = spark.createDataFrame(Collections.singletonList(
            new Customer(28, "Germany")), Customer.class);
    Dataset<Row> newCustomerPrediction = model.transform(newCustomer);
    newCustomerPrediction.select("prediction").show();

    spark.stop();
}

public static class Customer {
    private int Age;
    private String Country;
     
    public Customer(int age, String country) {
        this.Age = age;
        this.Country = country;
    }

    public int getAge() { return Age; }
    public String getCountry() { return Country; } 
}

Upvotes: 0

Views: 24

Answers (0)

Related Questions