vorteil/vorteil

View on GitHub
pkg/cli/main.go

Summary

Maintainability
A
2 hrs
Test Coverage
F
0%
package cli

/**
 * SPDX-License-Identifier: Apache-2.0
 * Copyright 2020 vorteil.io Pty Ltd
 */

import (
    "errors"
    "fmt"
    "io/ioutil"
    "net/http"
    "net/url"
    "os"
    "path/filepath"
    "runtime"
    "strings"

    "github.com/sisatech/tablewriter"
    "github.com/spf13/cobra"
    "github.com/spf13/pflag"
    "github.com/vorteil/vorteil/pkg/elog"
    "github.com/vorteil/vorteil/pkg/vcfg"
    "github.com/vorteil/vorteil/pkg/vdisk"
    "github.com/vorteil/vorteil/pkg/vio"
    "github.com/vorteil/vorteil/pkg/vpkg"
    "github.com/vorteil/vorteil/pkg/vproj"
)

var (
    release = "0.0.0"
    commit  = ""
    date    = "Thu, 01 Jan 1970 00:00:00 +0000"
)

// Each command executed may have a error message and status code
var errorStatusCode int
var errorStatusMessage error

// SetError sets the global variables for when the process exits to display accordingly
func SetError(err error, code int) {
    errorStatusCode = code
    errorStatusMessage = err
}

func isEmptyDir(path string) bool {

    fis, err := ioutil.ReadDir(path)
    if err != nil {
        return false
    }

    if len(fis) > 0 {
        return false
    }

    return true
}

func isNotExist(path string) bool {
    _, err := os.Stat(path)
    return os.IsNotExist(err)
}

func checkValidNewDirOutput(path string, force bool, dest, flag string) error {
    if !isEmptyDir(path) {
        return checkValidNewFileOutput(path, force, dest, flag)
    }

    return nil
}

func checkValidNewFileOutput(path string, force bool, dest, flag string) error {
    if !isNotExist(path) {
        if force {
            err := os.RemoveAll(path)
            if err != nil {
                return fmt.Errorf("failed to delete existing %s '%s': %w", dest, path, err)
            }

            dir := filepath.Dir(path)
            err = os.MkdirAll(dir, 0777)
            if err != nil {
                return fmt.Errorf("failed to create parent directory for %s '%s': %w", dest, path, err)
            }
        } else {
            return fmt.Errorf("%s '%s' already exists (you can use '%s' to force an overwrite)", dest, path, flag)
        }
    }

    return nil
}

func parseImageFormat(s string) (vdisk.Format, error) {
    format, err := vdisk.ParseFormat(s)
    if err != nil {
        return format, fmt.Errorf("%w -- try one of these: %s", err, strings.Join(vdisk.AllFormatStrings(), ", "))
    }
    return format, nil
}

type sourceType string

const (
    sourceURL     sourceType = "URL"
    sourceFile               = "File"
    sourceDir                = "Dir"
    sourceINVALID            = "INVALID"
)

// readSourcePath fetches src, target or returns an error
func readSourcePath(src string) (string, string, error) {

    // Get absolute path to help with splitting
    src, err := filepath.Abs(src)
    if err != nil {
        return "", "", err
    }

    // Check if source contains a target
    colonSplit := strings.Split(src, ":")
    colonLength := len(colonSplit)

    // Most of the time target should be here
    target := colonSplit[len(colonSplit)-1]

    // If target is src no target was provided
    if target == src {
        target = ""
    }
    if target != "" {
        src = colonSplit[0]
        if runtime.GOOS == "windows" {
            src = fmt.Sprintf("%s:%s", colonSplit[0], colonSplit[1])
            // Catches edge case when window users provide no targets because of directory
            if colonLength == 2 {
                target = ""
            }
        }
    }

    return src, target, err
}

func getSourceType(src string) (sourceType, error) {
    var err error
    var fi os.FileInfo

    // Check if Source is a URL
    if _, err := url.ParseRequestURI(src); err == nil {
        if u, uErr := url.Parse(src); uErr == nil && u.Scheme != "" && u.Host != "" && u.Path != "" {
            return sourceURL, nil
        }
    }

    src, target, err := readSourcePath(src)
    if err != nil {
        return sourceINVALID, err
    }

    // Check if Source is a file or dir
    fi, err = os.Stat(src)
    if !os.IsNotExist(err) && (fi != nil && !fi.IsDir()) {
        if target != "" {
            return sourceINVALID, errors.New("Targetable runs are unable to be used on packages")
        }
        return sourceFile, nil
    } else if !os.IsNotExist(err) && (fi != nil && fi.IsDir()) {
        return sourceDir, nil
    }

    // Source is unknown and thus is invalid
    return sourceINVALID, err
}

