Перевод с RNN

Часть 4: Двунаправленный и внимание RNN

В полном руководстве по НЛП с фастаем

Перейдите по ссылке на всю серию, нажав здесь: Полное руководство по НЛП с фастай

В этом посте будет собрано все, что мы узнали до этого момента, а затем представлен перевод с помощью RNN.

Это захватывающе, потому что результаты нашего путешествия по изучению НЛП можно резюмировать следующим образом:

В восторге???
Начнем…

Перевод с RNN

В этом посте мы займемся переводом. Мы будем переводить с французского на английский, и, чтобы наша задача оставалась управляемой, мы ограничимся переводом вопросов.

Эта задача является примером последовательности для последовательности (seq2seq). Seq2seq может быть более сложной задачей, чем классификация, поскольку выходные данные имеют переменную длину (и обычно отличаются от длины входных данных.

Французско-английские параллельные тексты с http://www.statmt.org/wmt15/translation-task.html. Он был создан Крисом Каллисон-Бёрчем, который просканировал миллионы веб-страниц, а затем использовал набор простых эвристик для преобразования URL-адресов на французском языке в URL-адреса на английском языке (т. е. заменив fr на en и около 40 других рукописных правила) и предположить, что эти документы являются переводами друг друга.

Перевод намного сложнее в прямом PyTorch: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html

In [1]:from fastai.text import *

Загрузите и предварительно обработайте наши данные

Мы начнем с сокращения исходного набора данных до вопросов. Вам нужно выполнить это только один раз, раскомментируйте для запуска. Набор данных можно скачать здесь.

In [2]:path = Config().data_path()
In [3]:
# ! wget https://s3.amazonaws.com/fast-ai-nlp/giga-fren.tgz -P {path}
In [4]:# ! tar xf {path}/giga-fren.tgz -C {path}
In [3]:
path = Config().data_path()/'giga-fren'
path.ls()
Out[3]:
[PosixPath('/home/racheltho/.fastai/data/giga-fren/models'),
...
In [6]:# with open(path/'giga-fren.release2.fixed.fr') as f: fr = f.read().split('\n')
In [7]:# with open(path/'giga-fren.release2.fixed.en') as f: en = f.read().split('\n')

Мы будем использовать регулярное выражение для выбора вопросов, находя строки в наборе данных на английском языке, которые начинаются с «Wh» и заканчиваются вопросительным знаком. Вам нужно запустить эти строки только один раз:

In [8]:
# re_eq = re.compile('^(Wh[^?.!]+\?)')
# re_fq = re.compile('^([^?.!]+\?)')
# en_fname = path/'giga-fren.release2.fixed.en'
# fr_fname = path/'giga-fren.release2.fixed.fr'
In [9]:
# lines = ((re_eq.search(eq), re_fq.search(fq)) 
#         for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8')))
# qs = [(e.group(), f.group()) for e,f in lines if e and f]
In [10]:
# qs = [(q1,q2) for q1,q2 in qs]
# df = pd.DataFrame({'fr': [q[1] for q in qs], 'en': [q[0] for q in qs]}, columns = ['en', 'fr'])
# df.to_csv(path/'questions_easy.csv', index=False)
In [11]: path.ls()
Out[11]:
[PosixPath('/home/racheltho/.fastai/data/giga-fren/models'),
...

Загрузите наши данные в DataBunch

Теперь наши вопросы выглядят так:

In [4]:
df = pd.read_csv(path/'questions_easy.csv')
df.head()

Для простоты мы пишем все строчными буквами.

In [5]:
df['en'] = df['en'].apply(lambda x:x.lower())
df['fr'] = df['fr'].apply(lambda x:x.lower())

Во-первых, нам нужно будет сопоставить входные данные и цели в пакете: они имеют разную длину, поэтому нам нужно добавить отступы, чтобы длина последовательности была одинаковой;

In [7]:
def seq2seq_collate(samples, pad_idx=1, pad_first=True, backwards=False):
    "Function that collect samples and adds padding. Flips token order if needed"
    samples = to_data(samples)
    max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])
    res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx
    res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx
    if backwards: pad_first = not pad_first
    for i,s in enumerate(samples):
        if pad_first: 
            res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])
        else:         
            res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])
    if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)
    return res_x,res_y

