Delete Node in a BST
update Sep 1, 2017 17:11
Given a root node reference of a BST and a key, delete the node with the given key in the BST. Return the root node reference (possibly updated) of the BST.
Basically, the deletion can be divided into two stages:
- Search for a node to remove.
- If the node is found, delete the node.
Note: Time complexity should be O(height of tree).
Example:
root = [5,3,6,2,4,null,7]
key = 3
5
/ \
3 6
/ \ \
2 4 7
Given key to delete is 3. So we find the node with value 3 and delete it.
One valid answer is [5,4,6,2,null,null,7], shown in the following BST.
5
/ \
4 6
/ \
2 7
Another valid answer is [5,2,6,null,4,null,7].
5
/ \
2 6
\ \
4 7
Basic Idea:
这里 有一个关于在BST中删除节点算法的介绍;
我们知道,在BST中删除节点分为三种情况:
- target node 没有子节点: 此时可以直接删除
- target node 有一个子节点: 直接用子节点替换target节点
- target node 有两个子节点:这种情况最为复杂。方法是用其右子树中最小元素(successor)复制后代替target,然后递归删除之前复制过的successor;
具体实现的时候,因为没有parent指针,需要另外写一个函数 remove(targetNode, parentNode),返回替换掉target的新node,用来应付root被移除的情况;
Python Code (新写法,模块化):
class Solution:
def deleteNode(self, root, key):
"""
:type root: TreeNode
:type key: int
:rtype: TreeNode
"""
def remove(parent, target):
# if target has no child
if not target.left and not target.right:
if parent.left == target: parent.left = None
else: parent.right = None
return
# if target has two child
# 从右子树中找到最小的,复制到之前位置,递归删除该最小node
elif target.left and target.right:
minParent, minNode = findRightMin(target)
newTarget = TreeNode(minNode.val)
newTarget.left = target.left
newTarget.right = target.right
if parent.left == target: parent.left = newTarget
else: parent.right = newTarget
if minParent.val == target.val:
remove(newTarget, minNode)
else:
remove(minParent, minNode)
# if target has only left child
elif target.left:
if parent.left == target: parent.left = target.left
else: parent.right = target.left
# if target has only right child
else:
if parent.left == target: parent.left = target.right
else: parent.right = target.right
# return parent, minNode in right sub tree
def findRightMin(root):
parent, node = root, root
node = node.right
while node.left:
parent = node
node = node.left
return parent, node
# return target's parent, targetNode
def findNode(key):
parent, target = root, root
while target and target.val != key:
if target.val < key:
parent = target
target = target.right
else:
parent = target
target = target.left
return parent, target
# 如果 target 是 root,另外建一个
if not root: return None
if root.val == key:
dummy = TreeNode(0)
dummy.right = root
remove(dummy, root)
root = dummy.right
else:
parent, target = findNode(key)
if not target: return root
remove(parent, target)
return root
Java Code (最初的写法):
class Solution {
public TreeNode deleteNode(TreeNode root, int key) {
// 先找到 key
TreeNode target = root;
TreeNode parent = root;
while (target != null && target.val != key) {
parent = target;
if (key < target.val) {
target = target.left;
} else {
target = target.right;
}
}
if (target == null) return root;
TreeNode temp = remove(parent, target);
if (target == root) return temp;
return root;
}
private TreeNode remove(TreeNode parent, TreeNode target) {
if (target.left == null && target.right == null) {
// 两边都空
if (parent.left == target) parent.left = null;
else parent.right = null;
return null;
} else if (! (target.left != null && target.right != null)) {
// 有一个子树
TreeNode child = target.left != null ? target.left : target.right;
if (parent.left == target) parent.left = child;
else parent.right = child;
return child;
} else {
// 有两个子树
// 找target的successor
TreeNode next_parent = target;
TreeNode next_target = target.right;
while (next_target.left != null) {
next_parent = next_target;
next_target = next_target.left;
}
// 复制,替换
TreeNode temp = new TreeNode(next_target.val);
temp.left = target.left;
temp.right = target.right;
if (parent.left == target) {
parent.left = temp;
} else {
parent.right = temp;
}
// 重要一步,如果next_parent是之前的target, 也需要替换
if (next_parent == target) next_parent = temp;
remove(next_parent, next_target);
return temp;
}
}
}
update Dec 25, 2017 20:28
Update
还是要建立分模块设计的思想。比如这道题,另外写两个函数 findRightMin(), findNode()
,可以令主函数 remove()
更加易读,写的时候也方便。另外注意一点,当待删除节点左右都有孩子的时候,选择用 leftMax 或者 rightMin 替换都可以。
实现的时候,先考虑target node没有孩子的情况,再考虑有两个孩子的情况,再考虑只有一个孩子的情况,这样写逻辑最简单。另外写一个 remove(parent, node)
是为了处理待删除点是 root 的情况。
Python实现见前面。
Java 实现示意:
/**
* public class TreeNode {
* public int key;
* public TreeNode left;
* public TreeNode right;
* public TreeNode(int key) {
* this.key = key;
* }
* }
*/
public class Solution {
public TreeNode delete(TreeNode root, int key) {
TreeNode dummy = new TreeNode(Integer.MAX_VALUE);
dummy.right = root;
// find target node first
TreeNode target = root, parent = dummy;
while (target != null) {
if (target.key == key) break;
parent = target;
if (target.key < key) target = target.right;
else target = target.left;
}
if (target == null) return root;
// delete target
// if target has no child
if (target.left == null && target.right == null) {
if (parent.left == target) {
parent.left = null;
} else {
parent.right = null;
}
}
// if target has only one child
else if (target.left == null || target.right == null) {
TreeNode child = target.left == null ? target.right : target.left;
if (parent.left == target) parent.left = child;
else parent.right = child;
}
// if target has two children, use target's right min as new target, and recursively delete it
else {
// find rightMin node of target
TreeNode rightMin = target.right;
while (rightMin.left != null) {
rightMin = rightMin.left;
}
// put rightMin at the position of target
TreeNode newNode = new TreeNode(rightMin.key);
newNode.left = target.left;
newNode.right = delete(target.right, newNode.key);
if (parent.left == target) parent.left = newNode;
else parent.right = newNode;
}
return dummy.right;
}
}
update 2018-06-10 14:16:39
C++ Solution
class Solution {
void removeNode(TreeNode* root, int key) {
TreeNode* parent = root;
TreeNode* curr = root->left; // real root
bool left = true;
while (curr && curr->val != key) {
if (curr->val < key) {
parent = curr;
left = false;
curr = curr->right;
} else if (curr->val > key) {
parent = curr;
left = true;
curr = curr->left;
}
}
if (! curr) return;
if (curr->left == nullptr && curr->right == nullptr) {
if (left) parent->left = nullptr;
else parent->right = nullptr;
} else if (curr->left == nullptr || curr->right == nullptr) {
if (curr->left == nullptr) {
if (left) parent->left = curr->right;
else parent->right = curr->right;
} else {
if (left) parent->left = curr->left;
else parent->right = curr->left;
}
} else {
TreeNode* rightMin = findMin(curr->right);
TreeNode* newCurr = new TreeNode(rightMin->val);
removeNode(root, rightMin->val);
newCurr->left = curr->left;
newCurr->right = curr->right;
if (left) parent->left = newCurr;
else parent->right = newCurr;
}
}
TreeNode* findMin(TreeNode* node) {
while (node->left) node = node->left;
return node;
}
public:
TreeNode* deleteNode(TreeNode* root, int key) {
TreeNode* dummy = new TreeNode(0);
dummy->left = root;
removeNode(dummy, key);
return dummy->left;
}
};