Reputation: 1372
I want to count the different types of records in Java RDD basis a field in the Object.
I have an Entity class having name
and state
as a member variable of the Class. The Entity class looks like this:
import java.io.Serializable;
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public class Entity implements Serializable {
private final String name;
private final String state;
}
I have a javaRDD of Entity Object. I want to determine how many objects are present for each state in this RDD.
The current approach that I am using to do this is by using LongAccumulator. The idea is to iterate through each record in the RDD, parse the state field, and increment the count of the corresponding accumulator. The code which I have tried so far is:
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.util.LongAccumulator;
import java.util.ArrayList;
import java.util.List;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class CountRDD {
public static void main(String[] args) {
String applicationName = CountRDD.class.getName();
SparkConf sparkConf = new SparkConf().setAppName(applicationName).setMaster("local");
JavaSparkContext javaSparkContext = new JavaSparkContext(sparkConf);
javaSparkContext.setLogLevel("INFO");
Entity entity1 = new Entity("a1", "s1");
Entity entity2 = new Entity("a2", "s2");
Entity entity3 = new Entity("a3", "s1");
Entity entity4 = new Entity("a4", "s2");
Entity entity5 = new Entity("a5", "s1");
List<Entity> entityList = new ArrayList<Entity>();
entityList.add(entity1);
entityList.add(entity2);
entityList.add(entity3);
entityList.add(entity4);
entityList.add(entity5);
JavaRDD<Entity> entityJavaRDD = javaSparkContext.parallelize(entityList, 1);
LongAccumulator s1Accumulator = javaSparkContext.sc().longAccumulator("s1");
LongAccumulator s2Accumulator = javaSparkContext.sc().longAccumulator("s2");
entityJavaRDD.foreach(entity -> {
if (entity != null) {
String state = entity.getState();
if ("s1".equalsIgnoreCase(state)) {
s1Accumulator.add(1);
} else if ("s2".equalsIgnoreCase(state)) {
s2Accumulator.add(1);
}
}
});
log.info("Final values for input entity RDD are following");
log.info("s1Accumulator = {} ", s1Accumulator.value());
log.info("s2Accumulator = {} ", s2Accumulator.value());
}
}
The above code works and produces this output s1Accumulator = 3
and s2Accumulator = 2
.
The limitation of the above code is: We should know all the permissible value of the state before the execution and maintain the corresponding accumulator. This would simply make the code too big for a larger value of the state.
Another approach that I can think of is to create a new Pair RDD of String (state) and Integer (count). Apply the mapToPair
transformation on the input RDD, and get the count from this newly created RDD.
Any other thought about how can I approach this problem.
Upvotes: 0
Views: 705
Reputation: 131
As mentioned in the comments, you can groupBy
on the state field and then call count
on it, this will give you the count for each state. You don't need accumulators.
As a side note, jobs run with significantly better performance if you avoid lambda functions and use DataFrames (which is DataSet<Row>
). DataFrames provide better query optimization and code generation capabilities than RDDs and have vectorized (meaning: very fast) functions for most processing use cases.
The DataSet API javadoc has a DataFrame groupBy example in the description: https://spark.apache.org/docs/2.4.5/api/java/org/apache/spark/sql/Dataset.html
It is preferred to read data as DataFrames to begin with, but you can convert RDDs and JavaRDDs with SparkSession.createDateFrame
.
Upvotes: 2