Глубокое обучение с подкреплением - проблема CartPole

Я попытался реализовать самый простой алгоритм Deep Q Learning. Я думаю, что реализовал это правильно и знаю, что Deep Q Learning борется с расхождениями, но награда очень быстро уменьшается, а потери расходятся. Я был бы признателен, если бы кто-нибудь помог мне указать правильные гиперпараметры или если бы я неправильно реализовал алгоритм. Я пробовал много комбинаций гиперпараметров, а также менял сложность QNet.

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import collections
import numpy as np
import matplotlib.pyplot as plt
import gym
from torch.nn.modules.linear import Linear
from torch.nn.modules.loss import MSELoss


class ReplayBuffer:
  def __init__(self, max_replay_size, batch_size):
    self.max_replay_size = max_replay_size
    self.batch_size      = batch_size
    self.buffer          = collections.deque()


def push(self, *transition):
    if len(self.buffer) == self.max_replay_size:
        self.buffer.popleft()
    self.buffer.append(transition)


def sample_batch(self):
    indices = np.random.choice(len(self.buffer), self.batch_size, replace = False)
    batch   = [self.buffer[index] for index in indices]
    
    state, action, reward, next_state, done = zip(*batch)
    
    state      = np.array(state)
    action     = np.array(action)
    reward     = np.array(reward)
    next_state = np.array(next_state)
    done       = np.array(done)
    
    return state, action, reward, next_state, done


def __len__(self):
    return len(self.buffer)


class QNet(nn.Module):
  def __init__(self, state_dim, action_dim):
    super(QNet, self).__init__()

    self.linear1 = Linear(in_features = state_dim, out_features = 64)
    self.linear2 = Linear(in_features = 64, out_features = action_dim)


  def forward(self, x):
    x = self.linear1(x)
    x = F.relu(x)
    x = self.linear2(x)
    return x


def train(replay_buffer, model, target_model, discount_factor, mse, optimizer):
  state, action, reward, next_state, _ = replay_buffer.sample_batch()
  state, next_state = torch.tensor(state, dtype = torch.float), torch.tensor(next_state, 
  dtype = torch.float)

  # Compute Q Value and Target Q Value
  q_values = model(state).gather(1, torch.tensor(action, dtype = torch.int64).unsqueeze(-1))

  with torch.no_grad():
    max_next_q_values = target_model(next_state).detach().max(1)[0]
    q_target_value = torch.tensor(reward, dtype = torch.float) + discount_factor * 
                     max_next_q_values

  optimizer.zero_grad()
  loss = mse(q_values, q_target_value.unsqueeze(1))
  loss.backward()
  optimizer.step()

  return loss.item()


