Infer – HackerRank Solution

In this post, we will solve Infer HackerRank Solution. This problem (Infer) is a part of HackerRank Functional Programming series.

If we know that `one` is of type `int` and `id` is of type `forall[a] a -> a`, we can infer that `id(one)` is of type `int`.

A function `fun x y -> x` has a generic type of `forall[a b] (a, b) -> a`.

Let’s write a program to help us infer the type of expression in a given envrionment!

First, we define the syntax of expression:

``````ident : [_A-Za-z][_A-Za-z0-9]*              // variable names

expr : "let " ident " = " expr " in " expr  // variable defination
| "fun " argList " -> " expr           // function defination
| simpleExpr

argList : { 0 or more ident seperated by ' ' }

simpleExpr : '(' expr ')'
| ident
| simpleExpr '(' paramList ')'   // function calling

paramList : { 0 or more expr seperated ", " }``````

Then, we define the syntax of type:

``````ty : "() -> " ty                            // function without arguments
| '(' tyList ") -> " ty                  // uncurry function
| "forall[" argList "]" ty               // generic type
| simpleTy " -> " ty                     // curry function
| simpleTy

tyList : { 1 or more ty seperated by ", " }

simpleTy : '(' ty ')'
| ident
| simpleTy '[' tyList ']'          // such as list[int]``````

Hint in parsing:

• Spacing is strict.
• Pay attention to avoid dead loop.

Type of given expression should be infered in an environment. The environment is consisted of a set of functions with types:

``````head: forall[a] list[a] -> a
tail: forall[a] list[a] -> list[a]
nil: forall[a] list[a]
cons: forall[a] (a, list[a]) -> list[a]
cons_curry: forall[a] a -> list[a] -> list[a]
map: forall[a b] (a -> b, list[a]) -> list[b]
map_curry: forall[a b] (a -> b) -> list[a] -> list[b]
one: int
zero: int
succ: int -> int
plus: (int, int) -> int
eq: forall[a] (a, a) -> bool
eq_curry: forall[a] a -> a -> bool
not: bool -> bool
true: bool
false: bool
pair: forall[a b] (a, b) -> pair[a, b]
pair_curry: forall[a b] a -> b -> pair[a, b]
first: forall[a b] pair[a, b] -> a
second: forall[a b] pair[a, b] -> b
id: forall[a] a -> a
const: forall[a b] a -> b -> a
apply: forall[a b] (a -> b, a) -> b
apply_curry: forall[a b] (a -> b) -> a -> b
choose: forall[a] (a, a) -> a
choose_curry: forall[a] a -> a -> a``````

Sample Input #00

``let x = id in x``

Sample Output #00

``forall[a] a -> a``

Explanation #00

`x` is just `id` in the environment.

Sample Input #01

``fun x -> let y = fun z -> z in y``

Sample Output #01

``forall[a b] a -> b -> b``

Explanation #01

Function with variables which are not bounded in the environment should be generic function. The variable names appear in `forall[]` should be from `a` to `z` subject to their appearance order in type body.

Sample Input #02

``choose(fun x y -> x, fun x y -> y)``

Sample Output #02

``forall[a] (a, a) -> a``

Explanation #02

The type of `choose` is `forall[a] (a, a) -> a`. So `x` and `y` should be of the same type.

Sample Input #03

``fun f -> let x = fun g y -> let _ = g(y) in eq(f, g) in x``

Sample Output #03

``forall[a b] (a -> b) -> (a -> b, a) -> bool``

Explanation #03

The longest test case.

Final note:

All given expression are valid, non-recursive and can be infered successfully in given environment. But an optional requirement is that your program should fail on incomplete uncurry version function calling. For example, `choose_curry(one)` should be infered as `int -> int` but `choose(one)` just fail in infering.

Solution – Infer – HackerRank Solution

Scala

