Недостаточно памяти в потоке при двойном использовании torch.serialize

Я пытаюсь добавить параллельный загрузчик данных в torch-dataframe, чтобы добавить совместимость с torchnet. Я использовал tnt.ParallelDatasetIterator и изменил так, что:

  1. Базовый пакет загружается вне потоков
  2. Пакет сериализуется и отправляется в поток
  3. В потоке пакет десериализуется и преобразует пакетные данные в тензоры
  4. Тензоры возвращаются в таблице с ключами input и target для соответствия tnt.Engine.

Проблема возникает при втором вызове enque с ошибкой: .../torch_distro/install/bin/luajit: not enough memory. В настоящее время я работаю только с mnist с адаптированным mnist-example. Цикл enque теперь выглядит так (с отладочным выводом памяти):

-- `samplePlaceholder` stands in for samples which have been
-- filtered out by the `filter` function
local samplePlaceholder = {}

-- The enque does the main loop
local idx = 1
local function enqueue()
  while idx <= size and threads:acceptsjob() do
    local batch, reset = self.dataset:get_batch(batch_size)

    if (reset) then
      idx = size + 1
    else
      idx = idx + 1
    end

    if (batch) then
      local serialized_batch = torch.serialize(batch)

      -- In the parallel section only the to_tensor is run in parallel
      --  this should though be the computationally expensive operation
      threads:addjob(
        function(argList)
          io.stderr:write("\n Start");
          io.stderr:write("\n 1: " ..tostring(collectgarbage("count")))
          local origIdx, serialized_batch, samplePlaceholder = unpack(argList)

          io.stderr:write("\n 2: " ..tostring(collectgarbage("count")))
          local batch = torch.deserialize(serialized_batch)
          serialized_batch = nil

          collectgarbage()
          collectgarbage()

          io.stderr:write("\n 3: " .. tostring(collectgarbage("count")))
          batch = transform(batch)

          io.stderr:write("\n 4: " .. tostring(collectgarbage("count")))
          local sample = samplePlaceholder
          if (filter(batch)) then
            sample = {}
            sample.input, sample.target = batch:to_tensor()
          end
          io.stderr:write("\n 5: " ..tostring(collectgarbage("count")))

          collectgarbage()
          collectgarbage()
          io.stderr:write("\n 6: " ..tostring(collectgarbage("count")))

          io.stderr:write("\n End \n");
          return {
            sample,
            origIdx
          }
        end,
        function(argList)
          sample, sampleOrigIdx = unpack(argList)
        end,
        {idx, serialized_batch, samplePlaceholder}
      )
    end
  end
end

Я побрызгал collectgarbage, а также попытался удалить все ненужные объекты. Вывод памяти довольно прост:

 Start
 1: 374840.87695312
 2: 374840.94433594
 3: 372023.79101562
 4: 372023.85839844
 5: 372075.41308594
 6: 372023.73632812
 End 

Функция, которая зацикливает enque, является тривиальной неупорядоченной функцией (ошибка памяти возникает во втором enque и ):

iterFunction = function()
  while threads:hasjob() do
    enqueue()
    threads:dojob()
    if threads:haserror() then
      threads:synchronize()
    end
    enqueue()

    if table.exact_length(sample) > 0 then
      return sample
    end
  end
end

person Max Gordon    schedule 17.07.2016    source источник


Ответы (1)


Таким образом, проблема заключалась в torch.serialize, где функция в настройке связывала весь набор данных с функцией. При добавлении:

serialized_batch = nil
collectgarbage()
collectgarbage()

Проблема была решена. Я также хотел знать, что занимает так много места, и виновником оказалось то, что я определил функцию в среде с большим набором данных, который переплелся с функцией, значительно увеличив размер. Здесь исходное определение локальных данных

mnist = require 'mnist'
local dataset = mnist[mode .. 'dataset']()

-- PROBLEMATIC LINE BELOW --
local ext_resource = dataset.data:reshape(dataset.data:size(1),
  dataset.data:size(2) * dataset.data:size(3)):double()

-- Create a Dataframe with the label. The actual images will be loaded
--  as an external resource
local df = Dataframe(
  Df_Dict{
    label = dataset.label:totable(),
    row_id = torch.range(1, dataset.data:size(1)):totable()
  })

-- Since the mnist package already has taken care of the data
--  splitting we create a single subsetter
df:create_subsets{
  subsets = Df_Dict{core = 1},
  class_args = Df_Tbl({
    batch_args = Df_Tbl({
      label = Df_Array("label"),
      data = function(row)
        return ext_resource[row.row_id]
      end
    })
  })
}

получается, что удаление выделенной строки уменьшает использование памяти с 358 Мб до 0,0008 Мб! Код, который я использовал для тестирования производительности, был таким:

local mem = {}
table.insert(mem, collectgarbage("count"))

local ser_data = torch.serialize(batch.dataset)
table.insert(mem, collectgarbage("count"))

local ser_retriever = torch.serialize(batch.batchframe_defaults.data)
table.insert(mem, collectgarbage("count"))

local ser_raw_retriever = torch.serialize(function(row)
  return ext_resource[row.row_id]
end)
table.insert(mem, collectgarbage("count"))

local serialized_batch = torch.serialize(batch)
table.insert(mem, collectgarbage("count"))

for i=2,#mem do
  print(i-1, (mem[i] - mem[i-1])/1024)
end

Который первоначально произвел вывод:

1   0.0082607269287109  
2   358.23344707489 
3   0.0017471313476562  
4   358.90182781219 

и после исправления:

1   0.0094480514526367  
2   0.00080204010009766 
3   0.00090408325195312 
4   0.010146141052246

Я попытался использовать setfenv для функции, но это не решило проблему. При отправке сериализованных данных в поток по-прежнему снижается производительность, но основная проблема решена, и без дорогостоящего средства извлечения данных функция становится значительно меньше.

person Max Gordon    schedule 17.07.2016