Some time ago with the release of the Haskell's Haxl library the same authors realized that the do
notation in Haskell could use a improvement. In it's current form the do
notation can be used to "flatten" deep nestings of >>=
applications. For example instead of writting something like this:
findVal :: String -> Maybe Int
findVal key = ... -- some function definition
sum :: Maybe Int
sum =
findVal "key1" >>= \val1 ->
findVal "key2" >>= \val2 ->
return (val1 + val2)
you can write:
sum :: Maybe Int
sum = do val1 <- findVal "key1"
val2 <- findVal "key2"
return (val1 + val2)
This makes it easier to read the expression and to understand what's going on. Haskell's compiler simply desugars this notation into >>=
applications. The do
notation is specially readable because you can interpret it as follows: At the left side of a <-
you will find the "extracted" value of the monadic value at the right side.
As you may have noticed in the previous example we are not using the full power of the Monad
typeclass. We are just calling >>=
because we want to have the values inside the monadic values in the same place so we can express our desired computation: (val1 + val2)
. But Monad
is more powerful than that! For example we can express something like this:
do val1 <- findVal "key1"
val2 <- findVal ("key" ++ show val1)
return val2
In this case the second monadic value depends on the result of the first monadic value. Unlike the previous example this can only be achieved with something as powerful as Monad
.
When programming we want to use the least powerful abstraction and in functional programming this translates into using the least specific typeclass. The least powerful abstraction that allows us to join values in different contexts into one is Applicative
. We can rewrite the first example like this:
add x y = x + y
sum = add <$> (findVal "key1") <*> (findVal "key2")
If you are thinking this is less readable than the do
version you are on to something. As it turns out for the Maybe
monad there is not much difference between choosing one style over the other because both alternatives pretty much do the same. But there are other monads in which it is preferrable to use Applicative
when possible. The first one that comes to my mind is precisely Haxl's Fetch
monad. Fetch
is a concurrency monad for fetching data from remote sources. With the applicative instance for Fetch
independent data fetches can be done concurrently. Even more, if they access the same data source then the query can be batched.
For instance in the following situation we are using Applicative
so the fetches will be done concurrently and if they are fetches for the same data type then they will be batched:
fetchData :: String -> Fetch Data
fetchData key = ... -- build fetch value
join x y = (x,y)
dataTuple :: (Data,Data)
dataTuple = join <$> fetchData "key1" <*> fetchData "key2"
But in contrast if we try to express the same with a do
expression like this:
do data1 <- fetchData "key1"
data2 <- fetchData "key2"
return (data1,data2)
In this situation the do
expression is desugared into >>=
applications. Then the fetches will be executed sequentially and not concurrently. This is because the first monadic value must be computed before computing the one in the second line. Thus, by trying to use a more intuitive and readable notation we are incurring in a performance loss.
But the problem is not just readability. When you are making changes to a codebase you must be aware that the change may allow you to use Applicative
instead of monad. This may happen if, for example you are changing a line of a long do
expression.
Wouldn't it be cool if do
expressions would use Monad
just when they are strictly necessary and Applicative
where possible? This would allow us to always use the do
notation without having to worry if we are using the correct typeclass. This is precisely the topic of Haxl's follow-up paper. There is a proposal in the Haskell community to include this behaviour in the GHC. This kind of transformation would allow us to treat values that have an Applicative
or Monad
instance with the same universal notation: the do expression.
You may be asking "doesn't this break the 'law' that says that Applicative and Monad should behave consistently?" More concretely I think that law is described as fulfilling the following equation:
fa >>= \a -> fb >>= \b -> return (f a b) == f <$> fa <*> fbWell, the applicative-do transformation would break this law for types that return different values at each side of this equation. But it's still a useful transformation if it's more efficient to use
Applicative
thanMonad
to compute the same value.But in the other cases there is another answer: Sometimes it's worth to break the laws if your intentions are pure. This is the case of
Fetch
in which the value returned may differ in shape but at the end should reflect the same data.
This post is about trying to do the same transformation in Scala. First, remember that Scala's for
comprehensions are similar to Haskell's do
expressions: they just call flatMap
for each <-
except for the last one which will be a map
call. The first example that we previously wrote in Haskell can be translated to Scala like this:
def findVal(key: String): Option[Int] = ???
for {
val1 <- findVal("key1")
val2 <- findVal("key2")
} yield val1 + val2
Let's start with our very own definition of a Validation
which will be similar to Scalaz's ValidationNel
or Cat's ValidatedNel
though it will be less general:
sealed trait Validation[+A] {
def flatMap[B](f: A => Validation[B]): Validation[B] = this match {
case Success(value) => f(value)
case Failure(error) => Failure(error)
}
def map[B](f: A => B): Validation[B] = flatMap(a => Success(f(a)))
}
case class Success[A](value: A) extends Validation[A]
case class Failure(errors: List[String]) extends Validation[Nothing]
Now let's define the Applicative
typeclass and then describe the instance for Validation
:
trait Applicative[F[_]] {
def pure[A](a: A): F[A]
def map2[A,B,C](fa: F[A], fb: F[B])(f: (A,B) => C): F[C]
}
implicit object ValidationApplicative extends Applicative[Validation] {
def pure[A](a: A) = Success(a)
def map2[A,B,C](va: Validation[A], vb: Validation[B])(f: (A,B) => C): Validation[C] = {
(va, vb) match {
case (Success(a) , Success(b) ) => Success(f(a,b))
case (Failure(ea), Failure(eb)) => Failure(ea ++ eb)
case (Failure(ea), _ ) => Failure(ea)
case (_ , Failure(eb)) => Failure(eb)
}
}
}
Our definition of Applicative
is a little bit different from the usual formulation which describes a function ap
(with type F[A => B] => F[A] => F[B]
). As it turns out both formulations are equivalent: you can convince yourself by implementing one in terms of the other.
Another thing to notice is that when flatMap
fails it will fail with the error of the first Validation
value (because there is in fact, no other way to do this). In contrast, the Applicative
instance says that if both values are Failure
s then we can accumulate the errors.
Let's see an example. First with flatMap
:
val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))
val withFlatMap = for {
x <- v1
y <- v2
} yield x + y
In this case the error in the result value is the one in the first validation. But given that the value v2
doesn't depend on the value x
couldn't we also report the failure of v2
? Let's see what Applicative
can do:
val withApplicative = ValidationApplicative.map2(v1,v2)(_ + _)
Unlike withFlatmap
this one returns both errors. And if both v1
and v2
were to be successfull then both expressions would return the same value.
Now, let's imagine we have a web form with a bunch of fields, each one of which has to be validated. But when the form is submitted and contains errors we don't want to bother the user by just reporting the first error. We would like to report the majority of independent errors:
for {
okFirstName <- validateFirstNameField
okLastName <- validateLastNameField
okFullName <- validateFullName(okFirstName, okLastName)
okAge <- validateAge
} yield NewUserData(okFirstName, okLastName, okFullName, okAge)
If we used a for comprehension sequencing each validation then we would be making a mistake because if there is one error only that one is going to be returned. We can use Applicative and some syntactic sugar like the one in Scalaz to get something like this:
for {
(okFirstName, okLastName) <- (validateFirstNameField |@|
validateLastNameField).tupled |@|
(okFullName, okAge) <- (validateFullName(okFirstName, okLastName) |@|
validateAge).tupled
} yield NewUserData(okFirstName, okLastName, okFullName, okAge)
This works but it may be less readable. More importantly this is coupled to the current computation structure. You can imagine what may happen with more fields and more complex dependencies between those fields.
It would be very useful if this could be done automatically by the compiler. In Scala the for comprehensions syntax is just another phase of the compiler. So when a macro inspects this code it will be already desugared into a nested sequence of flatMap
s and map
s. Let's see if we can build a macro that replaces flatMap
s and map
s by Applicative.map2
s when possible.
We are going to start with a very simple example:
val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))
for {
x <- v1
y <- v2
} yield x + y
Let's inspect the tree generated by this for comprehension:
import scala.reflect.runtime.universe._
val tree = reify {
for {
x <- v1
y <- v2
} yield x + y
}.tree
showRaw(tree)
That's a lot. The part that interests us is the flatMap
application:
val Apply(Select(firstMonadicValue, TermName("flatMap")), List(functionDef)) = tree
Now we must separate the function into it's argument and it's body:
showRaw(functionDef)
val Function(List(firstArgumentTerm), functionBody) = functionDef
// ^ only works for functions of arity one, which works for map and flatMap
If the functionBody
calls map
or flatMap
over some expression, then we must identify if that expression uses firstTermArgument
. If that's the case then we can't do anything, one expression depends on the other and flatMap
is the right choice. But if not, then that's an opportunity to use Applicative
's map2
instead of flatMap
. Let's first define a function usesTerm
that will indicate if a term is used in some expression:
def usesTerm(term: ValDef, exp: Tree): Boolean = {
val ValDef(_,termName,_,_) = term
exp.find{
case Ident(_termName) if termName == _termName =>
true
case _ =>
false
}.isDefined
}
Let's separate the functionBody
into two parts: the second monadic value and the next function definition:
showRaw(functionBody)
val Apply(Select(secondMonadicValue, TermName("map")), List(secondFunctionDef)) = functionBody
showRaw(firstArgumentTerm)
And we are interested in answering the question: is the "extracted" value for the first monad (firstArgumentTerm
) being used when defining the second monadic value?
usesTerm(firstArgumentTerm, secondMonadicValue)
As we can see when computing the second term we don't need the function argument. So we want to transform this flatMap
->map
call into an Applicative
's map2
call. For this we will need to extract the innermost expression of the for comprehension, that is the x+y
expression. After that we will have to build the map2
call passing the appropiate arguments. First, let's extract that expression from nextFunctionDef
:
showRaw(secondFunctionDef)
val Function(List(secondArgumentTerm), innerExpr) = secondFunctionDef
These are all the ingredients we need:
println(firstMonadicValue)
println(firstArgumentTerm)
println(secondMonadicValue)
println(secondArgumentTerm)
println(innerExpr)
Finally let's combine them with ValidationApplicative.map2
. We will describe our desired expression with a quasiquote:
val result = q"""
ValidationApplicative.map2(
$firstMonadicValue,
$secondMonadicValue,
${Function(List(firstArgumentTerm, secondArgumentTerm), innerExpr)}
)
"""
showRaw(result)
Unfortunately I haven't found a way to compile and execute expressions from a REPL. There exists the function compile
in the ToolBox
trait but I haven't found a way to use it from a REPL. That'd be useful for a faster development process. If you know how, please tell me!
Anyway, what comes next is just the code I wrote in a sbt project. You can find the full working code here. Putting it all together here is our macro implementation:
def app_for_impl(c: Context)(valid: c.Expr[M]): c.Expr[M] = {
import c.universe._
val Apply(
TypeApply(Select(firstMonadicValue, TermName("flatMap")),_),
List(functionDef)
) = valid.tree
val Function(List(firstArgumentTerm), functionBody) = functionDef
val Apply(
TypeApply(Select(secondMonadicValue, TermName("map")),_),
List(secondFunctionDef)
) = functionBody
if(usesTerm(c.universe)(firstArgumentTerm, secondMonadicValue)) {
valid
} else {
val Function(List(secondArgumentTerm), innerExpr) = secondFunctionDef
c.Expr(q"""_root_.appfor.Validation.applicativeInstance.map2(
$firstMonadicValue,
$secondMonadicValue
)(
${Function(List(firstArgumentTerm, secondArgumentTerm), innerExpr)}
)""")
}
}
As you may have noticed there are some differences with respect to what we did above: when matching the function calls the first argument has different shape (it's wrapped in a TypeApply
object). Other than that it's mostly the same as what we did above. I don't know the cause for these differences with respect to our REPL / Jupyter session. I'm just starting to learn macros!
Let's see the macro in action:
val v1: Validation[Int] = Failure(List("error1"))
val v2: Validation[Int] = Failure(List("error2"))
val resultWhenApplicative = app_for {
for {
x <- v1
y <- v2
} yield x + y
}
println(resultWhenApplicative)
// > Failure(List(error1, error2))
This is what we wanted! We used a for comprehension but our macro detected that map2
could be used and thus both errors were reported!
Now let's test it with an expression that really needs flatMap
:
val resultWhenMonad = app_for {
for {
x <- v1
y <- if(x>0) v2 else v1
} yield x + y
}
println(resultWhenMonad)
// Failure(List(error1))
It works, just like a a normal for comprehension!
This macro is not fully functional, though. It doesn't account for a lot of situations:
case class Person(firstName: String, lastName: String, age: Int)
for {
person @ Person(_,_,age) <- validatePerson
_ <- if(age < 18) validateMinor(person) else validateAdult(person)
} yield something
Validation
.And I may be missing some other things.
The next steps will be to fix one by one each of these problems.