Towards an applicative-for macro

Some time ago with the release of the Haskell's Haxl library the same authors realized that the do notation in Haskell could use a improvement. In it's current form the do notation can be used to "flatten" deep nestings of >>= applications. For example instead of writting something like this:

findVal :: String -> Maybe Int
findVal key = ... -- some function definition

sum :: Maybe Int
sum = 
    findVal "key1" >>= \val1 -> 
        findVal "key2" >>= \val2 -> 
            return (val1 + val2)

you can write:

sum :: Maybe Int
sum = do val1 <- findVal "key1"
         val2 <- findVal "key2"
         return (val1 + val2)

This makes it easier to read the expression and to understand what's going on. Haskell's compiler simply desugars this notation into >>= applications. The do notation is specially readable because you can interpret it as follows: At the left side of a <- you will find the "extracted" value of the monadic value at the right side.

As you may have noticed in the previous example we are not using the full power of the Monad typeclass. We are just calling >>= because we want to have the values inside the monadic values in the same place so we can express our desired computation: (val1 + val2). But Monad is more powerful than that! For example we can express something like this:

do val1 <- findVal "key1"
   val2 <- findVal ("key" ++ show val1)
   return val2

In this case the second monadic value depends on the result of the first monadic value. Unlike the previous example this can only be achieved with something as powerful as Monad.

When programming we want to use the least powerful abstraction and in functional programming this translates into using the least specific typeclass. The least powerful abstraction that allows us to join values in different contexts into one is Applicative. We can rewrite the first example like this:

add x y = x + y

sum = add <$> (findVal "key1") <*> (findVal "key2")

If you are thinking this is less readable than the do version you are on to something. As it turns out for the Maybe monad there is not much difference between choosing one style over the other because both alternatives pretty much do the same. But there are other monads in which it is preferrable to use Applicative when possible. The first one that comes to my mind is precisely Haxl's Fetch monad. Fetch is a concurrency monad for fetching data from remote sources. With the applicative instance for Fetch independent data fetches can be done concurrently. Even more, if they access the same data source then the query can be batched.

For instance in the following situation we are using Applicative so the fetches will be done concurrently and if they are fetches for the same data type then they will be batched:

fetchData :: String -> Fetch Data
fetchData key = ... -- build fetch value

join x y = (x,y)

dataTuple :: (Data,Data)
dataTuple = join <$> fetchData "key1" <*> fetchData "key2"

But in contrast if we try to express the same with a do expression like this:

do data1 <- fetchData "key1"
   data2 <- fetchData "key2"
   return (data1,data2)

In this situation the do expression is desugared into >>= applications. Then the fetches will be executed sequentially and not concurrently. This is because the first monadic value must be computed before computing the one in the second line. Thus, by trying to use a more intuitive and readable notation we are incurring in a performance loss.

But the problem is not just readability. When you are making changes to a codebase you must be aware that the change may allow you to use Applicative instead of monad. This may happen if, for example you are changing a line of a long do expression.

Wouldn't it be cool if do expressions would use Monad just when they are strictly necessary and Applicative where possible? This would allow us to always use the do notation without having to worry if we are using the correct typeclass. This is precisely the topic of Haxl's follow-up paper. There is a proposal in the Haskell community to include this behaviour in the GHC. This kind of transformation would allow us to treat values that have an Applicative or Monad instance with the same universal notation: the do expression.

You may be asking "doesn't this break the 'law' that says that Applicative and Monad should behave consistently?" More concretely I think that law is described as fulfilling the following equation:

fa >>= \a -> 
    fb >>= \b -> 
        return (f a b) == f <$> fa <*> fb

Well, the applicative-do transformation would break this law for types that return different values at each side of this equation. But it's still a useful transformation if it's more efficient to use Applicative than Monad to compute the same value.

But in the other cases there is another answer: Sometimes it's worth to break the laws if your intentions are pure. This is the case of Fetch in which the value returned may differ in shape but at the end should reflect the same data.

This post is about trying to do the same transformation in Scala. First, remember that Scala's for comprehensions are similar to Haskell's do expressions: they just call flatMap for each <- except for the last one which will be a map call. The first example that we previously wrote in Haskell can be translated to Scala like this:

