преобразовать модель pytorch с несколькими сетями в onnx

Я пытаюсь преобразовать модель pytorch с несколькими сетями в ONNX и столкнулся с некоторой проблемой.

Репозиторий git: https://github.com/InterDigitalInc/HRFAE.

Тренерский класс:

class Trainer(nn.Module):
    def __init__(self, config):
        super(Trainer, self).__init__()
        # Load Hyperparameters
        self.config = config
        # Networks
        self.enc = Encoder()
        self.dec = Decoder()
        self.mlp_style = Mod_Net()
        self.dis = Dis_PatchGAN()
        ...

Вот как обученная модель обрабатывает изображение:

    def gen_encode(self, x_a, age_a, age_b=0, training=False, target_age=0):
        if target_age:
            self.target_age = target_age
            age_modif = self.target_age*torch.ones(age_a.size()).type_as(age_a)
        else:
            age_modif = self.random_age(age_a, diff_val=25)

        # Generate modified image
        self.content_code_a, skip_1, skip_2 = self.enc(x_a)
        style_params_a = self.mlp_style(age_a)
        style_params_b = self.mlp_style(age_modif)
        
        x_a_recon = self.dec(self.content_code_a, style_params_a, skip_1, skip_2)
        x_a_modif = self.dec(self.content_code_a, style_params_b, skip_1, skip_2)
        
        return x_a_recon, x_a_modif, age_modif

И вот как я конвертировал в onnx:

enc = Encoder()
dec = Decoder()
mlp = Mod_Net()
layers = [enc, mlp, dec]
model = torch.nn.Sequential(*layers)  
# here is my confusion: how do I specify the inputs of each layer?? 
# E.g. one of the outputs of 'enc' layer should be input of 'mlp' layer, 
# or the outputs of 'enc' layer should be part of inputs of 'dec' layer...

params = torch.load('./logs/001/checkpoint')  
model[0].load_state_dict(params['enc_state_dict'])
model[1].load_state_dict(params['mlp_style_state_dict'])
model[2].load_state_dict(params['dec_state_dict'])

torch.onnx.export(model, torch.randn([1, 3, 1024, 1024]), 'trained_hrfae.onnx', do_constant_folding=True)  

Может быть, код конвертируемой части неправильный? Может кто поможет, большое спасибо!

#20210629-11:52GMT Редактировать:

Я обнаружил, что существует ограничение на использование torch.nn.Sequential. Вывод предыдущего слоя в Sequential должен соответствовать последнему вводу. Так что мой код вообще не должен работать, потому что вывод слоя enc не согласуется с вводом слоя mlp.

Может ли кто-нибудь помочь, как преобразовать этот тип модели pytorch в onnx? Большое спасибо, еще раз :)


person ZZ Shao    schedule 29.06.2021    source источник
comment
Пожалуйста, не прячьте сообщение об ошибке в самом коде; опубликуйте полную трассировку ошибки — посмотрите, как создать минимально воспроизводимый пример.   -  person desertnaut    schedule 29.06.2021
comment
@desertnaut спасибо за совет. но я думаю, что главная проблема - это сам код. сообщение об ошибке предназначено только для комментариев. Буду иметь в виду на будущее.   -  person ZZ Shao    schedule 29.06.2021
comment
я не следую; конечно проблема в коде (где еще может быть?). Возможно, вы захотите запомнить это только на будущее, но также имейте в виду, что ваш вопрос как таковой подлежит закрытию из-за отсутствия подробностей и/или MRE.   -  person desertnaut    schedule 29.06.2021
comment
Ладно, я понял. Я имею в виду, что сообщение об ошибке используется, чтобы помочь объяснить ситуацию, поэтому я принимаю их в комментариях. Прямо сейчас я нахожу другую проблему с кодом, я собираюсь ее отредактировать. Заодно уберу ошибку в коде, как вы предложили. Спасибо.   -  person ZZ Shao    schedule 29.06.2021


Ответы (1)


После исследований и попыток я нашел метод, который может быть правильным:

Преобразуйте каждую сеть (кодировщик, Mod_Net, декодер) в модель onnx и обработайте их ввод/вывод в последнем логическом процессе или любой другой процедуре (например, преобразование в модель tflite).

Я пытаюсь портировать на Android, используя этот метод.

#Edit 20210705-03:52GMT#

Другой подход может быть лучше: написать новую сеть, объединяющую три сети. Я доказал, что результат такой же, как у исходной модели pytorch.

class HRFAE(nn.Module):
    def __init__(self):
        super(HRFAE, self).__init__()
        self.enc = Encoder()
        self.mlp_style = Mod_Net()
        self.dec = Decoder()

    def forward(self, x, age_modif):
        content_code_a, skip_1, skip_2 = self.enc(x)
        style_params_b = self.mlp_style(age_modif)

        x_a_modif = self.dec(content_code_a, style_params_b, skip_1, skip_2)

        return x_a_modif

а затем преобразовать использовать следующее:

net = HRFAE()

params = torch.load('./logs/002/checkpoint')
net.enc.load_state_dict(params['enc_state_dict'])
net.mlp_style.load_state_dict(params['mlp_style_state_dict'])
net.dec.load_state_dict(params['dec_state_dict'])

net.eval()
torch.onnx.export(net, (torch.randn([1, 3, 512, 512]), torch.randn([1]).type(torch.long)), 'test_hrfae.onnx')

Это должен быть ответ.

person ZZ Shao    schedule 02.07.2021