Hakim
Hakim

Reputation: 51

Count distinct values with conditions

I have a dataframe as below :

+-----------+------------+-------------+-----------+
| id_doctor | id_patient | consumption | type_drug |
+-----------+------------+-------------+-----------+
| d1        | p1         |        12.0 | bhd       |
| d1        | p2         |        10.0 | lsd       |
| d1        | p1         |         6.0 | bhd       |
| d1        | p1         |        14.0 | carboxyl  |
| d2        | p1         |        12.0 | bhd       |
| d2        | p1         |        13.0 | bhd       |
| d2        | p2         |        12.0 | lsd       |
| d2        | p1         |         6.0 | bhd       |
| d2        | p2         |        12.0 | bhd       |
+-----------+------------+-------------+-----------+

I want to count distinct patients that take bhd with a consumption < 16.0 for each doctor.

I tried the following query2, but it doesn't work:

dataframe.groupBy(col("id_doctor")).agg(
    countDistinct(col("id_patient")).where(
        col("type_drug") == "bhd" & col("consumption") < 16.0
    )
)

Upvotes: 4

Views: 11536

Answers (3)

mish1818
mish1818

Reputation: 259

Another solution in PySpark without adding another column:

dataframe.groupBy("id_doctor").agg(
    F.countDistinct(
        F.when(
            col("type_drug") == "bhd" & col("consumption") < 16.0, col("id_patient")
        ).otherwise(None)
    )
)

Upvotes: 10

Paweł Kaczorowski
Paweł Kaczorowski

Reputation: 1562

And solution without adding additional column (Scala)

dataframe
    .groupBy("id_doctor")
    .agg(
        countDistinct(when(col("type_drug")==="bhd" && col("consumption") < 16.0))
    )

Upvotes: 2

Steven
Steven

Reputation: 15258

Just use the where on your dataframe - this version delete the id_doctor where the count is 0 :

dataframe.where(
    col("type_drug") == "bhd" & col("consumption") < 16.0
).groupBy(
    col("id_doctor")
).agg(
    countDistinct(col("id_patient"))
)

Using this syntax, you can keep all the "doctors" :

dataframe.withColumn(
    "fg",
    F.when(
        (col("type_drug") == "bhd") 
        & (col("consumption") < 16.0),
        col("id_patient")
    )
).groupBy(
    col("id_doctor")
).agg(
    countDistinct(col("fg"))
)

Upvotes: 3

Related Questions