perturbed:【自然语言处理】【可解释性】Perturbed Masking:分析和解释BERT的无参数探针

Perturbed Masking:分析和解释BERT的无参数探针 Perturbed Masking:Parameter-free Probing for Analyzing and Interpreting BERT

相关博客
【自然语言处理】【可解释性】NKB:用于预训练Transformers的神经知识银行
【自然语言处理】【可解释性】自注意力归因:解释Transformer内部的信息交互
【深度学习】【积分梯度】深度网络的公理归因(Axiomatic Attribution for Deep Networks)
【自然语言处理】【可解释性】Perturbed Masking:分析和解释BERT的无参数探针
【机器学习】【可解释性】LIME
【自然语言处理】【聚类】TELL:可解释神经聚类
【自然语言处理】【Prompt】语言模型即知识库(Language Models as Knowledge Bases)
【自然语言处理】【Prompt】P-tuning

一、简介

  • 近些年,预训练语言模型 ELMo、BERT、XLNet \text{ELMo、BERT、XLNet} ELMo、BERT、XLNet在各种下游任务中都实现了SOTA。为了更深入的了解预训练语言模型,许多探针任务被设计出来。探针(probe)通常是一个简单的神经网络(具有少量的额外参数),其使用预训练语言模型输出的特征向量,并执行某些简单任务(需要标注数据)。通常来说,探针的表现可以间接的衡量预训练语言模型生成向量的表现。
  • 基于探针的方法最大的缺点损失需要引入额外的参数,这将使最终的结果难以解释。因此,难以区分是预训练语言模型捕获了语义信息还是探针任务学习到了下游任务的知识,并将其编码至引入的额外参数中。
  • 本文提出了一种称为Perturbed Masking的无参数探针,其能用于分析和解释预训练语言模型。
  • Perturbed Masking通过改进 MLM \text{MLM} MLM任务的目标函数来衡量单词 x j x_j xj​对于预测 x i x_i xi​的重要性。

二、贡献

  • 引入了新的无参数探针技术Perturbed Masking,其能用来估计单词间相关性和全局语法信息抽取;
  • 在大量语言任务中评估了Perturbed Masking的有效性;
  • 将基于Perturbed Masking诱导出度依赖结构输入至下游任务中,并与解析器提供的依赖结构进行比较,发现论文中方法的性能与解析器相当,甚至更好。

三、Perturbed Masking

1. Token扰动(Token Perturbation)

一个句子可以表示为token列表: x = [ x 1 , … , x T ] \textbf{x}=[x_1,\dots,x_T] x=[x1​,…,xT​]。BERT将每个 x i x_i xi​映射为上下文表示 H θ ( x ) i H_\theta(\textbf{x})_i Hθ​(x)i​,参数 θ \theta θ表示网络参数。我们的目标是设计一个能够捕获上下文单词 x j x_j xj​对预测 x i x_i xi​影响的函数 f ( x i , x j ) f(x_i,x_j) f(xi​,xj​)。

论文提出了一种两阶段的方法来实现这个目标。首先,使用 [MASK] \text{[MASK]} [MASK]来替换单词 x i x_i xi​,并将替换后的句子 x ∖ { x i } \textbf{x}\setminus\{x_i\} x∖{xi​}输入至BERT,得到单词 x i x_i xi​的向量表示 H θ ( x ∖ { x i } ) i H_\theta(\textbf{x}\setminus\{x_i\})_i Hθ​(x∖{xi​})i​。其次,为了计算 x j ∈ x ∖ { x i } x_j\in\textbf{x}\setminus\{x_i\} xj​∈x∖{xi​}对 H θ ( x ∖ { x i } ) i H_\theta(\textbf{x}\setminus\{x_i\})_i Hθ​(x∖{xi​})i​的影响,论文进一步将 x j x_j xj​替换为 [ M A S K ] [MASK] [MASK]来获得第二个句子 x ∖ { x i , x j } \textbf{x}\setminus\{x_i,x_j\} x∖{xi​,xj​}。类似地,得到 x i x_i xi​的新向量表示 H θ ( x ∖ { x i , x j } ) i H_\theta(\textbf{x}\setminus\{x_i,x_j\})_i Hθ​(x∖{xi​,xj​})i​。

