изменения в примере кода:
import tensorflow as tf
#print(tf.__version__)#want 2.2.0
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import PIL
from tensorflow.keras import layers
import time
from IPython import display
DATADIR = "C:\image generator\marudata"
Category = ["meow"]
path = os.path.join(DATADIR)
for img in os.listdir(path):
img_array = cv2.imread(os.path.join(path,img), cv2.IMREAD_GRAYSCALE)
train_images = img_array
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]
#127.5 = 255/2 (0 is black 255 is white)
BUFFER_SIZE = 13056000
BATCH_SIZE = 32
###
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(5*6*13600, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((5, 6, 13600)))
assert model.output_shape == (None, 5, 6, 13600) # Note: None is the batch size
model.add(layers.Conv2DTranspose(256, (5, 5), strides=(1, 1), padding='same', use_bias=False))#transpose layer is the inverse of a normal filter layer.
#so strides act more like 1/2 instead of 2
assert model.output_shape == (None, 5, 6, 256)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(5, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 25, 12, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(34, 40), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 850, 480, 1)
return model
###useing this to cut up code to reduce wall of text size
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[850, 480, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
ошибка
Traceback (most recent call last):
File "C:\image generator\image_generator.py", line 197, in <module>
train(train_dataset, EPOCHS)
File "C:\image generator\image_generator.py", line 162, in train
train_step(image_batch)
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 627, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 506, in _initialize
*args, **kwds))
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\eager\function.py", line 2446, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\eager\function.py", line 2777, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\eager\function.py", line 2667, in _create_graph_function
capture_by_value=self._capture_by_value),
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\framework\func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 441, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\framework\func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
C:\image generator\image_generator.py:145 train_step *
real_output = discriminator(images, training=True)
C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:886 __call__ **
self.name)
C:\Users\will\miniconda3\lib\site-packages\tensorflow\python\keras\engine\input_spec.py:180 assert_input_compatibility
str(x.shape.as_list()))
ValueError: Input 0 of layer sequential_1 is incompatible with the layer: expected ndim=4, found ndim=2. Full shape received: [32, 850]
строки 138-197
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)#line 162
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
# Save the model every 10 epochs
if (epoch + 1) % 10 == 0:
checkpoint.save(file_prefix = checkpoint_prefix)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# Generate after the final epoch
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
#plt.show()
train(train_dataset, EPOCHS)#line 197
все остальное должно по-прежнему быть точной копией кода в предоставленной ссылке.