Как использовать предварительно обученный вектор встраивания слов BERT для точной настройки (инициализации) других сетей?

Когда я занимался классификацией с помощью textcnn, у меня был опыт тонкой настройки textcnn с использованием предварительно обученного встраивания слов с помощью Word2Vec и fasttext. И я использую этот процесс:

  1. Создайте слой встраивания в textcnn
  2. Загрузите матрицу встраивания слов, используемых на этот раз Word2Vec или fasttext
  3. Поскольку значение вектора слоя внедрения будет изменяться во время обучения, сеть находится в стадии тонкой настройки.

Недавно я тоже хочу попробовать BERT для этого. Я подумал: «Поскольку должно быть несколько различий в использовании предварительно обученного встраивания BERT для начального уровня встраивания и тонкой настройки других сетей, это должно быть легко!» Но на самом деле вчера я пытался весь день и до сих пор не могу этого сделать.
Я обнаружил, что, поскольку встраивание BERT является контекстным встраиванием, особенно при извлечении вложений слов, вектор каждого слова из каждого предложения будет варьироваться, поэтому кажется, что нет никакого способа использовать это вложение для инициализации уровня встраивания другой сети, как обычно ...

Наконец, я придумал один метод «точной настройки» в виде следующих шагов:

  1. Во-первых, не определяйте слой встраивания в textcnn.
  2. Вместо использования уровня внедрения в части обучения сети я сначала передаю токены последовательности в предварительно обученную модель BERT и получаю вложения слов для каждого предложения.
  3. Поместите вложение слова BERT из 2. в textcnn и обучите сеть textcnn.

Используя этот метод, я, наконец, смог тренироваться, но, если серьезно подумать, я не думаю, что вообще занимаюсь тонкой настройкой ...
Потому что, как вы можете видеть, каждый раз, когда я начинаю новый цикл тренировки, вложение слов, сгенерированное с помощью BERT, всегда является одним и тем же вектором, поэтому просто ввод этих неизмененных векторов в textcnn не позволит вообще настроить textcnn, верно?

ОБНОВЛЕНИЕ: Я придумал новый метод использования вложений BERT и «обучения» BERT и textcnn вместе.
Некоторая часть моего кода:

    BERTmodel = AutoModel.from_pretrained('bert- 
                base-uncased',output_hidden_states=True).to(device)
    TextCNNmodel = TextCNN(EMBD_DIM, CLASS_NUM, KERNEL_NUM, 
                   KERNEL_SIZES).to(device)
    optimizer = torch.optim.Adam(TextCNNmodel.parameters(), lr=LR)
    loss_func = nn.CrossEntropyLoss()
  for epoch in range(EPOCH):
    TextCNNmodel.train()
    BERTmodel.train()
    for step, (token_batch, seg_batch, y_batch) in enumerate(train_loader):
        token_batch = token_batch.to(device)
        y_batch = y_batch.to(device)

        BERToutputs = BERTmodel(token_batch)
        # I want to use the second-to-last hidden layer as the embedding, so
        x_batch = BERToutputs[2][-2]

        output = TextCNNmodel(x_batch)
        output = output.squeeze()
        loss = loss_func(output, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Я думаю, что при включении BERTmodel.train () и удалении torch.no_grad () при встраивании градиент потерь также может быть обратным к BERTmodel. Процесс обучения TextCNNmodel также прошел гладко.
Чтобы использовать эту модель позже, я сохранил параметры как TextCNNmodel, так и BERTmodel.
Затем, чтобы поэкспериментировать, действительно ли BERTmodel обучается и изменяется, в другой программе я загружаю BERTModel и ввожу предложение, чтобы проверить это действительно ли обучалась модель BERTM.
Однако я обнаружил, что результат (встраивание) исходной модели 'bert-base-uncased' и моей 'BERTmodel' совпадают, что разочаровывает ...
Я действительно не понимаю, почему BERTmodel часть не менял ...


person Peter Nova    schedule 20.01.2021    source источник
comment
Вы можете показать код? Я предполагаю, что вы передали оптимизатору только параметры нижестоящей модели и параметры BERT.   -  person Jindřich    schedule 20.01.2021
comment
Спасибо за ваш комментарий! @ Jindřich Я загружаю часть своего кода и мыслей. Пожалуйста, обратитесь к нему!   -  person Peter Nova    schedule 20.01.2021
comment
Ой, похоже, я нашёл брешь в своём коде ... Может я не делал оптимизатор для BERTmodel, поэтому BERT не менял? Верно?..   -  person Peter Nova    schedule 20.01.2021


Ответы (1)


Здесь я хотел бы поблагодарить @ Jindřich, спасибо за важный совет!
Я думаю, что почти у цели, когда использую обновленный код версии, но я забыл установить оптимизатор для BERTmodel.
После того, как я установил оптимизатор и снова провел процесс обучения, на этот раз, когда я загружаю свою BERTmodel, я обнаружил, что выходные данные (встраивание) исходной модели 'bert-base-uncased' и моей 'BERTmodel', наконец, отличаются, что означает, что эта BERTmodel изменена и должна быть настроена.
Вот мои окончательные коды, надеюсь, что это может Вам тоже поможет.

    BERTmodel = AutoModel.from_pretrained('bert- 
                base-uncased',output_hidden_states=True).to(device)
    TextCNNmodel = TextCNN(EMBD_DIM, CLASS_NUM, KERNEL_NUM, 
                   KERNEL_SIZES).to(device)
    optimizer = torch.optim.Adam(TextCNNmodel.parameters(), lr=LR)
    optimizer_bert = torch.optim.Adamw(BERTmodel.parameters(), lr=2e-5, weight_decay=1e-2)
    loss_func = nn.CrossEntropyLoss()
  for epoch in range(EPOCH):
    TextCNNmodel.train()
    BERTmodel.train()
    for step, (token_batch, seg_batch, y_batch) in enumerate(train_loader):
        token_batch = token_batch.to(device)
        y_batch = y_batch.to(device)

        BERToutputs = BERTmodel(token_batch)
        # I want to use the second-to-last hidden layer as the embedding, so
        x_batch = BERToutputs[2][-2]

        output = TextCNNmodel(x_batch)
        output = output.squeeze()
        loss = loss_func(output, y_batch)

        optimizer.zero_grad()
        optimizer_bert.zero_grad()
        loss.backward()
        optimizer.step()
        optimizer_bert.step()

Я продолжу свои эксперименты, чтобы увидеть, действительно ли моя модель BERT настраивается.

person Peter Nova    schedule 20.01.2021