jhu
jhu

Reputation: 460

Backtracking with state

The list monad provides an excellent abstraction for backtracking in search problems. However, the problem I am facing now is one which involves state in addition to backtracking. (It also involves constraints related to previous choices made in the search path, but I will attack that issue later.)

The following simplified example illustrates the problematics. The function sumTo is given a nonnegative integer and a list with pairs of integers. The first element in each pair is a positive integer, the second element is the number of such integers available. The search problem is to express the first argument using the integers in the list, with count constraints. For example, here the integer 8 is represented in different ways as sums of five 1s, three 2s and two 4s with the contraint that all numbers making up the sum have to be even (so the 1s can not be used).

λ> sumTo 8 [(1,5), (4,2), (2,3)]
[[4,4],[4,2,2],[2,2,4],[2,4,2]]

The following is my current recursive solution to the problem.

sumTo :: Int -> [(Int, Int)] -> [[Int]]
sumTo = go []
  where
    go :: [(Int, Int)] -> Int -> [(Int, Int)] -> [[Int]]
    go _ 0 _ = [[]] -- base case: success
    go _ _ [] = [] -- base case: out of options, failure 
    -- recursion step: use the first option if it has counts left and
    -- is suitable; append to this cases where current option is not
    -- used at this point
    go prevOpts n (opt@(val,cnt):opts) =
      (if cnt > 0 && val <= n && even val
       then map (val:) $ go [] (n - val) $ (val,cnt-1):(prevOpts ++ opts)
       else [])
      ++ go (opt:prevOpts) n opts

While the function seems to work ok, it is much more complicated than one without state, employing the list monad.

sumToN :: Int -> [Int] -> [[Int]]
sumToN 0 _ = [[]]
sumToN n opts = do
  val <- opts
  guard $ val <= n
  guard $ even val
  map (val:) $ sumToN (n - val) opts

Not having constraints, this one gives one additional solution.

λ> sumToN 8 [1, 4, 2]
[[4,4],[4,2,2],[2,4,2],[2,2,4],[2,2,2,2]]

Now I am wondering if some higher order abstraction, such as StateT or something similar, could be utilized to simplify the case of backtracking with this kind of state constraints.

Upvotes: 4

Views: 545

Answers (2)

castletheperson
castletheperson

Reputation: 33506

It's not much work to add the StateT monad transformer to your clean solution. You just need to add a layer of wrapping and unwrapping to lift the values into the StateT type, and then take them back out using evalStateT.

Your code would also benefit from internally using a more specialized type for the opts than [(Int, Int)]. MultiSet would be a good choice since it automatically manages occurrences.

Here's a tested example of what it could look like:

import Control.Monad.State (StateT, evalStateT, get, modify, lift, guard)
import Data.MultiSet (MultiSet, fromOccurList, distinctElems, delete)

sumToN :: Int -> [(Int, Int)] -> [[Int]]
sumToN nStart optsStart =
    evalStateT (go nStart) (fromOccurList optsStart)
  where
    go :: Int -> StateT (MultiSet Int) [] [Int]
    go 0 = return []
    go n = do
        val <- lift . distinctElems =<< get
        guard (val <= n && even val)
        modify (delete val)
        (val:) <$> go (n - val)
λ> sumToN 8 [(1,5), (4,2), (2,3)]
[[2,2,4],[2,4,2],[4,2,2],[4,4]]

And actually, the StateT isn't benefiting us very much here. You could refactor it to take the MultiSet Int as a parameter and it would work just as well.

import Control.Monad (guard)
import Data.MultiSet (fromOccurList, distinctElems, delete)

sumToN :: Int -> [(Int, Int)] -> [[Int]]
sumToN nStart optsStart =
    go nStart (fromOccurList optsStart)
  where
    go 0 _    = return []
    go n opts = do
        val <- distinctElems opts
        guard (val <= n && even val)
        (val:) <$> go (n - val) (delete val opts)

Upvotes: 4

Li-yao Xia
Li-yao Xia

Reputation: 33519

There are two versions below, the first that just uses lists, and the second with StateT.

import Control.Applicative
import Control.Monad.State

The list type is the type of nondeterministic computations.

Given a set of elements (given in compact form as a list of (element, nb_copies)), we can pick any one, and return it together with the updated set. The result is a pair (Int, [(Int, Int)]). As a regular function, pick returns all possible results of that action.

Internally, we can also follow the definition with an "imperative" point of view. If the list is empty, there is nothing to pick (the empty list is the failing computation). Otherwise, there is at least one element x (implicitly, i > 0). Then we either pick one x (pickx), or we pick one element from the rest (pickxs), being careful to put x back at the end.

pick :: [(Int, Int)] -> [(Int, [(Int, Int)])]
pick [] = []
pick ((x, i) : xs) = pickx ++ pickxs
  where
    pickx = if i == 1 then [ (x, xs) ] else [ (x, (x, i-1) : xs) ]
    pickxs = do
      (x', xs') <- pick xs
      return (x', (x, i) : xs')

Then sumTo is defined as follows: if n = 0 then the only solution is the empty sum ([]) and we return it. Otherwise, we pick one element i from the set, check its validity, and recursively look for a solution for n-i with the updated set.

sumTo :: Int -> [(Int, Int)] -> [[Int]]
sumTo = go
  where
    go 0 _ = return []
    go n xs = do
      (i, xs') <- pick xs
      guard $ i <= n
      guard $ even i
      s' <- go (n-i) xs'
      return (i : s')

Now threading the set around can be tedious. StateT transforms a type of computation to be stateful. [] is nondeterministic computation. StateT s [] is stateful nondeterministic computation, with state type s. Here the state will be the set of remaining elements.

Interestingly, pick can directly be interpreted as such a stateful computation. The intuition is that executing pickState removes an element from the state, which updates the state, and returns that element. pickState automatically fails if there are no more elements.

pickState :: StateT [(Int, Int)] [] Int
pickState = StateT pick

Then we repeatedly pick elements until we reach 0.

sumToState :: Int -> StateT [(Int, Int)] [] [Int]
sumToState = go
  where
    go 0 = return []
    go n = do
      i <- pickState
      guard $ i <= n
      guard $ even i
      s' <- go (n-i)
      return (i : s')

main = do
  let n = 8
      xs = [(1, 5), (4, 2), (2, 3)]
  print $ sumTo n xs
  print $ evalStateT (sumToState n) xs

(full source)

Upvotes: 5

Related Questions