Facebook Pixel

3820. Pythagorean Distance Nodes in a Tree

MediumTreeBreadth-First Search
LeetCode ↗

Problem Description

You are given an integer n and an undirected tree with n nodes numbered from 0 to n - 1. The tree is represented by a 2D array edges of length n - 1, where edges[i] = [uᵢ, vᵢ] indicates an undirected edge between uᵢ and vᵢ.

You are also given three distinct target nodes x, y, and z.

For any node u in the tree:

  • Let dx be the distance from u to node x.
  • Let dy be the distance from u to node y.
  • Let dz be the distance from u to node z.

The node u is called special if the three distances dx, dy, and dz form a Pythagorean Triplet.

Your task is to return an integer denoting the number of special nodes in the tree.

A Pythagorean triplet consists of three integers a, b, and c which, when sorted in ascending order, satisfy a² + b² = c².

The distance between two nodes in a tree is the number of edges on the unique path between them.

Quick Interview Experience
Help others by sharing your interview experience
Have you seen this problem before?

How We Pick the Algorithm

Why BFS?

This problem maps to BFS through a short path in the full flowchart.

Graph ortree?yesBFS forlevels/layers?yesBFS

Computing distances from three target nodes to all nodes via BFS, then checking Pythagorean triplet conditions at each node.

Open in Flowchart
Show step-by-step reasoning

First, let's pin down the algorithm using the Flowchart. Here's a step-by-step walkthrough:

Is it a graph?

  • Yes: The problem provides a tree with n nodes and n - 1 edges, which is a special kind of graph defined by nodes and connections.

Is it a tree?

  • Yes: The problem explicitly states we are given an undirected tree with n nodes.

DFS

  • The flowchart points toward DFS for tree problems. However, since we need to compute the distance (number of edges) from each target node to every other node, BFS is the natural fit for computing unweighted distances level by level. In a tree, both DFS and BFS traverse all nodes, but BFS directly yields shortest-path distances in an unweighted graph.

Conclusion: Because the tree is unweighted and we must measure the distance from each of the three target nodes x, y, and z to all other nodes, the flowchart leads us to the Breadth-First Search pattern. We run BFS from each target node to build three distance arrays, then check each node for the Pythagorean triplet condition.

Intuition

The key observation is that to decide whether a node u is special, we only need three pieces of information: the distance from u to x, the distance from u to y, and the distance from u to z. Once we have these three values, checking whether they form a Pythagorean triplet is a simple constant-time test.

So the problem naturally splits into two parts: first compute all the distances, and then check the condition for every node.

How do we compute distances? Since the tree is unweighted (each edge counts as 1), the distance from a fixed source node to all other nodes is exactly the BFS layer in which each node is discovered. This means a single BFS starting from a node i gives us the distance from i to every node in the tree in O(n) time. We don't need to repeat this per pair of nodes — one BFS per source covers all destinations at once.

Because we care about distances measured from the three fixed targets x, y, and z, we run BFS exactly three times: once from each target. This produces three distance arrays d1, d2, and d3, where d1[u], d2[u], and d3[u] are the distances from u to x, y, and z respectively.

After that, we walk through every node u and look at its triple (d1[u], d2[u], d3[u]). The Pythagorean condition requires the largest of the three values to be the hypotenuse, so we sort the three numbers to identify the smallest a, the middle b, and the largest c, then simply test whether a*a + b*b == c*c. Each node that satisfies this is counted as special.

This two-phase idea — precompute distances with BFS, then enumerate and test each node — keeps the whole solution efficient at O(n), since both the three BFS passes and the final scan over all nodes are linear in the size of the tree.

Pattern Learn more about Tree and Breadth-First Search patterns.

Solution Approach

We follow a BFS + Enumeration strategy. The implementation breaks down into three clear steps: build the graph, compute distances, then check each node.

Step 1: Build the adjacency list

We first construct an adjacency list g from the given edges, where g[u] stores all nodes adjacent to node u. Since the tree is undirected, each edge [u, v] is added in both directions:

