Shilpa
Shilpa

Reputation: 101

How do I convert convert a unicode list contained in pyspark column of a dataframe into float list?

I have created a dataframe as shown

   import ast
   from pyspark.sql.functions import udf
   values = [(u'['2','4','713',10),(u'['12','245']',20),(u'['101','12']',30)]
   df = sqlContext.createDataFrame(values,['list','A'])
   df.show()
   +-----------------+---+
   |             list|  A|
   +-----------------+---+
   |u'['2','4','713']| 10|
   |  u' ['12','245']| 20|
   |  u'['101','12',]| 30|
   +-----------------+---+

**How can I convert the above dataframe such that each element in the list is a float and is within a proper list**
I tried the below one :

   def df_amp_conversion(df_modelamp):
      string_list_to_list = udf(lambda row: ast.literal_eval(str(row)))
      df_modelamp  = df_modelamp.withColumn('float_list',string_list_to_list(col("list")))

   df2 = amp_conversion(df)

But the data remains the same without a change. I dont want convert the dataframe to pandas or use collect as it is memory intensive. And if possible try to give me an optimal solution.I am using pyspark

Upvotes: 1

Views: 1440

Answers (2)

OmG
OmG

Reputation: 18838

I can create the true result in python 3 with a little change in definition of function df_amp_conversion. You didn't return the value of df_modelamp! This code works for me properly:

import ast
from pyspark.sql.functions import udf, col
values = [(u"['2','4','713']",10),(u"['12','245']",20),(u"['101','12']",30)]

df = sqlContext.createDataFrame(values,['list','A'])


def df_amp_conversion(df_modelamp):
    string_list_to_list = udf(lambda row: ast.literal_eval(str(row)))
    df_modelamp  = df_modelamp.withColumn('float_list',string_list_to_list(col("list")))
    return df_modelamp

df2 = df_amp_conversion(df)
df2.show()

#    +---------------+---+-----------+
#    |           list|  A| float_list|
#    +---------------+---+-----------+
#    |['2','4','713']| 10|[2, 4, 713]|
#    |   ['12','245']| 20|  [12, 245]|
#    |   ['101','12']| 30|  [101, 12]|
#    +---------------+---+-----------+

Upvotes: 0

user10944437
user10944437

Reputation:

That's because you forgot about the type

udf(lambda row: ast.literal_eval(str(row)), "array<integer>")

Though something like this would be more efficient:

from pyspark.sql.functions import rtrim, ltrim, split 

df = spark.createDataFrame(["""u'[23,4,77,890,4]"""], "string").toDF("list")

df.select(split(
    regexp_replace("list", "^u'\\[|\\]$", ""), ","
).cast("array<integer>").alias("list")).show()

# +-------------------+
# |               list|
# +-------------------+
# |[23, 4, 77, 890, 4]|
# +-------------------+

Upvotes: 2

Related Questions