在准备大语言模型的语料的过程中,一个很重要的步骤就是去除重复的语料。特别是像CommonCrawl这样的语料来源,质量参差不齐,重复的比例也非常高。 去重处理往往是基于所谓的Jaccard相似度,也就是看文章A和文章B中重复的词的占比,用集合的方式表示也就是

\[J(A, B) = \frac{|A \cap B|}{| A \cup B |} = \frac{|A \cap B|}{| A | + | B | - |A \cap B|}\]

比如我们可以选取0.8作为阈值,即两篇文章有80%的内容重合,就只保留其中一篇。注意这个相似度是不考虑词出现的顺序的。 这样做通常可以提升语言模型的训练效果,具体可以参考Deduplicating Training Data Makes Language Models Better 这篇论文。


Minhash近似

计算两篇文章Jaccard相似度最简单的方法就是直接计算两篇文章里出现词语的交并集。 这种方法的缺点也很明显,如果文章很长的话,计算量就会变大。 对于小规模的语料这不是什么问题,而对于海量数据而言,无论计算还是内存消耗都将变得不可承受。 Minhash可以实现把一篇文章用一个较短的signature表示,这个signature有个很好的性质: 两个minhash signature的Jaccard相似度和原始文本的Jaccard相似度在概率上是一致的。 这就实现了用较小的计算和内存开销来检测相似文章。

具体为什么可以用两个较小的signature来实现对Jaccard相似度的计算呢?下面我们就来看一看minhash算法的原理。

Minhash算法会取一个随机哈希函数,将文章里每个词转换成一个int类型的哈希值,然后选取所有哈希值中最小的那个。 经典实现里,每个哈希函数都会产生一个最小哈希值,最后把这些哈希值拼起来就组成了这个文章的signature。

举个例子,假设两篇文章的词集合分别如下,

  • Set A: (“Apple”, “Fruit”, “Banana”, “Grape”, “Melon”, “Strawberry”)
  • Set B: (“Fruit”, “Cherry”, “Apple”, “Melon”, “Pear”, “Cucumber”, “Blueberry”)

其中加粗的3个词是两篇文章所共有的,总共有10个唯一的词,所以它们的Jaccard相似度就是3/10。

现在来看一下,如果我们的minhash只使用一个哈希函数,对两个集合进行操作,有多大的可能性这两个集合的minhash签名是一样的呢?

Union(A, B) = (“Apple”, “Fruit”, “Banana”, “Grape”, “Melon”, “Strawberry”, “Cherry”, “Pear”, “Cucumber”, “Blueberry”)

考虑到minhash操作一次相当于对这两个文章的并集做随机shuffle(随机哈希),然后取第一个元素。 那么对上面这个集合随机shuffle,有多大可能性会选到其中一个加粗的元素呢?答案是3/10。 这也就是说,单个minhash值相同的概率等于这两个集合的Jaccard相似度。

进一步说,如果生成多个minhash值,组成一个signature向量,那么只要数一下两个向量中对应位置有几个值相等, 除以向量长度,就能得到对原始集合Jaccard相似度的一个近似。 (因为如果Jaccard相似度是3/10,我们就会期望signature向量里的N个元素里会有$\frac{3}{10} N$个元素是一致的。) 从这里也可以看出,这个向量越大,那么近似也就越准确。

(以上这个直观例子的解释取自MinHash Tutorial with Python Code 这篇文章。)

实际操作中通常会使用N-gram先组合多个词,比如Falcon用到了5-gram, Gopher用了13-gram。


Minhash的高效计算

计算多个minhash signature最简单的方式就是使用多个哈希函数,比如说,如果我们要生成一个长度为128的signature向量, 我们就使用128个哈希函数,每个函数都对集合内每个元素做一次哈希操作,选取最小值,最后组成这个signature。这也是Spark中MinHashLSH的做法。 这个方法的问题在于,哈希操作通常比较耗时,不太适合应用于大规模文档特别是这个signature向量比较大的情况。 另一种相对高效的方法是只使用一个哈希函数,将这个哈希函数的结果$h(x)$用一系列的随机permutation(\(a\)和\(b\)) 映射到多个hash值,最后再取最小值:

\[\begin{align*} & ph(x_i) = [a \cdot h(x_i) + b] \% c \\ & Minhash_{doc} = \min_i ph(x_i) \end{align*}\]

这也是datasketch使用的算法。 实测中基于permutation的算法会比使用多个哈希函数快一倍左右。


LSH去重和性能优化

现在我们有了一个对Jaccard相似度的低维近似,意味着两个文档相似度的计算量将大大降低。 然而,下一个问题是在文档数量巨大的情况下,计算两两minhash相似度仍会是不可行的。 如果我们有3亿个文档,那么两两计算将到达3亿平方这样的量级。

