goruby/goruby

View on GitHub
evaluator/evaluator.go

Summary

Maintainability
F
6 days
Test Coverage
package evaluator

import (
    "fmt"
    "sort"
    "strings"

    "github.com/goruby/goruby/ast"
    "github.com/goruby/goruby/object"
    "github.com/pkg/errors"
)

type callContext struct {
    object.CallContext
}

func (c *callContext) Eval(node ast.Node, env object.Environment) (object.RubyObject, error) {
    return Eval(node, env)
}

type rubyObjects []object.RubyObject

func (r rubyObjects) Inspect() string {
    toS := make([]string, len(r))
    for i, e := range r {
        toS[i] = e.Inspect()
    }
    return strings.Join(toS, ", ")
}
func (r rubyObjects) Type() object.Type       { return "" }
func (r rubyObjects) Class() object.RubyClass { return nil }

func expandToArrayIfNeeded(obj object.RubyObject) object.RubyObject {
    arr, ok := obj.(rubyObjects)
    if !ok {
        return obj
    }
    return object.NewArray(arr...)
}

// Eval evaluates the given node and traverses recursive over its children
func Eval(node ast.Node, env object.Environment) (object.RubyObject, error) {
    switch node := node.(type) {

    // Statements
    case *ast.Program:
        return evalProgram(node.Statements, env)
    case *ast.ExpressionStatement:
        return Eval(node.Expression, env)
    case *ast.ReturnStatement:
        val, err := Eval(node.ReturnValue, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval of return statement")
        }
        return &object.ReturnValue{Value: val}, nil
    case *ast.BlockStatement:
        return evalBlockStatement(node, env)

    // Literals
    case (*ast.IntegerLiteral):
        return object.NewInteger(node.Value), nil
    case (*ast.Boolean):
        return nativeBoolToBooleanObject(node.Value), nil
    case (*ast.Nil):
        return object.NIL, nil
    case (*ast.Self):
        self, _ := env.Get("self")
        return self, nil
    case (*ast.Keyword__FILE__):
        return &object.String{Value: node.Filename}, nil
    case (*ast.InstanceVariable):
        self, _ := env.Get("self")
        selfObj := self.(*object.Self)
        selfAsEnv, ok := selfObj.RubyObject.(object.Environment)
        if !ok {
            return nil, errors.WithStack(
                object.NewSyntaxError(
                    fmt.Errorf("instance variable not allowed for %s", selfObj.Name),
                ),
            )
        }

        val, ok := selfAsEnv.Get(node.String())
        if !ok {
            return object.NIL, nil
        }
        return val, nil
    case *ast.Identifier:
        return evalIdentifier(node, env)
    case *ast.Global:
        val, ok := env.Get(node.Value)
        if !ok {
            return object.NIL, nil
        }
        return val, nil
    case *ast.StringLiteral:
        return &object.String{Value: node.Value}, nil
    case *ast.SymbolLiteral:
        switch value := node.Value.(type) {
        case *ast.Identifier:
            return &object.Symbol{Value: value.Value}, nil
        case *ast.StringLiteral:
            str, err := Eval(value, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval symbol literal string")
            }
            if str, ok := str.(*object.String); ok {
                return &object.Symbol{Value: str.Value}, nil
            }
            panic(errors.WithStack(
                fmt.Errorf("error while parsing SymbolLiteral: expected *object.String, got %T", str),
            ))
        default:
            return nil, errors.WithStack(
                object.NewSyntaxError(fmt.Errorf("malformed symbol AST: %T", value)),
            )
        }
    case *ast.FunctionLiteral:
        context, _ := env.Get("self")
        _, inClassOrModule := context.(*object.Self).RubyObject.(object.Environment)
        if node.Receiver != nil {
            rec, err := Eval(node.Receiver, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval function receiver")
            }
            context = rec
            _, recIsEnv := context.(object.Environment)
            if recIsEnv || inClassOrModule {
                inClassOrModule = true
                context = context.Class().(object.RubyClassObject)
            }
        }
        params := make([]*object.FunctionParameter, len(node.Parameters))
        for i, param := range node.Parameters {
            def, err := Eval(param.Default, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval function literal param")
            }
            params[i] = &object.FunctionParameter{Name: param.Name.Value, Default: def}
        }
        body := node.Body
        function := &object.Function{
            Parameters: params,
            Env:        env,
            Body:       body,
        }
        extended := object.AddMethod(context, node.Name.Value, function)
        if node.Receiver != nil && !inClassOrModule {
            envInfo, _ := object.EnvStat(env, context)
            envInfo.Env().Set(node.Receiver.Value, extended)
        }
        return &object.Symbol{Value: node.Name.Value}, nil
    case *ast.BlockExpression:
        params := node.Parameters
        body := node.Body
        block := &object.Proc{
            Parameters: params,
            Body:       body,
            Env:        env,
        }
        return block, nil
    case *ast.ArrayLiteral:
        elements, err := evalExpressions(node.Elements, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval array literal")
        }
        return &object.Array{Elements: elements}, nil
    case *ast.HashLiteral:
        var hash object.Hash
        for k, v := range node.Map {
            key, err := Eval(k, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval hash key")
            }
            value, err := Eval(v, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval hash value")
            }
            hash.Set(key, value)
        }
        return &hash, nil
    case ast.ExpressionList:
        var objects []object.RubyObject
        for _, e := range node {
            obj, err := Eval(e, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval expression list")
            }
            objects = append(objects, obj)
        }
        return rubyObjects(objects), nil

    // Expressions
    case *ast.Assignment:
        right, err := Eval(node.Right, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval right hand Assignment side")
        }

        switch left := node.Left.(type) {
        case *ast.IndexExpression:
            indexLeft, err := Eval(left.Left, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval left hand Assignment side: eval left side of IndexExpression")
            }
            index, err := Eval(left.Index, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval left hand Assignment side: eval right side of IndexExpression")
            }
            return evalIndexExpressionAssignment(indexLeft, index, expandToArrayIfNeeded(right))
        case *ast.InstanceVariable:
            self, _ := env.Get("self")
            selfObj := self.(*object.Self)
            selfAsEnv, ok := selfObj.RubyObject.(object.Environment)
            if !ok {
                return nil, errors.Wrap(
                    object.NewSyntaxError(fmt.Errorf("instance variable not allowed for %s", selfObj.Name)),
                    "eval left hand Assignment side",
                )
            }

            right = expandToArrayIfNeeded(right)
            selfAsEnv.Set(left.String(), right)
            return right, nil
        case *ast.Identifier:
            right = expandToArrayIfNeeded(right)
            env.Set(left.Value, right)
            return right, nil
        case *ast.Global:
            right = expandToArrayIfNeeded(right)
            env.SetGlobal(left.Value, right)
            return right, nil
        case ast.ExpressionList:
            values := []object.RubyObject{right}
            if list, ok := right.(rubyObjects); ok {
                values = list
            }
            if len(left) > len(values) {
                // enlarge slice
                for len(values) <= len(left) {
                    values = append(values, object.NIL)
                }
            }
            for i, exp := range left {
                if _, ok := exp.(*ast.InstanceVariable); ok {
                    self, _ := env.Get("self")
                    selfObj := self.(*object.Self)
                    selfAsEnv, ok := selfObj.RubyObject.(object.Environment)
                    if !ok {
                        return nil, errors.Wrap(
                            object.NewSyntaxError(fmt.Errorf("instance variable not allowed for %s", selfObj.Name)),
                            "eval left hand Assignment side",
                        )
                    }

                    selfAsEnv.Set(exp.String(), values[i])
                    continue
                }
                if indexExp, ok := exp.(*ast.IndexExpression); ok {
                    indexLeft, err := Eval(indexExp.Left, env)
                    if err != nil {
                        return nil, errors.WithMessage(err, "eval left hand Assignment side: eval left side of IndexExpression")
                    }
                    index, err := Eval(indexExp.Index, env)
                    if err != nil {
                        return nil, errors.WithMessage(err, "eval left hand Assignment side: eval right side of IndexExpression")
                    }
                    evalIndexExpressionAssignment(indexLeft, index, values[i])
                    continue
                }
                env.Set(exp.String(), values[i])
            }
            return expandToArrayIfNeeded(right), nil
        default:
            return nil, errors.WithStack(
                object.NewSyntaxError(fmt.Errorf("Assignment not supported to %T", node.Left)),
            )
        }
    case *ast.ModuleExpression:
        module, ok := env.Get(node.Name.Value)
        if !ok {
            module = object.NewModule(node.Name.Value, env)
        }
        moduleEnv := module.(object.Environment)
        moduleEnv.Set("self", &object.Self{RubyObject: module, Name: node.Name.Value})
        bodyReturn, err := Eval(node.Body, moduleEnv)
        if err != nil {
            return nil, errors.WithMessage(err, "eval Module body")
        }
        selfObject, _ := moduleEnv.Get("self")
        self := selfObject.(*object.Self)
        env.Set(node.Name.Value, self.RubyObject)
        return bodyReturn, nil
    case *ast.ClassExpression:
        superClassName := "Object"
        if node.SuperClass != nil {
            superClassName = node.SuperClass.Value
        }
        superClass, ok := env.Get(superClassName)
        if !ok {
            return nil, errors.Wrap(
                object.NewUninitializedConstantNameError(superClassName),
                "eval class superclass",
            )
        }
        class, ok := env.Get(node.Name.Value)
        if !ok {
            class = object.NewClass(node.Name.Value, superClass.(object.RubyClassObject), env)
        }
        classEnv := class.(object.Environment)
        classEnv.Set("self", &object.Self{RubyObject: class, Name: node.Name.Value})
        bodyReturn, err := Eval(node.Body, classEnv)
        if err != nil {
            return nil, errors.WithMessage(err, "eval class body")
        }
        selfObject, _ := classEnv.Get("self")
        self := selfObject.(*object.Self)
        env.Set(node.Name.Value, self.RubyObject)
        return bodyReturn, nil
    case *ast.ContextCallExpression:
        context, err := Eval(node.Context, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval method call receiver")
        }
        if context == nil {
            context, _ = env.Get("self")
        }
        args, err := evalExpressions(node.Arguments, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval method call arguments")
        }
        if node.Block != nil {
            block, err := Eval(node.Block, env)
            if err != nil {
                return nil, errors.WithMessage(err, "eval method call block")
            }
            args = append(args, block)
        }
        callContext := &callContext{object.NewCallContext(env, context)}
        return object.Send(callContext, node.Function.Value, args...)
    case *ast.YieldExpression:
        selfObject, _ := env.Get("self")
        self := selfObject.(*object.Self)
        if self.Block == nil {
            return nil, errors.WithStack(object.NewNoBlockGivenLocalJumpError())
        }
        args, err := evalExpressions(node.Arguments, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval yield arguments")
        }
        callContext := &callContext{object.NewCallContext(env, self)}
        return self.Block.Call(callContext, args...)
    case *ast.IndexExpression:
        left, err := Eval(node.Left, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval IndexExpression left side")
        }
        index, err := Eval(node.Index, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval IndexExpression index")
        }
        return evalIndexExpression(left, index)
    case *ast.PrefixExpression:
        right, err := Eval(node.Right, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval prefix right side")
        }
        return evalPrefixExpression(node.Operator, right)
    case *ast.InfixExpression:
        left, err := Eval(node.Left, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval operator left side")
        }

        if node.IsControlExpression() && !node.MustEvaluateRight() && isTruthy(left) {
            return left, nil
        }

        right, err := Eval(node.Right, env)
        if err != nil {
            return nil, errors.WithMessage(err, "eval operator right side")
        }
        if node.IsControlExpression() {
            return right, nil
        }
        context := &callContext{object.NewCallContext(env, left)}
        return object.Send(context, node.Operator, right)
    case *ast.ConditionalExpression:
        return evalConditionalExpression(node, env)
    case *ast.ScopedIdentifier:
        self, _ := env.Get("self")
        outer, ok := env.Get(node.Outer.Value)
        if !ok {
            return nil, errors.Wrap(
                object.NewUndefinedLocalVariableOrMethodNameError(self, node.Outer.Value),
                "eval scope outer",
            )
        }
        outerEnv, ok := outer.(object.Environment)
        if !ok {
            return nil, errors.Wrap(
                object.NewUndefinedLocalVariableOrMethodNameError(self, node.Outer.Value),
                "eval scope outer",
            )
        }
        inner, err := Eval(node.Inner, outerEnv)
        if err != nil {
            return nil, errors.WithMessage(err, "eval scope inner")
        }
        return inner, nil
    case *ast.ExceptionHandlingBlock:
        bodyReturn, err := Eval(node.TryBody, env)
        if err == nil {
            return bodyReturn, nil
        }
        return handleException(err, node.Rescues, env)

    case *ast.Comment:
        // ignore comments
        return nil, nil

    case nil:
        return nil, nil
    default:
        err := object.NewException("Unknown AST: %T", node)
        return nil, errors.WithStack(err)
    }

}

