hamzaremmal/amy

View on GitHub
compiler/src/main/scala/amyc/analyzer/Transformer.scala

Summary

Maintainability
A
0 mins
Test Coverage
B
89%
package amyc.analyzer

import amyc.*
import amyc.core.*
import amyc.core.Symbols.*
import amyc.core.StdDefinitions.*
import amyc.ast.{NominalTreeModule as N, SymbolicTreeModule as S}

object Transformer {

  /**
    *
    * @param p
    * @param core .Context
    * @return
    */
  def transformProgram(p: N.Program)(using Context): S.Program =
    val symMods = for mod <- p.modules yield
      transformModule(mod).setPos(mod)
    S.Program(symMods).setPos(p)

  /**
    * 
    * @param mod
    * @param Context
    * @return
    */
  def transformModule(mod: N.ModuleDef)(using Context) =
    val N.ModuleDef(name, defs, optExpr) = mod
    val symName = symbols.getModule(name).getOrElse {
      reporter.fatal(s"Cannot find symbol for module $name")
    }
    val symDefs = for d <- defs yield transformDef(d, name)
    val symExpr = optExpr.map(transformExpr(_)(name, ctx.scope(symName), ctx))
    S.ModuleDef(symName, symDefs, symExpr)

  /**
    * 
    * @param tt
    * @param inModule
    * @param core.Context
    * @return
    */
  def transformType(tt: N.TypeTree, inModule: String)(using Context): S.TypeTree = {
    tt match
      case N.FunctionTypeTree(params, rte) =>
        S.FunctionTypeTree(params.map(transformType(_, inModule)), transformType(rte, inModule))
      case N.ClassTypeTree(N.QualifiedName(None, name)) =>
        name match
          case "Unit" => S.ClassTypeTree(stdDef.UnitType)
          case "Boolean" => S.ClassTypeTree(stdDef.BooleanType)
          case "Int" => S.ClassTypeTree(stdDef.IntType)
          case "String" => S.ClassTypeTree(stdDef.StringType)
          case _ =>
            symbols.getType(inModule, name) map S.ClassTypeTree.apply getOrElse {
              reporter.fatal(s"Could not find type $name", tt)
            }
      case N.ClassTypeTree(qn@N.QualifiedName(pre, name)) =>
        symbols.getType(pre getOrElse inModule, name) map S.ClassTypeTree.apply getOrElse{
          reporter.fatal(s"Could not find type $qn", tt)
        }
      case N.TTypeTree(tpe) =>
        S.TTypeTree(tpe)
  }

  /**
    * 
    * @param fd
    * @param module
    * @param core.Context
    * @return
    */
  def transformFunDef(fd: N.FunDef, module: String)(using Context): S.FunDef = {
    val N.FunDef(name, params, retType, body) = fd
    val sym = symbols.function(module, name)

    params.groupBy(_.name).foreach { case (name, ps) =>
      if (ps.size > 1) {
        reporter.fatal(s"Two parameters named $name in function ${fd.name}", fd)
      }
    }

    val newParams = params zip sym.param map {
      case (pd@N.ParamDef(_, tt), sym) =>
        S.ParamDef(sym, sym.tpe.setPos(tt)).setPos(pd)
    }

    val paramsMap = sym.param.map(s => (s.name, s)).toMap

    S.FunDef(
      sym,
      newParams,
      sym.rte.setPos(retType),
      transformExpr(body)(module, Scope.fresh.withParams(paramsMap), ctx)
    ).setPos(fd)
  }

  /**
    * 
    * @param df
    * @param module
    * @param core.Context
    * @return
    */
  def transformDef(df: N.ClassOrFunDef, module: String)(using Context): S.ClassOrFunDef = {
    df match {
      case N.AbstractClassDef(name) =>
        S.AbstractClassDef(symbols.getType(module, name).get)
      case N.CaseClassDef(name, params, _) =>
        val sym = symbols.constructor(module, name)
        val newParams = params zip sym.param map {
          case (pd@N.ParamDef(_, tt), sym) =>
            S.ParamDef(sym, sym.tpe.setPos(tt)).setPos(pd)
        }
        S.CaseClassDef(sym, newParams, sym.parent)
      case fd: N.FunDef =>
        transformFunDef(fd, module)
    }
  }.setPos(df)

