Потеря кроссэнтропии Pytorch с 3D-входом

У меня есть сеть, которая выводит трехмерный тензор размера (batch_size, max_len, num_classes). Моя гордая правда в форме (batch_size, max_len). Если я выполню однократное кодирование этикеток, они будут иметь форму (batch_size, max_len, num_classes), т.е. значения в max_len являются целыми числами в диапазоне [0, num_classes]. Поскольку исходный код слишком длинный, я написал более простую версию, воспроизводящую исходную ошибку.

criterion = nn.CrossEntropyLoss()
batch_size = 32
max_len = 350
num_classes = 1000
pred = torch.randn([batch_size, max_len, num_classes])
label = torch.randint(0, num_classes,[batch_size, max_len])
pred = nn.Softmax(dim = 2)(pred)
criterion(pred, label)

форма пред и метка соответственно torch.Size([32, 350, 1000]) и torch.Size([32, 350])

Произошла ошибка

ValueError: ожидаемый размер цели (32, 1000), есть torch.Size ([32, 350, 1000])

Если я горячо кодирую метки для вычисления потерь

x = nn.functional.one_hot(label)
criterion(pred, x)

это вызовет следующую ошибку

ValueError: ожидаемый размер цели (32, 1000), есть torch.Size ([32, 350, 1000])


person Hari Krishnan    schedule 29.08.2020    source источник


Ответы (1)


Из документации Pytorch CrossEntropyLoss ожидает форму входных данных быть (N, C, ...), поэтому второе измерение - это всегда количество классов. Ваш код должен работать, если вы измените форму preds на размер (batch_size, num_classes, max_len).

person Kevin    schedule 29.08.2020
comment
Даже если вы быстро закодируете метки, он выдаст ошибку при передаче в CrossEntropyLoss - person Hari Krishnan; 29.08.2020
comment
Извините, я думаю, что нашел проблему и обновил свой ответ. - person Kevin; 29.08.2020