g = [[] for _ in range(n)]
for u, v in edges:
    g[u].append(v)
    g[v].append(u)

Step 2: Compute distances with BFS

We define a function bfs(i) that calculates the distances from node i to all other nodes. We use a queue (deque) to perform Breadth-First Search and maintain a distance array dist, where dist[j] is the distance from i to j. Initially, dist[i] = 0 and every other entry is set to inf. During the traversal, whenever we reach a neighbor v of the current node u with a shorter distance, we update dist[v] = dist[u] + 1 and push v into the queue:

def bfs(i: int) -> List[int]:
    q = deque([i])
    dist = [inf] * n
    dist[i] = 0
    while q:
        for _ in range(len(q)):
            u = q.popleft()
            for v in g[u]:
                if dist[v] > dist[u] + 1:
                    dist[v] = dist[u] + 1
                    q.append(v)
    return dist

We call bfs(x), bfs(y), and bfs(z) to obtain three distance arrays d1, d2, and d3, holding the distances from each node to x, y, and z respectively.

Step 3: Enumerate every node and test the condition

Finally, we iterate over all nodes by zipping the three distance arrays together. For each node we get its triple (a, b, c). To apply the Pythagorean check, we need the values sorted so the largest is treated as the hypotenuse. We compute the sum s = a + b + c, then take a = min(a, b, c) and c = max(a, b, c), and recover the middle value with b = s - a - c. We then test whether a*a + b*b == c*c, incrementing the answer for each node that passes:

ans = 0
for a, b, c in zip(d1, d2, d3):
    s = a + b + c
    a, c = min(a, b, c), max(a, b, c)
    b = s - a - c
    if a * a + b * b == c * c:
        ans += 1
return ans

Complexity Analysis

  • Time complexity: O(n). Each BFS visits all nodes and edges in O(n) (a tree has n - 1 edges), and we run it three times. The final enumeration over all nodes is also O(n).
  • Space complexity: O(n). The adjacency list, the BFS queue, and each distance array all use linear space.

Example Walkthrough

Let's trace through a small concrete example to see how the BFS + Enumeration approach works.

Input setup

n = 5
edges = [[0, 1], [1, 2], [1, 3], [3, 4]]
x = 0, y = 2, z = 4

The tree looks like this:

        0
        |
        1
       / \
      2   3
          |
          4

Step 1: Build the adjacency list

Adding each edge in both directions gives:

g[0] = [1]
g[1] = [0, 2, 3]
g[2] = [1]
g[3] = [1, 4]
g[4] = [3]

Step 2: Compute distances with three BFS passes

We run BFS from each target node. Each pass labels every node with its distance from the source.

d1 = bfs(x=0) — distances from node 0:

node01234
dist01223

(0→1 is 1 step, 1→2 and 1→3 are 2 steps, 3→4 is 3 steps.)

d2 = bfs(y=2) — distances from node 2:

node01234
dist21023

d3 = bfs(z=4) — distances from node 4:

node01234
dist32310

Step 3: Enumerate every node and test the Pythagorean condition

For each node u, we read the triple (d1[u], d2[u], d3[u]), sort to identify a ≤ b ≤ c, then check a² + b² == c².

node(d1, d2, d3)sorted (a, b, c)a² + b²special?
0(0, 2, 3)(0, 2, 3)0 + 4 = 49
1(1, 1, 2)(1, 1, 2)1 + 1 = 24
2(2, 0, 3)(0, 2, 3)0 + 4 = 49
3(2, 2, 1)(1, 2, 2)1 + 4 = 54
4(3, 3, 0)(0, 3, 3)0 + 9 = 99

Let's confirm how the code recovers the middle value for node 4 with triple (3, 3, 0):

  • s = 3 + 3 + 0 = 6
  • a = min(3, 3, 0) = 0, c = max(3, 3, 0) = 3
  • b = s - a - c = 6 - 0 - 3 = 3
  • Check: 0² + 3² = 9 == 3² ✓ → counted

Result

Only node 4 satisfies the condition, so the answer is:

