Learnis
Learnis

Reputation: 566

Spark combine multiple rows to Single row base on specific Column with out groupBy operation

I have a spark data frame like below with 7k columns.

+---+----+----+----+----+----+----+
| id|   1|   2|   3|sf_1|sf_2|sf_3|
+---+----+----+----+----+----+----+
|  2|null|null|null| 102| 202| 302|
|  4|null|null|null| 104| 204| 304|
|  1|null|null|null| 101| 201| 301|
|  3|null|null|null| 103| 203| 303|
|  1|  11|  21|  31|null|null|null|
|  2|  12|  22|  32|null|null|null|
|  4|  14|  24|  34|null|null|null|
|  3|  13|  23|  33|null|null|null|
+---+----+----+----+----+----+----+

I wanted to transform data frame like below by merging null rows. by doing the groupBy operation I'm able to merge it as single row, but the performance of this aggregation is very poor as I have 7k columns in my table.

import pyspark.sql.functions as F

(df.groupBy('id').agg(*[F.first(x,ignorenulls=True) for x in df.columns if x!='id'])
.show())
+---+----+----+----+----+----+----+
| id|   1|   2|   3|sf_1|sf_2|sf_3|
+---+----+----+----+----+----+----+
|  1|  11|  21|  31| 101| 201| 301|
|  2|  12|  22|  32| 102| 202| 302|
|  4|  14|  24|  34| 104| 204| 304|
|  3|  13|  23|  33| 103| 203| 303|
+---+----+----+----+----+----+----+

Any other recommendations/optimizations/efficient way of doing. Thanks

update1: after trying out with self join

---------------------------------------------------------------------------
Py4JJavaError                             Traceback (most recent call last)
<ipython-input-17-b7de100341cc> in <module>
     15 """.format(table_name, query, join_key)
     16 
---> 17 spark.sql(final_query).dropDuplicates().filter(filters).count()

~/quartic/spark-3.0.0-bin-hadoop2.7/python/pyspark/sql/dataframe.py in count(self)
    583         2
    584         """
--> 585         return int(self._jdf.count())
    586 
    587     @ignore_unicode_prefix

~/quartic/spark-3.0.0-bin-hadoop2.7/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
   1303         answer = self.gateway_client.send_command(command)
   1304         return_value = get_return_value(
-> 1305             answer, self.gateway_client, self.target_id, self.name)
   1306 
   1307         for temp_arg in temp_args:

~/quartic/spark-3.0.0-bin-hadoop2.7/python/pyspark/sql/utils.py in deco(*a, **kw)
    129     def deco(*a, **kw):
    130         try:
--> 131             return f(*a, **kw)
    132         except py4j.protocol.Py4JJavaError as e:
    133             converted = convert_exception(e.java_exception)

~/quartic/spark-3.0.0-bin-hadoop2.7/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
    326                 raise Py4JJavaError(
    327                     "An error occurred while calling {0}{1}{2}.\n".
--> 328                     format(target_id, ".", name), value)
    329             else:
    330                 raise Py4JError(

Py4JJavaError: An error occurred while calling o148.count.
: java.lang.StackOverflowError
    at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:35)
    at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
    at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
    at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
    at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
    at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:184)
    at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:47)
    at scala.collection.generic.GenericCompanion.apply(GenericCompanion.scala:53)
    at org.apache.spark.sql.catalyst.expressions.BinaryExpression.children(Expression.scala:533)
    at org.apache.spark.sql.catalyst.trees.TreeNode.containsChild$lzycompute(TreeNode.scala:115)
    at org.apache.spark.sql.catalyst.trees.TreeNode.containsChild(TreeNode.scala:115)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:349)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)

Upvotes: 0

Views: 1874

Answers (2)

Manoj Kumar Dhakad
Manoj Kumar Dhakad

Reputation: 1892

You can use self join like below.

from pyspark.sql.types import IntegerType, StructField, StructType

values_arr = [
(2,None, None,None,102, 202, 302),
(4,None, None,None,104, 204, 304),
(1,None, None,None,101, 201, 301),
(3,None, None,None,103, 203, 303),
(1,11, 21,31,None,None,None),
(2,12, 22,32,None,None,None),
(4,14, 24,34,None,None,None),
(3,13, 23,33,None,None,None)
]

sc = spark.sparkContext
rdd = sc.parallelize(values_arr)
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("col_1", IntegerType(), True),
    StructField("col_2", IntegerType(), True),
    StructField("col_3", IntegerType(), True),
    StructField("sf_1", IntegerType(), True),
    StructField("sf_2", IntegerType(), True),
    StructField("sf_3", IntegerType(), True)
])

df = spark.createDataFrame(rdd, schema)
df.show()

//Sample Inpput

