Перевод с RNN
Часть 4: Двунаправленный и внимание RNN
В полном руководстве по НЛП с фастаем
Перейдите по ссылке на всю серию, нажав здесь: Полное руководство по НЛП с фастай
В этом посте будет собрано все, что мы узнали до этого момента, а затем представлен перевод с помощью RNN.
Это захватывающе, потому что результаты нашего путешествия по изучению НЛП можно резюмировать следующим образом:
В восторге???
Начнем…
Перевод с RNN
В этом посте мы займемся переводом. Мы будем переводить с французского на английский, и, чтобы наша задача оставалась управляемой, мы ограничимся переводом вопросов.
Эта задача является примером последовательности для последовательности (seq2seq). Seq2seq может быть более сложной задачей, чем классификация, поскольку выходные данные имеют переменную длину (и обычно отличаются от длины входных данных.
Французско-английские параллельные тексты с http://www.statmt.org/wmt15/translation-task.html. Он был создан Крисом Каллисон-Бёрчем, который просканировал миллионы веб-страниц, а затем использовал набор простых эвристик для преобразования URL-адресов на французском языке в URL-адреса на английском языке (т. е. заменив fr на en и около 40 других рукописных правила) и предположить, что эти документы являются переводами друг друга.
Перевод намного сложнее в прямом PyTorch: https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
In [1]:from fastai.text import *
Загрузите и предварительно обработайте наши данные
Мы начнем с сокращения исходного набора данных до вопросов. Вам нужно выполнить это только один раз, раскомментируйте для запуска. Набор данных можно скачать здесь.
In [2]:path = Config().data_path() In [3]: # ! wget https://s3.amazonaws.com/fast-ai-nlp/giga-fren.tgz -P {path} In [4]:# ! tar xf {path}/giga-fren.tgz -C {path} In [3]: path = Config().data_path()/'giga-fren' path.ls() Out[3]: [PosixPath('/home/racheltho/.fastai/data/giga-fren/models'), ... In [6]:# with open(path/'giga-fren.release2.fixed.fr') as f: fr = f.read().split('\n') In [7]:# with open(path/'giga-fren.release2.fixed.en') as f: en = f.read().split('\n')
Мы будем использовать регулярное выражение для выбора вопросов, находя строки в наборе данных на английском языке, которые начинаются с «Wh» и заканчиваются вопросительным знаком. Вам нужно запустить эти строки только один раз:
In [8]: # re_eq = re.compile('^(Wh[^?.!]+\?)') # re_fq = re.compile('^([^?.!]+\?)') # en_fname = path/'giga-fren.release2.fixed.en' # fr_fname = path/'giga-fren.release2.fixed.fr' In [9]: # lines = ((re_eq.search(eq), re_fq.search(fq)) # for eq, fq in zip(open(en_fname, encoding='utf-8'), open(fr_fname, encoding='utf-8'))) # qs = [(e.group(), f.group()) for e,f in lines if e and f] In [10]: # qs = [(q1,q2) for q1,q2 in qs] # df = pd.DataFrame({'fr': [q[1] for q in qs], 'en': [q[0] for q in qs]}, columns = ['en', 'fr']) # df.to_csv(path/'questions_easy.csv', index=False) In [11]: path.ls() Out[11]: [PosixPath('/home/racheltho/.fastai/data/giga-fren/models'), ...
Загрузите наши данные в DataBunch
Теперь наши вопросы выглядят так:
In [4]: df = pd.read_csv(path/'questions_easy.csv') df.head()
Для простоты мы пишем все строчными буквами.
In [5]: df['en'] = df['en'].apply(lambda x:x.lower()) df['fr'] = df['fr'].apply(lambda x:x.lower())
Во-первых, нам нужно будет сопоставить входные данные и цели в пакете: они имеют разную длину, поэтому нам нужно добавить отступы, чтобы длина последовательности была одинаковой;
In [7]: def seq2seq_collate(samples, pad_idx=1, pad_first=True, backwards=False): "Function that collect samples and adds padding. Flips token order if needed" samples = to_data(samples) max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples]) res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx if backwards: pad_first = not pad_first for i,s in enumerate(samples): if pad_first: res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1]) else: res_x[i,:len(s[0]):],res_y[i,:len(s[1]):] = LongTensor(s[0]),LongTensor(s[1]) if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1) return res_x,res_y
Затем мы создаем специальный DataBunch
, который использует эту функцию сопоставления.
In [8]:doc(Dataset) In [9]:doc(DataLoader) In [6]:doc(DataBunch) In [20]: class Seq2SeqDataBunch(TextDataBunch): "Create a `TextDataBunch` suitable for training an RNN classifier." @classmethod def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1, dl_tfms=None, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch: "Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`" datasets = cls._init_ds(train_ds, valid_ds, test_ds) val_bs = ifnone(val_bs, bs) collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards) train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2) train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs) dataloaders = [train_dl] for ds in datasets[1:]: lengths = [len(t) for t in ds.x.items] sampler = SortSampler(ds.x, key=lengths.__getitem__) dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs)) return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check) In [ ]:SortishSampler??
И подкласс TextList
, который будет использовать этот класс DataBunch
в вызове .databunch
и будет использовать TextList
для маркировки (поскольку нашими целями являются другие тексты).
In [21]: class Seq2SeqTextList(TextList): _bunch = Seq2SeqDataBunch _label_cls = TextList
Это все, что нам нужно для использования API блока данных!
In [22]: src = Seq2SeqTextList.from_df(df, path = path, cols='fr').split_by_rand_pct(seed=42).label_from_df(cols='en', label_cls=TextList) In [23]: np.percentile([len(o) for o in src.train.x.items] + [len(o) for o in src.valid.x.items], 90) Out[23]:28.0 In [24]: np.percentile([len(o) for o in src.train.y.items] + [len(o) for o in src.valid.y.items], 90) Out[24]:23.0
Мы удаляем элементы, в которых одна из целей имеет длину более 30 токенов.
In [25]: src = src.filter_by_func(lambda x,y: len(x) > 30 or len(y) > 30) In [26]:len(src.train) + len(src.valid) Out[26]:48352 In [27]:data = src.databunch() In [28]:data.save() In [29]:data Out[29]: Seq2SeqDataBunch; Train: LabelList (38706 items) x: Seq2SeqTextList xxbos qu’est - ce que la lumière ? ... Test: None In [30]:path Out[30]:PosixPath('/home/racheltho/.fastai/data/giga-fren') In [31]:data = load_data(path) In [32]:data.show_batch()
Создайте нашу модель
Предварительно обученные встраивания
Вам нужно будет загрузить вложения слов (векторы сканирования) из документации fastText. FastText имеет предварительно обученные векторы слов для 157 языков, обученные на Common Crawl и Википедии. Эти модели были обучены с помощью CBOW.
Если вам нужно освежить в памяти информацию о встраивании слов, вы можете ознакомиться с моим нежным введением в этом семинаре по встраиванию слов с сопроводительным репозиторием github.
Дополнительная информация о CBOW (непрерывный пакет слов и пропуск граммов):
- учебник по быстрому тексту
- "Переполнение стека"
Чтобы установить FastText:
$ git clone https://github.com/facebookresearch/fastText.git
$ cd fastText
$ pip install .
In [33]: import fastText as ft
Строки для загрузки векторов слов нужно запустить только один раз:
In [60]: # ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz -P {path} # ! wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.fr.300.bin.gz -P {path} In [61]: # gunzip {path} / cc.en.300.bin.gz # gunzip {path} / cc.fr.300.bin.gz In [34]: fr_vecs = ft.load_model(str((path/'cc.fr.300.bin'))) en_vecs = ft.load_model(str((path/'cc.en.300.bin')))
Мы создаем модуль встраивания с предварительно обученными векторами и случайными данными для недостающих частей.
In [35]: def create_emb(vecs, itos, em_sz=300, mult=1.): emb = nn.Embedding(len(itos), em_sz, padding_idx=1) wgts = emb.weight.data vec_dic = {w:vecs.get_word_vector(w) for w in vecs.get_words()} miss = [] for i,w in enumerate(itos): try: wgts[i] = tensor(vec_dic[w]) except: miss.append(w) return emb In [36]: emb_enc = create_emb(fr_vecs, data.x.vocab.itos) emb_dec = create_emb(en_vecs, data.y.vocab.itos) In [37]:emb_enc.weight.size(), emb_dec.weight.size() Out[37]:(torch.Size([11336, 300]), torch.Size([8152, 300])) In [38]:model_path = Config().model_path() In [39]: torch.save(emb_enc, model_path/'fr_emb.pth') torch.save(emb_dec, model_path/'en_emb.pth') In [40]: emb_enc = torch.load(model_path/'fr_emb.pth') emb_dec = torch.load(model_path/'en_emb.pth')
Наша модель
Вопрос обзора: какие два типа чисел существуют в глубоком обучении?
Кодеры и декодеры
Сама по себе модель состоит из энкодера и декодера
Кодировщик представляет собой рекуррентную нейронную сеть, и мы передаем ему наше входное предложение, производя вывод (который мы пока отбрасываем) и скрытое состояние. Скрытое состояние — это активация, исходящая из RNN.
Затем это скрытое состояние передается декодеру (другой RNN), который использует его в сочетании с прогнозируемыми выходными данными для получения перевода. Мы зацикливаемся до тех пор, пока декодер не создаст токен заполнения (или до 30 итераций, чтобы убедиться, что это не бесконечный цикл в начале обучения).
Мы будем использовать GRU для нашего кодировщика и отдельный GRU для нашего декодера. Другие варианты — использовать LSTM или QRNN (см. здесь). GRU, LSTM и QRNN решают проблему отсутствия в RNN долговременной памяти.
Ссылки:
In [43]: class Seq2SeqRNN(nn.Module): def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1): super().__init__() self.nl,self.nh,self.out_sl = nl,nh,out_sl self.bos_idx,self.pad_idx = bos_idx,pad_idx self.em_sz_enc = emb_enc.embedding_dim self.em_sz_dec = emb_dec.embedding_dim self.voc_sz_dec = emb_dec.num_embeddings self.emb_enc = emb_enc self.emb_enc_drop = nn.Dropout(0.15) self.gru_enc = nn.GRU(self.em_sz_enc, nh, num_layers=nl, dropout=0.25, batch_first=True) self.out_enc = nn.Linear(nh, self.em_sz_dec, bias=False) self.emb_dec = emb_dec self.gru_dec = nn.GRU(self.em_sz_dec, self.em_sz_dec, num_layers=nl, dropout=0.1, batch_first=True) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(self.em_sz_dec, self.voc_sz_dec) self.out.weight.data = self.emb_dec.weight.data def encoder(self, bs, inp): h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) _, h = self.gru_enc(emb, h) h = self.out_enc(h) return h def decoder(self, dec_inp, h): emb = self.emb_dec(dec_inp).unsqueeze(1) outp, h = self.gru_dec(emb, h) outp = self.out(self.out_drop(outp[:,0])) return h, outp def forward(self, inp): bs, sl = inp.size() h = self.encoder(bs, inp) dec_inp = inp.new_zeros(bs).long() + self.bos_idx res = [] for i in range(self.out_sl): h, outp = self.decoder(dec_inp, h) dec_inp = outp.max(1)[1] res.append(outp) if (dec_inp==self.pad_idx).all(): break return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(self.nl, bs, self.nh) In [44]:xb,yb = next(iter(data.valid_dl)) In [45]:xb.shape Out[45]:torch.Size([64, 30]) In [46]:rnn = Seq2SeqRNN(emb_enc, emb_dec, 256, 30) In [47]:rnn Out[47]: Seq2SeqRNN( (emb_enc): Embedding(11336, 300, padding_idx=1) (emb_enc_drop): Dropout(p=0.15) (gru_enc): GRU(300, 256, num_layers=2, batch_first=True, dropout=0.25) (out_enc): Linear(in_features=256, out_features=300, bias=False) (emb_dec): Embedding(8152, 300, padding_idx=1) (gru_dec): GRU(300, 300, num_layers=2, batch_first=True, dropout=0.1) (out_drop): Dropout(p=0.35) (out): Linear(in_features=300, out_features=8152, bias=True) ) In [48]: len(xb[0]) Out[48]:30 In [51]:h = rnn.encoder(64, xb.cpu()) In [52]:h.size() Out[52]:torch.Size([2, 64, 300])
Поля потерь выводятся и настраиваются таким образом, чтобы они имели одинаковый размер, прежде чем использовать обычную сглаженную версию кросс-энтропии. Мы делаем то же самое для точности.
In [53]: def seq2seq_loss(out, targ, pad_idx=1): bs,targ_len = targ.size() _,out_len,vs = out.size() if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx) if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx) return CrossEntropyFlat()(out, targ)
Обучите нашу модель
In [54]:learn = Learner(data, rnn, loss_func=seq2seq_loss) In [55]:learn.lr_find() LR Finder is complete, type {learner_name}.recorder.plot() to see the graph. In [56]:learn.recorder.plot()
In [57]:learn.fit_one_cycle(4, 1e-2)
Освободим немного оперативной памяти
In [58]: del fr_vecs del en_vecs
Поскольку потери не очень интерпретируемы, давайте также посмотрим на точность. Опять же, мы добавим заполнение, чтобы выходные данные и цель имели одинаковую длину.
In [59]: def seq2seq_acc(out, targ, pad_idx=1): bs,targ_len = targ.size() _,out_len,vs = out.size() if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx) if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx) out = out.argmax(2) return (out==targ).float().mean()
Метрика Bleu (см. специальный блокнот)
В переводе обычно используется метрика BLEU.
Отличный пост от Rachael Tatman: Оценка вывода текста в NLP: BLEU на свой страх и риск
In [60]: class NGram(): def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n def __eq__(self, other): if len(self.ngram) != len(other.ngram): return False return np.all(np.array(self.ngram) == np.array(other.ngram)) def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)])) In [61]: def get_grams(x, n, max_n=5000): return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)] In [62]: def get_correct_ngrams(pred, targ, n, max_n=5000): pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n) pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams) return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams) In [63]: class CorpusBLEU(Callback): def __init__(self, vocab_sz): self.vocab_sz = vocab_sz self.name = 'bleu' def on_epoch_begin(self, **kwargs): self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4 def on_batch_end(self, last_output, last_target, **kwargs): last_output = last_output.argmax(dim=-1) for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()): self.pred_len += len(pred) self.targ_len += len(targ) for i in range(4): c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz) self.corrects[i] += c self.counts[i] += t def on_epoch_end(self, last_metrics, **kwargs): precs = [c/t for c,t in zip(self.corrects,self.counts)] len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1 bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25) return add_metrics(last_metrics, bleu)
Обучение с метриками
In [64]: learn = Learner(data, rnn, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))]) In [65]:learn.lr_find() LR Finder is complete, type {learner_name}.recorder.plot() to see the graph. In [66]:learn.recorder.plot()
In [67]: learn.fit_one_cycle(4, 1e-2) learn.fit_one_cycle(4, 1e-3)
Насколько хороша наша модель? Давайте посмотрим несколько прогнозов.
In [68]: def get_predictions(learn, ds_type=DatasetType.Valid): learn.model.eval() inputs, targets, outputs = [],[],[] with torch.no_grad(): for xb,yb in progress_bar(learn.dl(ds_type)): out = learn.model(xb) for x,y,z in zip(xb,yb,out): inputs.append(learn.data.train_ds.x.reconstruct(x)) targets.append(learn.data.train_ds.y.reconstruct(y)) outputs.append(learn.data.train_ds.y.reconstruct(z.argmax(1))) return inputs, targets, outputs In [88]: inputs, targets, outputs = get_predictions(learn)
100.00% [151/151 00:24<00:00]
In [89]:inputs[700], targets[700], outputs[700] Out[89]:(Text xxbos quels sont les résultats prévus à court et à ... In [90]:inputs[701], targets[701], outputs[701] Out[90]:(Text xxbos de quel(s ) xxunk ) a - t - on besoin pour xx... In [91]:inputs[2513], targets[2513], outputs[2513] Out[91]:(Text xxbos de quelles façons l'expérience et les capaci... In [92]:inputs[4000], targets[4000], outputs[4000] Out[92]:(Text xxbos qu'est - ce que la maladie de xxunk -...
Обычно он начинается хорошо, но в конце вопроса прерывается повторяющимися словами.
Принуждение учителя
Один из способов помочь обучению — помочь декодеру, скармливая ему реальные цели вместо его прогнозов (если он начнет с неправильных слов, очень маловероятно, что он даст нам правильный перевод). Мы делаем это все время в начале, затем постепенно уменьшаем количество принуждения учителя.
In [83]: class TeacherForcing(LearnerCallback): def __init__(self, learn, end_epoch): super().__init__(learn) self.end_epoch = end_epoch def on_batch_begin(self, last_input, last_target, train, **kwargs): if train: return {'last_input': [last_input, last_target]} def on_epoch_begin(self, epoch, **kwargs): self.learn.model.pr_force = 1 - epoch/self.end_epoch
Мы добавим следующий код в наш метод forward
:
if (targ is not None) and (random.random()<self.pr_force):
if i>=targ.shape[1]: break
dec_inp = targ[:,i]
Кроме того, forward
будет принимать дополнительный аргумент target
.
In [88]: class Seq2SeqRNN_tf(nn.Module): def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1): super().__init__() self.nl,self.nh,self.out_sl = nl,nh,out_sl self.bos_idx,self.pad_idx = bos_idx,pad_idx self.em_sz_enc = emb_enc.embedding_dim self.em_sz_dec = emb_dec.embedding_dim self.voc_sz_dec = emb_dec.num_embeddings self.emb_enc = emb_enc self.emb_enc_drop = nn.Dropout(0.15) self.gru_enc = nn.GRU(self.em_sz_enc, nh, num_layers=nl, dropout=0.25, batch_first=True) self.out_enc = nn.Linear(nh, self.em_sz_dec, bias=False) self.emb_dec = emb_dec self.gru_dec = nn.GRU(self.em_sz_dec, self.em_sz_dec, num_layers=nl, dropout=0.1, batch_first=True) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(self.em_sz_dec, self.voc_sz_dec) self.out.weight.data = self.emb_dec.weight.data self.pr_force = 0. def encoder(self, bs, inp): h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) _, h = self.gru_enc(emb, h) h = self.out_enc(h) return h def decoder(self, dec_inp, h): emb = self.emb_dec(dec_inp).unsqueeze(1) outp, h = self.gru_dec(emb, h) outp = self.out(self.out_drop(outp[:,0])) return h, outp def forward(self, inp, targ=None): bs, sl = inp.size() h = self.encoder(bs, inp) dec_inp = inp.new_zeros(bs).long() + self.bos_idx res = [] for i in range(self.out_sl): h, outp = self.decoder(dec_inp, h) res.append(outp) dec_inp = outp.max(1)[1] if (dec_inp==self.pad_idx).all(): break if (targ is not None) and (random.random()<self.pr_force): if i>=targ.shape[1]: continue dec_inp = targ[:,i] return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(self.nl, bs, self.nh) In [90]: emb_enc = torch.load(model_path/'fr_emb.pth') emb_dec = torch.load(model_path/'en_emb.pth') In [91]: rnn_tf = Seq2SeqRNN_tf(emb_enc, emb_dec, 256, 30) learn = Learner(data, rnn_tf, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))], callback_fns=partial(TeacherForcing, end_epoch=3)) In [74]: learn.lr_find() LR Finder is complete, type {learner_name}.recorder.plot() to see the graph. In [75]:learn.recorder.plot()
In [92]:learn.fit_one_cycle(6, 3e-3)
In [77]:inputs, targets, outputs = get_predictions(learn)
100.00% [151/151 00:23<00:00]
In [78]:inputs[700],targets[700],outputs[700] Out[78]:(Text xxbos qui a le pouvoir de modifier ... In [79]:inputs[2513], targets[2513], outputs[2513] Out[79]:(Text xxbos quelles sont les deux tendances qui ont... In [80]:inputs[4000], targets[4000], outputs[4000] Out[80]:(Text xxbos où les aires marines nationales de conser...
Время перейти к ВНИМАНИЕ!?
"Эта последовательная природа [RNN] препятствует распараллеливанию в обучающих примерах, что становится критическим при более длинных последовательностях, поскольку ограничения памяти ограничивают пакетную обработку примеров".
Перевод Seq2Seq с вниманием
Внимание — это метод, который использует выходные данные нашего кодировщика: вместо того, чтобы полностью его отбрасывать, мы используем его с нашим скрытым состоянием, чтобы обращать внимание на определенные слова во входном предложении для прогнозов в выходном предложении. В частности, мы вычисляем веса внимания, а затем добавляем к входным данным декодера линейную комбинацию выходных данных кодировщика с этими весами внимания.
Хорошая иллюстрация внимания исходит из этого поста в блоге Джея Аламмара (визуализация изначально из Tensor2Tensor Notebook):
Вторая вещь, которая может помочь, — это использование двунаправленной модели для кодировщика. Мы устанавливаем параметр bidrectional
равным True
для нашего кодировщика GRU и удваиваем количество входных данных для уровня линейного вывода кодировщика.
Кроме того, теперь нам нужно установить наше скрытое состояние:
hid = hid.view(2,self.n_layers, bs, self.n_hid).permute(1,2,0,3).contiguous()
hid = self.out_enc(self.hid_dp(hid).view(self.n_layers, bs, 2*self.n_hid))
Код для повторного запуска с самого начала
In [1]:from fastai.text import * In [2]:path = Config().data_path() In [3]:path = Config().data_path()/'giga-fren' In [4]: def seq2seq_collate(samples:BatchSamples, pad_idx:int=1, pad_first:bool=True, backwards:bool=False) -> Tuple[LongTensor, LongTensor]: "Function that collect samples and adds padding. Flips token order if needed" samples = to_data(samples) max_len_x,max_len_y = max([len(s[0]) for s in samples]),max([len(s[1]) for s in samples]) res_x = torch.zeros(len(samples), max_len_x).long() + pad_idx res_y = torch.zeros(len(samples), max_len_y).long() + pad_idx if backwards: pad_first = not pad_first for i,s in enumerate(samples): if pad_first: res_x[i,-len(s[0]):],res_y[i,-len(s[1]):] = LongTensor(s[0]),LongTensor(s[1]) else: res_x[i,:len(s[0])],res_y[i,:len(s[1])] = LongTensor(s[0]),LongTensor(s[1]) if backwards: res_x,res_y = res_x.flip(1),res_y.flip(1) return res_x,res_y class Seq2SeqDataBunch(TextDataBunch): "Create a `TextDataBunch` suitable for training an RNN classifier." @classmethod def create(cls, train_ds, valid_ds, test_ds=None, path:PathOrStr='.', bs:int=32, val_bs:int=None, pad_idx=1, dl_tfms=None, pad_first=False, device:torch.device=None, no_check:bool=False, backwards:bool=False, **dl_kwargs) -> DataBunch: "Function that transform the `datasets` in a `DataBunch` for classification. Passes `**dl_kwargs` on to `DataLoader()`" datasets = cls._init_ds(train_ds, valid_ds, test_ds) val_bs = ifnone(val_bs, bs) collate_fn = partial(seq2seq_collate, pad_idx=pad_idx, pad_first=pad_first, backwards=backwards) train_sampler = SortishSampler(datasets[0].x, key=lambda t: len(datasets[0][t][0].data), bs=bs//2) train_dl = DataLoader(datasets[0], batch_size=bs, sampler=train_sampler, drop_last=True, **dl_kwargs) dataloaders = [train_dl] for ds in datasets[1:]: lengths = [len(t) for t in ds.x.items] sampler = SortSampler(ds.x, key=lengths.__getitem__) dataloaders.append(DataLoader(ds, batch_size=val_bs, sampler=sampler, **dl_kwargs)) return cls(*dataloaders, path=path, device=device, collate_fn=collate_fn, no_check=no_check) class Seq2SeqTextList(TextList): _bunch = Seq2SeqDataBunch _label_cls = TextList data = load_data(path) model_path = Config().model_path() emb_enc = torch.load(model_path/'fr_emb.pth') emb_dec = torch.load(model_path/'en_emb.pth') def seq2seq_loss(out, targ, pad_idx=1): bs,targ_len = targ.size() _,out_len,vs = out.size() if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx) if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx) return CrossEntropyFlat()(out, targ) def seq2seq_acc(out, targ, pad_idx=1): bs,targ_len = targ.size() _,out_len,vs = out.size() if targ_len>out_len: out = F.pad(out, (0,0,0,targ_len-out_len,0,0), value=pad_idx) if out_len>targ_len: targ = F.pad(targ, (0,out_len-targ_len,0,0), value=pad_idx) out = out.argmax(2) return (out==targ).float().mean() class NGram(): def __init__(self, ngram, max_n=5000): self.ngram,self.max_n = ngram,max_n def __eq__(self, other): if len(self.ngram) != len(other.ngram): return False return np.all(np.array(self.ngram) == np.array(other.ngram)) def __hash__(self): return int(sum([o * self.max_n**i for i,o in enumerate(self.ngram)])) def get_grams(x, n, max_n=5000): return x if n==1 else [NGram(x[i:i+n], max_n=max_n) for i in range(len(x)-n+1)] def get_correct_ngrams(pred, targ, n, max_n=5000): pred_grams,targ_grams = get_grams(pred, n, max_n=max_n),get_grams(targ, n, max_n=max_n) pred_cnt,targ_cnt = Counter(pred_grams),Counter(targ_grams) return sum([min(c, targ_cnt[g]) for g,c in pred_cnt.items()]),len(pred_grams) def get_predictions(learn, ds_type=DatasetType.Valid): learn.model.eval() inputs, targets, outputs = [],[],[] with torch.no_grad(): for xb,yb in progress_bar(learn.dl(ds_type)): out = learn.model(xb) for x,y,z in zip(xb,yb,out): inputs.append(learn.data.train_ds.x.reconstruct(x)) targets.append(learn.data.train_ds.y.reconstruct(y)) outputs.append(learn.data.train_ds.y.reconstruct(z.argmax(1))) return inputs, targets, outputs class CorpusBLEU(Callback): def __init__(self, vocab_sz): self.vocab_sz = vocab_sz self.name = 'bleu' def on_epoch_begin(self, **kwargs): self.pred_len,self.targ_len,self.corrects,self.counts = 0,0,[0]*4,[0]*4 def on_batch_end(self, last_output, last_target, **kwargs): last_output = last_output.argmax(dim=-1) for pred,targ in zip(last_output.cpu().numpy(),last_target.cpu().numpy()): self.pred_len += len(pred) self.targ_len += len(targ) for i in range(4): c,t = get_correct_ngrams(pred, targ, i+1, max_n=self.vocab_sz) self.corrects[i] += c self.counts[i] += t def on_epoch_end(self, last_metrics, **kwargs): precs = [c/t for c,t in zip(self.corrects,self.counts)] len_penalty = exp(1 - self.targ_len/self.pred_len) if self.pred_len < self.targ_len else 1 bleu = len_penalty * ((precs[0]*precs[1]*precs[2]*precs[3]) ** 0.25) return add_metrics(last_metrics, bleu) class TeacherForcing(LearnerCallback): def __init__(self, learn, end_epoch): super().__init__(learn) self.end_epoch = end_epoch def on_batch_begin(self, last_input, last_target, train, **kwargs): if train: return {'last_input': [last_input, last_target]} def on_epoch_begin(self, epoch, **kwargs): self.learn.model.pr_force = 1 - epoch/self.end_epoch
Реализация внимания
class Seq2SeqRNN_attn(nn.Module): def __init__(self, emb_enc, emb_dec, nh, out_sl, nl=2, bos_idx=0, pad_idx=1): super().__init__() self.nl,self.nh,self.out_sl,self.pr_force = nl,nh,out_sl,1 self.bos_idx,self.pad_idx = bos_idx,pad_idx self.emb_enc,self.emb_dec = emb_enc,emb_dec self.emb_sz_enc,self.emb_sz_dec = emb_enc.embedding_dim,emb_dec.embedding_dim self.voc_sz_dec = emb_dec.num_embeddings self.emb_enc_drop = nn.Dropout(0.15) self.gru_enc = nn.GRU(self.emb_sz_enc, nh, num_layers=nl, dropout=0.25, batch_first=True, bidirectional=True) self.out_enc = nn.Linear(2*nh, self.emb_sz_dec, bias=False) self.gru_dec = nn.GRU(self.emb_sz_dec + 2*nh, self.emb_sz_dec, num_layers=nl, dropout=0.1, batch_first=True) self.out_drop = nn.Dropout(0.35) self.out = nn.Linear(self.emb_sz_dec, self.voc_sz_dec) self.out.weight.data = self.emb_dec.weight.data self.enc_att = nn.Linear(2*nh, self.emb_sz_dec, bias=False) self.hid_att = nn.Linear(self.emb_sz_dec, self.emb_sz_dec) self.V = self.init_param(self.emb_sz_dec) def encoder(self, bs, inp): h = self.initHidden(bs) emb = self.emb_enc_drop(self.emb_enc(inp)) enc_out, hid = self.gru_enc(emb, 2*h) pre_hid = hid.view(2, self.nl, bs, self.nh).permute(1,2,0,3).contiguous() pre_hid = pre_hid.view(self.nl, bs, 2*self.nh) hid = self.out_enc(pre_hid) return hid,enc_out def decoder(self, dec_inp, hid, enc_att, enc_out): hid_att = self.hid_att(hid[-1]) # we have put enc_out and hid through linear layers u = torch.tanh(enc_att + hid_att[:,None]) # we want to learn the importance of each time step attn_wgts = F.softmax(u @ self.V, 1) # weighted average of enc_out (which is the output at every time step) ctx = (attn_wgts[...,None] * enc_out).sum(1) emb = self.emb_dec(dec_inp) # concatenate decoder embedding with context (we could have just # used the hidden state that came out of the decoder, if we weren't # using attention) outp, hid = self.gru_dec(torch.cat([emb, ctx], 1)[:,None], hid) outp = self.out(self.out_drop(outp[:,0])) return hid, outp def show(self, nm,v): if False: print(f"{nm}={v[nm].shape}") def forward(self, inp, targ=None): bs, sl = inp.size() hid,enc_out = self.encoder(bs, inp) # self.show("hid",vars()) dec_inp = inp.new_zeros(bs).long() + self.bos_idx enc_att = self.enc_att(enc_out) res = [] for i in range(self.out_sl): hid, outp = self.decoder(dec_inp, hid, enc_att, enc_out) res.append(outp) dec_inp = outp.max(1)[1] if (dec_inp==self.pad_idx).all(): break if (targ is not None) and (random.random()<self.pr_force): if i>=targ.shape[1]: continue dec_inp = targ[:,i] return torch.stack(res, dim=1) def initHidden(self, bs): return one_param(self).new_zeros(2*self.nl, bs, self.nh) def init_param(self, *sz): return nn.Parameter(torch.randn(sz)/math.sqrt(sz[0])) hid=torch.Size([2, 64, 300]) dec_inp=torch.Size([64]) enc_att=torch.Size([64, 30, 300]) hid_att=torch.Size([64, 300]) u=torch.Size([64, 30, 300]) attn_wgts=torch.Size([64, 30]) enc_out=torch.Size([64, 30, 512]) ctx=torch.Size([64, 512]) emb=torch.Size([64, 300]) model = Seq2SeqRNN_attn(emb_enc, emb_dec, 256, 30) learn = Learner(data, model, loss_func=seq2seq_loss, metrics=[seq2seq_acc, CorpusBLEU(len(data.y.vocab.itos))], callback_fns=partial(TeacherForcing, end_epoch=30)) learn.lr_find() learn.recorder.plot() LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.
learn.fit_one_cycle(15, 3e-3)
In [146]:inputs, targets, outputs = get_predictions(learn)
100.00% [151/151 00:34<00:00]
In [147]:inputs[700], targets[700], outputs[700] Out[147]:(Text xxbos qui a le pouvoir de modifier le règlement ... In [148]:inputs[701], targets[701], outputs[701] Out[148]:(Text xxbos ´ ` ou sont xxunk leurs grandes convictions... In [149]:inputs[4002], targets[4002], outputs[4002] Out[149]:(Text xxbos quelles ressources votre communauté possède...