Добавить информацию о классе в модель генератора в keras

Я хочу использовать условные GAN с целью создания изображений для одного домена (обозначенного как domain A), а также с входными изображениями из второго домена (обозначенного как domain B), а также с информацией о классе. Оба домена связаны с одной и той же информацией метки (каждое изображение домена А связано с изображением домена Б и определенной меткой). Мой генератор до сих пор в Керасе выглядит следующим образом:

def generator_model_v2():

   global BATCH_SIZE
   inputs = Input((IN_CH, img_cols, img_rows))   
   e1 = BatchNormalization(mode=0)(inputs)
   e2 = Flatten()(e1)
   e3 = BatchNormalization(mode=0)(e2)
   e4 = Dense(1024, activation="relu")(e3)
   e5 = BatchNormalization(mode=0)(e4)
   e6 = Dense(512, activation="relu")(e5)
   e7 = BatchNormalization(mode=0)(e6)
   e8 = Dense(512, activation="relu")(e7)
   e9 = BatchNormalization(mode=0)(e8)
   e10 = Dense(IN_CH * img_cols *img_rows, activation="relu")(e9)
   e11  = Reshape((3, 28, 28))(e10)
   e12 = BatchNormalization(mode=0)(e11)
   e13 = Activation('tanh')(e12)

   model = Model(input=inputs, output=e13)
   return model

Пока что мой генератор принимает в качестве входных данных изображения из domain A (и область для вывода изображений из domain B). Я хочу каким-то образом также ввести информацию о классе для входного домена A с возможностью создания изображений того же класса для домена B. Как я могу добавить информацию о метке после выравнивания. Таким образом, вместо того, чтобы иметь размер ввода 1x1024, например, 1x1025. Могу ли я использовать второй вход для информации о классе в генераторе. И если да, то как я могу вызвать генератор из процедуры обучения GAN?

Процедура обучения:

discriminator_and_classifier_on_generator = generator_containing_discriminator_and_classifier(
    generator, discriminator, classifier)
generator.compile(loss=generator_l1_loss, optimizer=g_optim)
discriminator_and_classifier_on_generator.compile(
    loss=[generator_l1_loss, discriminator_on_generator_loss, "categorical_crossentropy"],
    optimizer="rmsprop")
discriminator.compile(loss=discriminator_loss, optimizer=d_optim) # rmsprop
classifier.compile(loss="categorical_crossentropy", optimizer=c_optim)

for epoch in range(30):
    for index in range(int(X_train.shape[0] / BATCH_SIZE)):
        image_batch = Y_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]
        label_batch = LABEL_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]  # replace with your data here
        generated_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE])
        real_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], image_batch),axis=1)
        fake_pairs = np.concatenate((X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], generated_images), axis=1)
        X = np.concatenate((real_pairs, fake_pairs))
        y = np.concatenate((np.ones((100, 1, 64, 64)), np.zeros((100, 1, 64, 64))))
        d_loss = discriminator.train_on_batch(X, y)
        discriminator.trainable = False
        c_loss = classifier.train_on_batch(image_batch, label_batch)
        classifier.trainable = False
        g_loss = discriminator_and_classifier_on_generator.train_on_batch(
            X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], 
            [image_batch, np.ones((100, 1, 64, 64)), label_batch])
        discriminator.trainable = True
        classifier.trainable = True

Код представляет собой реализацию условных dcgans (с добавлением классификатора поверх дискриминатора). А функции сети таковы:

def generator_containing_discriminator_and_classifier(generator, discriminator, classifier):
   inputs = Input((IN_CH, img_cols, img_rows))
   x_generator = generator(inputs)
   merged = merge([inputs, x_generator], mode='concat', concat_axis=1)
   discriminator.trainable = False
   x_discriminator = discriminator(merged)
   classifier.trainable = False
   x_classifier = classifier(x_generator)
   model = Model(input=inputs, output=[x_generator, x_discriminator, x_classifier])
   return model

def generator_containing_discriminator(generator, discriminator):
   inputs = Input((IN_CH, img_cols, img_rows))
   x_generator = generator(inputs)
   merged = merge([inputs, x_generator], mode='concat',concat_axis=1)
   discriminator.trainable = False
   x_discriminator = discriminator(merged)
   model = Model(input=inputs, output=[x_generator,x_discriminator])
   return model



Ответы (1)


Сначала, следуя предложению, которое дано в условных генеративно-состязательных сетях, вы должны определить второй вход . Затем просто соедините два входных вектора и обработайте этот объединенный вектор.

def generator_model_v2():
    input_image = Input((IN_CH, img_cols, img_rows)) 
    input_conditional = Input((n_classes))  
    e0 = Flatten()(input_image) 
    e1 = Concatenate()([e0, input_conditional])   
    e2 = BatchNormalization(mode=0)(e1)
    e3 = BatchNormalization(mode=0)(e2)
    e4 = Dense(1024, activation="relu")(e3)
    e5 = BatchNormalization(mode=0)(e4)
    e6 = Dense(512, activation="relu")(e5)
    e7 = BatchNormalization(mode=0)(e6)
    e8 = Dense(512, activation="relu")(e7)
    e9 = BatchNormalization(mode=0)(e8)
    e10 = Dense(IN_CH * img_cols *img_rows, activation="relu")(e9)
    e11  = Reshape((3, 28, 28))(e10)
    e12 = BatchNormalization(mode=0)(e11)
    e13 = Activation('tanh')(e12)

    model = Model(input=[input_image, input_conditional] , output=e13)
    return model

Затем вам нужно передать метки классов во время обучения, а также в сеть:

classifier.train_on_batch((image_batch, class_batch), label_batch)
person zimmerrol    schedule 27.08.2018
comment
Но мне нужно предоставить тот же ввод label_batch для сети генератора? Поэтому мне нужно добавить метки в g_loss = DISCTOR_AND_CLASSifier_on_generator.train_on_batch( X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE, :, :, :], [image_batch, np.ones((100, 1, 64, 64) ), label_batch]) - person Jose Ramon; 27.08.2018
comment
Я предполагаю, что мне нужно воплотить информацию о классе во время этапа обучения только в generate_images = generator.predict(X_train[index * BATCH_SIZE:(index + 1) * BATCH_SIZE]) и g_loss = diversator_and_classifier_on_generator.train_on_batch(...), Правильно? - person Jose Ramon; 27.08.2018