pkg/vault/ledger/vault.go
package ledger
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"time"
"github.com/ecadlabs/gotez/v2/crypt"
"github.com/ecadlabs/signatory/pkg/config"
"github.com/ecadlabs/signatory/pkg/errors"
"github.com/ecadlabs/signatory/pkg/vault"
"github.com/ecadlabs/signatory/pkg/vault/ledger/ledger"
"github.com/ecadlabs/signatory/pkg/vault/ledger/tezosapp"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
)
const defaultCloseAfter = time.Second * 10
type devRequest interface {
devRequest()
}
type getKeyReq struct {
key *keyID
res chan<- *ledgerKey
err chan<- error
}
func (g *getKeyReq) devRequest() {}
type signReq struct {
key *keyID
data []byte
sig chan<- crypt.Signature
err chan<- error
}
func (s *signReq) devRequest() {}
type keyID struct {
path tezosapp.BIP32
dt tezosapp.DerivationType
}
// Vault is a Ledger signer backend
type Vault struct {
config Config
keys []*keyID
req chan devRequest
scanner *scanner
}
// Config represents Ledger signer backend configuration
type Config struct {
ID string `yaml:"id"`
Keys []string `yaml:"keys"`
CloseAfter time.Duration `yaml:"close_after"`
Transport string `yaml:"transport"`
}
type ledgerKey struct {
id *keyID
pub crypt.PublicKey
}
func (l *ledgerKey) PublicKey() crypt.PublicKey { return l.pub }
func (l *ledgerKey) ID() string { return l.id.String() }
type ledgerIterator struct {
ctx context.Context
v *Vault
idx int
}
func (l *ledgerIterator) Next() (key vault.StoredKey, err error) {
if l.idx == len(l.v.keys) {
return nil, vault.ErrDone
}
pk, err := l.v.getPublicKey(l.ctx, l.v.keys[l.idx])
if err != nil {
return nil, err
}
l.idx++
return pk, nil
}
func (v *Vault) getPublicKey(ctx context.Context, id *keyID) (vault.StoredKey, error) {
res := make(chan *ledgerKey, 1)
errCh := make(chan error, 1)
v.req <- &getKeyReq{
key: id,
res: res,
err: errCh,
}
select {
case pk := <-res:
return pk, nil
case err := <-errCh:
return nil, fmt.Errorf("(Ledger/%s): %w", v.config.ID, err)
case <-ctx.Done():
return nil, ctx.Err()
}
}
// GetPublicKey returns a public key by given ID
func (v *Vault) GetPublicKey(ctx context.Context, id string) (vault.StoredKey, error) {
key, err := parseKeyID(id)
if err != nil {
return nil, errors.Wrap(fmt.Errorf("(Ledger/%s): %w", v.config.ID, err), http.StatusBadRequest)
}
return v.getPublicKey(ctx, key)
}
// ListPublicKeys returns a list of keys stored under the backend
func (v *Vault) ListPublicKeys(ctx context.Context) vault.StoredKeysIterator {
return &ledgerIterator{
ctx: ctx,
v: v,
}
}
func (v *Vault) SignMessage(ctx context.Context, digest []byte, key vault.StoredKey) (crypt.Signature, error) {
pk, ok := key.(*ledgerKey)
if !ok {
return nil, errors.Wrap(fmt.Errorf("(Ledger/%s): not a Ledger key: %T ", v.config.ID, key), http.StatusBadRequest)
}
res := make(chan crypt.Signature, 1)
errCh := make(chan error, 1)
v.req <- &signReq{
key: pk.id,
data: digest,
sig: res,
err: errCh,
}
select {
case pk := <-res:
return pk, nil
case err := <-errCh:
return nil, fmt.Errorf("(Ledger/%s): %w", v.config.ID, err)
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Name returns a backend name i.e. Ledger
func (v *Vault) Name() string {
return "Ledger"
}
// VaultName returns an instance ID
func (v *Vault) VaultName() string {
return v.config.ID
}
func (v *Vault) worker() {
var (
dev *tezosapp.App
err error
t *time.Timer
tch <-chan time.Time
)
closeAfter := v.config.CloseAfter
if closeAfter == 0 {
closeAfter = defaultCloseAfter
}
openDev := func(retry bool) error {
if dev != nil {
if retry {
dev.Close()
} else {
return nil
}
}
dev, err = v.scanner.open(v.config.ID)
if err != nil {
return err
}
if t == nil {
t = time.NewTimer(closeAfter)
} else {
if !t.Stop() {
<-t.C
}
t.Reset(closeAfter)
}
tch = t.C
return nil
}
for {
select {
case req := <-v.req:
switch r := req.(type) {
case *getKeyReq:
if err = openDev(false); err != nil {
r.err <- err
break
}
pub, err := dev.GetPublicKey(r.key.dt, r.key.path, false)
if err != nil {
r.err <- err
break
}
r.res <- &ledgerKey{
pub: pub,
id: r.key,
}
case *signReq:
// Retrying openDevice oncemore when ledger reset
attempt := 0
for attempt < 2 {
if err = openDev(attempt == 1); err != nil {
r.err <- err
break
}
sig, err := dev.Sign(r.key.dt, r.key.path, r.data)
if err != nil {
if attempt == 1 {
r.err <- err
} else {
attempt = attempt + 1
continue
}
break
}
attempt = 3
r.sig <- sig
}
}
case <-tch:
if err := dev.Close(); err != nil {
log.Errorf("(Ledger/%s): %v", v.config.ID, err)
break
}
dev = nil
tch = nil
}
}
}
// New returns new Ledger signer
func New(ctx context.Context, conf *Config) (*Vault, error) {
keys := make([]*keyID, len(conf.Keys))
for i, k := range conf.Keys {
kid, err := parseKeyID(k)
if err != nil {
return nil, err
}
keys[i] = kid
}
sc, err := getScanner(conf.Transport)
if err != nil {
return nil, err
}
v := &Vault{
config: *conf,
keys: keys,
req: make(chan devRequest, 10),
scanner: sc,
}
go v.worker()
return v, nil
}
func parseKeyID(s string) (*keyID, error) {
p := strings.SplitN(s, "/", 2)
if len(p) != 2 {
return nil, fmt.Errorf("error parsing key id: %s", s)
}
dt, err := tezosapp.DerivationTypeFromString(p[0])
if err != nil {
return nil, err
}
path := tezosapp.ParseBIP32(p[1])
if path == nil {
return nil, fmt.Errorf("error parsing key path: %s", p[1])
}
for _, p := range path {
if p&tezosapp.BIP32H == 0 {
return nil, errors.New("only hardened derivation is supported")
}
}
if len(path) < 2 || path[0] != tezosapp.TezosBIP32Root[0] || path[1] != tezosapp.TezosBIP32Root[1] {
path = append(tezosapp.TezosBIP32Root, path...)
}
if len(path) == 2 {
return nil, errors.New("root key isn't allowed to use")
}
return &keyID{
dt: dt,
path: path,
}, nil
}
func (k *keyID) String() string {
return k.dt.String() + "/" + k.path.String()
}
var (
hidTransport = ledger.USBHIDTransport{}
hidScanner = scanner{
tr: &hidTransport,
}
)
func getScanner(transport string) (*scanner, error) {
u, err := url.Parse(transport)
if err != nil {
return nil, err
}
var tr string
if u.Scheme != "" && u.Host != "" {
tr = u.Scheme
} else {
tr = transport
}
switch tr {
case "usb", "":
return &hidScanner, nil
case "tcp":
var model string
if u.User != nil {
model = u.User.Username()
}
return &scanner{tr: &ledger.TCPTransport{Addr: u.Host, Model: model}}, nil
default:
return nil, fmt.Errorf("undefined transport: %s", tr)
}
}
func init() {
vault.RegisterVault("ledger", func(ctx context.Context, node *yaml.Node) (vault.Vault, error) {
var conf Config
if node == nil || node.Kind == 0 {
return nil, errors.New("(Ledger): config is missing")
}
if err := node.Decode(&conf); err != nil {
return nil, err
}
if err := config.Validator().Struct(&conf); err != nil {
return nil, err
}
return New(ctx, &conf)
})
vault.RegisterCommand(newLedgerCommand())
}