Window functions with pySpark


sql-workflow

The flow while using window functions in pySpark is simple:

  1. Create a window
  2. Apply a function on the window

I'm using spark in jupyter. I used this code block to set things up:

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity='all'

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.master("local[*]").getOrCreate()

For the examples, I'll be using this data:

data = [
    (1, "Alice", "Austin", 100),
    (2, "Bob", "Austin", 200),
    (3, "Chris", "Austin", 300),
    (4, "Dave", "Toronto", 400),
    (5, "Elisa", "Toronto", 300),
    (6, "Fabrice", "Toronto", 200),
    (7, "Girard", "Toronto", 100),
    (8, "Hal", "Tokyo", 50),
    (9, "Ignis", "Tokyo", 100),
    (10, "John", "Tokyo", 100),
]
schema= ["id", "name", "location", "sales_amount"]
df = spark.createDataFrame(data = data, schema = schema)

df.show()
+---+-------+--------+------------+
| id|   name|location|sales_amount|
+---+-------+--------+------------+
|  1|  Alice|  Austin|         100|
|  2|    Bob|  Austin|         200|
|  3|  Chris|  Austin|         300|
|  4|   Dave| Toronto|         400|
|  5|  Elisa| Toronto|         300|
|  6|Fabrice| Toronto|         200|
|  7| Girard| Toronto|         100|
|  8|    Hal|   Tokyo|          50|
|  9|  Ignis|   Tokyo|         100|
| 10|   John|   Tokyo|         100|
+---+-------+--------+------------+

Create window

There are 2 steps

  1. Set the column(s) on which you'll partition the window
  2. (optional) Set the column(s) to use for ordering the rows within each partition

Un-ordered window

This is a window function without orderBy()

Syntax:

from pyspark.sql.window import Window
window_location  = Window.partitionBy(col("col1"),col("col2"),col("col3") ... )

Example:

from pyspark.sql.window import Window
window_location  = Window.partitionBy(col("location"))

Ordered window

Simply add orderBy() to the created window.

Syntax:

from pyspark.sql.window import Window
window_location  = Window.partitionBy(col("col1"),col("col2"),col("col3") ... ).orderBy(col("orderCol1"),col("orderCol2"),col("orderCol3") ... )

Example:

from pyspark.sql.window import Window
window_location  = Window.partitionBy(col("location")).orderBy(col("sales_amount"))

Applying the window function

Window functions can be roughly divided into 3 categories. They are:

  1. Aggregate functions - don't require ordered windows
    1. avg()
    2. sum()
    3. min()
    4. max()
  2. Ranking functions - require ordered windows
    1. row_number()
    2. rank()
    3. dense_rank()
    4. percent_rank()
    5. ntile(int)
  3. Analytical functions - require ordered windows
    1. cume_dist()
    2. lag(col_name, int)
    3. lead(col_name, int)

Aggregate functions

Aggregate window functions don't require ordered windows. So you can specify window without .orderBy().

In practice, always use aggregate window functions without order by. Otherwise, you'll get weird outputs (rolling outputs, instead of outputs over the entire window. More details in this article - Window functions - inner behaviors and optimization).

avg()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy(col("location"))
df_avg = df.withColumn("avg",F.avg(col("sales_amount")).over(window_location))
df_avg.show()
+---+-------+--------+------------+-----------------+
| id|   name|location|sales_amount|              avg|
+---+-------+--------+------------+-----------------+
|  1|  Alice|  Austin|         100|            200.0|
|  2|    Bob|  Austin|         200|            200.0|
|  3|  Chris|  Austin|         300|            200.0|
|  4|   Dave| Toronto|         400|            250.0|
|  5|  Elisa| Toronto|         300|            250.0|
|  6|Fabrice| Toronto|         200|            250.0|
|  7| Girard| Toronto|         100|            250.0|
|  8|    Hal|   Tokyo|          50|83.33333333333333|
|  9|  Ignis|   Tokyo|         100|83.33333333333333|
| 10|   John|   Tokyo|         100|83.33333333333333|
+---+-------+--------+------------+-----------------+

sum()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy(col("location"))
df_sum = df.withColumn("sum",F.sum(col("sales_amount")).over(window_location))
df_sum.show()
+---+-------+--------+------------+----+
| id|   name|location|sales_amount| sum|
+---+-------+--------+------------+----+
|  1|  Alice|  Austin|         100| 600|
|  2|    Bob|  Austin|         200| 600|
|  3|  Chris|  Austin|         300| 600|
|  4|   Dave| Toronto|         400|1000|
|  5|  Elisa| Toronto|         300|1000|
|  6|Fabrice| Toronto|         200|1000|
|  7| Girard| Toronto|         100|1000|
|  8|    Hal|   Tokyo|          50| 250|
|  9|  Ignis|   Tokyo|         100| 250|
| 10|   John|   Tokyo|         100| 250|
+---+-------+--------+------------+----+

