likes
comments
collection
share

数据结构之手写字典树

作者站长头像
站长
· 阅读数 17

今天我们来实现字典树,或称之为前缀树、Trie 树,又称单词查找树或键树,是一种树形结构,是一种哈希树的变种。典型应用是用于统计和排序大量的字符串(但不限于字符串),所以经常被搜索引擎系统用于文本词频统计。它的优点是:利用字符串的公共前缀来减少查询时间,最大限度地减少无谓的字符串比较。

1. 字典树介绍

字典树的性质:

  1. 根节点不包含字符,除根节点外每一个节点都只包含一个字符。
  2. 从根节点到某一节点,路径上经过的字符连接起来,为该节点对应的字符串。
  3. 每个节点的所有子节点包含的字符都不相同。

数据结构之手写字典树

比如我们存储字符串xy,xyz,xya,abc

初始代码介绍

为什么不使用泛型呢?

  • 因为字典树操作对象是字符串。
public class TrieSet {
​
    // ASCII 码个数
    private static final int R = 256;
​
    // 当前存在集合中的元素个数
    private int size = 0;
​
    private static class TrieNode {
        // 记录下一个字符的指针
        // 比如 next['a'] != null,说明这里有一个字符 'a'
​
        // 为什么每个节点有256个孩子,因为字符的ASCII编码有256个
        TrieNode[] next = new TrieNode[R];
​
        // 标志该节点是否是最后一个节点
        boolean isEnd = false;
​
    }
​
    private TrieNode root = null;
​
    public TrieSet(){}
​
    public boolean add(String key){}
​
​
    public boolean remove(String key) {}
​
​
    public boolean contains(String key) {}
​
    public boolean startsWith(String prefix) {}
​
    private TrieNode getNode(TrieNode node, String key) {}
​
​
}

2. 实现

2.1. 初步实现

实现 add 方法

// 在 set 中添加 key,如果成功插入则返回 true
public boolean add(String key){
    if(key == null){
        throw new IllegalArgumentException("key is null");
    }
    if(contains(key)){
        return false;
    }
    
    // 真正的插入方法
    root = add(root, key, 0);
    return true;
}
​
// 从 node 开始插入 key[i...]
private TrieNode add(TrieNode node, String key, int i) {
        if(node == null){
            node = new TrieNode();
        }
        if(i == key.length()){
            node.isEnd = true;
            return node;
        }
        char c = key.charAt(i);
        node.next[c] = add(node.next[c], key, i + 1);
        return node;
}

为什么这样写?

以插入 add("xy"),add("xyz") 为例。 数据结构之手写字典树

首先我们插入 add("xy") :传入 rootxy 字符串、初始索引 0

  • 首先判断节点是否为 null ,第一次 rootnull
  • 那么会创建一个新的 TrieNode 节点,也就是包含一个 next[256] 的空数组,和一个 isEnd 变量。
  • 然后判断i是否等于字符串长度,这时候显然不相等。
  • 然后执行 key.charAt(i) ,获取字符串第一个字符 x
  • 执行 node.next[x] = add(node.next[x], key, i + 1);
  • 执行下一次 add 方法,这时候 node.next[x] 为空,然后会创建一个新数组指向 node.next[x]
  • 当前索引不等于字符串长度,继续朝下执行,执行 key.charAt(i),获取第二个字符y。
  • 执行 node.next[y] = add(node.next[y], key, i + 1);
  • 执行下一次add 方法,这时候 node.next[y] 为空,然后会创建一个新数组指向 node.next[y]
  • 当前索引等于字符串长度,设置当前节点的 isEnd = true
  • 然后层层返回。

数据结构之手写字典树 然后插入 add("xyz")

  • 还是从 root 开始,此时 root 不为 null
  • 通过索引获取第一个字符 x
  • 执行 node.next[x] = add(node.next[x], key, i + 1);
  • 执行下一次 add 方法,这时候 node.next[x] 不为空。
  • 通过索引获取第二个字符y。
  • 执行 node.next[y] = add(node.next[y], key, i + 1);
  • 执行下一次 add 方法,这时候 node.next[y] 不为空
  • 通过索引获取第三个字符z。
  • 执行 node.next[z] = add(node.next[z], key, i + 1);
  • 此时 node.next[z] 等于空,创建一个新的节点。
  • 此时索引等于字符串长度,设置 isEnd = true ,然后返回 node
  • 向上层层返回。

通过以下过程,也就是每次按从左到右的顺序插入字符串中每个字符,并创建当前字符的数组。

