Torch: NN обрабатывает текст и числовой ввод

У меня есть следующая архитектура NN:

Часть 1:

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> (7) -> output]
  (1): nn.TemporalConvolution
  (2): nn.TemporalMaxPooling
  (3): nn.TemporalConvolution
  (4): nn.TemporalMaxPooling
  (5): nn.Reshape(14336)
  (6): nn.Dropout(0.500000)
  (7): nn.Linear(14336 -> 128)
}

Часть 2:

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> output]
  (1): nn.Linear(4 -> 8)
  (2): nn.ReLU
  (3): nn.Linear(8 -> 4)
}

Что я хотел бы сделать, так это использовать вывод этих двух частей в качестве ввода для другой части:

nn.Sequential {
  [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
  (1): nn.Linear(132 -> 32)
  (2): nn.ReLU
  (3): nn.Linear(32 -> 32)
  (4): nn.ReLU
  (5): nn.Linear(32 -> 2)
  (6): nn.LogSoftMax
}

Обратите внимание, что часть 1 имеет 128 выходов, часть 2 — 4 и, наконец, часть 3 — 132 входа. Итак, в основном то, что я хочу, - это сеть, которая принимает два типа ввода (часть 1 для текста, часть 2 для числового вектора) и использует обе эти информации в третьем слое для классификации 2 классов.

Я просмотрел различные контейнеры, но ничего не похоже на то, что мне нужно. В частности, я просмотрел nn.Parallel, но из документов похоже, что он делает что-то совершенно другое (один и тот же ввод для двух разных модулей). Первая проблема заключается в том, как должны выглядеть входные данные для сети (поскольку каждая часть использует свой тип тензора, я подумал, что подойдет простая таблица (массив), первый элемент которой будет двумерным тензором, а второй — одномерным тензором. ) и как подключить его выходы к другой сети, чтобы я мог использовать вызовы вперед/назад, как обычно.

Есть ли способ, как это сделать?

Спасибо!


person PeterK    schedule 17.09.2015    source источник


Ответы (1)


Вам нужны nn.ParallelTable и nn.JoinTable.

local parallel = nn.ParallelTable()
parallel:add(part1)
parallel:add(part2)

local net = nn.Sequential()
net:add(parallel)                   -- (A)
net:add(nn.JoinTable(1))            -- (B)
net:add(part3)                      -- (C)

(A):

parallel возьмет таблицу из 2 тензоров (в вашем случае текст и числа), перенаправит первый тензор в part1, второй тензор в part2 и выведет оба результата в другую таблицу из 2 тензоров.

(B):

Следующий nn.JoinTable берет эту таблицу в качестве входных данных и объединяет 2 тензора в один. Возможно, вам придется поиграть с параметром, обрабатывающим измерение конкатенации (в моем примере 1), в зависимости от формы ваших тензоров.

(C):

Наконец, вы можете добавить третью часть вашей сети, используя конкатенированный тензор в качестве входных данных.

person mbrenon    schedule 17.09.2015
comment
Потрясающе, спасибо! Завтра попробую и отпишусь о результатах. - person PeterK; 17.09.2015