pain
pain

Reputation: 329

SparkException: Can't zip RDDs with unequal numbers of partitions: List(2, 1)

Possible steps to reproduce:

  1. Run spark.sql multiple times, get DataFrame list [d1, d2, d3, d4]
  2. Combine DataFrame list [d1, d2, d3, d4] to a DataFrame d5 by calling Dataset#unionByName
  3. Run d5.groupBy("c1").pivot("c2").agg(concat_ws(", ", collect_list("value"))),produce DataFrame d6
  4. DataFrame d6 join another DataFrame d7
  5. Call function like count to trigger spark job
  6. Exception happend

stack trace:

org.apache.spark.SparkException: Exception thrown in awaitResult:
at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:226)
at org.apache.spark.sql.execution.adaptive.QueryStage.executeChildStages(QueryStage.scala:88)
at org.apache.spark.sql.execution.adaptive.QueryStage.prepareExecuteStage(QueryStage.scala:136)
at org.apache.spark.sql.execution.adaptive.QueryStage.executeCollect(QueryStage.scala:242)
at org.apache.spark.sql.Dataset$$anonfun$count$1.apply(Dataset.scala:2837)
at org.apache.spark.sql.Dataset$$anonfun$count$1.apply(Dataset.scala:2836)
at org.apache.spark.sql.Dataset$$anonfun$52.apply(Dataset.scala:3441)
at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:92)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:139)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:87)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withAction(Dataset.scala:3440)
at org.apache.spark.sql.Dataset.count(Dataset.scala:2836)
at java.util.concurrent.FutureTask.run(FutureTask.java:266)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
**Caused by: java.lang.IllegalArgumentException: Can't zip RDDs with unequal numbers of partitions: List(2, 1)**
at org.apache.spark.rdd.ZippedPartitionsBaseRDD.getPartitions(ZippedPartitionsRDD.scala:57)
at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:273)
at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:269)
at scala.Option.getOrElse(Option.scala:121)
at org.apache.spark.rdd.RDD.partitions(RDD.scala:269)
at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:49)
at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:273)
at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:269)
at scala.Option.getOrElse(Option.scala:121)
at org.apache.spark.rdd.RDD.partitions(RDD.scala:269)
at org.apache.spark.rdd.MapPartitionsRDD.getPartitions(MapPartitionsRDD.scala:49)
at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:273)
at org.apache.spark.rdd.RDD$$anonfun$partitions$2.apply(RDD.scala:269)
at scala.Option.getOrElse(Option.scala:121)
at org.apache.spark.rdd.RDD.partitions(RDD.scala:269)
at org.apache.spark.ShuffleDependency.<init>(Dependency.scala:94)
at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.prepareShuffleDependency(ShuffleExchangeExec.scala:361)
at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.prepareShuffleDependency(ShuffleExchangeExec.scala:69)
at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec.eagerExecute(ShuffleExchangeExec.scala:112)
at org.apache.spark.sql.execution.adaptive.ShuffleQueryStage.executeStage(QueryStage.scala:284)
at org.apache.spark.sql.execution.adaptive.QueryStage.doExecute(QueryStage.scala:236)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:137)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:133)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:161)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:158)
at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:133)
at org.apache.spark.sql.execution.adaptive.QueryStage$$anonfun$8$$anonfun$apply$2$$anonfun$apply$3.apply(QueryStage.scala:81)
at org.apache.spark.sql.execution.adaptive.QueryStage$$anonfun$8$$anonfun$apply$2$$anonfun$apply$3.apply(QueryStage.scala:81)
at org.apache.spark.sql.execution.SQLExecution$.withExecutionIdAndJobDesc(SQLExecution.scala:157)
at org.apache.spark.sql.execution.adaptive.QueryStage$$anonfun$8$$anonfun$apply$2.apply(QueryStage.scala:80)
at org.apache.spark.sql.execution.adaptive.QueryStage$$anonfun$8$$anonfun$apply$2.apply(QueryStage.scala:78)
at scala.concurrent.impl.Future$PromiseCompletingRunnable.liftedTree1$1(Future.scala:24)
at scala.concurrent.impl.Future$PromiseCompletingRunnable.run(Future.scala:24)

