moudi
moudi

Reputation: 147

Why a UDF doesnt recognize the column of the dataframe?

Suppose I have the following dataframe:

+-----------------+---------------------+
|       document1  |   document2        |
+-----------------+---------------------+
|    word1 word2  |   word2 word3       |
+-----------------+---------------------+

I need to add a new column to this dataframe called intersection which represent the INTERSECTIOn similarity between document1 and document2.

How can I manipulate the value in the column. I define a function called intersection taking two string in input but i cannot apply it to columns types. I think I should use UDF functions. How can I do that in Java. Noting Im using spark 2.3.0.

I tried the following:

SparkSession spark = SparkSession.builder().appName("spark session example").master("local[*]")
                .config("spark.sql.warehouse.dir", "/file:C:/tempWarehouse")
                .config("spark.sql.caseSensitive", "true")
                .getOrCreate();

sqlContext.udf().register("intersection", new UDF2<String, String, Double>() {
            @Override
            public Double call(String arg, String arg2) throws Exception {
            double key = inter(arg, arg2);
            return key;
            }
            }, DataTypes.DoubleType);
  v.registerTempTable("v_table");

Dataset<Row> df = spark.sql("select v_table.document, v_table.document1, "
                + "intersection(v_table.document, v_table.document1) as RowKey1,"
                + " from v_table");
        df.show();

but i get the following exception:

    INFO SparkSqlParser: Parsing command: select v_table.document, v_table.document1, intersection(v_table.document, v_table.document1) as RowKey1, from v_table
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve '`v_table.document`' given input columns: []; line 1 pos 7

If I delete + ", intersection(v.doc1, v.doc2) as RowKey1," from the query, the select works fine. Any suggestion please? In addition please how can I use the same approach by using only on the dataframe and not like i do i sql way?

the schema of "v" using v.printSchema(); is:

root
 |-- document: string (nullable = true)
 |-- document1: string (nullable = true)

Upvotes: 1

Views: 261

Answers (1)

Highbrainer
Highbrainer

Reputation: 760

I think I would work the other way.

Transform your dataset into two Datasets of work : one for doc1 and one for doc 2. First split your line into an array of words then explode. Then all you have to do is keep the intersection..

Something like that :

Dataset<Row> ds = spark.sql("select 'word1 word2' as document1, 'word2 word3' as document2");
ds.show();

Dataset<Row> ds1 = ds.select(functions.explode(functions.split(ds.col("document1"), " ")).as("word"));
Dataset<Row> ds2 = ds.select(functions.explode(functions.split(ds.col("document2"), " ")).as("word"));      

Dataset<Row> intersection = ds1.join(ds2, ds1.col("word").equalTo(ds2.col("word"))).select(ds1.col("word").as("Common words"));
intersection.show();

Ouput:

+-----------+-----------+
|  document1|  document2|
+-----------+-----------+
|word1 word2|word2 word3|
+-----------+-----------+
+------------+
|Common words|
+------------+
|       word2|
+------------+

Anyway, if your goal is 'only' to call a custom UDF onto two columns, here's how I would do :

1. Create your UDF

UDF2<String, String, String> intersection = new UDF2<String, String, String>() {
    @Override
    public String call(String arg, String arg2) throws Exception {
        String key = inter(arg, arg2);
        return key;
    }

    private String inter(String arg1, String arg2) {
        // this part of the implementation is just to stay in line with the previous part of this answer
        List<String> list1 = Arrays.asList(arg1.split(" "));
        return Stream.of(arg2.split(" ")).filter(list1::contains).collect(Collectors.joining(" "));
    }
};

2. Register and use it!

pure java

UserDefinedFunction intersect = functions.udf(intersection, DataTypes.StringType);      

Dataset<Row> ds1 = ds.select(ds.col("document1"), ds.col("document2"), intersect.apply(ds.col("document1"), ds.col("document2")));
ds1.show();

sql

spark.sqlContext().udf().register("intersect", intersection, DataTypes.StringType);
Dataset<Row> df = spark.sql("select document1, document2, "
                + "intersect(document1, document2) as RowKey1"
                + " from v_table");
df.show();

Output

+-----------+-----------+-------+
|  document1|  document2|RowKey1|
+-----------+-----------+-------+
|word1 word2|word2 word3|  word2|
+-----------+-----------+-------+

Complete code

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;

public class StackOverflowUDF {
    public static void main(String args[]) {
        SparkSession spark = SparkSession.builder().appName("JavaWordCount").master("local").getOrCreate();

        Dataset<Row> ds = spark.sql("select 'word1 word2' as document1, 'word2 word3' as document2");
        ds.show();

        UDF2<String, String, String> intersection = new UDF2<String, String, String>() {
            @Override
            public String call(String arg, String arg2) throws Exception {
                String key = inter(arg, arg2);
                return key;
            }

            private String inter(String arg1, String arg2) {
                List<String> list1 = Arrays.asList(arg1.split(" "));
                return Stream.of(arg2.split(" ")).filter(list1::contains).collect(Collectors.joining(" "));
            }
        };

        UserDefinedFunction intersect = functions.udf(intersection, DataTypes.StringType);

        Dataset<Row> ds1 = ds.select(ds.col("document1"), ds.col("document2"),
                intersect.apply(ds.col("document1"), ds.col("document2")));
        ds1.show();
        ds1.printSchema();

        ds.createOrReplaceTempView("v_table");

        spark.sqlContext().udf().register("intersect", intersection, DataTypes.StringType);
        Dataset<Row> df = spark
                .sql("select document1, document2, " + "intersect(document1, document2) as RowKey1" + " from v_table");
        df.show();
        spark.stop();

    }
}

HTH!

Upvotes: 2

Related Questions