Затем мы создаем специальный DataBunch, который использует эту функцию сопоставления.

In [8]:doc(Dataset)
In [9]:doc(DataLoader)
In [6]:doc(DataBunch)
In [20]:
class Seq2SeqDataBunch(TextDataBunch):
    "Create a `TextDataBunch` suitable for training an RNN classifier."
    @classmethod
    def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,
               dl_tfms=None, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:
        "Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`"
        datasets = cls._init_ds(train_ds, valid_ds, test_ds)
        val_bs = ifnone(val_bs, bs)
        collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)
        train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)
        train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)
        dataloaders = [train_dl]
        for ds in datasets[1:]:
            lengths = [len(t) for t in ds.x.items]
            sampler = SortSampler(ds.x, key=lengths.__getitem__)
            dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))
        return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)
In [ ]:SortishSampler??

И подкласс TextList, который будет использовать этот класс DataBunch в вызове .databunch и будет использовать TextList для маркировки (поскольку нашими целями являются другие тексты).

In [21]:
class Seq2SeqTextList(TextList):
    _bunch = Seq2SeqDataBunch
    _label_cls = TextList

Это все, что нам нужно для использования API блока данных!

In [22]:
src = Seq2SeqTextList.from_df(df, path = path, cols='fr').split_by_rand_pct(seed=42).label_from_df(cols='en', label_cls=TextList)
In [23]:
np.percentile([len(o) for o in src.train.x.items] + [len(o) for o in src.valid.x.items], 90)
Out[23]:28.0
In [24]:
np.percentile([len(o) for o in src.train.y.items] + [len(o) for o in src.valid.y.items], 90)
Out[24]:23.0

Мы удаляем элементы, в которых одна из целей имеет длину более 30 токенов.

In [25]: src = src.filter_by_func(lambda x,y: len(x) > 30 or len(y) > 30)
In [26]:len(src.train) + len(src.valid)
Out[26]:48352
In [27]:data = src.databunch()
In [28]:data.save()
In [29]:data
Out[29]:
Seq2SeqDataBunch;
Train: LabelList (38706 items)
x: Seq2SeqTextList
xxbos qu’est - ce que la lumière ?
...
Test: None
In [30]:path
Out[30]:PosixPath('/home/racheltho/.fastai/data/giga-fren')
In [31]:data = load_data(path)
In [32]:data.show_batch()

Создайте нашу модель

Предварительно обученные встраивания

Вам нужно будет загрузить вложения слов (векторы сканирования) из документации fastText. FastText имеет предварительно обученные векторы слов для 157 языков, обученные на Common Crawl и Википедии. Эти модели были обучены с помощью CBOW.
Если вам нужно освежить в памяти информацию о встраивании слов, вы можете ознакомиться с моим нежным введением в этом семинаре по встраиванию слов с сопроводительным репозиторием github.

Дополнительная информация о CBOW (непрерывный пакет слов и пропуск граммов):

Чтобы установить FastText:

$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
$ pip install .
In [33]: import fastText as ft

Строки для загрузки векторов слов нужно запустить только один раз:

In [60]:
# ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz -P {path}
# ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.bin.gz -P {path}
In [61]:
# gunzip {path} / cc.en.300.bin.gz
# gunzip {path} / cc.fr.300.bin.gz
In [34]:
fr_vecs = ft.load_model(str((path/'cc.fr.300.bin')))
en_vecs = ft.load_model(str((path/'cc.en.300.bin')))

Мы создаем модуль встраивания с предварительно обученными векторами и случайными данными для недостающих частей.

