Reputation: 57
I grouped by the sum(float) and the result is not what I expected.
Not only for grouping by but it happens when I cast float to double.
Here is an code example below.
>>> from pyspark.sql.functions import *
>>> from pyspark.sql.types import *
>>> schema = StructType([ \
... StructField("firstname",StringType(),True), \
... StructField("middlename",StringType(),True), \
... StructField("v",FloatType(),True)])
>>>
>>> df = spark.createDataFrame([["a","b",1.12],["a","b",2.23],["a","c",7.78]],schema=schema)
>>> df.show()
+---------+----------+----+
|firstname|middlename| v|
+---------+----------+----+
| a| b|1.12|
| a| b|2.23|
| a| c|7.78|
+---------+----------+----+
>>> df.groupBy("firstname","middlename").agg(sum("v")).show()
+---------+----------+-----------------+
|firstname|middlename| sum(v)|
+---------+----------+-----------------+
| a| b|3.350000023841858|
| a| c| 7.78000020980835|
+---------+----------+-----------------+
>>> df.groupBy("firstname","middlename").agg(sum("v").cast("float")).show()
+---------+----------+---------------------+
|firstname|middlename|CAST(sum(v) AS FLOAT)|
+---------+----------+---------------------+
| a| b| 3.35|
| a| c| 7.78|
+---------+----------+---------------------+
>>> df.select(col("v"), col("v").cast("double")).show()
+----+------------------+
| v| v|
+----+------------------+
|1.12|1.1200000047683716|
|2.23|2.2300000190734863|
|7.78| 7.78000020980835|
+----+------------------+
I think that's because of the type precision(4 bytes, 8 bytes) but I think this is a bug because the value of float should be preserved when it is cast to double.
I found a solution as I write that cast to float after grouping by but I think this is not clear.
Is there any fancy solution for this?
Upvotes: 0
Views: 3340
Reputation: 57
I found an answer that is doing cast to string before I aggregate column v.
ex)
from pyspark.sql import functions as F
>>> df.withColumn("v",col("v").cast("string").cast("double"))\
.groupBy("firstname","middlename").F.agg(sum("v")).show()
+---------+----------+------+
|firstname|middlename|sum(v)|
+---------+----------+------+
| a| b| 3.35|
| a| c| 7.78|
+---------+----------+------+
>>> df.withColumn("v",col("v").cast("string").cast("double"))\
.groupBy("firstname","middlename").F.agg(sum("v")).printSchema()
root
|-- firstname: string (nullable = true)
|-- middlename: string (nullable = true)
|-- sum(v): double (nullable = true)
Upvotes: 1