def findVal(key: String): Option[Int] = ???

for {
    val1 <- findVal("key1")
    val2 <- findVal("key2")
} yield val1 + val2

Let's start with our very own definition of a Validation which will be similar to Scalaz's ValidationNel or Cat's ValidatedNel though it will be less general:

In [1]:
sealed trait Validation[+A] {
    def flatMap[B](f: A => Validation[B]): Validation[B] = this match {
        case Success(value) => f(value)
        case Failure(error) => Failure(error)
    }
    def map[B](f: A => B): Validation[B] = flatMap(a => Success(f(a)))
}
case class Success[A](value: A) extends Validation[A]
case class Failure(errors: List[String]) extends Validation[Nothing]
defined trait Validation
defined class Success
defined class Failure

Now let's define the Applicative typeclass and then describe the instance for Validation:

In [2]:
trait Applicative[F[_]] {
    def pure[A](a: A): F[A]
    def map2[A,B,C](fa: F[A], fb: F[B])(f: (A,B) => C): F[C]
}

implicit object ValidationApplicative extends Applicative[Validation] {
    def pure[A](a: A) = Success(a)
    def map2[A,B,C](va: Validation[A], vb: Validation[B])(f: (A,B) => C): Validation[C] = {
        (va, vb) match {
            case (Success(a) , Success(b) ) => Success(f(a,b))
            case (Failure(ea), Failure(eb)) => Failure(ea ++ eb)
            case (Failure(ea), _          ) => Failure(ea)
            case (_          , Failure(eb)) => Failure(eb)            
        }
    }
}
defined trait Applicative
defined object ValidationApplicative

Our definition of Applicative is a little bit different from the usual formulation which describes a function ap (with type F[A => B] => F[A] => F[B]). As it turns out both formulations are equivalent: you can convince yourself by implementing one in terms of the other.

Another thing to notice is that when flatMap fails it will fail with the error of the first Validation value (because there is in fact, no other way to do this). In contrast, the Applicative instance says that if both values are Failures then we can accumulate the errors.

Let's see an example. First with flatMap:

In [3]:
val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))

val withFlatMap = for {
    x <- v1
    y <- v2
} yield x + y
v1: Validation[Int] = Failure(List(error1))
v2: Validation[Int] = Failure(List(error2))
withFlatMap: Validation[Int] = Failure(List(error1))

In this case the error in the result value is the one in the first validation. But given that the value v2 doesn't depend on the value x couldn't we also report the failure of v2? Let's see what Applicative can do:

In [4]:
val withApplicative = ValidationApplicative.map2(v1,v2)(_ + _)
withApplicative: Validation[Int] = Failure(List(error1, error2))

Unlike withFlatmap this one returns both errors. And if both v1 and v2 were to be successfull then both expressions would return the same value.

Now, let's imagine we have a web form with a bunch of fields, each one of which has to be validated. But when the form is submitted and contains errors we don't want to bother the user by just reporting the first error. We would like to report the majority of independent errors:

for {
    okFirstName <- validateFirstNameField
    okLastName  <- validateLastNameField
    okFullName  <- validateFullName(okFirstName, okLastName)
    okAge       <- validateAge
} yield NewUserData(okFirstName, okLastName, okFullName, okAge)

If we used a for comprehension sequencing each validation then we would be making a mistake because if there is one error only that one is going to be returned. We can use Applicative and some syntactic sugar like the one in Scalaz to get something like this:

for {
    (okFirstName, okLastName) <- (validateFirstNameField |@| 
                                  validateLastNameField).tupled |@| 
    (okFullName, okAge)       <- (validateFullName(okFirstName, okLastName) |@|
                                  validateAge).tupled
} yield NewUserData(okFirstName, okLastName, okFullName, okAge)

This works but it may be less readable. More importantly this is coupled to the current computation structure. You can imagine what may happen with more fields and more complex dependencies between those fields.

Towards an applicative-for macro

