Ref + Deferred = Set Once Concurrent Cache

18 minute read Published: 2022-02-06

This blog post is about building a concurrent cache. That itself is uninteresting and there are many ways to do this, incl. a straight up concurrent hashmap in Java. The interesting part is the requirements I needed:

It's enough to do this within a single running service, so that I only need to coordinate within that system. If I have to scale this out and coordinate between systems, I'll drag in something like Zookeeper. But it's easy enough to build the above.

Last year I had this problem to solve in production and managed to crank out a solution in cats-effect. The part that really surprised me is how easy it was to put something together, despite the fact I was dealing with both (a) writing my own data structure (sort of) and (b) messing with concurrency. I managed to figure it out with the docs and the following:

I was very surprised at how easy this was. I think that speaks to the great libraries available in Scala. We're going to be using cats-effect, but there are other solutions that are just as easy using ZIO, or ZIO-STM, or Cats-STM, and even Akka actors. Perhaps I'll blog about those another day.

I need this to coordinate a bunch of independent external processes. The basic design:

Ref

From the docs a ref is an asynchronous, concurrent mutable reference. You use it to share mutable state between threads/fibers. We are going to store our shared mapping of Id -> Thing inside of a Ref

abstract class Ref[F[_], A] {
  def get: F[A]
  def set(a: A): F[Unit]
  def modify[B](f: A => (A, B)): F[B]
  // ... and more
}

// Specialized on IO 

abstract class Ref[A] {
  def get: IO[A]
  def set(a: A): IO[Unit]
  def modify[B](f: A => (A, B)): IO[B]
  // ... and more
}

There is a compare and set under the hood when we try to modify the shared mapping. Because an attempt to set the value may fail, the function modifying the Ref may be called multiple times. Thus, whatever is in the Ref should be immutable. If the inside is an immutable List, who cares how many times you pretend to an immutable list. But if it's a mutable list, you may prepend your element many times. Cats-effect ships with some additional immutable data structures (like queues) for this reason.

The text book example for Ref is a shared counter between threads. We'll do a variation below, which is some sort of shared sum being computed across multiple threads. First, a version that doesn't work due to no coordination between threads. We are trying to sum the numbers 1 to 1000 across 10 different worker threads:

object VarFail extends IOApp.Simple {

  var counter: Int = 0

  override def run: IO[Unit] = {

    val cases = Range(1, 1000).inclusive.toList
    val expectedSum = cases.sum

    for {
      _ <- cases.parTraverseN(10)(
        i =>
          IO {
            val _ = counter = counter + i
          }
      )
      _ <- IO.println(s"expected $expectedSum got $counter MATCHES ---> ${expectedSum === counter}")
    } yield ()
  }
}

Running this reveals that sometimes the variable never got updated properly. This is not a surprise and normally you would put a mutex around the variable or some sort of java synchronization primitive.

If we build this with a Ref[IO, Int] instead:

object RefExample extends IOApp.Simple {

  var counter: IO[Ref[IO, Int]] = Ref.of[IO, Int](0)

  def increment(i: Int, ref: Ref[IO, Int]): IO[Unit] =
    ref.update(state => state + i)

  override def run: IO[Unit] = {
    val cases = Range(1, 1000).inclusive.toList
    val expectedSum = cases.sum
    for {
      ref <- counter
      _ <- cases.parTraverseN(10)(i => increment(i, ref))
      state <- ref.get
      _ <- IO.println(s"expected $expectedSum got $state MATCHES ---> ${expectedSum === state}")
    } yield ()
  }
}

We get the answer we expect.

So cool. We have a way to (safely) share state between our threads. But we still have a problem where we only want something to run once. Once again, Ref comes to the rescue. We can use Option to give us some surprising behavior. In the next example, we'll use Ref[IO, Option[Int]]. If the value is None then we can lock in a value. If the value is already Some, whoever gets it is too late. Said another way, we can get a handle on which thread "wins" the setting, and which threads lose the setting. We also use modfy rather than update. Modify wants a function of A => (A, B). This let's us set a value and critically: let's us emit some value downstream as part of setting the value. In the following example, if we are the "winning" thread w'ell print out that we won (our downstream IO). Otherwise, we print out that we lost because a value was already set. Note the use of flatten it's important since we are emitting an IO downstream, and modify is also an IO you end up in an IO[IO[...]]:

