likes
comments
collection
share

同事有话说:如何高效的统计出Bitmap中1的个数

作者站长头像
站长
· 阅读数 34

今天,我活泼可爱的同事们在群里欢快地讨论着一个问题:如何高效的统计出Bitmap中1的个数?那么问题来了,Bitmap是什么?

Bitmap

A Bitmap or bitset is an array of zeros and ones. A bit in a bitset can be set to either 0 or 1, and each position in the array is referred to as an offset. Operations such as logical AND, OR, XOR, etc. and other bitwise operations are fair game for Bitmaps.

以上摘自文章 REDIS BITMAPS – FAST, EASY, REALTIME METRICS ,发现还是原汁原味的好理解(我自己翻译的都是什么鬼),其实就是一个二进制数组,数组元素只能是0或者1,你可以对位图执行逻辑与、或、异或等位运算。

所以,其实我们可以把问题简化一下:

如何高效的统计出二进制数中1的个数

这个就好理解多了嘛,马上安排。俗话说的好,“没什么好想法的时候,直接暴破”,于是就有了第一种解法:

public static int count(int n) {
    int count = 0;
    while (n > 0) {
        if ((n&1) == 1) { // 计算最低位是不是为1
            count++;
        }
        n >>>= 1; // 无符号右移1位,把最低位挤出去
    }
    return count;
}

这个解法简单又粗暴,时间复杂度为O(N),N就是二进制数字的长度,比如二进制数字100100,N就等于6.于是问题又来了,如果这个二进制数长度很长,比如10000,但是1的数量只有10个,使用第一种解法的话我们需要循环计算10000次,但实际执行count++的操作只有10次,剩余的9990次似乎有些浪费,我们能不能做到1有多少个就循环计算多少遍?于是第二种解法应运而生:

public static int count1(int n) {
    int count = 0;
    // 每次将二进制数的最低位 1 变为 0,直到该二进制数变到 0 为止
    while (n > 0) {
        count++;
        n &= n-1;
    }
    return count;
}

其实就是利用了n&(n-1)能够剔除最低位1(变为0),具体的解释可以看下面这张图,我也不知道图的出处,群里面保存的,图侵删... 同事有话说:如何高效的统计出Bitmap中1的个数 这种解法相对于第一种解法来说效率有所提升,由于每次都是n&(n-1),初始时n有几个1就循环几次即可。但是问题双来了,如果这个二进制数长度非常长的同时,1的数量又非常多,比如10111......1110中间省略100万个1,那按照解法2不还是得循环超过百万次,不能忍吧,所以还有没有更好的解法?

如果对Redis熟悉的童鞋,可能会联想到Redis也提供了Bitmap的数据结构,它提供了一个叫做bitcount的命令可以用来获取字符串从start字节到end字节比特位值为1的数量(Redis的Bitmap底层结构就是字符串,即sds结构,感兴趣的推荐阅读 Redis 5设计与源码分析 第2章,还有老钱的这篇 Redis 字符串精致的内部结构 也值得一看,介绍Redis Bitmap的文章能google到很多,可以看下这篇 Redis中bitmap的妙用 ),所以Redis是怎么实现计算Bitmap中比特位值为1的数量的?

Redis bitcount 命令

这里需要先介绍一下 variable-precision swar 算法。本节是对 Redis 5 设计与源码分析 一书 11.5.4 "bitcount命令"小章节的整理,我其实更推荐你直接去翻书看一下。

variable-precision swar 算法

十六进制二进制备注
0x555555550101 0101 0101 0101 0101 0101 0101 0101奇数位为1,偶数位为0
0x333333330011 0011 0011 0011 0011 0011 0011 0011每两位为1,两位为0
0x0F0F0F0F0000 1111 0000 1111 0000 1111 0000 1111每个字节低4位为1,高4位为0
0x010101010000 0001 0000 0001 0000 0001 0000 0001每个字节最后一位为1

观察上表中的几个数,以这几个数作为掩码参与计算:

int swar(uint32_t i) {
	// 第一步:计算每2位二进制数中1的个数
    i = (i & 0x55555555) + ((i>>1) & 0x55555555);
    // 第二步:计算每4位二进制数中1的个数
     i = (i & 0x33333333) + ((i>>2) & 0x33333333);
     // 第三步:计算每8位二进制数中1的个数
     i = (i & 0x0F0F0F0F) + ((i>>4) & 0x0F0F0F0F);
     // 第四步:将每8位当做一个int8的整数,然后相加求和
     i = (i * 0x01010101) >> 24;
     return i;
}

