import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.utils.tensorboard as tensorboard
import tqdm

Данные МНИСТ

data_transform = torchvision.transforms.Compose([
data = torchvision.datasets.MNIST(root="./data", train=True, transform=data_transform, download=True)
data_loader =, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)
num_batches = len(data_loader)

Визуализируйте цифры MNIST

x, y = next(iter(data_loader))
grid = torchvision.utils.make_grid(x, nrow=8)
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0), cmap="gray")

print(x.shape, y.shape)
torch.Size([64, 1, 28, 28]) torch.Size([64])


1. Модель дискриминатора

class Discriminator(torch.nn.Module):
    def __init__(self, in_features=784, n_out=1, hidden_dims=[1024, 512, 256]):
        super(Discriminator, self).__init__()
        modules  = []
        # hidden dims
        if hidden_dims is None:
            hidden_dims = [1024, 512, 256]
        # create hidden layers
        for h_dim in hidden_dims:
            modules.append(self.nn_layer(in_features=in_features, out_features=h_dim))
            in_features = h_dim
        # output layer with sigmoid activation
            nn.Linear(in_features=in_features, out_features=n_out),
        self.out = nn.Sequential(*modules)
    def nn_layer(self, in_features, out_features):
        layer = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=out_features),
        return layer

    def forward(self, x):
        out = self.out(x)
        return out

2. Модель генератора

class Generator(torch.nn.Module):

    def __init__(self, in_features=100, out_features=784, hidden_dims=[256, 512, 1024]):
        super(Generator, self).__init__()
        # noise_dim is in_features 
        if hidden_dims is None:
            hidden_dims=[256, 512, 1024]
        modules = []
        for h_dim in hidden_dims:
            modules.append(self.nn_layer(in_features=in_features, out_features=h_dim))
            in_features = h_dim
            nn.Linear(in_features=in_features, out_features=out_features),
        self.out = nn.Sequential(*modules)
    def nn_layer(self, in_features, out_features):
        layer = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=out_features),
        return layer

    def forward(self, x):
        out = self.out(x)
        return out


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
writer = tensorboard.SummaryWriter()
lr = 0.0003
epochs = 100
noise_dim = 100
## Create models
gen = Generator().to(device)
disc = Discriminator().to(device)
criterion = nn.BCELoss().to(device)
d_optimizer = torch.optim.Adam(disc.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(gen.parameters(), lr=lr)
steps = 0
epoch_progress = tqdm.tqdm(total=epochs, desc="Epoch", position=0)

for epoch in range(epochs):
    g_epoch_loss = []
    d_epoch_loss = []
    step_progress = tqdm.tqdm(total=num_batches, desc="Step", position=0)
    for i, (images, _) in enumerate(data_loader):
        batch, channel, h, w = images.size(0), images.size(1), images.size(2), images.size(3)
        images = images.view(-1, channel*h*w).to(device)
        gt_real = torch.ones((batch, 1), device=device)
        # Train on real data
        d_loss_real = criterion(disc(images), gt_real)
        # train on fake data
        # create fake data from gaussian noise
        noise = torch.randn((batch, noise_dim), device=device)
        gt_fake = torch.zeros((batch, 1), device=device)
        d_loss_fake = criterion(disc(gen(noise)), gt_fake)
        # update the params
        d_loss = d_loss_real + d_loss_fake
        -- TRAIN GENERATOR --
        g_loss = criterion(disc(gen(noise)), gt_real) # noise train as real image
        writer.add_scalar("step_wise_loss", d_loss.item(), steps)
        writer.add_scalar("step_wise_loss", g_loss.item(), steps)
        if steps%5000==0:
            print(f'epoch {epoch} |  step  {steps} | d_loss {d_loss.item()} | g_loss {g_loss.item()}')
        steps += 1
    g_loss_avg = sum(g_epoch_loss)/len(g_epoch_loss)
    d_loss_avg = sum(d_epoch_loss)/len(d_epoch_loss)
    writer.add_scalar("epoch_wise_loss", g_loss_avg, epoch)
    writer.add_scalar("epoch_wise_loss", d_loss_avg, epoch)
epoch 0 |  step  0 | d_loss 1.39686918258667 | g_loss 0.7280533313751221

epoch 5 |  step  5000 | d_loss 1.1108969449996948 | g_loss 1.825584888458252

epoch 10 |  step  10000 | d_loss 0.5064196586608887 | g_loss 1.7891119718551636

epoch 15 |  step  15000 | d_loss 0.6235991716384888 | g_loss 3.147066116333008

epoch 21 |  step  20000 | d_loss 0.816655158996582 | g_loss 1.828172206878662

epoch 26 |  step  25000 | d_loss 1.167340636253357 | g_loss 2.020320415496826

epoch 31 |  step  30000 | d_loss 0.5045021772384644 | g_loss 2.1074342727661133

epoch 37 |  step  35000 | d_loss 0.5185412764549255 | g_loss 3.7087910175323486

epoch 42 |  step  40000 | d_loss 0.6276625394821167 | g_loss 2.220449686050415

epoch 47 |  step  45000 | d_loss 0.43490731716156006 | g_loss 1.9939409494400024

epoch 53 |  step  50000 | d_loss 0.516788125038147 | g_loss 2.9438328742980957

epoch 58 |  step  55000 | d_loss 0.677204966545105 | g_loss 2.088684558868408

epoch 63 |  step  60000 | d_loss 0.6085101962089539 | g_loss 3.6483681201934814

epoch 69 |  step  65000 | d_loss 0.3501152992248535 | g_loss 3.606964588165283

epoch 74 |  step  70000 | d_loss 0.559038519859314 | g_loss 3.5148210525512695

epoch 79 |  step  75000 | d_loss 0.4314274787902832 | g_loss 4.849875450134277

epoch 85 |  step  80000 | d_loss 0.46567481756210327 | g_loss 3.6981945037841797

epoch 90 |  step  85000 | d_loss 0.2274802029132843 | g_loss 3.7563960552215576

epoch 95 |  step  90000 | d_loss 0.1932884156703949 | g_loss 3.552668809890747

Протестируйте генераторную сеть

noise = torch.randn((BATCH_SIZE, noise_dim))
data = gen(noise).view(-1, 1, 28, 28)
torch.Size([64, 1, 28, 28])
grid = torchvision.utils.make_grid(data, nrow=8)
plt.figure(figsize=(24, 24))
plt.imshow(grid.detach().cpu().permute(1, 2, 0), cmap="gray")
Сгенерированные данные не на должном уровне, возможное исправление

  • Больше обучения
  • Добавить больше слоев
  • Нормализуйте значения пикселей между [-1, 1]
  • Скорость обучения и оптимизатор
  • Попробуйте изменить функции активации сети генераторов, замените LeakyRelu на Tanh