There are three points to note:

  1. I've never called a method zip or anything like that
  2. When I set the parameter "spark.sql.adaptive.enabled" to "false" , the error disappear
  3. Others have encountered this problem:https://github.com/Intel-bigdata/spark-adaptive/issues/73
  4. Spark Version: 2.4.7

Unfortunately, I can't share all the code snippets. I removed some sensitive information, and then the code contained the main execution logic.

Another discovery is that if I use spark-shell instead of spark-submit to execute the task, even if the parameter "spark.sql.adaptive.enabled" is set to "true", the error disappear

val tagTableId = "customer_tag"
val tagMeta = Map(
    "t1" -> (
    "tagId" -> "t1",
    "tagName" -> "t1",
    "valueType" -> "String",
    "valueNumType" -> "multi"
    ),
    "t2" -> (
    "tagId" -> "t2",
    "tagName" -> "t2",
    "valueType" -> "String",
    "valueNumType" -> "multi"
    ),
    "t3" -> (
    "tagId" -> "t3",
    "tagName" -> "t3",
    "valueType" -> "String",
    "valueNumType" -> "multi"
    ),
    "t4" -> (
    "tagId" -> "t4",
    "tagName" -> "t4",
    "valueType" -> "String",
    "valueNumType" -> "multi"
    ),
    "t5" -> (
    "tagId" -> "t5",
    "tagName" -> "t5",
    "valueType" -> "String",
    "valueNumType" -> "single"
    ),
    "t6" -> (
    "tagId" -> "t6",
    "tagName" -> "t6",
    "valueType" -> "String",
    "valueNumType" -> "single"
    ),
    "t7" -> (
    "tagId" -> "t7",
    "tagName" -> "t7",
    "valueType" -> "String",
    "valueNumType" -> "multi"
    ),
    "t8" -> (
    "tagId" -> "t8",
    "tagName" -> "t8",
    "valueType" -> "String",
    "valueNumType" -> "single"
    ),
    "t9" -> (
    "tagId" -> "t9",
    "tagName" -> "t9",
    "valueType" -> "String",
    "valueNumType" -> "single"
    ),
    "t10" -> (
    "tagId" -> "t10",
    "tagName" -> "t10",
    "valueType" -> "String",
    "valueNumType" -> "multi"
    )
)
val textTagIds = new util.ArrayList[String]()
val numTagIds = new util.ArrayList[String]()
val dateTagIds = new util.ArrayList[String]()
val dateTimeTagIds = new util.ArrayList[String]()

tagMeta.foreach(item => {
    val tagId = item._1
    val valueType = item._2._3._2

    valueType match {
    case "String" =>
        textTagIds.add(tagId)
    case "Number" =>
        numTagIds.add(tagId)
    case "Date" =>
        dateTagIds.add(tagId)
    case "DateTime" =>
        dateTimeTagIds.add(tagId)
    case _ =>
        throw new RuntimeException(s"invalid valueType: $valueType")
    }
})

val identitySql = "SELECT _uid, _type, _value, row_number() over(partition by _uid, _type order by _value desc) as rn FROM customer_identity WHERE _type IN ('membership_id')"
var oneDs = spark.sql(identitySql)
oneDs.createOrReplaceTempView("u")
oneDs = spark.sql(s"SELECT _uid, _type, _value FROM u WHERE rn <= 1")
    .groupBy("_uid")
    .pivot("_type")
    .agg(collect_list("_value").as("_value"))
oneDs.createOrReplaceTempView("u")

var textFrame: DataFrame = null
var numFrame: DataFrame = null
var dateFrame: DataFrame = null
var datetimeFrame: DataFrame = null

if (textTagIds.nonEmpty) {
    val tagIdsText = textTagIds.mkString("', '")
    val sql = s"SELECT _profile_id, tag_id, _value_text AS value, _weight AS weight FROM $tagTableId WHERE tag_id IN ('$tagIdsText')"
    textFrame = spark.sql(sql)
}

