Functors in Scala¶

The functional programming style eschews loops and replaces it with tail recursive functions that expresses the logic in a simple fashion, without carrying extra state.

Another mechanism that we will study now is that of "functors" such as map, filter and fold.

  • These functors allow us to manipulate Lists of objects.
  • They also apply to other data structures in scala such as Maps.

Before we begin with functors let us study different ways to write functions in scala, including a convenient notation for anonymous functions.

In [1]:
// P1: Let us start with a function to multiply every element of a list by two

def multiplyEachEltByTwo(lst: List[Int], accList: List[Int] = Nil): List[Int] = lst match {
    case Nil => accList
    case hd::tail => {
        val newAccList = accList ++ List(2 * hd) 
        // We add 2 * hd at the end why?
        multiplyEachEltByTwo(tail, newAccList)
    }
}
Out[1]:
defined function multiplyEachEltByTwo
In [2]:
// P2: We would like to remove all the even elements from a list. Returning a new list with just the odd elements.

def removeEvenNumbers(lst: List[Int], accList: List[Int] = Nil): List[Int] = lst match {
    case Nil => accList
    case hd::tail => {
        val newAccList = if (hd %2 == 0) { 
            accList 
        } else { 
            accList ++ List(hd) 
        }
        removeEvenNumbers(tail, newAccList)
    }
}
Out[2]:
defined function removeEvenNumbers
In [3]:
// P3: Wish to sum up the elements of the list

def sumOfList(lst: List[Int], sum: Int = 0): Int = lst match {
    case Nil => sum
    case hd::tail => sumOfList(tail, sum + hd)
}
Out[3]:
defined function sumOfList
In [6]:
// The main function looks like this

def processList(lst: List[Int]): Int = {
    val intermediateList = removeEvenNumbers(lst)
    val intermediateList2 = multiplyEachEltByTwo(intermediateList)
    sumOfList(intermediateList2)
}
Out[6]:
defined function processList
In [7]:
removeEvenNumbers(List(1,3,4,5,6,7,8))
multiplyEachEltByTwo(List(1,2,3,4))
sumOfList ((1 to 100).toList)

processList((1 to 20).toList)
Out[7]:
res7_0: List[Int] = List(1, 3, 5, 7)
res7_1: List[Int] = List(2, 4, 6, 8)
res7_2: Int = 5050
res7_3: Int = 200

This is somewhat painful since we need to write three separate functions to do the job.

Is there something better we can do?
Yes: let us recognize three patterns of operations we would like to achieve:

  • Map: apply a function f to every element of a list.
  • Filter: keep just those elements of the list that satisfy a "predicate"
  • Fold (or reduce): perform an accumulative operation to every element of the list.

Anonymous Functions In Scala¶

Before we look closer at these operations, let us first familiarize ourselves with anonymous functions in scala. Often it is cumbersome to define functions by name where we would like to pass a function. Therefore, we will use "anonymous" functions.

In [8]:
def multiplyByTwo(x: Int): Int = x * 2
Out[8]:
defined function multiplyByTwo
In [9]:
// Here is two other ways to write the same thing

val f : Int => Int = (x) => x * 2
Out[9]:
f: Int => Int = ammonite.$sess.cmd9$Helper$$Lambda$3151/0x0000000801565830@3d3886ab
In [10]:
val f2: Int => Int = _ * 2

// The `_` here is simply the first argument. 
// Often it is important to specify the type of an argument in an anonymous function.
Out[10]:
f2: Int => Int = ammonite.$sess.cmd10$Helper$$Lambda$3157/0x000000080156a628@5ce9a0f4

Anonymous functions, also known as lambda expressions or function literals, are functions defined without a name. They are particularly useful when you need to pass a function as an argument to another function.

In Scala, they can be written in several ways:

  • Standard syntax: (parameters) => expression
  • Shortened syntax using underscore (_) for parameters
  • Block syntax for multi-line functions: (parameters) => { statements }

These functions can be assigned to variables, passed as arguments, or returned from other functions, making them powerful tools for functional programming in Scala.

In the above example, f is bound to a function that takes in an argument x and returns x * 2. You can pass the expression (x) => x * 2 in any context you wish without giving it a name as we will see.

In [11]:
// Here is another succint version:
// Scala infers the type of x + x from that of x and the type of g is inferred.

val g = (x: String) => x + x
Out[11]:
g: String => String = ammonite.$sess.cmd11$Helper$$Lambda$3165/0x000000080156b2f8@4264c1dd
In [12]:
val g: String => String = x => x + x // OK: scala infers the typeof x from the type given to g
Out[12]:
g: String => String = ammonite.$sess.cmd12$Helper$$Lambda$3167/0x0000000801574000@733830b6
In [12]:
val g2 = x => x + x // BAD: Scala has no way of knowing what x is. It can be a String, Int, Double, ...
-- [E081] Type Error: cmd13.sc:1:9 ---------------------------------------------
1 |val g2 = x => x + x // BAD: Scala has no way of knowing what x is. It can be a String, Int, Double, ...
  |         ^
  |         Missing parameter type
  |
  |         I could not infer the type of the parameter x
