Infer – HackerRank Solution

In this post, we will solve Infer HackerRank Solution. This problem (Infer) is a part of HackerRank Functional Programming series.

Task

If we know that one is of type int and id is of type forall[a] a -> a, we can infer that id(one) is of type int.

A function fun x y -> x has a generic type of forall[a b] (a, b) -> a.

Let’s write a program to help us infer the type of expression in a given envrionment!

First, we define the syntax of expression:

ident : [_A-Za-z][_A-Za-z0-9]*              // variable names

expr : "let " ident " = " expr " in " expr  // variable defination
     | "fun " argList " -> " expr           // function defination
     | simpleExpr

argList : { 0 or more ident seperated by ' ' }

simpleExpr : '(' expr ')'
           | ident
           | simpleExpr '(' paramList ')'   // function calling

paramList : { 0 or more expr seperated ", " }

Then, we define the syntax of type:

ty : "() -> " ty                            // function without arguments
   | '(' tyList ") -> " ty                  // uncurry function
   | "forall[" argList "]" ty               // generic type
   | simpleTy " -> " ty                     // curry function
   | simpleTy

tyList : { 1 or more ty seperated by ", " }

simpleTy : '(' ty ')'
         | ident
         | simpleTy '[' tyList ']'          // such as list[int]

Hint in parsing:

  • Spacing is strict.
  • Pay attention to avoid dead loop.

Type of given expression should be infered in an environment. The environment is consisted of a set of functions with types:

head: forall[a] list[a] -> a
tail: forall[a] list[a] -> list[a]
nil: forall[a] list[a]
cons: forall[a] (a, list[a]) -> list[a]
cons_curry: forall[a] a -> list[a] -> list[a]
map: forall[a b] (a -> b, list[a]) -> list[b]
map_curry: forall[a b] (a -> b) -> list[a] -> list[b]
one: int
zero: int
succ: int -> int
plus: (int, int) -> int
eq: forall[a] (a, a) -> bool
eq_curry: forall[a] a -> a -> bool
not: bool -> bool
true: bool
false: bool
pair: forall[a b] (a, b) -> pair[a, b]
pair_curry: forall[a b] a -> b -> pair[a, b]
first: forall[a b] pair[a, b] -> a
second: forall[a b] pair[a, b] -> b
id: forall[a] a -> a
const: forall[a b] a -> b -> a
apply: forall[a b] (a -> b, a) -> b
apply_curry: forall[a b] (a -> b) -> a -> b
choose: forall[a] (a, a) -> a
choose_curry: forall[a] a -> a -> a

Sample Input #00

let x = id in x

Sample Output #00

forall[a] a -> a

Explanation #00

x is just id in the environment.

Sample Input #01

fun x -> let y = fun z -> z in y

Sample Output #01

forall[a b] a -> b -> b

Explanation #01

Function with variables which are not bounded in the environment should be generic function. The variable names appear in forall[] should be from a to z subject to their appearance order in type body.

Sample Input #02

choose(fun x y -> x, fun x y -> y)

Sample Output #02

forall[a] (a, a) -> a

Explanation #02

The type of choose is forall[a] (a, a) -> a. So x and y should be of the same type.

Sample Input #03

fun f -> let x = fun g y -> let _ = g(y) in eq(f, g) in x

Sample Output #03

forall[a b] (a -> b) -> (a -> b, a) -> bool

Explanation #03

The longest test case.

Final note:

All given expression are valid, non-recursive and can be infered successfully in given environment. But an optional requirement is that your program should fail on incomplete uncurry version function calling. For example, choose_curry(one) should be infered as int -> int but choose(one) just fail in infering.

Solution – Infer – HackerRank Solution

Scala

import java.util.Scanner

import scala.collection.mutable

trait Expression

object Expression {

  case class FunctionDefinition(params: List[String], body: Expression) extends Expression

  case class Variable(name: String) extends Expression

  case class FunctionCall(function: Expression, params: List[Expression]) extends Expression

  case class VariableDefinition(name: String, definition: Expression, body: Expression) extends Expression

}

object Parser {