基于两个向量表示,论文定义了 f ( x i , x j ) f(x_i,x_j) f(xi​,xj​):
f ( x i , x j ) = d ( H θ ( x ∖ { x i } ) i , H θ ( x ∖ { x i , x j } ) i ) f(x_i,x_j)=d(H_\theta(\textbf{x}\setminus\{x_i\})_i,H_\theta(\textbf{x}\setminus\{x_i,x_j\})_i) f(xi​,xj​)=d(Hθ​(x∖{xi​})i​,Hθ​(x∖{xi​,xj​})i​)
其中, d ( x,y ) d(\textbf{x,y}) d(x,y)是衡量两向量距离的函数。论文中设计了两种 d ( x,y ) d(\textbf{x,y}) d(x,y):

  • Dist:向量 x \textbf{x} x和 y \textbf{y} y的欧几里得距离;
  • Prob: d ( x,y ) = a ( x ) x i − a ( y ) x i d(\textbf{x,y})=a(\textbf{x})_{x_i}-a(\textbf{y})_{x_i} d(x,y)=a(x)xi​​−a(y)xi​​;

其中, a ( ⋅ ) a(\cdot) a(⋅)会将一个向量映射为词表中单词的概率分布, a ( x ) x i a(\textbf{x})_{x_i} a(x)xi​​表示基于 x \textbf{x} x预测单词 x i x_i xi​的概率。

通过对 x \textbf{x} x中的单词 x i , x j ∈ x x_i,x_j\in\textbf{x} xi​,xj​∈x重复执行两阶段的扰动,并计算 f ( x i , x j ) f(x_i,x_j) f(xi​,xj​),可以得到一个影响矩阵( impact   matrix \textbf{impact matrix} impact matrix) F \mathcal{F} F,其中 F i , j ∈ R T × T \mathcal{F}_{i,j}\in\mathbb{R}^{T\times T} Fi,j​∈RT×T。

论文后续会从 F \mathcal{F} F中抽取语法树,并与benchmark进行比较。但是,由于BERT是基于byte-pair进行编码的,因此可能会将一个词划分为多个token(或者成为子词)。为了能够更好的评估论文的方法,针对影响矩阵做了一些调整。

  • 在每次进行扰动是,会将同一个词的多个token均替换为 [ M A S K ] [MASK] [MASK]。
  • 被拆分词的影响分数为拆分后所有token的影响分数平均值。

2. Span扰动(Span Perturbation)

上面是token级别的扰动,其可以直接扩展至span级别的扰动(这里的span可以是短语、子句或者段落)。

论文将文档 D D D建模为 N N N个无覆盖文本片段(span) D = [ e 1 , e 2 , … , e N ] D=[e_1,e_2,\dots,e_N] D=[e1​,e2​,…,eN​],其中每个文本片段 e i e_i ei​均包含一个token序列 e i = [ x 1 i , x 2 i , … , x M i ] e_i=[x_1^i,x_2^i,\dots,x_M^i] ei​=[x1i​,x2i​,…,xMi​]。

对于span级别的扰动,不再是将单个token替换为 [ M A S K ] [MASK] [MASK],而是将整个文本片段(span)替换为一组 [ M A S K ] [MASK] [MASK]。我们通过将文本片段中所有token的向量表示进行平均来获得整个文本片段的向量表示。

类似地, e j e_j ej​对 e i e_i ei​的影响为
f ( e i , e j ) = d ( H θ ( D ∖ { e i } ) i , H θ ( D ∖ { e i , e j } ) i ) f(e_i,e_j)=d(H_\theta(D\setminus\{e_i\})_i,H_\theta(D\setminus\{e_i,e_j\})_i) f(ei​,ej​)=d(Hθ​(D∖{ei​})i​,Hθ​(D∖{ei​,ej​})i​)
其中, d d d是Dist函数。

四、使用影响图(Impact Map)进行可视化

在讨论语法现象前,先分析一下从例句中导出的影响矩阵。这里使用影响图(Impact Map)来指代影响矩阵所表示的热力图。

1. 设定(Setup)