Compilation Failed

Anonymous functions can take multiple arguments.

In [13]:
val addFun = (x: Int, y:Int) => x + y
Out[13]:
addFun: (Int, Int) => Int = ammonite.$sess.cmd13$Helper$$Lambda$3266/0x0000000801574cb0@27b23138
In [14]:
val addFun = (x: (Int, Int)) => x._1 + x._2
Out[14]:
addFun: ((Int, Int)) => Int = ammonite.$sess.cmd14$Helper$$Lambda$3279/0x0000000801575b60@24dc858c
In [15]:
val addFun: (Int, Int) => Int = _ + _  // First _ is the first argument and second _ is the second argument.
Out[15]:
addFun: (Int, Int) => Int = ammonite.$sess.cmd15$Helper$$Lambda$3281/0x0000000801576810@39078a9b

Last but not least in a case pattern matching setup, you can define an anonymous function without the match statement.

In [16]:
sealed trait MyList 
case object MyNil extends MyList
case class MyCons(x: Int, l: MyList) extends MyList
Out[16]:
defined trait MyList
defined object MyNil
defined class MyCons
In [17]:
val anonIsEmptyFun: MyList => Boolean = (x) => { x match {
    case MyNil => true
    case MyCons(_, _) => false
}}
Out[17]:
anonIsEmptyFun: MyList => Boolean = ammonite.$sess.cmd17$Helper$$Lambda$3399/0x00000008015ca208@4272035a
In [18]:
val anonIsEmptyFun: MyList => Boolean = {
    case MyNil => true
    case MyCons(_, _) => false
}
Out[18]:
anonIsEmptyFun: MyList => Boolean = ammonite.$sess.cmd18$Helper$$Lambda$3404/0x00000008015caec0@6ce23e45

In other words, when you have the pattern

 (x : Type) => x match {
     case .. =>
     case .. => 
     ...
     }

You can instead simply say

{ 
  case .. => 
  case .. => 
  ...
 }

without saying (x : Type) => x match.

Map, Filter and Fold (Reduce) Operations¶

In many languages, the use of for-loops/while loops to iterate is replaced by operations on data structures such as map, filter and fold. In this lecture, we provide a brief overview with some examples. We show how many varieties of loops or equivalently recursion, can be systematically replaced by these operations.

Map operation¶

The idea of a map operation is to apply a function $f$ to every member of a container (eg., list, array, map, etc.) and return a new container.

Example 1¶

We have a list List(1, 3, 4, 5, 6, 110, 12, 2). We wish to compute the square of each element in the list and make a new list with the result.

In [19]:
def recursivelySquareEachElt(l: List[Int], acc: List[Int] = Nil): List[Int] = { l match
    case Nil => acc.reverse
    case hd::tail => recursivelySquareEachElt(tail, (hd*hd)::acc)
}
Out[19]:
defined function recursivelySquareEachElt
In [21]:
recursivelySquareEachElt(List(10, -3, -3, 14, 2))
Out[21]:
res21: List[Int] = List(100, 9, 9, 196, 4)
In [22]:
recursivelySquareEachElt(List(1, 3, 4, 5, 6, 110, 12, 2), Nil)
Out[22]:
res22: List[Int] = List(1, 9, 16, 25, 36, 12100, 144, 4)

Using the map operator over lists.

In [23]:
def squareEachElt(l: List[Int]): List[Int] =  l.map( (x: Int) => x*x ) 
// x => x * x is an anonymous function that squares its arguments.
Out[23]:
defined function squareEachElt
In [24]:
squareEachElt(List(1, 3, 4, 5, 6, 110, 12, 2))
Out[24]:
res24: List[Int] = List(1, 9, 16, 25, 36, 12100, 144, 4)

l.map(f) says that apply the function f on each element of the list f.

  • First of all, the elements of the lists must be some type let's say A.
  • Next, the function f must be of type A => B.

Last but not least, l.map(f) applies f to every element in the list and returns a new list of type B.

Here is a recursive definition of this function. Can you make it tail recursive??

In [25]:
def listMap[A,B](lst: List[A], fun: A => B): List[B] = lst match {
    case Nil => Nil
    case hd :: tail => fun(hd) :: listMap(tail, fun)  // :: is the Cons operator in scala.
}
Out[25]:
defined function listMap
In [26]:
def sayHelloTo(l: List[String]): List[String] = l.map( x => ("Hello "+ x)) // Type of x is inferred by Scala
Out[26]:
defined function sayHelloTo
In [27]:
sayHelloTo(List("Cat", "Dog", "World"))
Out[27]:
res27: List[String] = List("Hello Cat", "Hello Dog", "Hello World")

Filter Operation.¶

Just like we have used map to apply a function to each element and make a new container, we use filter to remove all elements that do not satisfy a predicate.

Predicate: A predicate is a function that takes in a value and returns true/false.

l.filter(c) filters all those elements that do not satisfy the condition c from the list l.

