Date and Time: Jun 8, 2024, 6:16 PM (EST)
Link: https://leetcode.com/problems/kth-smallest-element-in-a-bst/
Given the root
of a binary search tree, and an integer k
, return the kth
smallest value (1-indexed) of all the values of the nodes in the tree.
Example 1:
Input: root = [3, 1, 4, null, 2], k = 1
Output: 1
Example 2:
Input: root = [5, 3, 6, 2, 4, null, null, 1], k = 3
Output: 3
By the property of BST, if we perform DFS In-order traversal, the tree will be sorted in ascending order. So, add each element to res
, and return the k-1
th index will be the kth
smallest value.
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
# DFS then store in an array, by BST property, the array is sorted
res = []
def dfs(node):
if node is None:
return
dfs(node.left)
res.append(node.val)
dfs(node.right)
dfs(root)
return res[k-1]
Time Complexity:
Space Complexity:
Check if count == k
can help reduce the space complexity a little bit. In Example 1, after we find 1
we will return to 3
and won't add 2
into res
. Remember to return res
in the end, because the return statement in dfs()
won't actually return the final value, res
only stores a value which saves a little bit space complexity instead of using list[]
.
# Definition for a binary tree node.
# class TreeNode:
# def __init__(self, val=0, left=None, right=None):
# self.val = val
# self.left = left
# self.right = right
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
count, res = 0, None
def dfs(node):
nonlocal count, res
if not node:
return
dfs(node.left)
count += 1
if count == k:
res = node.val
return node.val
dfs(node.right)
dfs(root)
return res
Time Complexity:
Space Complexity:
Perform dfs by stack
and a while loop. We just add everything of the left subtree into stack first, then pop
from stack and increment the n
to compare with k
, lastly, check if curr has right subtree.
class Solution:
def kthSmallest(self, root: Optional[TreeNode], k: int) -> int:
curr = root
stack = []
n = 0
while stack or curr:
while curr:
stack.append(curr)
curr = curr.left
curr = stack.pop()
n += 1
if n == k:
return curr.val
curr = curr.right