+---+-----+-----+-----+----+----+----+
| id|col_1|col_2|col_3|sf_1|sf_2|sf_3|
+---+-----+-----+-----+----+----+----+
|  2| null| null| null| 102| 202| 302|
|  4| null| null| null| 104| 204| 304|
|  1| null| null| null| 101| 201| 301|
|  3| null| null| null| 103| 203| 303|
|  1|   11|   21|   31|null|null|null|
|  2|   12|   22|   32|null|null|null|
|  4|   14|   24|   34|null|null|null|
|  3|   13|   23|   33|null|null|null|
+---+-----+-----+-----+----+----+----+

//Solution
df.createTempView("my_table")
query="select l.id as id,r.col_1 as col_1, r.col_2 as col_2, r.col_3 as col_3, l.sf_1 as sf_1, l.sf_2 as sf_2,l.sf_3 as sf_3 from my_table l, my_table r where l.id=r.id and r.col_1 is not null and l.sf_1 is not null"

spark.sql(query).show()

//Sample output: 
+---+-----+-----+-----+----+----+----+
| id|col_1|col_2|col_3|sf_1|sf_2|sf_3|
+---+-----+-----+-----+----+----+----+
|  1|   11|   21|   31| 101| 201| 301|
|  3|   13|   23|   33| 103| 203| 303|
|  4|   14|   24|   34| 104| 204| 304|
|  2|   12|   22|   32| 102| 202| 302|
+---+-----+-----+-----+----+----+----+

Upvotes: -1

Manish
Manish

Reputation: 1157

You can try this solution. Let me know if it is fast.

from pyspark.sql.types import IntegerType, StructField, StructType

values = [
(2,None, None,None,102, 202, 302),
(4,None, None,None,104, 204, 304),
(1,None, None,None,101, 201, 301),
(3,None, None,None,103, 203, 303),
(1,11, 21,31,None,None,None),
(2,12, 22,32,None,None,None),
(4,14, 24,34,None,None,None),
(3,13, 23,33,None,None,None)
]

sc = spark.sparkContext
rdd = sc.parallelize(values)
schema = StructType([
    StructField("id", IntegerType(), True),
    StructField("col1", IntegerType(), True),
    StructField("col2", IntegerType(), True),
    StructField("col3", IntegerType(), True),
    StructField("sf_1", IntegerType(), True),
    StructField("sf_2", IntegerType(), True),
    StructField("sf_3", IntegerType(), True)
])

data = spark.createDataFrame(rdd, schema)
data.show()
# +---+----+----+----+----+----+----+
# | id|col1|col2|col3|sf_1|sf_2|sf_3|
# +---+----+----+----+----+----+----+
# |  2|null|null|null| 102| 202| 302|
# |  4|null|null|null| 104| 204| 304|
# |  1|null|null|null| 101| 201| 301|
# |  3|null|null|null| 103| 203| 303|
# |  1|  11|  21|  31|null|null|null|
# |  2|  12|  22|  32|null|null|null|
# |  4|  14|  24|  34|null|null|null|
# |  3|  13|  23|  33|null|null|null|
# +---+----+----+----+----+----+----+

data.createOrReplaceTempView("data")
join_key = 'id'
table_name = 'data'
query = "{0}".format(join_key)
filters = ""
for index, column_name in enumerate(data.columns):
    if join_key != column_name:
        query += ",\n\t case when a." + column_name + " is null then b." + column_name + " else a." + column_name + " end as " + column_name 
        filters += "\nAND {0} IS NOT NULL".format(column_name) if index !=1 else " {0} IS NOT NULL".format(column_name) 
final_query ="""
SELECT a.{1}
FROM {0} a INNER JOIN {0} b ON a.{2} = b.{2}
""".format(table_name, query, join_key)
print(final_query)
# SELECT a.id,
#    case when a.col1 is null then b.col1 else a.col1 end as col1,
#    case when a.col2 is null then b.col2 else a.col2 end as col2,
#    case when a.col3 is null then b.col3 else a.col3 end as col3,
#    case when a.sf_1 is null then b.sf_1 else a.sf_1 end as sf_1,
#    case when a.sf_2 is null then b.sf_2 else a.sf_2 end as sf_2,
#    case when a.sf_3 is null then b.sf_3 else a.sf_3 end as sf_3
# FROM data a INNER JOIN data b ON a.id = b.id

print(filters)
#  col1 IS NOT NULL
# AND col2 IS NOT NULL
# AND col3 IS NOT NULL
# AND sf_1 IS NOT NULL
# AND sf_2 IS NOT NULL
# AND sf_3 IS NOT NULL

spark.sql(final_query).dropDuplicates().filter(filters).show()
# +---+----+----+----+----+----+----+
# | id|col1|col2|col3|sf_1|sf_2|sf_3|
# +---+----+----+----+----+----+----+
# |  1|  11|  21|  31| 101| 201| 301|
# |  3|  13|  23|  33| 103| 203| 303|
# |  4|  14|  24|  34| 104| 204| 304|
# |  2|  12|  22|  32| 102| 202| 302|
# +---+----+----+----+----+----+----+

Upvotes: 2

Related Questions