Reputation: 77
I am facing difficulty in generating unique sequence numbers to replace the null values in a column of a table. The table is obtained after joining two other tables and the column id the primary key column where null values are to be replaced with unique sequence values. I tried using accumulators but i am facing difficulty when running the program in a multinode cluster.
val joined=csv2.join(csv,csv2("ACCT_PRDCT_CD")===csv("ACCT_PRDCT_CD"),"left_outer")
joined.filter("ACCT_CO_NO is null").show
val k=joined.withColumn("Acc_flag", when($"ACCT_CO_NO".isNull,0).otherwise($"ACCT_CO_NO"))
var a=1
def generate(s:Int):Int={
if (s==0){
a=a+1
return a
}
else {
return s
}
}
val generateNum = udf(generate(_:Int))
val newjoined=k.withColumn("n",generateNum($"ACC_flag"))
Upvotes: 0
Views: 924
Reputation: 22439
If I understand your requirement correctly, consider using monotonically_increasing_id or RDD's zipWithIndex. To avoid collision, the generated sequence numbers will then be added to a number greater than the maximum column value before replacing the null
s.
import org.apache.spark.sql.functions._
val dfL = Seq(
(1, "a"),
(2, "b"),
(3, "c"),
(4, "d"),
(5, "e"),
(6, "f")
).toDF("c1", "c2")
val dfR = Seq(
(1, 100L),
(2, 200L),
(3, 300L)
).toDF("c1", "c2")
val c2max = dfR.select(max($"c2")).first.getLong(0)
// c2max: Long = 300
val dfJoined = dfL.join(dfR, Seq("c1"), "left").
select(dfL("c1"), dfR("c2"))
METHOD 1: using monotonically_increasing_id
dfJoined.withColumn( "c2x", when(col("c2").isNotNull, col("c2")).
otherwise(monotonically_increasing_id + c2max + 1)
).
show
// +---+----+-----------+
// | c1| c2| c2x|
// +---+----+-----------+
// | 1| 100| 100|
// | 2| 200| 200|
// | 3| 300| 300|
// | 4|null|25769804077|
// | 5|null|34359738669|
// | 6|null|42949673261|
// +---+----+-----------+
Note that the generated sequence numbers aren't necessarily consecutive.
METHOD 2: using RDD's zipWithIndex
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
val rdd = dfJoined.rdd.zipWithIndex.
map{ case (row: Row, idx: Long) => Row.fromSeq(row.toSeq :+ idx) }
spark.createDataFrame(rdd,
StructType(dfJoined.schema.fields :+ StructField("idx", LongType))
).
select( $"c1", $"c2",
when(col("c2").isNotNull, col("c2")).otherwise($"idx" + c2max + 1).
as("c2x")
).
show
// +---+----+---+
// | c1| c2|c2x|
// +---+----+---+
// | 1| 100|100|
// | 2| 200|200|
// | 3| 300|300|
// | 4|null|304|
// | 5|null|305|
// | 6|null|306|
// +---+----+---+
Upvotes: 1