数据结构之手写字典树
今天我们来实现字典树,或称之为前缀树、Trie
树,又称单词查找树或键树,是一种树形结构,是一种哈希树的变种。典型应用是用于统计和排序大量的字符串(但不限于字符串),所以经常被搜索引擎系统用于文本词频统计。它的优点是:利用字符串的公共前缀来减少查询时间,最大限度地减少无谓的字符串比较。
1. 字典树介绍
字典树的性质:
- 根节点不包含字符,除根节点外每一个节点都只包含一个字符。
- 从根节点到某一节点,路径上经过的字符连接起来,为该节点对应的字符串。
- 每个节点的所有子节点包含的字符都不相同。
比如我们存储字符串xy,xyz,xya,abc
。
初始代码介绍
为什么不使用泛型呢?
- 因为字典树操作对象是字符串。
public class TrieSet {
// ASCII 码个数
private static final int R = 256;
// 当前存在集合中的元素个数
private int size = 0;
private static class TrieNode {
// 记录下一个字符的指针
// 比如 next['a'] != null,说明这里有一个字符 'a'
// 为什么每个节点有256个孩子,因为字符的ASCII编码有256个
TrieNode[] next = new TrieNode[R];
// 标志该节点是否是最后一个节点
boolean isEnd = false;
}
private TrieNode root = null;
public TrieSet(){}
public boolean add(String key){}
public boolean remove(String key) {}
public boolean contains(String key) {}
public boolean startsWith(String prefix) {}
private TrieNode getNode(TrieNode node, String key) {}
}
2. 实现
2.1. 初步实现
实现
add
方法
// 在 set 中添加 key,如果成功插入则返回 true
public boolean add(String key){
if(key == null){
throw new IllegalArgumentException("key is null");
}
if(contains(key)){
return false;
}
// 真正的插入方法
root = add(root, key, 0);
return true;
}
// 从 node 开始插入 key[i...]
private TrieNode add(TrieNode node, String key, int i) {
if(node == null){
node = new TrieNode();
}
if(i == key.length()){
node.isEnd = true;
return node;
}
char c = key.charAt(i);
node.next[c] = add(node.next[c], key, i + 1);
return node;
}
为什么这样写?
以插入 add("xy"),add("xyz")
为例。
首先我们插入 add("xy")
:传入 root
,xy
字符串、初始索引 0
。
- 首先判断节点是否为
null
,第一次root
为null
。 - 那么会创建一个新的
TrieNode
节点,也就是包含一个next[256]
的空数组,和一个isEnd
变量。 - 然后判断i是否等于字符串长度,这时候显然不相等。
- 然后执行
key.charAt(i)
,获取字符串第一个字符x
。 - 执行
node.next[x] = add(node.next[x], key, i + 1);
- 执行下一次
add
方法,这时候node.next[x]
为空,然后会创建一个新数组指向node.next[x]
- 当前索引不等于字符串长度,继续朝下执行,执行
key.charAt(i)
,获取第二个字符y。 - 执行
node.next[y] = add(node.next[y], key, i + 1);
- 执行下一次
add
方法,这时候node.next[y]
为空,然后会创建一个新数组指向node.next[y]
- 当前索引等于字符串长度,设置当前节点的
isEnd = true
。 - 然后层层返回。
然后插入 add("xyz")
- 还是从
root
开始,此时root
不为null
。 - 通过索引获取第一个字符
x
。 - 执行
node.next[x] = add(node.next[x], key, i + 1);
- 执行下一次
add
方法,这时候node.next[x]
不为空。 - 通过索引获取第二个字符y。
- 执行
node.next[y] = add(node.next[y], key, i + 1);
- 执行下一次
add
方法,这时候node.next[y]
不为空 - 通过索引获取第三个字符z。
- 执行
node.next[z] = add(node.next[z], key, i + 1);
- 此时
node.next[z]
等于空,创建一个新的节点。 - 此时索引等于字符串长度,设置
isEnd = true
,然后返回node
。 - 向上层层返回。
通过以下过程,也就是每次按从左到右的顺序插入字符串中每个字符,并创建当前字符的数组。
逻辑上相当于一条链表:
- next 指针是一个大小为256的数组,节点的值是isEnd代表是否是一个字符串的末尾。
- 通过字符的索引,找到对应的下一个数组。
实现
getNode(TrieNode node, String key)
从node
节点开始搜索,返回指定字符串的尾节点。
可能的搜索情况有三种:
- 搜索尾节点不为
null
,且isEnd = true
,比如搜索字符串xy
- 搜索尾节点不为
null
,且isEnd = false
,虽然搜索到了这个节点,也存在,但是isEnd = false
,代表这个字符串没有存入集合中。 - 搜索尾节点为
null
,比如搜索xa
,代表该字符串不存在。
因为我们是获取尾节点,只需要在遍历的时候判断,当前指针是否为空,如果为空,则这个尾节点不存在。判断字符串是否包含存在于集合中,在 contains
函数中实现。
private TrieNode getNode(TrieNode node, String key) {
TrieNode p = node;
for(int i = 0; i < key.length(); i++){
char c = key.charAt(i);
if(p == null){
return null;
}
p = p.next[c];
}
return p;
}
实现
contains(String key)
方法,基于getNode
方法。
public boolean contains(String key) {
if(key == null){
throw new IllegalArgumentException("key is null");
}
TrieNode x = getNode(root, key);
// 要存在这个路径,且是字符串末尾,代表存在这个字符串
return x != null && x.isEnd;
}
实现
startsWith(String prefix)
判断字符串前缀是否存在于集合中,基于getNode
实现,只需存在前缀路径即可。
public boolean startsWith(String prefix) {
if(prefix == null){
throw new IllegalArgumentException("key is null");
}
TrieNode x = getNode(root, prefix);
// x 不为空,代表存在这个前缀路径
return x != null;
}
实现
remove
方法
public boolean remove(String key) {
if(key == null){
throw new IllegalArgumentException("key is null");
}
if(!contains(key)){
return false;
}
root = remove(root, key, 0);
size--;
return true;
}
private TrieNode remove(TrieNode node, String key, int i) {
if(i == key.length()){
node.isEnd = false;
}else{
char c = key.charAt(i);
node.next[c] = remove(node.next[c], key, i + 1);
}
if(node.isEnd) return node;
for(int j = 0; j < R; i++){
if(node.next[j] != null){
return node;
}
}
return null;
}
分析:以删除 xy
字符串为例。
- 首先从根节点开始,判断当前索引i是否等于字符串长度。如果等于将
isEnd
置为false
,代表逻辑删除该字符串。 - 显然不等于,进入
else
语句中,此时的c.next[x]
指向下一个数组。 - 然后一直向下,指向
3
这个数组,也就是索引i
等于字符串长度的时候,将isEnd
置为false
。 - 然后判断
isEnd
是否为true
,为true
,返回node
,代表这个节点不能删。 - 然后执行
for
循环,判断当前节点数组存储的元素是否都为空,显然最后的数组存储的元素全为null
。如果有一个不为null
,代表当前节点还有其他子节点。 - 如果前两项都不满足,代表当前节点没有子节点,那么就可以返回
null
(代表删除当前节点),此null
指向上一个节点也就是y
,此时next[y] == null
- 然后继续执行
5 6 7
步
- 层层向上判断,然后发现
y
指向null
,删除当前节点。返回null
,此时x
指向null
,然后删除当前节点。 - 最后返回
null
,也就是当前树变为空了。
注意
- 只要当前节点的
isEnd
不等于false
,或者当前节点的next
数组有一个元素不为空,就不能删除当前节点。 - 以下代码就是来判断上述情况的。
if(node.isEnd) return node;
for(int j = 0; j < R; i++){
if(node.next[j] != null){
return node;
}
}
return null;
为什么要这么麻烦,直接将 isEnd
置为 false
不就行了
- 因为
startWith(String prefix)
会判断指定前缀是否在集合中。但是如果我们只进行逻辑删除,假设说集合中只存储xyz
,然后删除xyz
,只是将isEnd
置为false
,然后判断xy
是否在集合中,这个时候就会返回true
,因为我们没有真正的删除。 - 假设说我们存储
xy
,又存储xyz
,此时删除了xyz
字符串,实际上只删除了z
,因为xy
和xyz
共用xy
,所以我们不能删除xy
。此时判断x
是否在集合中,返回true
是正确的。
2.2. 实现匹配相关的方法
// 在 trie 存储的字符串中,寻找 query 的最长前缀
public String longestPrefixOf(String query){}
// 在 trie 存储的字符串中,寻找 query 的最短前缀
public String shortestPrefixOf(String query){}
// 搜索前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){}
// 通配符 匹配任意字符串
public List<String> keyWithPattern(String pattern){}
// 通配符 匹配任意字符
public boolean matches(String pattern){}
2.2.1. 实现查找最长最短前缀方法
首先实现查找字符串最短前缀方法
shortestPrefixOf(String query)
假设说,我们存储在集合中的字符串为 ab abc abcde
,查 abcd
的最短前缀,也就是 ab
。
如何查找?
- 从根开始遍历,当遍历到第一个节点的
isEnd = true
时,即为我们要查找的字符串的最短前缀。
public String shortestPrefixOf(String query){
if(query == null){
throw new IllegalArgumentException("key is null");
}
// 从根节点开始遍历
TrieNode p = root;
int i = 0;
while(i < query.length() && p != null){
// isEnd = true,说明找到了
if(p.isEnd){
break;
}
// 否则继续向下查找
char c = query.charAt(i);
i++;
p = p.next[c];
}
// 返回结果
if(p != null && p.isEnd){
// 前闭后开区间
return query.substring(0, i);
}
// 否则返回 null
return null;
}
为什么是返回结果是 query.subString(0, i)
?
-
因为满足条件时,p指针指向第一个
isEnd = true
的节点,i 指向下一个next
。 -
又因为
subString(0, i)
,是前闭后后开区间,所以返回的是正确的结果。
实现查找字符串最长前缀方法
longestPrefixOf(String query)
假设说,我们存储在集合中的字符串为 ab abc abcde
,查 abcd
的最长前缀,也就是 abc
。也就距离 abcd
最近的字符串。
我们只能从根节点开始查找,那么如何获取距离 query
最近的前缀字符串呢?
- 我们可以再添加一个索引指针
len
,让其初始为len = i = 0
。 - 然后依次向下遍历,首先让
p, i
依次向下移动。 p.isEnd = true
时,移动len
指针,使len = i
。- 重复
2, 3
步,直到遍历结束。 - 那么此时
len
索引指针保存的就是最后一个isEnd = true
的下一个索引了。
public String shortestPrefixOf(String query){
if(query == null){
throw new IllegalArgumentException("key is null");
}
// 从根节点开始遍历
TrieNode p = root;
int i = 0;
int len = 0;
while(i < query.length() && p != null){
// isEnd = true
if(p.isEnd){
// 移动 len 指针至 i 的位置
len = i;
}
// 否则继续向下查找
char c = query.charAt(i);
i++;
p = p.next[c];
}
// 防止遍历结束时,p指针指向的节点的 isEnd 也为 true
if(p != null && p.isEnd){
len = i;
}
// 否则结果
return query.substring(0 ,i);
}
*为什么当 p != null && p.isEnd
时,需要 len = i
操作呢?
- 因为当遍历结束时,也就是
p
指针指向abcd
字符串的末尾时,可能abcd
字符串也存储在集合中了,那么此时isEnd = true
,也就是说最长前缀是它本身。那么此时需要移动len
指针至i
处,才能保证最长前缀正确。
2.2.2. 实现 keysWithPrefix
方法
// 返回前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){}
实现思路:
假如说,我们返回以 a
为前缀的所有字符串:
- 我们需要从根节点开始遍历,找到匹配
prefix
的最后节点,然后遍历prefix.next
为起点的多叉树,找到所有符合条件的字符串即可。
// 搜索前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){
LinkedList<String> res = new LinkedList<>();
// 找到匹配 prefix 的那个节点
TrieNode x = getNode(root, prefix);
if(x == null){
return res;
}
// DFS 遍历多叉树
traverse(x, new StringBuilder(prefix), res);
return res;
}
private void traverse(TrieNode node, StringBuilder path, LinkedList<String> res) {
}
主要实现 DFS
遍历多叉树 traverse
方法:
private void traverse(TrieNode node, StringBuilder path, LinkedList<String> res) {
if(node == null){
return;
}
if(node.isEnd){
// 是一个字符串,添加
res.add(path.toString());
}
for(char c = 0; c < R; c++){
path.append(c);
traverse(node.next[c], path, res);
path.deleteCharAt(path.length() - 1);
}
}
- 我们从
a
节点开始遍历多叉树 - 判断当前是否为
null
,然后判断isEnd
是否为true
,为true
说明是一个以a
为前缀的字符串。 - 然后开始遍历以
a.next
数组,也就是256
个分叉。首先遍历ASCII
码等于0
的分叉, 朝路径中追加ASCII = 0
的字符。此时path = a + 0
。 - 然后开始遍历这个分叉,发现
node == null
,说明没有这个分叉,然后直接返回,不需要遍历。因为此时path = a + 0
,我们需要维护路径的正确性,所以需要删除ASCII = 0
的字符。 - 然后一直循环。
- 直到
ASCII = 98
,也就是字符b
,然后遍历这个分叉,向path
追加字符b
,遇到isEnd = true
,向res
中添加ab
。 - 然后一直遍历分叉,直到尾部,然后层层返回,重新回到
a
处。 - 然后继续遍历,直到遍历完毕。
2.2.3. 实现 keyWithPattern
方法
// 通配符 . 匹配任意字符串
public List<String> keyWithPattern(String pattern){
List<String> res = new LinkedList<>();
traverse(root, new StringBuilder(), pattern, 0, res);
return res;
}
实现思路:假如说匹配字符串为 a.e
,.
占一位。
- 那么我们可以匹配
ade
,afe
- 首先判断当前节点是否为
null
,如果为null
,代表没有这个分叉。 - 然后当
i == pattern.length()
时,然后进一步判断是否是一个字符串,如果是则添加。 - 然后判断当前字符是否为
.
,如果为.
,则需要遍历当前节点的所有分叉,因为.
代表任意字符。 - 如果不是
.
,只遍历当前节点即可。
// 定义:在以 node 为根的 `trie` 树中匹配 pattern[i..]
private void traverse(TrieNode node, StringBuilder path, String pattern, int i, List<String> res) {
if(node == null){
return;
}
// pattern 匹配完成
if(i == pattern.length()){
if(node.isEnd){
res.add(path.toString());
}
return;
}
char c = pattern.charAt(i);
if(c == '.'){
for (char j = 0; j < R; j++){
path.append(j);
traverse(node.next[j], path, pattern, i + 1, res);
path.deleteCharAt(path.length() - 1);
}
}else{
path.append(c);
traverse(node.next[c], path, pattern, i + 1, res);
path.deleteCharAt(path.length() - 1);
}
}
2.2.4. 实现 matches
方法
// 通配符 . 匹配任意字符
public boolean matches(String pattern){}
这个方法与 keyWithPattern
方法实现思路类似。我们可以直接调用 keyWithPattern
方法,通过判断集合是否存储有元素,即可实现是否包含有匹配 pattern
的字符串了。但是由于 keyWithPattern
返回的是一个集合,当匹配到的字符串很多时,有性能损失。
所以我们需要重新实现,只需要对实现 keyWitrhPattern
方法的 traverse
进行改造即可:
// 通配符 . 匹配任意字符
public boolean matches(String pattern){
if(pattern == null){
throw new IllegalArgumentException("key is null");
}
return matches(root, pattern, 0);
}
private boolean matches(TrieNode node, String pattern, int i) {
if(node == null){
return false;
}
// pattern 匹配完成
if(i == pattern.length()){
return node.isEnd;
}
char c = pattern.charAt(i);
if(c == '.'){
for (char j = 0; j < R; j++){
if(matches(node.next[j], pattern, i + 1)){
return true;
}
}
}
//没有遇到通配符
return matches(node.next[c], pattern, i + 1);
}
转载自:https://juejin.cn/post/7215963267502850107