Название | Нейросети: создание и оптимизация будущего |
---|---|
Автор произведения | Джеймс Девис |
Жанр | |
Серия | |
Издательство | |
Год выпуска | 2025 |
isbn |
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms
# Определение модели (например, простой полносвязной сети)
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28 * 28) # Преобразование входа
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Настройка данных (например, MNIST)
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
# Инициализация модели, функции потерь и оптимизатора
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) # Начальная скорость обучения 0.1
# Планировщик скорости обучения: уменьшаем LR каждые 5 эпох на фактор 0.5
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)
# Процесс обучения
num_epochs = 15
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad() # Сброс градиентов
outputs = model(inputs) # Прямой проход
loss = criterion(outputs, labels) # Вычисление потерь
loss.backward() # Обратное распространение
optimizer.step() # Обновление весов
running_loss += loss.item()
# Обновление скорости обучения по планировщику
scheduler.step()
# Вывод информации об эпохе
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader):.4f}, LR: {scheduler.get_last_lr()[0]:.5f}")
```
Объяснение кода
1. Инициализация оптимизатора: Используется `SGD` (стохастический градиентный спуск) с начальной скоростью обучения ( 0.1 ) и моментом ( 0.9 ).
2. Планировщик скорости обучения: Планировщик `StepLR` уменьшает скорость обучения на фактор ( gamma = 0.5 ) каждые 5 эпох. Вывод текущего значения скорости обучения в конце каждой эпохи с помощью `scheduler.get_last_lr()`.
3. Прогресс скорости обучения: Сначала скорость обучения высокая (( 0.1 )) для быстрого уменьшения потерь, затем она постепенно уменьшается, что позволяет более точно достичь минимума функции потерь.
Этот подход показывает, как управлять скоростью обучения для повышения стабильности и эффективности процесса обучения.
2. Момент (Momentum)
Момент (momentum) – это метод, используемый в алгоритмах оптимизации для улучшения процесса обновления весов модели. Он добавляет инерцию к изменениям параметров, что позволяет ускорять движение в правильном направлении и снижать влияние шумов в данных или градиентах. В традиционном стохастическом градиентном спуске (SGD) обновление весов выполняется только на основе текущего градиента, что может приводить к хаотичным движениям или замедлению в негладких областях функции потерь. Момент решает эту проблему, учитывая также направление предыдущих шагов, добавляя «память» об истории обновлений.
Главное преимущество использования момента заключается в ускорении сходимости, особенно в условиях, когда функция потерь имеет вытянутую форму (например, в долинах с высокой кривизной вдоль одной оси и малой вдоль другой). Без момента модель