In [35]:
def create_emb(vecs, itos, em_sz=300, mult=1.):
    emb = nn.Embedding(len(itos), em_sz, padding_idx=1)
    wgts = emb.weight.data
    vec_dic = {w:vecs.get_word_vector(w) for w in vecs.get_words()}
    miss = []
    for i,w in enumerate(itos):
        try: wgts[i] = tensor(vec_dic[w])
        except: miss.append(w)
    return emb
In [36]:
emb_enc = create_emb(fr_vecs, data.x.vocab.itos)
emb_dec = create_emb(en_vecs, data.y.vocab.itos)
In [37]:emb_enc.weight.size(), emb_dec.weight.size()
Out[37]:(torch.Size([11336, 300]), torch.Size([8152, 300]))
In [38]:model_path = Config().model_path()
In [39]:
torch.save(emb_enc, model_path/'fr_emb.pth')
torch.save(emb_dec, model_path/'en_emb.pth')
In [40]:
emb_enc = torch.load(model_path/'fr_emb.pth')
emb_dec = torch.load(model_path/'en_emb.pth')

Наша модель

Вопрос обзора: какие два типа чисел существуют в глубоком обучении?

Кодеры и декодеры

Сама по себе модель состоит из энкодера и декодера

Кодировщик представляет собой рекуррентную нейронную сеть, и мы передаем ему наше входное предложение, производя вывод (который мы пока отбрасываем) и скрытое состояние. Скрытое состояние — это активация, исходящая из RNN.

Затем это скрытое состояние передается декодеру (другой RNN), который использует его в сочетании с прогнозируемыми выходными данными для получения перевода. Мы зацикливаемся до тех пор, пока декодер не создаст токен заполнения (или до 30 итераций, чтобы убедиться, что это не бесконечный цикл в начале обучения).

Мы будем использовать GRU для нашего кодировщика и отдельный GRU для нашего декодера. Другие варианты — использовать LSTM или QRNN (см. здесь). GRU, LSTM и QRNN решают проблему отсутствия в RNN долговременной памяти.

Ссылки:

In [43]:
class Seq2SeqRNN(nn.Module):
    def __init__(self, emb_enc, emb_dec, 
                    nh, out_sl, 
                    nl=2, bos_idx=0, pad_idx=1):
        super().__init__()
        self.nl,self.nh,self.out_sl = nl,nh,out_sl
        self.bos_idx,self.pad_idx = bos_idx,pad_idx
        self.em_sz_enc = emb_enc.embedding_dim
        self.em_sz_dec = emb_dec.embedding_dim
        self.voc_sz_dec = emb_dec.num_embeddings
                 
        self.emb_enc = emb_enc
        self.emb_enc_drop = nn.Dropout(0.15)
        self.gru_enc = nn.GRU(self.em_sz_enc, nh, num_layers=nl,
                              dropout=0.25, batch_first=True)
        self.out_enc = nn.Linear(nh, self.em_sz_dec, bias=False)
        
        self.emb_dec = emb_dec
        self.gru_dec = nn.GRU(self.em_sz_dec, self.em_sz_dec, num_layers=nl,
                              dropout=0.1, batch_first=True)
        self.out_drop = nn.Dropout(0.35)
        self.out = nn.Linear(self.em_sz_dec, self.voc_sz_dec)
        self.out.weight.data = self.emb_dec.weight.data
        
    def encoder(self, bs, inp):
        h = self.initHidden(bs)
        emb = self.emb_enc_drop(self.emb_enc(inp))
        _, h = self.gru_enc(emb, h)
        h = self.out_enc(h)
        return h
    
    def decoder(self, dec_inp, h):
        emb = self.emb_dec(dec_inp).unsqueeze(1)
        outp, h = self.gru_dec(emb, h)
        outp = self.out(self.out_drop(outp[:,0]))
        return h, outp
        
    def forward(self, inp):
        bs, sl = inp.size()
        h = self.encoder(bs, inp)
        dec_inp = inp.new_zeros(bs).long() + self.bos_idx
        
        res = []
        for i in range(self.out_sl):
            h, outp = self.decoder(dec_inp, h)
            dec_inp = outp.max(1)[1]
            res.append(outp)
            if (dec_inp==self.pad_idx).all(): break
        return torch.stack(res, dim=1)
    
    def initHidden(self, bs): return one_param(self).new_zeros(self.nl, bs, self.nh)
