pkg/vault/ledger/ledger/usbhid.go
package ledger
import (
"crypto/rand"
"errors"
"fmt"
"math/big"
"runtime"
"github.com/karalabe/hid"
)
const (
ledgerUSBVendorID = 0x2c97
ledgerUsagePage = 0xffa0
headerSize = 5
packetSize = 64
chunkSize = packetSize - headerSize
)
const (
cmdPing = 2
cmdAPDU = 5
)
// USBHIDTransport is a USB HID transport backend
type USBHIDTransport struct{}
func isValidInterface(d *hid.DeviceInfo) bool {
if runtime.GOOS == "darwin" || runtime.GOOS == "windows" {
return d.UsagePage == ledgerUsagePage
}
return d.Interface == 0
}
// Enumerate returns a list os attached Ledger devices
func (u *USBHIDTransport) Enumerate() ([]*DeviceInfo, error) {
devs := hid.Enumerate(ledgerUSBVendorID, 0)
res := make([]*DeviceInfo, 0, len(devs))
for _, d := range devs {
if !isValidInterface(&d) {
continue
}
di := DeviceInfo{
Path: d.Path,
}
for _, ldi := range ledgerDevices {
if ldi.LegacyUSBProductID == d.ProductID || ldi.ProductIDMM == uint8((d.ProductID>>8)&0xff) {
di.DeviceInfo = ldi
break
}
}
res = append(res, &di)
}
return res, nil
}
type usbHIDRoundTripper struct {
channel uint16
dev *hid.Device
}
type packet struct {
channel uint16
cmd uint8
seq uint16
data []byte
}
func (u *usbHIDRoundTripper) writePacket(p *packet) error {
var pkt [packetSize]byte
pkt[0] = uint8((p.channel >> 8) & 0xff)
pkt[1] = uint8(p.channel & 0xff)
pkt[2] = p.cmd
pkt[3] = uint8((p.seq >> 8) & 0xff)
pkt[4] = uint8(p.seq & 0xff)
copy(pkt[5:], p.data)
//fmt.Println(">>>")
//fmt.Println(hex.Dump(pkt[:]))
if _, err := u.dev.Write(pkt[:]); err != nil {
return fmt.Errorf("ledger: %w", err)
}
return nil
}
func (u *usbHIDRoundTripper) readPacket() (*packet, error) {
var pkt [packetSize]byte
sz, err := u.dev.Read(pkt[:])
if err != nil {
return nil, fmt.Errorf("ledger: %w", err)
}
var pl = pkt[:sz]
//fmt.Println("<<<")
//fmt.Println(hex.Dump(pl))
if len(pl) < 5 {
return nil, fmt.Errorf("ledger: packet is too short: %d", sz)
}
return &packet{
channel: uint16(pl[0])<<8 | uint16(pl[1]),
cmd: pl[2],
seq: uint16(pl[3])<<8 | uint16(pl[4]),
data: pkt[5:],
}, nil
}
func (u *usbHIDRoundTripper) writeCommand(cmd uint8, data []byte) error {
pkt := packet{
channel: u.channel,
cmd: cmd,
}
if cmd != cmdAPDU {
return u.writePacket(&pkt)
}
buf := make([]byte, len(data)+2)
buf[0] = uint8((len(data) >> 8) & 0xff)
buf[1] = uint8(len(data) & 0xff)
copy(buf[2:], data)
i := 0
off := 0
for off < len(buf) {
sz := chunkSize
if sz > len(buf)-off {
sz = len(buf) - off
}
pkt.seq = uint16(i)
pkt.data = buf[off : off+sz]
off += sz
i++
if err := u.writePacket(&pkt); err != nil {
return err
}
}
return nil
}
func (u *usbHIDRoundTripper) readCommand() (channel uint16, cmd uint8, data []byte, err error) {
var (
dataLen int
idx uint16
)
data = make([]byte, 0)
for {
var pkt *packet
pkt, err = u.readPacket()
if err != nil {
return
}
pl := pkt.data
if idx == 0 {
cmd = pkt.cmd
channel = pkt.channel
if cmd == cmdAPDU {
if len(pl) < 2 {
err = fmt.Errorf("ledger: packet is too short: %d", len(pl))
return
}
dataLen = int(pl[0])<<8 | int(pl[1])
pl = pl[2:]
}
}
// subsequent packages must have the same channel and command ids
if pkt.seq != idx {
err = fmt.Errorf("ledger: invalid packet index: %d", pkt.seq)
return
}
if pkt.cmd != cmd {
err = fmt.Errorf("ledger: unexpected command: %d", pkt.cmd)
return
}
if pkt.channel != channel {
err = fmt.Errorf("ledger: unexpected channel: %d", pkt.channel)
return
}
ln := len(pl)
if ln > dataLen-len(data) {
ln = dataLen - len(data)
}
data = append(data, pl[:ln]...)
idx++
if len(data) == dataLen {
return
}
}
}
func (u *usbHIDRoundTripper) Exchange(req *APDUCommand) (*APDUResponse, error) {
//fmt.Printf("%#v\n", req)
r := req.Bytes()
if err := u.writeCommand(cmdAPDU, r); err != nil {
return nil, err
}
ch, cmd, data, err := u.readCommand()
if err != nil {
return nil, err
}
if ch != u.channel {
return nil, fmt.Errorf("ledger: invalid channel in reply: %d", ch)
}
if cmd != cmdAPDU {
return nil, fmt.Errorf("ledger: invalid command: %d", cmd)
}
res := parseAPDUResponse(data)
if res == nil {
return nil, errors.New("ledger: error parsing APDU response")
}
//fmt.Printf("%#v\n", res)
return res, nil
}
func (u *usbHIDRoundTripper) Ping() error {
if err := u.writeCommand(cmdPing, nil); err != nil {
return err
}
ch, cmd, data, err := u.readCommand()
if err != nil {
return err
}
if cmd == cmdPing {
if ch != u.channel {
return fmt.Errorf("ledger: invalid channel in reply: %d", ch)
}
return nil
} else if cmd == cmdAPDU {
apdu := parseAPDUResponse(data)
if apdu == nil {
return errors.New("ledger: error parsing APDU response")
}
return APDUError(apdu.SW)
}
return fmt.Errorf("ledger: invalid command: %d", cmd)
}
func (u *usbHIDRoundTripper) Close() error {
return u.dev.Close()
}
// Open returns a new Exchanger
func (u *USBHIDTransport) Open(path string) (Exchanger, error) {
if path == "" {
devs, err := u.Enumerate()
if err != nil {
return nil, err
}
if len(devs) == 0 {
return nil, errors.New("ledger: no Ledger devices found")
}
path = devs[0].Path
}
dev, err := hid.DeviceInfo{Path: path}.Open()
if err != nil {
return nil, fmt.Errorf("ledger: %w", err)
}
n, err := rand.Int(rand.Reader, big.NewInt(0x10000))
if err != nil {
return nil, fmt.Errorf("ledger: %w", err)
}
rt := usbHIDRoundTripper{
dev: dev,
channel: uint16(n.Uint64()),
}
return &rt, nil
}
var _ Transport = (*USBHIDTransport)(nil)