RuntimeError: Ожидаемый объект скалярного типа Long, но получил скалярный тип Float для аргумента № 2 'mat2', как это исправить?


import torch.nn as nn 
import torch 
import torch.optim as optim
import itertools

class net1(nn.Module):
    def __init__(self):
        super(net1,self).__init__()

        self.pipe = nn.Sequential(
            nn.Linear(10,10),
            nn.ReLU()
        )

    def forward(self,x):
        return self.pipe(x.long())

class net2(nn.Module):
    def __init__(self):
        super(net2,self).__init__()

        self.pipe = nn.Sequential(
            nn.Linear(10,20),
            nn.ReLU(),
            nn.Linear(20,10)
        )

    def forward(self,x):
        return self.pipe(x.long())



netFIRST = net1()
netSECOND = net2()

learning_rate = 0.001

opt = optim.Adam(itertools.chain(netFIRST.parameters(),netSECOND.parameters()), lr=learning_rate)

epochs = 1000

x = torch.tensor([1,2,3,4,5,6,7,8,9,10],dtype=torch.long)
y = torch.tensor([10,9,8,7,6,5,4,3,2,1],dtype=torch.long)


for epoch in range(epochs):
    opt.zero_grad()

    prediction = netSECOND(netFIRST(x))
    loss = (y.long() - prediction)**2
    loss.backward()

    print(loss)
    print(prediction)
    opt.step()

ошибка:

строка 49, в прогнозе = netSECOND (netFIRST (x))

строка 1371, линейная; output = input.matmul (weight.t ())

RuntimeError: ожидаемый объект скалярного типа Long, но получил скалярный тип Float для аргумента # 2 'mat2'

Я действительно не понимаю, что делаю не так. Я всячески старался перевернуть все на Long. Я действительно не понимаю, как печатать на pytorch. В прошлый раз я пробовал что-то с одним слоем, и это заставило меня использовать тип int. Может ли кто-нибудь объяснить, как типизация устанавливается в pytorch и как предотвращать и исправлять такие ошибки? Я имею в виду заранее огромное спасибо, эта проблема действительно беспокоит меня, и я не могу ее исправить, что бы я ни пытался.


person hal9000    schedule 29.09.2019    source источник


Ответы (1)


Веса - плавающие, входные - длинные. Это не разрешено. На самом деле, я не думаю, что torch поддерживает что-либо еще, кроме Float в нейронных сетях.

Если вы удалите все вызовы на long и определите свой ввод как числа с плавающей запятой, он будет работать (да, я пробовал).

(Затем вы получите еще одну несвязанную ошибку: вам нужно суммировать свой убыток)

person Martino    schedule 29.09.2019
comment
Неважно, разобрался, как это работает. Мне нужно было удалить все dtype long, а затем изменить dtype на torch.float и суммирование в проигрыше, конечно. - person hal9000; 29.09.2019
comment
всегда пожалуйста. Отметьте ответ как принятый (щелкните галочку), если он был удовлетворительным. - person Martino; 30.09.2019