Alan
Alan

Reputation: 469

__getnewargs__ error while using udf in Pyspark

There is a datafarame with 2 columns (db and tb): db stands for database and tb stands for tableName of that database.

   +--------------------+--------------------+
   |            database|           tableName|
   +--------------------+--------------------+
   |aaaaaaaaaaaaaaaaa...|    tttttttttttttttt|
   |bbbbbbbbbbbbbbbbb...|    rrrrrrrrrrrrrrrr|
   |aaaaaaaaaaaaaaaaa...|  ssssssssssssssssss|

I have the following method in python:

 def _get_tb_db(db, tb):
      df = spark.sql("select * from {}.{}".format(db, tb))
      return df.dtypes

and this udf:

 test = udf(lambda db, tb: _get_tb_db(db, tb), StringType())

while running this:

 df = df.withColumn("dtype", test(col("db"), col("tb")))

there is following error:

 pickle.PicklingError: Could not serialize object: Py4JError: An 
 error occurred while calling o58.__getnewargs__. Trace:
 py4j.Py4JException: Method __getnewargs__([]) does not exist

I found some discussion on stackoverflow: Spark __getnewargs__ error but yet I am not sure how to resolve this issue? Is the error because I am creating another dataframe inside the UDF?

Similar to the solution in the link i tried this:

       cols = copy.deepcopy(df.columns)
       df = df.withColumn("dtype", scanning(cols[0], cols[1]))

but still get error

Any solution?

Upvotes: 0

Views: 270

Answers (1)

jxc
jxc

Reputation: 13998

The error means that you can not use Spark dataframe in the UDF. But since your dataframe containing names of databases and tables is most likely small, it's enough to just take a Python for loop, below are some methods which might help get your data:

from pyspark.sql import Row

# assume dfs is the df containing database names and table names
dfs.printSchema()
root
 |-- database: string (nullable = true)
 |-- tableName: string (nullable = true)

Method-1: use df.dtypes

Run the sql select * from database.tableName limit 1 to generate df and return its dtypes, convert it into StringType().

data = []
DRow = Row('database', 'tableName', 'dtypes')
for row in dfs.collect():
  try:
    dtypes = spark.sql('select * from `{}`.`{}` limit 1'.format(row.database, row.tableName)).dtypes
    data.append(DRow(row.database, row.tableName, str(dtypes)))
  except Exception, e:
    print("ERROR from {}.{}: [{}]".format(row.database, row.tableName, e))
    pass

df_dtypes = spark.createDataFrame(data)
# DataFrame[database: string, tableName: string, dtypes: string]

Note:

  • using dtypes instead of str(dtypes) will get the following schema where _1, and _2 are col_name and col_dtype respectively:

    root
     |-- database: string (nullable = true)
     |-- tableName: string (nullable = true)
     |-- dtypes: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- _1: string (nullable = true)
     |    |    |-- _2: string (nullable = true)
    
  • using this method, each table will have only one row. for the next two methods, each col_type of a table will have its own row.

Method-2: use describe

you can also retrieve this information from running spark.sql("describe tableName") by which you get dataframe directly, then use a reduce function to union the results from all tables.

from functools import reduce

def get_df_dtypes(db, tb):
  try:
    return spark.sql('desc `{}`.`{}`'.format(db, tb)) \
                .selectExpr(
                      '"{}" as `database`'.format(db)
                    , '"{}" as `tableName`'.format(tb)
                    , 'col_name'
                    , 'data_type')
  except Exception, e:
    print("ERROR from {}.{}: [{}]".format(db, tb, e))
    pass

# an example table:
get_df_dtypes('default', 'tbl_df1').show()
+--------+---------+--------+--------------------+
|database|tableName|col_name|           data_type|
+--------+---------+--------+--------------------+
| default|  tbl_df1| array_b|array<struct<a:st...|
| default|  tbl_df1| array_d|       array<string>|
| default|  tbl_df1|struct_c|struct<a:double,b...|
+--------+---------+--------+--------------------+

# use reduce function to union all tables into one df
df_dtypes = reduce(lambda d1, d2: d1.union(d2), [ get_df_dtypes(row.database, row.tableName) for row in dfs.collect() ])

Method-3: use spark.catalog.listColumns()

Use spark.catalog.listColumns() which creates a list of collections.Column objects, retrieve name and dataType and merge the data. the resulting dataframe is normalized with col_name and col_dtype on their own columns (same as using Method-2).

data = []
DRow = Row('database', 'tableName', 'col_name', 'col_dtype')
for row in dfs.select('database', 'tableName').collect():
  try:
    for col in spark.catalog.listColumns(row.tableName, row.database):
      data.append(DRow(row.database, row.tableName, col.name, col.dataType))
  except Exception, e:
    print("ERROR from {}.{}: [{}]".format(row.database, row.tableName, e))
    pass

df_dtypes = spark.createDataFrame(data)
# DataFrame[database: string, tableName: string, col_name: string, col_dtype: string]

A Note: different Spark distributions/versions might have different result from describe tbl_name and other commands when retrieving metadata, make sure the correct column names are used in the queries.

Upvotes: 1

Related Questions