It would be very useful if this could be done automatically by the compiler. In Scala the for comprehensions syntax is just another phase of the compiler. So when a macro inspects this code it will be already desugared into a nested sequence of flatMaps and maps. Let's see if we can build a macro that replaces flatMaps and maps by Applicative.map2s when possible.

We are going to start with a very simple example:

In [5]:
val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))

for {
    x <- v1
    y <- v2
} yield x + y
v1: Validation[Int] = Failure(List(error1))
v2: Validation[Int] = Failure(List(error2))
res4_2: Validation[Int] = Failure(List(error1))

Let's inspect the tree generated by this for comprehension:

In [6]:
import scala.reflect.runtime.universe._

val tree = reify {
    for {
        x <- v1
        y <- v2
    } yield x + y
}.tree

showRaw(tree)
import scala.reflect.runtime.universe._
tree: reflect.runtime.package.universe.Tree = cmd5.$ref$cmd4.v1.flatMap(((x) => cmd5.$ref$cmd4.v2.map(((y) => x.$plus(y)))))
res5_2: String = """
Apply(Select(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v1")), TermName("flatMap")), List(Function(List(ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree)), Apply(Select(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2")), TermName("map")), List(Function(List(ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y"))))))))))
"""

That's a lot. The part that interests us is the flatMap application:

In [7]:
val Apply(Select(firstMonadicValue, TermName("flatMap")), List(functionDef)) = tree
firstMonadicValue: Tree = cmd5.$ref$cmd4.v1
functionDef: Tree = ((x) => cmd5.$ref$cmd4.v2.map(((y) => x.$plus(y))))

Now we must separate the function into it's argument and it's body:

In [8]:
showRaw(functionDef)
val Function(List(firstArgumentTerm), functionBody) = functionDef 
//               ^ only works for functions of arity one, which works for map and flatMap
res7_0: String = """
Function(List(ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree)), Apply(Select(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2")), TermName("map")), List(Function(List(ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y"))))))))
"""
firstArgumentTerm: ValDef = val x = _
functionBody: Tree = cmd5.$ref$cmd4.v2.map(((y) => x.$plus(y)))

If the functionBody calls map or flatMap over some expression, then we must identify if that expression uses firstTermArgument. If that's the case then we can't do anything, one expression depends on the other and flatMap is the right choice. But if not, then that's an opportunity to use Applicative's map2 instead of flatMap. Let's first define a function usesTerm that will indicate if a term is used in some expression:

In [9]:
def usesTerm(term: ValDef, exp: Tree): Boolean = {
    val ValDef(_,termName,_,_) = term
    exp.find{
        case Ident(_termName) if termName == _termName => 
            true
        case _ => 
            false
    }.isDefined
}
defined function usesTerm

Let's separate the functionBody into two parts: the second monadic value and the next function definition:

In [10]:
showRaw(functionBody)
val Apply(Select(secondMonadicValue, TermName("map")), List(secondFunctionDef)) = functionBody
showRaw(firstArgumentTerm)
res9_0: String = """
Apply(Select(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2")), TermName("map")), List(Function(List(ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y")))))))
"""
secondMonadicValue: Tree = cmd5.$ref$cmd4.v2
secondFunctionDef: Tree = ((y) => x.$plus(y))
res9_2: String = """
ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree)
"""

And we are interested in answering the question: is the "extracted" value for the first monad (firstArgumentTerm) being used when defining the second monadic value?

In [11]:
usesTerm(firstArgumentTerm, secondMonadicValue)
res10: Boolean = false

As we can see when computing the second term we don't need the function argument. So we want to transform this flatMap->map call into an Applicative's map2 call. For this we will need to extract the innermost expression of the for comprehension, that is the x+y expression. After that we will have to build the map2 call passing the appropiate arguments. First, let's extract that expression from nextFunctionDef:

In [12]:
showRaw(secondFunctionDef)
val Function(List(secondArgumentTerm), innerExpr) = secondFunctionDef
res11_0: String = """
Function(List(ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y")))))
"""
secondArgumentTerm: ValDef = val y = _
innerExpr: Tree = x.$plus(y)

These are all the ingredients we need:

