数据结构之手写字典树
今天我们来实现字典树,或称之为前缀树、Trie 树,又称单词查找树或键树,是一种树形结构,是一种哈希树的变种。典型应用是用于统计和排序大量的字符串(但不限于字符串),所以经常被搜索引擎系统用于文本词频统计。它的优点是:利用字符串的公共前缀来减少查询时间,最大限度地减少无谓的字符串比较。
1. 字典树介绍
字典树的性质:
- 根节点不包含字符,除根节点外每一个节点都只包含一个字符。
- 从根节点到某一节点,路径上经过的字符连接起来,为该节点对应的字符串。
- 每个节点的所有子节点包含的字符都不相同。

比如我们存储字符串xy,xyz,xya,abc 。
初始代码介绍
为什么不使用泛型呢?
- 因为字典树操作对象是字符串。
public class TrieSet {
// ASCII 码个数
private static final int R = 256;
// 当前存在集合中的元素个数
private int size = 0;
private static class TrieNode {
// 记录下一个字符的指针
// 比如 next['a'] != null,说明这里有一个字符 'a'
// 为什么每个节点有256个孩子,因为字符的ASCII编码有256个
TrieNode[] next = new TrieNode[R];
// 标志该节点是否是最后一个节点
boolean isEnd = false;
}
private TrieNode root = null;
public TrieSet(){}
public boolean add(String key){}
public boolean remove(String key) {}
public boolean contains(String key) {}
public boolean startsWith(String prefix) {}
private TrieNode getNode(TrieNode node, String key) {}
}
2. 实现
2.1. 初步实现
实现
add方法
// 在 set 中添加 key,如果成功插入则返回 true
public boolean add(String key){
if(key == null){
throw new IllegalArgumentException("key is null");
}
if(contains(key)){
return false;
}
// 真正的插入方法
root = add(root, key, 0);
return true;
}
// 从 node 开始插入 key[i...]
private TrieNode add(TrieNode node, String key, int i) {
if(node == null){
node = new TrieNode();
}
if(i == key.length()){
node.isEnd = true;
return node;
}
char c = key.charAt(i);
node.next[c] = add(node.next[c], key, i + 1);
return node;
}
为什么这样写?
以插入 add("xy"),add("xyz") 为例。

首先我们插入 add("xy") :传入 root,xy 字符串、初始索引 0。
- 首先判断节点是否为
null,第一次root为null。 - 那么会创建一个新的
TrieNode节点,也就是包含一个next[256]的空数组,和一个isEnd变量。 - 然后判断i是否等于字符串长度,这时候显然不相等。
- 然后执行
key.charAt(i),获取字符串第一个字符x。 - 执行
node.next[x] = add(node.next[x], key, i + 1); - 执行下一次
add方法,这时候node.next[x]为空,然后会创建一个新数组指向node.next[x] - 当前索引不等于字符串长度,继续朝下执行,执行
key.charAt(i),获取第二个字符y。 - 执行
node.next[y] = add(node.next[y], key, i + 1); - 执行下一次
add方法,这时候node.next[y]为空,然后会创建一个新数组指向node.next[y] - 当前索引等于字符串长度,设置当前节点的
isEnd = true。 - 然后层层返回。
然后插入 add("xyz")
- 还是从
root开始,此时root不为null。 - 通过索引获取第一个字符
x。 - 执行
node.next[x] = add(node.next[x], key, i + 1); - 执行下一次
add方法,这时候node.next[x]不为空。 - 通过索引获取第二个字符y。
- 执行
node.next[y] = add(node.next[y], key, i + 1); - 执行下一次
add方法,这时候node.next[y]不为空 - 通过索引获取第三个字符z。
- 执行
node.next[z] = add(node.next[z], key, i + 1); - 此时
node.next[z]等于空,创建一个新的节点。 - 此时索引等于字符串长度,设置
isEnd = true,然后返回node。 - 向上层层返回。
通过以下过程,也就是每次按从左到右的顺序插入字符串中每个字符,并创建当前字符的数组。
逻辑上相当于一条链表:
- next 指针是一个大小为256的数组,节点的值是isEnd代表是否是一个字符串的末尾。
- 通过字符的索引,找到对应的下一个数组。