In [44]:xb,yb = next(iter(data.valid_dl))
In [45]:xb.shape
Out[45]:torch.Size([64, 30])
In [46]:rnn = Seq2SeqRNN(emb_enc, emb_dec, 256, 30)
In [47]:rnn
Out[47]:
Seq2SeqRNN(
  (emb_enc): Embedding(11336, 300, padding_idx=1)
  (emb_enc_drop): Dropout(p=0.15)
  (gru_enc): GRU(300, 256, num_layers=2, batch_first=True, dropout=0.25)
  (out_enc): Linear(in_features=256, out_features=300, bias=False)
  (emb_dec): Embedding(8152, 300, padding_idx=1)
  (gru_dec): GRU(300, 300, num_layers=2, batch_first=True, dropout=0.1)
  (out_drop): Dropout(p=0.35)
  (out): Linear(in_features=300, out_features=8152, bias=True)
)

In [48]: len(xb[0])
Out[48]:30
In [51]:h = rnn.encoder(64, xb.cpu())
In [52]:h.size()
Out[52]:torch.Size([2, 64, 300])

Поля потерь выводятся и настраиваются таким образом, чтобы они имели одинаковый размер, прежде чем использовать обычную сглаженную версию кросс-энтропии. Мы делаем то же самое для точности.

In [53]:
def seq2seq_loss(out, targ, pad_idx=1):
    bs,targ_len = targ.size()
    _,out_len,vs = out.size()
    if targ_len>out_len: out  = F.pad(out,  (0,0,0,targ_len-out_len,0,0), value=pad_idx)
    if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)
    return CrossEntropyFlat()(out, targ)

Обучите нашу модель

In [54]:learn = Learner(data, rnn, loss_func=seq2seq_loss)
In [55]:learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [56]:learn.recorder.plot()

In [57]:learn.fit_one_cycle(4, 1e-2)

Освободим немного оперативной памяти

In [58]:
del fr_vecs
del en_vecs

Поскольку потери не очень интерпретируемы, давайте также посмотрим на точность. Опять же, мы добавим заполнение, чтобы выходные данные и цель имели одинаковую длину.

In [59]:
def seq2seq_acc(out, targ, pad_idx=1):
    bs,targ_len = targ.size()
    _,out_len,vs = out.size()
    if targ_len>out_len: out  = F.pad(out,  (0,0,0,targ_len-out_len,0,0), value=pad_idx)
    if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)
    out = out.argmax(2)
    return (out==targ).float().mean()

Метрика Bleu (см. специальный блокнот)

В переводе обычно используется метрика BLEU.
Отличный пост от Rachael Tatman: Оценка вывода текста в NLP: BLEU на свой страх и риск

In [60]:
class NGram():
    def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n
    def __eq__(self, other):
        if len(self.ngram) != len(other.ngram): return False
        return np.all(np.array(self.ngram) == np.array(other.ngram))
    def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))

In [61]:
def get_grams(x, n, max_n=5000):
    return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]
In [62]:
def get_correct_ngrams(pred, targ, n, max_n=5000):
    pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)
    pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)
    return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)
In [63]:
class CorpusBLEU(Callback):
    def __init__(self, vocab_sz):
        self.vocab_sz = vocab_sz
        self.name = 'bleu'
    
    def on_epoch_begin(self, **kwargs):
        self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        last_output = last_output.argmax(dim=-1)
        for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):
            self.pred_len += len(pred)
            self.targ_len += len(targ)
            for i in range(4):
                c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)
                self.corrects[i] += c
                self.counts[i]   += t
    
    def on_epoch_end(self, last_metrics, **kwargs):
        precs = [c/t for c,t in zip(self.corrects,self.counts)]
        len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1
        bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)
        return add_metrics(last_metrics, bleu)

Обучение с метриками