object RefWinExample extends IOApp.Simple {

  val refOption: IO[Ref[IO, Option[Int]]] = Ref.of[IO, Option[Int]](None)

  def doStuff(i: Int, ref: Ref[IO, Option[Int]]): IO[Unit] =
    ref.modify {
      case None => Some(i) -> IO.println(s"****** i: $i -> I won the write! Setting saved value to $i")
      case Some(alreadySet) =>
        Some(alreadySet) -> IO.println(s"i: $i -> Lost, there was an existing value set to $alreadySet")
    }.flatten

  override def run: IO[Unit] = {
    val cases = Range(1, 10).inclusive.toList
    for {
      ref <- refOption
      _ <- cases.parTraverseN(10)(i => doStuff(i, ref))
    } yield ()
  }
}

Running this multiple times will result in a different "thread" winning. We don't know which thread, but we do know that it will only be one of them (the one who sets the None -> Some case).

That is as far as we can go with Ref. To solve the rest, we need something beter. There is no functionality for synchronization with Ref.

Enter...

Deferred

Deferred is a purely functional synchronization primitive which represents a single value which may not yet be available. It is created empty and fulfilled later.

abstract class Deferred[F[_], A] {
  def get: F[A]
  def complete(a: A): F[Boolean]
}

// specialized on IO 

abstract class Deferred[A] {
  def get: IO[A]
  def complete(a: A): IO[Boolean]
}

The interesting parts for us:

This combo of Ref and Deferred is everywhere when you dig under the hood in Cats-effect. These two little primitives give you a surprising amount of power. See the implementations of Semaphore and Queue for instance.

So how is this useful for our cache case?

We will need another piece. And that's where finite state machines come in:

This is how we get around the "I only want my very expensive mutation side effect setting function to run once". We store states in our cache, and only run the effect after a state transition is completed.

I'm still talking nonsense and a concrete example will help make the slides here more concrete.

We have two things to define: states and transitions, or as I call them in the example Operations.

First, let's define our cache:

trait SetOnceCache[K, V] {
  def get(k: K): IO[V]
}

In our actual implementations, we will store everything in a Map. We'll use MapRef, which is a Ref implementation optimized for Maps. You could also use a bare immutable Scala Map but in production I wanted the performance of the former as I will have many, many keys. Key is easy. But what is the value of our map? It's our states! Well, our optional state because we need the None -> Some trick to determine the thread winner.

This leaves two states then:

We don't need a state for value doesn't exist because that's handled by the None case and the Map itself.

sealed trait State[V] 
case class Done[V](value: IO[V]) 
  extends State[V]
case class Running[V](getWait: Deferred[IO, Outcome[IO, Throwable, V]])
  extends State[V]

If we're done, here's the value. Otherwise our state is Running and we will be waiting for some deferred to complete.

Now we have a bunch of Operations. These are our transitions between the states.

sealed trait Operation[V]

// if you get this transition
// you are the thread responsible for completing the deferred 
case class Set[V](set: Deferred[IO, Outcome[IO, Throwable, V]])
  extends Operation[V]

// if you get this transition
// someone else is responsible for completing the deferred
// and you can wait on it if you want to get the value eventually
case class Wait[V](getWait: Deferred[IO, Outcome[IO, Throwable, V]])
  extends Operation[V]


// if you get this transition
// there is already a value for you 
case class Completed[V](value: IO[V]) extends Operation[V]

Cool. The rest is just pattern matching. Lots and lots of pattern matching. Given the above, we'll actually implement this in like twenty lines of scala.

// to make a SetOnceCache we need some function that sets the value as a function of the key if it does not exist 
// this would be my expensive (call object store, hit DB, emit metrics, etc.) that I pass in
// we leave it a simple f because it also simplifies testing.  This doesn't need to know any of that, it just needs to 
// only run it once per key.

