Reputation: 737
I have a PySpark dataframe
simpleData = [("person0",10, 10), \
("person1",1, 1), \
("person2",1, 0), \
("person3",5, 1), \
]
columns= ["persons_name","A", 'B']
exp = spark.createDataFrame(data = simpleData, schema = columns)
exp.printSchema()
exp.show()
It looks like
root
|-- persons_name: string (nullable = true)
|-- A: long (nullable = true)
|-- B: long (nullable = true)
|-- total: long (nullable = true)
+------------+---+---+
|persons_name| A| B|
+------------+---+---+
| person0| 10| 10|
| person1| 1| 1|
| person2| 1| 0|
| person3| 5| 1|
+------------+---+---+
Now I want a threshold of value 2 to be applied to the values of columns A and B, such that any value in the column less than the threshold becomes 0 and the values greater than the threshold becomes 1.
The final result should look something like-
+------------+---+---+
|persons_name| A| B|
+------------+---+---+
| person0| 1| 1|
| person1| 0| 0|
| person2| 0| 0|
| person3| 1| 0|
+------------+---+---+
How can I achieve this?
Upvotes: 0
Views: 1073
Reputation: 42392
threshold = 2
exp.select(
[(F.col(col) > F.lit(threshold)).cast('int').alias(col) for col in ['A', 'B']]
)
Upvotes: 2