compiler.go

Summary

Maintainability
A
2 hrs
Test Coverage
// Copyright (c) 2020-2023 Ozan Hacıbekiroğlu.
// Use of this source code is governed by a MIT License
// that can be found in the LICENSE file.

package ugo

import (
    "errors"
    "fmt"
    "io"
    "os"
    "reflect"

    "github.com/ozanh/ugo/parser"
    "github.com/ozanh/ugo/token"
)

var (
    // DefaultCompilerOptions holds default Compiler options.
    DefaultCompilerOptions = CompilerOptions{
        OptimizerMaxCycle: 100,
        OptimizeConst:     true,
        OptimizeExpr:      true,
    }
    // TraceCompilerOptions holds Compiler options to print trace output
    // to stdout for Parser, Optimizer, Compiler.
    TraceCompilerOptions = CompilerOptions{
        Trace:             os.Stdout,
        TraceParser:       true,
        TraceCompiler:     true,
        TraceOptimizer:    true,
        OptimizerMaxCycle: 1<<8 - 1,
        OptimizeConst:     true,
        OptimizeExpr:      true,
    }
)

// errSkip is a sentinel error for compiler.
var errSkip = errors.New("skip")

type (

    // Compiler compiles the AST into a bytecode.
    Compiler struct {
        parent        *Compiler
        file          *parser.SourceFile
        constants     []Object
        constsCache   map[Object]int
        cfuncCache    map[uint32][]int
        symbolTable   *SymbolTable
        instructions  []byte
        sourceMap     map[int]int
        moduleMap     *ModuleMap
        moduleStore   *moduleStore
        modulePath    string
        variadic      bool
        loops         []*loopStmts
        loopIndex     int
        tryCatchIndex int
        iotaVal       int
        opts          CompilerOptions
        trace         io.Writer
        indent        int
    }

    // CompilerOptions represents customizable options for Compile().
    CompilerOptions struct {
        ModuleMap         *ModuleMap
        ModulePath        string
        Constants         []Object
        SymbolTable       *SymbolTable
        Trace             io.Writer
        TraceParser       bool
        TraceCompiler     bool
        TraceOptimizer    bool
        OptimizerMaxCycle int
        OptimizeConst     bool
        OptimizeExpr      bool
        moduleStore       *moduleStore
        constsCache       map[Object]int
    }

    // CompilerError represents a compiler error.
    CompilerError struct {
        FileSet *parser.SourceFileSet
        Node    parser.Node
        Err     error
    }

    // moduleStoreItem represents indexes of a single module.
    moduleStoreItem struct {
        typ           int
        constantIndex int
        moduleIndex   int
    }

    // moduleStore represents modules indexes and total count that are defined
    // while compiling.
    moduleStore struct {
        count int
        store map[string]moduleStoreItem
    }

    // loopStmts represents a loopStmts construct that the compiler uses to
    // track the current loopStmts.
    loopStmts struct {
        continues         []int
        breaks            []int
        lastTryCatchIndex int
    }
)

func (e *CompilerError) Error() string {
    filePos := e.FileSet.Position(e.Node.Pos())
    return fmt.Sprintf("Compile Error: %s\n\tat %s", e.Err.Error(), filePos)
}

func (e *CompilerError) Unwrap() error {
    return e.Err
}

// NewCompiler creates a new Compiler object.
func NewCompiler(file *parser.SourceFile, opts CompilerOptions) *Compiler {
    if opts.SymbolTable == nil {
        opts.SymbolTable = NewSymbolTable()
    }

    if opts.constsCache == nil {
        opts.constsCache = make(map[Object]int)
        for i := range opts.Constants {
            switch opts.Constants[i].(type) {
            case Int, Uint, String, Bool, Float, Char, *UndefinedType:
                opts.constsCache[opts.Constants[i]] = i
            }
        }
    }

    if opts.moduleStore == nil {
        opts.moduleStore = newModuleStore()
    }

    var trace io.Writer
    if opts.TraceCompiler {
        trace = opts.Trace
    }

    return &Compiler{
        file:          file,
        constants:     opts.Constants,
        constsCache:   opts.constsCache,
        cfuncCache:    make(map[uint32][]int),
        symbolTable:   opts.SymbolTable,
        sourceMap:     make(map[int]int),
        moduleMap:     opts.ModuleMap,
        moduleStore:   opts.moduleStore,
        modulePath:    opts.ModulePath,
        loopIndex:     -1,
        tryCatchIndex: -1,
        iotaVal:       -1,
        opts:          opts,
        trace:         trace,
    }
}