func evalProgram(stmts []ast.Statement, env object.Environment) (object.RubyObject, error) {
    var result object.RubyObject
    var err error
    for _, statement := range stmts {
        if _, ok := statement.(*ast.Comment); ok {
            continue
        }
        result, err = Eval(statement, env)

        if err != nil {
            return nil, errors.WithMessage(err, "eval program statement")
        }

        switch result := result.(type) {
        case *object.ReturnValue:
            return result.Value, nil
        }

    }
    return result, nil
}

func evalExpressions(exps []ast.Expression, env object.Environment) ([]object.RubyObject, error) {
    var result []object.RubyObject

    for _, e := range exps {
        evaluated, err := Eval(e, env)
        if err != nil {
            return nil, err
        }
        result = append(result, evaluated)
    }
    return result, nil
}

func evalPrefixExpression(operator string, right object.RubyObject) (object.RubyObject, error) {
    switch operator {
    case "!":
        return evalBangOperatorExpression(right), nil
    case "-":
        return evalMinusPrefixOperatorExpression(right)
    default:
        return nil, errors.WithStack(object.NewException("unknown operator: %s%s", operator, right.Type()))
    }
}

func evalBangOperatorExpression(right object.RubyObject) object.RubyObject {
    switch right {
    case object.TRUE:
        return object.FALSE
    case object.FALSE:
        return object.TRUE
    case object.NIL:
        return object.TRUE
    default:
        return object.FALSE
    }
}