将来自 English Parallel Universal Dependencies(PUD) treebank of the CoNLL 2017 Shared Task \text{English Parallel Universal Dependencies(PUD) treebank of the CoNLL 2017 Shared Task} English Parallel Universal Dependencies(PUD) treebank of the CoNLL 2017 Shared Task的1000个句子作为样本输入BERT,并抽取影响矩阵。图1是其中一个样本的影响图。

2.依赖(Dependency)

可以发现影响图(impact map)中包含了许多的条纹(stripes)。以单词 different \text{different} different为例,能够在主对角线上观察到清晰的垂直条纹。这说明单词 different \text{different} different的出现强烈影响了之前单词的出现,这能够通过影响图倒数第二列中深色像素块表现出来。观察的这种现象与真实的依赖树一致。同理,在单词 transitions \text{transitions} transitions和 Hill \text{Hill} Hill上也能观察到相似的模式,这也促成了从影响矩阵中提取依赖树的想法。

3. 层次结构(Constituency)

图2展示了从例句中使用 Standford CoreNLP \text{Standford CoreNLP} Standford CoreNLP抽取的层次结构树(Constituency Tree)。在这个例句中,单词 meida \text{meida} meida和 on \text{on} on均与 transitions \text{transitions} transitions相邻。但是,相较于 on \text{on} on,从树结构上看 media \text{media} media和 transitions \text{transitions} transitions更加接近。在语法不知情的情况下, media \text{media} media和 on \text{on} on对 transitions \text{transitions} transitions预测具有相同的影响。但是,在例子中 media \text{media} media对 transitions \text{transitions} transitions的影响显著大于 on \text{on} on。

4. 其他结构

沿着影响图的对角线,可以看到词被聚集为4个连续的块,并且这些块具有特定的意图。此外,中间的两块具有较强的相互影响,从而形成一个更大的短语。

五、语法探针

1. Dependency探针

(略)

2. Constituency探针

由顶向下的解析方法

给定一个句子的token序列: x = [ x 1 , … , x T ] \textbf{x}=[x_1,\dots,x_T] x=[x1​,…,xT​]及其对应的影响矩阵 F \mathcal{F} F。解析的目标是找到最好的划分位置 k k k,然后将矩阵划分为 ( ( x < k ) , ( x k , ( x > k ) ) ) ((\textbf{x}_{<k}),(x_k,(\textbf{x}_{>k}))) ((x<k​),(xk​,(x>k​))),,其中 x < k = [ x 1 , … , x k − 1 ] \textbf{x}_{<k}=[x_1,\dots,x_{k-1}] x<k​=[x1​,…,xk−1​]。最优的划分位置能够保证每个划分后的成分具有最大的平均影响(impact)且成分间的单词影响要尽量小。可以按照下面的优化方式为成分 x = [ x i , x i + 1 , … , x j ] \textbf{x}=[x_i,x_{i+1},\dots,x_j] x=[xi​,xi+1​,…,xj​]决定最优的划分位置 k k k:
arg max k F i , … , k i , … , k + F k + 1 , … , j k + 1 , … , j − F i , … , k k + 1 , … , j − F k + 1 , … , j i , … , k \mathop{\text{arg max}}_k\quad \mathcal{F}_{i,\dots,k}^{i,\dots,k}+\mathcal{F}_{k+1,\dots,j}^{k+1,\dots,j}-\mathcal{F}_{i,\dots,k}^{k+1,\dots,j}-\mathcal{F}_{k+1,\dots,j}^{i,\dots,k} arg maxk​Fi,…,ki,…,k​+Fk+1,…,jk+1,…,j​−Fi,…,kk+1,…,j​−Fk+1,…,ji,…,k​
其中, F i , … , k i , … , k = ∑ a = i k ∑ b = i k f ( x a , x b ) 2 ( k − i ) \mathcal{F}_{i,\dots,k}^{i,\dots,k}=\frac{\sum_{a=i}^k\sum_{b=i}^k f(x_a,x_b)}{2(k-i)} Fi,…,ki,…,k​=2(k−i)∑a=ik​∑b=ik​f(xa​,xb​)​。

递归的划分 x < k \textbf{x}_{<k} x<k​和 x > k \textbf{x}_{>k} x>k​,直至称为单词。

相关推荐

相关文章