// Compile compiles given script to Bytecode.
func Compile(script []byte, opts CompilerOptions) (*Bytecode, error) {
    fileSet := parser.NewFileSet()
    moduleName := opts.ModulePath

    if moduleName == "" {
        moduleName = "(main)"
    }

    srcFile := fileSet.AddFile(moduleName, -1, len(script))
    var trace io.Writer
    if opts.TraceParser {
        trace = opts.Trace
    }

    p := parser.NewParser(srcFile, script, trace)
    pf, err := p.ParseFile()
    if err != nil {
        return nil, err
    }

    compiler := NewCompiler(srcFile, opts)
    compiler.SetGlobalSymbolsIndex()

    if opts.OptimizeConst || opts.OptimizeExpr {
        err := compiler.optimize(pf)
        if err != nil && err != errSkip {
            return nil, err
        }
    }

    if err := compiler.Compile(pf); err != nil {
        return nil, err
    }

    bc := compiler.Bytecode()
    if bc.Main.NumLocals > 256 {
        return nil, ErrSymbolLimit
    }
    return bc, nil
}

// SetGlobalSymbolsIndex sets index of a global symbol. This is only required
// when a global symbol is defined in SymbolTable and provided to compiler.
// Otherwise, caller needs to append the constant to Constants, set the symbol
// index and provide it to the Compiler. This should be called before
// Compiler.Compile call.
func (c *Compiler) SetGlobalSymbolsIndex() {
    symbols := c.symbolTable.Symbols()
    for _, s := range symbols {
        if s.Scope == ScopeGlobal && s.Index == -1 {
            s.Index = c.addConstant(String(s.Name))
        }
    }
}

// optimize runs the Optimizer and returns Optimizer object and error from Optimizer.
// Note:If optimizer cannot run for some reason, a nil optimizer and errSkip
// error will be returned.
func (c *Compiler) optimize(file *parser.File) error {
    if c.opts.OptimizerMaxCycle < 1 {
        return errSkip
    }

    optim := NewOptimizer(file, c.symbolTable, c.opts)

    if err := optim.Optimize(); err != nil {
        return err
    }

    c.opts.OptimizerMaxCycle -= optim.Total()
    return nil
}

// Bytecode returns compiled Bytecode ready to run in VM.
func (c *Compiler) Bytecode() *Bytecode {
    var lastOp Opcode
    var operands = make([]int, 0, 4)
    var jumpPos = make(map[int]struct{})
    var offset int
    var i int

    for i < len(c.instructions) {
        lastOp = c.instructions[i]
        numOperands := OpcodeOperands[lastOp]
        operands, offset = ReadOperands(
            numOperands,
            c.instructions[i+1:],
            operands,
        )

        if lastOp == OpJump || lastOp == OpJumpFalsy ||
            lastOp == OpAndJump || lastOp == OpOrJump {
            jumpPos[operands[0]] = struct{}{}
        }

        delete(jumpPos, i)
        i += offset + 1
    }

    if lastOp != OpReturn || len(jumpPos) > 0 {
        c.emit(nil, OpReturn, 0)
    }

    return &Bytecode{
        FileSet:   c.file.Set(),
        Constants: c.constants,
        Main: &CompiledFunction{
            NumParams:    c.symbolTable.NumParams(),
            NumLocals:    c.symbolTable.MaxSymbols(),
            Variadic:     c.variadic,
            Instructions: c.instructions,
            SourceMap:    c.sourceMap,
        },
        NumModules: c.moduleStore.count,
    }
}

