LeetCode 230: Kth Smallest Element in a BST

link

Recursive

The value of the k-th visited node inorder is our answer.

Time: \mathcal{O}(k), space: \mathcal{O}(\texttt{tree-height}).

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        xth = k
        def inorder(root) -> TreeNode:
            nonlocal xth
            
            if not root:
                return None
            if node := inorder(root.left):
                return node
            xth -= 1
            if xth == 0:
                return root
            if node := inorder(root.right):
                return node
        
        node = inorder(root)
        return node.val

Follow-up: If we do the kth smallest query often, we could keep track of the size (or count of nodes) for each subtree. Then, to find the k-th smallest, we would need to just traverse the root -> kth-smallest node.

Initial, time: \mathcal{O}(n), space: \mathcal{O}(n).

Size-update, time: \mathcal{O}(\texttt{tree-height}), space: \mathcal{O}(\texttt{tree-height}).

Per query time: \mathcal{O}(\texttt{tree-height}), space: \mathcal{O}(\texttt{tree-height}).

# Definition for a binary tree node.
# class TreeNode:
#     def __init__(self, val=0, left=None, right=None):
#         self.val = val
#         self.left = left
#         self.right = right
class Solution:
    def count_nodes(self, root, node_count) -> int:
        if not root:
            return 0
        
        left_count = self.count_nodes(root.left, node_count)
        right_count = self.count_nodes(root.right, node_count)
        node_count[root] = left_count + 1 + right_count
        
        return node_count[root]

    def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
        node_count = {}
        self.count_nodes(root, node_count)

        def find_kth_smallest(u, k) -> TreeNode:
            if not u:
                return None
            
            left_count = node_count.get(u.left, 0)
            if left_count == k-1:
                return u
            if k > left_count + 1:
                return find_kth_smallest(u.right, k-left_count-1)
            else:
                return find_kth_smallest(u.left, k)

        node = find_kth_smallest(root, k)
        return node.val

Leave a comment