我本来想尝试解释下这个算法的原理,但感觉写来写去都不如国外这个大神的解释容易理解 How does this algorithm to count the number of set bits in a 32-bit integer work? 墙裂推荐大家去仔细阅读下这个回答。它的思想就是分而治之,最后合流汇总,跟归并排序的思想有点相像。

源码实现

源码传送门:github.com/redis/redis…

bitcount 命令的实现采用了查表法和swar算法相结合的方式计算。所谓“查表法”是定义一个数组,数组的各项为十进制 0 ~ 255 中所含1的数量,定义如下:

static const unsigned char bitsinbyte[256] = {0, 1, 1, 2, 1, 2, 2, 3, 1, ......}

另外,CPU一次性可以读取8个字节的内存值,因为swar算法一次性是处理4字节的内容,所以要先将非4字节整数倍地址的字节特殊处理,处理方法如下:

while ((unsigned long) p & 3 && count) { // CPU一次性读取8字节,如果4字节跨了两个8字节,需要读取两次才可以读取,所以考虑4字节对齐,只需读取一次就可以读取4字节数据。
	bits += bitsinbyte[*p++]; // 查表法直接获取当前值中1的数量,有没有让你联想到java的Integer类有个类似的缓存机制
    count--;  // 待处理字节数--
}

如果你对 “字节对齐” 概念感兴趣的话,这里有篇文章值得一看 C语言字节对齐问题详解

当处理完前面最多3个可能的字节之后,便采用swar算法来获取1的数量:

p4 = (uint32_t*) p; // 4字节
while (count>=28) { // 每次处理28个字节
	uint32_t aux1, aux2, aux3, aux4, aux5, aux6, aux7;
    
    aux1 = *p4++; // 一次读取4字节
    aux2 = *p4++; // 一次读取4字节
    ...
    aux7 = *p4++;
    count -= 28; // 当前共处理了 4*7=28 个字节,所以总长度需要减28字节
    
    aux1 = aux1 - ((aux1 >> 1) & 0x55555555);
    aux1 = (aux1 & 0x33333333) + ((aux1 >> 2) & 0x33333333);
    aux2 = aux2 - ((aux2 >> 1) & 0x55555555);
    aux2 = (aux2 & 0x33333333) + ((aux2 >> 2) & 0x33333333);
    aux3 = aux3 - ((aux3 >> 1) & 0x55555555);
    aux3 = (aux3 & 0x33333333) + ((aux3 >> 2) & 0x33333333);
    aux4 = aux4 - ((aux4 >> 1) & 0x55555555);
    aux4 = (aux4 & 0x33333333) + ((aux4 >> 2) & 0x33333333);
    aux5 = aux5 - ((aux5 >> 1) & 0x55555555);
    aux5 = (aux5 & 0x33333333) + ((aux5 >> 2) & 0x33333333);
    aux6 = aux6 - ((aux6 >> 1) & 0x55555555);
    aux6 = (aux6 & 0x33333333) + ((aux6 >> 2) & 0x33333333);
    aux7 = aux7 - ((aux7 >> 1) & 0x55555555);
    aux7 = (aux7 & 0x33333333) + ((aux7 >> 2) & 0x33333333);
    bits += ((((aux1 + (aux1 >> 4)) & 0x0F0F0F0F) +
                ((aux2 + (aux2 >> 4)) & 0x0F0F0F0F) +
                ((aux3 + (aux3 >> 4)) & 0x0F0F0F0F) +
                ((aux4 + (aux4 >> 4)) & 0x0F0F0F0F) +
                ((aux5 + (aux5 >> 4)) & 0x0F0F0F0F) +
                ((aux6 + (aux6 >> 4)) & 0x0F0F0F0F) +
                ((aux7 + (aux7 >> 4)) & 0x0F0F0F0F))* 0x01010101) >> 24;
}

当count的数量小于28之后,便可以用查表法计算出剩余的二进制1的数量了。

p = (unsigned char*)p4;
while(count--) bits += bitsinbyte[*p++];
return bits;

后记

分析完毕,把总结扔进群聊里,同事们直呼内行,并忍不住点了个赞🐶

这里也照例给一下参考到的资源:

转载自:https://juejin.cn/post/6876043235765174286
评论
请登录