akito0107/generr

View on GitHub
generator.go

Summary

Maintainability
C
1 day
Test Coverage
package generr

import (
    "bytes"
    "fmt"
    "go/ast"
    "go/format"
    "go/token"
    "io"
    "strconv"

    "github.com/iancoleman/strcase"
    "github.com/pkg/errors"
    "golang.org/x/tools/imports"
)

type Generator struct {
    pkgName string
    f       *ast.File
    ts      *ast.TypeSpec
}

func NewGenerator(pkgName string, ts *ast.TypeSpec) *Generator {
    return &Generator{
        pkgName: pkgName,
        f:       &ast.File{},
        ts:      ts,
    }
}

func (g *Generator) Generate() error {
    g.AppendPackage()
    return nil
}

func (g *Generator) AppendPackage() {
    g.f.Name = &ast.Ident{
        Name: g.pkgName,
    }
}

func (g *Generator) AppendPkgErrorImportSpec() {
    g.f.Imports = []*ast.ImportSpec{
        {
            Path: &ast.BasicLit{
                Value: "github.com/pkg/errors",
            },
        },
    }
}

func (g *Generator) AppendCheckFunction(withCause bool) error {
    d, err := appendCheckFunction(g.ts, withCause)
    if err != nil {
        return err
    }
    g.f.Decls = append(g.f.Decls, d...)
    return nil
}

func (g *Generator) AppendErrorImplementation(typename, message string) error {
    d, err := appendErrorImplementation(g.ts, typename, message)
    if err != nil {
        return err
    }
    g.f.Decls = append(g.f.Decls, d...)
    return nil
}

func (g *Generator) Out(w io.Writer) error {
    var buf bytes.Buffer
    if _, err := fmt.Fprintf(&buf, "// Code generated by \"generr\"; DO NOT EDIT.\n"); err != nil {
        return err
    }
    if err := format.Node(&buf, token.NewFileSet(), g.f); err != nil {
        return err
    }
    targ, err := imports.Process("", buf.Bytes(), &imports.Options{Comments: true})
    if err != nil {
        return err
    }
    if _, err := io.Copy(w, bytes.NewReader(targ)); err != nil {
        return err
    }
    return nil
}

func appendCheckFunction(ts *ast.TypeSpec, withCause bool) ([]ast.Decl, error) {
    it, ok := ts.Type.(*ast.InterfaceType)
    if !ok {
        return nil, errors.Errorf("type %+v is not a interface", ts.Type)
    }
    ft, ok := it.Methods.List[0].Type.(*ast.FuncType)
    if !ok {
        return nil, errors.Errorf("type %+v has no function", it)
    }
    var resultsList []*ast.Field
    if ft.Results != nil {
        resultsList = ft.Results.List
    }

    rtTypes := []*ast.Field{
        {
            Type: ast.NewIdent("bool"),
        },
    }

    assignStr := "_"
    var bodyStmt []ast.Stmt

    if withCause {
        bodyStmt = append(bodyStmt, &ast.AssignStmt{
            Tok: token.ASSIGN,
            Lhs: []ast.Expr{
                ast.NewIdent("err"),
            },
            Rhs: []ast.Expr{
                &ast.CallExpr{
                    Args: []ast.Expr{
                        ast.NewIdent("err"),
                    },
                    Fun: &ast.SelectorExpr{
                        X:   ast.NewIdent("errors"),
                        Sel: ast.NewIdent("Cause"),
                    },
                },
            },
        })
    }

    var ifbodyStmt []ast.Stmt
    var returnExprs []ast.Expr

    if len(resultsList) != 0 {
        var list []ast.Expr
        assignStr = "e"

        for _, r := range resultsList {
            rtTypes = append(rtTypes, &ast.Field{
                Type: r.Type,
            })
            bodyStmt = append(bodyStmt, &ast.DeclStmt{
                Decl: &ast.GenDecl{
                    Tok: token.VAR,
                    Specs: []ast.Spec{
                        &ast.TypeSpec{
                            Name: r.Names[0],
                            Type: r.Type,
                        },
                    },
                },
            })
            list = append(list, r.Names[0])
            returnExprs = append(returnExprs, r.Names[0])
        }

        ifbodyStmt = append(ifbodyStmt, &ast.AssignStmt{
            Tok: token.ASSIGN,
            Lhs: list,
            Rhs: []ast.Expr{&ast.CallExpr{
                Fun: &ast.SelectorExpr{
                    X:   ast.NewIdent("e"),
                    Sel: it.Methods.List[0].Names[0],
                },
            }},
        })
    }

    ifresults := []ast.Expr{ast.NewIdent("true")}
    ifresults = append(ifresults, returnExprs...)
    ifbodyStmt = append(ifbodyStmt, &ast.ReturnStmt{
        Results: ifresults,
    })

    bodyStmt = append(bodyStmt, &ast.IfStmt{
        Init: &ast.AssignStmt{
            Lhs: []ast.Expr{
                ast.NewIdent(assignStr),
                ast.NewIdent("ok"),
            },
            Rhs: []ast.Expr{
                &ast.TypeAssertExpr{
                    X:    ast.NewIdent("err"),
                    Type: ts.Name,
                },
            },
            Tok: token.DEFINE,
        },
        Cond: ast.NewIdent("ok"),
        Body: &ast.BlockStmt{
            List: ifbodyStmt,
        },
    })

    bodyresults := []ast.Expr{ast.NewIdent("false")}
    bodyresults = append(bodyresults, returnExprs...)
    bodyStmt = append(bodyStmt, &ast.ReturnStmt{
        Results: bodyresults,
    })

    name := "Is" + strcase.ToCamel(ts.Name.Name)

    decls := []ast.Decl{
        &ast.FuncDecl{
            Name: ast.NewIdent(name),
            Type: &ast.FuncType{
                Params: &ast.FieldList{
                    List: []*ast.Field{
                        {
                            Names: []*ast.Ident{ast.NewIdent("err")},
                            Type:  ast.NewIdent("error"),
                        },
                    },
                },
                Results: &ast.FieldList{
                    List: rtTypes,
                },
            },
            Body: &ast.BlockStmt{
                List: bodyStmt,
            },
        },
    }

    return decls, nil
}