ans = 1

This walkthrough demonstrates the full pipeline: build the graph once, run BFS three times to fill d1, d2, d3, then scan every node in linear time to count those whose three distances form a Pythagorean triplet.

Solution Implementation

1from collections import deque
2from math import inf
3from typing import List
4
5
6class Solution:
7    def specialNodes(
8        self, n: int, edges: List[List[int]], x: int, y: int, z: int
9    ) -> int:
10        # Build an undirected adjacency list for the graph.
11        graph = [[] for _ in range(n)]
12        for u, v in edges:
13            graph[u].append(v)
14            graph[v].append(u)
15
16        def bfs(start: int) -> List[int]:
17            # Compute the shortest distance (number of edges) from `start`
18            # to every other node using breadth-first search.
19            queue = deque([start])
20            dist = [inf] * n
21            dist[start] = 0
22            while queue:
23                # Process all nodes currently in the queue layer by layer.
24                for _ in range(len(queue)):
25                    u = queue.popleft()
26                    for v in graph[u]:
27                        # Relax the distance if a shorter path is found.
28                        if dist[v] > dist[u] + 1:
29                            dist[v] = dist[u] + 1
30                            queue.append(v)
31            return dist
32
33        # Distances from each of the three source nodes x, y, z.
34        dist_x = bfs(x)
35        dist_y = bfs(y)
36        dist_z = bfs(z)
37
38        ans = 0
39        # For every node, examine its distances to x, y and z.
40        for dx, dy, dz in zip(dist_x, dist_y, dist_z):
41            total = dx + dy + dz
42            # Sort the three distances: `low` (smallest), `high` (largest),
43            # and `mid` (the remaining one).
44            low, high = min(dx, dy, dz), max(dx, dy, dz)
45            mid = total - low - high
46            # Count the node if the three distances form a Pythagorean triple.
47            if low * low + mid * mid == high * high:
48                ans += 1
49        return ans
50
1class Solution {
2    // Adjacency list representation of the graph.
3    private List<Integer>[] g;
4    // Total number of nodes in the graph.
5    private int n;
6    // A large value used to represent "infinity" (unreachable distance).
7    // Using MAX_VALUE / 2 avoids integer overflow when computing dist[u] + 1.
8    private final int inf = Integer.MAX_VALUE / 2;
9
10    /**
11     * Counts the number of nodes whose distances to three given source nodes
12     * (x, y, z) can form a right triangle (i.e., satisfy the Pythagorean theorem).
13     *
14     * @param n     the number of nodes
15     * @param edges the undirected edges of the graph
16     * @param x     the first source node
17     * @param y     the second source node
18     * @param z     the third source node
19     * @return the count of nodes satisfying the right-triangle condition
20     */
21    public int specialNodes(int n, int[][] edges, int x, int y, int z) {
22        this.n = n;
23
24        // Initialize the adjacency list, one list per node.
25        g = new ArrayList[n];
26        Arrays.setAll(g, k -> new ArrayList<>());
27
28        // Build the undirected graph by adding both directions for each edge.
29        for (int[] e : edges) {
30            int u = e[0], v = e[1];
31            g[u].add(v);
32            g[v].add(u);
33        }
34
35        // Compute shortest distances from each of the three source nodes
36        // to all other nodes using BFS (graph is unweighted).
37        int[] d1 = bfs(x);
38        int[] d2 = bfs(y);
39        int[] d3 = bfs(z);
40
41        int ans = 0;
42        // For each node, check whether its three distances form a right triangle.
43        for (int i = 0; i < n; i++) {
44            // Use long to avoid overflow when squaring distances.
45            long[] a = new long[] {d1[i], d2[i], d3[i]};
46            // Sort so that a[2] is the largest (the potential hypotenuse).
47            Arrays.sort(a);
48            // Pythagorean check: a[0]^2 + a[1]^2 == a[2]^2.
49            if (a[0] * a[0] + a[1] * a[1] == a[2] * a[2]) {
50                ++ans;
51            }
52        }
53        return ans;
54    }
55
56    /**
57     * Performs a breadth-first search from the given source node and returns
58     * the shortest distance (in number of edges) from the source to every node.
59     *
60     * @param i the source node
61     * @return an array where dist[v] is the shortest distance from i to v,
62     *         or inf if v is unreachable
63     */
64    private int[] bfs(int i) {
65        // Distance array initialized to infinity for all nodes.
66        int[] dist = new int[n];
67        Arrays.fill(dist, inf);
68
69        // Standard BFS queue.
70        Deque<Integer> q = new ArrayDeque<>();
71        dist[i] = 0;
72        q.add(i);
73
74        // Process the queue level by level.
75        while (!q.isEmpty()) {
76            // Iterate over all nodes currently at the same BFS level.
77            for (int k = q.size(); k > 0; --k) {
78                int u = q.poll();
79                // Relax each neighbor: update distance if a shorter path is found.
80                for (int v : g[u]) {
81                    if (dist[v] > dist[u] + 1) {
82                        dist[v] = dist[u] + 1;
83                        q.add(v);
84                    }
85                }
86            }
87        }
88        return dist;
89    }
90}
91
1class Solution {
2private:
3    vector<vector<int>> graph;  // Adjacency list representation of the graph
4    int numNodes;               // Total number of nodes in the graph
5    const int kInf = INT_MAX / 2;  // A large value representing "unreachable" distance
6
7    // Performs BFS from a given source node and returns the shortest distances
8    // from the source to every other node (since the graph is unweighted).
9    vector<int> bfs(int source) {
10        vector<int> dist(numNodes, kInf);  // Initialize all distances as infinite
11        queue<int> q;
12        dist[source] = 0;  // Distance to the source itself is 0
13        q.push(source);
14
15        while (!q.empty()) {
16            // Process the current BFS level layer by layer
17            for (int levelSize = q.size(); levelSize > 0; --levelSize) {
18                int u = q.front();
19                q.pop();
20                // Relax all neighbors of the current node
21                for (int v : graph[u]) {
22                    if (dist[v] > dist[u] + 1) {
23                        dist[v] = dist[u] + 1;
24                        q.push(v);
25                    }
26                }
27            }
28        }
29        return dist;
30    }
31
32public:
33    int specialNodes(int n, vector<vector<int>>& edges, int x, int y, int z) {
34        this->numNodes = n;
35        graph.assign(n, {});
36
37        // Build the undirected graph from the edge list
38        for (auto& edge : edges) {
39            int u = edge[0], v = edge[1];
40            graph[u].push_back(v);
41            graph[v].push_back(u);
42        }
43
44        // Compute shortest distances from each of the three special sources
45        vector<int> distFromX = bfs(x);
46        vector<int> distFromY = bfs(y);
47        vector<int> distFromZ = bfs(z);
48
49        int ans = 0;
50        for (int i = 0; i < n; ++i) {
51            // Collect the three distances for node i
52            array<long long, 3> sides = {
53                static_cast<long long>(distFromX[i]),
54                static_cast<long long>(distFromY[i]),
55                static_cast<long long>(distFromZ[i])};
56
57            // Sort so that the largest distance is treated as the hypotenuse
58            sort(sides.begin(), sides.end());
59
60            // Check whether the three distances satisfy the Pythagorean theorem
61            if (sides[0] * sides[0] + sides[1] * sides[1] == sides[2] * sides[2]) {
62                ++ans;
63            }
64        }
65        return ans;
66    }
67};
68
1/**
2 * Counts the number of "special" nodes in an undirected graph.
3 *
4 * A node `i` is considered special when the three shortest-path distances
5 * from `i` to the three given source nodes (x, y, z) can form a right
6 * triangle. In other words, after sorting the three distances ascending as
7 * [a, b, c], the node is special if a^2 + b^2 === c^2.
8 *
9 * @param n     Number of nodes (labeled 0 .. n-1).
10 * @param edges List of undirected edges, each given as [u, v].
11 * @param x     First source node.
12 * @param y     Second source node.
13 * @param z     Third source node.
14 * @returns     The count of special nodes.
15 */
16function specialNodes(
17    n: number,
18    edges: number[][],
19    x: number,
20    y: number,
21    z: number,
22): number {
23    // Build the adjacency list for the undirected graph.
24    const graph: number[][] = Array.from({ length: n }, () => []);
25    for (const [u, v] of edges) {
26        graph[u].push(v);
27        graph[v].push(u);
28    }
29
30    // A sentinel value representing "unreachable / infinite distance".
31    const inf = 1e9;
32
33    /**
34     * Performs a breadth-first search from the given start node and
35     * returns the shortest distances to every other node.
36     *
37     * @param start The source node to start the BFS from.
38     * @returns     An array where index `i` holds the distance to node `i`.
39     */
40    const bfs = (start: number): number[] => {
41        const dist: number[] = Array(n).fill(inf);
42        let queue: number[] = [start];
43        dist[start] = 0;
44
45        // Process the graph layer by layer.
46        while (queue.length) {
47            const nextQueue: number[] = [];
48            for (const u of queue) {
49                for (const v of graph[u]) {
50                    // Relax the edge if a shorter path is found.
51                    if (dist[v] > dist[u] + 1) {
52                        dist[v] = dist[u] + 1;
53                        nextQueue.push(v);
54                    }
55                }
56            }
57            queue = nextQueue;
58        }
59
60        return dist;
61    };
62
63    // Compute shortest distances from each of the three source nodes.
64    const distFromX = bfs(x);
65    const distFromY = bfs(y);
66    const distFromZ = bfs(z);
67
68    let ans = 0;
69    for (let i = 0; i < n; i++) {
70        // Gather the three distances and sort them ascending.
71        const triple = [distFromX[i], distFromY[i], distFromZ[i]];
72        triple.sort((p, q) => p - q);
73
74        // Check the Pythagorean condition: a^2 + b^2 === c^2.
75        if (triple[0] * triple[0] + triple[1] * triple[1] === triple[2] * triple[2]) {
76            ans++;
77        }
78    }
79
80    return ans;
81}
82

