sunboyy/repogen

View on GitHub
internal/spec/parser.go

Summary

Maintainability
A
3 hrs
Test Coverage
package spec

import (
    "go/types"
    "strconv"

    "github.com/fatih/camelcase"
    "github.com/sunboyy/repogen/internal/code"
)

// ParseInterfaceMethod returns repository method spec from declared interface
// method.
func ParseInterfaceMethod(pkg *types.Package, namedStruct *types.Named,
    method *types.Func) (MethodSpec, error) {

    parser := interfaceMethodParser{
        NamedStruct:      namedStruct,
        UnderlyingStruct: namedStruct.Underlying().(*types.Struct),
        MethodName:       method.Name(),
        Signature:        method.Type().(*types.Signature),
    }

    return parser.Parse()
}

type interfaceMethodParser struct {
    NamedStruct      *types.Named
    UnderlyingStruct *types.Struct
    MethodName       string
    Signature        *types.Signature
}

func (p interfaceMethodParser) Parse() (MethodSpec, error) {
    operation, err := p.parseMethod()
    if err != nil {
        return MethodSpec{}, err
    }

    return MethodSpec{
        Name:      p.MethodName,
        Signature: p.Signature,
        Operation: operation,
    }, nil
}

func (p interfaceMethodParser) parseMethod() (Operation, error) {
    methodNameTokens := camelcase.Split(p.MethodName)
    switch methodNameTokens[0] {
    case "Insert":
        return p.parseInsertOperation(methodNameTokens[1:])
    case "Find":
        return p.parseFindOperation(methodNameTokens[1:])
    case "Update":
        return p.parseUpdateOperation(methodNameTokens[1:])
    case "Delete":
        return p.parseDeleteOperation(methodNameTokens[1:])
    case "Count":
        return p.parseCountOperation(methodNameTokens[1:])
    }
    return nil, NewUnknownOperationError(methodNameTokens[0])
}

func (p interfaceMethodParser) parseInsertOperation(tokens []string) (Operation, error) {
    mode, err := p.extractInsertReturns(p.Signature.Results())
    if err != nil {
        return nil, err
    }

    if err := p.validateContextParam(); err != nil {
        return nil, err
    }

    pointerType := types.NewPointer(p.NamedStruct)
    if mode == QueryModeOne && !types.Identical(p.Signature.Params().At(1).Type(), pointerType) {
        return nil, ErrInvalidParam
    }

    arrayType := types.NewSlice(pointerType)
    if mode == QueryModeMany && !types.Identical(p.Signature.Params().At(1).Type(), arrayType) {
        return nil, ErrInvalidParam
    }

    return InsertOperation{
        Mode: mode,
    }, nil
}

func (p interfaceMethodParser) extractInsertReturns(returns *types.Tuple) (QueryMode, error) {
    if returns.Len() != 2 {
        return "", NewOperationReturnCountUnmatchedError(2)
    }

    if !types.Identical(returns.At(1).Type(), code.TypeError) {
        return "", NewUnsupportedReturnError(returns.At(1).Type(), 1)
    }

    switch t := returns.At(0).Type().(type) {
    case *types.Interface:
        if t.Empty() {
            return QueryModeOne, nil
        }

    case *types.Slice:
        interfaceType, ok := t.Elem().(*types.Interface)
        if ok && interfaceType.Empty() {
            return QueryModeMany, nil
        }
    }

    return "", NewUnsupportedReturnError(returns.At(0).Type(), 0)
}

func (p interfaceMethodParser) parseFindOperation(tokens []string) (Operation, error) {
    mode, err := p.extractModelOrSliceReturns(p.Signature.Results())
    if err != nil {
        return nil, err
    }

    limit, tokens, err := p.parseFindTop(tokens)
    if err != nil {
        return nil, err
    }
    if mode == QueryModeOne && limit != 0 {
        return nil, ErrLimitOnFindOne
    }

    queryTokens, sortTokens := p.splitQueryAndSortTokens(tokens)

    querySpec, err := p.parseQuery(queryTokens, 1)
    if err != nil {
        return nil, err
    }

    sorts, err := p.parseSort(sortTokens)
    if err != nil {
        return nil, err
    }

    if err := p.validateQueryOnlyParams(querySpec); err != nil {
        return nil, err
    }

    return FindOperation{
        Mode:  mode,
        Query: querySpec,
        Sorts: sorts,
        Limit: limit,
    }, nil
}