In [13]:
println(firstMonadicValue)
println(firstArgumentTerm)
println(secondMonadicValue)
println(secondArgumentTerm)
println(innerExpr)
cmd5.$ref$cmd4.v1
val x = _
cmd5.$ref$cmd4.v2
val y = _
x.$plus(y)

Finally let's combine them with ValidationApplicative.map2. We will describe our desired expression with a quasiquote:

In [14]:
val result = q"""
ValidationApplicative.map2(
    $firstMonadicValue,
    $secondMonadicValue,
    ${Function(List(firstArgumentTerm, secondArgumentTerm), innerExpr)}
)
"""
showRaw(result)
result: Tree = ValidationApplicative.map2(cmd5.$ref$cmd4.v1, cmd5.$ref$cmd4.v2, ((x, y) => x.$plus(y)))
res13_1: String = """
Apply(Select(Ident(TermName("ValidationApplicative")), TermName("map2")), List(Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v1")), Select(Select(Ident(TermName("cmd5")), TermName("$ref$cmd4")), TermName("v2")), Function(List(ValDef(Modifiers(PARAM), TermName("x"), TypeTree(), EmptyTree), ValDef(Modifiers(PARAM), TermName("y"), TypeTree(), EmptyTree)), Apply(Select(Ident(TermName("x")), TermName("$plus")), List(Ident(TermName("y")))))))
"""

Unfortunately I haven't found a way to compile and execute expressions from a REPL. There exists the function compile in the ToolBox trait but I haven't found a way to use it from a REPL. That'd be useful for a faster development process. If you know how, please tell me!

Anyway, what comes next is just the code I wrote in a sbt project. You can find the full working code here. Putting it all together here is our macro implementation:

def app_for_impl(c: Context)(valid: c.Expr[M]): c.Expr[M] = {
    import c.universe._

    val Apply(
        TypeApply(Select(firstMonadicValue, TermName("flatMap")),_), 
        List(functionDef)
    ) = valid.tree
    val Function(List(firstArgumentTerm), functionBody) = functionDef
    val Apply(
            TypeApply(Select(secondMonadicValue, TermName("map")),_), 
            List(secondFunctionDef)
    ) = functionBody
    if(usesTerm(c.universe)(firstArgumentTerm, secondMonadicValue)) {
      valid
    } else {
      val Function(List(secondArgumentTerm), innerExpr) = secondFunctionDef
      c.Expr(q"""_root_.appfor.Validation.applicativeInstance.map2(
          $firstMonadicValue,
          $secondMonadicValue
      )(
          ${Function(List(firstArgumentTerm, secondArgumentTerm), innerExpr)}
      )""")
    }
}

As you may have noticed there are some differences with respect to what we did above: when matching the function calls the first argument has different shape (it's wrapped in a TypeApply object). Other than that it's mostly the same as what we did above. I don't know the cause for these differences with respect to our REPL / Jupyter session. I'm just starting to learn macros!

Let's see the macro in action:

val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))

val resultWhenApplicative = app_for {
    for {
      x <- v1
      y <- v2
    } yield x + y
}
println(resultWhenApplicative)
// > Failure(List(error1, error2))

This is what we wanted! We used a for comprehension but our macro detected that map2 could be used and thus both errors were reported!

Now let's test it with an expression that really needs flatMap:

val resultWhenMonad = app_for {
    for {
      x <- v1
      y <- if(x>0) v2 else v1
    } yield x + y
}

println(resultWhenMonad)
// Failure(List(error1))

It works, just like a a normal for comprehension!

What's next?

This macro is not fully functional, though. It doesn't account for a lot of situations:

  • The pattern matches are unsafe and don't produce good error messages when they fail.
  • This works just for two level expressions, but for comprehensions can be longer.
  • This macro doesn't account for more subtle uses of a monad result like pattern matches. For example:
case class Person(firstName: String, lastName: String, age: Int)

for {
    person @ Person(_,_,age) <- validatePerson
    _                        <- if(age < 18) validateMinor(person) else validateAdult(person)
} yield something
  • There is the issue of generalizing this for any type and not just for our own Validation.
  • Also I'm not sure if the pattern matches are as general as they can be.

And I may be missing some other things.

The next steps will be to fix one by one each of these problems.