Time and Space Complexity

时间复杂度

代码的时间复杂度为 O(n),其中 n 是树中节点的数量。

分析如下:

  • 建图阶段:遍历 edges 构建邻接表。由于输入是一棵树,边数为 n - 1,因此建图的时间复杂度为 O(n)

  • BFS 阶段bfs 函数从某个起点出发遍历整张图。每个节点最多入队一次,每条边最多被遍历常数次(无向图中每条边对应两个方向)。对于树而言,节点数为 n,边数为 n - 1,因此单次 BFS 的时间复杂度为 O(n)。代码中分别从 xyz 出发执行了 3 次 BFS,总计 3 * O(n) = O(n)

  • 统计阶段:通过 zip(d1, d2, d3) 遍历每个节点的三个距离值并做常数级别的判断与计算,时间复杂度为 O(n)

综合以上各部分,总时间复杂度为 O(n)

空间复杂度

代码的空间复杂度为 O(n),其中 n 是树中节点的数量。

分析如下:

  • 邻接表 g:存储所有节点及其邻居,树中边数为 n - 1,因此占用 O(n) 空间。

  • 距离数组bfs 函数中的 dist 数组长度为 n,调用 3 次分别得到 d1d2d3,共占用 3 * O(n) = O(n) 空间。

  • 队列 q:BFS 过程中队列最多容纳 O(n) 个节点。

