Monad laws in Scala

Trevor Hartman

The three Monad laws may seem pretty abstract at first, but they're quite practical. Let's try to internalize the laws by running through two of Scala's most popular monads and making sure they adhere. We could use something like ScalaCheck to more rigorously check these laws, but the purpose of this post is to help internalize them, so we'll only be manually verifying them using our intuition.

Here are the laws, from Monad laws on HaskellWiki.

  1. Left identity: return a >>= f ≡ f a
  2. Right identity: m >>= return ≡ m
  3. Associativity: (m >>= f) >>= g ≡ m >>= (\x -> f x >>= g)

In Haskell, return is used to "inject a value into the monadic type". In Scala, we do this via constructors (unless you're using Scalaz, in which case you probably already know everything this post has to offer).

The >>= operator in Haskell corresponds to Scala's flatMap method.

List

The List Monad deals with the context of non-determinism—that is, it represents multiple values. When we run multiple lists through a sequence comprehension we end up with the all combinations of values from each list.


_10
for (x <- List(1, 2, 3); y <- List('a', 'b', 'c')) yield (x, y)
_10
=> List((1,a), (1,b), (1,c), (2,a), (2,b), (2,c), (3,a), (3,b), (3,c))

Now the laws. Let's setup two simple functions f and g, both of type Int => List[Int].


_10
// Let f be a function that takes an Int and produces a List of its
_10
// neighboring Ints along with itself:
_10
val f: (Int => List[Int]) = x => List(x - 1, x, x + 1)
_10
_10
// Let g be a function that takes an Int x
_10
// and produces a List containing +x and -x
_10
val g: (Int => List[Int]) = x => List(x, -x)

Left identity


_10
val a = 2
_10
val lhs = List(a).flatMap(f)
_10
=> List(1, 2, 3)
_10
_10
val rhs = f(a)
_10
=> List(1, 2, 3)
_10
_10
lhs == rhs
_10
=> true

Right identity


_10
val m = List(2)
_10
_10
val lhs = m.flatMap(List(_))
_10
=> List(2)
_10
_10
val rhs = m
_10
=> List(2)
_10
_10
lhs == rhs
_10
=> true

Associativity


_11
val m = List(1, 2)
_11
_11
val lhs = m.flatMap(f).flatMap(g)
_11
=> List(0, 0, 1, -1, 2, -2, 1, -1, 2, -2, 3, -3)
_11
// Sidenote: now do you see what is meant by non-determinism?
_11
_11
val rhs = m.flatMap(x => f(x).flatMap(g))
_11
=> List(0, 0, 1, -1, 2, -2, 1, -1, 2, -2, 3, -3)
_11
_11
lhs == rhs
_11
=> true

Looks good to me.

Option

Let's create new test functions f and g of type Int => Option[Int]. Given the type signature, it's natural to think of f and g as partial functions that are only defined on certain inputs.


_10
// If x is not less than 10, return 2x
_10
val f: (Int => Option[Int]) = x => if (x < 10) None else Some(x * 2)
_10
_10
// If x is reater than 50, return x + 1
_10
val g: (Int => Option[Int]) = x => if (x > 50) Some(x + 1) else None

For the sake of testing our laws, the implementations of these functions really don't matter as long as the types line up.

Left identity


_10
val a = 30
_10
val lhs = Option(a).flatMap(f)
_10
=> Some(60)
_10
_10
val rhs = f(a)
_10
=> Some(60)
_10
_10
lhs == rhs
_10
=> true

Right identity


_10
val m = Option(30)
_10
_10
val lhs = m.flatMap(Option(_))
_10
=> Some(30)
_10
_10
val rhs = m
_10
=> Some(30)
_10
_10
lhs == rhs
_10
=> true

Associativity


_10
val m = Option(30)
_10
_10
val lhs = m.flatMap(f).flatMap(g)
_10
=> Some(61)
_10
_10
val rhs = m.flatMap(x => f(x).flatMap(g))
_10
=> Some(61)
_10
_10
lhs == rhs
_10
=> true

The end

I hope this post helped you internalize the Monad laws. If you need more practice, continue this exercise in your REPL for the Try and Either monads, or better yet: create your own Monad and verify that it obeys the laws!

Further reading