  /**
    * 
    * @param expr
    * @param module
    * @param names
    * @param context
    * @return
    */
  def transformExpr(expr: N.Expr)
                   (implicit module: String, scope : Scope, context: Context): S.Expr = {
    val res = expr match {
      case N.Variable(name) =>
        scope.resolve(name) match
          case Some(id) => S.Variable(id)
          case _ => reporter.fatal(s"Variable $name not found", expr)
      case N.FunRef(N.QualifiedName(module, name)) =>
        // TODO HR : get won't throw an exception; operation guaranteed to work
        val sym = symbols.getFunction(module.get, name)
          .getOrElse(reporter.fatal(s"Fix error message here"))
        S.FunRef(sym)
      case N.IntLiteral(value) =>
        S.IntLiteral(value)
      case N.BooleanLiteral(value) =>
        S.BooleanLiteral(value)
      case N.StringLiteral(value) =>
        S.StringLiteral(value)
      case N.UnitLiteral() =>
        S.UnitLiteral()
      case N.InfixCall(lhs, op, rhs) =>
        // desugar infix calls to function calls
        transformExpr(N.Call(N.QualifiedName(Some("unnamed"), op), lhs :: rhs :: Nil))
      case N.Not(e) =>
        S.Not(transformExpr(e))
      case N.Neg(e) =>
        S.Neg(transformExpr(e))
      case N.Call(qname, args) =>
        val owner = qname.module.getOrElse(module)
        val name = qname.name
        val entry = scope.resolve(qname.name) orElse {
            symbols.getConstructor(owner, name)
          } orElse {
            symbols.getFunction(owner, name)
          }
        entry match {
          case None =>
            reporter.fatal(s"Function or constructor $qname not found", expr)
          case Some(sym: ApplicationSymbol) =>
            if (sym.param.size != args.size) {
              reporter.fatal(s"Wrong number of arguments for function/constructor $qname", expr)
            }
            S.Call(sym, args.map(transformExpr(_)))
          case Some(sym: Symbol) =>
            S.Call(sym, args.map(transformExpr(_)))
        }
      case N.Sequence(e1, e2) =>
        S.Sequence(transformExpr(e1), transformExpr(e2))
      case N.Let(vd, value, body) =>
        if (scope.isLocal(vd.name)) {
          reporter.fatal(s"Variable redefinition: ${vd.name}", vd)
        }
        if (scope.isParam(vd.name)) {
          reporter.warning(s"Local variable ${vd.name} shadows function parameter", vd)
        }
        val sym = LocalSymbol(Identifier.fresh(vd.name))
        val tpe = transformType(vd.tt, module)
        S.Let(
          S.ParamDef(sym, tpe).setPos(vd),
          transformExpr(value),
          transformExpr(body)(module, scope.withLocal(vd.name, sym), ctx)
        )
      case N.Ite(cond, thenn, elze) =>
        S.Ite(transformExpr(cond), transformExpr(thenn), transformExpr(elze))
      case N.Match(scrut, cases) =>
        def transformCase(cse: N.MatchCase) = {
          val N.MatchCase(pat, rhs) = cse
          val (newPat, caseScope) = transformPattern(pat)
          S.MatchCase(newPat, transformExpr(rhs)(module, caseScope, ctx).setPos(rhs)).setPos(cse)
        }

        def transformPattern(pat: N.Pattern): (S.Pattern, Scope) = {
          val (newPat, newScope): (S.Pattern, Scope) = pat match {
            case N.WildcardPattern() =>
              (S.WildcardPattern(), scope)
            case N.IdPattern(name) =>
              if (scope.isLocal(name)) {
                reporter.fatal(s"Pattern identifier $name already defined", pat)
              }
              if (scope.isParam(name)) {
                reporter.warning("Suspicious shadowing by an Id Pattern", pat)
              }
              symbols.getConstructor(module, name) match {
                case Some(sym : ConstructorSymbol) if sym.param.isEmpty =>
                  reporter.warning(s"There is a nullary constructor in this module called '$name'. Did you mean '$name()'?", pat)
                case _ =>
              }
              val sym = LocalSymbol(Identifier.fresh(name))
              (S.IdPattern(sym), scope.withLocal(name, sym))
            case N.LiteralPattern(lit) =>
              (S.LiteralPattern(transformExpr(lit).asInstanceOf[S.Literal[_]]), scope)
            case N.CaseClassPattern(constr, args) =>
              val sym = symbols
                .getConstructor(constr.module.getOrElse(module), constr.name)
                .getOrElse {
                  reporter.fatal(s"Constructor $constr not found", pat)
                }
              if (sym.param.size != args.size) {
                reporter.fatal(s"Wrong number of args for constructor $constr", pat)
              }
              val (newPatts, moreLocals0) = (args map transformPattern).unzip
              val moreLocals = if moreLocals0.nonEmpty then
                // TODO HR : This check here should be refactored (inefficient)
                moreLocals0.toSet.flatMap(_.locals.map(identity)).groupBy(_._1).foreach { case (name, pairs) =>
                  if (pairs.size > 1) {
                    reporter.fatal(s"Multiple definitions of $name in pattern", pat)
                  }
                }
                moreLocals0.reduce(Scope.combine)
              else
                scope
              (S.CaseClassPattern(sym, newPatts), moreLocals)
          }
          (newPat.setPos(pat), newScope)
        }

        S.Match(transformExpr(scrut), cases map transformCase)
      case N.Error(msg) =>
        S.Error(transformExpr(msg))
      case N.EmptyExpr() =>
        S.EmptyExpr()
    }
    res.setPos(expr)
  }

}