В продолжении scala, как написать цикл в форме CPS?

Я пытаюсь реализовать пример по адресу:

https://portal.klewel.com/watch/webcast/scala-days-2019/talk/37/

используя продолжение scala:

object ReverseGrad_CPSImproved {

  import scala.util.continuations._

  case class Num(
      x: Double,
      var d: Double = 0.0
  ) {

    def +(that: Num) = shift { (cont: Num => Unit) =>
      val y = Num(x + that.x)

      cont(y)

      this.d += y.d
      that.d += y.d
    }

    def *(that: Num) = shift { (cont: Num => Unit) =>
      val y = Num(x * that.x)

      cont(y)

      this.d += that.x * y.d
      that.d += this.x * y.d
    }
  }

  object Num {

    implicit def fromX(x: Double): Num = Num(x)
  }

  def grad(f: Num => Num @cps[Unit])(x: Double): Double = {

    val _x = Num(x)
    reset { f(_x).d = 1.0 }

    _x.d
  }
}

Это работает, пока я использую простое выражение:

  it("simple") {

    val fn = { x: Num =>
      val result = (x + 3) * (x + 4)

      result
    }

    val gg = grad(fn)(3)

    println(gg)
  }

Но как только я начал использовать цикл, все развалилось:


  it("benchmark") {

    import scala.util.continuations._

    for (i <- 1 to 20) {

      val n = Math.pow(2, i).toInt

      val fn = { x: Num =>
        var result = x + 1

        for (j <- 2 to n) {
          result = result * (x + j)
        }

        result
      }

      val nanoFrom = System.nanoTime()
      val gg = grad(fn)(3)
      val nanoTo = System.nanoTime()

      println(s"diff = $gg,\t time = ${nanoTo - nanoFrom}")
    }
  }


[Error] /home/peng/git-spike/scalaspike/meta/src/test/scala/com/tribbloids/spike/meta/multistage/lms/ReverseGrad_CPSImproved.scala:78: found cps expression in non-cps position
one error found

У меня сложилось впечатление, что библиотека продолжения должна иметь собственную реализацию цикла, которую можно переписать в рекурсию, но я не могу найти ее нигде в последней версии (scala 2.12). Какой самый простой способ использовать цикл в этом случае?


person tribbloid    schedule 20.10.2020    source источник
comment
Почему вы используете CPS, если вы все равно мутируете состояние (var d: Double = 0.0, this.d += that.x * y.d)?   -  person Mateusz Kubuszok    schedule 20.10.2020


Ответы (1)


В CPS вам нужно переписать свой код, чтобы вы НЕ выполняли вложенный/итеративный/рекурсивный вызов в том же контексте, а вместо этого выполняли только один шаг вычисления и передавали частичный результат вперед.

Например. если вы хотите вычислить произведение чисел от A до B, вы можете реализовать это следующим образом:

import scala.util.continuations._

case class Num(toDouble: Double) {

  def get = shift { cont: (Num => Num) =>
    cont(this)
  } 

  def +(num: Num) = reset {
    val a  = num.get
    Num(toDouble + a.toDouble)
  }

  def *(num: Num) = reset {
    val a  = num.get
    Num(toDouble * a.toDouble)
  }
}

// type annotation required because of recursive call
def product(from: Int, to: Int): Num @cps[Num] = reset { 
  if (from > to) Num(1.toDouble)
  else Num(from.toDouble) * product(from + 1, to)
}

def run: Num = reset {
  product(2, 10)
}

println(run)

(см. этот scastie).

Наиболее интересен этот фрагмент:

reset {
  if (from > to) Num(1.toDouble)
  else Num(from.toDouble) * product(from + 1, to)
}

Здесь компилятор (плагин) переписывает это примерно так:

input: (Num => Num) => {
  if (from > to) Num(1.toDouble)
  else {
    Num(from.toDouble) * product(from + 1, to) // this is virtually (Num => Num) => Num function!
  } (input)
}

Компилятор может это сделать, потому что:

  • it observes the content of shift and reset calls
    • both create something that takes some parameter A and returns intermediate result B (usable in e.g. inside this or another reset) and final result C (what you get when you run the final result of the composition) (denoted as A @ cpsParam[B, C] - if B =:= C you can use a type alias A @ cps[A])
    • reset помогает не сойти с ума с передачей параметров, поскольку он обрабатывает параметр (A в A @ cpsParam[B, C]) и передает его всем вложенным вызовам CPS и получает промежуточный результат (таким образом, B в A @ cpsParam[B, C]) и заставляет весь блок возвращать окончательный результат - C A @ cpsParam[B, C]
    • shift поднимает функцию (A => B) => C в A @ cpsParam[B, C]
  • когда он видит, что возвращаемый тип Input @cpsParam[Output1, Output2], он знает, что должен переписать код, чтобы ввести параметр и передать его туда

На практике это немного сложнее, но это в основном все.

Тем временем вы делаете свое

        for (j <- 2 to n) {
          result = result * (x + j)
        }

вне этого контекста, где компилятор не выполняет никаких преобразований. Вы должны, по крайней мере, составить все эти операции CPS в пределах reset. (Кроме того, вы запускаете вещи в цикле и мутации, которые также можно делегировать CPS).

Тем не менее, CPS (как в: эта конкретная реализация) мертв. Он был удален в Scala 2.13, его никто не поддерживает, и с помощью какой-нибудь монады на основе батута (например, Cont из Cats) гораздо проще понять, поэтому единственное место, где я все еще вижу это, это устаревшие курсы или статьи об исторических мелочах.

person Mateusz Kubuszok    schedule 20.10.2020