How to use the window function to get a single row from each group in Apache Spark

In this article, we will group a Spark DataFrame by a key and extract a single row from each group. I will write the code using PySpark, but the Scala API looks almost the same.

The first thing we need is an example DataFrame. Let’s imagine that we have a DataFrame of financial product sales that contains the product category, the salesperson’s name, and the number of products sold.

+---------------+---------+--------+
|       category|     name|how_many|
+---------------+---------+--------+
|      insurance|   Janusz|       0|
|savings account|  Grażyna|       1|
|    credit card|Sebastian|       0|
|       mortgage|   Janusz|       2|
|   term deposit|   Janusz|       4|
|      insurance|  Grażyna|       2|
|savings account|   Janusz|       5|
|    credit card|Sebastian|       2|
|       mortgage|Sebastian|       4|
|   term deposit|   Janusz|       9|
|      insurance|  Grażyna|       3|
|savings account|  Grażyna|       1|
|savings account|Sebastian|       0|
|savings account|Sebastian|       2|
|    credit card|Sebastian|       1|
+---------------+---------+--------+

I want to get the name of the person who sold the most product in each category.

Using the Window Function

We can get the desired outcome using the window function. That function will group the DataFrame by the category and sort the rows in each group in the descending order by the how_many column. After that, we will use that window function to get the row position in each group.

# imports
from pyspark.sql.functions import col, row_number
from pyspark.sql.window import Window

# code
window = Window \
    .partitionBy(col('category')) \
    .orderBy(col("how_many").desc())

df \
    .withColumn(
        'position_in_group',
        row_number().over(window)
    )

In the result, we get the following DataFrame:

+---------------+---------+--------+-----------------+
|       category|     name|how_many|position_in_group|
+---------------+---------+--------+-----------------+
|savings account|   Janusz|       5|                1|
|savings account|Sebastian|       2|                2|
|savings account|  Grażyna|       1|                3|
|savings account|  Grażyna|       1|                4|
|savings account|Sebastian|       0|                5|
|   term deposit|   Janusz|       9|                1|
|   term deposit|   Janusz|       4|                2|
|       mortgage|Sebastian|       4|                1|
|       mortgage|   Janusz|       2|                2|
|    credit card|Sebastian|       2|                1|
|    credit card|Sebastian|       1|                2|
|    credit card|Sebastian|       0|                3|
|      insurance|  Grażyna|       3|                1|
|      insurance|  Grażyna|       2|                2|
|      insurance|   Janusz|       0|                3|
+---------------+---------+--------+-----------------+

In the end, we will use the where function to filter out the rows that are not the first in their respective groups, and use select to keep only the category and the name column. The full solution looks like this:

window = Window \
    .partitionBy(col('category')) \
    .orderBy(col("how_many").desc())

df \
    .withColumn(
        'position_in_group',
        row_number().over(window)
    ) \
    .where(col('position_in_group') == '1') \
    .select('category', 'name')

Here is the result we want:

+---------------+---------+
|       category|     name|
+---------------+---------+
|savings account|   Janusz|
|   term deposit|   Janusz|
|       mortgage|Sebastian|
|    credit card|Sebastian|
|      insurance|  Grażyna|
+---------------+---------+
Older post

How to make a pivot table in AWS Athena or PrestoSQL

How to make a pivot table in AWS Athena, and why the pivot function does not exist

Newer post

Broadcast variables and broadcast joins in Apache Spark

How to speed up joins of small DataFrames by using the broadcast join