Memorizers in Scala


Today I will talk about anonymous functions and memorizers.

Here is a simple, self-sufficient anonymous function (I think it is the simplest possible):
val fun = (a:Int) => a + 1
What's wrong with it? Nothing, except for the fact, that it would recompute the value every time it is called. In case of adding 1 to argument, it is not a problem, but what if we would be doing something expensive?

Wouldn't it be cool to be able to do something like following?
val fun = mem( (a:Int) => a + 1 )
And Scala permits us to do exactly this:
def mem[A,B](f:A=>B) = new Function[A,B] {
  import scala.collection.mutable.Map;
  private var cache:Map[A,B] = Map()
  def apply(v:A):B = cache getOrElseUpdate(v,f(v))
}

val fun = mem( (a:Int) => a + 1 )
Now, we have a function, that stores all the computed results into a Map, and does not recompute them. There are several minor problems with this approach - for example, it is not thread-safe, and sharing this function between threads would force values to ocassionaly recomuted anyway.

Another problem is multiple-argument functions - but it can be easily rectified with .tupled:
val add = (a:Int, b:Int) = a + b
val addMem = mem(add) // won't compile
val addMem2 = mem(add.tupled) // compiles!
So, this would nicely work for simple use-cases.

But when we add a bit more complexity, we would find out that there is a much bigger problem lurking here - if you look closer, you would see that you cannot make a recursive function this way - calling "fun" from inside of "mem()" would result in "illegal forward reference" compiler complaint.

Are we out of luck here? Thankfully, no. But to explain how we can get a recursive memorized function, I'll need to venture from "memorizers" topic to the "recursive" topic for a while.

There is a technique for creating anonymous recursive functions, called "Y combinator" (or, more generally, "fixed-point combinator"). It is most commonly used in languages, that do not allow self-referential definitions (lambda calculus).

Here is an example of this combinator in Scala (taken from this SO question):
def fix[A,B](f:(A=>B)=>(A=>B)):(A=>B) = f(fix(f))(_)

val factorial = fix[Int,Int](f => a => if (a < 2) 1 else f(a-1) * a)
Basically, it is a simple trick to pass function "into itself" - the "f" argument is actually "a => ..." function. Notice that "fix(f)" part is evaluated lazily, so this would not overflow (but "fix" function can not be tail-call optimized, so there are some restrictions).

Now, given that trick, we can now make the memorizer for self-recursive functions:
def memr[A,B](f:(A=>B)=>(A=>B)):(A=>B) = new Function[A,B] {
  import scala.collection.mutable.Map;
  private var cache:Map[A,B] = Map()
  private def result(v:A):B = cache getOrElseUpdate(v, fix(f)(v))
  private def fix(f:(A=>B)=>(A=>B)):(A=>B) = f(result)(_)
  def apply(v:A):B = result(v)
}

val fun = memr[Int,Int](f => a => if (a < 2) 1 else f(a - 1) + f(a - 2))
Perfect.

Here, I simply intercepted the "fix(f)" part of the Y combinator into "result" function, and applied memorization there.

This approach accounts for most of possible use-cases. But of course, the problem of thread safety still stays :)

There is another way to improve our helper function. Really, since we care so much about our processor well-being, how could we let all the work be done only by one of its cores? We surely need to use them all.

Sadly, there seems no simple way to turn plain sequential function into parallel function - because the control flow inside the function itself is sequential. So we will need to tweak the insides of the function itself.

I used Akka and its futures for this, but I think other similar libraries can work out as well.

Here's our refined function:
def parmemr[A,B](f:((A=>Future[B]),ExecutionContext)=>(A=>Future[B])):(A=>Future[B]) = 
  new Function[A,Future[B]] {
    // executional context for out Futures, adjust according to your core count
    implicit val ec = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(4)) 
    private var cache = scala.collection.mutable.Map[A,Future[B]]()
    private def result(v:A):Future[B] = cache getOrElseUpdate(v, fix(f)(v))
    private def fix(f:((A=>Future[B]),ExecutionContext)=>(A=>Future[B])):(A=>Future[B]) = f(result,ec)(_)
    def apply(v:A):Future[B] = result(v) 
  }
And it's usage:
val parfib = parmemr[Int,Int]{(f,ec) => a =>
  implicit val e = ec // need to provide Execution Context for our Futures
  if (a < 2) Future(Int(1))
  else {
    val f1 = f(a - 1)
    val f2 = f(a - 2)
    for {
      a <- f1
      b <- f2
    } yield { a + b }
  }
}
Await.result(parfib(N), 1 minute)
Et viola, you now have recursive, memorized, parallelized, anonymous function :) Hope this helps you in some way.

P.S. The whole buildable example project is located on github.

Comments

  1. I don't think using threadpool/Future can solve the thread safety issue here. See my discussion at SO: http://stackoverflow.com/a/20462893/2073130

    ReplyDelete

Post a Comment

Popular posts from this blog

How to create your own simple 3D render engine in pure Java

Solving quadruple dependency injection problem in Angular

Better CLI option parsing in Scala