func checkIfNewVRepo(src string) (string, error) {
    urlo, err := url.Parse(src)
    if err != nil {
        return "", err
    }
    client := &http.Client{}

    req, err := http.NewRequest("GET", fmt.Sprintf("%s://%s/info", urlo.Scheme, urlo.Host), nil)
    if err != nil {
        return "", err
    }
    req.Header.Set("Info-Request", "True")

    resp, err := client.Do(req)
    if err != nil {
        return "", err
    }
    defer resp.Body.Close()
    if resp.StatusCode != http.StatusOK {
        return "", errors.New("not a new vorteil repository")
    }
    return resp.Header.Get("Vorteil-Repository"), nil
}

func getReaderURL(src string) (vpkg.Reader, error) {

    newVrepo, err := checkIfNewVRepo(src)
    if err != nil {
        return nil, err
    }
    client := &http.Client{}

    req, err := http.NewRequest("GET", src, nil)
    if err != nil {
        return nil, err
    }
    if newVrepo == "True" {
        token, err := checkAuthentication()
        if err != nil {
            return nil, err
        }

        if token != "" {
            req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token))
        }
    }

    resp, err := client.Do(req)
    if err != nil {
        if resp != nil {
            resp.Body.Close()
        }
        return nil, err
    }
    if resp.StatusCode != http.StatusOK {
        return nil, errors.New(resp.Status)
    }

    var p elog.Progress
    if resp.ContentLength == -1 {
        p = log.NewProgress("Downloading package", "", 0)
        defer p.Finish(true)
    } else {
        p = log.NewProgress("Downloading package", "KiB", resp.ContentLength)
    }

    pkgr, err := vpkg.Load(p.ProxyReader(resp.Body))
    if err != nil {
        resp.Body.Close()
        p.Finish(false)
        return nil, err
    }

    return pkgr, nil
}

func getBuilderURL(argName, src string) (vpkg.Builder, error) {
    pkgr, err := getReaderURL(src)
    if err != nil {
        return nil, err
    }
    pkgb, err := vpkg.NewBuilderFromReader(pkgr)
    if err != nil {
        pkgr.Close()
        return nil, err
    }
    return pkgb, nil
}

func getReaderFile(src string) (vpkg.Reader, error) {
    f, err := os.Open(src)
    if err != nil {
        if !os.IsNotExist(err) {
            return nil, err
        }
        return nil, fmt.Errorf("failed to resolve %s ", src)
    }

    pkgr, err := vpkg.Load(f)
    if err != nil {
        f.Close()
        return nil, err
    }
    return pkgr, nil
}
func getBuilderFile(argName, src string) (vpkg.Builder, error) {

    pkgr, err := getReaderFile(src)
    if err != nil {
        return nil, err
    }
    pkgb, err := vpkg.NewBuilderFromReader(pkgr)
    if err != nil {
        pkgr.Close()
        return nil, err
    }
    return pkgb, err
}

func getBuilderDir(argName, src string) (vpkg.Builder, error) {
    var ptgt *vproj.Target
    path, target := vproj.Split(src)
    proj, err := vproj.LoadProject(path)
    if err != nil {
        if !os.IsNotExist(err) {
            return nil, err
        }
        return nil, fmt.Errorf("failed to resolve %s '%s'", argName, src)
    }

    ptgt, err = proj.Target(target)
    if err != nil {
        return nil, err
    }

    pkgb, err := ptgt.NewBuilder()
    return pkgb, err
}

func getPackageReader(argName, src string) (vpkg.Reader, error) {
    var err error
    var pkgR vpkg.Reader
    sType, err := getSourceType(src)
    if err != nil {
        return nil, err
    }

    switch sType {
    case sourceURL:
        pkgR, err = getReaderURL(src)
    case sourceFile:
        pkgR, err = getReaderFile(src)
    case sourceINVALID:
        fallthrough
    default:
        err = fmt.Errorf("failed to resolve %s '%s'", argName, src)
    }

    return pkgR, err
}
func getPackageBuilder(argName, src string) (vpkg.Builder, error) {
    var err error
    var pkgB vpkg.Builder
    sType, err := getSourceType(src)
    if err != nil {
        return nil, err
    }

    switch sType {
    case sourceURL:
        pkgB, err = getBuilderURL(argName, src)
    case sourceFile:
        pkgB, err = getBuilderFile(argName, src)
    case sourceDir:
        pkgB, err = getBuilderDir(argName, src)
    case sourceINVALID:
        fallthrough
    default:
        err = fmt.Errorf("failed to resolve %s '%s'", argName, src)
    }

    return pkgB, err

    // TODO: check for vrepo strings

}

