Solving Dynamic Programming problems with Functional Programming

(This notebook is intended to go along with this blog post)

First let's have some testing code:

In [1]:
object Tester {

    case class TestCase(maxWeight: Int, values: Vector[Int], weights: Vector[Int], expectedItems: Vector[Int])

    def testKnapsack(f: (Int, Vector[Int], Vector[Int]) => Int)(testCase: TestCase) = {
        val TestCase(maxWeight,values,weights,expectedItems) = testCase
        val expectedProfit = (values zip expectedItems).map { case (a,b) => a*b }.sum
        val returned = f(maxWeight, values, weights)
        assert(returned == expectedProfit, s"Expected $expectedProfit got $returned")
        println("Test Passed!")
    }


    val testCase1 = TestCase(
      maxWeight = 165,
      values = Vector(92, 57, 49, 68, 60, 43, 67, 84, 87, 72),
      weights = Vector(23, 31, 29, 44, 53, 38, 63, 85, 89, 82),
      expectedItems = Vector(1, 1,1,1,0,1,0,0,0,0)
    )

    val testCase2 = TestCase(
      maxWeight = 26,
      values = Vector(24,13,23,15,16),
      weights = Vector(12,7,11,8,9),
      expectedItems = Vector(0,1,1,1,0)
    )

    val testCase3 = TestCase(
      maxWeight = 190,
      values = Vector(50,50,64,46,50,5),
      weights = Vector(56,59,80,64,75,17),
      expectedItems = Vector(1,1,0,0,1,0)
    )

    val testCase4 = TestCase(
      maxWeight = 50,
      values = Vector(70,20,39,37,7,5,10),
      weights = Vector(31,10,20,19,4,3,6),
      expectedItems = Vector(1,0,0,1,0,0,0)
    )

    val testCase5 = TestCase(
      maxWeight = 104,
      values = Vector(350,400,450,20,70,8,5,5),
      weights = Vector(25,35,45,5,25,3,2,2),
      expectedItems = Vector(1,0,1,1,1,0,1,1)
    )

    val testCase6 = TestCase(
      maxWeight = 170,
      values = Vector(442,525,511,593,546,564,617),
      weights = Vector(41,50,49,59,55,57,60),
      expectedItems = Vector(0,1,0,1,0,0,1)
    )

    def test(f: (Int, Vector[Int], Vector[Int]) => Int) = {
        List(
            testCase1,
            testCase2,
            testCase3,
            testCase4,
            testCase5,
            testCase6
        ) foreach testKnapsack(f)
    }
}
Out[1]:
defined object Tester

The imperative approach

The usual imperative approach relies in some mutable data structure:

In [2]:
def knapsack(maxWeight: Int, value: Vector[Int], weight: Vector[Int]): Int = {
    val n = value.length
    val solutions: Array[Array[Int]] = Array.fill(n+1, maxWeight + 1)( 0 )
    (1 to n) foreach { i =>
        (1 to maxWeight) foreach { j =>
            solutions(i)(j) = if( j - weight(i-1) >= 0 ) {
                Math.max( solutions(i-1)(j) , solutions(i-1)(j - weight(i-1)) + value(i-1) )
            } else {
                solutions(i-1)(j)
            }
        }
    } 
    solutions(n)(maxWeight)
}


Tester.test(knapsack)
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Out[2]:
defined function knapsack

We really just need the last row:

In [3]:
def knapsack(maxWeight: Int, value: Vector[Int], weight: Vector[Int]): Int = {
    val n = value.length
    var solutions: Array[Int] = Array.fill(maxWeight + 1)( 0 )
    (1 to n) foreach { i =>
        val newSolutions = Array.fill(maxWeight + 1)( 0 )
        (1 to maxWeight) foreach { j =>
            newSolutions(j) = if( j - weight(i-1) >= 0 ) {
                Math.max( solutions(j) , solutions(j - weight(i-1)) + value(i-1) )
            } else {
                solutions(j)
            }
        }
        solutions = newSolutions
    } 
    solutions(maxWeight)
}

Tester.test(knapsack)
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Out[3]:
defined function knapsack

A functional (but complex) approach

We could use the State monad:

In [4]:
import $ivy.`org.typelevel::cats:0.7.2`
import cats._, cats.instances.all._, cats.syntax.traverse._, cats.syntax.foldable._
import cats.data.State

def setSolution(i: Int, j: Int, newVal: Int)
               (solutions: Vector[Vector[Int]]): Vector[Vector[Int]] = {
    solutions.updated(i, solutions(i).updated(j, newVal))
}

