2019ACL,SJTU & ByteDance,这是一篇融合了图表示学习来做多跳推理的文章。
本文作者提出的模型叫做DFGN,作者首先谈到HotpotQA这种类型的数据集带给人们两大挑战:
总结一下DFGN模型,模型从问题中的实体出发,根据paragraph构建起一张与问题实体相关的动态的entity graph,然后fusion模块会对entity graph进行建模并完成实体与文本之间的信息传递,document的向量表示也随之更新。上述的过程不断的迭代,模型就得到了一条reasoning chain,最终得到答案。
上图就是DFGN模型的整体架构,可以看出模型主要分为五大模块:
这个模块主要是用于过滤噪声段落,本文之前采用了先前的工作。用BERT来对所有的句子进行编码,做一个句子分类任务。作者把所有包含至少一条supporting fact的段落视为正例。在inference阶段,所有预测得分高于0.1的段落被选取出来,拼接到一起得到 C C C。
对于实体图,本文采取的方法是先对 C C C进行NER,提取出所有的候选实体,然后开始连边。边有三种类型:
对于问题和段落的编码,本文直接采用BERT,然后再经过一层bi-attention,得到 Q 0 ∈ R L × 2 d 2 Q_{0}\in{R^{L\times 2d_{2}}} Q0∈RL×2d2和 C 0 ∈ R M × 2 d 2 C_{0}\in{R{^{M\times 2d_{2}}}} C0∈RM×2d2。
fusion模块是本文的核心,主要包含三个子模块:
这一模块作者也称作是Tok2Ent,实现方法是用一个binary mask M M M, M i j = 1 M_{ij}=1 Mij=1表示文本中的第 i i i个token出现在第 j j j个实体的span里。然后用一个mean-max pooling得到实体的embedding E t − 1 ∈ R 2 d 2 × N E_{t-1}\in{R^{2d_{2}\times N}} Et−1∈R2d2×N。
对于图结构的建模本文采用的是GAT模型。但在这之前,作者先设计了一个soft mask,来得到Entity Graph中所有与query相关的实体,我觉得这个mask也是实现本文Introduction部分提到的dynamic local entity graph的关键。
q
~
(
t
−
1
)
=
M
e
a
n
P
o
o
l
i
n
g
(
Q
(
t
−
1
)
)
γ
i
(
t
)
=
q
~
(
t
−
1
)
V
t
e
i
(
t
−
1
)
/
d
2
m
(
t
)
=
σ
[
γ
1
(
t
)
,
γ
2
(
t
)
,
…
,
γ
1
(
t
)
]
E
~
(
t
−
1
)
=
m
(
t
)
⋅
E
(
t
−
1
)
\widetilde{q}^{(t-1)}\ =\ MeanPooling(Q^{(t-1)})\\ \gamma^{(t)}_{i}\ =\ \widetilde{q}^{(t-1)}V_{t}e^{(t-1)}_{i}/\sqrt{d_{2}}\\ m^{(t)}\ =\ \sigma[\gamma_{1}^{(t)},\ \gamma_{2}^{(t)}, \dots,\ \gamma_{1}^{(t)}] \\ \widetilde{E}^{(t-1)}=m^{(t)} \cdot E^{(t-1)}
q
(t−1) = MeanPooling(Q(t−1))γi(t) = q
(t−1)Vtei(t−1)/d2m(t) = σ[γ1(t), γ2(t),…, γ1(t)]E
(t−1)=m(t)⋅E(t−1)
V
t
V_{t}
Vt是一个linear projection,可以看出这个mask的计算是通过attention + sigmoid来实现的。这里的mask是可训练的。
得到了mask后的实体向量表示,接下来套用GAT模型。
h
i
(
t
)
=
U
e
~
i
(
t
−
1
)
+
b
s
i
,
j
(
t
)
=
L
e
a
k
y
R
e
L
u
(
W
t
T
[
h
i
(
t
)
;
h
j
(
t
)
]
)
α
i
j
(
t
)
=
e
x
p
(
s
i
,
j
(
t
)
)
∑
k
e
x
p
(
s
i
,
k
(
t
)
)
h^{(t)}_{i}\ =\ U\widetilde{e}^{(t-1)}_{i}+b\\ s^{(t)}_{i,j}\ =\ LeakyReLu(W^{T}_{t}[h^{(t)}_{i};h^{(t)}_{j}])\\ \alpha^{(t)}_{ij}\ =\ \frac{exp({s^{(t)}_{i,j}})}{\sum_{k}exp(s^{(t)}_{i,k})}
hi(t) = Ue
i(t−1)+bsi,j(t) = LeakyReLu(WtT[hi(t);hj(t)])αij(t) = ∑kexp(si,k(t))exp(si,j(t))
得到attention weight之后更新实体的向量表示:
e
i
(
t
)
=
R
e
L
u
(
∑
j
∈
N
i
α
j
,
i
(
t
)
h
j
(
t
)
)
e^{(t)}_{i}\ =\ ReLu(\sum_{j \in N_{i}}\alpha^{(t)}_{j,i}h^{(t)}_{j})
ei(t) = ReLu(j∈Ni∑αj,i(t)hj(t))
首先,作者对query进行了更新,因为当前时间步所访问到的新实体可能成为下一个时间步的start entity,因此对query的更新是必要的。更新的方式是Bi-Attention。
Q
(
t
)
=
B
i
−
A
t
t
e
n
t
i
o
n
(
Q
(
t
−
1
)
,
E
(
t
)
)
Q^{(t)}\ =\ Bi-Attention(Q^{(t-1)},E^{(t)})
Q(t) = Bi−Attention(Q(t−1),E(t))
接下来是信息的“反向传播”,即从graph传递到document,因此这一模块也被作者成为Graph2Doc。具体做法是,仍然使用Entity Graph Constructor中的
M
M
M矩阵来对实体进行过滤,然后用LSTM得到更新后的document
C
(
t
)
=
L
S
T
M
(
C
(
t
−
1
)
,
M
E
(
t
)
)
C^{(t)}\ =\ LSTM(C^{(t-1)},\ ME^{(t)})
C(t) = LSTM(C(t−1), ME(t))
HotpotQA一般有四个预测值:是否为supporting fact、answer start、answer end、question type。而本文的预测模块也是一个创新点,作者使用了级联的LSTM结构,四个LSTM层
F
i
F_{i}
Fi叠在一起
O
s
u
p
=
F
0
(
[
C
(
t
)
]
)
O
s
t
a
r
t
=
F
1
(
[
C
(
t
)
,
O
s
u
p
]
)
O
e
n
d
=
F
2
(
[
C
(
t
)
,
O
s
t
a
r
t
,
O
s
u
p
]
)
O
t
y
p
e
=
F
3
(
[
C
(
t
)
,
O
e
n
d
,
O
s
u
p
]
)
O_{sup}\ =\ F_{0}([C^{(t)}])\\ O_{start}\ =\ F_{1}([C^{(t)},\ O_{sup}])\\ O_{end}\ =\ F_{2}([C^{(t)},\ O_{start},\ O_{sup}])\\ O_{type}\ =\ F_{3}([C^{(t)},\ O_{end},\ O_{sup}])
Osup = F0([C(t)])Ostart = F1([C(t), Osup])Oend = F2([C(t), Ostart, Osup])Otype = F3([C(t), Oend, Osup])
而损失函数也是四者相加
L
=
L
s
t
a
r
t
+
L
e
n
d
+
λ
s
L
s
u
p
+
λ
t
L
t
y
p
e
L\ =\ L_{start}\ +\ L_{end}\ +\ \lambda_{s}L_{sup}\ +\ \lambda_{t}L_{type}
L = Lstart + Lend + λsLsup + λtLtype
再HotpotQA上的结果
消融实验
作者还做了case study
因篇幅问题不能全部显示,请点此查看更多更全内容