object SetOnceCache {
  def make[K, V](f: K => IO[V]): IO[SetOnceCache[K, V]] =
    MapRef
      .ofScalaConcurrentTrieMap[IO, K, State[V]]
      .map(
        mapref =>
          new SetOnceCache[K, V] {
            def get(k: K): IO[V] = Deferred[IO, Outcome[IO, Throwable, V]].flatMap {
              wait =>
                mapref(k)
                  .modify[Operation[V]] {
                    case None => (Running(wait).some, Set(wait))
                    case s@Some(Running(wait)) => (s, Wait(wait))
                    case s@Some(Done(v)) => (s, Completed(v))
                  }.flatten
}

This is mostly just annotation:

Continuing on, we'll flatmap this and now our elements we are dealing with are the transitions. More pattern matching:

                  .flatMap {
                    case Set(_) =>
                      f(k).guaranteeCase {
                        case s@Outcome.Succeeded(fa) =>
                          mapref(k).modify {
                            case Some(Running(wait)) => (Done(fa).some, wait.complete(s).void)
                            case s@Some(Done(_)) => (s, IO.unit) // technically unreachable but type checker :(
                            case None => (None, IO.unit) // technically unreachable but type checker :(
                          }.flatten // <- don't forgeet these flattens

One and only one of the processes will end up with the Set transition:

Continuing on, we need to deal with what to do if f(k) didn't succeed:

                        case s =>
                          mapref(k).modify {
                            case Some(Running(wait)) =>
                              (None, wait.complete(s).void) // clears the cached value but returns a failure
                            case s@Some(Done(_)) => (s, IO.unit)
                            case None => (None, IO.unit)
                          }.flatten
                      }

If the setter fails just die. Someone upstream will have to retry this key or there was some reason why we couldn't cahce it.

Continuing on, we need to deal with the other transitions. What if we are a thread that doesn't have to Set the value but is instead waiting on it?

                    case Wait(wait) => 
                      wait.get.flatMap { // `.get` "soft" blocks on IO but not on an actual thread
                        case Outcome.Succeeded(fa) => fa
                        case Outcome.Errored(e) => IO.raiseError(e)
                        case Outcome.Canceled() => IO.raiseError(new Throwable("Someone cancelled the setter"))
                      }

If we are the waiting transition, well we just wait on the deferred.

Finally our last transition, which is the value already existing in our state:

                    case Completed(v) => v

In this case we just return the completed value.

Putting it all together:

trait SetOnceCache[K, V] {
  def get(k: K): IO[V]
}

object SetOnceCache {

  private sealed trait State[V]

  private case class Running[V](getWait: Deferred[IO, Outcome[IO, Throwable, V]]) extends State[V]

  private case class Done[V](value: IO[V]) extends State[V]

  private sealed trait Operation[V] extends Product with Serializable

  private case class Set[V](set: Deferred[IO, Outcome[IO, Throwable, V]]) extends Operation[V]

  private case class Wait[V](getWait: Deferred[IO, Outcome[IO, Throwable, V]]) extends Operation[V]

  private case class Completed[V](value: IO[V]) extends Operation[V]

  def make[K, V](f: K => IO[V]): IO[SetOnceCache[K, V]] =
    MapRef
      .ofScalaConcurrentTrieMap[IO, K, State[V]]
      .map(
        mapref =>
          new SetOnceCache[K, V] {
            def get(k: K): IO[V] = Deferred[IO, Outcome[IO, Throwable, V]].flatMap {
              wait =>
                mapref(k)
                  .modify[Operation[V]] {
                    case None => (Running(wait).some, Set(wait))
                    case s@Some(Running(wait)) => (s, Wait(wait))
                    case s@Some(Done(v)) => (s, Completed(v))
                  }.flatten
                  .flatMap {
                    case Set(_) =>
                      f(k).guaranteeCase {
                        case s@Outcome.Succeeded(fa) =>
                          mapref(k).modify {
                            case Some(Running(wait)) => (Done(fa).some, wait.complete(s).void)
                            case s@Some(Done(_)) => (s, IO.unit) // technically unreachable but type checker :(
                            case None => (None, IO.unit) // technically unreachable but type checker :(
                          }.flatten // <- don't forgeet these flattens
                        case s =>
                          mapref(k).modify {
                            case Some(Running(wait)) =>
                              (None, wait.complete(s).void) // clears the cached value but returns a failure
                            case s@Some(Done(_)) => (s, IO.unit)
                            case None => (None, IO.unit)
                          }.flatten
                      }
                    case Wait(wait) => 
                      wait.get.flatMap { // `.get` "soft" blocks on IO but not on an actual thread
                        case Outcome.Succeeded(fa) => fa
                        case Outcome.Errored(e) => IO.raiseError(e)
                        case Outcome.Canceled() => IO.raiseError(new Throwable("Someone cancelled the setter"))
                      }
                    case Completed(v) => v

                  }
            }
          }
      )
}

This is all we need in the simple case and was my first pass at this. The more complicated case deals with:

The above don't add that much more complexity (some uncancelable and poll and a bit of cleanup) to the simple structure above.

Now let's use it!

object SimpleExample extends IOApp.Simple {

  // doesn't handle deletions well, doesn't handle timeouts well but you get the idea

  override def run: IO[Unit] = {

    val actionFunctionRecorder = Ref[IO].of(List.empty[Int])
    val result = actionFunctionRecorder.flatMap { recorder =>
      val someExpensiveCall: Int => IO[String] = i =>
        for {
          _ <- IO.sleep(100.milliseconds)
          _ <- recorder.update(xs => i +: xs)
          _ <- IO.println(s"*******RECORDED: $i")
        } yield s"id: $i"

      val keys = Range(1, 25).inclusive.toList
      val processes = Range(1, 25).inclusive.toList
      // We run 25 different ids with 25 concurrent gets to each key at the same time at the same time
      val ids = for {
        k <- keys
        p <- processes
      } yield (k, p)

      val cache = SetOnceCache.make(someExpensiveCall)

      for {
        cache <- cache
        _ <- ids.parTraverse {
          case (key, process) => IO.println(s"key $key -> process $process") >> cache.get(key)
        }
        calledIds <- recorder.get
        _ <- IO.println("finished\n\n\n------------\n")
      } yield calledIds.sorted
    }
    result.flatMap(xs => IO.println(xs))
  }
}

We make a Ref to record how many times a setting function is called. As you can see, the setter is only ever run once.

This still blows me away! We were able to build out some complicated concurrency stuff with honestly what amounts to a bunch of case classes and pattern matches. You see this pattern everywhere and it's useful.

Take for instance, Single Fibered from the Davenverse:

    * Prepares a function so that the resulting function is single-fibered.
    * This means that no matter how many fibers are executing this function.
    * For any specific key only 1 will be running at the same time, others
    * will share the result of that running computation. As soon as that computation
    * completes other computations will again be able to run the function again.
    * 

It's a simpler case than mine: they just want to make it so multiple functions invocations of the same function aren't run at the same time, but aren't interested in caching the result:

    */
  def singleFiberedFunction[F[_]: Concurrent, K, V](
    state: K => Ref[F, Option[F[Outcome[F, Throwable, V]]]],
    f: K => F[V]
  ) = {
    {(k: K) => 
      Deferred[F, Outcome[F, Throwable, V]].flatMap{d => 
        Concurrent[F].uncancelable{poll => 
          state(k)
            .modify{
              case s@Some(out) => s -> 
                poll(out)
                  .flatMap(embedError(_))
              case None => 
                Some(d.get) -> 
                  Concurrent[F].guaranteeCase(poll(f(k))){
                    o => state(k).set(None) >> d.complete(o).void 
                  }
            }.flatten
        }
      }
    }
  }

Note how the Ref is an Option, again so you can get a handle on the None -> Some case. If you squint, you can also see where the Deferred gets completed. After which case they clear out their cache (they are only interested in making sure multiple things don't run at the same time, not interested in caching the result). There is added complexity because they are also dealing with cancellation.

All said and done, pretty fun! I hope this helps you and demystifies what a lot of the cats-effect standard library is doing under the hood as it shows up everywhere. You can see it also in the implementation of Semaphore, although without the necessity of the Deferred synchronization.