实现
getNode(TrieNode node, String key)从node节点开始搜索,返回指定字符串的尾节点。
可能的搜索情况有三种:
- 搜索尾节点不为
null,且isEnd = true,比如搜索字符串xy - 搜索尾节点不为
null,且isEnd = false,虽然搜索到了这个节点,也存在,但是isEnd = false,代表这个字符串没有存入集合中。 - 搜索尾节点为
null,比如搜索xa,代表该字符串不存在。
因为我们是获取尾节点,只需要在遍历的时候判断,当前指针是否为空,如果为空,则这个尾节点不存在。判断字符串是否包含存在于集合中,在 contains 函数中实现。
private TrieNode getNode(TrieNode node, String key) {
TrieNode p = node;
for(int i = 0; i < key.length(); i++){
char c = key.charAt(i);
if(p == null){
return null;
}
p = p.next[c];
}
return p;
}
实现
contains(String key)方法,基于getNode方法。
public boolean contains(String key) {
if(key == null){
throw new IllegalArgumentException("key is null");
}
TrieNode x = getNode(root, key);
// 要存在这个路径,且是字符串末尾,代表存在这个字符串
return x != null && x.isEnd;
}
实现
startsWith(String prefix)判断字符串前缀是否存在于集合中,基于getNode实现,只需存在前缀路径即可。
public boolean startsWith(String prefix) {
if(prefix == null){
throw new IllegalArgumentException("key is null");
}
TrieNode x = getNode(root, prefix);
// x 不为空,代表存在这个前缀路径
return x != null;
}
实现
remove方法
public boolean remove(String key) {
if(key == null){
throw new IllegalArgumentException("key is null");
}
if(!contains(key)){
return false;
}
root = remove(root, key, 0);
size--;
return true;
}
private TrieNode remove(TrieNode node, String key, int i) {
if(i == key.length()){
node.isEnd = false;
}else{
char c = key.charAt(i);
node.next[c] = remove(node.next[c], key, i + 1);
}
if(node.isEnd) return node;
for(int j = 0; j < R; i++){
if(node.next[j] != null){
return node;
}
}
return null;
}
分析:以删除 xy 字符串为例。

- 首先从根节点开始,判断当前索引i是否等于字符串长度。如果等于将
isEnd置为false,代表逻辑删除该字符串。 - 显然不等于,进入
else语句中,此时的c.next[x]指向下一个数组。 - 然后一直向下,指向
3这个数组,也就是索引i等于字符串长度的时候,将isEnd置为false。 - 然后判断
isEnd是否为true,为true,返回node,代表这个节点不能删。 - 然后执行
for循环,判断当前节点数组存储的元素是否都为空,显然最后的数组存储的元素全为null。如果有一个不为null,代表当前节点还有其他子节点。 - 如果前两项都不满足,代表当前节点没有子节点,那么就可以返回
null(代表删除当前节点),此null指向上一个节点也就是y,此时next[y] == null - 然后继续执行
5 6 7步

- 层层向上判断,然后发现
y指向null,删除当前节点。返回null,此时x指向null,然后删除当前节点。 - 最后返回
null,也就是当前树变为空了。
注意
- 只要当前节点的
isEnd不等于false,或者当前节点的next数组有一个元素不为空,就不能删除当前节点。 - 以下代码就是来判断上述情况的。
if(node.isEnd) return node;
for(int j = 0; j < R; i++){
if(node.next[j] != null){
return node;
}
}
return null;
为什么要这么麻烦,直接将 isEnd 置为 false 不就行了
- 因为
startWith(String prefix)会判断指定前缀是否在集合中。但是如果我们只进行逻辑删除,假设说集合中只存储xyz,然后删除xyz,只是将isEnd置为false,然后判断xy是否在集合中,这个时候就会返回true,因为我们没有真正的删除。 - 假设说我们存储
xy,又存储xyz,此时删除了xyz字符串,实际上只删除了z,因为xy和xyz共用xy,所以我们不能删除xy。此时判断x是否在集合中,返回true是正确的。
2.2. 实现匹配相关的方法
// 在 trie 存储的字符串中,寻找 query 的最长前缀
public String longestPrefixOf(String query){}
// 在 trie 存储的字符串中,寻找 query 的最短前缀
public String shortestPrefixOf(String query){}
// 搜索前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){}
// 通配符 匹配任意字符串
public List<String> keyWithPattern(String pattern){}
// 通配符 匹配任意字符
public boolean matches(String pattern){}
2.2.1. 实现查找最长最短前缀方法
首先实现查找字符串最短前缀方法
shortestPrefixOf(String query)