  def toLexemes(s: String): List[Lexeme] = {
    val lexemeNames = Map(
      "let" -> Let,
      "in" -> In,
      "fun" -> Fun,
      "forall" -> Forall,
      "=" -> Equal,
      "->" -> Arrow,
      "," -> Comma,
      ":" -> Colon,
      "(" -> OpenRoundBracket,
      ")" -> CloseRoundBracket,
      "[" -> OpenSquareBracket,
      "]" -> CloseSquareBracket
    )

    def parseLexeme(index: Int, predicate: Char => Boolean): (Int, String) = {
      @scala.annotation.tailrec
      def inner(index: Int, sb: StringBuilder = new StringBuilder): (Int, String) = if (index < s.length) {
        val c = s(index)
        if (predicate(c)) inner(index + 1, sb.append(c)) else (index, sb.toString())
      } else (index, sb.toString())

      inner(index)
    }

    case class State(index: Int = 0, acc: List[Lexeme] = Nil)

    @scala.annotation.tailrec
    def inner(state: State): State = if (state.index < s.length) {
      val c = s(state.index)

      val nextState = c match {
        case d: Char if d.isLetter || d == '_' =>
          val (nextIndex, lexemeName) = parseLexeme(state.index, c => c.isLetterOrDigit || c == '_')
          val lexeme = lexemeNames.getOrElse(lexemeName, Identifier(lexemeName))
          State(nextIndex, lexeme :: state.acc)

        case '-' if state.index + 1 < s.length && s(state.index + 1) == '>' =>
          State(state.index + 2, Arrow :: state.acc)

        case d: Char if lexemeNames.contains(d.toString) =>
          val lexeme = lexemeNames(d.toString)
          State(index = state.index + 1, lexeme :: state.acc)

        case _ => state.copy(index = state.index + 1)
      }

      inner(nextState)
    } else state

    inner(State()).acc.reverse
  }

  def toExpression(lexemes: List[Lexeme]): Expression = lexemes match {
    case Let :: Identifier(name) :: Equal :: rest =>
      val (forExpr1, forExpr2) = findCloseBracket(rest, Let, In)
      Expression.VariableDefinition(name, toExpression(forExpr1), toExpression(forExpr2))

    case Fun :: rest =>
      val (forArgs, forExpr) = rest.span(_ != Arrow)
      val args = forArgs.map { case Identifier(name) => name }
      Expression.FunctionDefinition(args, toExpression(forExpr.tail))

    case OpenRoundBracket :: rest =>
      val (forExpr, nextRest) = findCloseBracket(rest, OpenRoundBracket, CloseRoundBracket)
      val expr = toExpression(forExpr)
      if (nextRest.isEmpty) {
        expr
      } else {
        val params = nextRest.tail.init.splitBy(Set(Comma), Set(Comma)).map(toExpression)
        Expression.FunctionCall(expr, params)
      }

    case Identifier(name) :: nextRest =>
      val expr = Expression.Variable(name)
      if (nextRest.isEmpty) {
        expr
      } else {
        @scala.annotation.tailrec
        def toFunctionCall(expr: Expression, nextRest: List[Lexeme]): Expression.FunctionCall = {
          val (forParams, rest2) = findCloseBracket(nextRest.tail, OpenRoundBracket, CloseRoundBracket)
          val params = forParams.splitBy(Set(Comma), Set(Comma)).map(toExpression)

          if (rest2.isEmpty) Expression.FunctionCall(expr, params) else toFunctionCall(Expression.FunctionCall(expr, params), rest2)
        }

        toFunctionCall(expr, nextRest)
      }

    case _ => throw new Exception("Wrong expression")
  }

  def loadEnvironment(s: String): Map[String, Inferer.Type] = {
    def parseLine(s: String): (String, Inferer.Type) = {
      val lexemes = toLexemes(s)
      lexemes match {
        case Identifier(name) :: Colon :: Forall :: rest =>
          val (forParams, nextRest) = rest.tail.span(_ != CloseSquareBracket)
          val params = forParams.collect { case Identifier(a) => a }
          val generics = params.map(name => name -> Inferer.nextVariable).toMap
          name -> parseType(generics, nextRest.tail)

        case Identifier(name) :: Colon :: rest =>
          name -> parseType(Map(), rest)

        case _ =>
          throw new Exception("Wrong type")
      }
    }

    s.split("\n").map(parseLine).toMap
  }

  private def findCloseBracket(lexemes: List[Lexeme], openBracket: Lexeme, closeBracket: Lexeme): (List[Lexeme], List[Lexeme]) = {
    case class Acc(lexemes: List[Lexeme] = Nil, bracketCount: Int = 1)
    @scala.annotation.tailrec
    def inner(lexemes: List[Lexeme], acc: Acc = Acc()): (List[Lexeme], List[Lexeme]) = {
      lexemes match {
        case (lex@`closeBracket`) :: lexemes =>
          if (acc.bracketCount == 1) (acc.lexemes.reverse, lexemes)
          else inner(lexemes, Acc(lex :: acc.lexemes, acc.bracketCount - 1))
        case (lex@`openBracket`) :: lexemes => inner(lexemes, Acc(lex :: acc.lexemes, acc.bracketCount + 1))
        case lex :: lexemes => inner(lexemes, Acc(lex :: acc.lexemes, acc.bracketCount))
        case Nil => throw new Exception("Wrong expression")
      }
    }

    inner(lexemes)
  }

