likes
comments
collection
share

手撕 LFU 缓存

作者站长头像
站长
· 阅读数 25

大家好,我是 方圆。LFU 的缩写是 Least Frequently Used,简单理解则是将使用最少的元素移除,如果存在多个使用次数最小的元素,那么则需要移除最近不被使用的元素。LFU 缓存在 LeetCode 上是一道困难的题目,实现起来并不容易,所以决定整理和记录一下。如果大家想要找刷题路线的话,可以参考 Github: LeetCode

LFU 缓存

LeetCode 原题:460. LFU 缓存 困难

思路:我们需要维护两个 HashMap,分别为 keyNodeMapaccessNodeMap,它们的职责如下:

  • keyNodeMap: key 为 put 的值的 key 值,value 为该 key 对应的节点,那么我们便可以通过这个 map 以 O(1) 的时间复杂度 get 到对应的值

  • accessNodeMap: key 为访问次数,各个节点被 put 和 get 都会使节点 accessNum 访问次数加 1,value 为该访问次数下的 循环双向链表的头节点,通过双向链表我们能以 O(1) 的时间复杂度将节点移除。我们定义 在相同访问次数下,越早插入的节点越靠近双向链表的尾端,在进行节点移除时,会将尾节点移除。

为了更好的理解两个 map 与链表节点的关系,我们用下图对容量为 3 的缓存进行表示,其中绿色代表链表节点,节点中各个值对应的字段为 key, value, accessNum

手撕 LFU 缓存

除此之外,我们要定义一个 minAccessNum 的字段来维护当前缓存中最小的访问次数,这样我们就能够在时间复杂度为 O(1) 的情况下在 accessNodeMap 中获取到对应访问次数的双向链表。

大致的方向确定了,我们需要再想一下具体的实现:

get 方法:我们首先去 keyNodeMap 中拿,没有的话返回 -1 即可。如果有对应的 key 的话,那么我们需要将对应节点的访问次数加 1,并需要改变它所在 accessNodeMap 中的位置:首先需要断开它与原链表的连接,之后加入到新的链表中,如果在 accessNodeMap 中有对应次数的链表,那么我们需要把它插入到该链表的 头节点;如果没有对应访问次数的双向链表的话,我们需要创建该访问次数的链表,并以该节点为头节点,维护在 accessNodeMap 中。这里需要注意,我们要对 minAccessNum 进行 更新,如果该节点的访问次数和 minAccessNum 相等,并且该节点在原来链表删除后,该访问次数下的链表中不存在其他任何节点,那么 minAccessNum 也要加 1。

put 方法:我们同样也需要在 keyNodeMap 中判断是否存在,存在的话需要将值进行覆盖,之后的处理逻辑与 get 方法一致。如果不存在的话,我们这里需要判断缓存的容量 是否足够,足够的话比较简单,先将其 put 到 keyNodeMap 中,再在 accessNodeMap 中将其插入到 key 为 1 的双向链表的头节点即可,这里要注意更改 minAccessNum 为 1,因为新插入的节点一定是访问次数最少的;如果不够的话那么先要 将最少使用的节点移除(在两个 map 中都要移除),在 accessNodeMap 中进行移除时,需要根据 minAccessNum 获取对应的双向链表,移除它的尾节点。在尾节点移除完之后,执行的逻辑和上述容量足够时执行插入节点的逻辑一致。

具体实现已经比较清楚了,直接上代码吧,大家可以关注一下注释信息:

class LFUCache {

    /**
     * 双向链表节点
     */
    static class Node {

        Node left;

        Node right;

        int key;

        int value;

        int accessNum;

        public Node(int key, int value, int accessNum) {
            this.key = key;
            this.value = value;
            this.accessNum = accessNum;
        }

    }

    private HashMap<Integer, Node> keyNodeMap;

    private HashMap<Integer, Node> accessNodeMap;

    private int minAccessNum;

    private int capacity;

    public LFUCache(int capacity) {
        keyNodeMap = new HashMap<>(capacity);
        accessNodeMap = new HashMap<>(capacity);
        minAccessNum = 0;
        this.capacity = capacity;
    }

