Facebook Pixel

Segment Tree Introduction

A segment tree stores information about intervals of an array. It supports two operations efficiently: update one element, and query an aggregate value over a contiguous range. In this introduction, the aggregate is sum, so the operations are update(k, value) and sum(arr[left:right]).

The key constraint is that updates and range queries are interleaved. An update can change any index k, not just the end of the array. A query can ask for any interval [left, right]. We want both operations to be much faster than scanning or rebuilding the whole array.

Why not the plain array?

Store the values in an array. Updating one index is O(1) — one write. But sum(arr[left:right]) walks every element between left and right, so each range query costs O(n). With 1,000,000 elements and a range covering half of them, that is 500,000 additions per query.

Why not a prefix-sum array?

Precompute prefix[i] = arr[0] + arr[1] + ... + arr[i-1]. Now sum(arr[left:right]) = prefix[right] - prefix[left] in O(1). But when arr[k] changes, every prefix[i] with i > k is stale. Rebuilding them costs O(n) per update. The prefix array bought cheap queries by making updates expensive.

Both structures optimize one side. A plain array is cheap to update and slow to query; a prefix-sum array is cheap to query and slow to update. We want both operations cheap on the same structure.

Caching sums on nested intervals

The reason sum(arr[left:right]) is slow on the plain array is that we recompute the sum from scratch every time. If we had already cached the sum of arr[0:4] and the sum of arr[4:8], the query sum(arr[0:8]) would be one addition. The trick is to choose which intervals to cache so that any query can be covered by a few of them — and so that any update invalidates only a few of them.

Start with the whole array as one interval. Split it in half and cache the sum of each half. Split each half in half again and cache those. Continue until each interval holds a single element. Every cached interval is the union of its two children, so its sum is the sum of its children's sums — the caching is self-consistent.

This nested split is a binary tree. The root represents the full array. Each internal node represents an interval and stores its sum; its two children represent the left and right halves. Leaves represent single indices. In the diagram below, every node is labeled [i, j] — the index range it covers.

Segment tree nested intervals

The height of the tree is ⌈log₂ n⌉, because each split halves the interval. That height is the budget for every operation below.

Range query

To compute sum(arr[query_left:query_right]), we recursively walk the tree and compare each node's interval against the query interval. Three cases cover every node we visit. When the node's interval sits entirely outside the query, it contributes 0 and we stop. When the node's interval sits entirely inside the query, we return its cached sum directly — no recursion. When the two intervals partially overlap, we recurse into both children and add their results.

Range query cases: outside, inside, overlap

The second case is the payoff. Once a node is fully inside the query, an entire subtree — potentially half the array — collapses to one lookup. The recursion only keeps descending along the two boundaries of the query, so at most O(log n) nodes do real work.

Point update

Updating a single index follows one root-to-leaf path. Start at the root; at each node, recurse into whichever child contains idx — the left child if idx ≤ mid, otherwise the right child. When we reach the leaf for idx, write the new value. On the way back up, every parent on that path recomputes its sum from its two children: tree[parent] = tree[left_child] + tree[right_child]. Nodes off the path are untouched — their cached sums are still correct.

Point update recomputes parents on the way up

One path, O(log n) nodes updated.

Implementation

Laying the tree out as a flat array

Allocating a real tree of pointer nodes works, but there is a simpler layout: put the tree inside a flat array tree[] using the heap indexing familiar from binary heaps. If a node sits at index cur, its left child is at 2 * cur and its right child is at 2 * cur + 1. We use 1-based indexing for the tree array so the child math stays clean; the input array is still 0-indexed.

Segment tree flat array layout

A safe size for tree[] is 4 * n. The tight bound is 2 * next_power_of_two(n), but 4n is always enough and avoids computing powers of two.

The recursive functions

Each recursive call carries three pieces of state: cur (the index of the current node in the flat tree), and cur_left, cur_right (the interval this node covers). update(cur, cur_left, cur_right, idx, val) walks down to the leaf for idx and recomputes sums on the way back. query(cur, cur_left, cur_right, query_left, query_right) returns the sum of values in [query_left, query_right] restricted to this node's interval, using the three cases above.

Segment tree recursive decomposition logic

The constructor builds the tree by calling update once per input index. A dedicated bottom-up build runs in O(n), but the n calls to update total O(n log n), which is fine for an intro.

