Segment Tree

Faster Range Queries

For this article, we want to introduce the idea of a Segment Tree. Segment Trees allow us to quickly perform range queries as well as range updates. Suppose we have an array, and we want to know the sum of a particular range of numbers as well as update the array when necessary. Normally, if we were to just use an array, updating would take O(1) time, but a sum query could take up to O(n) as it could entail looping through the entire array. Segment Trees make both operations an O(log(n)) operation.

Array to Tree

Segment trees work by breaking down the array into a binary tree where each node represents a segment of the array. Each node in the binary tree is created by taking the existing segment, cutting it in half, and distributing it to the children nodes. Here is a graphic to give you an idea of how this tree looks. Note that every node has [i,j] displayed, which shows the interval covered by that particular node in the tree.

Query and Update

Suppose we want to compute the sum of an array on an interval. After building the tree, we can compute this query by moving down the tree until the segment represented by our tree is completely within the bounds of the interval. We then take all our segments and compute the sum of the segments to compute the sum of the interval.

Updating our tree works in a similar fashion. Suppose we want to update a particular point in the array. This would mean we recursively work our way down to the leaf node that contains only that node and update it. Then, when we resolve the recursive stack, we make sure to update all the parent nodes that contain that segment to the new value.

Now that we visually have a good understanding of what our data structure will look like, let's try putting it into some code. An implementation detail that can simplify things is that we can actually use a linear array to represent our tree. We can make a new left node by taking our current node and doing 2 * n if n is our current node and 2 * n + 1 to represent the right node. For this example, we will assume we want to calculate the sum of the array on an interval. It can also be noted that both 1-indexed and 0-indexed arrays can work for segment trees and it is mostly up to personal preference. BUT, the segment tree must be 1-indexed.

For the Interview

Segment trees are only useful for problems involving range queries. The implementation is rather tricky to get right. However, knowing the existence and concept of this data structure would likely impress the interviewer. It's good to know but definitely focus on the more core patterns if you are short on time. Consider this an extra credit.

Implementation

