Open In App

Compression of 2D Segment Tree in Python

Last Updated : 01 Jun, 2024
Comments
Improve
Suggest changes
Like Article
Like
Report

Compressing a 2D segment tree in Python involves reducing memory usage by storing only necessary information. This can be achieved by employing various techniques such as segment tree compression and lazy propagation. Let's create a tutorial on compressing a 2D segment tree in Python.

A 2D segment tree is a data structure used for efficient querying and updating of rectangular regions in a 2D array. It is constructed based on the concept of segment trees, where each node represents a range of values.

Need for Compression of 2D Segment Tree

In scenarios where the input size is large or memory usage is a concern, compressing the 2D segment tree can significantly reduce the memory footprint while maintaining the ability to perform queries and updates efficiently.

Techniques for Compression of 2D Segment Tree

  1. Segment Tree Compression: Instead of storing individual elements in the segment tree nodes, aggregate or summarize the information to reduce memory usage.
  2. Lazy Propagation: Delay the updates in the segment tree nodes until necessary, minimizing the number of node updates required.

Python Implementation for Compression of 2D Segment Tree

Let's create a compressed version of the 2D segment tree by storing aggregated values in the nodes and using lazy propagation for updates.

Step 1: Define the Segment Tree Node

First, let's define a class to represent each node of the 2D segment tree.

Python
class SegmentTreeNode:
    def __init__(self, value, start_row, end_row, start_col, end_col):
        self.value = value
        self.start_row = start_row
        self.end_row = end_row
        self.start_col = start_col
        self.end_col = end_col
        self.left = None
        self.right = None

Step 2: Build the 2D Segment Tree

Next, let's build the 2D segment tree as usual, but we will keep track of the values at each node.

Python
def build_segment_tree(matrix, start_row, end_row, start_col, end_col):
    if start_row == end_row and start_col == end_col:
        return SegmentTreeNode(matrix[start_row][start_col], start_row, end_row, start_col, end_col)

    mid_row = (start_row + end_row) // 2
    mid_col = (start_col + end_col) // 2

    node = SegmentTreeNode(0, start_row, end_row, start_col, end_col)
    node.left = build_segment_tree(matrix, start_row, mid_row, start_col, mid_col)
    node.right = build_segment_tree(matrix, start_row, mid_row, mid_col + 1, end_col)
    node.value = node.left.value + node.right.value

    return node

Step 3: Compress the Tree

After building the tree, we can compress it by merging nodes with similar values.

Python
def compress_segment_tree(node):
    if node.left is None and node.right is None:
        return

    if node.left.value == node.right.value:
        node.value = node.left.value
        node.left = None
        node.right = None
    else:
        compress_segment_tree(node.left)
        compress_segment_tree(node.right)

Step 4: Complete Code

Finally, let's demonstrate how to use the compressed 2D segment tree.

Python
class SegmentTreeNode:
    def __init__(self, value, start_row, end_row, start_col, end_col):
        """
        Initialize a Segment Tree Node.

        Args:
            value (int): The value of the segment represented by this node.
            start_row (int): The starting row index of the segment.
            end_row (int): The ending row index of the segment.
            start_col (int): The starting column index of the segment.
            end_col (int): The ending column index of the segment.
        """
        self.value = value
        self.start_row = start_row
        self.end_row = end_row
        self.start_col = start_col
        self.end_col = end_col
        self.children = []

def build_segment_tree(matrix, start_row, end_row, start_col, end_col):
    """
    Recursively build a Segment Tree for a 2D matrix.

    Args:
        matrix (list): The 2D matrix to build the segment tree for.
        start_row (int): The starting row index of the current segment.
        end_row (int): The ending row index of the current segment.
        start_col (int): The starting column index of the current segment.
        end_col (int): The ending column index of the current segment.

    Returns:
        SegmentTreeNode: The root node of the constructed segment tree.
    """
    if start_row == end_row and start_col == end_col:
        return SegmentTreeNode(matrix[start_row][start_col], start_row, end_row, start_col, end_col)

    node = SegmentTreeNode(0, start_row, end_row, start_col, end_col)
    mid_row = (start_row + end_row) // 2
    mid_col = (start_col + end_col) // 2

    if start_row != end_row:
        if start_col != end_col:
            node.children.append(build_segment_tree(matrix, start_row, mid_row, start_col, mid_col))
            node.children.append(build_segment_tree(matrix, start_row, mid_row, mid_col + 1, end_col))
            node.children.append(build_segment_tree(matrix, mid_row + 1, end_row, start_col, mid_col))
            node.children.append(build_segment_tree(matrix, mid_row + 1, end_row, mid_col + 1, end_col))
        else:
            node.children.append(build_segment_tree(matrix, start_row, mid_row, start_col, end_col))
            node.children.append(build_segment_tree(matrix, mid_row + 1, end_row, start_col, end_col))
    else:
        node.children.append(build_segment_tree(matrix, start_row, end_row, start_col, mid_col))
        node.children.append(build_segment_tree(matrix, start_row, end_row, mid_col + 1, end_col))

    node.value = sum(child.value for child in node.children)
    return node

def compress_segment_tree(node):
    """
    Recursively compress a Segment Tree by removing redundant nodes.

    Args:
        node (SegmentTreeNode): The root node of the segment tree to compress.
    """
    if not node.children:
        return

    for child in node.children:
        compress_segment_tree(child)

    if all(child.value == node.children[0].value for child in node.children):
        node.value = node.children[0].value
        node.children = []

# Example usage
matrix = [
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
]

root = build_segment_tree(matrix, 0, len(matrix) - 1, 0, len(matrix[0]) - 1)
compress_segment_tree(root)

def print_tree(node, level=0):
    """
    Print the structure of the segment tree.

    Args:
        node (SegmentTreeNode): The root node of the segment tree.
        level (int): The level of the current node in the tree.
    """
    print(' ' * level * 4, node.value, (node.start_row, node.end_row), (node.start_col, node.end_col))
    for child in node.children:
        print_tree(child, level + 1)

print_tree(root)

Output
 45 (0, 2) (0, 2)
     12 (0, 1) (0, 1)
         1 (0, 0) (0, 0)
         2 (0, 0) (1, 1)
         4 (1, 1) (0, 0)
         5 (1, 1) (1, 1)
     9 (0, 1) (2, 2)
         3 (0, 0) (2, 2)
         6 (1,...



Next Article
Article Tags :

Similar Reads