HHH
HHH

Reputation: 6475

How to save the result of a machine learning model into a file in Spark

I've used KMeans from Spark MLLib as the code below shows, but I don't know how to save the results of the clustering into a file? For example, how to store the cluster centers into a file or how to store points belonging to the same cluster?

public final class JavaKMeans {

private static class ParsePoint implements Function<String, Vector> {
private static final Pattern SPACE = Pattern.compile(" ");

@Override
public Vector call(String line) {
  String[] tok = SPACE.split(line);
  double[] point = new double[tok.length];
  for (int i = 0; i < tok.length; ++i) {
    point[i] = Double.parseDouble(tok[i]);
  }
  return Vectors.dense(point);
}
}

public static void main(String[] args) {
if (args.length < 3) {
  System.err.println(
    "Usage: JavaKMeans <input_file> <k> <max_iterations> [<runs>]");
  System.exit(1);
}
String inputFile = args[0];
int k = Integer.parseInt(args[1]);
int iterations = Integer.parseInt(args[2]);
int runs = 1;

if (args.length >= 4) {
  runs = Integer.parseInt(args[3]);
}
SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
JavaRDD<String> lines = sc.textFile(inputFile);


JavaRDD<Vector> points = lines.map(new ParsePoint());

KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL());

System.out.println("Cluster centers:");
for (Vector center : model.clusterCenters()) {

  System.out.println(" " + center);
}
double cost = model.computeCost(points.rdd());
System.out.println("Cost: " + cost);

sc.stop();
}

Upvotes: 0

Views: 1240

Answers (1)

Thomas Jungblut
Thomas Jungblut

Reputation: 20969

You can simply use the toArray() method to write the underlying array to a file.

Simplistic Java 8 code might look like this:

Files.write(
     Paths.get("some_file.txt"), 
     model.clusterCenters()
       .stream().map(a -> Arrays.toString(a.toArray()))
       .collect(Collectors.toList())
 );

Non Java 8 looked like this:

List<String> lines = new ArrayList<>();
for (Vector center : model.clusterCenters()) {       
   lines.add(Arrays.toString(center.toArray()));
}

Files.write(Paths.get("some_file.txt"), lines, Charset.defaultCharset());

Upvotes: 1

Related Questions