// Compile compiles parser.Node and builds Bytecode.
func (c *Compiler) Compile(node parser.Node) error {
    if c.trace != nil {
        if node != nil {
            defer untracec(tracec(c, fmt.Sprintf("%s (%s)",
                node.String(), reflect.TypeOf(node).Elem().Name())))
        } else {
            defer untracec(tracec(c, "<nil>"))
        }
    }

    switch node := node.(type) {
    case *parser.File:
        for _, stmt := range node.Stmts {
            if err := c.Compile(stmt); err != nil {
                return err
            }
        }
    case *parser.ExprStmt:
        if err := c.Compile(node.Expr); err != nil {
            return err
        }
        c.emit(node, OpPop)
    case *parser.IncDecStmt:
        op := token.AddAssign
        if node.Token == token.Dec {
            op = token.SubAssign
        }
        return c.compileAssignStmt(
            node,
            []parser.Expr{node.Expr},
            []parser.Expr{&parser.IntLit{Value: 1, ValuePos: node.TokenPos}},
            token.Var,
            op,
        )
    case *parser.ParenExpr:
        return c.Compile(node.Expr)
    case *parser.BinaryExpr:
        if node.Token == token.LAnd || node.Token == token.LOr {
            return c.compileLogical(node)
        }
        return c.compileBinaryExpr(node)
    case *parser.IntLit:
        c.emit(node, OpConstant, c.addConstant(Int(node.Value)))
    case *parser.UintLit:
        c.emit(node, OpConstant, c.addConstant(Uint(node.Value)))
    case *parser.FloatLit:
        c.emit(node, OpConstant, c.addConstant(Float(node.Value)))
    case *parser.BoolLit:
        if node.Value {
            c.emit(node, OpTrue)
        } else {
            c.emit(node, OpFalse)
        }
    case *parser.StringLit:
        c.emit(node, OpConstant, c.addConstant(String(node.Value)))
    case *parser.CharLit:
        c.emit(node, OpConstant, c.addConstant(Char(node.Value)))
    case *parser.UndefinedLit:
        c.emit(node, OpNull)
    case *parser.UnaryExpr:
        return c.compileUnaryExpr(node)
    case *parser.IfStmt:
        return c.compileIfStmt(node)
    case *parser.TryStmt:
        return c.compileTryStmt(node)
    case *parser.CatchStmt:
        return c.compileCatchStmt(node)
    case *parser.FinallyStmt:
        return c.compileFinallyStmt(node)
    case *parser.ThrowStmt:
        return c.compileThrowStmt(node)
    case *parser.ForStmt:
        return c.compileForStmt(node)
    case *parser.ForInStmt:
        return c.compileForInStmt(node)
    case *parser.BranchStmt:
        return c.compileBranchStmt(node)
    case *parser.BlockStmt:
        return c.compileBlockStmt(node)
    case *parser.DeclStmt:
        return c.compileDeclStmt(node)
    case *parser.AssignStmt:
        return c.compileAssignStmt(node,
            node.LHS, node.RHS, token.Var, node.Token)
    case *parser.Ident:
        return c.compileIdent(node)
    case *parser.ArrayLit:
        return c.compileArrayLit(node)
    case *parser.MapLit:
        return c.compileMapLit(node)
    case *parser.SelectorExpr: // selector on RHS side
        return c.compileSelectorExpr(node)
    case *parser.IndexExpr:
        return c.compileIndexExpr(node)
    case *parser.SliceExpr:
        return c.compileSliceExpr(node)
    case *parser.FuncLit:
        return c.compileFuncLit(node)
    case *parser.ReturnStmt:
        return c.compileReturnStmt(node)
    case *parser.CallExpr:
        return c.compileCallExpr(node)
    case *parser.ImportExpr:
        return c.compileImportExpr(node)
    case *parser.CondExpr:
        return c.compileCondExpr(node)
    case *parser.EmptyStmt:
    case nil:
    default:
        return c.errorf(node, `%[1]T "%[1]v" not implemented`, node)
    }
    return nil
}

func (c *Compiler) changeOperand(opPos int, operand ...int) {
    op := c.instructions[opPos]
    inst := make([]byte, 0, 8)
    inst, err := MakeInstruction(inst, op, operand...)
    if err != nil {
        panic(err)
    }
    c.replaceInstruction(opPos, inst)
}

