В этой истории мы используем учебную модель CycleGAN в TensorFlow Core. Сначала обучите учебную модель в Colaboratory.
## run all cells in colab to this line.
for epoch in range(EPOCHS):
start = time.time()
n = 0
for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
train_step(image_x, image_y)
if n % 10 == 0:
print ('.', end='')
n+=1
clear_output(wait=True)
# Using a consistent image (sample_horse) so that the progress of the model
# is clearly visible.
generate_images(generator_g, sample_horse)
if (epoch + 1) % 5 == 0:
ckpt_save_path = ckpt_manager.save()
print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
ckpt_save_path))
print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
time.time()-start))
Затем вставьте новые ячейки и запустите конвертер.
1. Установите CoreMLTools и TFCoreML.
!pip install --upgrade coremltools !pip install --upgrade tfcoreml
2, восстановить контрольные точки.
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(generator_g=generator_g,
generator_f=generator_f,
discriminator_x=discriminator_x,
discriminator_y=discriminator_y,
generator_g_optimizer=generator_g_optimizer,
generator_f_optimizer=generator_f_optimizer,
discriminator_x_optimizer=discriminator_x_optimizer,
discriminator_y_optimizer=discriminator_y_optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)
# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
ckpt.restore(ckpt_manager.latest_checkpoint)
print ('Latest checkpoint restored!!')
3. Сохраните generator_g (это генератор «horse2zebra») как формат «Сохраненной модели» для временного использования.
generator_g.save('./savedmodel')
4. Запустите конвертер.
import tfcoreml input_name = generator.inputs[0].name.split(':')[0] print(input_name) #Check input_name. keras_output_node_name = generator_g.outputs[0].name.split(':')[0] graph_output_node_name = keras_output_node_name.split('/')[-1] mlmodel = tfcoreml.convert('./savedmodel', input_name_shape_dict={input_name: (1, 256, 256, 3)}, output_feature_names=[graph_output_node_name], minimum_ios_deployment_target='13', image_input_names=input_name, image_scale=2/ 255.0, red_bias=-1, green_bias=-1, blue_bias=-1 ) mlmodel.save('./cyclegan.mlmodel')
Теперь вы можете использовать CycleGAN в своем проекте iOS.
import Vision lazy var coreMLRequest:VNCoreMLRequest = { let model = try! VNCoreMLModel(for: pix2pix().model) let request = VNCoreMLRequest(model: model, completionHandler: self.coreMLCompletionHandler0) return request }() let handler = VNImageRequestHandler(ciImage: ciimage,options: [:]) DispatchQueue.global(qos: .userInitiated).async { try? handler.perform([coreMLRequest]) }
Для визуализации multiArray как изображения очень удобны Помощники CoreML г-на Холланса.
Пожалуйста, подпишитесь на мой Twitter. Https://twitter.com/JackdeS11 И, пожалуйста, хлопайте в ладоши 👏.