1class SegmentTree:
2    def __init__(self, arr):
3        # 4n is a safe upper bound for the flat tree array
4        self.tree = [0] * (4 * len(arr))
5        for i in range(len(arr)):
6            self.update(1, 0, len(arr) - 1, i, arr[i])
7
8    # walk down to the leaf for idx, then recompute sums on the way up
9    def update(self, cur, cur_left, cur_right, idx, val):
10        if cur_left == cur_right:
11            self.tree[cur] = val
12            return
13        cur_mid = (cur_left + cur_right) // 2
14        if idx <= cur_mid:
15            self.update(cur * 2, cur_left, cur_mid, idx, val)
16        else:
17            self.update(cur * 2 + 1, cur_mid + 1, cur_right, idx, val)
18        self.tree[cur] = self.tree[cur * 2] + self.tree[cur * 2 + 1]
19
20    # sum of arr[query_left..query_right]
21    def query(self, cur, cur_left, cur_right, query_left, query_right):
22        # current interval sits entirely outside the query
23        if cur_left > query_right or cur_right < query_left:
24            return 0
25        # current interval sits entirely inside the query
26        if query_left <= cur_left and cur_right <= query_right:
27            return self.tree[cur]
28        # partial overlap: recurse into both children
29        cur_mid = (cur_left + cur_right) // 2
30        return (self.query(cur * 2, cur_left, cur_mid, query_left, query_right)
31                + self.query(cur * 2 + 1, cur_mid + 1, cur_right, query_left, query_right))
32
1public static class SegmentTree {
2    int[] tree;
3
4    public SegmentTree(int[] arr) {
5        // 4n is a safe upper bound for the flat tree array
6        tree = new int[4 * arr.length];
7        for (int i = 0; i < arr.length; i++) {
8            update(1, 0, arr.length - 1, i, arr[i]);
9        }
10    }
11
12    // walk down to the leaf for idx, then recompute sums on the way up
13    void update(int cur, int curLeft, int curRight, int idx, int val) {
14        if (curLeft == curRight) {
15            tree[cur] = val;
16            return;
17        }
18        int curMid = (curLeft + curRight) / 2;
19        if (idx <= curMid) update(cur * 2, curLeft, curMid, idx, val);
20        else update(cur * 2 + 1, curMid + 1, curRight, idx, val);
21        tree[cur] = tree[cur * 2] + tree[cur * 2 + 1];
22    }
23
24    // sum of arr[queryLeft..queryRight]
25    int query(int cur, int curLeft, int curRight, int queryLeft, int queryRight) {
26        // current interval sits entirely outside the query
27        if (curLeft > queryRight || curRight < queryLeft) return 0;
28        // current interval sits entirely inside the query
29        if (queryLeft <= curLeft && curRight <= queryRight) return tree[cur];
30        // partial overlap: recurse into both children
31        int curMid = (curLeft + curRight) / 2;
32        return query(cur * 2, curLeft, curMid, queryLeft, queryRight)
33             + query(cur * 2 + 1, curMid + 1, curRight, queryLeft, queryRight);
34    }
35}
36
1class SegmentTree {
2    constructor(arr) {
3        // 4n is a safe upper bound for the flat tree array
4        this.tree = Array(4 * arr.length).fill(0);
5        for (let i = 0; i < arr.length; i++) {
6            this.update(1, 0, arr.length - 1, i, arr[i]);
7        }
8    }
9
10    // walk down to the leaf for idx, then recompute sums on the way up
11    update(cur, curLeft, curRight, idx, val) {
12        if (curLeft === curRight) {
13            this.tree[cur] = val;
14            return;
15        }
16        const curMid = Math.floor((curLeft + curRight) / 2);
17        if (idx <= curMid) {
18            this.update(cur * 2, curLeft, curMid, idx, val);
19        } else {
20            this.update(cur * 2 + 1, curMid + 1, curRight, idx, val);
21        }
22        this.tree[cur] = this.tree[cur * 2] + this.tree[cur * 2 + 1];
23    }
24
25    // sum of arr[queryLeft..queryRight]
26    query(cur, curLeft, curRight, queryLeft, queryRight) {
27        // current interval sits entirely outside the query
28        if (curLeft > queryRight || curRight < queryLeft) return 0;
29        // current interval sits entirely inside the query
30        if (queryLeft <= curLeft && curRight <= queryRight) return this.tree[cur];
31        // partial overlap: recurse into both children
32        const curMid = Math.floor((curLeft + curRight) / 2);
33        return this.query(cur * 2, curLeft, curMid, queryLeft, queryRight)
34             + this.query(cur * 2 + 1, curMid + 1, curRight, queryLeft, queryRight);
35    }
36}
37
1struct SegmentTree {
2    std::vector<int> tree;
3
4    SegmentTree(const std::vector<int>& arr) : tree(4 * arr.size(), 0) {
5        // 4n is a safe upper bound for the flat tree array
6        for (int i = 0; i < (int)arr.size(); i++) {
7            update(1, 0, arr.size() - 1, i, arr[i]);
8        }
9    }
10
11    // walk down to the leaf for idx, then recompute sums on the way up
12    void update(int cur, int cur_left, int cur_right, int idx, int val) {
13        if (cur_left == cur_right) {
14            tree[cur] = val;
15            return;
16        }
17        int cur_mid = (cur_left + cur_right) / 2;
18        if (idx <= cur_mid) update(cur * 2, cur_left, cur_mid, idx, val);
19        else update(cur * 2 + 1, cur_mid + 1, cur_right, idx, val);
20        tree[cur] = tree[cur * 2] + tree[cur * 2 + 1];
21    }
22
23    // sum of arr[query_left..query_right]
24    int query(int cur, int cur_left, int cur_right, int query_left, int query_right) {
25        // current interval sits entirely outside the query
26        if (cur_left > query_right || cur_right < query_left) return 0;
27        // current interval sits entirely inside the query
28        if (query_left <= cur_left && cur_right <= query_right) return tree[cur];
29        // partial overlap: recurse into both children
30        int cur_mid = (cur_left + cur_right) / 2;
31        return query(cur * 2, cur_left, cur_mid, query_left, query_right)
32             + query(cur * 2 + 1, cur_mid + 1, cur_right, query_left, query_right);
33    }
34};
35