func (c *Compiler) replaceInstruction(pos int, inst []byte) {
    copy(c.instructions[pos:], inst)
    if c.trace != nil {
        printTrace(c.indent, c.trace, fmt.Sprintf("REPLC %s",
            FormatInstructions(c.instructions[pos:], pos)[0]))
    }
}

func (c *Compiler) addConstant(obj Object) (index int) {
    defer func() {
        if c.trace != nil {
            printTrace(c.indent, c.trace,
                fmt.Sprintf("CONST %04d %v", index, obj))
        }
    }()

    switch obj.(type) {
    case Int, Uint, String, Bool, Float, Char, *UndefinedType:
        i, ok := c.constsCache[obj]
        if ok {
            index = i
            return
        }
    case *CompiledFunction:
        return c.addCompiledFunction(obj)
    default:
        // unhashable types cannot be stored in constsCache, append them to constants slice
        // and return index
        index = len(c.constants)
        c.constants = append(c.constants, obj)
        return
    }

    index = len(c.constants)
    c.constants = append(c.constants, obj)
    c.constsCache[obj] = index
    return
}

func (c *Compiler) addCompiledFunction(obj Object) (index int) {
    // Currently, caching compiled functions is only effective for functions
    // used in const declarations.
    // e.g.
    // const (
    //     f = func() { return 1 }
    //     g
    // )
    //
    cf := obj.(*CompiledFunction)
    key := cf.hash32()
    arr, ok := c.cfuncCache[key]
    if ok {
        for _, idx := range arr {
            f := c.constants[idx].(*CompiledFunction)
            if f.identical(cf) && f.equalSourceMap(cf) {
                return idx
            }
        }
    }
    index = len(c.constants)
    c.constants = append(c.constants, obj)
    c.cfuncCache[key] = append(c.cfuncCache[key], index)
    return
}

func (c *Compiler) emit(node parser.Node, opcode Opcode, operands ...int) int {
    filePos := parser.NoPos
    if node != nil {
        filePos = node.Pos()
    }

    inst := make([]byte, 0, 8)
    inst, err := MakeInstruction(inst, opcode, operands...)
    if err != nil {
        panic(err)
    }

    pos := c.addInstruction(inst)
    c.sourceMap[pos] = int(filePos)

    if c.trace != nil {
        printTrace(c.indent, c.trace, fmt.Sprintf("EMIT  %s",
            FormatInstructions(c.instructions[pos:], pos)[0]))
    }
    return pos
}

func (c *Compiler) addInstruction(b []byte) int {
    posNewIns := len(c.instructions)
    c.instructions = append(c.instructions, b...)
    return posNewIns
}

func (c *Compiler) checkCyclicImports(node parser.Node, modulePath string) error {
    if c.modulePath == modulePath {
        return c.errorf(node, "cyclic module import: %s", modulePath)
    } else if c.parent != nil {
        return c.parent.checkCyclicImports(node, modulePath)
    }
    return nil
}

func (c *Compiler) addModule(name string, typ, constantIndex int) moduleStoreItem {
    moduleIndex := c.moduleStore.count
    c.moduleStore.count++
    c.moduleStore.store[name] = moduleStoreItem{
        typ:           typ,
        constantIndex: constantIndex,
        moduleIndex:   moduleIndex,
    }
    return c.moduleStore.store[name]
}

func (c *Compiler) getModule(name string) (moduleStoreItem, bool) {
    indexes, ok := c.moduleStore.store[name]
    return indexes, ok
}

func (c *Compiler) baseModuleMap() *ModuleMap {
    if c.parent == nil {
        return c.moduleMap
    }
    return c.parent.baseModuleMap()
}

