У меня вопрос о входе и выходе (слое) DQN.
e.g
Две точки: P1 (x1, y1) и P2 (x2, y2)
P1 должен идти в сторону P2.
У меня есть следующая информация:
- Текущая позиция P1 (x / y)
- Текущая позиция P2 (x / y)
- Расстояние до P1-P2 (x / y)
- Направление на P1-P2 (x / y)
P1 имеет 4 возможных действия:
- Up
- Вниз
- Левый
- Верно
Как мне настроить входной и выходной слой?
- 4 входных узла
- 4 выходных узла
Это верно? Что мне делать с выводом? На выходе я получил 4 массива по 4 значения в каждом. Правильно ли выполняет argmax на выходе?
Редактировать:
Вход / состояние:
# Current position P1
state_pos = [x_POS, y_POS]
state_pos = np.asarray(state_pos, dtype=np.float32)
# Current position P2
state_wp = [wp_x, wp_y]
state_wp = np.asarray(state_wp, dtype=np.float32)
# Distance P1 - P2
state_dist_wp = [wp_x - x_POS, wp_y - y_POS]
state_dist_wp = np.asarray(state_dist_wp, dtype=np.float32)
# Direction P1 - P2
distance = [wp_x - x_POS, wp_y - y_POS]
norm = math.sqrt(distance[0] ** 2 + distance[1] ** 2)
state_direction_wp = [distance[0] / norm, distance[1] / norm]
state_direction_wp = np.asarray(state_direction_wp, dtype=np.float32)
state = [state_pos, state_wp, state_dist_wp, state_direction_wp]
state = np.array(state)
Сеть:
def __init__(self):
self.q_net = self._build_dqn_model()
self.epsilon = 1
def _build_dqn_model(self):
q_net = Sequential()
q_net.add(Dense(4, input_shape=(4,2), activation='relu', kernel_initializer='he_uniform'))
q_net.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
q_net.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
q_net.add(Dense(4, activation='linear', kernel_initializer='he_uniform'))
rms = tf.optimizers.RMSprop(lr = 1e-4)
q_net.compile(optimizer=rms, loss='mse')
return q_net
def random_policy(self, state):
return np.random.randint(0, 4)
def collect_policy(self, state):
if np.random.random() < self.epsilon:
return self.random_policy(state)
return self.policy(state)
def policy(self, state):
# Here I get 4 arrays with 4 values each as output
action_q = self.q_net(state)