解决这个问题的基本思路是先找到可能相似的一组文档,然后在这个文档集合内部做进一步的两两相似度计算。 比如说,我们可以对minhash值做一个倒排索引, 每次考察一个新文档的时候,从索引中找到至少有一个minhash值相同的一组文档,再去和这里面的每个文档做相似度计算。 Datasketch提供了建立minhash索引的工具。Spark的MinHashLSH本质上也是使用了这样的方法。 略不同的是,LSH不是去建立索引,而是把包含相同值的文档shuffle到同一个桶里,然后在这每个桶里可以做进一步的两两计算。

“至少一个Minhash值相同”可能仍然导致太多的候选pair,很自然的想法是选取signature里多个连续的minhash值, 只有这些连续值相同的文档才会被选中去做进一步的相似度计算。这样必然会牺牲一些recall,也就是某些本来相似的文档并没有被选到一起。 实际上有一个公式来计算这个概率。假设signature向量被分成了\(b\)份(band),每一个band包含\(r\)个minhash值, (也就是signature向量总长为\(b\times r\)),并且假设两个文档的Jaccard相似度是\(s\),

  1. 其中某一个band里\(r\)个值完全一致的概率是\(s^r\)
  2. band里至少有一个minhash值不相同的概率是\(1-s^r\)
  3. 所有\(b\)个band里,每个band都不相同的概率是\((1-s^r)^b\)
  4. 所以,这两个文档至少有一个band相同的概率就是\(1-(1-s^r)^b\)

(进一步详细的解释和理论可以参考Finding Similar Items 这个课程材料。)

依据公式,下面这个表格展示了不同的\(b\),\(r\)组合在各种相似度\(s\)下产生候选pair的概率

Similarity r=3, b=10 r=6, b=10 r=8, b=15 r=10, b=50 r=12, b=100 r=20, b=450
0.9 99.99% 99.94% 99.97% 99.99% 99.99% 100%
0.8 99.92% 95.22% 93.64% 99.66% 99.92% 99.46%
0.75 99.58% 85.91% 79.45% 94.49% 96.00% 76.05%
0.7 98.50% 71.40% 58.96% 76.13% 75.19% 30.17%
0.6 91.22% 37.99% 22.43% 26.15% 19.58% 1.63%
0.5 73.69% 14.57% 5.7% 4.77% 2.41% 0.04%
0.4 48.38% 4.02% 0.97% 0.52% 0.17% < 0.001%
0.3 23.94% 0.72% 0.1% 0.03% 0.005% < 0.001%
0.2 7.71% 0.06% 0.004% < 0.001% < 0.001% < 0.001%
0.1 0.99% < 0.001% < 0.001% < 0.001% < 0.001% < 0.001%

可以看到,如果signature向量(\(r\times b\))比较小,通常会产生比较多的假阳性, 也就是说,即使两个文档的相似度很低,它们仍然有很大概率成为一对候选pair。 比如上面表格中r=3, b=10(signature向量长度\(3\times 10 = 30\))的例子,即使两个文档的相似度低至0.4, 仍然有48%的概率会成为一对候选pair。这就意味着无论是基于mapreduce还是倒排索引的实现, 不仅需要一次额外的检查确认候选pair的真实相似度,而且候选集的数量可能非常大,在倒排索引里这意味着大量的内存占用, 在mapreduce程序中还意味着分桶不均和数据倾斜。这些在实际操作中都会导致非常严重的问题。

一个技巧是在mapreduce中引入salt。如果对某一个band key进行group会产生太多的数据, 可以选择加入一个salt key把这个组进一步切分。 这样一来,尽管这个band key会损失一些candidate, 但由于基于group key的去重方法具有传递性(如果A和B相似,B和C相似,那么A和C也会被判定为相似), 所以只要salt的取值范围不要太大,实际操作中并不会损失太多的recall。 如果想进一步提高recall的话,可以iteratively运行多次, 由于每次数据会被assign到不一样的salt,最终也会收敛到无法进一步去重的状态。

另一种做法是使用较大的signature向量,在对band key做group之后不再检查这一组中两两signature的真实相似度, 而只是简单地保留其中一个文档。这样会错杀一些实际并不相似的文档,但只要r和b选取合适,这样的损失是可以接受的。 比如上面的表格中r=20, b=450的情况,相似度为0.8的文档被group到一个桶中的概率高达99.46%, 而相似度为0.4的两个文档会被group到一起的概率则小到可以忽略不计。那么这种情况下错杀(false positive) 和丢失(false negative)的比例都会非常小。实用上是可以接受的。


总结

本文介绍了Minhash的基本原理,以及在大数据下如何进行快速有效的去重。 当然在十亿乃至百亿规模的数据面前,再高效的算法也是需要算力来支撑的。 Falcon团队在论文中提到他们用于处理数据的vCPU数量达到了10,000-20,000个。 相比后续model training的开销这可能不算什么,但仍旧是相当可观的。