if (numTagIds.nonEmpty) {
    val tagIdsText = numTagIds.mkString("', '")
    val sql = s"SELECT _profile_id, tag_id, _value_num AS value, _weight AS weight FROM $tagTableId WHERE tag_id IN ('$tagIdsText')"
    numFrame = spark.sql(sql)
}

if (dateTagIds.nonEmpty) {
    val tagIdsText = dateTagIds.mkString("', '")
    val sql = s"SELECT _profile_id, tag_id, _value_date AS value, _weight AS weight FROM $tagTableId WHERE tag_id IN ('$tagIdsText')"
    dateFrame = spark.sql(sql).withColumn("value", date_format(col("value"), "yyyy-MM-dd"))
}

if (dateTimeTagIds.nonEmpty) {
    val tagIdsText = dateTimeTagIds.mkString("', '")
    val sql = s"SELECT _profile_id, tag_id, _value_date AS value, _weight AS weight FROM $tagTableId WHERE tag_id IN ('$tagIdsText')"
    datetimeFrame = spark.sql(sql).withColumn("value", date_format(col("value"), "yyyy-MM-dd'T'HH:mm:ss'Z'"))
}

var tagFrame: DataFrame = textFrame

if (tagFrame == null) {
    tagFrame = numFrame
} else if (numFrame != null) {
    tagFrame = tagFrame.unionByName(numFrame)
}

if (tagFrame == null) {
    tagFrame = dateFrame
} else if (dateFrame != null) {
    tagFrame = tagFrame.unionByName(dateFrame)
}

if (tagFrame == null) {
    tagFrame = datetimeFrame
} else if (datetimeFrame != null) {
    tagFrame = tagFrame.unionByName(datetimeFrame)
}

val structType = StructType(Seq(
    StructField("tag_id", StringType),
    StructField("tag_name", StringType)
))
val rows = tagMeta.map(item => {
    val tagId = item._1
    val tagName = item._2._2._2
    RowFactory.create(tagId, tagName.replace("'", "\\'"))
}).toList
val tagMetaFrame = spark.createDataFrame(rows, structType)
tagFrame.createOrReplaceTempView("t")
tagMetaFrame.createOrReplaceTempView("m")

var sql = s"SELECT t._profile_id AS `_profile_id`, t.tag_id, m.tag_name, t.value, t.weight FROM t JOIN m ON t.tag_id = m.tag_id"
var dataFrame = spark.sql(sql)

dataFrame.createOrReplaceTempView("t")
sql = s"SELECT u.*, t.* FROM t LEFT JOIN u ON t._profile_id = u._uid"
dataFrame = spark.sql(sql).drop("_uid")
dataFrame.createOrReplaceTempView("t")

val orderedColumns = Array(s"`_profile_id`") ++ dataFrame.columns.filter(column => column != "_profile_id").map(column => s"`$column`")
sql = s"select ${orderedColumns.mkString(",")} from t"
dataFrame = spark.sql(sql)
val total = dataFrame.count()

println(total)

Upvotes: 0

Views: 458

Answers (2)

Ged
Ged

Reputation: 18003

Adaptive Query Execution, AQE, is a layer on top of Spark Catalyst which will modify the Spark Execution plan on the fly. It is a bug in AQE, clearly, for the version of Spark you are running. Set AQE out.

zip works with RDD partitions when all RDDs have same number of partitions, else you get an error. That's a given.

  • If you did not issue zip yourself in your code,
    • and with AQE turned off there is no issue,
      • then if AQE is turned on AND this error is gotten,
        • then by definition AQE is doing something to optimize that is causing this bug to occur.

Upvotes: 1

Matt Andruff
Matt Andruff

Reputation: 5125

Spark is selecting an optimization (spark.sql.adaptive.enabled) that it should not be. You should run this query with spark.sql.adaptive.enabled = false as you are already doing. There may be settings that you could adjust that would work for you to run this with spark.sql.adaptive.enabled set to true work. But do you need to optimize this query and do you know what corner case you are hitting? I suggest until it's require to optimize that you just leave spark.sql.adaptive.enabled = false.

Upvotes: 1

Related Questions