cmd/rpc.go
// Copyright © 2016 Asteris, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package cmd
import (
"encoding/json"
"io"
"net"
"net/url"
"strings"
"time"
"google.golang.org/grpc/metadata"
"github.com/asteris-llc/converge/graph"
"github.com/asteris-llc/converge/helpers/logging"
"github.com/asteris-llc/converge/rpc"
"github.com/asteris-llc/converge/rpc/pb"
"github.com/pkg/errors"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"golang.org/x/net/context"
)
const (
rpcAddrFlagName = "rpc-addr"
rpcLocalAddrName = "local-addr"
rpcEnableLocalName = "local"
)
func registerRPCFlags(flags *pflag.FlagSet) {
flags.String(rpcTokenFlagName, "", "token for RPC")
flags.Bool(rpcNoTokenFlagName, false, "don't use or generate an RPC token")
flags.String(rpcAddrFlagName, addrServer, "address for RPC connection")
}
func registerLocalRPCFlags(flags *pflag.FlagSet) {
flags.String(rpcLocalAddrName, addrServerLocal, "address for local RPC connection")
flags.Bool(rpcEnableLocalName, false, "self host RPC")
}
func maybeStartSelfHostedRPC(ctx context.Context) error {
if getLocal() {
go startRPC(ctx)
var err error
for i := 0; i < 5; i++ {
_, err = net.Dial("tcp", getServerURL().Host)
if err == nil {
return nil
}
time.Sleep(100 * time.Millisecond)
}
return err
}
return nil
}
func startRPC(ctx context.Context) error {
// set context for logging
logger := logging.GetLogger(ctx).WithField("component", "rpc")
ctx = logging.WithLogger(ctx, logger)
loc := getServerURL()
// set up security options
if !usingSSL() {
logger.Warning("no SSL config in use, server will accept unencrypted connections")
}
// create server
server := &rpc.Server{
Security: getSecurityConfig(),
ResourceRoot: viper.GetString("root"),
EnableBinaryDownload: viper.GetBool("self-serve"),
}
return server.Listen(ctx, loc)
}
func getRPCExecutorClient(ctx context.Context, security *rpc.Security) (pb.ExecutorClient, error) {
return rpc.NewExecutorClient(ctx, getServerURL().Host, security)
}
func getRPCGrapherClient(ctx context.Context, security *rpc.Security) (*rpc.GrapherClient, error) {
return rpc.NewGrapherClient(ctx, getServerURL().Host, security)
}
func getInfoClient(ctx context.Context, security *rpc.Security) (*rpc.InfoClient, error) {
return rpc.NewInfoClient(ctx, getServerURL().Host, security)
}
type recver interface {
Recv() (*pb.StatusResponse, error)
}
func iterateOverStream(stream recver, cb func(*pb.StatusResponse)) error {
for {
resp, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
return errors.Wrap(err, "error getting status response")
}
cb(resp)
}
return nil
}
type headerer interface {
Header() (metadata.MD, error)
}
func getMeta(stream headerer) ([]*graph.Edge, error) {
meta, err := stream.Header()
if err != nil {
return nil, errors.Wrap(err, "error getting RPC header")
}
var edges []*graph.Edge
if blobs, ok := meta["edges"]; ok {
for _, blob := range blobs {
var out []*graph.Edge
err := json.Unmarshal([]byte(blob), &out)
if err != nil {
return nil, errors.Wrap(err, "could not deserialize edge metadata")
}
edges = append(edges, out...)
}
}
return edges, nil
}
// More getters
func setLocal(local bool) { viper.Set(rpcEnableLocalName, local) }
func getLocal() bool { return viper.GetBool(rpcEnableLocalName) }
func getLocalAddr() string { return viper.GetString(rpcLocalAddrName) }
func getRPCAddr() string { return viper.GetString(rpcAddrFlagName) }
func getServerURL() *url.URL {
out := new(url.URL)
if getLocal() {
out.Host = getLocalAddr()
} else {
out.Host = getRPCAddr()
}
// set host to localhost, if not set
if strings.HasPrefix(out.Host, ":") {
out.Host = "127.0.0.1" + out.Host
}
// set protocol
if usingSSL() {
out.Scheme = "https"
} else {
out.Scheme = "http"
}
return out
}