import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.utils.tensorboard as tensorboard
import tqdm
Данные МНИСТ
BATCH_SIZE = 64
NUM_WORKERS = 4
data_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
data = torchvision.datasets.MNIST(root="./data", train=True, transform=data_transform, download=True)
data_loader = torch.utils.data.DataLoader(dataset=data, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
num_batches = len(data_loader)
Визуализируйте цифры MNIST
x, y = next(iter(data_loader))
grid = torchvision.utils.make_grid(x, nrow=8)
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0), cmap="gray")
plt.show()
![](https://i2.wp.com/miro.medium.com/1*iv4Up5hJB7CWI6KPtQW68w.png)
print(x.shape, y.shape)
torch.Size([64, 1, 28, 28]) torch.Size([64])
Модель
1. Модель дискриминатора
class Discriminator(torch.nn.Module):
def __init__(self, in_features=784, n_out=1, hidden_dims=[1024, 512, 256]):
super(Discriminator, self).__init__()
modules = []
# hidden dims
if hidden_dims is None:
hidden_dims = [1024, 512, 256]
# create hidden layers
for h_dim in hidden_dims:
modules.append(self.nn_layer(in_features=in_features, out_features=h_dim))
in_features = h_dim
# output layer with sigmoid activation
modules.append(nn.Sequential(
nn.Linear(in_features=in_features, out_features=n_out),
nn.Sigmoid()
))
self.out = nn.Sequential(*modules)
def nn_layer(self, in_features, out_features):
layer = nn.Sequential(
nn.Linear(in_features=in_features, out_features=out_features),
nn.LeakyReLU(0.2),
nn.Dropout(0.2)
)
return layer
def forward(self, x):
out = self.out(x)
return out
2. Модель генератора
class Generator(torch.nn.Module):
def __init__(self, in_features=100, out_features=784, hidden_dims=[256, 512, 1024]):
super(Generator, self).__init__()
# noise_dim is in_features
if hidden_dims is None:
hidden_dims=[256, 512, 1024]
modules = []
for h_dim in hidden_dims:
modules.append(self.nn_layer(in_features=in_features, out_features=h_dim))
in_features = h_dim
modules.append(nn.Sequential(
nn.Linear(in_features=in_features, out_features=out_features),
nn.Tanh()
))
self.out = nn.Sequential(*modules)
def nn_layer(self, in_features, out_features):
layer = nn.Sequential(
nn.Linear(in_features=in_features, out_features=out_features),
nn.LeakyReLU(0.2)
)
return layer
def forward(self, x):
out = self.out(x)
return out
Обучение
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
writer = tensorboard.SummaryWriter()
lr = 0.0003
epochs = 100
noise_dim = 100
## Create models
gen = Generator().to(device)
disc = Discriminator().to(device)
criterion = nn.BCELoss().to(device)
d_optimizer = torch.optim.Adam(disc.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=lr)
Step: 100%|█████████▉| 936/938 [00:49<00:00, 22.08it/s]
steps = 0
epoch_progress = tqdm.tqdm(total=epochs, desc="Epoch", position=0)
for epoch in range(epochs):
g_epoch_loss = []
d_epoch_loss = []
step_progress = tqdm.tqdm(total=num_batches, desc="Step", position=0)
for i, (images, _) in enumerate(data_loader):
batch, channel, h, w = images.size(0), images.size(1), images.size(2), images.size(3)
images = images.view(-1, channel*h*w).to(device)
gt_real = torch.ones((batch, 1), device=device)
"""
-- TRAIN DISCRIMINATOR --
"""
d_optimizer.zero_grad()
# Train on real data
d_loss_real = criterion(disc(images), gt_real)
d_loss_real.backward()
# train on fake data
# create fake data from gaussian noise
noise = torch.randn((batch, noise_dim), device=device)
gt_fake = torch.zeros((batch, 1), device=device)
d_loss_fake = criterion(disc(gen(noise)), gt_fake)
d_loss_fake.backward()
# update the params
d_optimizer.step()
d_loss = d_loss_real + d_loss_fake
"""
-- TRAIN GENERATOR --
"""
g_optimizer.zero_grad()
g_loss = criterion(disc(gen(noise)), gt_real) # noise train as real image
g_loss.backward()
g_optimizer.step()
writer.add_scalar("step_wise_loss", d_loss.item(), steps)
writer.add_scalar("step_wise_loss", g_loss.item(), steps)
g_epoch_loss.append(g_loss.item())
d_epoch_loss.append(d_loss.item())
if steps%5000==0:
print(f'epoch {epoch} | step {steps} | d_loss {d_loss.item()} | g_loss {g_loss.item()}')
step_progress.update(1)
steps += 1
g_loss_avg = sum(g_epoch_loss)/len(g_epoch_loss)
d_loss_avg = sum(d_epoch_loss)/len(d_epoch_loss)
writer.add_scalar("epoch_wise_loss", g_loss_avg, epoch)
writer.add_scalar("epoch_wise_loss", d_loss_avg, epoch)
epoch_progress.update(1)
Step: 1%| | 5/938 [00:00<01:00, 15.44it/s]
epoch 0 | step 0 | d_loss 1.39686918258667 | g_loss 0.7280533313751221
Step: 34%|███▎ | 316/938 [00:12<00:27, 22.55it/s]
epoch 5 | step 5000 | d_loss 1.1108969449996948 | g_loss 1.825584888458252
Step: 67%|██████▋ | 625/938 [00:24<00:12, 24.82it/s]
epoch 10 | step 10000 | d_loss 0.5064196586608887 | g_loss 1.7891119718551636
Step: 100%|█████████▉| 934/938 [00:37<00:00, 25.54it/s]
epoch 15 | step 15000 | d_loss 0.6235991716384888 | g_loss 3.147066116333008
Step: 33%|███▎ | 307/938 [00:12<00:24, 25.38it/s]
epoch 21 | step 20000 | d_loss 0.816655158996582 | g_loss 1.828172206878662
Step: 66%|██████▌ | 615/938 [00:27<00:14, 22.83it/s]
epoch 26 | step 25000 | d_loss 1.167340636253357 | g_loss 2.020320415496826
Step: 99%|█████████▊| 925/938 [00:44<00:00, 14.78it/s]
epoch 31 | step 30000 | d_loss 0.5045021772384644 | g_loss 2.1074342727661133
Step: 32%|███▏ | 297/938 [00:13<00:30, 20.99it/s]
epoch 37 | step 35000 | d_loss 0.5185412764549255 | g_loss 3.7087910175323486
Step: 65%|██████▍ | 607/938 [00:33<00:20, 16.46it/s]
epoch 42 | step 40000 | d_loss 0.6276625394821167 | g_loss 2.220449686050415
Step: 98%|█████████▊| 917/938 [00:39<00:01, 19.59it/s]
epoch 47 | step 45000 | d_loss 0.43490731716156006 | g_loss 1.9939409494400024
Step: 31%|███ | 291/938 [00:10<00:22, 28.20it/s]
epoch 53 | step 50000 | d_loss 0.516788125038147 | g_loss 2.9438328742980957
Step: 64%|██████▍ | 601/938 [00:34<00:13, 24.51it/s]
epoch 58 | step 55000 | d_loss 0.677204966545105 | g_loss 2.088684558868408
Step: 97%|█████████▋| 910/938 [00:34<00:01, 26.59it/s]
epoch 63 | step 60000 | d_loss 0.6085101962089539 | g_loss 3.6483681201934814
Step: 30%|███ | 282/938 [00:12<00:32, 20.38it/s]
epoch 69 | step 65000 | d_loss 0.3501152992248535 | g_loss 3.606964588165283
Step: 63%|██████▎ | 593/938 [00:27<00:14, 24.26it/s]
epoch 74 | step 70000 | d_loss 0.559038519859314 | g_loss 3.5148210525512695
Step: 96%|█████████▋| 903/938 [00:49<00:01, 23.07it/s]
epoch 79 | step 75000 | d_loss 0.4314274787902832 | g_loss 4.849875450134277
Step: 29%|██▉ | 273/938 [00:15<00:33, 19.90it/s]t]
epoch 85 | step 80000 | d_loss 0.46567481756210327 | g_loss 3.6981945037841797
Step: 62%|██████▏ | 584/938 [00:32<00:19, 17.83it/s]t]
epoch 90 | step 85000 | d_loss 0.2274802029132843 | g_loss 3.7563960552215576
Step: 95%|█████████▌| 893/938 [00:38<00:02, 15.92it/s]t]
epoch 95 | step 90000 | d_loss 0.1932884156703949 | g_loss 3.552668809890747
Step: 100%|██████████| 938/938 [00:53<00:00, 25.95it/s]it]
Протестируйте генераторную сеть
gen.eval()
noise = torch.randn((BATCH_SIZE, noise_dim))
data = gen(noise).view(-1, 1, 28, 28)
data.size()
torch.Size([64, 1, 28, 28])
grid = torchvision.utils.make_grid(data, nrow=8)
plt.figure(figsize=(24, 24))
plt.imshow(grid.detach().cpu().permute(1, 2, 0), cmap="gray")
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7f988b505290>
![](https://i2.wp.com/miro.medium.com/1*G_OG5DXZdlZoid7_e4pzuQ.png)
Сгенерированные данные не на должном уровне, возможное исправление
- Больше обучения
- Добавить больше слоев
- Нормализуйте значения пикселей между [-1, 1]
- Скорость обучения и оптимизатор
- Попробуйте изменить функции активации сети генераторов, замените LeakyRelu на Tanh