func evalMinusPrefixOperatorExpression(right object.RubyObject) (object.RubyObject, error) {
    switch right := right.(type) {
    case *object.Integer:
        return &object.Integer{Value: -right.Value}, nil
    default:
        return nil, errors.WithStack(object.NewException("unknown operator: -%s", right.Type()))
    }
}

func evalConditionalExpression(ce *ast.ConditionalExpression, env object.Environment) (object.RubyObject, error) {
    condition, err := Eval(ce.Condition, env)
    if err != nil {
        return nil, err
    }
    evaluateConsequence := isTruthy(condition)
    if ce.IsNegated() {
        evaluateConsequence = !evaluateConsequence
    }
    if evaluateConsequence {
        return Eval(ce.Consequence, env)
    } else if ce.Alternative != nil {
        return Eval(ce.Alternative, env)
    } else {
        return object.NIL, nil
    }
}

func evalIndexExpressionAssignment(left, index, right object.RubyObject) (object.RubyObject, error) {
    switch target := left.(type) {
    case *object.Array:
        integer, ok := index.(*object.Integer)
        if !ok {
            return nil, errors.Wrap(
                object.NewImplicitConversionTypeError(integer, index),
                "eval array index",
            )
        }
        idx := int(integer.Value)
        if idx >= len(target.Elements) {
            // enlarge slice
            for len(target.Elements) <= idx {
                target.Elements = append(target.Elements, object.NIL)
            }
        }
        target.Elements[idx] = right
        return right, nil
    case *object.Hash:
        target.Set(index, right)
        return right, nil
    default:
        return nil, errors.Wrap(
            object.NewException("assignment target not supported: %s", left.Type()),
            "eval IndexExpression Assignment",
        )
    }
}

