Как кэшировать и перебирать набор данных неизвестного размера?

При добавлении шага .cache() в мой конвейер набора данных последующие эпохи обучения по-прежнему загружают данные из сетевого хранилища.

У меня есть набор данных в сетевом хранилище. Я хочу кэшировать его, но не повторять: эпоха обучения должна проходить через весь набор данных. Вот мой конвейер построения набора данных:

return tf.data.Dataset.list_files(
        file_pattern
    ).interleave(
        tf.data.TFRecordDataset,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    ).shuffle(
        buffer_size=2048
    ).batch(
        batch_size=2048,
        drop_remainder=True,
    ).cache(
    ).map(
        map_func=_parse_example_batch,
        num_parallel_calls=tf.data.experimental.AUTOTUNE
    ).prefetch(
        buffer_size=32
    )

Если я использую его как есть, набор данных загружается в каждую эпоху. Чтобы этого избежать, я должен добавить в конвейер шаг .repeat() и использовать ключевое слово steps_per_epoch функции model.fit. Однако мне неизвестен размер полного набора данных, поэтому я не могу передать правильное значение steps_per_epoch.

Как правильно кэшировать и использовать набор данных неизвестного размера?

Спасибо.


Редактировать

Читая некоторый код TF, я (повторно) обнаружил _7 _. Кажется, это именно то, что я ищу, то есть повторять несколько раз по одному и тому же набору данных (используя кеш после первой итерации). Однако это устарело и больше не является частью основного API в TF2.

Инструкция по обновлению состоит в том, чтобы вручную перебрать набор данных с помощью for ... in dataset. Разве это не то, что делает функция keras.Model.fit? Должен ли я писать цикл обучения вручную, чтобы получить преимущества кеширования?

Добрый.


person AlexisBRENON    schedule 17.09.2019    source источник


Ответы (2)


В TF2.0 вам не нужен .repeat(). К

последующие эпохи обучения по-прежнему загружают данные из сетевого хранилища.

Думаю, вы запутались с сообщением filling up shuffle buffer. Это происходит перед каждой эпохой, если вы используете функцию shuffle(). Может, попробуй без shuffle(), просто чтобы увидеть разницу. Кроме того, я бы посоветовал вам использовать cache() после map() и до batch().

ИЗМЕНИТЬ

заполнение буфера перемешивания

это сообщение, которое вы получаете при использовании функции shuffle. Вы все еще можете shuffle() набор данных после использования cache(). Посмотрите здесь. Кроме того, если я правильно понял, вы загружаете полученный набор данных из map() к вашей модели для обучения, тогда вам следует cache() этот набор данных, а не другой, потому что обучение будет проводиться на нем. Для подсчета количества элементов в вашем наборе данных вы можете использовать следующий код

num_elements = 0
for element in dataset: # tf.dataset type
  num_elements += 1
print ('Total number of elements in the file: ',num_elements)

Теперь, ныряя в этот num_elements своим batch_size, вы получите steps_per_epoch

person Rishabh Sahrawat    schedule 18.09.2019
comment
Спасибо, но я не получаю filling up shuffle buffer (я думаю, что если я кэширую после shuffle(), то shuffle() не происходит в каждую эпоху). Я знаю, что он загружает данные, потому что итерация очень медленная (по сравнению с тем, когда набор данных кэшируется (с использованием материалов repeat и steps_per_epoch)). Мой процессор не перегружен, и map() мой набор данных стал больше, поэтому я cache() перед этим. - person AlexisBRENON; 18.09.2019
comment
Я отредактировал свой ответ на ваши вопросы в комментарии. - person Rishabh Sahrawat; 18.09.2019
comment
Спасибо за подробности, но я не думаю, что у меня есть какие-либо проблемы с шагом в случайном порядке. Мой главный вопрос: как использовать cache() с fit(), не требуя repeat() и steps_per_epoch? Я хочу кэшировать свой набор данных до map(), потому что после него размер набора данных умножается на 8! Не имеет значения, нужно ли каждый раз пересчитывать map(). Фрагмент кода, которым вы делитесь для подсчета элементов, выполняется примерно за 4 часа. Я не хочу ждать на этот раз перед тренировкой. - person AlexisBRENON; 18.09.2019

Хорошие новости! В окончательном выпуске v2.0.0 это поведение исправлено.