  private def parseType(generics: Map[String, Inferer.Variable], lexemes: List[Lexeme]): Inferer.Type = {
    val res = lexemes match {
      case OpenRoundBracket :: rest =>
        val (forSubTypes, nextRest) = findCloseBracket(rest, OpenRoundBracket, CloseRoundBracket)
        val tys = forSubTypes.splitBy(Set(Comma), Set(Comma)).map(lexemes => parseType(generics, lexemes))
        Inferer.Function(tys, parseType(generics, nextRest.tail))

      case lexemes =>
        val (first, rest) = lexemes.span(_ != Arrow)
        val (id, params) = first.span(_ != OpenSquareBracket)

        val name = id.head.asInstanceOf[Identifier].name
        val firstType = if (params.isEmpty) { //Identifier
          generics.getOrElse(name, Inferer.Operation(name, Seq()))
        } else {
          Inferer.Operation(name, params.collect { case Identifier(a) => generics(a) })
        }

        if (rest.isEmpty)
          firstType
        else
          Inferer.Function(List(firstType), parseType(generics, rest.tail))
    }
    res
  }

  trait Lexeme

  case class Identifier(name: String) extends Lexeme

  case object Let extends Lexeme

  case object In extends Lexeme

  case object Fun extends Lexeme

  case object Forall extends Lexeme

  case object Equal extends Lexeme

  case object Arrow extends Lexeme

  case object Comma extends Lexeme

  implicit class Splitter(lexemes: List[Lexeme]) {
    def splitBy(separators: Set[Lexeme], drop: Set[Lexeme] = Set()): List[List[Lexeme]] = {
      case class Acc(current: List[Lexeme] = Nil, data: List[List[Lexeme]] = Nil)
      val acc = lexemes.foldLeft(Acc())((acc, lex) => {
        val nextCurrent = if (drop.contains(lex)) acc.current else lex :: acc.current
        if (separators.contains(lex)) Acc(Nil, nextCurrent.reverse :: acc.data)
        else Acc(nextCurrent, acc.data)
      })

      (if (acc.current.isEmpty) acc.data else acc.current.reverse :: acc.data).reverse
    }
  }

  case object Colon extends Lexeme

  case object OpenRoundBracket extends Lexeme

  case object CloseRoundBracket extends Lexeme

  case object OpenSquareBracket extends Lexeme

  case object CloseSquareBracket extends Lexeme

}

object Inferer {

  type Environment = Map[String, Type]
  private val arrow = "->"
  private var nextName = 'a'
  private var nextId = 0

  def Function(from: List[Type], to: Type): Operation = Operation(arrow, from.:+(to))

  def nextUniqueName: String = {
    val res = nextName
    nextName = (nextName + 1).toChar
    res.toString
  }

  def nextVariable: Variable = {
    val result = nextId
    nextId += 1
    Variable(result)
  }

  def infer(expression: Expression, environment: Environment, variables: mutable.Set[Variable] = mutable.Set()): Type = expression match {
    case Expression.VariableDefinition(name, definition, body) =>
      val definitionType = infer(definition, environment, variables)
      val nextEnvironment = environment + (name -> definitionType)
      infer(body, nextEnvironment, variables)

    case Expression.FunctionDefinition(params, body) =>
      val paramTypes = params.map(_ => nextVariable)
      val nextEnvironment = params.zip(paramTypes)
        .foldLeft(environment) { case (acc, (v, t)) => acc + (v -> t) }
      val nextVariables = paramTypes.foldLeft(variables)((acc, t) => acc.union(mutable.Set(t)))
      val resultType = infer(body, nextEnvironment, nextVariables)
      Function(paramTypes, resultType)

    case Expression.FunctionCall(function, params) =>
      val functionType = infer(function, environment, variables)
      val paramTypes = params.map(infer(_, environment, variables))
      val resultType = nextVariable
      unify(Function(paramTypes, resultType), functionType)
      resultType

    case Expression.Variable(name) =>
      inferVariableType(environment(name), variables)
  }