def knapsack(maxWeight: Int, value: Vector[Int], weight: Vector[Int]): Int = {
    val n = value.length
    val initialState: Vector[Vector[Int]] = Vector.fill(n+1, maxWeight + 1)( 0 )
    val st: State[Vector[Vector[Int]], Unit] = ( 1 to n ).toList.traverseU_ { i =>
        ( 1 to maxWeight ).toList.traverseU_ { j =>
            for {
                solutions <- State.get[Vector[Vector[Int]]]
                newVal = if( j - weight(i-1) >= 0 ) {
                    Math.max( solutions(i-1)(j) , solutions(i-1)(j - weight(i-1)) + value(i-1) )
                } else {
                    solutions(i-1)(j)
                }
                _ <- State.modify(setSolution(i,j,newVal))
            } yield ()
        }
    }
    val solution = st.runS(initialState).value
    solution(n)(maxWeight)
}

Tester.test(knapsack)
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Out[4]:
import $ivy.$                          

import cats._, cats.instances.all._, cats.syntax.traverse._, cats.syntax.foldable._

import cats.data.State


defined function setSolution
defined function knapsack

Once again, we just need the last row:

In [5]:
def knapsack(maxWeight: Int, value: Vector[Int], weight: Vector[Int]): Int = {
    val n = value.length
    val initialState: Vector[Int] = Vector.fill(maxWeight + 1)( 0 )
    val st: State[Vector[Int], Unit] = ( 1 to n ).toList.traverseU_ { i =>
        for {
            solutions <- State.get[Vector[Int]]
            newSolutions = 0 +: ( 1 to maxWeight ).map { j: Int =>
                if( j - weight(i-1) >= 0 ) {
                    Math.max( solutions(j) , solutions(j - weight(i-1)) + value(i-1) )
                } else {
                    solutions(j)
                }
            }.toVector
            _ <- State.set(newSolutions)
        } yield ()
    }
    val solution = st.runS(initialState).value
    solution(maxWeight)
}

Tester.test(knapsack)
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Out[5]:
defined function knapsack

A functional and simple approach

There is a simpler and functional way to do it. Just use fold:

In [6]:
def knapsack(maxWeight: Int, value: Vector[Int], weight: Vector[Int]): Int = {
    val n = value.length
    val firstRow = Vector.fill(maxWeight + 1)( 0 )
    (1 to n).foldLeft(firstRow) { (upperRow, i) =>
        0 +: (1 to maxWeight).map { j => 
            if( j - weight(i-1) >= 0 ) {
                Math.max( 
                    upperRow(j) ,
                    upperRow(j - weight(i-1)) + value(i-1)
                )
            } else {
                upperRow(j)
            }
        }.toVector
    }.last
}

Tester.test(knapsack)
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Test Passed!
Out[6]:
defined function knapsack

When the dependencies have another shape

In this other problem the dependencies have another shape. How can we do it this time?

In [7]:
val values = Vector(
    Vector(131, 673, 234, 103, 18 ),
    Vector(201, 96 , 342, 965, 150),
    Vector(630, 803, 746, 422, 111),
    Vector(537, 699, 497, 121, 956),
    Vector(805, 732, 524,  37, 331)
)

def accumulativeSum(nums: Vector[Int]): Vector[Int] =
    nums.scanLeft(0)(_ + _).drop(1)

def minPath(values: Vector[Vector[Int]]): Int = {
    val m = values.length
    val n = values(0).length
    val firstRow = accumulativeSum(values(0))
    val firstColumn = accumulativeSum( (0 to m-1).toVector.map(values(_)(0)) )
    (1 to m-1).foldLeft(firstRow) { (upperRow, i) =>
        val leftmostSolution = firstColumn(i)
        (1 to n-1).scanLeft(leftmostSolution) { (leftSolution,j) =>
            val upperSolution = upperRow(j)
            Math.min(leftSolution, upperSolution) + values(i)(j)
        }.toVector
    }.last
}

assert(minPath(values) == 2427)
Out[7]:
values: Vector[Vector[Int]] = Vector(
  Vector(131, 673, 234, 103, 18),
  Vector(201, 96, 342, 965, 150),
  Vector(630, 803, 746, 422, 111),
  Vector(537, 699, 497, 121, 956),
  Vector(805, 732, 524, 37, 331)
)
defined function accumulativeSum
defined function minPath