Facebook Pixel

Segment Tree

Why Segment Trees Exist

A segment tree is built for one common situation: we need to answer many range queries on an array, and the array values can change between queries. For example, we may repeatedly ask for the sum of arr[left:right], then update one index, then ask another range sum.

A plain array gives fast point updates in O(1), but each range-sum query can cost O(n) because we may scan most of the array. Prefix sums give O(1) range-sum queries, but after one update, many prefix values become invalid and fixing them can cost O(n). A segment tree balances both needs: range query in O(log n) and point update in O(log n). (Range updates are also possible, usually with lazy propagation, which is a follow-up topic.)

How the Array Becomes a Tree

The root node represents the full interval of the array. Each internal node splits its interval into two halves, and its children represent those halves. Leaves represent single indices. If this is a sum segment tree, each node stores the sum of its interval. That means each parent can be rebuilt from its two children.

In the diagram, every node shows [i, j], the index interval covered by that node.

Query and Update

To query a range sum, we recursively walk the tree and compare each node interval against the query interval. When a node is completely outside the query, it contributes 0. When a node is completely inside the query, we use its stored sum directly. When it partially overlaps, we recurse into both children and add their results. Because each level discards large parts of the array, the query takes O(log n) in typical segment-tree usage.

Point update follows the same top-down path idea. We recurse to the leaf for index idx and assign the new value there. Then we return upward and recompute every parent on that path from its two children. Only one root-to-leaf path changes, so update also takes O(log n).

We usually store the segment tree in a flat array instead of explicit tree nodes. In this layout, if a node is stored at index cur, its left child is 2 * cur and right child is 2 * cur + 1. This article's code uses a 1-indexed tree array so that child math stays simple. The input array itself is still 0-indexed. The implementation below is for range-sum query plus point update.

For the Interview

Segment trees are most useful when a problem mixes frequent range queries with frequent updates. The implementation has many interval boundary cases, so it is easy to make off-by-one mistakes. In interviews, even explaining when to use a segment tree and how query/update work can already be valuable. If you are short on time, treat this as an advanced topic after core patterns.

Implementation

