型推論の実装

http://www.geocities.jp/lethevert/clean/gettingStarted13.html を実装してみました。かなり、実装はいんちきです。いろいろバグっていると思います。ごめんなさい。

f a b c
    | a > 0     = c b
    | otherwise = b

上記 Haskell を f :: Int -> a -> (a->a) -> a と推論してくれます。ちょっと感動。

import scala.collection.mutable._

object TypeInference extends Application {
  class Expr(var exprType:Type)
  case class Func(name:Symbol, vars:List[Expr]) extends Expr(TypeAny)
  case class Symbol(name:String) extends Expr(TypeAny)
  case class Guard(cond:Expr, body:Expr) extends Expr(TypeAny)
  case class GuardList(list:List[Guard]) extends Expr(TypeAny)
  case class IntValue(v:int) extends Expr(TypeInt)
  case class BoolValue(v:boolean) extends Expr(TypeBool)
  
  var nameNo = 0
  class Type
  case class TypeAny extends Type {
    // TypeAnyに名前をつける
    val name:String = "" + ('a' + nameNo).asInstanceOf[char]
    nameNo += 1
    override def toString() = "TypeAny(" + name + ")"
    override def equals(t:Any) = t match {
      case that:TypeAny => name == that.name
      case _ => false
    }
    override def hashCode() = name.hashCode
  }
  case class TypeInt extends Type
  case class TypeBool extends Type
  case class TypeFunc(typeList:List[Type]) extends Type
  
  class Equation(val left:Expr, val right:Expr) {
    override def toString() = "Equation(" + left + "," + right  
  }
  val symbolA = Symbol("a")
  val symbolB = Symbol("b")
  val symbolC = Symbol("c")
  val symbolF = Symbol("f")
  val eq = new Equation(
      Func(symbolF, List(symbolA, symbolB, symbolC)),
      GuardList(List(
          Guard(
              Func(Symbol(">"), List(symbolA, IntValue(0))), 
              Func(symbolC, List(symbolB))),
          Guard(
              BoolValue(true),
              symbolB))))
  //println(eq);
              
  // 型の方程式を作る
  //println("----")
  var typeEq = Map[Expr, Set[Expr]]()
  def addTypeEq(ex1:Expr, ex2:Expr) {
    if(typeEq.contains(ex1)) {
      typeEq(ex1) += ex2
    } else {
      typeEq(ex1) = Set(ex2)
    }
    if(typeEq.contains(ex2)) {
      typeEq(ex2) += ex1
    } else {
      typeEq(ex2) = Set(ex1)
    }
  }
  def fillTypeMap(ex:Expr) {
    ex match {
      case f:Func => {
        if(f.name.name == ">") {
          addTypeEq(f.vars(0), f.vars(1))
        }
      }
      case gl:GuardList => {
        for(g1 <- gl.list) {
          addTypeEq(gl, g1)
          for(g2 <- gl.list) {
            if(g1 != g2) {
              addTypeEq(g1, g2)
              fillTypeMap(g1)
              fillTypeMap(g2)
              //println("==================")
              //printTypeEq()
              //println("==================")
            }
          }
        }
      }
      case g:Guard => {
        addTypeEq(g, g.body)
        fillTypeMap(g.cond)
        fillTypeMap(g.body)
        // ここでやることじゃないけど、ついでに
        g.cond.exprType = TypeBool
      }
      case _ =>
    }
  }
  addTypeEq(eq.left, eq.right)
  fillTypeMap(eq.left)
  fillTypeMap(eq.right)
  //println("typeEq = ");
  def printTypeEq() {
    for((k,v) <- typeEq) {
      println(k + " -> " + v)
    }
  }
  //printTypeEq();
  
  // 式の集合
  //println("----")
  var exprSet = Set[Expr]()
  var exprCount = 0
  var exprNo = Map[Expr, int]()
  def addExprSet(ex:Expr) {
    exprSet += ex
    exprNo(ex) = exprCount
    exprCount += 1
    ex match {
      case f:Func => {
        addExprSet(f.name)
        for(v <- f.vars) {
          addExprSet(v)
        }
      }
      case g:Guard => {
        addExprSet(g.cond)
        addExprSet(g.body)
      }
      case gl:GuardList => gl.list.foreach(addExprSet(_))
      case _ =>
    }
  }
  addExprSet(eq.left)
  addExprSet(eq.right)
  //println("exprSet = ")
  //for(ex <- exprSet) {
  //  println(ex + " -> " + ex.exprType)
  //}
  
  // 式 -> 型を変形できなくなるまで変形する
  //println("----")
  var isFound = true
  while(isFound) {
    isFound = false
    for(ex <- exprSet) {
      if(typeEq.contains(ex)) {
        for(ex2 <- typeEq(ex)) {
          if(ex.exprType != ex2.exprType) {
            //println(ex + "," + ex.exprType + " -- " + ex2 + "," + ex2.exprType)
            if(ex2.exprType.isInstanceOf[TypeAny]) {
              ex2.exprType = ex.exprType;
              isFound = true
              //println("new " + ex2 + "," + ex2.exprType)
            } else {
              if(!ex.exprType.isInstanceOf[TypeAny]) {
                error("型エラー");
              }
            }
          }
        }
      }
    }
  }
  var hasChange = true;
  while(hasChange) {
    hasChange = false
    for(ex <- exprSet) {
      ex match {
        case f:Func => { 
          val newType = TypeFunc(f.vars.map({
            ex2 => ex2 match {
              case f2:Func => f2.name.exprType
              case _ => ex2.exprType
            }
          }) ::: List(f.exprType))
          if(f.name.exprType != newType) {
            hasChange = true
          }
          f.name.exprType = newType
        }
        case _ =>
      }
    }
  }
  //println("----")
  //for(ex <- exprSet) {
  //  println(ex + " -> " + ex.exprType)
  //}
  println(symbolF.exprType)  
}