Calculating the cumulative sum of a group using Apache Spark
In this article, I am going to explain how to calculate the cumulative sum of values grouped by another column.
First, I have to read the CSV file. That part is going to be a little bit tricky because, in my file, semicolons are used as a field separator, the comma is the decimal point, and dates are in this format: “day-month-year”. The default settings of Spark are not sufficient to deal with such a file, so I have to specify every parameter myself.
1 2 3 4 5 6 7 val fileContent = spark.read.option("delimiter", ";").csv(filePath) val data = fileContent .select( to_date($"_c0", "dd-MM-yyyy").as("date"), regexp_replace($"_c5", ",", "\\.").cast(FloatType).as("amount") )
I loaded history of my credit card transactions. I am interested only in the money I have spent, so I decided to keep only transactions with a negative amount, so my “data” variable became this:
1 2 3 4 5 6 7 8 9 10 val data = fileContent .select( to_date($"_c0", "dd-MM-yyyy").as("date"), regexp_replace($"_c5", ",", "\\.").cast(FloatType).as("amount") ) .where($"amount" < 0) .select( $"date", abs($"amount").as("spent") )
I want to group the payments by year and month. There is no function which returns it, so I had to separately get the year and month, then concatenate the results.
1 val withMonth = data.withColumn("yearWithMonth", concat(year($"date"), month($"date")))
Now, it is time to define the window used to calculate the cumulative sum. I want the payments to be grouped by year and month, so that data becomes my partition.
1 2 3 val window = Window .partitionBy($"yearWithMonth") ... //this is not all the code you need, look below
As the second parameter, I specify the order, because I want the payments to be sorted.
1 2 3 4 val window = Window .partitionBy($"yearWithMonth") .orderBy($"date".asc) ... //this is not all the code you need, look below
Finally, I use the rowsBetween function to specify the window range (note that you should NOT use the rangeBetween function, because it works on the actual values of the rows, not their position. In this case we want to group by position within the partition)
I want the sum of payments between the start of the month (the first payment within the month) and the current day.
1 2 3 4 val window = Window .partitionBy($"yearWithMonth") .orderBy($"date".asc) .rowsBetween(Window.unboundedPreceding, Window.currentRow)
Now, I can use the sum function with the specified window to get the cumulative sum.
1 withMonth.withColumn("spentPerMonth", sum($"spent").over(window))
Then, I made a terrible mistake. I calculated the sum of all my credit card transactions since the day when I had opened the bank account… I did not want to know that…
Remember to share on social media! If you like this text, please share it on Facebook/Twitter/LinkedIn/Reddit or other social media.