Как загрузить изображение в генератор в GAN Pytorch

Итак, я тренирую модель DCGAN в pytorch на наборе данных celeba (люди). А вот и архитектура генератора:

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

Итак, после тренировки я хочу проверить, что выдает генератор, если я передам закрытое изображение, например: введите здесь описание изображения (размер: 64X64)

Но, как вы могли догадаться, изображение имеет 3 канала, и мой генератор принимает скрытый вектор из 100 каналов при запуске, так как же правильно передать это изображение генератору и проверить вывод. (Я ожидаю, что генератор попытается сгенерировать только скрытую часть изображения). Если вам нужен справочный код, попробуйте этот демонстрационный файл pytorch. Я изменил этот файл в соответствии со своими потребностями, так что для ссылки это поможет.


person Prithvi Raj Kanaujia    schedule 24.03.2021    source источник


Ответы (1)


Вы просто не можете этого сделать. Как вы сказали, ваша сеть ожидает 100-мерный ввод, который обычно выбирается из стандартного нормального распределения:

введите здесь описание изображения

Таким образом, работа генератора состоит в том, чтобы взять этот случайный вектор и сгенерировать изображение размером 3x64x64, неотличимое от реального изображения. Входные данные представляют собой случайный 100-мерный вектор, выбранный из стандартного нормального распределения. Я не вижу способа ввести ваше изображение в текущую сеть без изменения архитектуры и переобучения новой модели. Если вы хотите попробовать новую модель, вы можете изменить ввод на закрытые изображения, применить некоторые конв. / linear Layers, чтобы уменьшить размеры до 100, а остальную часть сети оставить прежней. Таким образом, сеть будет пытаться научиться генерировать изображения не из скрытого вектора, а из вектора признаков, извлеченного из скрытых изображений. Это может или не может работать.

РЕДАКТИРОВАТЬ Я решил попробовать и посмотреть, может ли сеть обучаться с этим типом условных входных векторов вместо скрытых векторов. Я использовал пример учебника, который вы связали, и добавил пару изменений. Сначала новая сеть для получения ввода и уменьшения ее до 100 измерений:

class ImageTransformer(nn.Module):
    def __init__(self):
        super(ImageTransformer, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 1, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )            

        self.linear = nn.Linear(32*32, 100)

    def forward(self, input):
        out = self.main(input).view(input.shape[0], -1)
        return self.linear(out).view(-1, 100, 1, 1)

Простой слой свертки + relu + линейный слой для сопоставления со 100 измерениями на выходе. Обратите внимание, что вы можете попробовать гораздо лучшую сеть здесь в качестве лучшего экстрактора функций, я просто хотел провести простой тест.

fixed_input = next(iter(dataloader))[0][0:64, :, : ,:]
fixed_input[:, :, 20:44, 20:44] = torch.tensor(np.zeros((24,24), dtype = np.float32))
fixed_input = fixed_input.to(device)

Вот как я модифицирую тензор, чтобы добавить черное пятно поверх ввода. Только что сэмплировал партию, чтобы создать фиксированный вход для отслеживания процесса, как это было сделано в учебнике со случайным вектором.

# Create the generator
netG = Generator().to(device)
netD = Discriminator().to(device)
netT = ImageTransformer().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)
netD.apply(weights_init)
netT.apply(weights_init)

# Print the model
print(netG)
print(netD)
print(netT)

Большинство шагов одинаковы, просто создан экземпляр новой сети трансформатора. Затем, наконец, слегка модифицируется обучающий цикл, в котором генератору не подаются случайные векторы, а выдаются выходные данные новой сети преобразователя.

img_list = []
G_losses = []
D_losses = []
iters = 0

for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        transformed = data[0].detach().clone()
        transformed[:, :, 20:44, 20:44] = torch.tensor(np.zeros((24,24), dtype = np.float32))
        transformed = transformed.to(device)
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu).view(-1)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        fake = netT(transformed)
        fake = netG(fake)
        label.fill_(fake_label)
        output = netD(fake.detach()).view(-1)
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake).view(-1)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netT(fixed_input)
                fake = netG(fake).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

Тренировка прошла нормально с точки зрения снижения потерь и т. д. Наконец, вот что я получил после 5-ти эпохальной тренировки:

входы

выводы

Так что же говорит нам этот результат? Поскольку входные данные генератора не были взяты случайным образом из нормального распределения, генератор не смог изучить распределение лиц, чтобы создать различный диапазон выходных лиц. А поскольку вход представляет собой условный вектор признаков, диапазон выходных изображений ограничен. Таким образом, для генератора требуются случайные входные данные, хотя он научился удалять патчи :)

person yutasrobot    schedule 24.03.2021
comment
Во-первых, спасибо за всю работу, которую вы сделали, я очень ценю это. Первоначально мне просто нужен был способ передать изображение для создания, которое вы сделали через ImageTransformer. Я даже думал об этом, но просто не был уверен. В любом случае, спасибо за это :) Во-вторых, я обучил генератор случайным входным данным, а затем моей целью было проверить, как генератор работает с этими перекрытыми изображениями, и посмотреть, какие улучшения я могу внести, чтобы генерировать более качественные и разнообразные изображения. - person Prithvi Raj Kanaujia; 25.03.2021
comment
Таким образом, я мог бы добиться хороших результатов. Также спасибо за автоматизированный способ установки патчей. Сначала я делал это вручную :'( . Но теперь это будет легко. Кроме того, пожалуйста, поделитесь, если у вас есть какие-либо другие сомнения или идеи относительно того, как я это планировал... Я был бы очень признателен. - person Prithvi Raj Kanaujia; 25.03.2021
comment
Кроме того, я не очень хорошо разбираюсь в экстракторе функций, поэтому, если вы знаете что-нибудь, что может мне помочь, это было бы здорово. - person Prithvi Raj Kanaujia; 25.03.2021
comment
и, кстати... я пытался передать изображения в генератор через ImageTransformer... так как мой генератор обучен набору случайных входных данных от генератора, я подумал, что передача этих изображений не будет проблемой... но как вы сказал, что генератор научился удалять исправление, но он не может генерировать различные выходные данные или выходные данные, которые почти похожи на исходное изображение без исправления, вероятно, потому, что мы подаем условный вектор признаков, и он, вероятно, привязан к локальным минимумам. где ассортимент ограничен... - person Prithvi Raj Kanaujia; 25.03.2021
comment
любые предложения, что можно сделать? - person Prithvi Raj Kanaujia; 25.03.2021
comment
К сожалению, ничего не могу придумать. Это было лучшее, что я мог сделать - person yutasrobot; 26.03.2021
comment
Хорошо, без проблем. Я приму ваш ответ, поскольку вы решили мою первоначальную проблему. Благодарность :) - person Prithvi Raj Kanaujia; 26.03.2021
comment
Я попробовал сеть Inception v3 в качестве экстрактора признаков, но она выводит тензор такого размера [10,2048] , где 10 — это количество изображений. Но, как мы уже говорили, генератор принимает на вход размеры [batch_size,100,1,1] Итак, какие изменения я должен внести в ImageTransformer? - person Prithvi Raj Kanaujia; 26.03.2021