func (c *Compiler) compileModule(
    node parser.Node,
    modulePath string,
    moduleMap *ModuleMap,
    src []byte,
) (int, error) {
    var err error
    if err = c.checkCyclicImports(node, modulePath); err != nil {
        return 0, err
    }

    modFile := c.file.Set().AddFile(modulePath, -1, len(src))
    var trace io.Writer
    if c.opts.TraceParser {
        trace = c.trace
    }

    p := parser.NewParser(modFile, src, trace)
    var file *parser.File
    file, err = p.ParseFile()
    if err != nil {
        return 0, err
    }

    symbolTable := NewSymbolTable().
        DisableBuiltin(c.symbolTable.DisabledBuiltins()...)

    fork := c.fork(modFile, modulePath, moduleMap, symbolTable)
    err = fork.optimize(file)
    if err != nil && err != errSkip {
        return 0, c.error(node, err)
    }

    if err = fork.Compile(file); err != nil {
        return 0, err
    }

    bc := fork.Bytecode()
    if bc.Main.NumLocals > 256 {
        return 0, c.error(node, ErrSymbolLimit)
    }

    c.constants = bc.Constants
    index := c.addConstant(bc.Main)
    return index, nil
}

func (c *Compiler) enterLoop() *loopStmts {
    loop := &loopStmts{lastTryCatchIndex: c.tryCatchIndex}
    c.loops = append(c.loops, loop)
    c.loopIndex++

    if c.trace != nil {
        printTrace(c.indent, c.trace, "LOOPE", c.loopIndex)
    }
    return loop
}

func (c *Compiler) leaveLoop() {
    if c.trace != nil {
        printTrace(c.indent, c.trace, "LOOPL", c.loopIndex)
    }
    c.loops = c.loops[:len(c.loops)-1]
    c.loopIndex--
}

func (c *Compiler) currentLoop() *loopStmts {
    if c.loopIndex >= 0 {
        return c.loops[c.loopIndex]
    }
    return nil
}

func (c *Compiler) fork(
    file *parser.SourceFile,
    modulePath string,
    moduleMap *ModuleMap,
    symbolTable *SymbolTable,
) *Compiler {
    child := NewCompiler(file, CompilerOptions{
        ModuleMap:         moduleMap,
        ModulePath:        modulePath,
        Constants:         c.constants,
        SymbolTable:       symbolTable,
        Trace:             c.trace,
        TraceParser:       c.opts.TraceParser,
        TraceCompiler:     c.opts.TraceCompiler,
        TraceOptimizer:    c.opts.TraceOptimizer,
        OptimizerMaxCycle: c.opts.OptimizerMaxCycle,
        OptimizeConst:     c.opts.OptimizeConst,
        OptimizeExpr:      c.opts.OptimizeExpr,
        moduleStore:       c.moduleStore,
        constsCache:       c.constsCache,
    })

    child.parent = c
    child.cfuncCache = c.cfuncCache

    if modulePath == c.modulePath {
        child.indent = c.indent
    }
    return child
}

func (c *Compiler) error(node parser.Node, err error) error {
    return &CompilerError{
        FileSet: c.file.Set(),
        Node:    node,
        Err:     err,
    }
}

func (c *Compiler) errorf(
    node parser.Node,
    format string,
    args ...interface{},
) error {
    return &CompilerError{
        FileSet: c.file.Set(),
        Node:    node,
        Err:     fmt.Errorf(format, args...),
    }
}

func printTrace(indent int, trace io.Writer, a ...interface{}) {
    const (
        dots = ". . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . "
        n    = len(dots)
    )

    i := 2 * indent
    for i > n {
        _, _ = fmt.Fprint(trace, dots)
        i -= n
    }

    _, _ = fmt.Fprint(trace, dots[0:i])
    _, _ = fmt.Fprintln(trace, a...)
}

func tracec(c *Compiler, msg string) *Compiler {
    printTrace(c.indent, c.trace, msg, "{")
    c.indent++
    return c
}

func untracec(c *Compiler) {
    c.indent--
    printTrace(c.indent, c.trace, "}")
}

