Reputation: 2439
Say I have a very basic Spark DataFrame that consists of a couple of columns, one of which contains a value that I want to modify.
|| value || lang ||
| 3 | en |
| 4 | ua |
Say, I want to have a new column per specific class where I would add a float number to the given value (this is not much relevant to the final question though, in reality I do a prediction with sklearn there, but for simplicity let's assume we are adding stuff, the idea is I am modifying the value in some way). So given a dict classes={'1':2.0, '2':3.0}
I would like to have a column for each class where I add the value from DF to the value of the class and then save it to a csv:
class_1.csv
|| value || lang || my_class | modified ||
| 3 | en | 1 | 5.0 | # this is 3+2.0
| 4 | ua | 1 | 6.0 | # this is 4+2.0
class_2.csv
|| value || lang || my_class | modified ||
| 3 | en | 2 | 6.0 | # this is 3+3.0
| 4 | ua | 2 | 7.0 | # this is 4+3.0
So far I have the following code that works and modifies the value for each defined class, but it is done with a for loop and I am looking for a more advanced optimization for it:
import pyspark
from pyspark import SparkConf, SparkContext
from pyspark.sql import functions as F
from pyspark.sql.types import FloatType
from pyspark.sql.functions import udf
from pyspark.sql.functions import lit
# create session and context
spark = pyspark.sql.SparkSession.builder.master("yarn").appName("SomeApp").getOrCreate()
conf = SparkConf().setAppName('Some_App').setMaster("local[*]")
sc = SparkContext.getOrCreate(conf)
my_df = spark.read.csv("some_file.csv")
# modify the value here
def do_stuff_to_column(value, separate_class):
# do stuff to column, let's pretend we just add a specific value per specific class that is read from a dictionary
class_dict = {'1':2.0, '2':3.0} # would be loaded from somewhere
return float(value+class_dict[separate_class])
# iterate over each given class later
class_dict = {'1':2.0, '2':3.0} # in reality have more than 10 classes
# create a udf function
udf_modify = udf(do_stuff_to_column, FloatType())
# loop over each class
for my_class in class_dict:
# create the column first with lit
my_df2 = my_df.withColumn("my_class", lit(my_class))
# modify using udf function
my_df2 = my_df2.withColumn("modified", udf_modify("value","my_class"))
# write to csv now
my_df2.write.format("csv").save("class_"+my_class+".csv")
So the question is, is there a better/faster way of doing this then in a for loop?
Upvotes: 0
Views: 642
Reputation: 1721
I would use some form of join
, in this case crossJoin
. Here's a MWE:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
spark = SparkSession.builder.getOrCreate()
df = spark.createDataFrame([(3, 'en'), (4, 'ua')], ['value', 'lang'])
classes = spark.createDataFrame([(1, 2.), (2, 3.)], ['class_key', 'class_value'])
res = df.crossJoin(classes).withColumn('modified', F.col('value') + F.col('class_value'))
res.show()
For saving as separate CSV's I think there is no better way than to use a loop.
Upvotes: 1