pkg/infrastructure/ast/ast_java/ast_api_java/java_api_listener.go

Summary

Maintainability
A
25 mins
Test Coverage
package ast_api_java

import (
    "github.com/antlr/antlr4/runtime/Go/antlr/v4"
    "github.com/modernizing/coca/languages/java"
    api_domain2 "github.com/modernizing/coca/pkg/domain/api_domain"
    "github.com/modernizing/coca/pkg/domain/core_domain"
    "reflect"
    "strings"
)

var jClassNodes []core_domain.CodeDataStruct

var hasEnterClass = false
var isSpringRestController = false
var hasEnterRestController = false
var baseApiUrl string
var localVars = make(map[string]string)

var currentRestAPI api_domain2.RestAPI
var restAPIs []api_domain2.RestAPI
var currentClz string
var currentPkg string

var identMap map[string]core_domain.CodeDataStruct
var imports []string
var currentImplements = ""

func NewJavaAPIListener(jIdentMap map[string]core_domain.CodeDataStruct, diMap map[string]string) *JavaAPIListener {
    isSpringRestController = false
    currentClz = ""
    currentPkg = ""
    currentImplements = ""

    imports = nil
    restAPIs = nil

    identMap = jIdentMap

    params := make(map[string]string)
    currentRestAPI = api_domain2.RestAPI{MethodParams: params}
    return &JavaAPIListener{}
}

type JavaAPIListener struct {
    parser.BaseJavaParserListener
}

func (s *JavaAPIListener) EnterImportDeclaration(ctx *parser.ImportDeclarationContext) {
    importText := ctx.QualifiedName().GetText()
    imports = append(imports, importText)
}

func (s *JavaAPIListener) EnterPackageDeclaration(ctx *parser.PackageDeclarationContext) {
    currentPkg = ctx.QualifiedName().GetText()
}

func (s *JavaAPIListener) EnterClassDeclaration(ctx *parser.ClassDeclarationContext) {
    hasEnterClass = true
    if ctx.Identifier() != nil {
        currentClz = ctx.Identifier().GetText()
    }

    if ctx.IMPLEMENTS() != nil {
        currentImplements = ctx.TypeList(0).GetText()
    }
}

func (s *JavaAPIListener) ExitClassDeclaration(ctx *parser.ClassDeclarationContext) {
    hasEnterClass = false
}

func (s *JavaAPIListener) EnterAnnotation(ctx *parser.AnnotationContext) {
    if ctx.QualifiedName() == nil {
        return
    }

    annotationName := ctx.QualifiedName().GetText()
    if annotationName == "RestController" || annotationName == "Controller" {
        isSpringRestController = true
    }

    if !isSpringRestController {
        return
    }

    if !hasEnterClass {
        buildBaseApiUrlString(annotationName, ctx)
    }

    notAPI := annotationName == "RequestMapping" || annotationName == "GetMapping" || annotationName == "PutMapping" || annotationName == "PostMapping" || annotationName == "DeleteMapping"
    if !notAPI {
        return
    }

    hasEnterRestController = true
    uri := ""
    if ctx.ElementValue() != nil {
        uri = baseApiUrl + ctx.ElementValue().GetText()
    } else {
        uri = baseApiUrl
    }

    uriRemoveQuote := strings.ReplaceAll(uri, "\"", "")

    currentRestAPI = api_domain2.RestAPI{Uri: uriRemoveQuote}
    if annotationName != "RequestMapping" {
        if hasEnterClass {
            addApiMethod(annotationName)
        }

        return
    }

    if ctx.ElementValuePairs() != nil {
        allValuePair := ctx.ElementValuePairs().(*parser.ElementValuePairsContext).AllElementValuePair()
        for _, valuePair := range allValuePair {
            pair := valuePair.(*parser.ElementValuePairContext)
            if pair.Identifier().GetText() == "method" {
                addApiMethod(pair.ElementValue().GetText())
            }
            if pair.Identifier().GetText() == "value" {
                text := pair.ElementValue().GetText()
                currentRestAPI.Uri = baseApiUrl + text[1:len(text)-1]
            }
        }
    }
}

func buildBaseApiUrlString(annotationName string, ctx *parser.AnnotationContext) {
    // 类声明处的注解
    if annotationName == "RequestMapping" {
        if ctx.ElementValuePairs() != nil {
            allValuePair := ctx.ElementValuePairs().(*parser.ElementValuePairsContext).AllElementValuePair()
            for _, valuePair := range allValuePair {
                pair := valuePair.(*parser.ElementValuePairContext)
                if pair.Identifier().GetText() == "value" {
                    text := pair.ElementValue().GetText()
                    baseApiUrl = text[1 : len(text)-1]
                }
            }
        } else if ctx.ElementValue() != nil {
            text := ctx.ElementValue().GetText()
            baseApiUrl = text[1 : len(text)-1]
        } else {
            baseApiUrl = "/"
        }
    }
}