假设说,我们存储在集合中的字符串为 ab abc abcde,查 abcd 的最短前缀,也就是 ab 。
如何查找?
- 从根开始遍历,当遍历到第一个节点的
isEnd = true时,即为我们要查找的字符串的最短前缀。
public String shortestPrefixOf(String query){
if(query == null){
throw new IllegalArgumentException("key is null");
}
// 从根节点开始遍历
TrieNode p = root;
int i = 0;
while(i < query.length() && p != null){
// isEnd = true,说明找到了
if(p.isEnd){
break;
}
// 否则继续向下查找
char c = query.charAt(i);
i++;
p = p.next[c];
}
// 返回结果
if(p != null && p.isEnd){
// 前闭后开区间
return query.substring(0, i);
}
// 否则返回 null
return null;
}
为什么是返回结果是 query.subString(0, i) ?
-
因为满足条件时,p指针指向第一个
isEnd = true的节点,i 指向下一个next。 -
又因为
subString(0, i),是前闭后后开区间,所以返回的是正确的结果。
实现查找字符串最长前缀方法
longestPrefixOf(String query)
假设说,我们存储在集合中的字符串为 ab abc abcde,查 abcd 的最长前缀,也就是 abc 。也就距离 abcd 最近的字符串。
我们只能从根节点开始查找,那么如何获取距离 query 最近的前缀字符串呢?

- 我们可以再添加一个索引指针
len,让其初始为len = i = 0。 - 然后依次向下遍历,首先让
p, i依次向下移动。 p.isEnd = true时,移动len指针,使len = i。- 重复
2, 3步,直到遍历结束。 - 那么此时
len索引指针保存的就是最后一个isEnd = true的下一个索引了。
public String shortestPrefixOf(String query){
if(query == null){
throw new IllegalArgumentException("key is null");
}
// 从根节点开始遍历
TrieNode p = root;
int i = 0;
int len = 0;
while(i < query.length() && p != null){
// isEnd = true
if(p.isEnd){
// 移动 len 指针至 i 的位置
len = i;
}
// 否则继续向下查找
char c = query.charAt(i);
i++;
p = p.next[c];
}
// 防止遍历结束时,p指针指向的节点的 isEnd 也为 true
if(p != null && p.isEnd){
len = i;
}
// 否则结果
return query.substring(0 ,i);
}
*为什么当 p != null && p.isEnd 时,需要 len = i 操作呢?
- 因为当遍历结束时,也就是
p指针指向abcd字符串的末尾时,可能abcd字符串也存储在集合中了,那么此时isEnd = true,也就是说最长前缀是它本身。那么此时需要移动len指针至i处,才能保证最长前缀正确。

2.2.2. 实现 keysWithPrefix 方法
// 返回前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){}
实现思路:

假如说,我们返回以 a 为前缀的所有字符串:
- 我们需要从根节点开始遍历,找到匹配
prefix的最后节点,然后遍历prefix.next为起点的多叉树,找到所有符合条件的字符串即可。
// 搜索前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){
LinkedList<String> res = new LinkedList<>();
// 找到匹配 prefix 的那个节点
TrieNode x = getNode(root, prefix);
if(x == null){
return res;
}
// DFS 遍历多叉树
traverse(x, new StringBuilder(prefix), res);
return res;
}
private void traverse(TrieNode node, StringBuilder path, LinkedList<String> res) {
}
主要实现 DFS 遍历多叉树 traverse 方法:
private void traverse(TrieNode node, StringBuilder path, LinkedList<String> res) {
if(node == null){
return;
}
if(node.isEnd){
// 是一个字符串,添加
res.add(path.toString());
}
for(char c = 0; c < R; c++){
path.append(c);
traverse(node.next[c], path, res);
path.deleteCharAt(path.length() - 1);
}
}

