Как модель трансформатора вычисляет самовнимание?

В модели трансформатора https://arxiv.org/pdf/1706.03762.pdf есть самовнимание, которое вычисляется с помощью softmax на векторах Query (Q) и Key (K):

Я пытаюсь понять умножение матриц:

Q = batch_size x seq_length x embed_size

K = batch_size x seq_length x embed_size

QK^T = batch_size x seq_length x seq_length

Softmax QK^T = Softmax (batch_size x seq_length x seq_length)

Как вычисляется softmax, если на каждый элемент пакета приходится seq_length x seq_length значений?

Ссылка на вычисления Pytorch будет очень полезной.

Ваше здоровье!


person tstseby    schedule 13.05.2020    source источник
comment
Этот ресурс может оказаться полезным (и он содержит код на PyTorch) nlp. sea.harvard.edu/2018/04/03/attention.html   -  person Gabriela Melo    schedule 13.05.2020


Ответы (2)


Как вычисляется softmax, если на каждый элемент пакета приходится seq_length x seq_length значений?

Softmax выполняется относительно последней оси (torch.nn.Softmax(dim=-1)(tensor), где tensor имеет форму batch_size x seq_length x seq_length), чтобы получить вероятность обращения к каждому элементу для каждого элемента во входной последовательности.


Предположим, у нас есть текстовая последовательность «Thinking Machines», поэтому у нас есть матрица формы «2 x 2» (где seq_length = 2) после выполнения QK^T.

Я использую следующую иллюстрацию (ссылка), чтобы объяснить вычисление самовнимания. Как вы знаете, сначала выполняется масштабированное скалярное произведение QK^T/square_root(d_k), а затем вычисляется softmax для каждого элемента последовательности.

Здесь Softmax выполняется для первого элемента последовательности «Мышление». Необработанный результат 14 and 12 превращается в вероятность 0.88 and 0.12 с помощью softmax. Эта вероятность указывает на то, что токен «Мышление» будет обслуживать себя с вероятностью 88%, а токен «Машины» - с вероятностью 12%. Точно так же вероятность внимания вычисляется и для токена «Машины».

введите здесь описание изображения


Примечание. Я настоятельно рекомендую прочитать эту отличную статью о Transformer. Для реализации вы можете взглянуть на OpenNMT.

person Wasi Ahmad    schedule 13.05.2020

Умножение QKᵀ - это групповое умножение матриц - оно выполняет отдельное умножение seq_length x embed_size на embed_size x seq_length batch_size раз. Каждый из них дает результат размера seq_length x seq_length, поэтому мы получаем QKᵀ, имеющий форму batch_size x seq_length x seq_length.

Предлагаемый ресурс Габриэлы Мело использует следующий код PyTorch для этой операции :

torch.matmul(query, key.transpose(-2, -1))

Это работает, потому что torch.matmul выполняет пакетное матричное умножение, когда вход имеет как минимум 3 измерения (см. https://pytorch.org/docs/stable/torch.html#torch.matmul).

person ziedaniel1    schedule 13.05.2020