In [28]:
def retainAllMultiplesOfThree(l: List[Int], acc: List[Int] = Nil): List[Int] = l match {
    case Nil => acc
    case hd :: tail => {
        val newAcc = if (hd % 3 == 0) { acc ++ List(hd) } else { acc }
        retainAllMultiplesOfThree(tail, newAcc)
    }
}
Out[28]:
defined function retainAllMultiplesOfThree
In [29]:
def retainAllMultiplesOfThree(l: List[Int]): List[Int] = {
    l.filter( x => x%3 == 0 )
}
Out[29]:
defined function retainAllMultiplesOfThree
In [30]:
retainAllMultiplesOfThree(List(10, 15, 18, 12, 3, 1, 5, 7, 8, 14))
Out[30]:
res30: List[Int] = List(15, 18, 12, 3)

Here is how the filter operation is defined abstractly

In [31]:
def filterList[A] (lst: List[A], filterFun: A => Boolean): List[A] = lst match {
    case Nil => Nil
    case head :: tail => {
        if (filterFun(head)){
            head :: filterList(tail, filterFun)
        } else {
            filterList(tail, filterFun)
        }
    }
}

// Ths is not tail recursive. Why? Can you make it tail recursive?
Out[31]:
defined function filterList

Fold Operations¶

Fold/reduce operations are useful to gather all data thus far during a computation. Take a list $ [l_1, l_2, \ldots, l_n] $

We wish to sum up the numbers in the list. This is achieved in a loop with accumulator.

acc = 0
for each item in List
   acc = acc + item
return acc

We can also do it with fold left operator.

As an example consider the sum of the elements of a list above.

In [32]:
def recSumOfList(lst: List[Int], sum: Int = 0): Int = lst match {
    case Nil => sum
    case hd::tail => recSumOfList(tail, sum + hd)
}
Out[32]:
defined function recSumOfList

Fold is a tricky operation to wrap one's head around. A list data structure gives us two versions of fold.

list.foldLeft (startVal) (fun)¶

For list [l1, l2, l3, ..., ln] the function call computes the following unrolled function:

fun(.... fun( fun ( fun( startVal, l1), l2), l3), ....., ln) This is equivalent to the following scala code:

var acc = startVal
for (lj <- list)
   acc =  fun(acc, lj) // Very imp: acc is the first argument and lj is the second argument.

list.foldRight (startVal) (fun)¶

This iterates the list from right to left. To wit, list [l1, l2, l3, ..., ln] the function call computes the following unrolled function:

fun(l1, fun(.....,fun(ln-2, fun(ln-1, fun(ln, startVal)))

This is equivalent to the following scala code:

var acc = startVal
for (lj <- list.reverse) // Note list is iterated in reverse
   acc =  fun(lj, acc) // Very imp: acc is the second argument for foldRight

The fold function has two arguments: startVal and fun.

Why don't we write: list.foldLeft(startVal, fun)? This is a special syntax for writing functions with multiple argument in scala called curried syntax

https://alvinalexander.com/scala/fp-book/partially-applied-functions-currying-in-scala

We will talk about currying in detail later on (in a few weeks) and it has nothing to do with Indian cuisine.

In [33]:
// Fold left with initial value of accumulator = 0
// Every time we have a new list element x and accumulator value acc, update acc by acc + x

def sumList(l: List[Int]): Int = l.foldLeft (0) ((acc, x) => acc + x )
Out[33]:
defined function sumList
In [34]:
sumList(List(1, 2, 3,4, 5, 6, 7, 8, 9, 10))
Out[34]:
res34: Int = 55
In [35]:
def sumFromRight(l: List[Int]) : Int = l.foldRight (0) ((x, acc) => x + acc)
Out[35]:
defined function sumFromRight
In [36]:
sumFromRight((1 to 10).toList)
Out[36]:
res36: Int = 55
In [37]:
// Let us now write a function `reverseList`

def reverseList(l: List[Int]): List[Int] = 
l.foldLeft (Nil)  ( (listSoFar: List[Int], elt: Int) => {
    elt::listSoFar
} )
Out[37]:
defined function reverseList

In general, it is always nice to have the type of the accumulator specified in fold left. Last but not least, note that the anonymous function in fold can be written in case pattern form.

In [38]:
def reverseList(l: List[Int]): List[Int] = l.foldLeft[List[Int]] (Nil) {
    case (listSoFar: List[Int], elt: Int) => elt::listSoFar
}
Out[38]:
defined function reverseList
In [27]:
reverseList((1 to 10).toList)
Out[27]:
res27: List[Int] = List(10, 9, 8, 7, 6, 5, 4, 3, 2, 1)
In [39]:
// Putting it all together, write a function that does all the above using map, reduce and filter operations only

def processList(lst: List[Int]): Int = {
    lst
        .filter(x => x % 2 != 0)
        .map(x => x * 2)       
        .foldLeft(0)(_ + _)
}
Out[39]:
defined function processList
In [40]:
processList((1 to 20).toList)
Out[40]:
res40: Int = 200

A one-liner¶

In [44]:
def processList(l: List[Int]): Int = l.filter(x => x%2 != 0).map(x => x*2).foldLeft(0)(_ + _)
Out[44]:
defined function processList
In [45]:
processList((1 to 20).toList)
Out[45]:
res45: Int = 200