In [64]:
learn = Learner(data, rnn, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))])
In [65]:learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [66]:learn.recorder.plot()

In [67]:
learn.fit_one_cycle(4, 1e-2)
learn.fit_one_cycle(4, 1e-3)

Насколько хороша наша модель? Давайте посмотрим несколько прогнозов.

In [68]:
def get_predictions(learn, ds_type=DatasetType.Valid):
    learn.model.eval()
    inputs, targets, outputs = [],[],[]
    with torch.no_grad():
        for xb,yb in progress_bar(learn.dl(ds_type)):
            out = learn.model(xb)
            for x,y,z in zip(xb,yb,out):
                inputs.append(learn.data.train_ds.x.reconstruct(x))
                targets.append(learn.data.train_ds.y.reconstruct(y))
                outputs.append(learn.data.train_ds.y.reconstruct(z.argmax(1)))
    return inputs, targets, outputs
In [88]: inputs, targets, outputs = get_predictions(learn)

100.00% [151/151 00:24<00:00]

In [89]:inputs[700], targets[700], outputs[700]
Out[89]:(Text xxbos quels sont les résultats prévus à court et à ...
In [90]:inputs[701], targets[701], outputs[701]
Out[90]:(Text xxbos de quel(s ) xxunk ) a - t - on besoin pour xx...
In [91]:inputs[2513], targets[2513], outputs[2513]
Out[91]:(Text xxbos de quelles façons l'expérience et les capaci...
In [92]:inputs[4000], targets[4000], outputs[4000]
Out[92]:(Text xxbos qu'est - ce que la maladie de xxunk -...

Обычно он начинается хорошо, но в конце вопроса прерывается повторяющимися словами.

Принуждение учителя

Один из способов помочь обучению — помочь декодеру, скармливая ему реальные цели вместо его прогнозов (если он начнет с неправильных слов, очень маловероятно, что он даст нам правильный перевод). Мы делаем это все время в начале, затем постепенно уменьшаем количество принуждения учителя.

In [83]:
class TeacherForcing(LearnerCallback):
    
    def __init__(self, learn, end_epoch):
        super().__init__(learn)
        self.end_epoch = end_epoch
    
    def on_batch_begin(self, last_input, last_target, train, **kwargs):
        if train: return {'last_input': [last_input, last_target]}
    
    def on_epoch_begin(self, epoch, **kwargs):
        self.learn.model.pr_force = 1 - epoch/self.end_epoch

Мы добавим следующий код в наш метод forward:

if (targ is not None) and (random.random()<self.pr_force):
        if i>=targ.shape[1]: break
        dec_inp = targ[:,i]

Кроме того, forward будет принимать дополнительный аргумент target.

In [88]:
class Seq2SeqRNN_tf(nn.Module):
    def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):
        super().__init__()
        self.nl,self.nh,self.out_sl = nl,nh,out_sl
        self.bos_idx,self.pad_idx = bos_idx,pad_idx
        self.em_sz_enc = emb_enc.embedding_dim
        self.em_sz_dec = emb_dec.embedding_dim
        self.voc_sz_dec = emb_dec.num_embeddings
                 
        self.emb_enc = emb_enc
        self.emb_enc_drop = nn.Dropout(0.15)
        self.gru_enc = nn.GRU(self.em_sz_enc, nh, num_layers=nl,
                              dropout=0.25, batch_first=True)
        self.out_enc = nn.Linear(nh, self.em_sz_dec, bias=False)
        
        self.emb_dec = emb_dec
        self.gru_dec = nn.GRU(self.em_sz_dec, self.em_sz_dec, num_layers=nl,
                              dropout=0.1, batch_first=True)
        self.out_drop = nn.Dropout(0.35)
        self.out = nn.Linear(self.em_sz_dec, self.voc_sz_dec)
        self.out.weight.data = self.emb_dec.weight.data
        self.pr_force = 0.
        
    def encoder(self, bs, inp):
        h = self.initHidden(bs)
        emb = self.emb_enc_drop(self.emb_enc(inp))
        _, h = self.gru_enc(emb, h)
        h = self.out_enc(h)
        return h
    
    def decoder(self, dec_inp, h):
        emb = self.emb_dec(dec_inp).unsqueeze(1)
        outp, h = self.gru_dec(emb, h)
        outp = self.out(self.out_drop(outp[:,0]))
        return h, outp
            
    def forward(self, inp, targ=None):
        bs, sl = inp.size()
        h = self.encoder(bs, inp)
        dec_inp = inp.new_zeros(bs).long() + self.bos_idx
        
        res = []
        for i in range(self.out_sl):
            h, outp = self.decoder(dec_inp, h)
            res.append(outp)
            dec_inp = outp.max(1)[1]
            if (dec_inp==self.pad_idx).all(): break
            if (targ is not None) and (random.random()<self.pr_force):
                if i>=targ.shape[1]: continue
                dec_inp = targ[:,i]
        return torch.stack(res, dim=1)
    def initHidden(self, bs): return one_param(self).new_zeros(self.nl, bs, self.nh)
In [90]:
emb_enc = torch.load(model_path/'fr_emb.pth')
emb_dec = torch.load(model_path/'en_emb.pth')
In [91]:
rnn_tf = Seq2SeqRNN_tf(emb_enc, emb_dec, 256, 30)
learn = Learner(data, rnn_tf, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],
               callback_fns=partial(TeacherForcing, end_epoch=3))
In [74]:
learn.lr_find()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
In [75]:learn.recorder.plot()

In [92]:learn.fit_one_cycle(6, 3e-3)

In [77]:inputs, targets, outputs = get_predictions(learn)

100.00% [151/151 00:23<00:00]

In [78]:inputs[700],targets[700],outputs[700]
Out[78]:(Text xxbos qui a le pouvoir de modifier ...
In [79]:inputs[2513], targets[2513], outputs[2513]
Out[79]:(Text xxbos quelles sont les deux tendances qui ont...
In [80]:inputs[4000], targets[4000], outputs[4000]
Out[80]:(Text xxbos où les aires marines nationales de conser...

Время перейти к ВНИМАНИЕ!?

"Эта последовательная природа [RNN] препятствует распараллеливанию в обучающих примерах, что становится критическим при более длинных последовательностях, поскольку ограничения памяти ограничивают пакетную обработку примеров".

Перевод Seq2Seq с вниманием

Внимание — это метод, который использует выходные данные нашего кодировщика: вместо того, чтобы полностью его отбрасывать, мы используем его с нашим скрытым состоянием, чтобы обращать внимание на определенные слова во входном предложении для прогнозов в выходном предложении. В частности, мы вычисляем веса внимания, а затем добавляем к входным данным декодера линейную комбинацию выходных данных кодировщика с этими весами внимания.

Хорошая иллюстрация внимания исходит из этого поста в блоге Джея Аламмара (визуализация изначально из Tensor2Tensor Notebook):

Вторая вещь, которая может помочь, — это использование двунаправленной модели для кодировщика. Мы устанавливаем параметр bidrectional равным True для нашего кодировщика GRU и удваиваем количество входных данных для уровня линейного вывода кодировщика.

Кроме того, теперь нам нужно установить наше скрытое состояние:

hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()
hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))

Код для повторного запуска с самого начала

In [1]:from fastai.text import *
In [2]:path = Config().data_path()
In [3]:path = Config().data_path()/'giga-fren'
In [4]:
def seq2seq_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]:
    "Function that collect samples and adds padding. Flips token order if needed"
    samples = to_data(samples)
    max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples])
    res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx
    res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx
    if backwards: pad_first = not pad_first
    for i,s in enumerate(samples):
        if pad_first: 
            res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1])
        else:         
            res_x[i,:len(s[0])],res_y[i,:len(s[1])] = LongTensor(s[0]),LongTensor(s[1])
    if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1)
    return res_x,res_y
