Reputation: 566
I have a spark data frame like below with 7k columns.
+---+----+----+----+----+----+----+
| id| 1| 2| 3|sf_1|sf_2|sf_3|
+---+----+----+----+----+----+----+
| 2|null|null|null| 102| 202| 302|
| 4|null|null|null| 104| 204| 304|
| 1|null|null|null| 101| 201| 301|
| 3|null|null|null| 103| 203| 303|
| 1| 11| 21| 31|null|null|null|
| 2| 12| 22| 32|null|null|null|
| 4| 14| 24| 34|null|null|null|
| 3| 13| 23| 33|null|null|null|
+---+----+----+----+----+----+----+
I wanted to transform data frame like below by merging null rows. by doing the groupBy operation I'm able to merge it as single row, but the performance of this aggregation is very poor as I have 7k columns in my table.
import pyspark.sql.functions as F
(df.groupBy('id').agg(*[F.first(x,ignorenulls=True) for x in df.columns if x!='id'])
.show())
+---+----+----+----+----+----+----+
| id| 1| 2| 3|sf_1|sf_2|sf_3|
+---+----+----+----+----+----+----+
| 1| 11| 21| 31| 101| 201| 301|
| 2| 12| 22| 32| 102| 202| 302|
| 4| 14| 24| 34| 104| 204| 304|
| 3| 13| 23| 33| 103| 203| 303|
+---+----+----+----+----+----+----+
Any other recommendations/optimizations/efficient way of doing. Thanks
update1: after trying out with self join
---------------------------------------------------------------------------
Py4JJavaError Traceback (most recent call last)
<ipython-input-17-b7de100341cc> in <module>
15 """.format(table_name, query, join_key)
16
---> 17 spark.sql(final_query).dropDuplicates().filter(filters).count()
~/quartic/spark-3.0.0-bin-hadoop2.7/python/pyspark/sql/dataframe.py in count(self)
583 2
584 """
--> 585 return int(self._jdf.count())
586
587 @ignore_unicode_prefix
~/quartic/spark-3.0.0-bin-hadoop2.7/python/lib/py4j-0.10.9-src.zip/py4j/java_gateway.py in __call__(self, *args)
1303 answer = self.gateway_client.send_command(command)
1304 return_value = get_return_value(
-> 1305 answer, self.gateway_client, self.target_id, self.name)
1306
1307 for temp_arg in temp_args:
~/quartic/spark-3.0.0-bin-hadoop2.7/python/pyspark/sql/utils.py in deco(*a, **kw)
129 def deco(*a, **kw):
130 try:
--> 131 return f(*a, **kw)
132 except py4j.protocol.Py4JJavaError as e:
133 converted = convert_exception(e.java_exception)
~/quartic/spark-3.0.0-bin-hadoop2.7/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name)
326 raise Py4JJavaError(
327 "An error occurred while calling {0}{1}{2}.\n".
--> 328 format(target_id, ".", name), value)
329 else:
330 raise Py4JError(
Py4JJavaError: An error occurred while calling o148.count.
: java.lang.StackOverflowError
at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:35)
at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
at scala.collection.generic.Growable.$plus$plus$eq(Growable.scala:62)
at scala.collection.generic.Growable.$plus$plus$eq$(Growable.scala:53)
at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:184)
at scala.collection.mutable.ListBuffer.$plus$plus$eq(ListBuffer.scala:47)
at scala.collection.generic.GenericCompanion.apply(GenericCompanion.scala:53)
at org.apache.spark.sql.catalyst.expressions.BinaryExpression.children(Expression.scala:533)
at org.apache.spark.sql.catalyst.trees.TreeNode.containsChild$lzycompute(TreeNode.scala:115)
at org.apache.spark.sql.catalyst.trees.TreeNode.containsChild(TreeNode.scala:115)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:349)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$1(TreeNode.scala:330)
at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:399)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:330)
Upvotes: 0
Views: 1874
Reputation: 1892
You can use self join like below.
from pyspark.sql.types import IntegerType, StructField, StructType
values_arr = [
(2,None, None,None,102, 202, 302),
(4,None, None,None,104, 204, 304),
(1,None, None,None,101, 201, 301),
(3,None, None,None,103, 203, 303),
(1,11, 21,31,None,None,None),
(2,12, 22,32,None,None,None),
(4,14, 24,34,None,None,None),
(3,13, 23,33,None,None,None)
]
sc = spark.sparkContext
rdd = sc.parallelize(values_arr)
schema = StructType([
StructField("id", IntegerType(), True),
StructField("col_1", IntegerType(), True),
StructField("col_2", IntegerType(), True),
StructField("col_3", IntegerType(), True),
StructField("sf_1", IntegerType(), True),
StructField("sf_2", IntegerType(), True),
StructField("sf_3", IntegerType(), True)
])
df = spark.createDataFrame(rdd, schema)
df.show()
//Sample Inpput
+---+-----+-----+-----+----+----+----+
| id|col_1|col_2|col_3|sf_1|sf_2|sf_3|
+---+-----+-----+-----+----+----+----+
| 2| null| null| null| 102| 202| 302|
| 4| null| null| null| 104| 204| 304|
| 1| null| null| null| 101| 201| 301|
| 3| null| null| null| 103| 203| 303|
| 1| 11| 21| 31|null|null|null|
| 2| 12| 22| 32|null|null|null|
| 4| 14| 24| 34|null|null|null|
| 3| 13| 23| 33|null|null|null|
+---+-----+-----+-----+----+----+----+
//Solution
df.createTempView("my_table")
query="select l.id as id,r.col_1 as col_1, r.col_2 as col_2, r.col_3 as col_3, l.sf_1 as sf_1, l.sf_2 as sf_2,l.sf_3 as sf_3 from my_table l, my_table r where l.id=r.id and r.col_1 is not null and l.sf_1 is not null"
spark.sql(query).show()
//Sample output:
+---+-----+-----+-----+----+----+----+
| id|col_1|col_2|col_3|sf_1|sf_2|sf_3|
+---+-----+-----+-----+----+----+----+
| 1| 11| 21| 31| 101| 201| 301|
| 3| 13| 23| 33| 103| 203| 303|
| 4| 14| 24| 34| 104| 204| 304|
| 2| 12| 22| 32| 102| 202| 302|
+---+-----+-----+-----+----+----+----+
Upvotes: -1
Reputation: 1157
You can try this solution. Let me know if it is fast.
from pyspark.sql.types import IntegerType, StructField, StructType
values = [
(2,None, None,None,102, 202, 302),
(4,None, None,None,104, 204, 304),
(1,None, None,None,101, 201, 301),
(3,None, None,None,103, 203, 303),
(1,11, 21,31,None,None,None),
(2,12, 22,32,None,None,None),
(4,14, 24,34,None,None,None),
(3,13, 23,33,None,None,None)
]
sc = spark.sparkContext
rdd = sc.parallelize(values)
schema = StructType([
StructField("id", IntegerType(), True),
StructField("col1", IntegerType(), True),
StructField("col2", IntegerType(), True),
StructField("col3", IntegerType(), True),
StructField("sf_1", IntegerType(), True),
StructField("sf_2", IntegerType(), True),
StructField("sf_3", IntegerType(), True)
])
data = spark.createDataFrame(rdd, schema)
data.show()
# +---+----+----+----+----+----+----+
# | id|col1|col2|col3|sf_1|sf_2|sf_3|
# +---+----+----+----+----+----+----+
# | 2|null|null|null| 102| 202| 302|
# | 4|null|null|null| 104| 204| 304|
# | 1|null|null|null| 101| 201| 301|
# | 3|null|null|null| 103| 203| 303|
# | 1| 11| 21| 31|null|null|null|
# | 2| 12| 22| 32|null|null|null|
# | 4| 14| 24| 34|null|null|null|
# | 3| 13| 23| 33|null|null|null|
# +---+----+----+----+----+----+----+
data.createOrReplaceTempView("data")
join_key = 'id'
table_name = 'data'
query = "{0}".format(join_key)
filters = ""
for index, column_name in enumerate(data.columns):
if join_key != column_name:
query += ",\n\t case when a." + column_name + " is null then b." + column_name + " else a." + column_name + " end as " + column_name
filters += "\nAND {0} IS NOT NULL".format(column_name) if index !=1 else " {0} IS NOT NULL".format(column_name)
final_query ="""
SELECT a.{1}
FROM {0} a INNER JOIN {0} b ON a.{2} = b.{2}
""".format(table_name, query, join_key)
print(final_query)
# SELECT a.id,
# case when a.col1 is null then b.col1 else a.col1 end as col1,
# case when a.col2 is null then b.col2 else a.col2 end as col2,
# case when a.col3 is null then b.col3 else a.col3 end as col3,
# case when a.sf_1 is null then b.sf_1 else a.sf_1 end as sf_1,
# case when a.sf_2 is null then b.sf_2 else a.sf_2 end as sf_2,
# case when a.sf_3 is null then b.sf_3 else a.sf_3 end as sf_3
# FROM data a INNER JOIN data b ON a.id = b.id
print(filters)
# col1 IS NOT NULL
# AND col2 IS NOT NULL
# AND col3 IS NOT NULL
# AND sf_1 IS NOT NULL
# AND sf_2 IS NOT NULL
# AND sf_3 IS NOT NULL
spark.sql(final_query).dropDuplicates().filter(filters).show()
# +---+----+----+----+----+----+----+
# | id|col1|col2|col3|sf_1|sf_2|sf_3|
# +---+----+----+----+----+----+----+
# | 1| 11| 21| 31| 101| 201| 301|
# | 3| 13| 23| 33| 103| 203| 303|
# | 4| 14| 24| 34| 104| 204| 304|
# | 2| 12| 22| 32| 102| 202| 302|
# +---+----+----+----+----+----+----+
Upvotes: 2