- 我们从
a节点开始遍历多叉树 - 判断当前是否为
null,然后判断isEnd是否为true,为true说明是一个以a为前缀的字符串。 - 然后开始遍历以
a.next数组,也就是256个分叉。首先遍历ASCII码等于0的分叉, 朝路径中追加ASCII = 0的字符。此时path = a + 0。 - 然后开始遍历这个分叉,发现
node == null,说明没有这个分叉,然后直接返回,不需要遍历。因为此时path = a + 0,我们需要维护路径的正确性,所以需要删除ASCII = 0的字符。 - 然后一直循环。
- 直到
ASCII = 98,也就是字符b,然后遍历这个分叉,向path追加字符b,遇到isEnd = true,向res中添加ab。 - 然后一直遍历分叉,直到尾部,然后层层返回,重新回到
a处。 - 然后继续遍历,直到遍历完毕。
2.2.3. 实现 keyWithPattern 方法
// 通配符 . 匹配任意字符串
public List<String> keyWithPattern(String pattern){
List<String> res = new LinkedList<>();
traverse(root, new StringBuilder(), pattern, 0, res);
return res;
}
实现思路:假如说匹配字符串为 a.e,. 占一位。

- 那么我们可以匹配
ade,afe - 首先判断当前节点是否为
null,如果为null,代表没有这个分叉。 - 然后当
i == pattern.length()时,然后进一步判断是否是一个字符串,如果是则添加。 - 然后判断当前字符是否为
.,如果为.,则需要遍历当前节点的所有分叉,因为.代表任意字符。 - 如果不是
.,只遍历当前节点即可。
// 定义:在以 node 为根的 `trie` 树中匹配 pattern[i..]
private void traverse(TrieNode node, StringBuilder path, String pattern, int i, List<String> res) {
if(node == null){
return;
}
// pattern 匹配完成
if(i == pattern.length()){
if(node.isEnd){
res.add(path.toString());
}
return;
}
char c = pattern.charAt(i);
if(c == '.'){
for (char j = 0; j < R; j++){
path.append(j);
traverse(node.next[j], path, pattern, i + 1, res);
path.deleteCharAt(path.length() - 1);
}
}else{
path.append(c);
traverse(node.next[c], path, pattern, i + 1, res);
path.deleteCharAt(path.length() - 1);
}
}
2.2.4. 实现 matches 方法
// 通配符 . 匹配任意字符
public boolean matches(String pattern){}
这个方法与 keyWithPattern 方法实现思路类似。我们可以直接调用 keyWithPattern 方法,通过判断集合是否存储有元素,即可实现是否包含有匹配 pattern 的字符串了。但是由于 keyWithPattern 返回的是一个集合,当匹配到的字符串很多时,有性能损失。
所以我们需要重新实现,只需要对实现 keyWitrhPattern 方法的 traverse 进行改造即可:
// 通配符 . 匹配任意字符
public boolean matches(String pattern){
if(pattern == null){
throw new IllegalArgumentException("key is null");
}
return matches(root, pattern, 0);
}
private boolean matches(TrieNode node, String pattern, int i) {
if(node == null){
return false;
}
// pattern 匹配完成
if(i == pattern.length()){
return node.isEnd;
}
char c = pattern.charAt(i);
if(c == '.'){
for (char j = 0; j < R; j++){
if(matches(node.next[j], pattern, i + 1)){
return true;
}
}
}
//没有遇到通配符
return matches(node.next[c], pattern, i + 1);
}
转载自:https://juejin.cn/post/7215963267502850107