class Seq2SeqDataBunch(TextDataBunch):
    "Create a `TextDataBunch` suitable for training an RNN classifier."
    @classmethod
    def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1,
               dl_tfms=None, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch:
        "Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`"
        datasets = cls._init_ds(train_ds, valid_ds, test_ds)
        val_bs = ifnone(val_bs, bs)
        collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards)
        train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2)
        train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs)
        dataloaders = [train_dl]
        for ds in datasets[1:]:
            lengths = [len(t) for t in ds.x.items]
            sampler = SortSampler(ds.x, key=lengths.__getitem__)
            dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs))
        return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check)
class Seq2SeqTextList(TextList):
    _bunch = Seq2SeqDataBunch
    _label_cls = TextList
data = load_data(path)
model_path = Config().model_path()
emb_enc = torch.load(model_path/'fr_emb.pth')
emb_dec = torch.load(model_path/'en_emb.pth')
def seq2seq_loss(out, targ, pad_idx=1):
    bs,targ_len = targ.size()
    _,out_len,vs = out.size()
    if targ_len>out_len: out  = F.pad(out,  (0,0,0,targ_len-out_len,0,0), value=pad_idx)
    if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)
    return CrossEntropyFlat()(out, targ)
def seq2seq_acc(out, targ, pad_idx=1):
    bs,targ_len = targ.size()
    _,out_len,vs = out.size()
    if targ_len>out_len: out  = F.pad(out,  (0,0,0,targ_len-out_len,0,0), value=pad_idx)
    if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx)
    out = out.argmax(2)
    return (out==targ).float().mean()