func (p interfaceMethodParser) parseFindTop(tokens []string) (int, []string,
    error) {

    if len(tokens) >= 1 && tokens[0] == "Top" {
        if len(tokens) < 2 {
            return 0, nil, ErrLimitAmountRequired
        }

        limit, err := strconv.Atoi(tokens[1])
        if err != nil {
            return 0, nil, ErrLimitAmountRequired
        }

        if limit <= 0 {
            return 0, nil, ErrLimitNonPositive
        }
        return limit, tokens[2:], nil
    }

    return 0, tokens, nil
}

func (p interfaceMethodParser) parseSort(rawTokens []string) ([]Sort, error) {
    if len(rawTokens) == 0 {
        return nil, nil
    }

    sortTokens, ok := splitByAnd(rawTokens[2:])
    if !ok {
        return nil, NewInvalidSortError(rawTokens)
    }

    var sorts []Sort
    for _, token := range sortTokens {
        sort, err := p.parseSortToken(token)
        if err != nil {
            return nil, err
        }
        sorts = append(sorts, sort)
    }

    return sorts, nil
}

func (p interfaceMethodParser) parseSortToken(t []string) (Sort, error) {
    if len(t) > 1 && t[len(t)-1] == "Asc" {
        return p.createSort(t[:len(t)-1], OrderingAscending)
    }
    if len(t) > 1 && t[len(t)-1] == "Desc" {
        return p.createSort(t[:len(t)-1], OrderingDescending)
    }
    return p.createSort(t, OrderingAscending)
}

func (p interfaceMethodParser) createSort(t []string, ordering Ordering) (Sort, error) {
    fields, ok := resolveStructField(p.UnderlyingStruct, t)
    if !ok {
        return Sort{}, NewStructFieldNotFoundError(t)
    }

    return Sort{
        FieldReference: fields,
        Ordering:       ordering,
    }, nil
}

func (p interfaceMethodParser) splitQueryAndSortTokens(tokens []string) ([]string, []string) {
    var queryTokens []string
    var sortTokens []string

    for i, token := range tokens {
        if len(tokens) > i && token == "Order" && tokens[i+1] == "By" {
            sortTokens = tokens[i:]
            break
        } else {
            queryTokens = append(queryTokens, token)
        }
    }

    return queryTokens, sortTokens
}

func (p interfaceMethodParser) extractModelOrSliceReturns(returns *types.Tuple) (QueryMode, error) {
    if returns.Len() != 2 {
        return "", NewOperationReturnCountUnmatchedError(2)
    }

    if !types.Identical(returns.At(1).Type(), code.TypeError) {
        return "", NewUnsupportedReturnError(returns.At(1).Type(), 1)
    }

    switch t := returns.At(0).Type().(type) {
    case *types.Pointer:
        if types.Identical(t.Elem(), p.NamedStruct) {
            return QueryModeOne, nil
        }

    case *types.Slice:
        pointerType, ok := t.Elem().(*types.Pointer)
        if ok {
            if types.Identical(pointerType.Elem(), p.NamedStruct) {
                return QueryModeMany, nil
            }
        }
    }

    return "", NewUnsupportedReturnError(returns.At(0).Type(), 0)
}

func splitByAnd(tokens []string) ([][]string, bool) {
    var updateFieldTokens [][]string
    var aggregatedToken []string

    for _, token := range tokens {
        if token != "And" {
            aggregatedToken = append(aggregatedToken, token)
        } else if len(aggregatedToken) == 0 {
            return nil, false
        } else {
            updateFieldTokens = append(updateFieldTokens, aggregatedToken)
            aggregatedToken = nil
        }
    }
    if len(aggregatedToken) == 0 {
        return nil, false
    }
    updateFieldTokens = append(updateFieldTokens, aggregatedToken)

    return updateFieldTokens, true
}

func (p interfaceMethodParser) splitUpdateAndQueryTokens(tokens []string) ([]string, []string) {
    var updateTokens []string
    var queryTokens []string

    for i, token := range tokens {
        if token == "By" || token == "All" {
            queryTokens = tokens[i:]
            break
        } else {
            updateTokens = append(updateTokens, token)
        }
    }

    return updateTokens, queryTokens
}

