Bowen Peng
Bowen Peng

Reputation: 1825

PySpark: how to use `StringIndexer` to do label encoding with the string array column

As we know, we can do LabelEncoder() by StringIndexer in the string column, but if want to do LabelEncoder() on string array column, it is not easy to implement.

# input
df.show()

+--------------------------------------+
|                                  tags|
+--------------------------------------+
|        [industry, display, Merchants]|
|    [smart, swallow, game, Experience]|
|             [social, picture, social]|
|        [default, game, us, adventure]|
| [financial management, loan, product]|
|       [system, profile, optimization]|

...
# After do LabelEncoder() on `tags` column 
...

+--------------------------------------+
|                                  tags|
+--------------------------------------+
|                             [0, 1, 2]|
|                          [3, 4, 4, 5]|
|                             [6, 7, 6]|
|                         [8, 4, 9, 10]|
|                          [11, 12, 13]|
|                          [14, 15, 16]|

Upvotes: 1

Views: 1655

Answers (2)

dominic
dominic

Reputation: 654

You can create a class, which will explode the array column, apply the StringIndexer, and will collect the indexes back to the list. The benefit of using it as a class instead of step by step transformations, is that it can be used in a pipeline or saved as fitted.

A class doing all the transformations and applying a StringIndexer:

class ArrayStringIndexerModel(Model
                              ,DefaultParamsReadable, DefaultParamsWritable):

    def __init__(self, indexer, inputCol: str, outputCol: str):
        super(ArrayStringIndexerModel, self).__init__()
        self.indexer = indexer
        self.inputCol = inputCol
        self.outputCol = outputCol

    def _transform(self, df: DataFrame=[]) -> DataFrame:

        # Creating always increasing id (as in fit)
        df = df.withColumn("id_temp_added", monotonically_increasing_id())\

        # Exploding "inputCol" and saving to the new dataframe (as in fit)
        df2 = df.withColumn('inputExpl', F.explode(self.inputCol)).select('id_temp_added', 'inputExpl')

        # Transforming with fitted "indexed"
        indexed_df = self.indexer.transform(df2)

        # Converting indexed back to array
        indexed_df = indexed_df.groupby('id_temp_added').agg(F.collect_list(F.col(self.outputCol)).alias(self.outputCol))

        # Joining to the main dataframe
        df = df.join(indexed_df, on='id_temp_added', how='left')

        # dropping created id column
        df = df.drop('id_temp_added')

        return df


class ArrayStringIndexer(Estimator
                ,DefaultParamsReadable, DefaultParamsWritable):
    """
    A custom Transformer which applies string indexer to the array of strings
    (explodes, applies StirngIndexer, aggregates back)
    """

    def __init__(self, inputCol: str, outputCol: str):
        super(ArrayStringIndexer, self).__init__()
       # self.indexer = None
        self.inputCol = inputCol
        self.outputCol = outputCol

    def _fit(self, df: DataFrame = []) -> ArrayStringIndexerModel:
        # Creating always increasing id
        df = df.withColumn("id_temp_added", monotonically_increasing_id())

        # Exploding "inputCol" and saving to the new dataframe
        df2 = df.withColumn('inputExpl', F.explode(self.inputCol)).select('id_temp_added', 'inputExpl')

        # Indexing self.indexer and self.indexed dataframe with exploded input column
        indexer = StringIndexer(inputCol='inputExpl', outputCol=self.outputCol)
        indexer = indexer.fit(df2)

        # Returns ArrayStringIndexerModel class with fitted StringIndexer, input and output columns
        return ArrayStringIndexerModel(indexer=indexer, inputCol=self.inputCol, outputCol=self.outputCol)

How to use the class in a code:

tags_indexer = ArrayStringIndexer(inputCol="tags", outputCol="tagsIndex")
tags_indexer.fit(df).transform(df).show()

Upvotes: 0

chlebek
chlebek

Reputation: 2451

Python version will be very similar:

// add unique id to each row
val df2 = df.withColumn("id", monotonically_increasing_id).select('id, explode('tags).as("tag"))

val indexer = new StringIndexer()
  .setInputCol("tag")
  .setOutputCol("tagIndex")

val indexed = indexer.fit(df2).transform(df2)

// in the final step you should convert tags back to array of tags
val dfFinal = indexed.groupBy('id).agg(collect_list('tagIndex))

Upvotes: 2

Related Questions