class NGram():
    def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n
    def __eq__(self, other):
        if len(self.ngram) != len(other.ngram): return False
        return np.all(np.array(self.ngram) == np.array(other.ngram))
    def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)]))
def get_grams(x, n, max_n=5000):
    return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)]
def get_correct_ngrams(pred, targ, n, max_n=5000):
    pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n)
    pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams)
    return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams)
def get_predictions(learn, ds_type=DatasetType.Valid):
    learn.model.eval()
    inputs, targets, outputs = [],[],[]
    with torch.no_grad():
        for xb,yb in progress_bar(learn.dl(ds_type)):
            out = learn.model(xb)
            for x,y,z in zip(xb,yb,out):
                inputs.append(learn.data.train_ds.x.reconstruct(x))
                targets.append(learn.data.train_ds.y.reconstruct(y))
                outputs.append(learn.data.train_ds.y.reconstruct(z.argmax(1)))
    return inputs, targets, outputs
class CorpusBLEU(Callback):
    def __init__(self, vocab_sz):
        self.vocab_sz = vocab_sz
        self.name = 'bleu'
    
    def on_epoch_begin(self, **kwargs):
        self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4
    
    def on_batch_end(self, last_output, last_target, **kwargs):
        last_output = last_output.argmax(dim=-1)
        for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()):
            self.pred_len += len(pred)
            self.targ_len += len(targ)
            for i in range(4):
                c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz)
                self.corrects[i] += c
                self.counts[i]   += t
    
    def on_epoch_end(self, last_metrics, **kwargs):
        precs = [c/t for c,t in zip(self.corrects,self.counts)]
        len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1
        bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25)
        return add_metrics(last_metrics, bleu)
class TeacherForcing(LearnerCallback):
    def __init__(self, learn, end_epoch):
        super().__init__(learn)
        self.end_epoch = end_epoch
    
    def on_batch_begin(self, last_input, last_target, train, **kwargs):
        if train: return {'last_input': [last_input, last_target]}
    
    def on_epoch_begin(self, epoch, **kwargs):
        self.learn.model.pr_force = 1 - epoch/self.end_epoch

Реализация внимания