min()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy(col("location"))
df_min = df.withColumn("min",F.min(col("sales_amount")).over(window_location))
df_min.show()
+---+-------+--------+------------+---+
| id|   name|location|sales_amount|min|
+---+-------+--------+------------+---+
|  1|  Alice|  Austin|         100|100|
|  2|    Bob|  Austin|         200|100|
|  3|  Chris|  Austin|         300|100|
|  4|   Dave| Toronto|         400|100|
|  5|  Elisa| Toronto|         300|100|
|  6|Fabrice| Toronto|         200|100|
|  7| Girard| Toronto|         100|100|
|  8|    Hal|   Tokyo|          50| 50|
|  9|  Ignis|   Tokyo|         100| 50|
| 10|   John|   Tokyo|         100| 50|
+---+-------+--------+------------+---+

max()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy(col("location"))
df_max = df.withColumn("max",F.max(col("sales_amount")).over(window_location))
df_max.show()
+---+-------+--------+------------+---+
| id|   name|location|sales_amount|max|
+---+-------+--------+------------+---+
|  1|  Alice|  Austin|         100|300|
|  2|    Bob|  Austin|         200|300|
|  3|  Chris|  Austin|         300|300|
|  4|   Dave| Toronto|         400|400|
|  5|  Elisa| Toronto|         300|400|
|  6|Fabrice| Toronto|         200|400|
|  7| Girard| Toronto|         100|400|
|  8|    Hal|   Tokyo|          50|100|
|  9|  Ignis|   Tokyo|         100|100|
| 10|   John|   Tokyo|         100|100|
+---+-------+--------+------------+---+

Ranking functions

Ranking window functions need the window to be ordered. So, while creating window for ranking functions, you must specify orderBy(). If you don't, spark sql will throw an AnalysisException.

Example -

AnalysisException: Window function row_number() requires window to be ordered, please add ORDER BY clause. For example SELECT row_number()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table.

row_number()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy("location").orderBy("sales_amount")
df_row_number = df.withColumn("row_number",F.row_number().over(window_location))
df_row_number.show()
+---+-------+--------+------------+----------+
| id|   name|location|sales_amount|row_number|
+---+-------+--------+------------+----------+
|  1|  Alice|  Austin|         100|         1|
|  2|    Bob|  Austin|         200|         2|
|  3|  Chris|  Austin|         300|         3|
|  7| Girard| Toronto|         100|         1|
|  6|Fabrice| Toronto|         200|         2|
|  5|  Elisa| Toronto|         300|         3|
|  4|   Dave| Toronto|         400|         4|
|  8|    Hal|   Tokyo|          50|         1|
|  9|  Ignis|   Tokyo|         100|         2|
| 10|   John|   Tokyo|         100|         3|
+---+-------+--------+------------+----------+

rank()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy("location").orderBy("sales_amount")
df_rank = df.withColumn("rank",F.rank().over(window_location))
df_rank.show()
+---+-------+--------+------------+----+
| id|   name|location|sales_amount|rank|
+---+-------+--------+------------+----+
|  1|  Alice|  Austin|         100|   1|
|  2|    Bob|  Austin|         200|   2|
|  3|  Chris|  Austin|         300|   3|
|  7| Girard| Toronto|         100|   1|
|  6|Fabrice| Toronto|         200|   2|
|  5|  Elisa| Toronto|         300|   3|
|  4|   Dave| Toronto|         400|   4|
|  8|    Hal|   Tokyo|          50|   1|
|  9|  Ignis|   Tokyo|         100|   2|
| 10|   John|   Tokyo|         100|   2|
+---+-------+--------+------------+----+

dense_rank()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy("location").orderBy("sales_amount")
df_dense_rank = df.withColumn("dense_rank",F.dense_rank().over(window_location))
df_dense_rank.show()
+---+-------+--------+------------+----------+
| id|   name|location|sales_amount|dense_rank|
+---+-------+--------+------------+----------+
|  1|  Alice|  Austin|         100|         1|
|  2|    Bob|  Austin|         200|         2|
|  3|  Chris|  Austin|         300|         3|
|  7| Girard| Toronto|         100|         1|
|  6|Fabrice| Toronto|         200|         2|
|  5|  Elisa| Toronto|         300|         3|
|  4|   Dave| Toronto|         400|         4|
|  8|    Hal|   Tokyo|          50|         1|
|  9|  Ignis|   Tokyo|         100|         2|
| 10|   John|   Tokyo|         100|         2|
+---+-------+--------+------------+----------+

