synapsecns/sanguine

View on GitHub
tools/modulecopier/internal/copy.go

Summary

Maintainability
A
2 hrs
Test Coverage
package internal

import (
    "bytes"
    "fmt"
    "github.com/markbates/pkger"
    "github.com/thoas/go-funk"
    "go/ast"
    "go/format"
    "go/parser"
    "go/printer"
    "go/token"
    "golang.org/x/tools/go/ast/astutil"
    "io"
    "io/fs"
    "os"
    "path"
    "path/filepath"
    "strings"
)

// CopyModule copies a module path to a destination.
func CopyModule(toCopy, dest, packageName string) error {
    // walk through the dir, see: https://github.com/markbates/pkger/blob/09e9684b656b/examples/app/main.go#L29
    info, err := pkger.Info(toCopy)
    if err != nil {
        return fmt.Errorf("could not resolve %s", toCopy)
    }

    // get the go files to copy
    goFiles := append(info.GoFiles, info.TestGoFiles...)

    err = pkger.Walk(toCopy, func(filePath string, info fs.FileInfo, err error) error {
        if err != nil {
            return fmt.Errorf("error while walking: %w", err)
        }
        // if it's not a go file, skip it
        if !funk.ContainsString(goFiles, info.Name()) {
            return nil
        }

        return copyGoFile(filePath, packageName, dest, info)
    })

    if err != nil {
        return fmt.Errorf("error while copying: %w", err)
    }
    return nil
}

// copyGoFile copies a go file using the package info.
func copyGoFile(filePath, packageName, dest string, info fs.FileInfo) error {
    fileContents, err := getUpdatedFileContents(filePath, packageName)
    if err != nil {
        return fmt.Errorf("could not get updated file contents: %w", err)
    }

    newFile := fmt.Sprintf("%s/%s", dest, getFileName(info.Name()))
    //nolint: gosec
    f, err := os.Create(newFile)
    if err != nil {
        return fmt.Errorf("could not open file")
    }

    // write the contents to the file
    _, err = f.Write(fileContents)
    if err != nil {
        return fmt.Errorf("could not write to file: %w", err)
    }

    err = f.Close()
    if err != nil {
        return fmt.Errorf("could not close file: %w", err)
    }

    return nil
}

// CopyFile copies a single go file. This will not bring dependencies.
func CopyFile(fileToCopy, dest, packageName string) error {
    // first things first, pkger operates on go modules, so we need to trim
    modulePath := path.Dir(fileToCopy)
    fileName := path.Base(fileToCopy)

    // make sure the last element is a file
    if filepath.Ext(fileName) != ".go" {
        return fmt.Errorf("must specify a .go file after module, got %s", filepath.Ext(fileName))
    }

    err := pkger.Walk(modulePath, func(filePath string, info fs.FileInfo, err error) error {
        if err != nil {
            return fmt.Errorf("error while walking: %w", err)
        }

        // only copy the target file
        if info.Name() != fileName {
            return nil
        }

        return copyGoFile(filePath, packageName, dest, info)
    })

    if err != nil {
        return fmt.Errorf("error while copying: %w", err)
    }

    return nil
}

// getFileName gets the new file name. Gen is added here before the .go in the case of non tests
// and before _test.go in the case of tests.
func getFileName(originalName string) string {
    suffix := filepath.Ext(originalName)
    noExtensionName := strings.TrimSuffix(originalName, suffix)

    const testSuffix = "_test"

    // if it's a test strip it from the original name and add it to the suffix
    testIndex := strings.LastIndex(noExtensionName, testSuffix)
    if testIndex != -1 {
        noExtensionName = noExtensionName[:testIndex] + strings.Replace(noExtensionName[testIndex:], testSuffix, "", 1)
        suffix = testSuffix + suffix
    }

    return noExtensionName + "_gen" + suffix
}

// getUpdatedFileContents rewrites adds the generation header and rewrites the package name.
func getUpdatedFileContents(path, newPackageName string) (fileContents []byte, err error) {
    file, err := pkger.Open(path)
    if err != nil {
        return fileContents, fmt.Errorf("could not open file at %s: %w", path, err)
    }

    fileContents, err = io.ReadAll(file)
    if err != nil {
        return fileContents, fmt.Errorf("could not read file %s: %w", fileContents, err)
    }

    // prepend the header to the file
    fileContents = append([]byte(makeGeneratedHeader(path)+"\n\n"), fileContents...)

    // rename the package by modifying the ast
    fset := token.NewFileSet()

    fileAst, err := parser.ParseFile(fset, filepath.Base(path), fileContents, parser.ParseComments)
    if err != nil {
        return nil, fmt.Errorf("could not parse ast. This could indicate an invalid source file: %w", err)
    }

    newAst := astutil.Apply(fileAst, nil, func(cursor *astutil.Cursor) bool {
        if ident, ok := cursor.Node().(*ast.Ident); ok {
            cursor.Replace(&ast.Ident{
                NamePos: ident.NamePos,
                Name:    newPackageName,
                Obj:     ident.Obj,
            })
            return false
        }
        return true
    })

    fileBuffer := bytes.NewBuffer([]byte{})
    err = printer.Fprint(fileBuffer, fset, newAst)
    if err != nil {
        return nil, fmt.Errorf("could not write resulting ast: %w", err)
    }

    // TODO: use golangci-lint
    formatted, err := format.Source(fileBuffer.Bytes())
    if err != nil {
        return nil, fmt.Errorf("could not format: %w", err)
    }

    return formatted, nil
}

// makeGenerated header makes the code generation header
// note: this must conform to https://github.com/golangci/golangci-lint/blob/1fb67fe448da8a3fb525ecef28decceb23b42d7a/pkg/result/processors/autogenerated_exclude.go#L76
// to bypass linters.
func makeGeneratedHeader(origin string) string {
    return fmt.Sprintf("// Code copied from %s for testing by synapse modulecopier DO NOT EDIT.\"", origin)
}