// MakeInstruction returns a bytecode for an Opcode and the operands.
//
// Provide "buf" slice which is a returning value to reduce allocation or nil
// to create new byte slice. This is implemented to reduce compilation
// allocation that resulted in -15% allocation, +2% speed in compiler.
// It takes ~8ns/op with zero allocation.
//
// Returning error is required to identify bugs faster when VM and Opcodes are
// under heavy development.
//
// Warning: Unknown Opcode causes panic!
func MakeInstruction(buf []byte, op Opcode, args ...int) ([]byte, error) {
    operands := OpcodeOperands[op]
    if len(operands) != len(args) {
        return buf, fmt.Errorf(
            "MakeInstruction: %s expected %d operands, but got %d",
            OpcodeNames[op], len(operands), len(args),
        )
    }

    buf = append(buf[:0], op)
    switch op {
    case OpConstant, OpMap, OpArray, OpGetGlobal, OpSetGlobal, OpJump,
        OpJumpFalsy, OpAndJump, OpOrJump, OpStoreModule:
        buf = append(buf, byte(args[0]>>8))
        buf = append(buf, byte(args[0]))
        return buf, nil
    case OpLoadModule, OpSetupTry:
        buf = append(buf, byte(args[0]>>8))
        buf = append(buf, byte(args[0]))
        buf = append(buf, byte(args[1]>>8))
        buf = append(buf, byte(args[1]))
        return buf, nil
    case OpClosure:
        buf = append(buf, byte(args[0]>>8))
        buf = append(buf, byte(args[0]))
        buf = append(buf, byte(args[1]))
        return buf, nil
    case OpCall, OpCallName:
        buf = append(buf, byte(args[0]))
        buf = append(buf, byte(args[1]))
        return buf, nil
    case OpGetBuiltin, OpReturn, OpBinaryOp, OpUnary, OpGetIndex, OpGetLocal,
        OpSetLocal, OpGetFree, OpSetFree, OpGetLocalPtr, OpGetFreePtr, OpThrow,
        OpFinalizer, OpDefineLocal:
        buf = append(buf, byte(args[0]))
        return buf, nil
    case OpEqual, OpNotEqual, OpNull, OpTrue, OpFalse, OpPop, OpSliceIndex,
        OpSetIndex, OpIterInit, OpIterNext, OpIterKey, OpIterValue,
        OpSetupCatch, OpSetupFinally, OpNoOp:
        return buf, nil
    default:
        return buf, &Error{
            Name:    "MakeInstruction",
            Message: fmt.Sprintf("unknown Opcode %d %s", op, OpcodeNames[op]),
        }
    }
}

// FormatInstructions returns string representation of bytecode instructions.
func FormatInstructions(b []byte, posOffset int) []string {
    var out []string
    var operands = make([]int, 0, 4)
    var offset int
    var i int

    for i < len(b) {
        numOperands := OpcodeOperands[b[i]]
        operands, offset = ReadOperands(numOperands, b[i+1:], operands)

        switch len(numOperands) {
        case 0:
            out = append(out, fmt.Sprintf("%04d %-7s",
                posOffset+i, OpcodeNames[b[i]]))
        case 1:
            out = append(out, fmt.Sprintf("%04d %-7s %-5d",
                posOffset+i, OpcodeNames[b[i]], operands[0]))
        case 2:
            out = append(out, fmt.Sprintf("%04d %-7s %-5d %-5d",
                posOffset+i, OpcodeNames[b[i]],
                operands[0], operands[1]))
        }
        i += 1 + offset
    }
    return out
}

// IterateInstructions iterate instructions and call given function for each instruction.
// Note: Do not use operands slice in callback, it is reused for less allocation.
func IterateInstructions(insts []byte,
    fn func(pos int, opcode Opcode, operands []int, offset int) bool) {
    operands := make([]int, 0, 4)
    var offset int

    for i := 0; i < len(insts); i++ {
        numOperands := OpcodeOperands[insts[i]]
        operands, offset = ReadOperands(numOperands, insts[i+1:], operands)
        if !fn(i, insts[i], operands, offset) {
            break
        }
        i += offset
    }
}

func newModuleStore() *moduleStore {
    return &moduleStore{
        store: make(map[string]moduleStoreItem),
    }
}

func (ms *moduleStore) reset() *moduleStore {
    ms.count = 0
    for k := range ms.store {
        delete(ms.store, k)
    }
    return ms
}