Reputation: 1
I am testing the TableAggregateFunction in flink as Java udf and Python udf. The Java udf emit function is called after intermediate aggregation values, where as for python udf the emit function call is only getting called once at end of the accumulation or different from the number of times emit function is called in java udf.
The java code.
package com.company.test;
import org.apache.flink.table.api.*;
import org.apache.flink.types.Row;
import static org.apache.flink.table.api.Expressions.$;
import static org.apache.flink.table.api.Expressions.call;
public class AggregateExample {
public static void main(String[] args) throws Exception {
EnvironmentSettings settings = EnvironmentSettings
.newInstance()
.inStreamingMode()
.build();
TableEnvironment env = TableEnvironment.create(settings);
Table orders = env.fromValues(
DataTypes.of("ROW<userid STRING, count INT>"),
// before onboarding
Row.of("user1", 2),
Row.of("user1", 1),
Row.of("user1", 3),
Row.of("user2", 1),
Row.of("user3", 12),
Row.of("user3", 1),
Row.of("user3", 5));
//env.createTemporaryView("Orders", orders);
// call function "inline" without registration in Table API
Table rs1 = orders.groupBy($("userid"))
.flatAggregate(call(Top2.class, $("count")))
.select($("userid"), $("f0"), $("f1"));
rs1.execute().print();
// call function "inline" without registration in Table API
// but use an alias for a better naming of Tuple2's fields
Table rs2 = orders.groupBy($("userid"))
.flatAggregate(call(Top2.class, $("count")).as("count", "rank"))
.select($("userid"), $("count"), $("rank"));
rs2.execute().print();
// register function
env.createTemporarySystemFunction("Top2", Top2.class);
// call registered function in Table API
Table rs3 = orders.groupBy($("userid"))
.flatAggregate(call("Top2", $("count")).as("count", "rank"))
.select($("userid"), $("count"), $("rank"));
rs3.execute().print();
env.getConfig().getConfiguration().setString("python.files", "<base_dir>/src/signals/python/udfs.py");
env.getConfig().getConfiguration().setString("python.client.executable", "python");
env.getConfig().getConfiguration().setString("python.executable", "python");
//register function to call python class
env.executeSql("create temporary system function top2_python as 'udfs.top2' language python");
;
Table rs4 = orders.groupBy($("userid"))
.flatAggregate(call("top2_python", $("count")).as("count", "rank"))
.select($("userid"), $("count"), $("rank"));
rs4.execute().print();
}
}
Top2.java
package com.company.test;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.table.functions.TableAggregateFunction;
import org.apache.flink.util.Collector;
public class Top2 extends TableAggregateFunction<Tuple2<Integer, Integer>, Top2Accumulator> {
@Override
public Top2Accumulator createAccumulator() {
Top2Accumulator acc = new Top2Accumulator();
acc.first = Integer.MIN_VALUE;
acc.second = Integer.MIN_VALUE;
return acc;
}
public void accumulate(Top2Accumulator acc, Integer value) {
if (value > acc.first) {
acc.second = acc.first;
acc.first = value;
} else if (value > acc.second) {
acc.second = value;
}
}
public void merge(Top2Accumulator acc, Iterable<Top2Accumulator> it) {
for (Top2Accumulator otherAcc : it) {
accumulate(acc, otherAcc.first);
accumulate(acc, otherAcc.second);
}
}
public void emitValue(Top2Accumulator acc, Collector<Tuple2<Integer, Integer>> out) {
// emit the value and rank
if (acc.first != Integer.MIN_VALUE) {
out.collect(Tuple2.of(acc.first, 1));
}
if (acc.second != Integer.MIN_VALUE) {
out.collect(Tuple2.of(acc.second, 2));
}
}
}
Top2Accumulator.java
package com.company.test;
public class Top2Accumulator {
public Integer first;
public Integer second;
}
udfs.py
rom pyflink.common import Row
from pyflink.table import DataTypes, TableEnvironment, EnvironmentSettings
from pyflink.table.udf import udtaf, TableAggregateFunction
class Top2(TableAggregateFunction):
def emit_value(self, accumulator):
yield Row(accumulator[0],1)
yield Row(accumulator[1],2)
def create_accumulator(self):
return [0, 0]
def accumulate(self, accumulator, value):
if value > accumulator[0] :
accumulator[1] = accumulator[0]
accumulator[0] = value
elif value > accumulator[1] :
accumulator[1] = value
def get_accumulator_type(self):
return DataTypes.ARRAY(DataTypes.BIGINT())
def get_result_type(self):
return DataTypes.ROW(
[DataTypes.FIELD("count", DataTypes.BIGINT()),DataTypes.FIELD("rank", DataTypes.BIGINT())])
top2 = udtaf(Top2(), result_type=DataTypes.ROW([DataTypes.FIELD("count", DataTypes.BIGINT()),DataTypes.FIELD("rank", DataTypes.BIGINT())]))
The output result
[INFO] BUILD SUCCESS
[INFO] ------------------------------------------------------------------------
[INFO] Total time: 6.359 s
[INFO] Finished at: 2024-05-14T17:41:52+02:00
[INFO] ------------------------------------------------------------------------
Job has been submitted with JobID 759f53fd7e9014acc2f18d9973768d52
+----+--------------------------------+-------------+-------------+
| op | userid | f0 | f1 |
+----+--------------------------------+-------------+-------------+
| +I | user1 | 2 | 1 |
| -D | user1 | 2 | 1 |
| +I | user1 | 2 | 1 |
| +I | user1 | 1 | 2 |
| -D | user1 | 2 | 1 |
| -D | user1 | 1 | 2 |
| +I | user1 | 3 | 1 |
| +I | user1 | 2 | 2 |
| +I | user2 | 1 | 1 |
| +I | user3 | 12 | 1 |
| -D | user3 | 12 | 1 |
| +I | user3 | 12 | 1 |
| +I | user3 | 1 | 2 |
| -D | user3 | 12 | 1 |
| -D | user3 | 1 | 2 |
| +I | user3 | 12 | 1 |
| +I | user3 | 5 | 2 |
+----+--------------------------------+-------------+-------------+
17 rows in set
Job has been submitted with JobID 91160b46db3925f1eb446a70ea8bcd1e
+----+--------------------------------+-------------+-------------+
| op | userid | count | rank |
+----+--------------------------------+-------------+-------------+
| +I | user1 | 2 | 1 |
| -D | user1 | 2 | 1 |
| +I | user1 | 2 | 1 |
| +I | user1 | 1 | 2 |
| -D | user1 | 2 | 1 |
| -D | user1 | 1 | 2 |
| +I | user1 | 3 | 1 |
| +I | user1 | 2 | 2 |
| +I | user2 | 1 | 1 |
| +I | user3 | 12 | 1 |
| -D | user3 | 12 | 1 |
| +I | user3 | 12 | 1 |
| +I | user3 | 1 | 2 |
| -D | user3 | 12 | 1 |
| -D | user3 | 1 | 2 |
| +I | user3 | 12 | 1 |
| +I | user3 | 5 | 2 |
+----+--------------------------------+-------------+-------------+
17 rows in set
Job has been submitted with JobID ea686c8d55e121eec3a3effae804c4ee
+----+--------------------------------+-------------+-------------+
| op | userid | count | rank |
+----+--------------------------------+-------------+-------------+
| +I | user1 | 2 | 1 |
| -D | user1 | 2 | 1 |
| +I | user1 | 2 | 1 |
| +I | user1 | 1 | 2 |
| -D | user1 | 2 | 1 |
| -D | user1 | 1 | 2 |
| +I | user1 | 3 | 1 |
| +I | user1 | 2 | 2 |
| +I | user2 | 1 | 1 |
| +I | user3 | 12 | 1 |
| -D | user3 | 12 | 1 |
| +I | user3 | 12 | 1 |
| +I | user3 | 1 | 2 |
| -D | user3 | 12 | 1 |
| -D | user3 | 1 | 2 |
| +I | user3 | 12 | 1 |
| +I | user3 | 5 | 2 |
+----+--------------------------------+-------------+-------------+
17 rows in set
Job has been submitted with JobID 4fc3bcf9bcaba7f693eaaeeaf75e7cfe
+----+--------------------------------+----------------------+----------------------+
| op | userid | count | rank |
+----+--------------------------------+----------------------+----------------------+
| +I | user1 | 2 | 1 |
| +I | user1 | 0 | 2 |
| -D | user1 | 2 | 1 |
| -D | user1 | 0 | 2 |
| +I | user1 | 3 | 1 |
| +I | user1 | 2 | 2 |
| +I | user2 | 1 | 1 |
| +I | user2 | 0 | 2 |
| +I | user3 | 12 | 1 |
| +I | user3 | 5 | 2 |
+----+--------------------------------+----------------------+----------------------+
10 rows in set
The output from the python is 10 rows where as from java udf is 17 rows. The python table aggregate is missing the intermediate emits like java. Is there a programmatic configuration which will make python udf and java udf behavior the same.
Upvotes: 0
Views: 62