go/zeusclient/zeusclient.go

Summary

Maintainability
C
1 day
Test Coverage
package zeusclient

import (
    "io"
    "net"
    "os"
    "os/signal"
    "strconv"
    "strings"
    "sync"
    "syscall"

    "github.com/burke/ttyutils"
    "github.com/burke/zeus/go/messages"
    slog "github.com/burke/zeus/go/shinylog"
    "github.com/burke/zeus/go/unixsocket"
    "github.com/burke/zeus/go/zerror"
    "github.com/kr/pty"
)

const (
    sigInt  = 3
    sigQuit = 28
    sigTstp = 26
)

// man signal | grep 'terminate process' | awk '{print $2}' | xargs -I '{}' echo -n "syscall.{}, "
var terminatingSignals = []os.Signal{syscall.SIGHUP, syscall.SIGINT, syscall.SIGKILL, syscall.SIGPIPE, syscall.SIGALRM, syscall.SIGTERM, syscall.SIGXCPU, syscall.SIGXFSZ, syscall.SIGVTALRM, syscall.SIGPROF, syscall.SIGUSR1, syscall.SIGUSR2}

// Run runs the main Zeus command.
func Run(args []string, input io.Reader, output *os.File, stderr *os.File) int {
    if os.Getenv("RAILS_ENV") != "" {
        println("Warning: Specifying a Rails environment via RAILS_ENV has no effect for commands run with zeus.")
        println("As a safety precaution to protect you from nuking your development database,")
        println("Zeus will now cowardly refuse to proceed. Please unset RAILS_ENV and try again.")
        return 1
    }

    // setup stdout
    localStdout, remoteStdout, outputIsTerminal, err := socketsForOutput(output)
    if err != nil {
        slog.ErrorString(err.Error() + "\r")
        return 1
    }
    defer localStdout.Close()
    defer remoteStdout.Close()

    // setup stderr
    localStderr, remoteStderr, _, err := socketsForOutput(stderr)
    if err != nil {
        slog.ErrorString(err.Error() + "\r")
        return 1
    }
    defer localStderr.Close()
    defer remoteStderr.Close()

    addr, err := net.ResolveUnixAddr("unixgram", unixsocket.ZeusSockName())
    if err != nil {
        slog.ErrorString(err.Error() + "\r")
        return 1
    }

    conn, err := net.DialUnix("unix", nil, addr)
    if err != nil {
        zerror.ErrorCantConnectToMaster()
        return 1
    }
    usock := unixsocket.New(conn)

    msg := messages.CreateCommandAndArgumentsMessage(args, os.Getpid())
    usock.WriteMessage(msg)
    err = sendCommandLineArguments(usock, args)
    if err != nil {
        slog.ErrorString(err.Error() + "\r")
        return 1
    }

    usock.WriteFD(int(remoteStdout.Fd()))
    usock.WriteFD(int(remoteStderr.Fd()))

    msg, err = usock.ReadMessage()
    if err != nil {
        slog.ErrorString(err.Error() + "\r")
        return 1
    }

    parts := strings.Split(msg, "\000")
    commandPid, err := strconv.Atoi(parts[0])
    defer func() {
        if commandPid > 0 {
            // Just in case.
            syscall.Kill(commandPid, 9)
        }
    }()

    if err != nil {
        slog.ErrorString(err.Error() + "\r")
        return 1
    }

    if outputIsTerminal {
        c := make(chan os.Signal, 1)
        handledSignals := append(append(terminatingSignals, syscall.SIGWINCH), syscall.SIGCONT)
        signal.Notify(c, handledSignals...)
        go func() {
            for sig := range c {
                if sig == syscall.SIGCONT {
                    syscall.Kill(commandPid, syscall.SIGCONT)
                } else if sig == syscall.SIGWINCH {
                    ttyutils.MirrorWinsize(output, localStdout)
                    syscall.Kill(commandPid, syscall.SIGWINCH)
                } else { // member of terminatingSignals
                    print("\r")
                    syscall.Kill(commandPid, sig.(syscall.Signal))
                    os.Exit(1)
                }
            }
        }()
    }

    exitStatus := -1
    if len(parts) > 2 {
        exitStatus, err = strconv.Atoi(parts[0])
        if err != nil {
            slog.ErrorString(err.Error() + "\r")
            return 1
        }
    }

    var endOfIO sync.WaitGroup

    oldTerminalStateStdout, err := forwardOutput(localStdout, output, endOfIO)
    if oldTerminalStateStdout != nil {
        defer ttyutils.RestoreTerminalState(output.Fd(), oldTerminalStateStdout)
    }
    if err != nil {
        slog.ErrorString(err.Error() + "\r")
        return 1
    }

    oldTerminalStateStderr, err := forwardOutput(localStderr, stderr, endOfIO)
    if oldTerminalStateStderr != nil {
        defer ttyutils.RestoreTerminalState(stderr.Fd(), oldTerminalStateStderr)
    }
    if err != nil {
        slog.ErrorString(err.Error() + "\r")
        return 1
    }

    go func() {
        endOfIO.Add(1)
        buf := make([]byte, 8192)
        for {
            n, err := input.Read(buf)
            if err != nil {
                endOfIO.Done()
                break
            }
            if outputIsTerminal {
                for i := 0; i < n; i++ {
                    switch buf[i] {
                    case sigInt:
                        syscall.Kill(commandPid, syscall.SIGINT)
                    case sigQuit:
                        syscall.Kill(commandPid, syscall.SIGQUIT)
                    case sigTstp:
                        syscall.Kill(commandPid, syscall.SIGTSTP)
                        syscall.Kill(os.Getpid(), syscall.SIGTSTP)
                    }
                }
            }
            localStdout.Write(buf[:n])
        }
    }()

    endOfIO.Wait()

    if exitStatus == -1 {
        msg, err = usock.ReadMessage()
        if err != nil {
            slog.ErrorString(err.Error() + "\r")
            return 1
        }
        parts := strings.Split(msg, "\000")
        exitStatus, err = strconv.Atoi(parts[0])
        if err != nil {
            slog.ErrorString(err.Error() + "\r")
            return 1
        }
    }

    return exitStatus
}

