This blog post is about building a concurrent cache. That itself is uninteresting and there are many ways to do this, incl. a straight up concurrent hashmap in Java. The interesting part is the requirements I needed:
- in my system, many processes may try to create the same key at the same time
- my setting function is extremely mutable with multiple remote calls (object store, DB, emitting some metrics, etc.)
- the setting function should only happen once no matter how many processes are trying to set the key at the same time
It's enough to do this within a single running service, so that I only need to coordinate within that system. If I have to scale this out and coordinate between systems, I'll drag in something like Zookeeper. But it's easy enough to build the above.
Last year I had this problem to solve in production and managed to crank out a solution in cats-effect. The part that really surprised me is how easy it was to put something together, despite the fact I was dealing with both (a) writing my own data structure (sort of) and (b) messing with concurrency. I managed to figure it out with the docs and the following:
- this excellent talk on ref + deferred and slides
- a little help from the type level discord in the advanced case (dealing with cancellation)
I was very surprised at how easy this was. I think that speaks to the great libraries available in Scala. We're going to be using cats-effect, but there are other solutions that are just as easy using ZIO, or ZIO-STM, or Cats-STM, and even Akka actors. Perhaps I'll blog about those another day.
I need this to coordinate a bunch of independent external processes. The basic design:
- an http4s server
- GET request comes in, checks if the value is in the cache returns
Option[Thing]
- POST request comes in:
- checks if value is in the cache, in which case return
Thing
- otherwise run our effectful function to figure out
Thing
and setThing
in the cache
- checks if value is in the cache, in which case return
- multiple
Get
/Post
requests can be happening concurrently for the sameId
to getThing
Ref
From the docs a ref is an asynchronous, concurrent mutable reference. You use it to share
mutable state between threads/fibers. We are going to store our shared mapping of Id
-> Thing
inside of a Ref
abstract class Ref[F[_], A] { def get: F[A] def set(a: A): F[Unit] def modify[B](f: A => (A, B)): F[B] // ... and more } // Specialized on IO abstract class Ref[A] { def get: IO[A] def set(a: A): IO[Unit] def modify[B](f: A => (A, B)): IO[B] // ... and more }
- it's always initialized with a value
- basically a functional wrapper over
AtomicReference
There is a compare and set under the hood when we try to modify the shared mapping.
Because an attempt to set the value may fail, the function modifying the Ref
may be called multiple times. Thus, whatever is in the Ref
should be immutable. If the inside is an immutable List, who cares how many times you pretend to an immutable list. But if it's a mutable list, you may prepend your element many times. Cats-effect ships with some additional immutable data structures (like queues) for this reason.
The text book example for Ref
is a shared counter between threads. We'll do a variation below, which is some sort of shared sum being
computed across multiple threads. First, a version that doesn't work due to no coordination between threads. We are trying to sum the
numbers 1 to 1000 across 10 different worker threads:
object VarFail extends IOApp.Simple { var counter: Int = 0 override def run: IO[Unit] = { val cases = Range(1, 1000).inclusive.toList val expectedSum = cases.sum for { _ <- cases.parTraverseN(10)( i => IO { val _ = counter = counter + i } ) _ <- IO.println(s"expected $expectedSum got $counter MATCHES ---> ${expectedSum === counter}") } yield () } }
Running this reveals that sometimes the variable never got updated properly. This is not a surprise and normally you would put a mutex around the variable or some sort of java synchronization primitive.
If we build this with a Ref[IO, Int]
instead:
object RefExample extends IOApp.Simple { var counter: IO[Ref[IO, Int]] = Ref.of[IO, Int](0) def increment(i: Int, ref: Ref[IO, Int]): IO[Unit] = ref.update(state => state + i) override def run: IO[Unit] = { val cases = Range(1, 1000).inclusive.toList val expectedSum = cases.sum for { ref <- counter _ <- cases.parTraverseN(10)(i => increment(i, ref)) state <- ref.get _ <- IO.println(s"expected $expectedSum got $state MATCHES ---> ${expectedSum === state}") } yield () } }
We get the answer we expect.
So cool. We have a way to (safely) share state between our threads. But we still have a problem where we only want something
to run once. Once again, Ref comes to the rescue. We can use Option
to give us some surprising behavior. In the next example, we'll
use Ref[IO, Option[Int]]
. If the value is None
then we can lock in a value. If the value is already Some
, whoever gets it is too late.
Said another way, we can get a handle on which thread "wins" the setting, and which threads lose the setting. We also use modfy rather than update. Modify wants
a function of A => (A, B)
. This let's us set a value and critically: let's us emit some value downstream as part of setting the value. In
the following example, if we are the "winning" thread w'ell print out that we won (our downstream IO). Otherwise, we print out that
we lost because a value was already set. Note the use of flatten
it's important since we are emitting an IO downstream, and modify is also
an IO
you end up in an IO[IO[...]]
:
object RefWinExample extends IOApp.Simple { val refOption: IO[Ref[IO, Option[Int]]] = Ref.of[IO, Option[Int]](None) def doStuff(i: Int, ref: Ref[IO, Option[Int]]): IO[Unit] = ref.modify { case None => Some(i) -> IO.println(s"****** i: $i -> I won the write! Setting saved value to $i") case Some(alreadySet) => Some(alreadySet) -> IO.println(s"i: $i -> Lost, there was an existing value set to $alreadySet") }.flatten override def run: IO[Unit] = { val cases = Range(1, 10).inclusive.toList for { ref <- refOption _ <- cases.parTraverseN(10)(i => doStuff(i, ref)) } yield () } }
Running this multiple times will result in a different "thread" winning. We don't know which thread, but we do know that it will only be one of them (the one who sets the None -> Some case).
That is as far as we can go with Ref
. To solve the rest, we need something beter. There is no functionality for synchronization with Ref
.
Enter...
Deferred
Deferred is a purely functional synchronization primitive which represents a single value which may not yet be available. It is created empty and fulfilled later.
abstract class Deferred[F[_], A] { def get: F[A] def complete(a: A): F[Boolean] } // specialized on IO abstract class Deferred[A] { def get: IO[A] def complete(a: A): IO[Boolean] }
The interesting parts for us:
get
semantically blocks (it's cats-effect, so we're not blocking a native thread) until someone completes it- someone needs to
complete
this. When it does, all the people "blocked" on it are notified and get the vbalue
This combo of Ref
and Deferred
is everywhere when you dig under the hood in Cats-effect. These two little primitives give you a surprising
amount of power. See the implementations of Semaphore
and Queue
for instance.
So how is this useful for our cache case?
- we know that one thread will "win" the
None -> Some
case - If we have something like a
Ref[IO, Deferred[...]]
then the thread that wins theNone -> Some
can be responsible for completing it - All the threads that lost can wait on the winner via
get
on the deferred
We will need another piece. And that's where finite state machines come in:
- there is a set of states
- there are concurrent inputs
- there are transitions between the states on inputs
- actions are run after the transition is done
This is how we get around the "I only want my very expensive mutation side effect setting function to run once". We store states in our cache, and only run the effect after a state transition is completed.
I'm still talking nonsense and a concrete example will help make the slides here more concrete.
We have two things to define: states and transitions, or as I call them in the example Operations.
First, let's define our cache:
trait SetOnceCache[K, V] { def get(k: K): IO[V] }
In our actual implementations, we will store everything in a Map. We'll use MapRef, which is a Ref
implementation optimized for Maps. You could also use a bare immutable Scala Map but in production I wanted the performance of the former as I will
have many, many keys. Key is easy. But what is the value of our map? It's our states! Well, our optional state because we need the None -> Some
trick to determine the thread winner.
This leaves two states then:
- Done w/ some value (e.g. the setting function has been run and we have some value to return)
- Waiting (e.g. the value isn't available yet but someone is getting it)
We don't need a state for value doesn't exist
because that's handled by the None
case and the Map
itself.
sealed trait State[V] case class Done[V](value: IO[V]) extends State[V] case class Running[V](getWait: Deferred[IO, Outcome[IO, Throwable, V]]) extends State[V]
If we're done, here's the value. Otherwise our state is Running and we will be waiting for some deferred to complete.
Now we have a bunch of Operations. These are our transitions between the states.
sealed trait Operation[V] // if you get this transition // you are the thread responsible for completing the deferred case class Set[V](set: Deferred[IO, Outcome[IO, Throwable, V]]) extends Operation[V] // if you get this transition // someone else is responsible for completing the deferred // and you can wait on it if you want to get the value eventually case class Wait[V](getWait: Deferred[IO, Outcome[IO, Throwable, V]]) extends Operation[V] // if you get this transition // there is already a value for you case class Completed[V](value: IO[V]) extends Operation[V]
Cool. The rest is just pattern matching. Lots and lots of pattern matching. Given the above, we'll actually implement this in like twenty lines of scala.
// to make a SetOnceCache we need some function that sets the value as a function of the key if it does not exist // this would be my expensive (call object store, hit DB, emit metrics, etc.) that I pass in // we leave it a simple f because it also simplifies testing. This doesn't need to know any of that, it just needs to // only run it once per key. object SetOnceCache { def make[K, V](f: K => IO[V]): IO[SetOnceCache[K, V]] = MapRef .ofScalaConcurrentTrieMap[IO, K, State[V]] .map( mapref => new SetOnceCache[K, V] { def get(k: K): IO[V] = Deferred[IO, Outcome[IO, Throwable, V]].flatMap { wait => mapref(k) .modify[Operation[V]] { case None => (Running(wait).some, Set(wait)) case s@Some(Running(wait)) => (s, Wait(wait)) case s@Some(Done(v)) => (s, Completed(v)) }.flatten }
This is mostly just annotation:
- we make a Mapref backed by a ScalaConcurrentTrieMap.
- We then use this to back our implementation of
SetOnceCache
- in order to do anything, you first need to make a deferred. We may chuck this deferred out in the majority of cases, but it's cheap to make:
- if the value doesn't exist in the cache at all, we're going to need it
- we see the same call to
modify
in order to do theNone-> Some
trick - if we are the thread that gets
None
, then we update the State to say someone is fulfilled the value right now and here is theDeferred
to wait on- we emit a value downstream of
Set
for our transition
- we emit a value downstream of
- if we instead pull the state from the map and get
Running
then we are a thread that lost. Someone else is already handling the value for this key- we do not update the state in the map
- we emit a value downstream of
Wait
for our transition
- if we pull the state from the map and we get that it is
Done
:- we do not update the state in the map (it's already done w/ a value!)
- we emit a value downstream of
Completed
for our transition
Continuing on, we'll flatmap this and now our elements we are dealing with are the transitions. More pattern matching:
.flatMap { case Set(_) => f(k).guaranteeCase { case s@Outcome.Succeeded(fa) => mapref(k).modify { case Some(Running(wait)) => (Done(fa).some, wait.complete(s).void) case s@Some(Done(_)) => (s, IO.unit) // technically unreachable but type checker :( case None => (None, IO.unit) // technically unreachable but type checker :( }.flatten // <- don't forgeet these flattens
One and only one of the processes will end up with the Set
transition:
- we run our setting function
- we run guaranteeCase which will return (Success, failed, cancelled)
- if the setting function succeeded:
- modify our State:
- if the state is Running, replace it with Done and the value we got from out setting function
f
- also complete the deferred inside of
Running
with the value we got from our setting functionf
- if the state is Running, replace it with Done and the value we got from out setting function
- the remaining cases here are to satisfy the type checker in the simple case
- these do matter in more complex versions of the cache which account for deletion of keys and cancelling of IOs
- modify our State:
Continuing on, we need to deal with what to do if f(k)
didn't succeed:
case s => mapref(k).modify { case Some(Running(wait)) => (None, wait.complete(s).void) // clears the cached value but returns a failure case s@Some(Done(_)) => (s, IO.unit) case None => (None, IO.unit) }.flatten }
- if we didn't succeed, basically: nuke this value from the Cache
- complete the Deferred with a failed value (this will fail everyone who is waiting on this value)
- again, the other cases aren't interesting in the basic case.
If the setter fails just die. Someone upstream will have to retry this key or there was some reason why we couldn't cahce it.
Continuing on, we need to deal with the other transitions. What if we are a thread that doesn't have to Set
the value but is
instead waiting on it?
case Wait(wait) => wait.get.flatMap { // `.get` "soft" blocks on IO but not on an actual thread case Outcome.Succeeded(fa) => fa case Outcome.Errored(e) => IO.raiseError(e) case Outcome.Canceled() => IO.raiseError(new Throwable("Someone cancelled the setter")) }
If we are the waiting transition, well we just wait on the deferred.
- If that deferred succeeds, then we return the value.
- otherwise we raise an error, something went wrong
- the canceled case doesn't really show up in the simple case but it's there
Finally our last transition, which is the value already existing in our state:
case Completed(v) => v
In this case we just return the completed value.
Putting it all together:
trait SetOnceCache[K, V] { def get(k: K): IO[V] } object SetOnceCache { private sealed trait State[V] private case class Running[V](getWait: Deferred[IO, Outcome[IO, Throwable, V]]) extends State[V] private case class Done[V](value: IO[V]) extends State[V] private sealed trait Operation[V] extends Product with Serializable private case class Set[V](set: Deferred[IO, Outcome[IO, Throwable, V]]) extends Operation[V] private case class Wait[V](getWait: Deferred[IO, Outcome[IO, Throwable, V]]) extends Operation[V] private case class Completed[V](value: IO[V]) extends Operation[V] def make[K, V](f: K => IO[V]): IO[SetOnceCache[K, V]] = MapRef .ofScalaConcurrentTrieMap[IO, K, State[V]] .map( mapref => new SetOnceCache[K, V] { def get(k: K): IO[V] = Deferred[IO, Outcome[IO, Throwable, V]].flatMap { wait => mapref(k) .modify[Operation[V]] { case None => (Running(wait).some, Set(wait)) case s@Some(Running(wait)) => (s, Wait(wait)) case s@Some(Done(v)) => (s, Completed(v)) }.flatten .flatMap { case Set(_) => f(k).guaranteeCase { case s@Outcome.Succeeded(fa) => mapref(k).modify { case Some(Running(wait)) => (Done(fa).some, wait.complete(s).void) case s@Some(Done(_)) => (s, IO.unit) // technically unreachable but type checker :( case None => (None, IO.unit) // technically unreachable but type checker :( }.flatten // <- don't forgeet these flattens case s => mapref(k).modify { case Some(Running(wait)) => (None, wait.complete(s).void) // clears the cached value but returns a failure case s@Some(Done(_)) => (s, IO.unit) case None => (None, IO.unit) }.flatten } case Wait(wait) => wait.get.flatMap { // `.get` "soft" blocks on IO but not on an actual thread case Outcome.Succeeded(fa) => fa case Outcome.Errored(e) => IO.raiseError(e) case Outcome.Canceled() => IO.raiseError(new Throwable("Someone cancelled the setter")) } case Completed(v) => v } } } ) }
This is all we need in the simple case and was my first pass at this. The more complicated case deals with:
- timeouts
- you do not want your setting function to hang indefinitely, if it does everyone will forever wait for a Deferred to get completed
- deletions and cancellations
- what happens if someone deletes a key while your are setting it (some of those other cases come into play)
- what happens if someone cancels the setting function while you are setting it?
The above don't add that much more complexity (some uncancelable
and poll
and a bit of cleanup) to the simple structure above.
Now let's use it!
object SimpleExample extends IOApp.Simple { // doesn't handle deletions well, doesn't handle timeouts well but you get the idea override def run: IO[Unit] = { val actionFunctionRecorder = Ref[IO].of(List.empty[Int]) val result = actionFunctionRecorder.flatMap { recorder => val someExpensiveCall: Int => IO[String] = i => for { _ <- IO.sleep(100.milliseconds) _ <- recorder.update(xs => i +: xs) _ <- IO.println(s"*******RECORDED: $i") } yield s"id: $i" val keys = Range(1, 25).inclusive.toList val processes = Range(1, 25).inclusive.toList // We run 25 different ids with 25 concurrent gets to each key at the same time at the same time val ids = for { k <- keys p <- processes } yield (k, p) val cache = SetOnceCache.make(someExpensiveCall) for { cache <- cache _ <- ids.parTraverse { case (key, process) => IO.println(s"key $key -> process $process") >> cache.get(key) } calledIds <- recorder.get _ <- IO.println("finished\n\n\n------------\n") } yield calledIds.sorted } result.flatMap(xs => IO.println(xs)) } }
We make a Ref
to record how many times a setting function is called. As you can see, the setter is only ever run once.
This still blows me away! We were able to build out some complicated concurrency stuff with honestly what amounts to a bunch of case classes and pattern matches. You see this pattern everywhere and it's useful.
Take for instance, Single Fibered from the Davenverse:
* Prepares a function so that the resulting function is single-fibered.
* This means that no matter how many fibers are executing this function.
* For any specific key only 1 will be running at the same time, others
* will share the result of that running computation. As soon as that computation
* completes other computations will again be able to run the function again.
*
It's a simpler case than mine: they just want to make it so multiple functions invocations of the same function aren't run at the same time, but aren't interested in caching the result:
*/ def singleFiberedFunction[F[_]: Concurrent, K, V]( state: K => Ref[F, Option[F[Outcome[F, Throwable, V]]]], f: K => F[V] ) = { {(k: K) => Deferred[F, Outcome[F, Throwable, V]].flatMap{d => Concurrent[F].uncancelable{poll => state(k) .modify{ case s@Some(out) => s -> poll(out) .flatMap(embedError(_)) case None => Some(d.get) -> Concurrent[F].guaranteeCase(poll(f(k))){ o => state(k).set(None) >> d.complete(o).void } }.flatten } } } }
Note how the Ref
is an Option
, again so you can get a handle on the None
-> Some
case. If you squint, you can also see where
the Deferred
gets completed. After which case they clear out their cache (they are only interested in making sure multiple things
don't run at the same time, not interested in caching the result). There is added complexity because they are also dealing with
cancellation.
All said and done, pretty fun! I hope this helps you and demystifies what a lot of the cats-effect standard library is doing under the hood as it shows up everywhere. You can see it also in the implementation of Semaphore, although
without the necessity of the Deferred
synchronization.