1class segment_tree:
2    def __init__(self, arr):
3        self.tree = [0] * (4 * len(arr))
4        # in the constructor here we use zero indexed instead of one-indexed for curLeft, curRight and idx on update(),
5        # since we are assuming the array for initialization is 0-indexed.
6        # "cur" parameter is always 1 to ensure 1-indexing in the segment tree itself, however curLeft, curRight and idx can be
7        # 0-indexed or 1-indexed relative to each other to be used for calculating segments.
8        for i in range(len(arr)):
9            self.update(1, 0, len(arr) - 1, i, arr[i])
10
11    def update(self, cur, cur_left, cur_right, idx, val):
12        # make sure we reach leaf node when the left interval equals right interval and return the value located in the tree
13        if cur_left == cur_right and cur_left == idx:
14            self.tree[cur] = val
15        else:
16            # compute value of the midpoint where we cut the segment in half
17            cur_mid = (cur_left + cur_right) // 2
18            # remember n * 2 is left child node and n * 2 + 1 is the right child node
19            if idx <= cur_mid:
20                self.update(cur * 2, cur_left, cur_mid, idx, val)
21            else:
22                self.update(cur * 2 + 1, cur_mid + 1, cur_right, idx, val)
23            # after updating the values, compute the new value for the node
24            self.tree[cur] = self.tree[cur * 2] + self.tree[cur * 2 + 1]
25
26    def query(self, cur, cur_left, cur_right, query_left, query_right):
27        # if our current left interval is greater than the queried right interval it means we are out of range
28        # similarly, if the current right interval is less than the queried left interval we are out of range and in both cases return 0
29        if cur_left > query_right or cur_right < query_left:
30            return 0
31        # check if we are in range, if we are return the current interval
32        elif query_left <= cur_left and cur_right <= query_right:
33            return self.tree[cur]
34        # this means part of our interval is in range but part of our interval is not in range, we must therefore query both children
35        cur_mid = (cur_left + cur_right) // 2
36        return self.query(cur * 2, cur_left, cur_mid, query_left, query_right) + self.query(cur * 2 + 1, cur_mid + 1, cur_right, query_left, query_right)
37
1public static class SegmentTree {
2    int [] tree;
3
4    public SegmentTree(int [] arr) {
5        tree = new int [4 * arr.length];
6        // in the constructor here we use zero indexed instead of one-indexed for curLeft, curRight and idx on update(),
7        // since we are assuming the array for initialization is 0-indexed.
8        // "cur" parameter is always 1 to ensure 1-indexing in the segment tree itself, however curLeft, curRight and idx can be
9        // 0-indexed or 1-indexed relative to each other to be used for calculating segments. 
10
11        for (int i = 0; i < arr.length; i++) {
12            update(1, 0, arr.length - 1, i, arr[i]);
13        }
14    }
15
16    void update(int cur, int curLeft, int curRight, int idx, int val) {
17        // make sure we reach leaf node when the left interval equals right interval and return the value located in the tree
18        if (curLeft == curRight && curLeft == idx) {
19            tree[cur] = val;
20        }
21        else {
22            // compute value of the midpoint where we cut the segment in half
23            int curMid = (curLeft + curRight) / 2;
24            // remember n * 2 is left child node and n * 2 + 1 is the right child node
25            if (idx <= curMid) update(cur * 2, curLeft, curMid, idx, val);
26            else update(cur * 2 + 1, curMid + 1, curRight, idx, val);
27            // after updating the values, compute the new value for the node
28            tree[cur] = tree[cur * 2] + tree[cur * 2 + 1];
29        }
30    }
31
32    int query(int cur, int curLeft, int curRight, int queryLeft, int queryRight) {
33        // if our current left interval is greater than the queried right interval it means we are out of range
34        // similarly, if the current right interval is less than the queried left interval we are out of range and in both cases return 0
35        if (curLeft > queryRight || curRight < queryLeft) {
36            return 0;
37        }
38        // check if we are in range, if we are return the current interval
39        else if (queryLeft <= curLeft && curRight <= queryRight) {
40            return tree[cur];
41        }
42        // this means part of our interval is in range but part of our interval is not in range, we must therefore query both children
43        int curMid = (curLeft + curRight) / 2;
44        return query(cur * 2, curLeft, curMid, queryLeft, queryRight) + query(cur * 2 + 1, curMid + 1, curRight, queryLeft, queryRight);
45    }
46}
47
1class SegmentTree {
2    constructor(arr) {
3        this.tree = Array(4 * arr.length).fill(0);
4        // in the constructor here we use zero indexed instead of one-indexed for curLeft, curRight and idx on update(),
5        // since we are assuming the array for initialization is 0-indexed.
6        // "cur" parameter is always 1 to ensure 1-indexing in the segment tree itself, however curLeft, curRight and idx can be
7        // 0-indexed or 1-indexed relative to each other to be used for calculating segments.
8        for (let i = 0; i < arr.length; i++) {
9            this.update(1, 0, arr.length - 1, i, arr[i]);
10        }
11    }
12
13    update(cur, curLeft, curRight, idx, val) {
14        // make sure we reach leaf node when the left interval equals right interval and return the value located in the tree
15        if (curLeft === curRight && curLeft === idx) {
16            this.tree[cur] = val;
17        } else {
18            const curMid = Math.floor((curLeft + curRight) / 2);
19            // n * 2 is the left child node and n * 2 + 1 is the right child node
20            if (idx <= curMid) {
21                this.update(cur * 2, curLeft, curMid, idx, val);
22            } else {
23                this.update(cur * 2 + 1, curMid + 1, curRight, idx, val);
24            }
25            // compute the new value for the node
26            this.tree[cur] = this.tree[cur * 2] + this.tree[cur * 2 + 1];
27        }
28    }
29
30    query(cur, curLeft, curRight, queryLeft, queryRight) {
31        // if our current left interval is greater than the queried right interval, it means we are out of range
32        // similarly, if the current right interval is less than the queried left interval, we are out of range, and in both cases return 0
33        if (curLeft > queryRight || curRight < queryLeft) return 0;
34        // check if we are in range, if we are return the current interval
35        else if (queryLeft <= curLeft && curRight <= queryRight) {
36            return this.tree[cur];
37        }
38        // this means part of our interval is in range but part of our interval is not in range, we must therefore query both children
39        const curMid = Math.floor((curLeft + curRight) / 2);
40        return this.query(cur * 2, curLeft, curMid, queryLeft, queryRight) + this.query(cur * 2 + 1, curMid + 1, curRight, queryLeft, queryRight);
41    }
42}
43
1struct SegmentTree {
2    int* tree;
3
4    SegmentTree(std::vector<int> arr) {
5        tree = new int[4 * arr.size()]();
6        // in the constructor here we use zero indexed instead of one-indexed for curLeft, curRight and idx on update(),
7        // since we are assuming the array for initialization is 0-indexed.
8        // "cur" parameter is always 1 to ensure 1-indexing in the segment tree itself, however curLeft, curRight and idx can be
9        // 0-indexed or 1-indexed relative to each other to be used for calculating segments.
10        for (int i = 0; i < arr.size(); i++) {
11            update(1, 0, arr.size() - 1, i, arr[i]);
12        }
13    }
14
15    void update(int cur, int cur_left, int cur_right, int idx, int val) {
16        // make sure we reach leaf node when the left interval equals right interval and return the value located in the tree
17        if (cur_left == cur_right && cur_left == idx) {
18            tree[cur] = val;
19        } else {
20            // compute value of the midpoint where we cut the segment in half
21            int cur_mid = (cur_left + cur_right) / 2;
22            // remember n * 2 is left child node and n * 2 + 1 is the right child node
23            if (idx <= cur_mid) {
24                update(cur * 2, cur_left, cur_mid, idx, val);
25            } else {
26                update(cur * 2 + 1, cur_mid + 1, cur_right, idx, val);
27            }
28            // after updating the values, compute the new value for the node
29            tree[cur] = tree[cur * 2] + tree[cur * 2 + 1];
30        }
31    }
32
33    int query(int cur, int cur_left, int cur_right, int query_left, int query_right) {
34        // if our current left interval is greater than the queried right interval it means we are out of range
35        // similarly, if the current right interval is less than the queried left interval we are out of range and in both cases return 0
36        if (cur_left > query_right || cur_right < query_left) return 0;
37        // check if we are in range, if we are return the current interval
38        else if (query_left <= cur_left && cur_right <= query_right) {
39            return tree[cur];
40        }
41        // this means part of our interval is in range but part of our interval is not in range, we must therefore query both children
42        int cur_mid = (cur_left + cur_right) / 2;
43        return query(cur * 2, cur_left, cur_mid, query_left, query_right) + query(cur * 2 + 1, cur_mid + 1, cur_right, query_left, query_right);
44    }
45};
46

Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