Является ли Tensorflow Dataset.from_generator устаревшим в tensorflow 2.0? Выдает ошибку устаревания tf.py_func

Когда я создаю набор данных tf из генератора и пытаюсь запустить код tf2.0, он предупреждает меня сообщением об ограничении.

Код:

import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model


def my_function():
    import numpy as np
    for i in range(1000):
        yield np.random.random(size=(28, 28, 1)), [1.0]


train_ds = tf.data.Dataset.from_generator(my_function, output_types=(tf.float32, tf.float32)).batch(32)


class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

    # def __call__(self, *args, **kwargs):
    #     return super().__call(*args,**kwargs)


model = MyModel()

loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')


@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)


EPOCHS = 5

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images, labels)
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                          train_loss.result(),
                          train_accuracy.result() * 100))

Предупреждение:

........
Instructions for updating:
tf.py_func is deprecated in TF V2. Instead, there are two
    options available in V2. ........

Я хотел бы передать данные в модель из потокового ввода с помощью API набора данных (с предварительной выборкой). Несмотря на то, что это все еще возможно в текущей альфа-версии, будет ли она удалена позже?

Заменит ли tenorflow tf.py_func, используемый в наборе данных генератора, на что-то новое или весь API генератора dataset_from будет удален?


person Himaprasoon    schedule 08.05.2019    source источник


Ответы (1)


Нет, tf.data.Dataset.from_generator не будет считаться устаревшим в TensorFlow 2.0. Вы видите предупреждающее сообщение, которое используется для информирования пользователей о будущих изменениях. Если вам нужно использовать py_func напрямую, самый простой способ - использовать tf.compat.v1.py_func. У TF2.0 есть собственная оболочка, которая называется tf.py_function.

person Sharky    schedule 08.05.2019
comment
Я не хочу использовать функцию tf.py_func. Функция Dataset.from_generator внутренне использует py_func. Заменят ли они эту реализацию на что-то другое? Я просто хочу знать, является ли from_generator перспективным (по крайней мере, в ближайшем будущем). - person Himaprasoon; 08.05.2019
comment
Он включен в TF 2.0. Я предполагаю, что заменим py_func в более поздней стабильной версии - person Sharky; 09.05.2019