```import java.util.Scanner

import scala.collection.mutable

trait Expression

object Expression {

case class FunctionDefinition(params: List[String], body: Expression) extends Expression

case class Variable(name: String) extends Expression

case class FunctionCall(function: Expression, params: List[Expression]) extends Expression

case class VariableDefinition(name: String, definition: Expression, body: Expression) extends Expression

}

object Parser {

def toLexemes(s: String): List[Lexeme] = {
val lexemeNames = Map(
"let" -> Let,
"in" -> In,
"fun" -> Fun,
"forall" -> Forall,
"=" -> Equal,
"->" -> Arrow,
"," -> Comma,
":" -> Colon,
"(" -> OpenRoundBracket,
")" -> CloseRoundBracket,
"[" -> OpenSquareBracket,
"]" -> CloseSquareBracket
)

def parseLexeme(index: Int, predicate: Char => Boolean): (Int, String) = {
@scala.annotation.tailrec
def inner(index: Int, sb: StringBuilder = new StringBuilder): (Int, String) = if (index < s.length) {
val c = s(index)
if (predicate(c)) inner(index + 1, sb.append(c)) else (index, sb.toString())
} else (index, sb.toString())

inner(index)
}

case class State(index: Int = 0, acc: List[Lexeme] = Nil)

@scala.annotation.tailrec
def inner(state: State): State = if (state.index < s.length) {
val c = s(state.index)

val nextState = c match {
case d: Char if d.isLetter || d == '_' =>
val (nextIndex, lexemeName) = parseLexeme(state.index, c => c.isLetterOrDigit || c == '_')
val lexeme = lexemeNames.getOrElse(lexemeName, Identifier(lexemeName))
State(nextIndex, lexeme :: state.acc)

case '-' if state.index + 1 < s.length && s(state.index + 1) == '>' =>
State(state.index + 2, Arrow :: state.acc)

case d: Char if lexemeNames.contains(d.toString) =>
val lexeme = lexemeNames(d.toString)
State(index = state.index + 1, lexeme :: state.acc)

case _ => state.copy(index = state.index + 1)
}

inner(nextState)
} else state

inner(State()).acc.reverse
}

def toExpression(lexemes: List[Lexeme]): Expression = lexemes match {
case Let :: Identifier(name) :: Equal :: rest =>
val (forExpr1, forExpr2) = findCloseBracket(rest, Let, In)
Expression.VariableDefinition(name, toExpression(forExpr1), toExpression(forExpr2))

case Fun :: rest =>
val (forArgs, forExpr) = rest.span(_ != Arrow)
val args = forArgs.map { case Identifier(name) => name }
Expression.FunctionDefinition(args, toExpression(forExpr.tail))

case OpenRoundBracket :: rest =>
val (forExpr, nextRest) = findCloseBracket(rest, OpenRoundBracket, CloseRoundBracket)
val expr = toExpression(forExpr)
if (nextRest.isEmpty) {
expr
} else {
val params = nextRest.tail.init.splitBy(Set(Comma), Set(Comma)).map(toExpression)
Expression.FunctionCall(expr, params)
}

case Identifier(name) :: nextRest =>
val expr = Expression.Variable(name)
if (nextRest.isEmpty) {
expr
} else {
@scala.annotation.tailrec
def toFunctionCall(expr: Expression, nextRest: List[Lexeme]): Expression.FunctionCall = {
val (forParams, rest2) = findCloseBracket(nextRest.tail, OpenRoundBracket, CloseRoundBracket)
val params = forParams.splitBy(Set(Comma), Set(Comma)).map(toExpression)

if (rest2.isEmpty) Expression.FunctionCall(expr, params) else toFunctionCall(Expression.FunctionCall(expr, params), rest2)
}

toFunctionCall(expr, nextRest)
}

case _ => throw new Exception("Wrong expression")
}

def loadEnvironment(s: String): Map[String, Inferer.Type] = {
def parseLine(s: String): (String, Inferer.Type) = {
val lexemes = toLexemes(s)
lexemes match {
case Identifier(name) :: Colon :: Forall :: rest =>
val (forParams, nextRest) = rest.tail.span(_ != CloseSquareBracket)
val params = forParams.collect { case Identifier(a) => a }
val generics = params.map(name => name -> Inferer.nextVariable).toMap
name -> parseType(generics, nextRest.tail)

case Identifier(name) :: Colon :: rest =>
name -> parseType(Map(), rest)

case _ =>
throw new Exception("Wrong type")
}
}

s.split("\n").map(parseLine).toMap
}

private def findCloseBracket(lexemes: List[Lexeme], openBracket: Lexeme, closeBracket: Lexeme): (List[Lexeme], List[Lexeme]) = {
case class Acc(lexemes: List[Lexeme] = Nil, bracketCount: Int = 1)
@scala.annotation.tailrec
def inner(lexemes: List[Lexeme], acc: Acc = Acc()): (List[Lexeme], List[Lexeme]) = {
lexemes match {
case (lex@`closeBracket`) :: lexemes =>
if (acc.bracketCount == 1) (acc.lexemes.reverse, lexemes)
else inner(lexemes, Acc(lex :: acc.lexemes, acc.bracketCount - 1))
case (lex@`openBracket`) :: lexemes => inner(lexemes, Acc(lex :: acc.lexemes, acc.bracketCount + 1))
case lex :: lexemes => inner(lexemes, Acc(lex :: acc.lexemes, acc.bracketCount))
case Nil => throw new Exception("Wrong expression")
}
}

inner(lexemes)
}

private def parseType(generics: Map[String, Inferer.Variable], lexemes: List[Lexeme]): Inferer.Type = {
val res = lexemes match {
case OpenRoundBracket :: rest =>
val (forSubTypes, nextRest) = findCloseBracket(rest, OpenRoundBracket, CloseRoundBracket)
val tys = forSubTypes.splitBy(Set(Comma), Set(Comma)).map(lexemes => parseType(generics, lexemes))
Inferer.Function(tys, parseType(generics, nextRest.tail))

case lexemes =>
val (first, rest) = lexemes.span(_ != Arrow)
val (id, params) = first.span(_ != OpenSquareBracket)

val firstType = if (params.isEmpty) { //Identifier
generics.getOrElse(name, Inferer.Operation(name, Seq()))
} else {
Inferer.Operation(name, params.collect { case Identifier(a) => generics(a) })
}

if (rest.isEmpty)
firstType
else
Inferer.Function(List(firstType), parseType(generics, rest.tail))
}
res
}

trait Lexeme

case class Identifier(name: String) extends Lexeme

case object Let extends Lexeme

case object In extends Lexeme

case object Fun extends Lexeme

case object Forall extends Lexeme

case object Equal extends Lexeme

case object Arrow extends Lexeme

case object Comma extends Lexeme

implicit class Splitter(lexemes: List[Lexeme]) {
def splitBy(separators: Set[Lexeme], drop: Set[Lexeme] = Set()): List[List[Lexeme]] = {
case class Acc(current: List[Lexeme] = Nil, data: List[List[Lexeme]] = Nil)
val acc = lexemes.foldLeft(Acc())((acc, lex) => {
val nextCurrent = if (drop.contains(lex)) acc.current else lex :: acc.current
if (separators.contains(lex)) Acc(Nil, nextCurrent.reverse :: acc.data)
else Acc(nextCurrent, acc.data)
})

(if (acc.current.isEmpty) acc.data else acc.current.reverse :: acc.data).reverse
}
}

case object Colon extends Lexeme

case object OpenRoundBracket extends Lexeme

case object CloseRoundBracket extends Lexeme

case object OpenSquareBracket extends Lexeme

case object CloseSquareBracket extends Lexeme

}

object Inferer {

type Environment = Map[String, Type]
private val arrow = "->"
private var nextName = 'a'
private var nextId = 0

def Function(from: List[Type], to: Type): Operation = Operation(arrow, from.:+(to))

def nextUniqueName: String = {
val res = nextName
nextName = (nextName + 1).toChar
res.toString
}

def nextVariable: Variable = {
val result = nextId
nextId += 1
Variable(result)
}

def infer(expression: Expression, environment: Environment, variables: mutable.Set[Variable] = mutable.Set()): Type = expression match {
case Expression.VariableDefinition(name, definition, body) =>
val definitionType = infer(definition, environment, variables)
val nextEnvironment = environment + (name -> definitionType)
infer(body, nextEnvironment, variables)

case Expression.FunctionDefinition(params, body) =>
val paramTypes = params.map(_ => nextVariable)
val nextEnvironment = params.zip(paramTypes)
.foldLeft(environment) { case (acc, (v, t)) => acc + (v -> t) }
val nextVariables = paramTypes.foldLeft(variables)((acc, t) => acc.union(mutable.Set(t)))
val resultType = infer(body, nextEnvironment, nextVariables)
Function(paramTypes, resultType)

case Expression.FunctionCall(function, params) =>
val functionType = infer(function, environment, variables)
val paramTypes = params.map(infer(_, environment, variables))
val resultType = nextVariable
unify(Function(paramTypes, resultType), functionType)
resultType

case Expression.Variable(name) =>
inferVariableType(environment(name), variables)
}

def unify(type0: Type, type1: Type): Unit = {
(prune(type0), prune(type1)) match {
case (a: Variable, b) => if (a != b) {
if (contains(b, a)) throw new Exception("recursion detected")
a.typeOpt = Some(b)
}

case (a: Operation, b: Variable) =>
unify(b, a)

case (a: Operation, b: Operation) =>
if (a.name != b.name || a.params.length != b.params.length)
throw new Exception(s"Type mismatch: \$a =/= \$b")
for (i <- a.params.indices)
unify(a.params(i), b.params(i))

case _ => throw new Exception("Wrong types.")
}
}

def prune(t: Type): Type = t match {
case v: Variable if v.typeOpt.isDefined =>
val res = prune(v.typeOpt.get)
v.typeOpt = Some(res)
res
case _ => t
}

def inferVariableType(t: Type, variables: mutable.Set[Variable]): Type = {
val map = mutable.Map[Variable, Variable]()

def inner(t: Type): Type = {
prune(t) match {
case v: Variable =>
if (contains(variables, v)) v
else map.getOrElseUpdate(v, nextVariable) // generic
case Operation(name, args) =>
Operation(name, args.map(inner))
}
}

inner(t)
}

def contains(t: Type, v: Variable): Boolean = {
prune(t) match {
case `v` => true
case Operation(_, params) => contains(params, v)
case _ => false
}
}

def contains(types: Iterable[Type], v: Variable): Boolean = types.exists(contains(_, v))

private def asString(t: Type, variables: mutable.Set[String]): String = t match {
case v: Variable => v.typeOpt match {
case Some(t) =>
asString(t, variables)
case None =>
v.name
}

case Operation(name, params) =>
if (params.isEmpty) name
else if (name == arrow) {
val tempLeft = params.take(params.length - 1).map(asString(_, variables)).mkString(", ")
val parentheses = params match {
case Operation(`arrow`, _) :: _ => true
case (v: Inferer.Variable) :: _ if (v.typeOpt match {
case Some(op: Operation) if op.name == arrow => true
case _ => false
}) => true
case _ :: _ :: Nil => false
case _ => true
}

val left = if (parentheses) s"(\$tempLeft)" else tempLeft
val right = asString(params.last, variables)
s"\$left -> \$right"
} else {
s"\$name[\${params.map(asString(_, variables)).mkString(", ")}]"
}
}

trait Type {
override def toString: String = {
val variables = mutable.Set[String]()
val s = asString(this, variables)
if (variables.isEmpty) s else s"forall[\${variables.toList.sorted.mkString(" ")}] \$s"
}
}

case class Variable(id: Int) extends Type {
lazy val name: String = nextUniqueName
var typeOpt: Option[Type] = None
}

case class Operation(name: String, params: Seq[Type]) extends Type

}

object Solution {
private val envString =
|tail: forall[a] list[a] -> list[a]
|nil: forall[a] list[a]
|cons: forall[a] (a, list[a]) -> list[a]
|cons_curry: forall[a] a -> list[a] -> list[a]
|map: forall[a b] (a -> b, list[a]) -> list[b]
|map_curry: forall[a b] (a -> b) -> list[a] -> list[b]
|one: int
|zero: int
|succ: int -> int
|plus: (int, int) -> int
|eq: forall[a] (a, a) -> bool
|eq_curry: forall[a] a -> a -> bool
|not: bool -> bool
|true: bool
|false: bool
|pair: forall[a b] (a, b) -> pair[a, b]
|pair_curry: forall[a b] a -> b -> pair[a, b]
|first: forall[a b] pair[a, b] -> a
|second: forall[a b] pair[a, b] -> b
|id: forall[a] a -> a
|const: forall[a b] a -> b -> a
|apply: forall[a b] (a -> b, a) -> b
|apply_curry: forall[a b] (a -> b) -> a -> b
|choose: forall[a] (a, a) -> a
|choose_curry: forall[a] a -> a -> a
|""".stripMargin

def main(args: Array[String]): Unit = {
val sc = new Scanner(System.in)
val s = sc.nextLine
solve(s)
}

def solve(s: String): Unit = {
val lexemes = Parser.toLexemes(s)
val expression = Parser.toExpression(lexemes)