数据结构之手写 TreeMap
本篇文章将逐步带你手写一个 TreeMap
,使用Java
语言实现。
1. 迭代改递归练习
首先我们进行迭代改递归练习,帮助我们理解递归。
打印数组中所有元素
使用迭代实现
public static void printArr(int[] arr){
for(int i = 0; i < arr.length; i++){
System.out.println(arr[i]);
}
}
使用递归实现:(注意递归要有终止条件)
public static void printArr(int[] arr){
//传入初始索引
printArr(arr, 0);
}
//定义:打印arr[i] 以及之后的所有元素
private static void printArr(int[] arr, int i) {
//base case 递归结束条件
if(i == arr.length){
return;
}
//打印arr[i]
System.out.println(arr[i]);
//i++
//然后打印arr[i + 1]
printArr(arr, i + 1);
}
倒序打印所有元素?
- 根据
printArr(int[] arr, int i)
定义,我们只需要调换相关语句顺序即可:
//i++
//然后打印arr[i + 1]
printArr(arr, i + 1);
//打印arr[i]
System.out.println(arr[i]);
- 改变
base case
和初始传入值:
//传入初始索引
printArr(arr, arr.length - 1);
_________________
//base case 递归结束条件
if(i < 0){
return;
}
//打印arr[i]
System.out.println(arr[i]);
//i--
//然后打印arr[i - 1]
printArr(arr, i - 1);
搜索数组中指定值返回其索引,如果找不到返回-1
使用迭代实现:
public static int search(int[] arr, int target) {
for(int i = 1; i < arr.length; i++){
if(arr[i] == target){
return i;
}
}
return -1;
}
使用递归实现:
public static int search(int[] arr, int target) {
return search(arr, target, 0);
}
//定义:在arr[i]以及arr[i]之后寻找target
private static int search(int[] arr, int target, int i) {
//base case
if(i == arr.length){
return -1;
}
//base case
if(arr[i] == target){
return i;
}
//当前找不到,那么在i+1之后寻找
return search(arr, target, i + 1);
}
2. 手写单链表 RecursiveList
之后我们需要手写一个单向链表使用纯递归实现,来为后续实现 TreeMap
打下基础。
首先看初始代码
/**
* 单链表递归实现
*/
public class RecursiveList<E> {
// 单链表链表节点
private static class Node<E> {
E val;
//指向下一个节点
Node<E> next;
Node(E val) {
this.val = val;
}
}
//头节点
private Node<E> first = null;
//长度
private int size = 0;
public RecursiveList() {
}
/***** 增 *****/
//在头部插入节点
public void addFirst(E e) {
}
public void addLast(E e) {
}
public void add(int index, E e) {
}
/***** 删 *****/
public E removeFirst() {
}
public void removeLast() {
}
public void remove(int index) {
}
/***** 查 *****/
public E get(int index){
}
public E getFirst() {
}
public E getLast() {
}
/***** 改 *****/
public E set(int index, E element) {
}
/***** 其他工具函数 *****/
public int size() {
return size;
}
public boolean isEmpty() {
return size == 0;
}
private boolean isElementIndex(int index) {
return index >= 0 && index < size;
}
private boolean isPositionIndex(int index) {
return index >= 0 && index <= size;
}
/**
* 检查 index 索引位置是否可以存在元素
*/
private void checkElementIndex(int index) {
if (!isElementIndex(index))
throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
}
/**
* 检查 index 索引位置是否可以添加元素
*/
private void checkPositionIndex(int index) {
if (!isPositionIndex(index))
throw new IndexOutOfBoundsException("Index: " + index + ", Size: " + size);
}
//返回 index 对应的 Node
//注意:请保证传入的 index 是合法的
private Node<E> getNode(int index){
Node<E> p = first;
for(int i = 0; i < index; i++){
p = p.next;
}
return p;
}
}
我们首先实现
xxxFirst
相关的方法,因为这些方法不涉及for
循环。
addFirst
在链表头部添加元素
public void addFirst(E e) {
//首先做出这个节点
Node<E> x = new Node<>(e);
x.next = first;
first = x;
size++;
}
removeFirst
删除头部节点
public E removeFirst() {
//首先判断链表是否为空
if(isEmpty()){
throw new NoSuchElementException();
}
E deleteVal = first.val;
// 删除
first = first.next;
return deleteVal;
}
getFirst
获取头部节点
public E getFirst() {
if(isEmpty()){
throw new NoSuchElementException();
}
return first.val;
}
然后我们实现
getNode
方法,根据索引获取对应的 Node 节点
起初我们是通过 for
循环实现的
//返回 index 对应的 Node
//注意:请保证传入的 index 是合法的
private Node<E> getNode(int index){
Node<E> p = first;
for(int i = 0; i < index; i++){
p = p.next;
}
return p;
}
那么如何通过递归实现:我们可以仿照我们最初的 search
函数迭代改递归:
// 返回「从 node 开始的第 index 个链表节点」
private Node<E> getNode(Node<E> node, int index) {
// base case
if (index == 0) {
return node;
}
// 返回 「从 node.next 开始的第 index - 1 个链表节点」
return getNode(node.next, index - 1);
}
那么 getNode(int index)
如何写?
private Node<E> getNode(int index){
//返回从first 开始的第index个节点
return getNode(first, index);
}
基于
getNode
我们就可以实现get
和set
方法了
public E get(int index){
//检查索引是否越界
checkElementIndex(index);
Node<E> node = getNode(index);
return node.val;
}
public E set(int index, E element) {
checkElementIndex(index);
Node<E> p = getNode(index);
E oldVal = p.val;
p.val = element;
return oldVal;
}
然后我们实现
getLast
方法
我们可以通过 getNode(size -1)
来获取最后一个节点
public E getLast() {
if (isEmpty()) {
throw new NoSuchElementException();
}
return getNode(size -1).val;
}
那么我们能不能通过递归来实现呢?
public E getLast() {
if (isEmpty()) {
throw new NoSuchElementException();
}
return getLast(first);
}
// 返回 node 之后最后一个节点
private E getLast(Node<E> node){
// base case
if(node.next == null){
//如果该节点的下一个节点为空,则其为尾节点
return node.val;
}
//不满足base case 继续寻找下一个节点
return getLast(node.next);
}
修改链表结构
removeLast
remove
addLast
add
- 实现
removeLast
凡是修改链表结构的,都需要有一个返回值:
public void removeLast() {
if (isEmpty()) {
throw new NoSuchElementException();
}
first = removeLast(first);
size--;
}
// x -> y -> null
private Node<E> removeLast(Node<E> node) {
//base case
if (node.next == null) {
// node 就是最后一个节点,让自己直接消失
return null;
}
node.next = removeLast(node.next);
return node;
}
为什么这样写?
我们从后往前看,也就是找到最后一个节点时,我们返回 null
,然后被其上一个节点的 next
指针接收,一直返回,最后就只删除了最后一个节点。
2. 实现
remove(int index)
首先我们在 remove(int index)
,书写函数 remove(first, index)
,该函数表明删除 first
节点之后的第 index
个节点。
public void remove(int index) {
checkElementIndex(index);
first = remove(first, index);
size--;
}
base case
结束条件,也就是 index == 0
时找到了这个节点。那么我们应该返回什么呢?
- 我们应该返回
node.next
,也就是返回当前节点之后的节点,与当前节点之前的节点连接起来。 - 然后
node.next = remove(node.next, index - 1)
接收返回的值。
//删除node之后的第 index 个节点
private Node<E> remove(Node<E> node, int index) {
//base case
if (index == 0) {
return node.next;
}
node.next = remove(node.next, index - 1);
return node;
}
- 实现
addLast
public void addLast(E e) {
first = addLast(first, e);
size++;
}
我们需要在最后插入一个节点,也就是当 node == null
时,返回新添加的节点,然后进行接收即可。
// a -> b -> c -> d -> e -> null
private Node<E> addLast(Node<E> node, E e) {
//说明指向最后一个节点的next
if (node == null) {
return new Node<>(e);
}
//返回被接受,连接
node.next = addLast(node.next, e);
return node;
}
- 实现
add
,在指定索引处插入节点
public void add(int index, E e) {
checkPositionIndex(index);
if (index == size) {
addLast(e);
return;
}
first = add(first, index, e);
size++;
}
当 index == 0
时,说明找到了这个插入点,然后插入节点的next
指针需要指向当前插入节点(也就是连接插入点之后的节点),然后进行返回。与插入点前的节点连接。
// a -> b -> c -> null
private Node<E> add(Node<E> node, int index, E e) {
if (index == 0) {
//假设在b处插入节点
Node<E> x = new Node<>(e);
//需要把b节点之后的链表连接到x.next上
x.next = node;
//然后返回这段链表
return x;
}
node.next = add(node.next, index - 1, e);
return node;
}
3. 实现TreeMap
3.1. TreeMap原理及特性
TreeMap
底层是基于二叉查找树BST,有一个特点,根节点的值比左子树要大,比右子树要小。
为什么要使用BST?
- 因为BST左小右大的特性,我们可以进行二分搜索。比如我们搜索2,先从根节点开始查找,2比6小,那么搜索左子树,否则搜索左子树,直到找到该节点。
- 通常来说BST的查询效率很高,如果所有的节点都接到左子树上,那么就退化成单链表(不是自平衡的),那么搜索的效率就是
O(n)
了。 - Java的TreeMap底层使用的是红黑树,也就是自平衡的二叉搜索树。
我们基于什么来实现?
- 我们基于
BST
实现,不考虑退化的情况,也就是不使用红黑树。
3.2. 初次实现
首先看初始代码
为什么 K extends Comparable<K> ?
- 我们是根据
Key
的大小,将key
存入到左右子树当中的。所以key
必须是可比较的。
public class MyTreeMap<K extends Comparable<K>, V> {
private class TreeNode {
K key;
V val;
TreeNode left, right;
TreeNode(K key, V val) {
this.key = key;
this.val = val;
this.size = 1;
left = right = null;
}
}
private TreeNode root = null;
public MyTreeMap() {
}
/***** 增/改 *****/
// 添加 key -> val 键值对,如果键 key 已存在,则将值修改为 val
public V put(K key, V val) {
}
/***** 删 *****/
// 删除 key 并返回对应的 val
public V remove(K key) {
}
// 删除并返回 BST 中最小的那个 key
public void removeMin() {
}
// 删除并返回 BST 中最大的那个 key
public void removeMax() {
}
/***** 查 *****/
// 返回 key 对应的 val,如果 key 不存在,则返回 null
public V get(K key) {
}
// 返回小于等于 key 的最大的键
public K floorKey(K key) {
}
// 返回大于等于 key 的最小的键
public K ceilingKey(K key) {
}
// 返回小于 key 的键的个数
public int rank(K key) {
}
// 返回索引为 i 的键,i 从 0 开始计算
public K select(int i) {
}
// 返回 BST 中最大的键
public K maxKey() {
}
// 返回 BST 中最小的键
public K minKey() {
}
// 判断 key 是否存在 Map 中
public boolean containsKey(K key) {
}
/***** 工具函数 *****/
public boolean isEmpty() {
return size == 0;
}
}
我们首先实现
get(K key)
方法
public V get(K key){
//进行判断
if(key == null){
throw new IllegalArgumentException("key is null");
}
//定义get(root, key)方法
TreeNode node = get(root, key);
if(node == null){
return null;
}
return node.val;
}
我们需要定义 get(TreeNode node, K key)
,从根节点开始遍历,找到指定 key
对应的节点。
- 首先我们需要比较当前节点的
Key
与指定Key
的大小。node.key.compareTo(key)
- 如果
node == null
,说明没找到。 - 如果当前根节点
key
小于指定Key
,那么就搜索右子树;如果大于就搜索左子树。 - 剩下的情况就是找到了。
//获取node根节点下指定key对应的节点
private TreeNode get(TreeNode node, K key) {
if(node == null){
return null;
}
int i = node.key.compareTo(key);
// i < 0 说明当前根节点小于 key
if(i < 0){
//说明在右子树
return get(node.right, key);
}
if(i > 0){
// i > 0 说明在左子树
return get(node.left, key);
}
// 剩下的情况就是找到了
return node;
}
然后我们实现
containsKey(K key)
方法,判断key
是否存在。 我们可以直接使用get(K key)
方法,判断其返回值是否为空即可。
public boolean containsKey(K key){
if(key == null){
throw new IllegalArgumentException("key is null");
}
//直接使用上面的方法
TreeNode x = get(root, key);
return x != null;
}
实现
put(K key, V val)
方法
有两种情况:key
存在修改或者 key
不存在插入。
public V put(K key, V val){
if(key == null){
throw new IllegalArgumentException("key is null");
}
V oldVal = get(key);
//首先获取Key,如果oldVal不存在那么就相当于新增 size++
if(oldVal == null){
size++;
}
//为什么赋值给root 因为新增改变了二叉树
root = put(root, key, val);
return oldVal;
}
首先我们需要将当前节点 key
与插入 key
进行比较,如果 cmp > 0
说明在左子树,cmp < 0
说明在右子树,如果 cmp == 0
说明找到了,修改其值即可。
如果找不到也就是 node == null
时,直接返回新节点即可。因为 node.left
或者 node.right
其一会进行接收。
private TreeNode put(TreeNode node, K key, V val) {
//2.找不到,判断是否为空,如果为空,说明没有找到
if(node == null){
//为什么直接返回,当找不到的时候
//只能在左子树或者右子树搜索
//找不到返回值会被左右子树接收。
return new TreeNode(key, val);
}
// 1.先找是否存在
int cmp = node.key.compareTo(key);
if(cmp > 0){
node.left = put(node.left, key, val);
}else if(cmp < 0){
node.right = put(node.right, key, val);
}else{
// node.key == key 找到了
node.val = val;
}
return node;
}
删除最小
Key
和最大Key
removeMin()
removeMax()
我们观察这个树,可以发现最左子树就是最小值,最右子树就是最大值。
那么我们只需要判断
node.left == null
和 node.right == null
就找到了最左子树和最右子树。
- 实现
removeMin()
删除最左子树。
首先我们需要从根节点查找到最左子树。
//删除最小值,也就是删除最左子树
public void removeMin() {
if (isEmpty()) {
throw new NoSuchElementException();
}
root = removeMin(root);
size--;
}
当 node.left == null
说明已经走到最左侧节点,当前节点没有左子树,当前节点就是最左子树了。
可以返回 null 吗?
- 不可以,如果当前节点有右子树,则会将当前节点和右子树一起删除。
- 所以应该返回其右子树,让其覆盖当前节点即可,我们不用操心右子树是否为空。同时也维护了
BST
的性质。
private TreeNode removeMin(TreeNode node) {
if(node.left == null){
return node.right;
}
node.left = removeMin(node.left);
return node;
}
- 实现
removeMax()
删除最右子树。
参照 removeMax实现可得
:
public void removeMax(){
if (isEmpty()) {
throw new NoSuchElementException();
}
root = removeMax(root);
size--;
}
private TreeNode removeMax(TreeNode node) {
if(node.right == null){
return node.left;
}
node.right = removeMax(node.right);
return node;
}
3.3. TreeMap的删除
实现
remove(K key)
方法,删除指定key
对应的节点。
// 删除 key 并返回对应的 val
public V remove(K key){
// 检查参数有效性
if(key == null){
throw new IllegalArgumentException("key is null");
}
if(!containsKey(key)){
return null;
}
// 获取旧值
V deleteVal = get(key);
// 删除节点
root = remove(root, key);
size--;
return deleteVal;
}
然后我们定义 remove(TreeNode node, K key)
方法,表明删除当前 node
节点下对应二叉树中节点 key
等于指定 key
。
private TreeNode remove(TreeNode node, K key) {
}
我们应该怎么做?
- 首先将当前节点
key
和需要删除的key
进行比较。 - 如果 node.key > key 则查询左子树。
- 如果 node.key < key 则查询右子树。
那么相等的情况,我们应该怎么删除呢?会涉及三种情况?
- node 是叶子节点,左右子树都是 null
- node 左右子树有一个非空
- node 左右子树都不为空
node 是叶子节点,左右子树都是 null ,这个我们只需返回 null 即可。可以写如下的代码:
左右子树有一个非空怎么办呢?
假如我们要删除的节点是11,其左子树不为空,那么我们只需要返回其左子树即可,根据图示可以发现,BST的性质没有改变。
同理,如果删除节点的右子树不为空,那么只需要返回其右子树即可。我们可以写如下代码:
node 左右子树都不为空怎么办呢?
假如我们要删除的节点是根节点
6
。那么我们应该怎么才能维持BST的性质呢?
- 我们可以移动左子树的最大节点5至根节点,或者移动右子树的最小节点7至根节点来维持BST的性质。
- 我们可以移动左子树的最大节点5至根节点:
确实可以维护BST的性质。
疑问?
-
难道删除节点的左子树的最大节点不能有左右子树吗,那样不是变成多于二叉的树了吗?
- 首先最大节点肯定没有右子树,如果有那么其不是最大节点。
- 那么它会有左子树吗,可能有左子树,那么我们应该怎么办呢?
我们可以把最大节点的左子树接到其父节点上,然后将从当前位置删除。记得保留当前节点的值,然后将该节点的左右指针指向删除节点指向的左右子树。这样就可以维护BST的性质了。
- 移动右子树的最小节点7至根节点来维持BST的性质。
疑问?
- 我们右子树的最小节点,一定没有左子树,但是其可能有右子树,那么我们应该怎么办呢?
我们只需要把右子树接到最小节点的父节点上,然后将最小节点从当前位置删除,然后保留最小节点的值,然后将其左右指针指向删除节点的左右子树。这样就维护了BST的性质。
完整代码实现
- 首先我们需要实现两个方法,找到以某个节点为根节点的BST的最大节点和最小节点:
// 以当前节点为根节点的BST的最大节点
private TreeNode maxNode(TreeNode p) {
while (p.right != null) {
p = p.right;
}
return p;
}
// 以当前节点为根节点的BST的最小节点
private TreeNode minNode(TreeNode p) {
while (p.left != null) {
p = p.left;
}
return p;
}
- 然后实现我们的方法:前面的章节我们已经实现了
removeMin
和removeMax
方法了。
private TreeNode remove(TreeNode node, K key) {
// 进行比较
int cmp = node.key.compareTo(key);
if(cmp > 0){
// node.key > key 去左子树找
node.left = remove(node.left, key);
}else if(cmp < 0){
// node.key < key 去右子树找
node.right = remove(node.right, key);
}else{
if(node.left == null && node.right == null){
// 左右子树都为空
return null;
}else if(node.left != null && node.right == null){
//左子树不为空
return node.left;
}else if(node.left == null && node.right != null){
// 右子树不为空
return node.right;
}
// 剩下的情况的就是左右子树不为空的情况了。
// 我们有两种方案实现
// 1. 找到当前节点的前驱节点,也就是左子树的最大值
//首先我们需要找到这个最大节点
TreeNode leftMax = maxNode(node.left);
//然后通过我们的 removeMax删除这个节点,让其父节点的左指针指向这个删除这个节点后的左子树
node.left = removeMax(node.left);
leftMax.left = node.left;
leftMax.right = node.right;
// // 2. 找到当前节点的后继节点,也就是右子树的最小值
// TreeNode rightMin = minNode(node.right);
// node.right = removeMin(node.right);
// rightMin.left = node.left;
// rightMin.right = node.right;
}
return node;
}
3.4. 实现floorKey和ceilingKey方法
// 返回小于等于 key 的最大的键
public K floorKey(K key) {
}
// 返回大于等于 key 的最小的键
public K ceilingKey(K key) {
}
首先我们实现 floorKey
方法:
public K floorKey(K key) {
// 参数检查
if (key == null) {
throw new IllegalArgumentException("key is null");
}
if (isEmpty()) {
throw new NoSuchElementException();
}
//创建辅助函数
TreeNode x = floorKey(root, key);
return x.key;
}
注意:我们要查找的是小于等于key的最大的键,如果key存在于BST中,那么就返回key。我们只需要考虑key不存在的情况。
比如我们要查找小于等于16的最大的键,我们一直向下查找,发现找不到,那么我们返回其父节点就可以了。
实现:逻辑类似于
get
方法,只需要修改右子树处的相关代码即可。
private TreeNode floorKey(TreeNode node, K key) {
if(node == null){
return null;
}
int i = node.key.compareTo(key);
// i < 0 说明当前根节点小于 key
if(i < 0){
//说明在右子树
TreeNode x = floorKey(node.right, key);
// 当发现在右子树查找时,返回 null 说明找不到,那么返回其父节点即可,也就是返回当前节点 node。
if(x == null){
return node;
}
}
if(i > 0){
// i > 0 说明在左子树
return floorKey(node.left, key);
}
// 剩下的情况就是找到了
return node;
}
同理实现:ceilingKey
// 返回大于等于 key 的最小的键
public K ceilingKey(K key) {
if (key == null) {
throw new IllegalArgumentException("key is null");
}
if (isEmpty()) {
throw new NoSuchElementException();
}
TreeNode x = ceilingKey(root, key);
return x.key;
}
private TreeNode ceilingKey(TreeNode node, K key) {
if(node == null){
return null;
}
int i = node.key.compareTo(key);
// i < 0 说明当前根节点小于 key
if(i < 0){
//说明在右子树
return ceilingKey(node.right, key);
}
if(i > 0){
// i > 0 说明在左子树
TreeNode x = ceilingKey(node.left, key);
if(x == null){
return node;
}
return x;
}
// 剩下的情况就是找到了
return node;
}
3.5. 实现 keys 相关方法
// 从小到大返回所有键
public Iterable<K> keys() {
}
// 从小到大返回闭区间 [min, max] 中的键
public Iterable<K> keys(K min, K max) {
}
首先我们实现从小到大返回所有键,根据BST的特性,那么通过中序遍历(左 根 右)BST,就可以返回从小到大的键了。
// 从小到大返回所有键
public Iterable<K> keys() {
if (isEmpty()) {
return new LinkedList<>();
}
LinkedList<K> list = new LinkedList<>();
traverse(root, list);
return list;
}
// 中序遍历 BST
private void traverse(TreeNode node, LinkedList<K> list) {
if (node == null) {
return;
}
// 先遍历left
traverse(node.left, list);
// 中序遍历
list.addLast(node.key);
// 再遍历 right
traverse(node.right, list);
}
实现从小到大返回闭区间 [min, max] 中的键,我们只需要在中序遍历的时候添加值得时候,判断当前节点是否在 [min, max] 范围内。
// 从小到大返回闭区间 [min, max] 中的键
public Iterable<K> keys(K min, K max) {
if (min == null) throw new IllegalArgumentException("min is null");
if (max == null) throw new IllegalArgumentException("max is null");
LinkedList<K> list = new LinkedList<>();
traverse(root, list, min, max);
return list;
}
// 中序遍历 BST
private void traverse(TreeNode node, LinkedList<K> list, K min, K max) {
if (node == null) {
return;
}
int cmpMin = min.compareTo(node.key);
int cmpMax = max.compareTo(node.key);
traverse(node.left, list);
// 中序遍历 min <= node.key <= max
if (cmpMin <= 0 && cmpMax >= 0) {
list.addLast(node.key);
}
traverse(node.right, list);
}
但是这样的效率很低?
-
因为我们没有BST的性质。
-
当
min
大于等于当前node
的时候,那么我们就不需要遍历当前node
的左子树了,因为左子树的节点都小于node
,自然也小于min
。 -
当
max
小于等于当前node
的时候,那么我们就不需要遍历当前node
的右子树了,因为右子树的节点都大于node
,自然也大于max
。
优化后的代码:
private void traverse(TreeNode node, LinkedList<K> list, K min, K max) {
if (node == null) {
return;
}
int cmpMin = min.compareTo(node.key);
int cmpMax = max.compareTo(node.key);
if (cmpMin < 0) {
// min < node.key 才进行遍历
traverse(node.left, list);
}
// 中序遍历 min <= node.key <= max
if (cmpMin <= 0 && cmpMax >= 0) {
list.addLast(node.key);
}
if (cmpMax > 0) {
// max > node.key 才进行遍历
traverse(node.right, list);
}
}
3.6. 实现 select 和 rank 方法
// 返回小于 key 的键的个数
public int rank(K key) {
}
// 返回索引为 i 的键,i 从 0 开始计算
public K select(int i) {
}
实现
rank
方法,根据BST的性质,小于key
的键的个数,就是统计其左子树节点的个数。
我们可以通过前序遍历、中序遍历、后序遍历来统计节点个数,但是这样的时间效率很高。
有没有一种方法可以快速统计节点的个数呢?
- 我们可以在
TreeNode
类中添加size
属性,用来记录以当前节点为根的BST
有多少个节点。然后移除外部的size
属性。
- 然后需要修改相应的工具函数:
- 然后修改改变BST结构的方法,
put, removeXxx
:
同上图改造其他方法即可。
正式实现
rank
方法
// 返回小于 key 的键的个数
public int rank(K key) {
if (key == null) {
throw new IllegalArgumentException();
}
return rank(root, key);
}
创建辅助函数:
- 首先我们需要比较当前节点
key
与指定key
的大小。 - 如果
key < node.key
,根据BST
的性质,说明node
和node.right
都大于key
。我们只需要查找左子树就好了。 - 如果
key > node.key
,根据BST
的性质,说明node
和node.left
都是比key
小的,我们需要返回左子树的个数和当前节点个数1,然后查找右子树返回其个数即可。 - 如果相等,说明
node
节点的左子树满足,直接返回个数即可。
// 返回以 node 为根的 BST 中小于 key 的键的个数
private int rank(TreeNode node, K key) {
int cmp = key.compareTo(node.key);
if (cmp < 0) {
// key < node.key
// 和 node 以及 node.right 没啥关系了
// 因为它们太大了
return rank(node.left, key);
} else if (cmp > 0) {
// key > node.key
// node 和 node.left 左子树都是比 key 小的
return size(node.left) + 1 + rank(node.right, key);
} else {
// key == node.key
return size(node.left);
}
}
实现
select
方法,返回索引为i
的键,i
从0
开始计算返回索引为
i
的键,是什么意思呢?
- 就是我们通过中序遍历,从小到大排序的节点。索引
0
就是返回最小的元素。
// 返回索引为 i 的键,i 从 0 开始计算
public K select(int i) {
if (i < 0 || i >= size()) {
throw new IllegalArgumentException();
}
TreeNode x = select(root, i);
return x.key;
}
定义辅助函数:
- 首先我们需要计算出当前节点的索引,那么如何计算呢,通过
size(node.left)
方法计算,因为根据BST的性质,左小右大,其左子树节点个数就是其索引。
2. 然后比较当前节点索引
n
和 i
的大小:
- 如果
n > i
,那么我们只能在当前节点左子树查找,因为当前节点右子树索引大于n
也大于i
。 - 如果
n < i
,那么我们只能在当前节点右子树查找,因为当前节点左子树索引都小于n
。- 为什么是
i - n -1
? - 如果我们需要查找索引为10的节点,当前节点索引为
5
,那么我们肯定要查找右子树。如果索引还是传入i=10,那么下一轮重新计算node
的索引,就是 0 似乎没什么问题;如果计算下一轮,那么索引变为了3,i还是10,总数是4个,在这个小的子树中,肯定找不到索引为10的节点。 i - n - 1 => 10 - 5 - 1 = 4
计算出其在右子树的索引,然后来查找。
- 为什么是
- 如果
n = i
,说明找到了直接返回即可。
// 返回以 node 为根的 BST 中索引为 i 的那个节点
private TreeNode select(TreeNode node, int i) {
int n = size(node.left);
if (n > i) {
// n == 10, i == 3
return select(node.left, i);
} else if (n < i) {
// n == 3, i == 10
return select(node.right, i - n - 1);
} else {
// i == n
// node 就是索引为 i 的那个节点
return node;
}
}
转载自:https://juejin.cn/post/7215213377094090808