  def unify(type0: Type, type1: Type): Unit = {
    (prune(type0), prune(type1)) match {
      case (a: Variable, b) => if (a != b) {
        if (contains(b, a)) throw new Exception("recursion detected")
        a.typeOpt = Some(b)
      }

      case (a: Operation, b: Variable) =>
        unify(b, a)

      case (a: Operation, b: Operation) =>
        if (a.name != b.name || a.params.length != b.params.length)
          throw new Exception(s"Type mismatch: $a =/= $b")
        for (i <- a.params.indices)
          unify(a.params(i), b.params(i))

      case _ => throw new Exception("Wrong types.")
    }
  }

  def prune(t: Type): Type = t match {
    case v: Variable if v.typeOpt.isDefined =>
      val res = prune(v.typeOpt.get)
      v.typeOpt = Some(res)
      res
    case _ => t
  }

  def inferVariableType(t: Type, variables: mutable.Set[Variable]): Type = {
    val map = mutable.Map[Variable, Variable]()

    def inner(t: Type): Type = {
      prune(t) match {
        case v: Variable =>
          if (contains(variables, v)) v
          else map.getOrElseUpdate(v, nextVariable) // generic
        case Operation(name, args) =>
          Operation(name, args.map(inner))
      }
    }

    inner(t)
  }

  def contains(t: Type, v: Variable): Boolean = {
    prune(t) match {
      case `v` => true
      case Operation(_, params) => contains(params, v)
      case _ => false
    }
  }

  def contains(types: Iterable[Type], v: Variable): Boolean = types.exists(contains(_, v))

  private def asString(t: Type, variables: mutable.Set[String]): String = t match {
    case v: Variable => v.typeOpt match {
      case Some(t) =>
        asString(t, variables)
      case None =>
        variables.add(v.name)
        v.name
    }

    case Operation(name, params) =>
      if (params.isEmpty) name
      else if (name == arrow) {
        val tempLeft = params.take(params.length - 1).map(asString(_, variables)).mkString(", ")
        val parentheses = params match {
          case Operation(`arrow`, _) :: _ => true
          case (v: Inferer.Variable) :: _ if (v.typeOpt match {
            case Some(op: Operation) if op.name == arrow => true
            case _ => false
          }) => true
          case _ :: _ :: Nil => false
          case _ => true
        }

        val left = if (parentheses) s"($tempLeft)" else tempLeft
        val right = asString(params.last, variables)
        s"$left -> $right"
      } else {
        s"$name[${params.map(asString(_, variables)).mkString(", ")}]"
      }
  }

  trait Type {
    override def toString: String = {
      val variables = mutable.Set[String]()
      val s = asString(this, variables)
      if (variables.isEmpty) s else s"forall[${variables.toList.sorted.mkString(" ")}] $s"
    }
  }

  case class Variable(id: Int) extends Type {
    lazy val name: String = nextUniqueName
    var typeOpt: Option[Type] = None
  }

  case class Operation(name: String, params: Seq[Type]) extends Type

}

object Solution {
  private val envString =
    """head: forall[a] list[a] -> a
      |tail: forall[a] list[a] -> list[a]
      |nil: forall[a] list[a]
      |cons: forall[a] (a, list[a]) -> list[a]
      |cons_curry: forall[a] a -> list[a] -> list[a]
      |map: forall[a b] (a -> b, list[a]) -> list[b]
      |map_curry: forall[a b] (a -> b) -> list[a] -> list[b]
      |one: int
      |zero: int
      |succ: int -> int
      |plus: (int, int) -> int
      |eq: forall[a] (a, a) -> bool
      |eq_curry: forall[a] a -> a -> bool
      |not: bool -> bool
      |true: bool
      |false: bool
      |pair: forall[a b] (a, b) -> pair[a, b]
      |pair_curry: forall[a b] a -> b -> pair[a, b]
      |first: forall[a b] pair[a, b] -> a
      |second: forall[a b] pair[a, b] -> b
      |id: forall[a] a -> a
      |const: forall[a b] a -> b -> a
      |apply: forall[a b] (a -> b, a) -> b
      |apply_curry: forall[a b] (a -> b) -> a -> b
      |choose: forall[a] (a, a) -> a
      |choose_curry: forall[a] a -> a -> a
      |""".stripMargin

  def main(args: Array[String]): Unit = {
    val sc = new Scanner(System.in)
    val s = sc.nextLine
    solve(s)
  }

  def solve(s: String): Unit = {
    val lexemes = Parser.toLexemes(s)
    val expression = Parser.toExpression(lexemes)
    val environment: Inferer.Environment = Parser.loadEnvironment(envString)
    println(Inferer.infer(expression, environment))
  }
}

Note: This problem (Infer) is generated by HackerRank but the solution is provided by CodingBroz. This tutorial is only for Educational and Learning purpose.

Leave a Comment

Your email address will not be published. Required fields are marked *