Leetcode 230. Kth Smallest Element in a BST

Problem Explanation and Solution Approach

This problem is an application of exploring a Binary Search Tree (BST) and requires a deep understanding of its properties and traversals. A BST has this special property: For each node, every descendant on its left subtree is less than the node, and every descendant on its right subtree is greater than the node.

With this property in mind, we can solve the problem by doing an in-order traversal, which generates a list of elements in increasing value. By extracting the k-1th element from this list, we can find the kth smallest node.

It's important to note that this naïve approach is not efficient if we perform the operation often, as each operation takes linear time. To achieve better performance, we can use an enhanced in-order traversal: count the number of nodes in the left subtree leftCount, if k-1 equals to leftCount, the kth smallest node is the root itself; if k is less than or equals to leftCount, the kth smallest node is in the left subtree; if leftCount is less than k, the kth smallest node is in the right subtree (where k is adjusted to k-leftCount-1 due to the exclusion of the root and its left subtree).

Example

Let's consider the following BST tree:

1
2
3    5
4   / \
5  3   6
6 / \
72   3
8/
91

If we want to find the 3rd smallest element, which is 3 in this case,

  • k = 3, and also the number of nodes in left subtree leftCount = 3
  • k > leftCount . So, we continue this process for right children of node 5,
  • Now, root = 6, leftCount = 0 and k = 3 - 1 - 3 = -1 (which is invalid)
  • root = 3, leftCount = 2 and k = 3 (which is valid)

So, the 3rd smallest element is 3.

Solution

Python

1
2python
3class Solution:
4    def kthSmallest(self, root: TreeNode, k: int) -> int:
5        def countNodes(node: TreeNode) -> int:
6            if not node: 
7                return 0
8            return 1 + countNodes(node.left) + countNodes(node.right)
9        
10        leftCount = countNodes(root.left)
11        if leftCount == k - 1:
12            return root.val
13        elif leftCount >= k:
14            return self.kthSmallest(root.left, k)
15        else:
16            return self.kthSmallest(root.right, k - 1 - leftCount)

Java

1
2java
3public class Solution {
4    public int kthSmallest(TreeNode root, int k) {
5        int leftCount = countNodes(root.left);
6        if (leftCount == k - 1) {
7            return root.val;
8        } else if (leftCount >= k) {
9            return kthSmallest(root.left, k);
10        } else {
11            return kthSmallest(root.right, k - 1 - leftCount);
12        }
13    }
14
15    private int countNodes(TreeNode node) {
16        if (node == null) return 0;
17        return 1 + countNodes(node.left) + countNodes(node.right);
18    }
19}

JavaScript

1
2javascript
3class Solution {
4  kthSmallest(root, k) {
5    function countNodes(node) {
6      if (!node) return 0;
7      return 1 + countNodes(node.left) + countNodes(node.right);
8    }
9
10    const leftCount = countNodes(root.left);
11    if (leftCount === k - 1) {
12      return root.val;
13    } else if (leftCount >= k) {
14      return this.kthSmallest(root.left, k);
15    } else {
16      return this.kthSmallest(root.right, k - 1 - leftCount);
17    }
18  }
19}

C++

1
2cpp
3class Solution {
4public:
5    static int countNodes(TreeNode* root) {
6        return root ? 1 + countNodes(root->left) + countNodes(root->right) : 0;
7    }
8
9    int kthSmallest(TreeNode* root, int k) {
10        int leftCount = countNodes(root->left);
11        if (leftCount == k - 1) return root->val;
12        else if(leftCount >= k) return kthSmallest(root->left, k);
13        else return kthSmallest(root->right, k - 1 - leftCount); 
14    }
15};

C#

1
2csharp
3public class Solution {
4    public int KthSmallest(TreeNode root, int k) {
5        int leftCount = CountNodes(root.left);
6        if (leftCount == k - 1) {
7            return root.val;
8        } else if (leftCount >= k) {
9            return KthSmallest(root.left, k);
10        } else {
11            return KthSmallest(root.right, k - 1 - leftCount);
12        }
13    }
14
15    private int CountNodes(TreeNode node) {
16        if (node == null) return 0;
17        return 1 + CountNodes(node.left) + CountNodes(node.right);
18    }
19}

This solution works by first counting the nodes in the left subtree. Then based on the value of this count, we determine if the kth smallest node should be in the left subtree, the root itself, or the right subtree:

  • If the leftCount is equal to k-1, it implies there are k-1 nodes less than the root (the root being the kth node), and hence, the kth smallest node is the root.
  • If k is less than or equal to leftCount, this implies that the kth smallest number would be in the left subtree, and hence, we recursively find the kth smallest number in that subtree.
  • If leftCount is less than k, this implies that the kth smallest number would be in the right subtree. However, before proceeding with the right subtree, we adjust k by subtracting the leftCount and 1 (for the root).

The method countNodes is used to count the nodes present in a specific subtree. The code makes use of recursion to explore the subtree and increases the count for each node encountered.

This approach has a time complexity of O(N) where N is the number of nodes in the tree, as each node is visited once during the node count process. The space complexity is O(N) for the recursion stack in a worst-case scenario when the tree is unbalanced.

Hence, using this technique, we can effectively find the kth smallest element in a Binary Search Tree across different programming languages.


Got a question? Ask the Teaching Assistant anything you don't understand.

Still not clear? Ask in the Forum,  Discord or Submit the part you don't understand to our editors.


TA 👨‍🏫