def main():
  # Define Hyperparameters and Parameters
  EPISODES        = 10000
  MAX_REPLAY_SIZE = 10000
  BATCH_SIZE      = 32
  EPSILON         = 1.0
  MIN_EPSILON     = 0.05
  DISCOUNT_FACTOR = 0.95
  DECAY_RATE      = 0.99
  LEARNING_RATE   = 1e-3
  SYNCHRONISATION = 33
  EVALUATION      = 32

  # Initialize Environment, Model, Target-Model, Optimizer, Loss Function and Replay Buffer
  env = gym.make("CartPole-v0")

  model        = QNet(state_dim = env.observation_space.shape[0], action_dim = 
                 env.action_space.n)
  target_model = QNet(state_dim = env.observation_space.shape[0], action_dim = 
                 env.action_space.n)
  target_model.load_state_dict(model.state_dict())

  optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)
  mse       = MSELoss()

  replay_buffer = ReplayBuffer(max_replay_size = MAX_REPLAY_SIZE, batch_size = BATCH_SIZE)

  while len(replay_buffer) != MAX_REPLAY_SIZE:
    state = env.reset()
    done  = False
    while done != True:
        action = env.action_space.sample()

        next_state, reward, done, _ = env.step(action)

        replay_buffer.push(state, action, reward, next_state, done)

        state = next_state

  # Begin with the Main Loop where the QNet is trained
  count_until_synchronisation = 0
  count_until_evaluation      = 0
  history = {'Episode': [], 'Reward': [], 'Loss': []}
  for episode in range(EPISODES):
    total_reward = 0.0
    total_loss   = 0.0
    state        = env.reset()
    iterations   = 0
    done         = False
    while done != True:
        count_until_synchronisation += 1
        count_until_evaluation      += 1

        # Take an action
        if np.random.rand(1) < EPSILON:
            action = env.action_space.sample()
        else:
            with torch.no_grad():
                output = model(torch.tensor(state, dtype = torch.float)).numpy()
            action = np.argmax(output)

        # Observe new state and reward + store into replay_buffer
        next_state, reward, done, _ = env.step(action)
        total_reward += reward

        replay_buffer.push(state, action, reward, next_state, done)

        state = next_state

        if count_until_synchronisation % SYNCHRONISATION == 0:
            target_model.load_state_dict(model.state_dict())

        if count_until_evaluation % EVALUATION == 0:
            loss = train(replay_buffer = replay_buffer, model = model, target_model = 
                         target_model, discount_factor = DISCOUNT_FACTOR,
                         mse = mse, optimizer = optimizer)
            total_loss += loss

        iterations += 1

    print (f"Episode {episode} is concluded in {iterations} iterations with a total reward 
           of {total_reward}")

    if EPSILON > MIN_EPSILON:
        EPSILON *= DECAY_RATE

    history['Episode'].append(episode)
    history['Reward'].append(total_reward)
    history['Loss'].append(total_loss)

# Plot the Loss + Reward per Episode
fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(history['Episode'], history['Reward'], label = "Reward")
ax.set_xlabel('Episodes', fontsize = 15)
ax.set_ylabel('Total Reward per Episode', fontsize = 15)
plt.legend(prop = {'size': 15})
plt.show()

fig, ax = plt.subplots(figsize = (10, 6))
ax.plot(history['Episode'], history['Loss'], label = "Loss")
ax.set_xlabel('Episodes', fontsize = 15)
ax.set_ylabel('Total Loss per Episode', fontsize = 15)
plt.legend(prop = {'size': 15})
plt.show()


if __name__ == "__main__":
  main()

person EnTDeS    schedule 25.05.2021    source источник


Ответы (2)


Ваш код выглядит нормально, я думаю, что ваши гиперпараметры не идеальны. Я бы изменил две, потенциально три вещи:

  • Если не ошибаюсь, вы обновляете свою целевую сеть каждые 32 шага. Я думаю, это слишком мало. В исходной статье Мних и др., они выполняют жесткое обновление каждые 10 тысяч шагов. Подумайте об этом: целевая сеть используется для расчета потерь, вы по существу меняете функцию потерь каждые 32 шага, что будет более одного раза за эпизод.
  • Размер вашего буфера воспроизведения довольно мал. Я бы установил его на 100 тыс. Или 1 млн, даже если это больше, чем вы собираетесь тренироваться. Если буфер воспроизведения слишком мал, вы потеряете старые переходы, что может привести к тому, что ваша сеть забудет то, что она уже изучила. Не уверен, насколько это драматично для тележки, но, возможно, стоит попробовать ...
  • Скорость обучения также может быть ниже, я использую 1-e4 с RMSProp. Обычно изменение оптимизатора также может дать разные результаты.

Надеюсь это поможет. Удачи :)

person frietz58    schedule 31.05.2021

Ваш код выглядит хорошо и хорошо написан, гиперпарамы кажутся разумными (за исключением, может быть, частоты обновления, которая может быть слишком низкой), я думаю, что сеть Q довольно мала с одним плотным слоем.

Более глубокая модель, вероятно, могла бы работать лучше (хотя, вероятно, не более 3-4 слоев), но вы сказали, что уже пробовали разные размеры сетей.

Еще одна вещь, которая приходит в голову, - это целевое обновление. Каждые n шагов вы выполняете жесткое обновление; мягкое обновление может немного помочь, но я бы не стал на это рассчитывать.

Вы также можете попробовать немного снизить скорость обучения, но я полагаю, вы уже это сделали.


Мои предложения:

  • попробуйте реже обновлять цели
  • попробуйте более крупный (более глубокий, что-то вроде 2/3 плотных слоев с 32 узлами), если вы еще не
  • посмотреть обновления soft target (усреднение поляка и тд)
  • попробуйте свою реализацию в другом простом тренажерном зале и проверьте, остается ли его поведение таким же.

К сожалению, DQN не идеален и не может сойтись для многих проблем, но он должен быть в состоянии решить тележку.

person gekrone    schedule 30.05.2021