手把手教你写LRU算法

一. 介绍

什么是缓存? 我们为什么需要它?

缓存是操作系统按顺序维护的页面缓冲区,以避免昂贵的内存访问操作。 通常,缓存比主内存快得多。 由于高速缓存的大小与主内存相比非常小,因此有可能需要在高速缓存和主内存之间交换页面。 每当在高速缓存中找不到需要从主存储器中导入的页面时,该页面就称为高速缓存未命中。 LRU缓存实现是一种FIFO策略,用于最大程度地减少缓存未命中。

当引用了缓存中不存在的新页面时,如果缓存已满,则使用最近最少使用高速缓存方案来删除最近最少使用的页面。 如果页面在缓存中,则将其移到缓存的开头。 因此,缓存按引用顺序存储页面。

区间插入过程描述

LRU缓存使用以下数据结构实现:

  1. 双向链表:使用双向链表将高速缓存实现,链表大小就是缓存的大小。 最近使用的页面在链表的前端附近,最少使用的页面在链表的后端附近。

  2. HashMap:它将节点的地址存储在缓存中。 它有助于在O(1)时间中找到节点地址。

具体的使用方法如下:

CustomizedLRUCache cache = new CustomizedLRUCache(2);
cache.put(1, 1); // 缓存是 {1=1}
cache.put(2, 2); // 缓存是 {1=1, 2=2}
cache.get(1);    // 返回 1
cache.put(3, 3); // 该操作会使得关键字 2 作废,缓存是 {1=1, 3=3}
cache.get(2);    // 返回 -1 (未找到)
cache.put(4, 4); // 该操作会使得关键字 1 作废,缓存是 {4=4, 3=3}
cache.get(1);    // 返回 -1 (未找到)
cache.get(3);    // 返回 3
cache.get(4);    // 返回 4

算法流程

当引用页面时:

  1. 检查页面是否在缓存中,如果页面在内存中,移除节点并将其置于链表的最前面。
  2. 如果页面不在缓存中,如果队列未满,则在队列的前面添加一个新节点,并更新哈希中的相应节点地址。 如果队列已满,请从队列的后面移除节点,然后将新节点添加到队列的前面。

二. 实现

我们可以借助Java原有的数据结构LinkedHashMap来实现,也可以自定义数据结构来实现。当然对于后者,我是更推荐的!先来介绍一下自定义的数据结构实现方法。

(1) DoubleLinkedList + HashMap

这里我们将页面定义为Node类型的数据结构,其中Node类中包含键和值两个属性:

class Node {
    int key, val;
    Node pre, next;

    public Node(int key, int val) {
        this.key = key;
        this.val = val;
    }
}

定义好了页面,我们现在来定义存储页面的链表,双向链表:

class DoubleLinkedList {
    private Node head, tail;
    private int capacity;

    public DoubleLinkedList() {
        this.head = new Node(0, 0);
        this.tail = new Node(0, 0);
        this.head.next = this.tail;
        this.tail.pre = this.head;
        this.capacity = 0;
    }

    // 在链表的最后添加节点
    public void addLast(Node node) {
        node.pre = tail.pre;
        node.next = tail;
        tail.pre.next = node;
        tail.pre = node;
        capacity += 1;
    }

    // 删除某个节点
    public void remove(Node node) {
        node.pre.next = node.next;
        node.next.pre = node.pre;
        capacity -= 1;
    }

    // 移除首节点(后续的换出最少使用的节点使用)
    public Node removeFirst() {
        if (head.next == tail)
            return null;
        Node first = head.next;
        remove(first);
        return first;
    }

    // 返回当前容器的大小
    public int size() {
        return capacity;
    }
}

所以,我们有了基础的数据结构。有了双向链表的实现,我们只需要在 LRU 算法中把它和哈希表结合起来即可,先搭出代码框架:

class CustomizedLRUCache {
    private DoubleLinkedList cache;
    private HashMap<Integer, Node> map;
    private int capacity;

    public CustomizedLRUCache(int capacity) {
        cache = new DoubleLinkedList();
        map = new HashMap<>();
        this.capacity = capacity;
    }

    /**
     * LRUCache的put()方法
     *
     * @param key
     * @param value
     */
    public void put(int key, int value) {
        ...
    }

    /**
     * LRUCache的get()方法
     *
     * @param key
     * @return
     */
    public int get(int key) {
        ...
    }
}

先不慌去实现 LRU 算法的 getput 方法。由于我们要同时维护一个双链表 cache 和一个哈希表 map,很容易漏掉一些操作,比如说删除某个 key 时,在 cache 中删除了对应的 Node,但是却忘记在 map 中删除 key

解决这种问题的有效方法是:在这两种数据结构之上提供一层抽象 API

说的有点玄幻,实际上很简单,就是尽量让 LRU 的主方法 getput 避免直接操作 mapcache 的细节。我们可以先实现下面几个函数:

/**
* 将已有的键刷新为访问过
*
* @param key
*/
private void makeRecently(int key) {
    Node x = map.get(key);
    cache.remove(x);
    cache.addLast(x);
}

/**
* 使新的节点变为最新访问
*
* @param key
* @param val
*/
private void addRecently(int key, int val) {
    Node x = new Node(key, val);
    cache.addLast(x);
    map.put(key, x);
}