1class segment_tree:
2    def __init__(self, arr):
3        self.tree = [0] * (4 * len(arr))
4        # in the constructor here we use zero indexed instead of one-indexed for curLeft, curRight and idx on update(),
5        # since we are assuming the array for initialization is 0-indexed.
6        # "cur" parameter is always 1 to ensure 1-indexing in the segment tree itself, however curLeft, curRight and idx can be
7        # 0-indexed or 1-indexed relative to each other to be used for calculating segments.
8        for i in range(len(arr)):
9            self.update(1, 0, len(arr) - 1, i, arr[i])
10
11    def update(self, cur, cur_left, cur_right, idx, val):
12        # make sure we reach leaf node when the left interval equals right interval and return the value located in the tree
13        if cur_left == cur_right and cur_left == idx:
14            self.tree[cur] = val
15        else:
16            # compute value of the midpoint where we cut the segment in half
17            cur_mid = (cur_left + cur_right) // 2
18            # remember n * 2 is left child node and n * 2 + 1 is the right child node
19            if idx <= cur_mid:
20                self.update(cur * 2, cur_left, cur_mid, idx, val)
21            else:
22                self.update(cur * 2 + 1, cur_mid + 1, cur_right, idx, val)
23            # after updating the values, compute the new value for the node
24            self.tree[cur] = self.tree[cur * 2] + self.tree[cur * 2 + 1]
25
26    def query(self, cur, cur_left, cur_right, query_left, query_right):
27        # if our current left interval is greater than the queried right interval it means we are out of range
28        # similarly, if the current right interval is less than the queried left interval we are out of range and in both cases return 0
29        if cur_left > query_right or cur_right < query_left:
30            return 0
31        # check if we are in range, if we are return the current interval
32        elif query_left <= cur_left and cur_right <= query_right:
33            return self.tree[cur]
34        # this means part of our interval is in range but part of our interval is not in range, we must therefore query both children
35        cur_mid = (cur_left + cur_right) // 2
36        return self.query(cur * 2, cur_left, cur_mid, query_left, query_right) + self.query(cur * 2 + 1, cur_mid + 1, cur_right, query_left, query_right)
37
1public static class SegmentTree {
2    int [] tree;
3
4    public SegmentTree(int [] arr) {
5        tree = new int [4 * arr.length];
6        // in the constructor here we use zero indexed instead of one-indexed for curLeft, curRight and idx on update(),
7        // since we are assuming the array for initialization is 0-indexed.
8        // "cur" parameter is always 1 to ensure 1-indexing in the segment tree itself, however curLeft, curRight and idx can be
9        // 0-indexed or 1-indexed relative to each other to be used for calculating segments. 
10
11        for (int i = 0; i < arr.length; i++) {
12            update(1, 0, arr.length - 1, i, arr[i]);
13        }
14    }
15
16    void update(int cur, int curLeft, int curRight, int idx, int val) {
17        // make sure we reach leaf node when the left interval equals right interval and return the value located in the tree
18        if (curLeft == curRight && curLeft == idx) {
19            tree[cur] = val;
20        }
21        else {
22            // compute value of the midpoint where we cut the segment in half
23            int curMid = (curLeft + curRight) / 2;
24            // remember n * 2 is left child node and n * 2 + 1 is the right child node
25            if (idx <= curMid) update(cur * 2, curLeft, curMid, idx, val);
26            else update(cur * 2 + 1, curMid + 1, curRight, idx, val);
27            // after updating the values, compute the new value for the node
28            tree[cur] = tree[cur * 2] + tree[cur * 2 + 1];
29        }
30    }
31
32    int query(int cur, int curLeft, int curRight, int queryLeft, int queryRight) {
33        // if our current left interval is greater than the queried right interval it means we are out of range
34        // similarly, if the current right interval is less than the queried left interval we are out of range and in both cases return 0
35        if (curLeft > queryRight || curRight < queryLeft) {
36            return 0;
37        }
38        // check if we are in range, if we are return the current interval
39        else if (queryLeft <= curLeft && curRight <= queryRight) {
40            return tree[cur];
41        }
42        // this means part of our interval is in range but part of our interval is not in range, we must therefore query both children
43        int curMid = (curLeft + curRight) / 2;
44        return query(cur * 2, curLeft, curMid, queryLeft, queryRight) + query(cur * 2 + 1, curMid + 1, curRight, queryLeft, queryRight);
45    }
46}
47
1class SegmentTree {
2    constructor(arr) {
3        this.tree = Array(4 * arr.length).fill(0);
4        // in the constructor here we use zero indexed instead of one-indexed for curLeft, curRight and idx on update(),
5        // since we are assuming the array for initialization is 0-indexed.
6        // "cur" parameter is always 1 to ensure 1-indexing in the segment tree itself, however curLeft, curRight and idx can be
7        // 0-indexed or 1-indexed relative to each other to be used for calculating segments.
8        for (let i = 0; i < arr.length; i++) {
9            this.update(1, 0, arr.length - 1, i, arr[i]);
10        }
11    }
12
13    update(cur, curLeft, curRight, idx, val) {
14        // make sure we reach leaf node when the left interval equals right interval and return the value located in the tree
15        if (curLeft === curRight && curLeft === idx) {
16            this.tree[cur] = val;
17        } else {
18            const curMid = Math.floor((curLeft + curRight) / 2);
19            // n * 2 is the left child node and n * 2 + 1 is the right child node
20            if (idx <= curMid) {
21                this.update(cur * 2, curLeft, curMid, idx, val);
22            } else {
23                this.update(cur * 2 + 1, curMid + 1, curRight, idx, val);
24            }
25            // compute the new value for the node
26            this.tree[cur] = this.tree[cur * 2] + this.tree[cur * 2 + 1];
27        }
28    }
29
30    query(cur, curLeft, curRight, queryLeft, queryRight) {
31        // if our current left interval is greater than the queried right interval, it means we are out of range
32        // similarly, if the current right interval is less than the queried left interval, we are out of range, and in both cases return 0
33        if (curLeft > queryRight || curRight < queryLeft) return 0;
34        // check if we are in range, if we are return the current interval
35        else if (queryLeft <= curLeft && curRight <= queryRight) {
36            return this.tree[cur];
37        }
38        // this means part of our interval is in range but part of our interval is not in range, we must therefore query both children
39        const curMid = Math.floor((curLeft + curRight) / 2);
40        return this.query(cur * 2, curLeft, curMid, queryLeft, queryRight) + this.query(cur * 2 + 1, curMid + 1, curRight, queryLeft, queryRight);
41    }
42}
43
1struct SegmentTree {
2    int* tree;
3
4    SegmentTree(std::vector<int> arr) {
5        tree = new int[4 * arr.size()]();
6        // in the constructor here we use zero indexed instead of one-indexed for curLeft, curRight and idx on update(),
7        // since we are assuming the array for initialization is 0-indexed.
8        // "cur" parameter is always 1 to ensure 1-indexing in the segment tree itself, however curLeft, curRight and idx can be
9        // 0-indexed or 1-indexed relative to each other to be used for calculating segments.
10        for (int i = 0; i < arr.size(); i++) {
11            update(1, 0, arr.size() - 1, i, arr[i]);
12        }
13    }
14
15    void update(int cur, int cur_left, int cur_right, int idx, int val) {
16        // make sure we reach leaf node when the left interval equals right interval and return the value located in the tree
17        if (cur_left == cur_right && cur_left == idx) {
18            tree[cur] = val;
19        } else {
20            // compute value of the midpoint where we cut the segment in half
21            int cur_mid = (cur_left + cur_right) / 2;
22            // remember n * 2 is left child node and n * 2 + 1 is the right child node
23            if (idx <= cur_mid) {
24                update(cur * 2, cur_left, cur_mid, idx, val);
25            } else {
26                update(cur * 2 + 1, cur_mid + 1, cur_right, idx, val);
27            }
28            // after updating the values, compute the new value for the node
29            tree[cur] = tree[cur * 2] + tree[cur * 2 + 1];
30        }
31    }
32
33    int query(int cur, int cur_left, int cur_right, int query_left, int query_right) {
34        // if our current left interval is greater than the queried right interval it means we are out of range
35        // similarly, if the current right interval is less than the queried left interval we are out of range and in both cases return 0
36        if (cur_left > query_right || cur_right < query_left) return 0;
37        // check if we are in range, if we are return the current interval
38        else if (query_left <= cur_left && cur_right <= query_right) {
39            return tree[cur];
40        }
41        // this means part of our interval is in range but part of our interval is not in range, we must therefore query both children
42        int cur_mid = (cur_left + cur_right) / 2;
43        return query(cur * 2, cur_left, cur_mid, query_left, query_right) + query(cur * 2 + 1, cur_mid + 1, cur_right, query_left, query_right);
44    }
45};
46
Invest in Yourself
Your new job is waiting. 83% of people that complete the program get a job offer. Unlock unlimited access to all content and features.
Go Pro