dennis-tra/pcp

View on GitHub
pkg/node/transfer.go

Summary

Maintainability
B
4 hrs
Test Coverage
package node

import (
    "archive/tar"
    "context"
    "fmt"
    "io"
    "os"
    "path/filepath"
    "sync"

    "github.com/pkg/errors"

    "github.com/libp2p/go-libp2p-core/network"
    "github.com/libp2p/go-libp2p-core/peer"
    progress "github.com/schollz/progressbar/v3"

    "github.com/dennis-tra/pcp/internal/log"
    "github.com/dennis-tra/pcp/pkg/crypt"
)

// pattern: /protocol-name/request-or-response-message/version
const (
    ProtocolTransfer = "/pcp/transfer/0.2.0"
)

// TransferProtocol encapsulates data necessary to fulfill its protocol.
type TransferProtocol struct {
    node *Node
    lk   sync.RWMutex
    th   TransferHandler
}

type TransferHandler interface {
    HandleFile(*tar.Header, io.Reader)
    Done()
}

func (t *TransferProtocol) RegisterTransferHandler(th TransferHandler) {
    log.Debugln("Registering transfer handler")
    t.lk.Lock()
    defer t.lk.Unlock()
    t.th = th
    t.node.SetStreamHandler(ProtocolTransfer, t.onTransfer)
}

func (t *TransferProtocol) UnregisterTransferHandler() {
    log.Debugln("Unregistering transfer handler")
    t.lk.Lock()
    defer t.lk.Unlock()
    t.node.RemoveStreamHandler(ProtocolTransfer)
    t.th = nil
}

// New TransferProtocol initializes a new TransferProtocol object with all
// fields set to their default values.
func NewTransferProtocol(node *Node) *TransferProtocol {
    return &TransferProtocol{node: node, lk: sync.RWMutex{}}
}

// onTransfer is called when the peer initiates a file transfer.
func (t *TransferProtocol) onTransfer(s network.Stream) {
    defer t.th.Done()
    defer t.node.ResetOnShutdown(s)()

    // Get PAKE session key for stream decryption
    sKey, found := t.node.GetSessionKey(s.Conn().RemotePeer())
    if !found {
        log.Warningln("Received transfer from unauthenticated peer:", s.Conn().RemotePeer())
        s.Reset() // Tell peer to go away
        return
    }

    // Read initialization vector from stream. This is sent first from our peer.
    iv, err := t.node.ReadBytes(s)
    if err != nil {
        log.Warningln("Could not read stream initialization vector", err)
        s.Reset() // Stream is probably broken anyways
        return
    }

    t.lk.RLock()
    defer func() {
        if err = s.Close(); err != nil {
            log.Warningln(err)
        }
        t.lk.RUnlock()
    }()

    // Decrypt the stream
    sd, err := crypt.NewStreamDecrypter(sKey, iv, s)
    if err != nil {
        log.Warningln("Could not instantiate stream decrypter", err)
        return
    }

    // Drain tar archive
    tr := tar.NewReader(sd)
    for {
        hdr, err := tr.Next()
        if err == io.EOF {
            break // End of archive
        } else if err != nil {
            log.Warningln("Error reading next tar element", err)
            return
        }
        t.th.HandleFile(hdr, tr)
    }

    // Read file hash from the stream and check if it matches
    hash, err := t.node.ReadBytes(s)
    if err != nil {
        log.Warningln("Could not read hash", err)
        return
    }

    // Check if hashes match
    if err = sd.Authenticate(hash); err != nil {
        log.Warningln("Could not authenticate received data", err)
        return
    }
}

// Transfer can be called to transfer the given payload to the given peer. The PushRequest is used for displaying
// the progress to the user. This function returns when the bytes where transmitted and we have received an
// acknowledgment.
func (t *TransferProtocol) Transfer(ctx context.Context, peerID peer.ID, basePath string) error {
    // Open a new stream to our peer.
    s, err := t.node.NewStream(ctx, peerID, ProtocolTransfer)
    if err != nil {
        return err
    }

    defer s.Close()
    defer t.node.ResetOnShutdown(s)()

    base, err := os.Stat(basePath)
    if err != nil {
        return err
    }

    // Get PAKE session key for stream encryption
    sKey, found := t.node.GetSessionKey(peerID)
    if !found {
        return fmt.Errorf("session key not found to encrypt data transfer")
    }

    // Initialize new stream encrypter
    se, err := crypt.NewStreamEncrypter(sKey, s)
    if err != nil {
        return err
    }

    // Send the encryption initialization vector to our peer.
    // Does not need to be encrypted, just unique.
    // https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Initialization_vector_.28IV.29
    _, err = t.node.WriteBytes(s, se.InitializationVector())
    if err != nil {
        return err
    }

    tw := tar.NewWriter(se)
    err = filepath.Walk(basePath, func(path string, info os.FileInfo, err error) error {
        log.Debugln("Preparing file for transmission:", path)
        if err != nil {
            log.Debugln("Error walking file:", err)
            return err
        }

        hdr, err := tar.FileInfoHeader(info, "")
        if err != nil {
            return errors.Wrapf(err, "error writing tar file info header %s: %s", path, err)
        }

        // To preserve directory structure in the tar ball.
        hdr.Name, err = relPath(basePath, base.IsDir(), path)
        if err != nil {
            return errors.Wrapf(err, "error building relative path: %s (%v) %s", basePath, base.IsDir(), path)
        }

        if err = tw.WriteHeader(hdr); err != nil {
            return errors.Wrap(err, "error writing tar header")
        }

        // Continue as all information was written above with WriteHeader.
        if info.IsDir() {
            return nil
        }

        f, err := os.Open(path)
        if err != nil {
            return errors.Wrapf(err, "error opening file for taring at: %s", path)
        }
        defer f.Close()

        bar := progress.DefaultBytes(info.Size(), info.Name())
        if _, err = io.Copy(io.MultiWriter(tw, bar), f); err != nil {
            return err
        }

        return nil
    })
    if err != nil {
        return err
    }

    if err = tw.Close(); err != nil {
        log.Debugln("Error closing tar ball", err)
    }

    // Send the hash of all sent data, so our recipient can check the data.
    _, err = t.node.WriteBytes(s, se.Hash())
    if err != nil {
        return errors.Wrap(err, "error writing final hash to stream")
    }

    return t.node.WaitForEOF(s)
}

// relPath builds the path structure for the tar archive - this will be the structure as it is received.
func relPath(basePath string, baseIsDir bool, targetPath string) (string, error) {
    if baseIsDir {
        rel, err := filepath.Rel(basePath, targetPath)
        if err != nil {
            return "", err
        }
        return filepath.Clean(filepath.Join(filepath.Base(basePath), rel)), nil
    } else {
        return filepath.Base(basePath), nil
    }
}