Calculating the cumulative sum of a group using Apache Spark

Calculating the cumulative sum of a group using Apache Spark

In this article, I am going to explain how to calculate the cumulative sum of values grouped by another column.

First, I have to read the CSV file. That part is going to be a little bit tricky because, in my file, semicolons are used as a field separator, the comma is the decimal point, and dates are in this format: “day-month-year”. The default settings of Spark are not sufficient to deal with such a file, so I have to specify every parameter myself.

1
2
3
4
5
6
7
val fileContent = spark.read.option("delimiter", ";").csv(filePath)

val data = fileContent
  .select(
    to_date($"_c0", "dd-MM-yyyy").as("date"),
    regexp_replace($"_c5", ",", "\\.").cast(FloatType).as("amount")
  )

I loaded history of my credit card transactions. I am interested only in the money I have spent, so I decided to keep only transactions with a negative amount, so my “data” variable became this:

1
2
3
4
5
6
7
8
9
10
val data = fileContent
  .select(
    to_date($"_c0", "dd-MM-yyyy").as("date"),
    regexp_replace($"_c5", ",", "\\.").cast(FloatType).as("amount")
  )
  .where($"amount" < 0)
  .select(
    $"date",
    abs($"amount").as("spent")
  )

I want to group the payments by year and month. There is no function which returns it, so I had to separately get the year and month, then concatenate the results.

1
val withMonth = data.withColumn("yearWithMonth", concat(year($"date"), month($"date")))

Now, it is time to define the window used to calculate the cumulative sum. I want the payments to be grouped by year and month, so that data becomes my partition.

1
2
3
val window = Window
  .partitionBy($"yearWithMonth")
... //this is not all the code you need, look below

As the second parameter, I specify the order, because I want the payments to be sorted.

1
2
3
4
val window = Window
  .partitionBy($"yearWithMonth")
  .orderBy($"date".asc)
... //this is not all the code you need, look below

Are you interested in data engineering?

Check out my other blog https://easydata.engineering

Finally, I use the rowsBetween function to specify the window range (note that you should NOT use the rangeBetween function, because it works on the actual values of the rows, not their position. In this case we want to group by position within the partition)

I want the sum of payments between the start of the month (the first payment within the month) and the current day.

1
2
3
4
val window = Window
  .partitionBy($"yearWithMonth")
  .orderBy($"date".asc)
  .rowsBetween(Window.unboundedPreceding, Window.currentRow)

Now, I can use the sum function with the specified window to get the cumulative sum.

1
withMonth.withColumn("spentPerMonth", sum($"spent").over(window))

Then, I made a terrible mistake. I calculated the sum of all my credit card transactions since the day when I had opened the bank account… I did not want to know that…


Remember to share on social media!
If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.

If you watch programming live streams, check out my YouTube channel.
You can also follow me on Twitter: @mikulskibartosz

If you want to hire me, send me a message on LinkedIn or Twitter.


Bartosz Mikulski
Bartosz Mikulski * data scientist / software/data engineer * conference speaker * organizer of School of A.I. meetups in Poznań * co-founder of Software Craftsmanship Poznan & Poznan Scala User Group