borderline_coder
borderline_coder

Reputation: 302

Selecting 'Exclusive Rows' from a PySpark Dataframe

I have a PySpark dataframe like this:

+----------+-----+
|account_no|types|
+----------+-----+
|         1|    K|
|         1|    A|
|         1|    S|
|         2|    M|
|         2|    D|
|         2|    S|
|         3|    S|
|         3|    S|
|         4|    M|
|         5|    K|
|         1|    S|
|         6|    S|
+----------+-----+

and I am trying to pick the account numbers for which Exclusively 'S' exists. For example: Even though '1' has type ='S', I will not pick it because it has also got other types. But I will pick 3 and 6, because they have just one type 'S'.

What I am doing right now is: - First get all accounts for which 'K' exists and remove them; which in this example removes '1' and '5' - Second find all accounts for which 'D' exists and remove them, which removes '2' - Third find all accounts for which 'M' exists, and remove '4' ('2' has also got 'M' but it was removed at step 2) - Fourth find all accounts for which 'A' exists, and remove them

So, now '1', '2', '4' and '5' are removed and I get '3' and '6' which have exclusive 'S'.

But this is a long process, how do I optimize it? Thank you

Upvotes: 1

Views: 263

Answers (4)

thebot
thebot

Reputation: 91

You can simply detect the amount of distinct types an account has and then filter the 'S' accounts which only have 1 distinct type.

Here is my code for it:

from pyspark.sql.functions import countDistinct

data = [(1, 'k'),
        (1, 'a'),
        (1, 's'),
        (2, 'm'),
        (2, 'd'),
        (2, 's'),
        (3, 's'),
        (3, 's'),
        (4, 'm'),
        (5, 'k'),
        (1, 's'),
        (6, 's')]

df = spark.createDataFrame(data, ['account_no', 'types']).distinct()

exclusive_s_accounts = (df.groupBy('account_no').agg(countDistinct('types').alias('distinct_count'))
                        .join(df, 'account_no')
                        .where((col('types') == 's') & (col('distinct_count') == 1))
                        .drop('distinct_count'))

Upvotes: 1

anky
anky

Reputation: 75080

Another alternative is counting distinct over a window and then filter where Distinct count == 1 and types == S , for ordering you can assign a monotonically increasing id and then orderBy the same.

from pyspark.sql import functions as F
W = Window.partitionBy('account_no')

out = (df.withColumn("idx",F.monotonically_increasing_id())
   .withColumn("Distinct",F.approx_count_distinct(F.col("types")).over(W)).orderBy("idx")
   .filter("Distinct==1 AND types =='S'")).drop('idx','Distinct')

out.show()

+----------+-----+
|account_no|types|
+----------+-----+
|         3|    S|
|         3|    S|
|         6|    S|
+----------+-----+

Upvotes: 2

H Roy
H Roy

Reputation: 635

Another alternate approach could be to get all the types under one column and then apply filter operations to exclude which has non "S" values.

from pyspark.sql.functions import concat_ws
from pyspark.sql.functions import collectivist
from pyspark.sql.functions import col
df = spark.read.csv("/Users/Downloads/account.csv", header=True, inferSchema=True, sep=",")
type_df = df.groupBy("account_no").agg(concat_ws(",",     collect_list("types")).alias("all_types")).select(col("account_no"),     col("all_types"))

+----------+---------+
|account_no|all_types|
+----------+---------+
|         1|  K,A,S,S|
|         6|        S|
|         3|      S,S|
|         5|        K|
|         4|        M|
|         2|    M,D,S|
+----------+---------+

further filtering using regular expression
only_s_df =  type_df.withColumn("S_status",F.col("all_types").rlike("K|A|M|D"))
only_s_df.show()
+----------+---------+----------+
|account_no|all_types|S_status  |
+----------+---------+----------+
|         1|  K,A,S,S|      true|
|         6|        S|     false|
|         3|      S,S|     false|
|         5|        K|      true|
|         4|        M|      true|
|         2|    M,D,S|      true|
+----------+---------+----------+

hope this way you can get the answer and further processing.

Upvotes: 0

murtihash
murtihash

Reputation: 8410

One way to do this is to use Window functions. First we get a sum of the number of S in each account_no grouping. Then we compare that to the total number of entries for that group, in the filter, if they match we keep that number.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w=Window().partitionBy("account_no")
w1=Window().partitionBy("account_no").orderBy("types")

df.withColumn("sum_S", F.sum(F.when(F.col("types")=='S', F.lit(1)).otherwise(F.lit(0))).over(w))\
  .withColumn("total", F.max(F.row_number().over(w1)).over(w))\
  .filter('total=sum_S').drop("total","Sum_S").show()

#+----------+-----+
#|account_no|types|
#+----------+-----+
#|         6|    S|
#|         3|    S|
#|         3|    S|
#+----------+-----+

Upvotes: 2

Related Questions