Reputation: 29
I have a dataframe called df which contains the following:
accountname | clustername | namespace | cost
account1 | cluster_1_1 | ns_1_1 | 10
account1 | cluster_1_1 | ns_1_2 | 11
account1 | cluster_1_1 | infra | 12
account1 | cluster_1_2 | infra | 12
account2 | cluster_2_1 | infra | 13
account3 | cluster_3_1 | ns_3_1 | 10
account3 | cluster_3_1 | ns_3_2 | 11
account3 | cluster_3_1 | infra | 12
df is in a groupby by the accountname field, I need to make a filter by the clustername field within each accountname that does the following: When the row in clustername has more than 1 entry for each accountname, remove the row where namespace = infra, if row in clustername has only one row within its accountname, keep this, something like this:
accountname | clustername | namespace | cost
account1 | cluster_1_1 | ns_1_1 | 10
account1 | cluster_1_1 | ns_1_2 | 11
account1 | cluster_1_2 | infra | 12
account2 | cluster_2_1 | infra | 13
account3 | cluster_3_1 | ns_3_1 | 10
account3 | cluster_3_1 | ns_3_2 | 11
As cluster_1_1 had more than one row, and had the value "infra" in namespace, that row was eliminated. But in the case of cluster_1_2 and cluster_2_1 as they only had one row, then it is preserved. my code is something like this:
from pyspark.sql import SparkSession
from pyspark.sql import *
from pyspark.sql.functions import *
spark = SparkSession \
.builder \
.appName("Python Spark SQL basic example") \
.config("spark.some.config.option", "some-value") \
.getOrCreate()
fields = Row('accountname','clustername','namespace','cost')
s1 = fields("account1","cluster_1_1","ns_1_1",10)
s2 = fields("account1","cluster_1_1","ns_1_2",11)
s3 = fields("account1","cluster_1_1","infra",12)
s4 = fields("account1","cluster_1_2","infra",12)
s5 = fields("account2","cluster_2_1","infra",13)
s6 = fields("account3","cluster_3_1","ns_3_1",10)
s7 = fields("account3","cluster_3_1","ns_3_2",11)
s8 = fields("account3","cluster_3_1","infra",12)
fieldsData=[s1,s2,s3,s4,s5,s6,s7,s8]
df=spark.createDataFrame(fieldsData)
df.show()
Thanks in advance.
Upvotes: 0
Views: 847
Reputation: 1405
check this out, you can first calculate the count of clustername using window function partitioned by accountname &clustername and then use the negate of filter for rows having count greater than 1 and namespace=infra
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w= Window.partitionBy("accountname", "clustername")
df.withColumn("count", F.count("clustername").over(w))\
.filter(~((F.col("count")>1)&(F.col("namespace")=='infra')))\
.drop("count").orderBy(F.col("accountname")).show()
+-----------+-----------+---------+----+
|accountname|clustername|namespace|cost|
+-----------+-----------+---------+----+
| account1|cluster_1_1| ns_1_1| 10|
| account1|cluster_1_1| ns_1_2| 11|
| account1|cluster_1_2| infra| 12|
| account2|cluster_2_1| infra| 13|
| account3|cluster_3_1| ns_3_1| 10|
| account3|cluster_3_1| ns_3_2| 11|
+-----------+-----------+---------+----+
Upvotes: 1