Bookamp
Bookamp

Reputation: 682

Split spark dataframe by column value and get x number of rows per column value in the result

I have the following spark dataframe, and I am trying to split this up by column value, and return a new dataframe containing x number of rows for each column value

Suppose that this is the dataframe I have:

from pyspark import *;
from pyspark.sql import *;
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType, StructType, StructField, IntegerType, DoubleType
import math;

sc = SparkContext.getOrCreate();
spark = SparkSession.builder.master('local').getOrCreate();


schema = StructType([
    StructField("INDEX", IntegerType(), True),
    StructField("SYMBOL", StringType(), True),
    StructField("DATETIMETS", StringType(), True),
    StructField("PRICE", DoubleType(), True),
    StructField("SIZE", IntegerType(), True),
])

df = spark\
    .createDataFrame(
        data=[(0,'A','2002-12-01 9:30:20',19.75,30200),
             (1,'A','2002-12-02 9:31:20',29.75,30200),             
             (2,'A','2004-12-03 10:36:20',3.0,30200),
             (3,'A','2006-12-06 22:41:20',24.0,30200),
             (4,'A','2006-12-08 22:42:20',60.0,30200),
             (5,'B','2002-12-09 9:30:20',15.75,30200),
             (6,'B','2002-12-12 9:31:20',49.75,30200),             
             (7,'C','2004-11-02 10:36:20',6.0,30200),
             (8,'C','2007-12-02 22:41:20',50.0,30200),
             (9,'D','2008-12-02 22:42:20',60.0,30200),
             (10,'E','2052-12-02 9:30:20',14.75,30200),
             (11,'A','2062-12-02 9:31:20',12.75,30200),             
             (12,'A','2007-12-02 11:36:20',5.0,30200),
             (13,'A','2008-12-02 22:41:20',40.0,30200),
             (14,'A','2008-12-02 22:42:20',50.0,30200)],
        schema=schema);

Say I want at most two rows per symbol, i.e. create a new dataframe with the following data.

Resulting dataframe

Is there a way to do this other than looping though each dataset by using a 'where' clause for the symbol ?

Upvotes: 0

Views: 991

Answers (1)

akuiper
akuiper

Reputation: 214927

Here is one option taking the first two rows from each SYMBOL:

df.rdd.groupBy(lambda r: r['SYMBOL']).flatMap(lambda x: list(x[1])[:2]).toDF().show()

+-----+------+-------------------+-----+-----+
|INDEX|SYMBOL|         DATETIMETS|PRICE| SIZE|
+-----+------+-------------------+-----+-----+
|    0|     A| 2002-12-01 9:30:20|19.75|30200|
|    1|     A| 2002-12-02 9:31:20|29.75|30200|
|   10|     E| 2052-12-02 9:30:20|14.75|30200|
|    9|     D|2008-12-02 22:42:20| 60.0|30200|
|    7|     C|2004-11-02 10:36:20|  6.0|30200|
|    8|     C|2007-12-02 22:41:20| 50.0|30200|
|    5|     B| 2002-12-09 9:30:20|15.75|30200|
|    6|     B| 2002-12-12 9:31:20|49.75|30200|
+-----+------+-------------------+-----+-----+

Upvotes: 1

Related Questions