Open In App

Python PySpark pivot() Function

Last Updated : 24 Sep, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

The pivot() function in PySpark is a powerful method used to reshape a DataFrame by transforming unique values from one column into multiple columns in a new DataFrame, while aggregating data in the process. The function takes a set of unique values from a specified column and turns them into separate columns.

In this article, we will go through a detailed example of how to use the pivot() function in PySpark, covering its usage step by step.

Introduction to PySpark pivot()

In PySpark, the pivot() function is part of the DataFrame API. It allows us to convert rows into columns by specifying:

  • A column whose values will become new columns.
  • Optionally, an aggregation method to apply during the pivot process, e.g., sum(), avg, etc.

The syntax of the pivot() function is:

df.pivot(pivot_column, [values])

Where:

  • pivot_column: The column whose unique values will become the column headers.
  • values: An optional list of values from pivot_column to include. If not specified, all unique values will be used.

Example Data:

Let's consider the following DataFrame, which contains sales data of different products in various regions:

Product

Region

Sale

A

East

100

A

West

150

B

East

200

B

West

250

C

East

300

C

West

350

We will pivot this data so that each region becomes a column, with sales as the values.

Step-by-Step Implementation

1. Initial DataFrame Setup

First, let's create a PySpark DataFrame for the sales data

Python
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a SparkSession
spark = SparkSession.builder.appName("PySpark Pivot Example").getOrCreate()

# Sample data
data = [
    ("A", "East", 100),
    ("A", "West", 150),
    ("B", "East", 200),
    ("B", "West", 250),
    ("C", "East", 300),
    ("C", "West", 350)
]

# Create a DataFrame
columns = ["Product", "Region", "Sales"]
df = spark.createDataFrame(data, columns)

# Show the initial DataFrame
df.show()

Output:

Screenshot-2024-09-24-143529
Creating a PySpark DataFrame

2. Using pivot()

Now we will use the pivot() function to reorganize the data. We want to pivot on the Region column, so that East and West become separate columns. We will aggregate the sales by Product.

Python
# ...

# Pivot the DataFrame
pivot_df = df.groupBy("Product").pivot("Region").sum("Sales")

# Show the pivoted DataFrame
pivot_df.show()

Output:

Screenshot-2024-09-24-143722
Using PySpark Pivot Function

In this output:

  • Each Product is in a single row.
  • The East and West regions have become columns.
  • The values in the new columns represent the Sales values.

3. Aggregating with pivot()

We can also apply additional aggregation functions during the pivot process. For example, if we had multiple rows for each Product and Region, we could use avg(), min(), max(), or other aggregate functions.

For example, if we have multiple sales entries for each product in the same region, we can sum them during the pivot.

Python
# ...

# Sample data with multiple entries for the same product-region combination
data_with_duplicates = [
    ("A", "East", 100),
    ("A", "East", 50),
    ("A", "West", 150),
    ("B", "East", 200),
    ("B", "West", 250),
    ("B", "West", 100),
    ("C", "East", 300),
    ("C", "West", 350)
]

# Create a new DataFrame with duplicate data
df_with_duplicates = spark.createDataFrame(data_with_duplicates, columns)

# Pivot and sum the sales
pivot_df_with_aggregation_sum = df_with_duplicates.groupBy("Product").pivot("Region").sum("Sales")

# Show the result
pivot_df_with_aggregation_sum.show()

# Pivot and sum the sales
pivot_df_with_aggregation_avg = df_with_duplicates.groupBy("Product").pivot("Region").avg("Sales")

# Show the result
pivot_df_with_aggregation_avg.show()

Output:

Screenshot-2024-09-24-144251
Aggregating with Pivot in PySpark

Here, for Product A, the sales from two entries in the East region have been summed (100 + 50 = 150). For Product B, the sales in the West region have also been aggregated (250 + 100 = 350).

And for Product A, the sales from two entries in the East region have been averaged ((100 + 50)/2 = 75.0). For Product B, the sales in the West region have also been aggregated ((250 + 100)/2 = 175.0).

Conclusion

The pivot() function in PySpark is a powerful tool for transforming data. It allows us to convert row-based data into column-based data by pivoting on a specific column's values. In this article, we demonstrated how to pivot data using PySpark, with a focus on sales data by region. Additionally, we showed how to apply aggregation methods like sum() during the pivot process. By utilizing pivot(), we can restructure our DataFrame to make it more suitable for further analysis or reporting.

When using pivot(), keep in mind:

  • We should choose the correct column for pivoting based on the data structure.
  • Ensure that the aggregation method used suits our needs, whether it be sum(), avg(), or another method.
  • Use groupBy() before pivot() to ensure the data is grouped correctly.

Next Article
Article Tags :
Practice Tags :

Similar Reads