    public int get(int key) {
        if (keyNodeMap.containsKey(key)) {
            Node node = keyNodeMap.get(key);
            // 如果所在链表只有一个节点的话,那么直接将该访问次数的链表删掉
            if (node == node.right) {
                accessNodeMap.remove(node.accessNum);
                // 维护缓存中最小的访问次数
                if (minAccessNum == node.accessNum) {
                    minAccessNum++;
                }
            } else {
                // 断开与原链表的连接
                node.left.right = node.right;
                node.right.left = node.left;
                // 如果该节点是头节点的话,那么需要替换为它的下一个节点作为头节点
                if (node == accessNodeMap.get(node.accessNum)) {
                    accessNodeMap.put(node.accessNum, node.right);
                }
            }

            // 增加后的访问次数链表看看有没有
            node.accessNum++;
            if (accessNodeMap.containsKey(node.accessNum)) {
                Node target = accessNodeMap.get(node.accessNum);
                // 插入头节点
                insertHead(node, target);
            } else {
                // 没有的话,直接 put 即可
                accessNodeMap.put(node.accessNum, node);
                // 单节点循环链表
                node.left = node;
                node.right = node;
            }

            return node.value;
        } else {
            return -1;
        }
    }

    public void put(int key, int value) {
        if (keyNodeMap.containsKey(key)) {
            Node node = keyNodeMap.get(key);
            node.value = value;
            // 执行get方法
            get(key);
        } else {
            Node node = new Node(key, value, 1);
            if (keyNodeMap.size() == capacity) {
                // 容量不够需要将最少使用的节点移除
                Node oldNodeHead = accessNodeMap.get(minAccessNum);

                Node tail = oldNodeHead.left;
                // 如果所在链表只有一个节点的话,那么直接将该访问次数的链表删掉
                if (tail.right == tail) {
                    accessNodeMap.remove(tail.accessNum);
                } else {
                    // 断开与原链表的连接
                    tail.left.right = tail.right;
                    tail.right.left = tail.left;
                    // 如果该节点是头节点的话,那么需要替换为它的下一个节点作为头节点
                    if (oldNodeHead == accessNodeMap.get(tail.accessNum)) {
                        accessNodeMap.put(tail.accessNum, tail.right);
                    }
                }
                keyNodeMap.remove(tail.key);
            }
            // 这样就有有足够的容量了
            keyNodeMap.put(key, node);

            // 是否有对应的链表
            if (accessNodeMap.containsKey(node.accessNum)) {
                // 插入头节点
                insertHead(node, accessNodeMap.get(node.accessNum));
            } else {
                // 没有对应的链表 直接插入
                accessNodeMap.put(node.accessNum, node);
                node.left = node;
                node.right = node;
            }
            minAccessNum = 1;
        }
    }

    private void insertHead(Node node, Node target) {
        // 拼接到该链表头,并构建循环双向链表
        node.right = target;
        node.left = target.left;
        target.left.right = node;
        target.left = node;
        // 覆盖头节点
        accessNodeMap.put(node.accessNum, node);
    }

}

需要注意的是:

  1. 因为我们维护的是循环双向链表,所以在插入头节点时注意尾节点和头节点的引用关系

  2. 因为我们在 accessNodeMap 中维护的是头节点,所以当我们将链表的头结点进行移除时,需要将头节点的下一个节点作为新的头节点保存在 accessNodeMap

针对第二点我们可以做一个优化,每当第一次生成双向链表的时候,我们创建一个哨兵节点作为头节点,那么这样我们就无需在头节点被移除后再将新的头节点插入 accessNodeMap 中进行覆盖了,始终保持 accessNodeMap 中 value 值保存的是哨兵节点,最终代码如下:

class LFUCache {
    /**
     * 双向链表节点
     */
    static class Node {

        int key, value;

        Node pre, next;

        int accessNum;

        public Node(int key, int value, int accessNum) {
            this.key = key;
            this.value = value;
            this.accessNum = accessNum;
        }
    }

