How to use the window function to get a single row from each group in Apache Spark

This article is a part of my "100 data engineering tutorials in 100 days" challenge. (5/100)

In this article, we will group a Spark DataFrame by a key and extract a single row from each group. I will write the code using PySpark, but the Scala API looks almost the same.

The first thing we need is an example DataFrame. Let’s imagine that we have a DataFrame of financial product sales that contains the product category, the salesperson’s name, and the number of products sold.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
+---------------+---------+--------+
|       category|     name|how_many|
+---------------+---------+--------+
|      insurance|   Janusz|       0|
|savings account|  Grażyna|       1|
|    credit card|Sebastian|       0|
|       mortgage|   Janusz|       2|
|   term deposit|   Janusz|       4|
|      insurance|  Grażyna|       2|
|savings account|   Janusz|       5|
|    credit card|Sebastian|       2|
|       mortgage|Sebastian|       4|
|   term deposit|   Janusz|       9|
|      insurance|  Grażyna|       3|
|savings account|  Grażyna|       1|
|savings account|Sebastian|       0|
|savings account|Sebastian|       2|
|    credit card|Sebastian|       1|
+---------------+---------+--------+

I want to get the name of the person who sold the most product in each category.

Using the Window Function

We can get the desired outcome using the window function. That function will group the DataFrame by the category and sort the rows in each group in the descending order by the how_many column. After that, we will use that window function to get the row position in each group.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# imports
from pyspark.sql.functions import col, row_number  
from pyspark.sql.window import Window

# code
window = Window \  
    .partitionBy(col('category')) \  
    .orderBy(col("how_many").desc())  
  
df \
    .withColumn(
        'position_in_group',
        row_number().over(window)
    )

In the result, we get the following DataFrame:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
+---------------+---------+--------+-----------------+
|       category|     name|how_many|position_in_group|
+---------------+---------+--------+-----------------+
|savings account|   Janusz|       5|                1|
|savings account|Sebastian|       2|                2|
|savings account|  Grażyna|       1|                3|
|savings account|  Grażyna|       1|                4|
|savings account|Sebastian|       0|                5|
|   term deposit|   Janusz|       9|                1|
|   term deposit|   Janusz|       4|                2|
|       mortgage|Sebastian|       4|                1|
|       mortgage|   Janusz|       2|                2|
|    credit card|Sebastian|       2|                1|
|    credit card|Sebastian|       1|                2|
|    credit card|Sebastian|       0|                3|
|      insurance|  Grażyna|       3|                1|
|      insurance|  Grażyna|       2|                2|
|      insurance|   Janusz|       0|                3|
+---------------+---------+--------+-----------------+


In the end, we will use the where function to filter out the rows that are not the first in their respective groups, and use select to keep only the category and the name column. The full solution looks like this:

1
2
3
4
5
6
7
8
9
10
11
window = Window \  
    .partitionBy(col('category')) \  
    .orderBy(col("how_many").desc())  
  
df \  
    .withColumn(
        'position_in_group',
        row_number().over(window)
    ) \  
    .where(col('position_in_group') == '1') \  
    .select('category', 'name')

Here is the result we want:

1
2
3
4
5
6
7
8
9
+---------------+---------+
|       category|     name|
+---------------+---------+
|savings account|   Janusz|
|   term deposit|   Janusz|
|       mortgage|Sebastian|
|    credit card|Sebastian|
|      insurance|  Grażyna|
+---------------+---------+

Remember to share on social media!
If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.

If you want to contact me, send me a message on LinkedIn or Twitter.

Would you like to have a call and talk? Please schedule a meeting using this link.


Bartosz Mikulski
Bartosz Mikulski * data/machine learning engineer * conference speaker * co-founder of Software Craft Poznan & Poznan Scala User Group