Не удалось сохранить модель tf.keras с классификатором Bert (huggingface)

Я обучаю бинарный классификатор, который использует Берта (обнимающее лицо). Модель выглядит так:

def get_model(lr=0.00001):
    inp_bert = Input(shape=(512), dtype="int32")
    bert = TFBertModel.from_pretrained('bert-base-multilingual-cased')(inp_bert)[0]
    doc_encodings = tf.squeeze(bert[:, 0:1, :], axis=1)
    out = Dense(1, activation="sigmoid")(doc_encodings)
    model = Model(inp_bert, out)
    adam = optimizers.Adam(lr=lr)
    model.compile(optimizer=adam, loss="binary_crossentropy", metrics=["accuracy"])
    return model

После точной настройки для моей задачи классификации я хочу сохранить модель.

model.save("best_model.h5")

Однако это вызывает NotImplementedError:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
<ipython-input-55-8c5545f0cd9b> in <module>()
----> 1 model.save("best_spam.h5")
      2 # import transformers

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
    973     """
    974     saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975                       signatures, options)
    976 
    977   def save_weights(self, filepath, overwrite=True, save_format=None):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
    110           'or using `save_weights`.')
    111     hdf5_format.save_model_to_hdf5(
--> 112         model, filepath, overwrite, include_optimizer)
    113   else:
    114     saved_model_save.save(model, filepath, overwrite, include_optimizer,

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
     97 
     98   try:
---> 99     model_metadata = saving_utils.model_metadata(model, include_optimizer)
    100     for k, v in model_metadata.items():
    101       if isinstance(v, (dict, list, tuple)):

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    163   except NotImplementedError as e:
    164     if require_config:
--> 165       raise e
    166 
    167   metadata = dict(

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    160   model_config = {'class_name': model.__class__.__name__}
    161   try:
--> 162     model_config['config'] = model.get_config()
    163   except NotImplementedError as e:
    164     if require_config:

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    885     if not self._is_graph_network:
    886       raise NotImplementedError
--> 887     return copy.deepcopy(get_network_config(self))
    888 
    889   @classmethod

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
   1940           filtered_inbound_nodes.append(node_data)
   1941 
-> 1942     layer_config = serialize_layer_fn(layer)
   1943     layer_config['name'] = layer.name
   1944     layer_config['inbound_nodes'] = filtered_inbound_nodes

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    138   if hasattr(instance, 'get_config'):
    139     return serialize_keras_class_and_config(instance.__class__.__name__,
--> 140                                             instance.get_config())
    141   if hasattr(instance, '__name__'):
    142     return instance.__name__

~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
    884   def get_config(self):
    885     if not self._is_graph_network:
--> 886       raise NotImplementedError
    887     return copy.deepcopy(get_network_config(self))
    888 

NotImplementedError: 

Мне известно, что huggingface предоставляет метод model.save_pretrained () для TFBertModel, но я предпочитаю обернуть его в tf.keras.Model, поскольку я планирую добавить в эту сеть другие компоненты / функции. Может кто подскажет решение по сохранению актуальной модели?


person boltzsman    schedule 09.01.2020    source источник
comment
Из нескольких обсуждений на странице tenorflow в GIT, я думаю, что это проблема с tensorflow 2.0, попробуйте обновить / понизить версию tenorflow.   -  person Ashwin Geet D'Sa    schedule 09.01.2020
comment
Также model.save("model_name",save_format='tf') должен работать   -  person Ashwin Geet D'Sa    schedule 09.01.2020
comment
model.save (имя_модели, save_format = 'tf') решил мою проблему. Спасибо! Если вы разместите свой комментарий в качестве ответа, я его приму.   -  person boltzsman    schedule 09.01.2020


Ответы (1)


Это действительно проблема с tensorflow 2.0.

Пожалуйста, используйте: model.save("model_name",save_format='tf')

Кроме того, вы также можете попробовать обновить или понизить версию tenorflow.

person Ashwin Geet D'Sa    schedule 10.01.2020