Предложения по оптимизации простого Scala foldLeft по нескольким значениям?

Я повторно реализую некоторый код (простой байесовский алгоритм вывода, но это не очень важно) с Java на Scala. Я хотел бы реализовать его максимально производительным способом, сохраняя при этом код чистым и функциональным, максимально избегая изменчивости.

Вот фрагмент кода Java:

    // initialize
    double lP  = Math.log(prior);
    double lPC = Math.log(1-prior);

    // accumulate probabilities from each annotation object into lP and lPC
    for (Annotation annotation : annotations) {
        float prob = annotation.getProbability();
        if (isValidProbability(prob)) {
            lP  += logProb(prob);
            lPC += logProb(1 - prob);
        }
    } 

Довольно просто, правда? Поэтому я решил использовать методы Scala foldLeft и map для моей первой попытки. Поскольку у меня есть два значения, которые я накапливаю, аккумулятор представляет собой кортеж:

    val initial  = (math.log(prior), math.log(1-prior))
    val probs    = annotations map (_.getProbability)
    val (lP,lPC) = probs.foldLeft(initial) ((r,p) => {
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

К сожалению, этот код работает примерно в 5 раз медленнее, чем Java (с использованием простой и неточной метрики; просто вызвал код 10000 раз в цикле). Один недостаток довольно очевиден; мы проходим списки дважды, один раз в вызове map, а другой - в foldLeft. Итак, вот версия, которая проходит по списку один раз.

    val (lP,lPC) = annotations.foldLeft(initial) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) (r._1 + logProb(p), r._2 + logProb(1-p)) else r
    })

Это лучше! Он работает примерно в 3 раза хуже, чем код Java. Моя следующая догадка заключалась в том, что создание всех новых кортежей на каждом этапе свертки, вероятно, требует определенных затрат. Поэтому я решил попробовать версию, которая дважды просматривает список, но без создания кортежей.

    val lP = annotations.foldLeft(math.log(prior)) ((r,annotation) => {
       val  p = annotation.getProbability
       if(isValidProbability(p)) r + logProb(p) else r
    })
    val lPC = annotations.foldLeft(math.log(1-prior)) ((r,annotation) => {
      val  p = annotation.getProbability
      if(isValidProbability(p)) r + logProb(1-p) else r
    })

Это работает примерно так же, как и предыдущая версия (в 3 раза медленнее, чем версия для Java). На самом деле это не удивительно, но я был полон надежд.

Итак, мой вопрос: есть ли более быстрый способ реализовать этот фрагмент Java в Scala, сохранив при этом чистый код Scala, избегая ненужной изменчивости и следуя идиомам Scala? Я ожидаю, что в конечном итоге я буду использовать этот код в параллельной среде, поэтому ценность сохранения неизменности может перевесить более низкую производительность в одном потоке.


person Raj B    schedule 02.02.2012    source источник
comment
У вас есть ленивые структуры данных в scala? В таком случае вы сможете избежать многократных проходов.   -  person Marcin    schedule 02.02.2012


Ответы (5)


Во-первых, часть вашего штрафа может быть связана с типом используемой вами коллекции. Но по большей части это, вероятно, создание объекта, которого вы на самом деле не избежите, запустив цикл дважды, поскольку числа должны быть заключены в рамку.

Вместо этого вы можете создать изменяемый класс, который накапливает для вас значения:

class LogOdds(var lp: Double = 0, var lpc: Double = 0) {
  def *=(p: Double) = {
    if (isValidProbability(p)) {
      lp += logProb(p)
      lpc += logProb(1-p)
    }
    this  // Pass self on so we can fold over the operation
  }
  def toTuple = (lp, lpc)
}

Теперь, хотя вы можете использовать это небезопасно, в этом нет необходимости. Фактически, вы можете просто сложить его.

annotations.foldLeft(new LogOdds()) { (r,ann) => r *= ann.getProbability } toTuple

Если вы используете этот паттерн, вся изменчивая небезопасность скрывается внутри складки; он никогда не ускользнет.

Теперь вы не можете выполнить параллельное сгибание, но вы можете выполнить агрегирование, которое похоже на складывание с дополнительной операцией по объединению частей. Итак, вы добавляете метод

def **(lo: LogOdds) = new LogOdds(lp + lo.lp, lpc + lo.lpc)

в LogOdds, а затем

annotations.aggregate(new LogOdds())(
  (r,ann) => r *= ann.getProbability,
  (l,r) => l**r
).toTuple

и тебе будет хорошо идти.

(Не стесняйтесь использовать для этого нематематические символы, но поскольку вы в основном умножаете вероятности, символ умножения, похоже, скорее дает интуитивное представление о том, что происходит, чем включает в себя Вероятность или что-то подобное.)

person Rex Kerr    schedule 02.02.2012
comment
Почему он не может сделать параллельную складку? Он просто добавляет значения, что является коммутативным и ассоциативным. - person Daniel C. Sobral; 02.02.2012
comment
@ DanielC.Sobral - Потому что ему нужно сделать foldLeft ((U,T)=>U), а не просто сбросить ((U,U)=>U), и foldLeft нельзя разумно накапливать параллельно. Вот почему aggregate существует. - person Rex Kerr; 02.02.2012
comment
@Rex - Я тоже не понимаю. Если вы сначала выполните фильтр по валидности (и проигнорируете инициализацию lp и lpc, что является простым дополнением), это выглядит ассоциативным для меня. Вы можете произвольно распараллелить Foldable[A : Monoid].sum - person oxbow_lakes; 02.02.2012
comment
@oxbow_lakes - Вы можете map, затем filter, затем fold, или можете aggregate. Один шаг обычно быстрее трех. aggregate - это тоже параллельная операция. - person Rex Kerr; 02.02.2012
comment
Ах, хорошо, я забыл про шаг map. - person Daniel C. Sobral; 02.02.2012