逻辑上相当于一条链表:

  • next 指针是一个大小为256的数组,节点的值是isEnd代表是否是一个字符串的末尾。
  • 通过字符的索引,找到对应的下一个数组。

数据结构之手写字典树

实现 getNode(TrieNode node, String key)node 节点开始搜索,返回指定字符串的尾节点。

数据结构之手写字典树 可能的搜索情况有三种:

  1. 搜索尾节点不为 null ,且 isEnd = true ,比如搜索字符串 xy
  2. 搜索尾节点不为 null ,且 isEnd = false ,虽然搜索到了这个节点,也存在,但是 isEnd = false ,代表这个字符串没有存入集合中。
  3. 搜索尾节点为 null,比如搜索 xa ,代表该字符串不存在。

因为我们是获取尾节点,只需要在遍历的时候判断,当前指针是否为空,如果为空,则这个尾节点不存在。判断字符串是否包含存在于集合中,在 contains 函数中实现。

private TrieNode getNode(TrieNode node, String key) {
    TrieNode p = node;
    for(int i = 0; i < key.length(); i++){
        char c = key.charAt(i);
        if(p == null){
            return null;
        }
        p = p.next[c];
    }
    return p;
}

实现 contains(String key) 方法,基于 getNode 方法。

public boolean contains(String key) {
    if(key == null){
        throw new IllegalArgumentException("key is null");
    }
    TrieNode x = getNode(root, key);
    // 要存在这个路径,且是字符串末尾,代表存在这个字符串
    return x != null && x.isEnd;
}

实现 startsWith(String prefix) 判断字符串前缀是否存在于集合中,基于 getNode 实现,只需存在前缀路径即可。

public boolean startsWith(String prefix) {
    if(prefix == null){
        throw new IllegalArgumentException("key is null");
    }
    TrieNode x = getNode(root, prefix);
    // x 不为空,代表存在这个前缀路径
    return x != null;
}

实现 remove 方法

public boolean remove(String key) {
    if(key == null){
        throw new IllegalArgumentException("key is null");
    }
    if(!contains(key)){
        return false;
    }
    root = remove(root, key, 0);
    size--;
    return true;
}
​
private TrieNode remove(TrieNode node, String key, int i) {
    if(i == key.length()){
        node.isEnd = false;
    }else{
        char c = key.charAt(i);
        node.next[c] = remove(node.next[c], key, i + 1);
    }
    
    if(node.isEnd) return node;
    
    for(int j = 0; j < R; i++){
        if(node.next[j] != null){
            return node;
        }
    }
    
    return null;
}

分析:以删除 xy 字符串为例。 数据结构之手写字典树

  1. 首先从根节点开始,判断当前索引i是否等于字符串长度。如果等于将 isEnd 置为 false ,代表逻辑删除该字符串。
  2. 显然不等于,进入 else 语句中,此时的 c.next[x] 指向下一个数组。
  3. 然后一直向下,指向 3 这个数组,也就是索引 i等于字符串长度的时候,将 isEnd 置为 false
  4. 然后判断 isEnd 是否为 true,为 true ,返回 node ,代表这个节点不能删。
  5. 然后执行 for 循环,判断当前节点数组存储的元素是否都为空,显然最后的数组存储的元素全为null 。如果有一个不为 null ,代表当前节点还有其他子节点。
  6. 如果前两项都不满足,代表当前节点没有子节点,那么就可以返回 null (代表删除当前节点),此 null 指向上一个节点也就是 y ,此时 next[y] == null
  7. 然后继续执行 5 6 7

数据结构之手写字典树

  1. 层层向上判断,然后发现 y 指向 null ,删除当前节点。返回 null ,此时 x 指向 null ,然后删除当前节点。
  2. 最后返回 null ,也就是当前树变为空了。

注意

  • 只要当前节点的 isEnd 不等于 false ,或者当前节点的 next 数组有一个元素不为空,就不能删除当前节点。
  • 以下代码就是来判断上述情况的。
if(node.isEnd) return node;
for(int j = 0; j < R; i++){
    if(node.next[j] != null){
        return node;
    }
}
return null;

