Как обучить модель Huggingface TFT5ForConditionalGeneration?

Мой код выглядит следующим образом:

batch_size=8
sequence_length=25
vocab_size=100
import tensorflow as tf
from transformers import T5Config, TFT5ForConditionalGeneration
configT5 = T5Config(
    vocab_size=vocab_size,
    d_ff =512, 
)  
model = TFT5ForConditionalGeneration(configT5)

model.compile(
    optimizer = tf.keras.optimizers.Adam(),
    loss = tf.keras.losses.SparseCategoricalCrossentropy()
)
input = tf.random.uniform([batch_size,sequence_length],0,vocab_size,dtype=tf.int32)
labels = tf.random.uniform([batch_size,sequence_length],0,vocab_size,dtype=tf.int32)
input = {'inputs': input, 'decoder_input_ids': input}
model.fit(input, labels)

Выдает ошибку:

логиты и метки должны иметь одно и то же первое измерение, форму логитов [1600,64] и форму меток [200] [[node sparse_categorical_crossentropy_3 / SparseSoftmaxCrossEntropyWithLogits / SparseSoftmaxCrossEntropyWithLogits (определено в C: \ Google \ Users \ FA.Col.Cab.PROJ \ Блокноты \ PoetryTransformer \ эксперименты \ TFT5.py: 30)]] [Op: __ inference_train_function_25173] Стек вызовов функций: train_function

Я не понимаю - почему модель возвращает тензор [1600, 64]. Согласно https://huggingface.co/transformers/model_doc/t5.html#tft5forconditionalgeneration модель возвращает [batch_size, sequence_len, vocab_size].


person Andrey    schedule 26.08.2020    source источник


Ответы (1)


Вызвать fit() невозможно из-за нестандартной сигнатуры call() метода TFT5ForConditionalGeneration. Мне нужно переопределить train_step(), чтобы TFT5 заработал. См. Здесь - https://colab.research.google.com/github/snapthat/TF-T5-text-to-text/blob/master/snapthatT5/notebooks/TF-T5-Datasets%20Training.ipynb#scrollTo=cgxRVn34Z0wb

person Andrey    schedule 26.08.2020