func evalIndexExpression(left, index object.RubyObject) (object.RubyObject, error) {
    switch target := left.(type) {
    case *object.Array:
        return evalArrayIndexExpression(target, index), nil
    case *object.Hash:
        return evalHashIndexExpression(target, index), nil
    default:
        return nil, errors.WithStack(object.NewException("index operator not supported: %s", left.Type()))
    }
}

func evalArrayIndexExpression(arrayObject *object.Array, index object.RubyObject) object.RubyObject {
    idx := index.(*object.Integer).Value
    maxNegative := -int64(len(arrayObject.Elements))
    maxPositive := maxNegative*-1 - 1
    if maxPositive < 0 {
        return object.NIL
    }

    if idx > 0 && idx > maxPositive {
        return object.NIL
    }
    if idx < 0 && idx < maxNegative {
        return object.NIL
    }
    if idx < 0 {
        return arrayObject.Elements[len(arrayObject.Elements)+int(idx)]
    }
    return arrayObject.Elements[idx]
}

func evalHashIndexExpression(hash *object.Hash, index object.RubyObject) object.RubyObject {
    result, ok := hash.Get(index)
    if !ok {
        return object.NIL
    }
    return result
}

func evalBlockStatement(block *ast.BlockStatement, env object.Environment) (object.RubyObject, error) {
    var result object.RubyObject
    var err error
    for _, statement := range block.Statements {
        result, err = Eval(statement, env)
        if err != nil {
            return nil, err
        }
        if result != nil {
            rt := result.Type()
            if rt == object.RETURN_VALUE_OBJ {
                return result, nil
            }

        }
    }
    if result == nil {
        return object.NIL, nil
    }
    return result, nil
}