综合以上各部分,总空间复杂度为 O(n)

Pattern Learn more about how to find time and space complexity quickly.

Common Pitfalls

Pitfall 1: Treating a Degenerate Triplet (with a Zero Distance) as a Valid Pythagorean Triplet

The single most common bug in this problem comes from the three target nodes x, y, and z themselves. When u happens to equal one of the target nodes, one of the distances becomes 0. Consider the node u = x: then dx = 0, and the triple looks like (0, dy, dz).

After sorting, the check becomes:

low = 0,  mid = dy,  high = dz   (assuming dy <= dz)
0*0 + dy*dy == dz*dzdy*dy == dz*dzdy == dz

So whenever a node is equidistant from the two other targets, the naive check low*low + mid*mid == high*high returns True even though one side has length 0. A "triangle" with a zero-length side is not a genuine Pythagorean triplet — a Pythagorean triplet should consist of three positive integers.

Why this matters: The target node x always has dy == dz? No — but any node lying on the perpendicular "midpoint" between y and z (including possibly x itself) will be miscounted. This silently inflates the answer.

Solution: Explicitly reject any triple containing a zero before doing the Pythagorean test.

ans = 0
for dx, dy, dz in zip(dist_x, dist_y, dist_z):
    total = dx + dy + dz
    low, high = min(dx, dy, dz), max(dx, dy, dz)
    mid = total - low - high
    # Reject degenerate triplets where any distance is 0.
    if low == 0:
        continue
    if low * low + mid * mid == high * high:
        ans += 1