为什么要这么麻烦,直接将 isEnd 置为 false 不就行了

  • 因为 startWith(String prefix) 会判断指定前缀是否在集合中。但是如果我们只进行逻辑删除,假设说集合中只存储 xyz,然后删除 xyz ,只是将 isEnd 置为 false ,然后判断 xy 是否在集合中,这个时候就会返回 true ,因为我们没有真正的删除。
  • 假设说我们存储 xy ,又存储 xyz ,此时删除了 xyz 字符串,实际上只删除了 z ,因为xyxyz 共用 xy ,所以我们不能删除 xy 。此时判断 x是否在集合中,返回 true 是正确的。

2.2. 实现匹配相关的方法

// 在 trie 存储的字符串中,寻找 query 的最长前缀
public String longestPrefixOf(String query){}

// 在 trie 存储的字符串中,寻找 query 的最短前缀
public String shortestPrefixOf(String query){}

// 搜索前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){}

// 通配符 匹配任意字符串
public List<String> keyWithPattern(String pattern){}

// 通配符 匹配任意字符

public boolean matches(String pattern){}

2.2.1. 实现查找最长最短前缀方法

首先实现查找字符串最短前缀方法 shortestPrefixOf(String query)

数据结构之手写字典树

假设说,我们存储在集合中的字符串为 ab abc abcde,查 abcd 的最短前缀,也就是 ab

如何查找?

  • 从根开始遍历,当遍历到第一个节点的 isEnd = true 时,即为我们要查找的字符串的最短前缀。
public String shortestPrefixOf(String query){
    if(query == null){
        throw new IllegalArgumentException("key is null");
    }
    // 从根节点开始遍历
    TrieNode p = root;
    int i = 0;
    while(i < query.length() && p != null){
        // isEnd = true,说明找到了
        if(p.isEnd){
            break;
        }
        // 否则继续向下查找
        char c = query.charAt(i);
        i++;
        p = p.next[c];
    }
    
    // 返回结果
    if(p != null && p.isEnd){
        // 前闭后开区间
        return query.substring(0, i);
    }

    // 否则返回 null
    return null;
}

为什么是返回结果是 query.subString(0, i)

  • 因为满足条件时,p指针指向第一个 isEnd = true 的节点,i 指向下一个 next

  • 又因为 subString(0, i) ,是前闭后后开区间,所以返回的是正确的结果。

实现查找字符串最长前缀方法 longestPrefixOf(String query)

假设说,我们存储在集合中的字符串为 ab abc abcde,查 abcd 的最长前缀,也就是 abc 。也就距离 abcd 最近的字符串。

我们只能从根节点开始查找,那么如何获取距离 query 最近的前缀字符串呢?

数据结构之手写字典树

  1. 我们可以再添加一个索引指针 len,让其初始为 len = i = 0
  2. 然后依次向下遍历,首先让 p, i 依次向下移动。
  3. p.isEnd = true 时,移动 len 指针,使 len = i
  4. 重复 2, 3 步,直到遍历结束。
  5. 那么此时 len 索引指针保存的就是最后一个 isEnd = true 的下一个索引了。
public String shortestPrefixOf(String query){
    if(query == null){
        throw new IllegalArgumentException("key is null");
    }
    // 从根节点开始遍历
    TrieNode p = root;
    int i = 0;
    int len = 0;
    while(i < query.length() && p != null){
        // isEnd = true
        if(p.isEnd){
            // 移动 len 指针至 i 的位置
            len = i;
        }
        // 否则继续向下查找
        char c = query.charAt(i);
        i++;
        p = p.next[c];
    }

    // 防止遍历结束时,p指针指向的节点的 isEnd 也为 true
    if(p != null && p.isEnd){
        len = i;
    }

    // 否则结果
    return query.substring(0 ,i);
}

*为什么当 p != null && p.isEnd 时,需要 len = i 操作呢?

  • 因为当遍历结束时,也就是 p 指针指向 abcd 字符串的末尾时,可能 abcd 字符串也存储在集合中了,那么此时 isEnd = true ,也就是说最长前缀是它本身。那么此时需要移动 len 指针至 i 处,才能保证最长前缀正确。

数据结构之手写字典树

2.2.2. 实现 keysWithPrefix 方法

// 返回前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){}

实现思路:

数据结构之手写字典树

假如说,我们返回以 a 为前缀的所有字符串:

  • 我们需要从根节点开始遍历,找到匹配 prefix 的最后节点,然后遍历 prefix.next 为起点的多叉树,找到所有符合条件的字符串即可。
// 搜索前缀为 prefix 的所有字符串
public List<String> keysWithPrefix(String prefix){
    LinkedList<String> res = new LinkedList<>();
    // 找到匹配 prefix 的那个节点
    TrieNode x = getNode(root, prefix);
    if(x == null){
        return res;
    }
    
    // DFS 遍历多叉树
    traverse(x, new StringBuilder(prefix), res);
    return res;
}

