dennis-tra/pcp

View on GitHub
pkg/node/push.go

Summary

Maintainability
A
35 mins
Test Coverage
package node

import (
    "context"
    "sync"

    "github.com/libp2p/go-libp2p-core/network"
    "github.com/libp2p/go-libp2p-core/peer"

    "github.com/dennis-tra/pcp/internal/log"
    p2p "github.com/dennis-tra/pcp/pkg/pb"
)

// pattern: /protocol-name/request-or-response-message/version
const ProtocolPushRequest = "/pcp/push/0.0.1"

// PushProtocol .
type PushProtocol struct {
    node *Node
    lk   sync.RWMutex
    prh  PushRequestHandler
}

type PushRequestHandler interface {
    HandlePushRequest(*p2p.PushRequest) (bool, error)
}

func NewPushProtocol(node *Node) *PushProtocol {
    return &PushProtocol{node: node, lk: sync.RWMutex{}}
}

func (p *PushProtocol) RegisterPushRequestHandler(prh PushRequestHandler) {
    log.Debugln("Registering push request handler")
    p.lk.Lock()
    defer p.lk.Unlock()
    p.prh = prh
    p.node.SetStreamHandler(ProtocolPushRequest, p.onPushRequest)
}

func (p *PushProtocol) UnregisterPushRequestHandler() {
    log.Debugln("Unregistering push request handler")
    p.lk.Lock()
    defer p.lk.Unlock()
    p.node.RemoveStreamHandler(ProtocolPushRequest)
    p.prh = nil
}

func (p *PushProtocol) onPushRequest(s network.Stream) {
    defer s.Close()
    defer p.node.ResetOnShutdown(s)()

    if !p.node.IsAuthenticated(s.Conn().RemotePeer()) {
        log.Infoln("Received push request from unauthenticated peer")
        s.Reset() // Tell peer to go away
        return
    }

    req := &p2p.PushRequest{}
    if err := p.node.Read(s, req); err != nil {
        log.Infoln(err)
        return
    }
    log.Debugln("Received push request", req.Name, req.Size)

    p.lk.RLock()
    defer p.lk.RUnlock()
    accept, err := p.prh.HandlePushRequest(req)
    if err != nil {
        log.Infoln(err)
        accept = false
        // Fall through and tell peer we won't handle the request
    }

    if err := p.node.Send(s, p2p.NewPushResponse(accept)); err != nil {
        log.Infoln(err)
        return
    }

    if err = p.node.WaitForEOF(s); err != nil {
        log.Infoln(err)
        return
    }
}

func (p *PushProtocol) SendPushRequest(ctx context.Context, peerID peer.ID, filename string, size int64, isDir bool) (bool, error) {
    s, err := p.node.NewStream(ctx, peerID, ProtocolPushRequest)
    if err != nil {
        return false, err
    }
    defer s.Close()

    log.Debugln("Sending push request", filename, size)
    if err = p.node.Send(s, p2p.NewPushRequest(filename, size, isDir)); err != nil {
        return false, err
    }

    resp := &p2p.PushResponse{}
    if err = p.node.Read(s, resp); err != nil {
        return false, err
    }

    return resp.Accept, nil
}