文章目录

回顾:多头注意力1. 编码器内的自注意力机制2. 解码器内的交叉注意力机制

NoteMHA机制(Multi-head Attention)MQA机制(Multi-Query Attention)GQA机制(Grouped-Query Attention)Reference

回顾:多头注意力

比如在pytorch中我们可以很方便的使用nn.TransformerEncoderLayer或者nn.TransformerDecoderLayer类,里面包括一个多头注意力和一个FFN前馈神经网络(这两部分之间有残差连接)和层归一化操作。

from torch import nn

# 编码层:使用Transformer

encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)

self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)

可以看到pytorch中nn.TransformerEncoderLayer的源码,就有多头注意力MultiheadAttention:

class TransformerEncoderLayer(Module):

__constants__ = ['batch_first', 'norm_first']

def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,

activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,

layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,

device=None, dtype=None) -> None:

factory_kwargs = {'device': device, 'dtype': dtype}

super().__init__()

self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,

**factory_kwargs)

# Implementation of Feedforward model

self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)

self.dropout = Dropout(dropout)

self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)

self.norm_first = norm_first

self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

self.dropout1 = Dropout(dropout)

self.dropout2 = Dropout(dropout)

# Legacy string support for activation function.

if isinstance(activation, str):

activation = _get_activation_fn(activation)

# We can't test self.activation in forward() in TorchScript,

# so stash some information about it instead.

if activation is F.relu or isinstance(activation, torch.nn.ReLU):

self.activation_relu_or_gelu = 1

elif activation is F.gelu or isinstance(activation, torch.nn.GELU):

self.activation_relu_or_gelu = 2

else:

self.activation_relu_or_gelu = 0

self.activation = activation

def __setstate__(self, state):

super().__setstate__(state)

if not hasattr(self, 'activation'):

self.activation = F.relu

def forward(

self,

src: Tensor,

src_mask: Optional[Tensor] = None,

src_key_padding_mask: Optional[Tensor] = None,

is_causal: bool = False) -> Tensor:

r"""Pass the input through the encoder layer.

Args:

src: the sequence to the encoder layer (required).

src_mask: the mask for the src sequence (optional).

is_causal: If specified, applies a causal mask as src_mask.

Default: ``False``.

src_key_padding_mask: the mask for the src keys per batch (optional).

Shape:

see the docs in Transformer class.

"""

src_key_padding_mask = F._canonical_mask(

mask=src_key_padding_mask,

mask_name="src_key_padding_mask",

other_type=F._none_or_dtype(src_mask),

other_name="src_mask",

target_type=src.dtype

)

src_mask = F._canonical_mask(

mask=src_mask,

mask_name="src_mask",

other_type=None,

other_name="",

target_type=src.dtype,

check_other=False,

)

# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

why_not_sparsity_fast_path = ''

if not src.dim() == 3:

why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"

elif self.training:

why_not_sparsity_fast_path = "training is enabled"

elif not self.self_attn.batch_first :

why_not_sparsity_fast_path = "self_attn.batch_first was not True"

elif not self.self_attn._qkv_same_embed_dim :

why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"

elif not self.activation_relu_or_gelu:

why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"

elif not (self.norm1.eps == self.norm2.eps):

why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"

elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None):

why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input"

elif self.self_attn.num_heads % 2 == 1:

why_not_sparsity_fast_path = "num_head is odd"

elif torch.is_autocast_enabled():

why_not_sparsity_fast_path = "autocast is enabled"

if not why_not_sparsity_fast_path:

tensor_args = (

src,

self.self_attn.in_proj_weight,

self.self_attn.in_proj_bias,

self.self_attn.out_proj.weight,

self.self_attn.out_proj.bias,

self.norm1.weight,

self.norm1.bias,

self.norm2.weight,

self.norm2.bias,

self.linear1.weight,

self.linear1.bias,

self.linear2.weight,

self.linear2.bias,

)

# We have to use list comprehensions below because TorchScript does not support

# generator expressions.

if torch.overrides.has_torch_function(tensor_args):

