У меня проблемы с загрузкой моей модели в Google Colab. вот код:
Я прикрепил код ниже
Я попытался изменить имя statedict, и это в основном не помогает, я пытаюсь сохранить свою модель для последующего использования, но это становится чрезвычайно сложно, поскольку я не могу правильно сохранить и загрузить ее. Пожалуйста, помогите мне с проблемой. После раздела кода вы также найдете ошибку, которую я прикрепил ниже.
вот код
from zipfile import ZipFile
file_name = 'data.zip'
with ZipFile(file_name, 'r') as zip:
zip.extractall()
from zipfile import ZipFile
file_name = 'results.zip'
with ZipFile(file_name, 'r') as zip:
zip.extractall()
!pip install tensorflow-gpu
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
batchSize = 64
imageSize = 64
transform = transforms.Compose([transforms.Resize(imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])
dataset = dset.CIFAR10(root = './data', download = True, transform = transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batchSize, shuffle = True, num_workers = 2)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class G(nn.Module):
def __init__(self):
super(G, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(100, 512, 4, 1, 0, bias = False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False),
nn.Tanh()
)
def forward(self, input):
output = self.main(input)
return output
netG = G()
netG.load_state_dict(torch.load('generator.pth'))
netG.eval()
#netG.apply(weights_init)
class D(nn.Module):
def __init__(self):
super(D, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias = False),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(64, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(128, 256, 4, 2, 1, bias = False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(256, 512, 4, 2, 1, bias = False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace = True),
nn.Conv2d(512, 1, 4, 1, 0, bias = False),
nn.Sigmoid()
)
def forward(self, input):
output = self.main(input)
return output.view(-1)
netD = D()
netD.load_state_dict(torch.load('discriminator.pth'))
netD.eval()
#netD.apply(weights_init)
criterion = nn.BCELoss()
checkpoint = torch.load('discriminator.pth')
optimizerD = optim.Adam(netD.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerD.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
errD = checkpoint['loss']
checkpoint1 = torch.load('genrator.pth')
optimizerG = optim.Adam(netG.parameters(), lr = 0.0002, betas = (0.5, 0.999))
optimizerG.load_state_dict(checkpoint1['optimizer_state_dict'])
errG = checkpoint1['loss']
k = epoch
for j in range(k, 10):
for i, data in enumerate(dataloader, 0):
netD.zero_grad()
real, _ = data
input = Variable(real)
target = Variable(torch.ones(input.size()[0]))
output = netD(input)
errD_real = criterion(output, target)
noise = Variable(torch.randn(input.size()[0], 100, 1, 1))
fake = netG(noise)
target = Variable(torch.zeros(input.size()[0]))
output = netD(fake.detach())
errD_fake = criterion(output, target)
errD = errD_real + errD_fake
errD.backward()
optimizerD.step()
netG.zero_grad()
target = Variable(torch.ones(input.size()[0]))
output = netD(fake)
errG = criterion(output, target)
errG.backward()
optimizerG.step()
print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f' % (epoch+1, 10, i+1, len(dataloader), errD.data, errG.data))
if i % 100 == 0:
vutils.save_image(real, '%s/real_samples.png' % "./results", normalize = True)
fake = netG(noise)
vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch+1), normalize = True)
torch.save({
'epoch': epoch,
'model_state_dict': netD.state_dict(),
'optimizer_state_dict': optimizerD.state_dict(),
'loss': errD
}, 'discriminator.pth')
torch.save({
'epoch': epoch,
'model_state_dict': netG.state_dict(),
'optimizer_state_dict': optimizerG.state_dict(),
'loss': errG
}, 'generator.pth')
вот ошибка
RuntimeError Traceback (most recent call last)
<ipython-input-23-3e55546152c7> in <module>()
26 # Creating the generator
27 netG = G()
---> 28 netG.load_state_dict(torch.load('generator.pth'))
29 netG.eval()
30 #netG.apply(weights_init)
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
767 if len(error_msgs) > 0:
768 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769 self.__class__.__name__, "\n\t".join(error_msgs)))
770
771 def _named_members(self, get_members_fn, prefix='', recurse=True):
RuntimeError: Error(s) in loading state_dict for G:
Missing key(s) in state_dict: "main.0.weight", "main.1.weight", "main.1.bias", "main.1.running_mean", "main.1.running_var", "main.3.weight", "main.4.weight", "main.4.bias", "main.4.running_mean", "main.4.running_var", "main.6.weight", "main.7.weight", "main.7.bias", "main.7.running_mean", "main.7.running_var", "main.9.weight", "main.10.weight", "main.10.bias", "main.10.running_mean", "main.10.running_var", "main.12.weight".
Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "loss".