ナップサック問題をHaskellとScalaで

ナップサック問題という大昔からある有名な問題があります。怪盗が重量制限のあるナップサックにできるだけ物を詰め込んで、詰め込んだ価値を最大化する問題です。

そのための、教科書的な解法は、動的計画法を使うことです。プログラミングコンテスト(IOIやICPCなど)では非常に良く出るアルゴリズムです。

まずは、教科書的なボトムアップの解法。Scalaで書いています。

object KnapsackBottomUp extends Application {
  val goods = List((3,1), (4,2), (5,3))
  val n = Integer.parseInt(Console.readLine)
  val solved = new Array[int](n + 1)
  for(weight <- 1 to n) {
    solved(weight) = goods.map(
        g => if(weight < g._1) 0 else solved(weight - g._1) + g._2).
      reduceLeft((a:int,b:int) => Math.max(a,b))
  }
  println(solved(n))
}

goods は(重さ、価値)です。標準入力から重量制限を読み込み、標準出力に総価値を返します。solvedという配列に、1からnまで順番に代入していき、総重量が大きいのは、過去に解いた solved のより小さい重量制限を使って解きます。高校の数学的帰納法のように部分を解いて、そこから、だんだんと伸ばしていくテクニックです。これを闇雲にやると莫大な計算量になりますが、前から順番にとくと非常に高速に解けます。これが、教科書的な動的計画法

次は、それをトップダウンに変えた物。Scalaです。

object KnapsackTopDown extends Application {
  val goods = List((3,1), (4,2), (5,3))
  val n = Integer.parseInt(Console.readLine)
  val cache = new Array[int](n + 1)
  def calc(weight:int):int = {
    if(cache(weight) == 0) 
      cache(weight) = goods.map(
          g => if(weight < g._1) 0 else calc(weight - g._1) + g._2).
        reduceLeft((a:int,b:int) => Math.max(a,b))
    return cache(weight)
  }
  println(calc(n))
}

自分が最終的に求めたい総重量から、1つ1つ品物を試していきながら、部分問題に分割していっています。分割統治法です。こちらの方が、より素直な解法であり、英語版のWikipediaではまずこちらが紹介されています。ただし、この方法を使うには、一度計算した結果をキャッシュしないと、計算量が激増します。キャッシュを素直に作ると可変Mapです。今回は、高速化のために配列をMap代わりに使っています。ただし、可変Mapは本来は純粋関数型言語では許されていない物です。

さて、次は、純粋関数型言語 HaskellHaskell は可変Mapが作るのがやりづらい代わりに、一度関数を呼び出したら、その結果をキャッシュしてくれます(正確には遅延評価しているだけだけど)。参照透過性、つまり「引数が決まれば必ず返値が決まる」が成り立つからなせる技です。(この部分、うそです)

上記のトップダウンコードを Haskell で書きます。

goods = [(3,1), (4,2), (5,3)]
calc weight = maximum $ map 
  (\(w,v) -> if weight < w then 0 else calc(weight - w) + v) goods
main = do n <- readLn
          print $ calc n

cache というのは無くなっています。

でも、ここで問題発生です!これ、n=45あたりから、物凄くメモリを消費するようになります。遅延評価なため、まず、計算結果の「木」を作ると思うんですが、それが巨大になりすぎるようです。どうすればいいの?

追記:Memoization - にゃあさんの戯言日記 これが答え?こんな複雑な方法しかないの?それなら、純粋関数型言語つかえねぇ〜!(この問題に関しては)非純粋の方がいいや。

ちなみに、ボトムアップの方は普通にn=1000でも処理してくれます。(みずしまさんが再帰を使わずにコメント欄に書いてくださいました)

goods = [(3,1), (4,2), (5,3)]
calc weight n solved =
  if weight > n then (solved !! 0)
  else calc (weight + 1) n (va : solved)
  where 
    va = maximum $ map 
      (\(w,v) -> if weight < w then 0 else (solved !! (w - 1)) + v) 
      goods
main = do n <- readLn
          print $ calc 1 n [0]

うーん。トップダウンだと末尾再帰にならないことが色々と問題を生んでいる気がする…

さらに追記:みずしまさんのコメント欄に書かれたコード。このコードは凄いです!遅延評価だとトップダウンだし、正格評価(非遅延評価)だと、ボトムアップになります。

import Array
goods = [(3,1), (4,2), (5,3)]
calc n = c ! n
  where 
    c = array (0, n) [(weight, maximum $ map 
          (\(w, v) -> if weight < w then 0 else c ! (weight - w) + v) 
          goods) | weight <- [0..n]]
main = do n <- readLn
          print $ calc n