generator.go
package generr
import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/token"
"io"
"strconv"
"github.com/iancoleman/strcase"
"github.com/pkg/errors"
"golang.org/x/tools/imports"
)
type Generator struct {
pkgName string
f *ast.File
ts *ast.TypeSpec
}
func NewGenerator(pkgName string, ts *ast.TypeSpec) *Generator {
return &Generator{
pkgName: pkgName,
f: &ast.File{},
ts: ts,
}
}
func (g *Generator) Generate() error {
g.AppendPackage()
return nil
}
func (g *Generator) AppendPackage() {
g.f.Name = &ast.Ident{
Name: g.pkgName,
}
}
func (g *Generator) AppendPkgErrorImportSpec() {
g.f.Imports = []*ast.ImportSpec{
{
Path: &ast.BasicLit{
Value: "github.com/pkg/errors",
},
},
}
}
func (g *Generator) AppendCheckFunction(withCause bool) error {
d, err := appendCheckFunction(g.ts, withCause)
if err != nil {
return err
}
g.f.Decls = append(g.f.Decls, d...)
return nil
}
func (g *Generator) AppendErrorImplementation(typename, message string) error {
d, err := appendErrorImplementation(g.ts, typename, message)
if err != nil {
return err
}
g.f.Decls = append(g.f.Decls, d...)
return nil
}
func (g *Generator) Out(w io.Writer) error {
var buf bytes.Buffer
if _, err := fmt.Fprintf(&buf, "// Code generated by \"generr\"; DO NOT EDIT.\n"); err != nil {
return err
}
if err := format.Node(&buf, token.NewFileSet(), g.f); err != nil {
return err
}
targ, err := imports.Process("", buf.Bytes(), &imports.Options{Comments: true})
if err != nil {
return err
}
if _, err := io.Copy(w, bytes.NewReader(targ)); err != nil {
return err
}
return nil
}
func appendCheckFunction(ts *ast.TypeSpec, withCause bool) ([]ast.Decl, error) {
it, ok := ts.Type.(*ast.InterfaceType)
if !ok {
return nil, errors.Errorf("type %+v is not a interface", ts.Type)
}
ft, ok := it.Methods.List[0].Type.(*ast.FuncType)
if !ok {
return nil, errors.Errorf("type %+v has no function", it)
}
var resultsList []*ast.Field
if ft.Results != nil {
resultsList = ft.Results.List
}
rtTypes := []*ast.Field{
{
Type: ast.NewIdent("bool"),
},
}
assignStr := "_"
var bodyStmt []ast.Stmt
if withCause {
bodyStmt = append(bodyStmt, &ast.AssignStmt{
Tok: token.ASSIGN,
Lhs: []ast.Expr{
ast.NewIdent("err"),
},
Rhs: []ast.Expr{
&ast.CallExpr{
Args: []ast.Expr{
ast.NewIdent("err"),
},
Fun: &ast.SelectorExpr{
X: ast.NewIdent("errors"),
Sel: ast.NewIdent("Cause"),
},
},
},
})
}
var ifbodyStmt []ast.Stmt
var returnExprs []ast.Expr
if len(resultsList) != 0 {
var list []ast.Expr
assignStr = "e"
for _, r := range resultsList {
rtTypes = append(rtTypes, &ast.Field{
Type: r.Type,
})
bodyStmt = append(bodyStmt, &ast.DeclStmt{
Decl: &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: r.Names[0],
Type: r.Type,
},
},
},
})
list = append(list, r.Names[0])
returnExprs = append(returnExprs, r.Names[0])
}
ifbodyStmt = append(ifbodyStmt, &ast.AssignStmt{
Tok: token.ASSIGN,
Lhs: list,
Rhs: []ast.Expr{&ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("e"),
Sel: it.Methods.List[0].Names[0],
},
}},
})
}
ifresults := []ast.Expr{ast.NewIdent("true")}
ifresults = append(ifresults, returnExprs...)
ifbodyStmt = append(ifbodyStmt, &ast.ReturnStmt{
Results: ifresults,
})
bodyStmt = append(bodyStmt, &ast.IfStmt{
Init: &ast.AssignStmt{
Lhs: []ast.Expr{
ast.NewIdent(assignStr),
ast.NewIdent("ok"),
},
Rhs: []ast.Expr{
&ast.TypeAssertExpr{
X: ast.NewIdent("err"),
Type: ts.Name,
},
},
Tok: token.DEFINE,
},
Cond: ast.NewIdent("ok"),
Body: &ast.BlockStmt{
List: ifbodyStmt,
},
})
bodyresults := []ast.Expr{ast.NewIdent("false")}
bodyresults = append(bodyresults, returnExprs...)
bodyStmt = append(bodyStmt, &ast.ReturnStmt{
Results: bodyresults,
})
name := "Is" + strcase.ToCamel(ts.Name.Name)
decls := []ast.Decl{
&ast.FuncDecl{
Name: ast.NewIdent(name),
Type: &ast.FuncType{
Params: &ast.FieldList{
List: []*ast.Field{
{
Names: []*ast.Ident{ast.NewIdent("err")},
Type: ast.NewIdent("error"),
},
},
},
Results: &ast.FieldList{
List: rtTypes,
},
},
Body: &ast.BlockStmt{
List: bodyStmt,
},
},
}
return decls, nil
}
func appendErrorImplementation(ts *ast.TypeSpec, typename, mes string) ([]ast.Decl, error) {
it, ok := ts.Type.(*ast.InterfaceType)
if !ok {
return nil, errors.Errorf("type %+v is not a interface", ts.Type)
}
ft, ok := it.Methods.List[0].Type.(*ast.FuncType)
if !ok {
return nil, errors.Errorf("type %+v has no function", it)
}
var decls []ast.Decl
name := typename
if name == "" {
name = strcase.ToCamel(ts.Name.Name)
}
var fields []*ast.Field
var rtTypes []*ast.Field
var rtExprs []ast.Expr
errorReturnExpr := &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("fmt"),
Sel: ast.NewIdent("Sprint"),
},
Args: []ast.Expr{
ast.NewIdent(strconv.Quote(ts.Name.Name)),
},
}
if ft.Results != nil {
message := ts.Name.Name
for _, f := range ft.Results.List {
camelName := strcase.ToCamel(f.Names[0].Name)
fields = append(fields, &ast.Field{
Names: []*ast.Ident{
ast.NewIdent(camelName),
},
Type: f.Type,
})
rtTypes = append(rtTypes, &ast.Field{
Type: f.Type,
})
rtExprs = append(rtExprs, &ast.SelectorExpr{
X: ast.NewIdent("e"),
Sel: ast.NewIdent(camelName),
})
message = fmt.Sprintf("%s %s: %s", message, camelName, "%v")
}
if mes == "" {
mes = message
}
args := []ast.Expr{ast.NewIdent(strconv.Quote(mes))}
args = append(args, rtExprs...)
errorReturnExpr = &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: ast.NewIdent("fmt"),
Sel: ast.NewIdent("Sprintf"),
},
Args: args,
}
}
decls = append(decls, &ast.GenDecl{
Tok: token.TYPE,
Specs: []ast.Spec{
&ast.TypeSpec{
Name: ast.NewIdent(name),
Type: &ast.StructType{
Fields: &ast.FieldList{
List: fields,
},
},
},
},
})
decls = append(decls, &ast.FuncDecl{
Recv: &ast.FieldList{
List: []*ast.Field{
{
Names: []*ast.Ident{ast.NewIdent("e")},
Type: ast.NewIdent("*" + name),
},
},
},
Type: &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{
List: rtTypes,
},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: rtExprs,
},
},
},
Name: it.Methods.List[0].Names[0],
})
decls = append(decls, &ast.FuncDecl{
Recv: &ast.FieldList{
List: []*ast.Field{
{
Names: []*ast.Ident{ast.NewIdent("e")},
Type: ast.NewIdent("*" + name),
},
},
},
Type: &ast.FuncType{
Params: &ast.FieldList{},
Results: &ast.FieldList{
List: []*ast.Field{
{
Type: ast.NewIdent("string"),
},
},
},
},
Body: &ast.BlockStmt{
List: []ast.Stmt{
&ast.ReturnStmt{
Results: []ast.Expr{
errorReturnExpr,
},
},
},
},
Name: ast.NewIdent("Error"),
})
return decls, nil
}