Вы можете реализовать хвостовой рекурсивный метод, который будет преобразован компилятором в цикл while, следовательно, должен быть таким же быстрым, как версия Java. Или вы можете просто использовать цикл - против него нет закона, если он просто использует локальные переменные в методе (см., Например, обширное использование в исходном коде коллекций Scala).

def calc(lst: List[Annotation], lP: Double = 0, lPC: Double = 0): (Double, Double) = {
  if (lst.isEmpty) (lP, lPC)
  else {
    val prob = lst.head.getProbability
    if (isValidProbability(prob)) 
      calc(lst.tail, lP + logProb(prob), lPC + logProb(1 - prob))
    else 
      calc(lst.tail, lP, lPC)
  }
}

Преимущество сворачивания заключается в том, что его можно распараллеливать, что может привести к тому, что он будет быстрее, чем версия Java на многоядерной машине (см. Другие ответы).

person Luigi Plinge    schedule 02.02.2012
comment
List не может эффективно распараллеливаться - person oxbow_lakes; 02.02.2012
comment
@oxbow, что имеет смысл; при распараллеливании лучше убедиться, что вы используете класс с быстрым произвольным доступом, например Vector. - person Luigi Plinge; 02.02.2012

В качестве побочного примечания: вы можете избежать двойного идиоматического обхода списка, используя _ 1_:

val probs = annotations.view.map(_.getProbability).filter(isValidProbability)

val (lP, lPC) = ((logProb(prior), logProb(1 - prior)) /: probs) {
   case ((pa, ca), p) => (pa + logProb(p), ca + logProb(1 - p))
}

Это, вероятно, не даст вам большей производительности, чем ваша третья версия, но мне она кажется более элегантной.

person Travis Brown    schedule 02.02.2012
comment
Спасибо за предложение regd view (), особенно с примером. - person Raj B; 03.02.2012
comment
Вероятно, есть некоторые накладные расходы на создание ленивой структуры данных, необходимой для представления. Мои простые сравнительные эксперименты показывают, что это связано с расходами. Но я полностью согласен с аспектом элегантности :) - person Raj B; 03.02.2012

Во-первых, давайте рассмотрим проблему производительности: нет способа реализовать его так же быстро, как Java, кроме как с помощью циклов while. По сути, JVM не может оптимизировать цикл Scala в той степени, в которой он оптимизирует цикл Java. Причины этого вызывают беспокойство даже среди людей, занимающихся JVM, потому что они также мешают параллельной работе с библиотеками.

Теперь, возвращаясь к производительности Scala, вы также можете использовать .view, чтобы избежать создания новой коллекции на этапе map, но я думаю, что этап map всегда приведет к снижению производительности. Дело в том, что вы конвертируете коллекцию в одну, параметризованную на Double, которую нужно упаковывать и распаковывать.

Однако есть один возможный способ его оптимизации: сделать его параллельным. Если вы вызываете .par на annotations, чтобы сделать его параллельной коллекцией, вы можете затем использовать fold:

val parAnnot = annotations.par
val lP = parAnnot.map(_.getProbability).fold(math.log(prior)) ((r,p) => {
   if(isValidProbability(p)) r + logProb(p) else r
})
val lPC = parAnnot.map(_.getProbability).fold(math.log(1-prior)) ((r,p) => {
  if(isValidProbability(p)) r + logProb(1-p) else r
})

Чтобы избежать отдельного map шага, используйте aggregate вместо fold, как предлагает Рекс.

В качестве бонусных баллов вы можете использовать Future, чтобы оба вычисления выполнялись параллельно. Я подозреваю, что вы получите лучшую производительность, если вернете кортежи и запустите их за один раз. Вам нужно будет протестировать этот материал, чтобы увидеть, что работает лучше.

В параллельных коллекциях это может сначала filter окупиться за действительные аннотации. Или, может быть, collect.

val parAnnot = annottions.par.view map (_.getProbability) filter (isValidProbability(_)) force;

or

val parAnnot = annotations.par collect { case annot if isValidProbability(annot.getProbability) => annot.getProbability }

Во всяком случае, эталон.

person Daniel C. Sobral    schedule 02.02.2012

В настоящее время невозможно взаимодействовать с библиотекой коллекций scala без бокса. Итак, что такое примитивные double в Java, будет постоянно упаковываться и распаковываться в операции fold, даже если вы не упаковываете их в Tuple2 (который является специализированным, но, конечно, вы уже платите накладные расходы на создание новых объектов каждый раз).

person oxbow_lakes    schedule 02.02.2012
comment
Действительно раздражает, что самый низкий уровень (тот, у которого больше всего итераций) всегда должен работать с массивами примитивных типов, если вам нужна производительность. Возможны ли какие-то абстракции? - person ziggystar; 02.02.2012