Название | 120 практических задач |
---|---|
Автор произведения | Джейд Картер |
Жанр | |
Серия | |
Издательство | |
Год выпуска | 2024 |
isbn |
```python
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.model_selection import train_test_split
# Пусть 'landscapes' – это директория с изображениями
image_dir = 'path_to_landscape_images'
image_size = (128, 128) # Размер изображения для нейронной сети
def load_images(image_dir, image_size):
images = []
for filename in os.listdir(image_dir):
if filename.endswith(".jpg") or filename.endswith(".png"):
img_path = os.path.join(image_dir, filename)
img = Image.open(img_path).resize(image_size)
img = np.array(img)
images.append(img)
return np.array(images)
images = load_images(image_dir, image_size)
images = (images – 127.5) / 127.5 # Нормализация изображений в диапазон [-1, 1]
train_images, test_images = train_test_split(images, test_size=0.2)
```
2. Построение модели GAN
Генеративно-состязательная сеть состоит из двух частей: генератора и дискриминатора.
```python
import tensorflow as tf
from tensorflow.keras import layers
# Генератор
def build_generator():
model = tf.keras.Sequential()
model.add(layers.Dense(256, activation='relu', input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.Dense(512, activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.BatchNormalization())
model.add(layers.Dense(np.prod(image_size) * 3, activation='tanh'))
model.add(layers.Reshape((image_size[0], image_size[1], 3)))
return model
# Дискриминатор
def build_discriminator():
model = tf.keras.Sequential()
model.add(layers.Flatten(input_shape=image_size + (3,)))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
return model
# Сборка модели GAN
generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
gan_input = layers.Input(shape=(100,))
generated_image = generator(gan_input)
discriminator.trainable = False
gan_output = discriminator(generated_image)
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(optimizer='adam', loss='binary_crossentropy')
```
3. Обучение модели
```python
import tensorflow as tf
# Гиперпараметры
epochs = 10000
batch_size = 64
sample_interval = 200
latent_dim = 100
# Генерация меток
real_labels = np.ones((batch_size, 1))
fake_labels = np.zeros((batch_size, 1))
for epoch in range(epochs):
# Обучение дискриминатора
idx = np.random.randint(0, train_images.shape[0], batch_size)
real_images = train_images[idx]
noise = np.random.normal(0, 1, (batch_size, latent_dim))
fake_images = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(real_images, real_labels)
d_loss_fake = discriminator.train_on_batch(fake_images, fake_labels)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Обучение генератора
noise = np.random.normal(0, 1, (batch_size, latent_dim))
g_loss = gan.train_on_batch(noise, real_labels)
# Печать прогресса
if epoch % sample_interval == 0:
print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100*d_loss[1]}] [G loss: {g_loss}]")
sample_images(generator)
def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
noise = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, latent_dim))
gen_images = generator.predict(noise)
gen_images = 0.5 * gen_images + 0.5
fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(10, 10))
cnt = 0
for i in range(image_grid_rows):
for j in range(image_grid_columns):
axs[i,j].imshow(gen_images[cnt])
axs[i,j].axis('off')
cnt += 1
plt.show()
```
4. Генерация изображений
После завершения обучения, можно использовать генератор для создания новых изображений ландшафтов.
```python
noise = np.random.normal(0, 1, (1, latent_dim))
generated_image = generator.predict(noise)
generated_image = 0.5 * generated_image + 0.5 # Возвращение значений к диапазону [0, 1]
plt.imshow(generated_image[0])
plt.axis('off')
plt.show()
```
Этот