Я пытаюсь преобразовать модель pytorch с несколькими сетями в ONNX и столкнулся с некоторой проблемой.
Репозиторий git: https://github.com/InterDigitalInc/HRFAE.
Тренерский класс:
class Trainer(nn.Module):
def __init__(self, config):
super(Trainer, self).__init__()
# Load Hyperparameters
self.config = config
# Networks
self.enc = Encoder()
self.dec = Decoder()
self.mlp_style = Mod_Net()
self.dis = Dis_PatchGAN()
...
Вот как обученная модель обрабатывает изображение:
def gen_encode(self, x_a, age_a, age_b=0, training=False, target_age=0):
if target_age:
self.target_age = target_age
age_modif = self.target_age*torch.ones(age_a.size()).type_as(age_a)
else:
age_modif = self.random_age(age_a, diff_val=25)
# Generate modified image
self.content_code_a, skip_1, skip_2 = self.enc(x_a)
style_params_a = self.mlp_style(age_a)
style_params_b = self.mlp_style(age_modif)
x_a_recon = self.dec(self.content_code_a, style_params_a, skip_1, skip_2)
x_a_modif = self.dec(self.content_code_a, style_params_b, skip_1, skip_2)
return x_a_recon, x_a_modif, age_modif
И вот как я конвертировал в onnx:
enc = Encoder()
dec = Decoder()
mlp = Mod_Net()
layers = [enc, mlp, dec]
model = torch.nn.Sequential(*layers)
# here is my confusion: how do I specify the inputs of each layer??
# E.g. one of the outputs of 'enc' layer should be input of 'mlp' layer,
# or the outputs of 'enc' layer should be part of inputs of 'dec' layer...
params = torch.load('./logs/001/checkpoint')
model[0].load_state_dict(params['enc_state_dict'])
model[1].load_state_dict(params['mlp_style_state_dict'])
model[2].load_state_dict(params['dec_state_dict'])
torch.onnx.export(model, torch.randn([1, 3, 1024, 1024]), 'trained_hrfae.onnx', do_constant_folding=True)
Может быть, код конвертируемой части неправильный? Может кто поможет, большое спасибо!
#20210629-11:52GMT Редактировать:
Я обнаружил, что существует ограничение на использование torch.nn.Sequential. Вывод предыдущего слоя в Sequential должен соответствовать последнему вводу. Так что мой код вообще не должен работать, потому что вывод слоя enc не согласуется с вводом слоя mlp.
Может ли кто-нибудь помочь, как преобразовать этот тип модели pytorch в onnx? Большое спасибо, еще раз :)