flck
flck

Reputation: 29

How to drop column in a condition with groupby in Pyspark

I have a dataframe called df which contains the following:

accountname |   clustername |   namespace   |   cost
account1    |   cluster_1_1 |   ns_1_1      |   10
account1    |   cluster_1_1 |   ns_1_2      |   11
account1    |   cluster_1_1 |   infra       |   12
account1    |   cluster_1_2 |   infra       |   12
account2    |   cluster_2_1 |   infra       |   13
account3    |   cluster_3_1 |   ns_3_1      |   10
account3    |   cluster_3_1 |   ns_3_2      |   11
account3    |   cluster_3_1 |   infra       |   12

df is in a groupby by the accountname field, I need to make a filter by the clustername field within each accountname that does the following: When the row in clustername has more than 1 entry for each accountname, remove the row where namespace = infra, if row in clustername has only one row within its accountname, keep this, something like this:

accountname |   clustername |   namespace   |   cost
account1    |   cluster_1_1 |   ns_1_1      |   10
account1    |   cluster_1_1 |   ns_1_2      |   11
account1    |   cluster_1_2 |   infra       |   12
account2    |   cluster_2_1 |   infra       |   13
account3    |   cluster_3_1 |   ns_3_1      |   10
account3    |   cluster_3_1 |   ns_3_2      |   11

As cluster_1_1 had more than one row, and had the value "infra" in namespace, that row was eliminated. But in the case of cluster_1_2 and cluster_2_1 as they only had one row, then it is preserved. my code is something like this:

from pyspark.sql import SparkSession
from pyspark.sql import *
from pyspark.sql.functions import *

spark = SparkSession \
    .builder \
    .appName("Python Spark SQL basic example") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

fields = Row('accountname','clustername','namespace','cost')
s1 = fields("account1","cluster_1_1","ns_1_1",10)
s2 = fields("account1","cluster_1_1","ns_1_2",11)
s3 = fields("account1","cluster_1_1","infra",12)
s4 = fields("account1","cluster_1_2","infra",12)
s5 = fields("account2","cluster_2_1","infra",13)
s6 = fields("account3","cluster_3_1","ns_3_1",10)
s7 = fields("account3","cluster_3_1","ns_3_2",11)
s8 = fields("account3","cluster_3_1","infra",12)

fieldsData=[s1,s2,s3,s4,s5,s6,s7,s8]
df=spark.createDataFrame(fieldsData)
df.show()

Thanks in advance.

Upvotes: 0

Views: 847

Answers (1)

kites
kites

Reputation: 1405

check this out, you can first calculate the count of clustername using window function partitioned by accountname &clustername and then use the negate of filter for rows having count greater than 1 and namespace=infra

from pyspark.sql import functions as F
from pyspark.sql.window import Window

w= Window.partitionBy("accountname", "clustername")

df.withColumn("count", F.count("clustername").over(w))\
    .filter(~((F.col("count")>1)&(F.col("namespace")=='infra')))\
    .drop("count").orderBy(F.col("accountname")).show()

+-----------+-----------+---------+----+
|accountname|clustername|namespace|cost|
+-----------+-----------+---------+----+
|   account1|cluster_1_1|   ns_1_1|  10|
|   account1|cluster_1_1|   ns_1_2|  11|
|   account1|cluster_1_2|    infra|  12|
|   account2|cluster_2_1|    infra|  13|
|   account3|cluster_3_1|   ns_3_1|  10|
|   account3|cluster_3_1|   ns_3_2|  11|
+-----------+-----------+---------+----+

Upvotes: 1

Related Questions