class Seq2SeqRNN_attn(nn.Module):
    def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1):
        super().__init__()
        self.nl,self.nh,self.out_sl,self.pr_force = nl,nh,out_sl,1
        self.bos_idx,self.pad_idx = bos_idx,pad_idx
        self.emb_enc,self.emb_dec = emb_enc,emb_dec
        self.emb_sz_enc,self.emb_sz_dec = emb_enc.embedding_dim,emb_dec.embedding_dim
        self.voc_sz_dec = emb_dec.num_embeddings
                 
        self.emb_enc_drop = nn.Dropout(0.15)
        self.gru_enc = nn.GRU(self.emb_sz_enc, nh, num_layers=nl, dropout=0.25, 
                              batch_first=True, bidirectional=True)
        self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False)
        
        self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl,
                              dropout=0.1, batch_first=True)
        self.out_drop = nn.Dropout(0.35)
        self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec)
        self.out.weight.data = self.emb_dec.weight.data
        
        self.enc_att = nn.Linear(2*nh, self.emb_sz_dec, bias=False)
        self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec)
        self.V =  self.init_param(self.emb_sz_dec)
        
    def encoder(self, bs, inp):
        h = self.initHidden(bs)
        emb = self.emb_enc_drop(self.emb_enc(inp))
        enc_out, hid = self.gru_enc(emb, 2*h)
        
        pre_hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous()
        pre_hid = pre_hid.view(self.nl, bs, 2*self.nh)
        hid = self.out_enc(pre_hid)
        return hid,enc_out
    
    def decoder(self, dec_inp, hid, enc_att, enc_out):
        hid_att = self.hid_att(hid[-1])
        # we have put enc_out and hid through linear layers
        u = torch.tanh(enc_att + hid_att[:,None])
        # we want to learn the importance of each time step
        attn_wgts = F.softmax(u @ self.V, 1)
        # weighted average of enc_out (which is the output at every time step)
        ctx = (attn_wgts[...,None] * enc_out).sum(1)
        emb = self.emb_dec(dec_inp)
        # concatenate decoder embedding with context (we could have just
        # used the hidden state that came out of the decoder, if we weren't
        # using attention)
        outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid)
        outp = self.out(self.out_drop(outp[:,0]))
        return hid, outp
        
    def show(self, nm,v):
        if False: print(f"{nm}={v[nm].shape}")
        
    def forward(self, inp, targ=None):
        bs, sl = inp.size()
        hid,enc_out = self.encoder(bs, inp)
#        self.show("hid",vars())
        dec_inp = inp.new_zeros(bs).long() + self.bos_idx
        enc_att = self.enc_att(enc_out)
        
        res = []
        for i in range(self.out_sl):
            hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out)
            res.append(outp)
            dec_inp = outp.max(1)[1]
            if (dec_inp==self.pad_idx).all(): break
            if (targ is not None) and (random.random()<self.pr_force):
                if i>=targ.shape[1]: continue
                dec_inp = targ[:,i]
        return torch.stack(res, dim=1)
    def initHidden(self, bs): return one_param(self).new_zeros(2*self.nl, bs, self.nh)
    def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0]))
hid=torch.Size([2, 64, 300])
dec_inp=torch.Size([64])
enc_att=torch.Size([64, 30, 300])
hid_att=torch.Size([64, 300])
u=torch.Size([64, 30, 300])
attn_wgts=torch.Size([64, 30])
enc_out=torch.Size([64, 30, 512])
ctx=torch.Size([64, 512])
emb=torch.Size([64, 300])

model = Seq2SeqRNN_attn(emb_enc, emb_dec, 256, 30)
learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))],
                callback_fns=partial(TeacherForcing, end_epoch=30))
learn.lr_find()
learn.recorder.plot()
LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.

learn.fit_one_cycle(15, 3e-3)

In [146]:inputs, targets, outputs = get_predictions(learn)

100.00% [151/151 00:34<00:00]

In [147]:inputs[700], targets[700], outputs[700]
Out[147]:(Text xxbos qui a le pouvoir de modifier le règlement ...
In [148]:inputs[701], targets[701], outputs[701]
Out[148]:(Text xxbos ´ ` ou sont xxunk leurs grandes convictions...
In [149]:inputs[4002], targets[4002], outputs[4002]
Out[149]:(Text xxbos quelles ressources votre communauté possède...

Конец

Кредиты:

https://www.fast.ai/