private void traverse(TrieNode node, StringBuilder path, LinkedList<String> res) {

}

主要实现 DFS 遍历多叉树 traverse 方法:

private void traverse(TrieNode node, StringBuilder path, LinkedList<String> res) {
    if(node == null){
        return;
    }

    if(node.isEnd){
        // 是一个字符串,添加
        res.add(path.toString());
    }

    for(char c = 0; c < R; c++){
        path.append(c);
        traverse(node.next[c], path, res);
        path.deleteCharAt(path.length() - 1);
    }
}

数据结构之手写字典树

  • 我们从 a 节点开始遍历多叉树
  • 判断当前是否为 null,然后判断 isEnd 是否为 true,为 true 说明是一个以 a 为前缀的字符串。
  • 然后开始遍历以 a.next 数组,也就是 256 个分叉。首先遍历 ASCII 码等于 0 的分叉, 朝路径中追加 ASCII = 0 的字符。此时 path = a + 0
  • 然后开始遍历这个分叉,发现 node == null ,说明没有这个分叉,然后直接返回,不需要遍历。因为此时 path = a + 0 ,我们需要维护路径的正确性,所以需要删除 ASCII = 0 的字符。
  • 然后一直循环。
  • 直到 ASCII = 98,也就是字符 b,然后遍历这个分叉,向 path 追加字符 b ,遇到 isEnd = true,向 res 中添加 ab
  • 然后一直遍历分叉,直到尾部,然后层层返回,重新回到 a 处。
  • 然后继续遍历,直到遍历完毕。

2.2.3. 实现 keyWithPattern 方法

// 通配符 . 匹配任意字符串
public List<String> keyWithPattern(String pattern){
    List<String> res = new LinkedList<>();
    traverse(root, new StringBuilder(), pattern, 0, res);
    return res;
}

实现思路:假如说匹配字符串为 a.e. 占一位。 数据结构之手写字典树

  • 那么我们可以匹配 adeafe
  • 首先判断当前节点是否为 null,如果为 null,代表没有这个分叉。
  • 然后当 i == pattern.length() 时,然后进一步判断是否是一个字符串,如果是则添加。
  • 然后判断当前字符是否为 .,如果为 .,则需要遍历当前节点的所有分叉,因为 . 代表任意字符。
  • 如果不是 .,只遍历当前节点即可。
// 定义:在以 node 为根的 `trie` 树中匹配 pattern[i..]
private void traverse(TrieNode node, StringBuilder path, String pattern, int i, List<String> res) {
    if(node == null){
        return;
    }

    // pattern 匹配完成
    if(i == pattern.length()){
        if(node.isEnd){
            res.add(path.toString());
        }
        return;
    }
    char c = pattern.charAt(i);
    if(c == '.'){
        for (char j = 0; j < R; j++){
            path.append(j);
            traverse(node.next[j], path, pattern, i + 1, res);
            path.deleteCharAt(path.length() - 1);
        }
    }else{
        path.append(c);
        traverse(node.next[c], path, pattern, i + 1, res);
        path.deleteCharAt(path.length() - 1);
    }

}

2.2.4. 实现 matches 方法

// 通配符 . 匹配任意字符
public boolean matches(String pattern){}

这个方法与 keyWithPattern 方法实现思路类似。我们可以直接调用 keyWithPattern 方法,通过判断集合是否存储有元素,即可实现是否包含有匹配 pattern 的字符串了。但是由于 keyWithPattern 返回的是一个集合,当匹配到的字符串很多时,有性能损失。

所以我们需要重新实现,只需要对实现 keyWitrhPattern 方法的 traverse 进行改造即可:

// 通配符 . 匹配任意字符
public boolean matches(String pattern){
    if(pattern == null){
        throw new IllegalArgumentException("key is null");
    }
    return matches(root, pattern, 0);
}

private boolean matches(TrieNode node, String pattern, int i) {
    if(node == null){
        return false;
    }

    // pattern 匹配完成
    if(i == pattern.length()){
        return node.isEnd;
    }
    char c = pattern.charAt(i);
    if(c == '.'){
        for (char j = 0; j < R; j++){
            if(matches(node.next[j], pattern, i + 1)){
                return true;
            }
        }
    }
    //没有遇到通配符
    return matches(node.next[c], pattern, i + 1);
}

字典树完整代码