var (
    flagIcon             string
    flagVCFG             []string
    flagInfoDate         string
    flagInfoURL          string
    flagSystemFilesystem string
    flagSystemOutputMode string
    flagSysctl           []string
    flagVMDiskSize       string
    flagVMInodes         string
    flagVMRAM            string
    overrideVCFG         vcfg.VCFG
)

func addModifyFlags(f *pflag.FlagSet) {
    vcfgFlags.AddTo(f)
}

// mergeFlagVCFGFiles : Merge values from from VCFG files stored in 'flagVCFG', and then merge vcfg flag values with overrideVCFG
func mergeVCFGFlagValues(b *vpkg.Builder) error {
    var err error
    var f vio.File
    var cfg *vcfg.VCFG

    // Iterate over vcfg paths stored in flagVCFG, read vcfg files and merge into b
    for _, path := range flagVCFG {
        f, err = vio.Open(path)
        if err != nil {
            return err
        }

        cfg, err = vcfg.LoadFile(f)
        if err != nil {
            return err
        }

        err = (*b).MergeVCFG(cfg)
        if err != nil {
            return err
        }
    }

    // Merge overrideVCFG object containing flag values into b
    err = (*b).MergeVCFG(&overrideVCFG)
    return err
}

func modifyPackageBuilder(b vpkg.Builder) error {
    var err error
    var f vio.File

    err = vcfgFlags.Validate()
    if err != nil {
        return err
    }

    if flagIcon != "" {
        f, err = vio.LazyOpen(flagIcon)
        if err != nil {
            return err
        }
        b.SetIcon(f)
    }

    err = mergeVCFGFlagValues(&b)
    if err != nil {
        return err
    }

    err = handleFileInjections(b)
    return err
}

// NumbersMode determines which numbers format a PrintableSize should render to.
var NumbersMode int

// SetNumbersMode parses s and sets NumbersMode accordingly.
func SetNumbersMode(s string) error {
    s = strings.ToLower(s)
    s = strings.TrimSpace(s)
    switch s {
    case "", "short":
        NumbersMode = 0
    case "dec", "decimal":
        NumbersMode = 1
    case "hex", "hexadecimal":
        NumbersMode = 2
    default:
        return fmt.Errorf("numbers mode must be one of 'dec', 'hex', or 'short'")
    }
    return nil
}

// SetNumberModeFlagCMD : Will SetNumberMode to the value of the cmd flag 'numbers'
func SetNumberModeFlagCMD(cmd *cobra.Command) error {
    numbers, err := cmd.Flags().GetString("numbers")
    if err != nil {
        return err
    }

    err = SetNumbersMode(numbers)
    if err != nil {
        return fmt.Errorf("couldn't parse value of --numbers: %v", err)
    }

    return nil
}

// PrintableSize is a wrapper around int to alter its string formatting behaviour.
type PrintableSize int

// String returns a string representation of the PrintableSize, formatted according to the global NumbersMode.
func (c PrintableSize) String() string {
    switch NumbersMode {
    case 0:
        x := int(c)
        if x == 0 {
            return "0"
        }
        var units int
        var suffixes = []string{"", "K", "M", "G"}
        for {
            if x%1024 != 0 {
                break
            }
            x /= 1024
            units++
            if units == len(suffixes)-1 {
                break
            }
        }
        return fmt.Sprintf("%d%s", x, suffixes[units])
    case 1:
        return fmt.Sprintf("%d", int(c))
    case 2:
        return fmt.Sprintf("%#x", int(c))
    default:
        panic("invalid NumbersMode")
    }
}

// PlainTable prints data in a grid, handling alignment automatically.
func PlainTable(vals [][]string) {
    if len(vals) == 0 {
        panic(errors.New("no rows provided"))
    }

    table := tablewriter.NewWriter(os.Stdout)
    table.SetAlignment(tablewriter.ALIGN_LEFT)
    table.SetBorder(false)
    table.SetColumnSeparator("")
    for i := 1; i < len(vals); i++ {
        table.Append(vals[i])
    }

    table.Render()
}