func sendCommandLineArguments(usock *unixsocket.Usock, args []string) error {
    master, slave, err := unixsocket.Socketpair(syscall.SOCK_STREAM)
    if err != nil {
        return err
    }
    usock.WriteFD(int(slave.Fd()))
    if err != nil {
        return err
    }
    slave.Close()

    go func() {
        defer master.Close()
        argAsBytes := []byte{}
        for _, arg := range args[1:] {
            argAsBytes = append(argAsBytes, []byte(arg)...)
            argAsBytes = append(argAsBytes, byte(0))
        }
        _, err = master.Write(argAsBytes)
        if err != nil {
            slog.ErrorString("Could not send arguments across: " +
                err.Error() + "\r")
            return
        }
    }()

    return nil
}

func socketsForOutput(out *os.File) (local, remote *os.File, outIsTerminal bool, err error) {
    outIsTerminal = ttyutils.IsTerminal(out.Fd())

    if outIsTerminal {
        local, remote, err = pty.Open()
    } else {
        local, remote, err = unixsocket.Socketpair(syscall.SOCK_STREAM)
    }

    return
}

func forwardOutput(from, to *os.File, signalEnd sync.WaitGroup) (oldTerminalState *ttyutils.Termios, err error) {
    ttyutils.MirrorWinsize(to, from)

    if ttyutils.IsTerminal(to.Fd()) {
        oldTerminalState, err = ttyutils.MakeTerminalRaw(to.Fd())
        if err != nil {
            return
        }
    }

    go func() {
        signalEnd.Add(1)
        for {
            buf := make([]byte, 1024)
            n, err := from.Read(buf)

            if err == nil || (err == io.EOF && n > 0) {
                to.Write(buf[:n])
            } else {
                signalEnd.Done()
                break
            }
        }
    }()

    return
}