sunboyy/repogen

View on GitHub
main.go

Summary

Maintainability
A
2 hrs
Test Coverage
package main

import (
    "errors"
    "flag"
    "fmt"
    "log"
    "os"
    "path/filepath"

    "github.com/sunboyy/repogen/internal/generator"
    "golang.org/x/tools/go/packages"
)

const usageText = `repogen generates MongoDB repository implementation from repository interface

  Find more information at: https://github.com/sunboyy/repogen

Supported options:`

// version indicates the version of repogen.
const version = "v0.4-next"

func main() {
    flag.Usage = printUsage

    versionPtr := flag.Bool("version", false, "print version of repogen")
    pkgPtr := flag.String(
        "pkg",
        ".",
        "package directory to scan for model struct and repository interface",
    )
    destPtr := flag.String("dest", "", "destination file")
    modelPtr := flag.String("model", "", "model struct name")
    repoPtr := flag.String("repo", "", "repository interface name")
    modelPkgPtr := flag.String(
        "model-pkg",
        "",
        "package directory to scan for model struct. If not set, will fallback to -pkg.",
    )
    destPkgPtr := flag.String(
        "dest-pkg",
        "",
        "destination package path. If not set, will consider as in the same package as repository interface.",
    )
    flag.Parse()

    if *versionPtr {
        printVersion()
        return
    }

    if *modelPtr == "" {
        printUsage()
        log.Fatal("-model flag required")
    }
    if *repoPtr == "" {
        printUsage()
        log.Fatal("-repo flag required")
    }

    request := GenerationRequest{
        Pkg:       *pkgPtr,
        ModelName: *modelPtr,
        RepoName:  *repoPtr,
        Dest:      *destPtr,
        ModelPkg:  *modelPkgPtr,
        DestPkg:   *destPkgPtr,
    }
    code, err := generateFromRequest(request)
    if err != nil {
        panic(err)
    }

    dest := os.Stdout
    if *destPtr != "" {
        if err := os.MkdirAll(filepath.Dir(*destPtr), os.ModePerm); err != nil {
            panic(err)
        }
        file, err := os.Create(*destPtr)
        if err != nil {
            panic(err)
        }
        defer file.Close()
        dest = file
    }

    if _, err := dest.WriteString(code); err != nil {
        panic(err)
    }
}

type GenerationRequest struct {
    Pkg       string
    ModelName string
    RepoName  string
    Dest      string
    ModelPkg  string
    DestPkg   string
}

func printUsage() {
    fmt.Println(usageText)
    flag.PrintDefaults()
}

func printVersion() {
    fmt.Println(version)
}

var (
    errNoPackageFound        = errors.New("no package found")
    errUnsupportMultiplePkgs = errors.New(
        `multiple packages are not supported, 
        please specify the package ID or directory path that only contains one package`,
    )
    errMissingPackageName = errors.New("missing package name")
)

func generateFromRequest(request GenerationRequest) (string, error) {
    cfg := packages.Config{
        Mode: packages.NeedName | packages.NeedTypes,
    }
    if request.ModelPkg == "" {
        request.ModelPkg = request.Pkg
    }
    if request.DestPkg == "" {
        request.DestPkg = request.Pkg
    }
    intfPkgID, err := getPkgID(request.Pkg)
    if err != nil {
        return "", err
    }
    modelPkgID, err := getPkgID(request.ModelPkg)
    if err != nil {
        return "", err
    }
    destPkgID, err := getPkgID(request.DestPkg)
    if err != nil {
        return "", err
    }
    pkgs, err := packages.Load(&cfg, intfPkgID, modelPkgID, destPkgID)
    if err != nil {
        return "", err
    }
    pkgM := packagesToMap(pkgs)
    return generator.GenerateRepositoryImpl(
        pkgM[modelPkgID].Types,
        pkgM[intfPkgID].Types,
        pkgM[destPkgID].Types,
        request.ModelName,
        request.RepoName,
    )
}

func getPkgID(pattern string) (string, error) {
    pkgs, err := packages.Load(nil, pattern)
    if err != nil {
        return "", err
    }
    if len(pkgs) < 1 {
        return "", errNoPackageFound
    }
    if len(pkgs) > 1 {
        return "", errUnsupportMultiplePkgs
    }
    // when no go file in the package, the package name will be empty
    // this prevent the missing field upfront.
    if pkgs[0].Name == "" {
        return "", fmt.Errorf("%w on %s", errMissingPackageName, pattern)
    }
    return pkgs[0].ID, nil
}

func packagesToMap(pkgs []*packages.Package) map[string]*packages.Package {
    m := make(map[string]*packages.Package)
    for _, pkg := range pkgs {
        m[pkg.ID] = pkg
    }
    return m
}