why_not_sparsity_fast_path = "some Tensor argument has_torch_function"

elif not all((x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args):

why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU"

elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):

why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "

"input/output projection weights or biases requires_grad")

if not why_not_sparsity_fast_path:

merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src)

return torch._transformer_encoder_layer_fwd(

src,

self.self_attn.embed_dim,

self.self_attn.num_heads,

self.self_attn.in_proj_weight,

self.self_attn.in_proj_bias,

self.self_attn.out_proj.weight,

self.self_attn.out_proj.bias,

self.activation_relu_or_gelu == 2,

self.norm_first,

self.norm1.eps,

self.norm1.weight,

self.norm1.bias,

self.norm2.weight,

self.norm2.bias,

self.linear1.weight,

self.linear1.bias,

self.linear2.weight,

self.linear2.bias,

merged_mask,

mask_type,

)

x = src

if self.norm_first:

x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)

x = x + self._ff_block(self.norm2(x))

else:

x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))

x = self.norm2(x + self._ff_block(x))

return x

# self-attention block

def _sa_block(self, x: Tensor,

attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:

x = self.self_attn(x, x, x,

attn_mask=attn_mask,

key_padding_mask=key_padding_mask,

need_weights=False, is_causal=is_causal)[0]

return self.dropout1(x)

# feed forward block

def _ff_block(self, x: Tensor) -> Tensor:

x = self.linear2(self.dropout(self.activation(self.linear1(x))))

return self.dropout2(x)

1. 编码器内的自注意力机制

可以看到上面编码器中有自注意机制函数,通常QKV都来自同一个序列,即为序列中的每个token生成Q、K、V矩阵(是由同一个X输入进经过三个不同的线性变化得到的,对应的线性变换矩阵W_q、W_K、W_v是待学习的权重矩阵),比如“i love large language model”。

# self-attention block

def _sa_block(self, x: Tensor,

attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:

x = self.self_attn(x, x, x,

attn_mask=attn_mask,

key_padding_mask=key_padding_mask,

need_weights=False, is_causal=is_causal)[0]

return self.dropout1(x)

为了得到编码单词

x

i

x_i

xi​ 时所需要关注的上下文信息,通过位置

i

\mathrm{i}

i 查询向量与其他位置的键向量做点积得到匹配分数

q

i

k

1

,

q

i

k

2

,

,

q

i

k

t

\boldsymbol{q}_i \cdot \boldsymbol{k}_1, \boldsymbol{q}_i \cdot \boldsymbol{k}_2, \ldots, \boldsymbol{q}_i \cdot \boldsymbol{k}_t

qi​⋅k1​,qi​⋅k2​,…,qi​⋅kt​为了防止过大的匹配分数在后续Softmax 计算过程中导致的梯度爆炸及收敛效率差的问题,这些得分会除以放缩因子

d

\sqrt{d}

d

​ 以稳定优化放缩后的得分经过Softmax 归一化为概率(即一个权重矩阵),与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。上述计算过程如下,其实就是对各个value进行加权求和:

Z

=

Attention

(

Q

,

K

,

V

)

=

Softmax

(

Q

K

d

)

V

\boldsymbol{Z}=\operatorname{Attention}(\boldsymbol{Q}, \boldsymbol{K}, \boldsymbol{V})=\operatorname{Softmax}\left(\frac{\boldsymbol{Q} \boldsymbol{K}^{\top}}{\sqrt{d}}\right) \boldsymbol{V}

Z=Attention(Q,K,V)=Softmax(d

​QK⊤​)V

2. 解码器内的交叉注意力机制

在解码器中有self-attention和cross-attention模块:

前者:qkv都是解码器到此为止的输出,和编码器的类似后者:query来自于解码器的当前输出,用编码器的输出作为key和value,使得解码器在生成每一个输出时都能考虑输入序列的信息。

class TransformerDecoderLayer(Module):

__constants__ = ['batch_first', 'norm_first']

def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,

activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,

layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,

device=None, dtype=None) -> None:

factory_kwargs = {'device': device, 'dtype': dtype}

super().__init__()

self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,

**factory_kwargs)