    /**
     * 记录访问最小的值
     */
    private int minAccessNum;

    private final int capacity;

    private final HashMap<Integer, Node> accessNodeMap;

    private final HashMap<Integer, Node> keyNodeMap;

    public LFUCache(int capacity) {
        this.capacity = capacity;
        accessNodeMap = new HashMap<>(capacity);
        keyNodeMap = new HashMap<>(capacity);

        // 初始化访问次数为 1 的哨兵节点
        minAccessNum = 1;
        accessNodeMap.put(minAccessNum, initialSentinelNode(minAccessNum));
    }

    public int get(int key) {
        if (keyNodeMap.containsKey(key)) {
            Node node = keyNodeMap.get(key);
            // 找到新的位置
            insertIntoNextSentinel(node);

            return node.value;
        }

        return -1;
    }

    public void put(int key, int value) {
        if (keyNodeMap.containsKey(key)) {
            Node node = keyNodeMap.get(key);
            node.value = value;

            insertIntoNextSentinel(node);
        } else {
            if (keyNodeMap.size() == capacity) {
                // 移除最老的节点
                removeEldest();
            }
            // 新加进来的肯定是最小访问次数 1
            minAccessNum = 1;
            Node newNode = new Node(key, value, minAccessNum);

            // 插入到头节点
            insertIntoHead(newNode, accessNodeMap.get(minAccessNum));
            keyNodeMap.put(key, newNode);
        }
    }

    /**
     * 插入下一个链表中
     */
    private void insertIntoNextSentinel(Node node) {
        // 在原来的位置移除
        remove(node);
        // 尝试更新 minAccessNum
        tryToIncreaseMinAccessNum(node.accessNum++);
        // 获取增加 1 后的哨兵节点
        Node nextCacheSentinel = getSpecificAccessNumSentinel(node.accessNum);
        // 插入该链表的头节点
        insertIntoHead(node, nextCacheSentinel);
    }

    /**
     * 在原链表中移除
     */
    private void remove(Node node) {
        node.pre.next = node.next;
        node.next.pre = node.pre;
        node.next = null;
        node.pre = null;
    }

    /**
     * 尝试更新 minAccessNum
     */
    private void tryToIncreaseMinAccessNum(int accessNum) {
        // 原访问次数的哨兵节点
        Node originSentinel = accessNodeMap.get(accessNum);
        // 如果只剩下哨兵节点的话,需要看看需不需要把 minAccessNum 增加 1
        if (originSentinel.next == originSentinel && originSentinel.accessNum == minAccessNum) {
            minAccessNum++;
        }
    }

    /**
     * 获取指定访问次数的哨兵节点
     */
    private Node getSpecificAccessNumSentinel(int accessNum) {
        if (accessNodeMap.containsKey(accessNum)) {
            return accessNodeMap.get(accessNum);
        } else {
            // 没有的话得初始化一个
            Node nextCacheSentinel = initialSentinelNode(accessNum);
            accessNodeMap.put(accessNum, nextCacheSentinel);

            return nextCacheSentinel;
        }
    }

    /**
     * 生成具体访问次数的哨兵节点
     */
    private Node initialSentinelNode(int accessNum) {
        Node sentinel = new Node(-1, -1, accessNum);
        // 双向循环链表
        sentinel.next = sentinel;
        sentinel.pre = sentinel;

        return sentinel;
    }

    /**
     * 插入头节点
     */
    private void insertIntoHead(Node node, Node nextCacheSentinel) {
        node.next = nextCacheSentinel.next;
        nextCacheSentinel.next.pre = node;
        nextCacheSentinel.next = node;
        node.pre = nextCacheSentinel;
    }

    /**
     * 如果容量满了的话,需要把 minAccessNum 访问次数的尾巴节点先移除掉
     */
    private void removeEldest() {
        Node minSentinel = accessNodeMap.get(minAccessNum);

        Node tail = minSentinel.pre;
        tail.pre.next = tail.next;
        minSentinel.pre = tail.pre;
        keyNodeMap.remove(tail.key);
    }
}

巨人的肩膀