return ans

Pitfall 2: Misordering the Pythagorean Equation (Wrong Hypotenuse)

A frequent mistake is writing the check using the raw, unsorted distances, e.g.:

if dx * dx + dy * dy == dz * dz:   # WRONG

The Pythagorean relation a² + b² = c² only holds when c is the largest value (the hypotenuse). If you compare arbitrary positions, you will miss valid triplets and may count invalid ones. You must identify the maximum and treat it as the hypotenuse.

Solution: Always sort (or extract min/max) first, then square. The provided code does this correctly via:

low, high = min(dx, dy, dz), max(dx, dy, dz)
mid = total - low - high
if low * low + mid * mid == high * high:

Pitfall 3: Resetting the BFS Distance Array on Every Layer (Performance / Correctness Confusion)

The inner for _ in range(len(queue)) loop is used here for layer-by-layer processing, but in this problem layered processing is unnecessary — a plain BFS that pops one node at a time gives identical distances. Some implementations mistakenly recompute len(queue) inside the loop or reinitialize dist per layer, breaking the traversal.

Solution: Either keep the layered loop exactly as written (correct), or simplify to a flat BFS:

def bfs(start: int) -> List[int]:
    queue = deque([start])
    dist = [inf] * n
    dist[start] = 0
    while queue:
        u = queue.popleft()
        for v in graph[u]:
            if dist[v] > dist[u] + 1:
                dist[v] = dist[u] + 1
                queue.append(v)
    return dist

Both produce correct shortest distances in a tree; avoid mixing the two styles in a way that corrupts the queue length.

Ready to land your dream job?

Unlock your dream job with a 5-minute quiz for a personalized study roadmap!

Get My Roadmap
Discover Your Strengths and Weaknesses: Take Our 5-Minute Quiz to Get a Personalized Study Roadmap:

What's the output of running the following function using the following tree as input?

1def serialize(root):
2    res = []
3    def dfs(root):
4        if not root:
5            res.append('x')
6            return
7        res.append(root.val)
8        dfs(root.left)
9        dfs(root.right)
10    dfs(root)
11    return ' '.join(res)
12
1import java.util.StringJoiner;
2
3public static String serialize(Node root) {
4    StringJoiner res = new StringJoiner(" ");
5    serializeDFS(root, res);
6    return res.toString();
7}
8
9private static void serializeDFS(Node root, StringJoiner result) {
10    if (root == null) {
11        result.add("x");
12        return;
13    }
14    result.add(Integer.toString(root.val));
15    serializeDFS(root.left, result);
16    serializeDFS(root.right, result);
17}
18
1function serialize(root) {
2    let res = [];
3    serialize_dfs(root, res);
4    return res.join(" ");
5}
6
7function serialize_dfs(root, res) {
8    if (!root) {
9        res.push("x");
10        return;
11    }
12    res.push(root.val);
13    serialize_dfs(root.left, res);
14    serialize_dfs(root.right, res);
15}
16

Recommended Readings

Want a Structured Path to Master System Design Too? Don’t Miss This!

Load More