func evalIdentifier(node *ast.Identifier, env object.Environment) (object.RubyObject, error) {
    val, ok := env.Get(node.Value)
    if ok {
        return val, nil
    }

    if node.IsConstant() {
        return nil, errors.Wrap(
            object.NewUninitializedConstantNameError(node.Value),
            "eval identifier",
        )
    }

    self, _ := env.Get("self")
    context := &callContext{object.NewCallContext(env, self)}
    val, err := object.Send(context, node.Value)
    if err != nil {
        return nil, errors.Wrap(
            object.NewUndefinedLocalVariableOrMethodNameError(self, node.Value),
            "eval ident as method call",
        )
    }
    return val, nil
}

func unwrapReturnValue(obj object.RubyObject) object.RubyObject {
    if returnValue, ok := obj.(*object.ReturnValue); ok {
        return returnValue.Value
    }
    return obj
}

func handleException(err error, rescues []*ast.RescueBlock, env object.Environment) (object.RubyObject, error) {
    if err != nil && len(rescues) == 0 {
        return nil, err
    }
    errorObject := err.(object.RubyObject)
    errClass := errorObject.Class().Name()
    rescueEnv := object.WithScopedLocalVariables(env)

    var catchAll *ast.RescueBlock
    for _, r := range rescues {
        if len(r.ExceptionClasses) == 0 {
            catchAll = r
            continue
        }
        if r.Exception != nil {
            rescueEnv.Set(r.Exception.Value, errorObject)
        }
        for _, cl := range r.ExceptionClasses {
            if cl.Value == errClass {
                rescueRet, err := Eval(r.Body, rescueEnv)
                return rescueRet, err
            }
        }
    }

    if catchAll != nil {
        ancestors := getAncestors(errorObject)
        sort.Strings(ancestors)
        if sort.SearchStrings(ancestors, "StandardError") >= len(ancestors) {
            return nil, err
        }

        if catchAll.Exception != nil {
            rescueEnv.Set(catchAll.Exception.Value, errorObject)
        }
        rescueRet, err := Eval(catchAll.Body, rescueEnv)
        return rescueRet, err
    }

    return nil, err
}

func getAncestors(obj object.RubyObject) []string {
    class := obj.Class()
    if c, ok := obj.(object.RubyClass); ok {
        class = c
    }
    var ancestors []string
    ancestors = append(ancestors, class.Name())

    superClass := class.SuperClass()
    if superClass != nil {
        superAncestors := getAncestors(superClass.(object.RubyClassObject))
        ancestors = append(ancestors, superAncestors...)
    }
    return ancestors
}

func isTruthy(obj object.RubyObject) bool {
    switch obj {
    case object.NIL:
        return false
    case object.TRUE:
        return true
    case object.FALSE:
        return false
    default:
        return true
    }
}

// IsError returns true if the given RubyObject is an object.Error or an
// object.Exception (or any subclass of object.Exception)
func IsError(obj object.RubyObject) bool {
    if obj != nil {
        return obj.Type() == object.EXCEPTION_OBJ
    }
    return false
}

func nativeBoolToBooleanObject(input bool) object.RubyObject {
    if input {
        return object.TRUE
    }
    return object.FALSE
}