func appendErrorImplementation(ts *ast.TypeSpec, typename, mes string) ([]ast.Decl, error) {
    it, ok := ts.Type.(*ast.InterfaceType)
    if !ok {
        return nil, errors.Errorf("type %+v is not a interface", ts.Type)
    }
    ft, ok := it.Methods.List[0].Type.(*ast.FuncType)
    if !ok {
        return nil, errors.Errorf("type %+v has no function", it)
    }

    var decls []ast.Decl
    name := typename
    if name == "" {
        name = strcase.ToCamel(ts.Name.Name)
    }

    var fields []*ast.Field
    var rtTypes []*ast.Field
    var rtExprs []ast.Expr
    errorReturnExpr := &ast.CallExpr{
        Fun: &ast.SelectorExpr{
            X:   ast.NewIdent("fmt"),
            Sel: ast.NewIdent("Sprint"),
        },
        Args: []ast.Expr{
            ast.NewIdent(strconv.Quote(ts.Name.Name)),
        },
    }

    if ft.Results != nil {
        message := ts.Name.Name
        for _, f := range ft.Results.List {
            camelName := strcase.ToCamel(f.Names[0].Name)
            fields = append(fields, &ast.Field{
                Names: []*ast.Ident{
                    ast.NewIdent(camelName),
                },
                Type: f.Type,
            })

            rtTypes = append(rtTypes, &ast.Field{
                Type: f.Type,
            })

            rtExprs = append(rtExprs, &ast.SelectorExpr{
                X:   ast.NewIdent("e"),
                Sel: ast.NewIdent(camelName),
            })
            message = fmt.Sprintf("%s %s: %s", message, camelName, "%v")
        }
        if mes == "" {
            mes = message
        }
        args := []ast.Expr{ast.NewIdent(strconv.Quote(mes))}
        args = append(args, rtExprs...)
        errorReturnExpr = &ast.CallExpr{
            Fun: &ast.SelectorExpr{
                X:   ast.NewIdent("fmt"),
                Sel: ast.NewIdent("Sprintf"),
            },
            Args: args,
        }
    }

    decls = append(decls, &ast.GenDecl{
        Tok: token.TYPE,
        Specs: []ast.Spec{
            &ast.TypeSpec{
                Name: ast.NewIdent(name),
                Type: &ast.StructType{
                    Fields: &ast.FieldList{
                        List: fields,
                    },
                },
            },
        },
    })

    decls = append(decls, &ast.FuncDecl{
        Recv: &ast.FieldList{
            List: []*ast.Field{
                {
                    Names: []*ast.Ident{ast.NewIdent("e")},
                    Type:  ast.NewIdent("*" + name),
                },
            },
        },
        Type: &ast.FuncType{
            Params: &ast.FieldList{},
            Results: &ast.FieldList{
                List: rtTypes,
            },
        },
        Body: &ast.BlockStmt{
            List: []ast.Stmt{
                &ast.ReturnStmt{
                    Results: rtExprs,
                },
            },
        },
        Name: it.Methods.List[0].Names[0],
    })

    decls = append(decls, &ast.FuncDecl{
        Recv: &ast.FieldList{
            List: []*ast.Field{
                {
                    Names: []*ast.Ident{ast.NewIdent("e")},
                    Type:  ast.NewIdent("*" + name),
                },
            },
        },
        Type: &ast.FuncType{
            Params: &ast.FieldList{},
            Results: &ast.FieldList{
                List: []*ast.Field{
                    {
                        Type: ast.NewIdent("string"),
                    },
                },
            },
        },
        Body: &ast.BlockStmt{
            List: []ast.Stmt{
                &ast.ReturnStmt{
                    Results: []ast.Expr{
                        errorReturnExpr,
                    },
                },
            },
        },
        Name: ast.NewIdent("Error"),
    })

    return decls, nil
}