hamzaremmal/amy

View on GitHub
compiler/src/main/scala/amyc/backend/wasm/gen/WASMCodeGenerator.scala

Summary

Maintainability
A
0 mins
Test Coverage
A
92%
package amyc.backend.wasm.gen

import amyc.*
import amyc.ast.*
import amyc.ast.SymbolicTreeModule.{Call as AmyCall, *}
import amyc.core.*
import amyc.core.StdDefinitions.*
import amyc.core.StdTypes.*
import amyc.core.Symbols.*
import amyc.backend.wasm.*
import amyc.backend.wasm.Instructions.*
import amyc.backend.wasm.Modules.*
import amyc.backend.wasm.Types.{result, typeuse}
import amyc.backend.wasm.builtin.BuiltIn.*
import amyc.backend.wasm.builtin.amy.*
import amyc.backend.wasm.builtin.amy.Boolean.mkBoolean
import amyc.backend.wasm.builtin.amy.Unit.mkUnit
import amyc.backend.wasm.handlers.{LocalsHandler, ModuleHandler}
import amyc.backend.wasm.utils.*
import amyc.utils.Pipeline

object WASMCodeGenerator extends Pipeline[Program, Module] :

  override val name: String = "WASMCodeGenerator"
  override def run(program: Program)(using Context): Module =
    given ModuleHandler = ModuleHandler(program.modules.last.name.name)
    val fn = wasmFunctions ++ (program.modules flatMap cgModule)
    Module(
      program.modules.last.name.name,
      Global(mh.static_boundary) :: Nil,
      defaultImports,
      Some(mh.table),
      mh.strpool,
      fn
    )

  // ==============================================================================================
  // ===================================== TRANSFORMER ============================================
  // ==============================================================================================

  /**
    *
    * @param moduleDef
    * @param Context
    * @param ModuleHandler
    * @return
    */
  def cgModule(moduleDef: ModuleDef)(using Context)(using ModuleHandler): List[Function] =
    val ModuleDef(name, defs, optExpr) = moduleDef
    defs.collect {
      case cd@CaseClassDef(sym, params, _) =>
        Function.forDefinition(cd){
          val index = lh.constructor(sym)
          val l = lh.getFreshLocal // ref to object, should contain the name

          // Dynamically allocate space for the object in memory
          lh.mh.dynamic_alloc(params.size) <:>
          local.set(l) <:>
          local.get(l) <:>
          i32.const(index) <:>
          i32.store <:> {
            // HR: Store each of the constructor parameter
            for offset <- (0 to params.size).toList yield
              adtField(local.get(l), offset) <:> // Compute the offset to store in
              local.get(offset) <:> // Fetch the data to store
              i32.store
          } <:>
          local.get(l)
        }
    } ++
    // Generate code for all functions
    defs.collect {
      case fd: FunDef if !(fd.name.asInstanceOf[FunctionSymbol] is "native") =>
        cgFunction(fd, Some(result(i32)))
    } ++
      // Generate code for the "main" function, which contains the module expression
      optExpr.toList.map { expr =>
        val sym = symbols.addFunction(name.asInstanceOf, "main", Nil, Nil, TTypeTree(stdType.IntType))
        val mainFd = FunDef(sym, Nil, ClassTypeTree(stdDef.IntType), expr)
        cgFunction(mainFd, None)
      }

  /**
    * Generate code for a function in module 'owner'
    * @param fd
    * @param owner
    * @param result
    * @param Context
    * @param ModuleHandler
    * @return
    */
  def cgFunction(fd: FunDef, result : Option[result])(using Context)(using ModuleHandler): Function = {
    // Note: We create the wasm function name from a combination of
    // module and function name, since we put everything in the same wasm module.
    Function.forDefinition(fd, result) {
      val body = cgExpr(fd.body)
      withComment(fd.toString) {
        if result.isEmpty then
          body <:> drop // Main functions do not return a value,
        // so we need to drop the value generated by their body
        else
          body
      }
    }
  }

  /**
    * Generate code for an expression expr.
    * @param expr
    * @param Context
    * @param LocalsHandler
    * @return
    */
  def cgExpr(expr: Expr)(using Context)(using LocalsHandler): Code = {
    expr match {
      case Variable(name) =>
        val idx = lh.fetch(name)
        if idx == -1 then
          reporter.error(s"cannot find $name")
        local.get(idx)
      case FunRef(ref: FunctionSymbol) => i32.const(lh.function(ref))
      case IntLiteral(i) => i32.const(i)
      case BooleanLiteral(b) => mkBoolean(b)
      case StringLiteral(s) =>
        i32.const(lh.mh.string(s))
      case UnitLiteral() => mkUnit
      case InfixCall(_, op, _) =>
        reporter.fatal(s"Cannot generate wasm code for operator, should not appear here $op")
      case Not(e) =>
        cgExpr(e) <:> i32.eqz
      case Neg(e) =>
        i32.const(0) <:>
        cgExpr(e) <:>
        i32.sub
      case AmyCall(sym: ConstructorSymbol, args) =>
        args.map(cgExpr) <:>
        call(fullName(sym))
      case AmyCall(qname: FunctionSymbol, args) =>
        val defn = stdDef(using ctx)
        qname match
          case defn.binop_&& =>
            and(cgExpr(args.head), cgExpr(args(1)))
          case defn.binop_|| =>
            or(cgExpr(args.head), cgExpr(args(1)))
          case _ =>
            args.map(cgExpr) <:>
            call(fullName(qname))
      case AmyCall(sym: (LocalSymbol | ParameterSymbol), args) =>
        args.map(cgExpr) <:>
        local.get(lh.fetch(sym)) <:>
        call_indirect(typeuse(mkFunTypeName(args.size)))
      case Sequence(e1, e2) =>
        cgExpr(e1) <:>
        drop <:>
        cgExpr(e2)
      case Let(pdf, value, body) =>
        val idx = lh.getFreshLocal(pdf.name)
        cgExpr(value) <:>
        local.set(idx) <:>
        cgExpr(body)
      case Ite(cond, thenn, elze) =>
        ift(cgExpr(cond), cgExpr(thenn), cgExpr(elze))
      case Match(scrut, cases) =>
        val l = lh.getFreshLocal

        cgExpr(scrut) <:>
        local.set(l) <:> {
          for
            c <- cases
            cond = matchAndBind(c.pat)
          yield
            local.get(l) <:>
              cond <:>
              `if`(None, Some(result(i32))) <:>
              cgExpr(c.expr) <:>
              `else`()
          // Else here become we are building a big if else bloc.
          // Last bloc will be concatenated with the Match error below and the
          // match error case there
        } <:>
          i32.const(lh.mh.string(s"Match error!$scrut")) <:>
          cases.map(_ => end) // HR: Autant de End que de cases
      case Error(msg) =>
        error(cgExpr(msg))
      case _ =>
        reporter.fatal(s"Cannot generate wasm code for $expr", expr.position)
    }
  }

  // ==============================================================================================
  // =============================== GENERATE PATTERN MATCHING ====================================
  // ==============================================================================================

  /**
    * Checks if a value matches a pattern.
    * Assumes value is on top of stack (and CONSUMES it)
    * Returns the code to check the value
    * @param pat
    * @param Context
    * @param LocalsHandler
    * @return
    */
  def matchAndBind(pat: Pattern)(using Context)(using LocalsHandler): Code =
    pat match
      case WildcardPattern() =>
        // HR : We return true as this pattern will be executed if encountered
        drop <:> mkBoolean(true)
      case IdPattern(id) =>
        // HR : We return true as this pattern will be executed if encountered
        val idLocal = lh.getFreshLocal(id)
        local.set(idLocal) <:> mkBoolean(true)
      case LiteralPattern(lit) =>
        cgExpr(lit) <:> i32.eq
      case CaseClassPattern(constr, args) =>
        val idx = lh.getFreshLocal
        val code = args.map(matchAndBind).zipWithIndex.map {
          p => adtField(local.get(idx), p._2) <:> i32.load <:> p._1
        }
        // HR : Setting the constructor index we are checking as a local variable
        local.set(idx) <:>
          ift({
            // HR : First check if the primary constructor is the same
            local.get(idx) <:>
            i32.load <:>
            i32.const(lh.constructor(constr)) <:>
            i32.eq
          }, {
            // HR : Check if all the pattern applies
            // HR : if the constructor has no parameters the foldLeft returns true
            code.foldLeft(mkBoolean(true))((acc, c) => and(c._1, acc))
          }, {
            mkBoolean(false) // HR : Signal that the scrut does not follow this pattern
          })