func addApiMethod(annotationName string) {
    switch annotationName {
    case
        "GetMapping",
        "RequestMethod.GET",
        "GET":
        currentRestAPI.HttpMethod = "GET"

    case
        "PutMapping",
        "RequestMethod.PUT",
        "PUT":
        currentRestAPI.HttpMethod = "PUT"

    case
        "PostMapping",
        "RequestMethod.POST",
        "POST":
        currentRestAPI.HttpMethod = "POST"

    case
        "DeleteMapping",
        "RequestMethod.DELETE",
        "DELETE":

        currentRestAPI.HttpMethod = "DELETE"
    }
}

var requestBodyClass string

func (s *JavaAPIListener) EnterMethodDeclaration(ctx *parser.MethodDeclarationContext) {
    methodName := ctx.Identifier().GetText()

    if currentImplements != "" {
        if buildApiForInterfaceAnnotation(methodName) {
            return
        }
    }

    if hasEnterRestController && ctx.FormalParameters() != nil {
        if ctx.FormalParameters().GetChild(0) == nil || ctx.FormalParameters().GetChild(1) == nil {
            return
        }

        currentRestAPI.PackageName = currentPkg
        currentRestAPI.ClassName = currentClz
        currentRestAPI.MethodName = methodName
        if ctx.FormalParameters().GetText() == "()" {
            currentRestAPI.RequestBodyClass = requestBodyClass
            hasEnterRestController = false
            requestBodyClass = ""
            restAPIs = append(restAPIs, currentRestAPI)
        } else {
            buildRestApiWithParameters(ctx)
        }
    }

    methodBody := ctx.MethodBody()
    blockContext := methodBody.GetChild(0)
    if reflect.TypeOf(blockContext).String() == "*parser.BlockContext" {
        filterMethodCall(blockContext)
    }
}

func buildApiForInterfaceAnnotation(methodName string) bool {
    var superClz = ""
    for index := range imports {
        imp := imports[index]
        if strings.HasSuffix(imp, "."+currentImplements) {
            superClz = imp
        }
        // TODO: 支持 interface 在同一个包内
    }

    if _, ok := identMap[superClz]; ok {
        for _, method := range identMap[superClz].Functions {
            if method.Name == methodName {
                for _, annotation := range method.Annotations {
                    if annotation.Name == "ServiceMethod" {
                        currentRestAPI.PackageName = currentPkg
                        currentRestAPI.ClassName = currentClz
                        currentRestAPI.MethodName = methodName

                        restAPIs = append(restAPIs, currentRestAPI)
                        return true
                    }
                }
            }
        }
    }
    return false
}

func filterMethodCall(blockContext antlr.Tree) {
    blcStatement := blockContext.(*parser.BlockContext).AllBlockStatement()
    for _, rangeStatement := range blcStatement {
        if reflect.TypeOf(rangeStatement.GetChild(0)).String() == "*parser.StatementContext" {
            statement := rangeStatement.GetChild(0).(*parser.StatementContext)
            if reflect.TypeOf(statement.GetChild(0)).String() == "*parser.ExpressionContext" {
                express := statement.GetChild(0).(*parser.ExpressionContext)
                reflect.TypeOf(express.GetChild(0))
            }
        }
    }
}

func buildRestApiWithParameters(ctx *parser.MethodDeclarationContext) {
    parameterList := ctx.FormalParameters().GetChild(1).(*parser.FormalParameterListContext)
    formalParameter := parameterList.AllFormalParameter()
    for _, param := range formalParameter {
        paramContext := param.(*parser.FormalParameterContext)

        modifiers := paramContext.AllVariableModifier()
        hasRequestBody := false
        for _, modifier := range modifiers {
            childType := reflect.TypeOf(modifier.GetChild(0))
            if childType.String() == "*parser.AnnotationContext" {
                qualifiedName := modifier.GetChild(0).(*parser.AnnotationContext).QualifiedName().GetText()
                if qualifiedName == "RequestBody" {
                    hasRequestBody = true
                }
            }
        }

        paramType := paramContext.TypeType().GetText()
        paramValue := paramContext.VariableDeclaratorId().(*parser.VariableDeclaratorIdContext).Identifier().GetText()

        if hasRequestBody {
            requestBodyClass = paramType
        }

        localVars[paramValue] = paramType
    }
    currentRestAPI.RequestBodyClass = requestBodyClass

    buildMethodParameters(requestBodyClass)

    hasEnterRestController = false
    requestBodyClass = ""
    restAPIs = append(restAPIs, currentRestAPI)
}

func buildMethodParameters(requestBodyClass string) {
    params := make(map[string]string)
    for _, clz := range jClassNodes {
        if clz.NodeName == requestBodyClass {
            for _, field := range clz.Fields {
                params[field.TypeValue] = field.TypeType
            }
        }
    }

    currentRestAPI.MethodParams = params
}

func (s *JavaAPIListener) AppendClasses(classes []core_domain.CodeDataStruct) {
    jClassNodes = classes
}

func (s *JavaAPIListener) GetClassApis() []api_domain2.RestAPI {
    return restAPIs
}