self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,

**factory_kwargs)

# Implementation of Feedforward model

self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)

self.dropout = Dropout(dropout)

self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)

self.norm_first = norm_first

self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)

self.dropout1 = Dropout(dropout)

self.dropout2 = Dropout(dropout)

self.dropout3 = Dropout(dropout)

# Legacy string support for activation function.

if isinstance(activation, str):

self.activation = _get_activation_fn(activation)

else:

self.activation = activation

def __setstate__(self, state):

if 'activation' not in state:

state['activation'] = F.relu

super().__setstate__(state)

def forward(

self,

tgt: Tensor,

memory: Tensor,

tgt_mask: Optional[Tensor] = None,

memory_mask: Optional[Tensor] = None,

tgt_key_padding_mask: Optional[Tensor] = None,

memory_key_padding_mask: Optional[Tensor] = None,

tgt_is_causal: bool = False,

memory_is_causal: bool = False,

) -> Tensor:

# see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf

x = tgt

if self.norm_first:

x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)

x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)

x = x + self._ff_block(self.norm3(x))

else:

x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))

x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))

x = self.norm3(x + self._ff_block(x))

return x

# self-attention block

def _sa_block(self, x: Tensor,

attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:

x = self.self_attn(x, x, x,

attn_mask=attn_mask,

key_padding_mask=key_padding_mask,

is_causal=is_causal,

need_weights=False)[0]

return self.dropout1(x)

# multihead attention block

def _mha_block(self, x: Tensor, mem: Tensor,

attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:

x = self.multihead_attn(x, mem, mem,

attn_mask=attn_mask,

key_padding_mask=key_padding_mask,

is_causal=is_causal,

need_weights=False)[0]

return self.dropout2(x)

# feed forward block

def _ff_block(self, x: Tensor) -> Tensor:

x = self.linear2(self.dropout(self.activation(self.linear1(x))))

return self.dropout3(x)

从上面的源码中可以看到forward部分经过自注意力_sa_block的计算后,解码器会使用交叉注意力self._mha_block,这函数的第一个参数self.norm2(x)是目标序列的嵌入表示(作为Q),第二个参数memory是编码器的输出(作为K和V),从而让模型理解输入序列和目标序列之间的依赖关系。

x = tgt

if self.norm_first:

x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)

x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)

x = x + self._ff_block(self.norm3(x))

else:

x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))

x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))

x = self.norm3(x + self._ff_block(x))

return x

Note

LLama2的注意力机制使用了GQA。三种机制的图如下:

MHA机制(Multi-head Attention)

MHA(Multi-head Attention)是标准的多头注意力机制,包含h个Query、Key 和 Value 矩阵。所有注意力头的 Key 和 Value 矩阵权重不共享

MQA机制(Multi-Query Attention)

MQA(Multi-Query Attention,Fast Transformer Decoding: One Write-Head is All You Need)是多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。与MHA不同的,MQA 让所有的头之间共享同一份 Key 和 Value 矩阵,每个头只单独保留了一份 Query 参数,从而大大减少 Key 和 Value 矩阵的参数量。

GQA机制(Grouped-Query Attention)

GQA(Grouped-Query Attention,GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints)是分组查询注意力,GQA将查询头分成G组,每个组共享一个Key 和 Value 矩阵。GQA-G是指具有G组的grouped-query attention。GQA-1具有单个组,因此具有单个Key 和 Value,等效于MQA。若GQA-H具有与头数相等的组,则其等效于MHA。GQA介于MHA和MQA之间。GQA机制,多头共用 KV Cache。

Reference

[1] 一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA [2] Transformer系列:注意力机制的优化,MQA和GQA原理简述 [3] Navigating the Attention Landscape: MHA, MQA, and GQA Decoded [4] 【NLP】(task2)图解attention+transformer(代码讲解)

好文阅读

评论可见,请评论后查看内容,谢谢!!!评论后请刷新页面。