func (p interfaceMethodParser) parseDeleteOperation(tokens []string) (Operation, error) {
    mode, err := p.extractIntOrBoolReturns(p.Signature.Results())
    if err != nil {
        return nil, err
    }

    querySpec, err := p.parseQuery(tokens, 1)
    if err != nil {
        return nil, err
    }

    if err := p.validateQueryOnlyParams(querySpec); err != nil {
        return nil, err
    }

    return DeleteOperation{
        Mode:  mode,
        Query: querySpec,
    }, nil
}

func (p interfaceMethodParser) parseCountOperation(tokens []string) (Operation, error) {
    if err := p.validateCountReturns(p.Signature.Results()); err != nil {
        return nil, err
    }

    querySpec, err := p.parseQuery(tokens, 1)
    if err != nil {
        return nil, err
    }

    if err := p.validateQueryOnlyParams(querySpec); err != nil {
        return nil, err
    }

    return CountOperation{
        Query: querySpec,
    }, nil
}

func (p interfaceMethodParser) validateCountReturns(returns *types.Tuple) error {
    if returns.Len() != 2 {
        return NewOperationReturnCountUnmatchedError(2)
    }

    if !types.Identical(returns.At(0).Type(), code.TypeInt) {
        return NewUnsupportedReturnError(returns.At(0).Type(), 0)
    }

    if !types.Identical(returns.At(1).Type(), code.TypeError) {
        return NewUnsupportedReturnError(returns.At(1).Type(), 1)
    }

    return nil
}

func (p interfaceMethodParser) extractIntOrBoolReturns(returns *types.Tuple) (QueryMode, error) {
    if returns.Len() != 2 {
        return "", NewOperationReturnCountUnmatchedError(2)
    }

    if !types.Identical(returns.At(1).Type(), code.TypeError) {
        return "", NewUnsupportedReturnError(returns.At(1).Type(), 1)
    }

    basicType, ok := returns.At(0).Type().(*types.Basic)
    if ok {
        if types.Identical(basicType, code.TypeBool) {
            return QueryModeOne, nil
        }
        if types.Identical(basicType, code.TypeInt) {
            return QueryModeMany, nil
        }
    }

    return "", NewUnsupportedReturnError(returns.At(0).Type(), 0)
}

func (p interfaceMethodParser) validateQueryOnlyParams(querySpec QuerySpec) error {
    if err := p.validateContextParam(); err != nil {
        return err
    }

    if err := p.validateQueryFromParams(p.Signature.Params(), 1, querySpec); err != nil {
        return err
    }

    return nil
}

func (p interfaceMethodParser) validateContextParam() error {
    if p.Signature.Params().Len() == 0 || p.Signature.Params().At(0).Type().String() != "context.Context" {
        return ErrContextParamRequired
    }
    return nil
}

func (p interfaceMethodParser) validateQueryFromParams(params *types.Tuple, startIndex int, querySpec QuerySpec) error {
    if params.Len()-startIndex != querySpec.NumberOfArguments() {
        return ErrInvalidParam
    }

    currentParamIndex := startIndex
    for _, predicate := range querySpec.Predicates {
        if (predicate.Comparator == ComparatorTrue || predicate.Comparator == ComparatorFalse) &&
            !types.Identical(predicate.FieldReference.ReferencedField().Var.Type(), code.TypeBool) {
            return NewIncompatibleComparatorError(predicate.Comparator,
                predicate.FieldReference.ReferencedField())
        }

        for i := 0; i < predicate.Comparator.NumberOfArguments(); i++ {
            requiredType := predicate.Comparator.ArgumentTypeFromFieldType(
                predicate.FieldReference.ReferencedField().Var.Type(),
            )

            if !types.Identical(params.At(currentParamIndex).Type(), requiredType) {
                return NewArgumentTypeNotMatchedError(predicate.FieldReference.ReferencingCode(), requiredType,
                    params.At(currentParamIndex).Type())
            }
            currentParamIndex++
        }
    }

    return nil
}

func (p interfaceMethodParser) parseQuery(queryTokens []string, paramIndex int) (QuerySpec, error) {
    queryParser := queryParser{
        UnderlyingStruct: p.UnderlyingStruct,
    }
    return queryParser.parseQuery(queryTokens, paramIndex)
}