Reputation: 21
I'm new to Apache Spark ML. I would like to get a prediction for a balance by age and country. As input I have a CSV file in the following format:
RowNumber,Age,Country,Balance
The model is built and can also be trained against the test data. Everything works so far.
My problem now is when I want to make a prediction for a new customer record
Dataset<Row> newCustomer = spark.createDataFrame(Collections.singletonList(
new Customer(28, ‘Germany’)), Customer.class);
Dataset<Row> newCustomerPrediction = model.transform(newCustomer);
I get the following error message :
java.lang.IllegalArgumentException: CountryIndex does not exist. Available: age, country.
How can I get a prediction for the new dataset?
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.master("local[*]")
.appName("JavaGeneralizedLinearRegressionExample")
.getOrCreate();
Dataset<Row> data = spark.read()
.option("header", "true")
.option("inferSchema", "true")
.option("delimiter", ",") // oder "," je nach Dateiformat
.csv("/data/testdaten_v4.csv");
StringIndexer countryIndexer = new StringIndexer()
.setInputCol("Country")
.setOutputCol("CountryIndex")
.setHandleInvalid("skip");
OneHotEncoder countryEncoder = new OneHotEncoder()
.setInputCol("CountryIndex")
.setOutputCol("CountryVec");
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[]{"Age", "CountryVec"}) // andere Features hinzufügen falls nötig
.setOutputCol("features");
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
.setOutputCol("scaledFeatures");
LinearRegression lr = new LinearRegression()
.setLabelCol("Balance")
.setFeaturesCol("scaledFeatures")
.setMaxIter(100)
.setRegParam(0.3)
.setElasticNetParam(0.8);
Pipeline pipeline = new Pipeline()
.setStages(new PipelineStage[]{countryIndexer, countryEncoder, assembler, scaler, lr});
PipelineModel model = pipeline.fit(data);
Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2}, 42);
Dataset<Row> trainData = splits[0];
Dataset<Row> testData = splits[1];
Dataset<Row> predictions = model.transform(testData);
predictions.select("Age", "Country", "Balance", "prediction").show();
RegressionEvaluator evaluator = new RegressionEvaluator()
.setLabelCol("Balance")
.setPredictionCol("prediction")
.setMetricName("rmse");
double rmse = evaluator.evaluate(predictions);
Dataset<Row> newCustomer = spark.createDataFrame(Collections.singletonList(
new Customer(28, "Germany")), Customer.class);
Dataset<Row> newCustomerPrediction = model.transform(newCustomer);
newCustomerPrediction.select("prediction").show();
spark.stop();
}
public static class Customer {
private int Age;
private String Country;
public Customer(int age, String country) {
this.Age = age;
this.Country = country;
}
public int getAge() { return Age; }
public String getCountry() { return Country; }
}
Upvotes: 0
Views: 24