percent_rank()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy("location").orderBy("sales_amount")
df_percent_rank = df.withColumn("percent_rank",F.percent_rank().over(window_location))
df_percent_rank.show()
+---+-------+--------+------------+------------------+
| id|   name|location|sales_amount|      percent_rank|
+---+-------+--------+------------+------------------+
|  1|  Alice|  Austin|         100|               0.0|
|  2|    Bob|  Austin|         200|               0.5|
|  3|  Chris|  Austin|         300|               1.0|
|  7| Girard| Toronto|         100|               0.0|
|  6|Fabrice| Toronto|         200|0.3333333333333333|
|  5|  Elisa| Toronto|         300|0.6666666666666666|
|  4|   Dave| Toronto|         400|               1.0|
|  8|    Hal|   Tokyo|          50|               0.0|
|  9|  Ignis|   Tokyo|         100|               0.5|
| 10|   John|   Tokyo|         100|               0.5|
+---+-------+--------+------------+------------------+

ntile(int)

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy("location").orderBy("sales_amount")
df_ntile = df.withColumn("ntile",F.ntile(2).over(window_location))
df_ntile.show()
+---+-------+--------+------------+-----+
| id|   name|location|sales_amount|ntile|
+---+-------+--------+------------+-----+
|  1|  Alice|  Austin|         100|    1|
|  2|    Bob|  Austin|         200|    1|
|  3|  Chris|  Austin|         300|    2|
|  7| Girard| Toronto|         100|    1|
|  6|Fabrice| Toronto|         200|    1|
|  5|  Elisa| Toronto|         300|    2|
|  4|   Dave| Toronto|         400|    2|
|  8|    Hal|   Tokyo|          50|    1|
|  9|  Ignis|   Tokyo|         100|    1|
| 10|   John|   Tokyo|         100|    2|
+---+-------+--------+------------+-----+

Analytical functions

Analytical window functions need the window to be ordered. So, while creating window for ranking functions, you must specify orderBy(). If you don't, spark sql will throw an AnalysisException.

Example -

AnalysisException: Window function cume_dist() requires window to be ordered, please add ORDER BY clause. For example SELECT cume_dist()(value_expr) OVER (PARTITION BY window_partition ORDER BY window_ordering) from table.

cume_dist()

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy(col("location")).orderBy("sales_amount")
df_cume_dist = df.withColumn("cume_dist",F.cume_dist().over(window_location))
df_cume_dist.show()
+---+-------+--------+------------+------------------+
| id|   name|location|sales_amount|         cume_dist|
+---+-------+--------+------------+------------------+
|  1|  Alice|  Austin|         100|0.3333333333333333|
|  2|    Bob|  Austin|         200|0.6666666666666666|
|  3|  Chris|  Austin|         300|               1.0|
|  7| Girard| Toronto|         100|              0.25|
|  6|Fabrice| Toronto|         200|               0.5|
|  5|  Elisa| Toronto|         300|              0.75|
|  4|   Dave| Toronto|         400|               1.0|
|  8|    Hal|   Tokyo|          50|0.3333333333333333|
|  9|  Ignis|   Tokyo|         100|               1.0|
| 10|   John|   Tokyo|         100|               1.0|
+---+-------+--------+------------+------------------+

lag(col_name, int)

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy(col("location")).orderBy("sales_amount")
df_lag = df.withColumn("lag",F.lag(col("sales_amount"),1).over(window_location))
df_lag.show()
+---+-------+--------+------------+----+
| id|   name|location|sales_amount| lag|
+---+-------+--------+------------+----+
|  1|  Alice|  Austin|         100|NULL|
|  2|    Bob|  Austin|         200| 100|
|  3|  Chris|  Austin|         300| 200|
|  7| Girard| Toronto|         100|NULL|
|  6|Fabrice| Toronto|         200| 100|
|  5|  Elisa| Toronto|         300| 200|
|  4|   Dave| Toronto|         400| 300|
|  8|    Hal|   Tokyo|          50|NULL|
|  9|  Ignis|   Tokyo|         100|  50|
| 10|   John|   Tokyo|         100| 100|
+---+-------+--------+------------+----+

lead(col_name, int)

from pyspark.sql.window import Window
import pyspark.sql.functions as F

window_location  = Window.partitionBy(col("location")).orderBy("sales_amount")
df_lead = df.withColumn("lead",F.lead(col("sales_amount"),1).over(window_location))
df_lead.show()
+---+-------+--------+------------+----+
| id|   name|location|sales_amount|lead|
+---+-------+--------+------------+----+
|  1|  Alice|  Austin|         100| 200|
|  2|    Bob|  Austin|         200| 300|
|  3|  Chris|  Austin|         300|NULL|
|  7| Girard| Toronto|         100| 200|
|  6|Fabrice| Toronto|         200| 300|
|  5|  Elisa| Toronto|         300| 400|
|  4|   Dave| Toronto|         400|NULL|
|  8|    Hal|   Tokyo|          50| 100|
|  9|  Ignis|   Tokyo|         100| 100|
| 10|   John|   Tokyo|         100|NULL|
+---+-------+--------+------------+----+

That's it. Enjoy.

TABLE OF CONTENTS