/**
* 遍历当前map中的页面
*/
public void currentState() {
    Iterator<Integer> iterator = map.keySet().iterator();
    while (iterator.hasNext()) {
        int key = iterator.next();
        Node node = map.get(key);
        System.out.print("[" + key + "->" + node.val + "]");
    }
    System.out.println();
}

然后我们就来实现getput方法

/**
     * LRUCache的get()方法
     *
     * @param key
     * @return
     */
public int get(int key) {
    if (!map.containsKey(key))
        return -1;
    makeRecently(key);
    return map.get(key).val;
}
/**
     * LRUCache的put()方法
     *
     * @param key
     * @param value
     */
public void put(int key, int value) {
    // 更新已有的键值
    if (map.containsKey(key)) {
        Node x = map.get(key);
        cache.remove(x);
        map.remove(key);
        addRecently(key, value);
        return;
    }
    // 缓存已满,置换出最少使用的键值对
    if (capacity <= cache.size()) {
        Node node = cache.removeFirst();
        map.remove(node.key);
    }
    // 添加新的键值对
    addRecently(key, value);
}

最后将代码进行整合,得到如下代码:

import java.util.HashMap;
import java.util.Iterator;

class Node {
    int key, val;
    Node pre, next;

    public Node(int key, int val) {
        this.key = key;
        this.val = val;
    }
}

class DoubleLinkedList {
    private Node head, tail;
    private int capacity;

    public DoubleLinkedList() {
        this.head = new Node(0, 0);
        this.tail = new Node(0, 0);
        this.head.next = this.tail;
        this.tail.pre = this.head;
        this.capacity = 0;
    }

    public void addLast(Node node) {
        node.pre = tail.pre;
        node.next = tail;
        tail.pre.next = node;
        tail.pre = node;
        capacity += 1;
    }

    public void remove(Node node) {
        node.pre.next = node.next;
        node.next.pre = node.pre;
        capacity -= 1;
    }

    public Node removeFirst() {
        if (head.next == tail)
            return null;
        Node first = head.next;
        remove(first);
        return first;
    }

    public int size() {
        return capacity;
    }
}

class CustomizedLRUCache {
    private DoubleLinkedList cache;
    private HashMap<Integer, Node> map;
    private int capacity;

    public CustomizedLRUCache(int capacity) {
        cache = new DoubleLinkedList();
        map = new HashMap<>();
        this.capacity = capacity;
    }

    /**
     * LRUCache的put()方法
     *
     * @param key
     * @param value
     */
    public void put(int key, int value) {
        // 更新已有的键值
        if (map.containsKey(key)) {
            Node x = map.get(key);
            cache.remove(x);
            map.remove(key);
            addRecently(key, value);
            return;
        }
        // 缓存已满,置换出最少使用的键值对
        if (capacity <= cache.size()) {
            Node node = cache.removeFirst();
            map.remove(node.key);
        }
        // 添加新的键值对
        addRecently(key, value);
    }

    /**
     * LRUCache的get()方法
     *
     * @param key
     * @return
     */
    public int get(int key) {
        if (!map.containsKey(key))
            return -1;
        makeRecently(key);
        return map.get(key).val;
    }

    /**
     * 将已有的键刷新为访问过
     *
     * @param key
     */
    private void makeRecently(int key) {
        Node x = map.get(key);
        cache.remove(x);
        cache.addLast(x);
    }

    /**
     * 使新的节点变为最新访问
     *
     * @param key
     * @param val
     */
    private void addRecently(int key, int val) {
        Node x = new Node(key, val);
        cache.addLast(x);
        map.put(key, x);
    }

    public void currentState() {
        Iterator<Integer> iterator = map.keySet().iterator();
        while (iterator.hasNext()) {
            int key = iterator.next();
            Node node = map.get(key);
            System.out.print("[" + key + "->" + node.val + "]");
        }
        System.out.println();
    }
}

编写main方法对算法进行测试:

public class Main {
    public static void main(String[] args) {
        CustomizedLRUCache cache = new CustomizedLRUCache(2);
        cache.put(1, 1);
        cache.put(2, 2);
        cache.currentState();
        cache.put(1, 3);
        cache.currentState();
        cache.put(3, 3);
        cache.currentState();
        cache.put(4, 4);
        cache.currentState();
        cache.get(3);
        cache.put(1, 1);
        cache.currentState();
    }
}

使用cache.currentState();来输出链表中的状态,这样可以实时看到链表中数据的变化。

输出的结果我们可以看到:

[1->1][2->2]
[1->3][2->2]
[1->3][3->3]
[3->3][4->4]
[1->1][3->3]

(2) LinkedHashMap

import java.util.LinkedHashMap;

class LRUCache {

    private LinkedHashMap<Integer, Integer> cache;
    private int capacity;
    public LRUCache(int capacity) {
        this.capacity = capacity;
        cache = new LinkedHashMap<>();
    }

    public int get(int key) {
        if(!cache.containsKey(key))
            return -1;
        makeRecently(key);
        return cache.get(key);
    }

    public void put(int key, int value) {
        if(cache.containsKey(key)){
            cache.put(key, value);
            makeRecently(key);
            return;
        }
        if(capacity <= cache.size()){
            int oldKey = cache.keySet().iterator().next();
            cache.remove(oldKey);
        }
        cache.put(key, value);
    }

    public void makeRecently(int key){
        int val = cache.get(key);
        cache.remove(key);
        cache.put(key, val);
    }
}

这个方法实现还是比较简单!


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!