Вот фрагмент кода, чтобы выделить различное поведение.

import time

import tensorflow as tf
import tensorflow.keras as keras

# Simple layer that just print its inputs
class Print(keras.layers.Layer):

       def compute_output_signature(self, input_signature):
              return input_signature

       def call(self, inputs, **kwargs):
              tf.print(inputs)
              return inputs

# Generator returning incremented values each time it is re-initialized
generator_list = [0]
def generator():
       v = generator_list[-1]
       generator_list.append(v+1)
       tf.print("Generating samples with value {}".format(v))
       time.sleep(2)
       for i in range(2):
              yield (tf.constant([v]), tf.constant(v))


def main():
       model_input = keras.layers.Input(shape=(1,))
       model_output = Print()(model_input)
       model = keras.Model(inputs=model_input, outputs=model_output)
       model.compile("adam", loss="mae")

       ds = tf.data.Dataset.from_generator(
              generator, (tf.int64, tf.int64), ([1], [])
       )
       cached_ds = ds.cache()

       tf.print("Fit")
       model.fit(
              cached_ds,
              epochs=3,
              verbose=2
       )

       tf.print("For ... in ...")
       for i in range(3):
              for x, y in cached_ds:
                     model(x)

if __name__ == '__main__':
    main()

Вот результат с tenorflow 2.0.0-b1 (используется на платформе Google AI):

Fit
Epoch 1/3
Generating samples with value 0
# sleep 2s
2019-10-03 15:45:32.718522: W tensorflow/compiler/jit/mark_for_compilation_pass.cc:1483] (One-time warning): Not using XLA:CPU for cluster because envvar TF_XLA_FLAGS=--tf_xla_cpu_global_jit was not set.  If you want XLA:CPU, either set that envvar, or use experimental_jit_scope to enable XLA:CPU.  To confirm that XLA is active, pass --vmodule=xla_compilation_cache=1 (as a proper command-line flag, not via TF_XLA_FLAGS) or set the envvar XLA_FLAGS=--xla_hlo_profile.
[[0]]
[[0]]
2/2 - 2s - loss: 0.0000e+00
Generating samples with value 1
# sleep 2s
Epoch 2/3
[[1]]
[[1]]
2/2 - 2s - loss: 0.0000e+00
Epoch 3/3
2019-10-03 15:45:34.774195: W tensorflow/core/kernels/data/cache_dataset_ops.cc:815] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Generating samples with value 2
# sleep 2s
[[2]]
[[2]]
2019-10-03 15:45:36.782046: W tensorflow/core/kernels/data/cache_dataset_ops.cc:815] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2/2 - 2s - loss: 0.0000e+00
For ... in ...
Generating samples with value 3
# sleep 2s
[3]
[3]
Generating samples with value 4
# sleep 2s
[4]
[4]
Generating samples with value 5
# sleep 2s
[5]
[5]

Вы можете видеть, что значение тензора увеличивается для каждой эпохи, и инструкция сна выполняется каждый раз. Более того, мы получаем предупреждение об усеченном итераторе ...

Теперь с tenorflow 2.0.0:

Fit
Epoch 1/3
WARNING:tensorflow:The list of trainable weights is empty. Make sure that you are not setting model.trainable to False before compiling the model.
Generating samples with value 0
# sleep 2s
[[0]]
[[0]]
2019-10-03 15:49:59.587796: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
     [[{{node IteratorGetNext}}]]
2/2 - 2s - loss: 0.0000e+00
Epoch 2/3
[[0]]
[[0]]
2019-10-03 15:49:59.598144: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
     [[{{node IteratorGetNext}}]]
2/2 - 0s - loss: 0.0000e+00
Epoch 3/3
[[0]]
[[0]]
2019-10-03 15:49:59.605260: W tensorflow/core/common_runtime/base_collective_executor.cc:216] BaseCollectiveExecutor::StartAbort Out of range: End of sequence
     [[{{node IteratorGetNext}}]]
For ... in ...
2/2 - 0s - loss: 0.0000e+00
[0]
[0]
[0]
[0]
[0]
[0]

И вуаля! Функция генератора выполняется только один раз, без сна и всегда с тем же значением тензора. У меня просто есть несколько предупреждений об окончании последовательности, но я могу это поддержать!

Добрый.

person AlexisBRENON    schedule 03.10.2019