Tracing the running example

Take arr = [2, 1, 5, 3], indices 0..3. After building, the tree holds the sums shown below (each node is labeled [i, j] = sum).

Tracing segment tree query and update

Call query(1, 3) — the sum of arr[1..3]. We start at the root [0, 3]. The query [1, 3] overlaps it partially, so we recurse into both children. The left child covers [0, 1], which overlaps [1, 3] partially, so we recurse into its children: [0, 0] sits entirely outside the query and returns 0; [1, 1] sits entirely inside and returns 1. The left subtree contributes 1. The right child covers [2, 3], which sits entirely inside [1, 3], so we return its cached sum 8 directly — no deeper recursion. The root sums its children: 1 + 8 = 9, matching arr[1] + arr[2] + arr[3].

Call update(2, 0) — set arr[2] = 0. We walk the root-to-leaf path: root [0, 3] → right child [2, 3] → left child [2, 2]. Write 0 at the leaf. On the way back, [2, 3] recomputes as 0 + 3 = 3; the root recomputes as 3 + 3 = 6. The left subtree ([0, 1] and its children) was never visited — its cached sums are still correct.

A follow-up query(1, 3) now returns 1 + 3 = 4, reflecting the update.

What about performance?

Both query and update run in O(log n) because each operation visits a bounded number of nodes per level, and the tree has ⌈log₂ n⌉ levels. For update, each level visits exactly one node — the one on the root-to-leaf path for idx. For query, each level visits at most four nodes: the two that span the query boundaries plus their immediate siblings. Everything else collapses into a single cached read.

Space is O(n). The flat tree array of size 4n is an over-allocation for simplicity; the tree itself holds fewer than 2n useful entries.

For the interview

Reach for a segment tree when a problem mixes frequent range queries and frequent point updates, and either a plain array or a prefix-sum array leaves one side slow. Signals in the problem statement: "sum / min / max / gcd over a range," "after each update," or queries interleaved with updates.

The operation does not have to be sum. Any associative operation with a neutral element works — min, max, gcd, bitwise OR. Change the merge rule (tree[parent] = left + right becomes tree[parent] = min(left, right), and so on) and the "outside the query" return value (0 for sum, +∞ for min), and the same structure answers a different question. Range Maximum Query, the next article in this section, applies exactly this swap.

The implementation has many interval boundary cases, so off-by-one mistakes are easy. In an interview, explaining when to reach for a segment tree and how query and update decompose the range is often enough to move forward; you can then fill in the code with care. Range updates with lazy propagation are a follow-up topic layered on the same shape.

Invest in Yourself
Your new job is waiting. 83% of people that complete the program get a job offer. Unlock unlimited access to all content and features.
Go Pro