johnsonjh/gfsmux

View on GitHub
session.go

Summary

Maintainability
F
3 days
Test Coverage
package gfsmux

import (
    "container/heap"
    "encoding/binary"
    "errors"
    "io"
    "net"
    "sync"
    "sync/atomic"
    "time"
)

const (
    defaultAcceptBacklog = 2048
)

var (
    // ErrInvalidProtocol version or bag negotiation.
    ErrInvalidProtocol = errors.New(
        "invalid protocol",
    )
    // ErrConsumed protocol error, indicates desync
    ErrConsumed = errors.New(
        "peer consumed more than sent",
    )
    // ErrGoAway overflow condition, restart it all.
    ErrGoAway = errors.New(
        "stream id overflows, should start a new Connection",
    )
    // ErrTimeout ...
    ErrTimeout = &timeoutError{}
    // ErrWouldBlock error for invalid blocking I/O operating
    ErrWouldBlock = errors.New(
        "operation would block on IO",
    )
)

var _ net.Error = &timeoutError{}

type timeoutError struct{}

func (
    e *timeoutError,
) Error() string {
    return "timeout"
}

func (
    e *timeoutError,
) Timeout() bool {
    return true
}

func (
    e *timeoutError,
) Temporary() bool {
    return true
}

// WriteRequest ...
type WriteRequest struct {
    Prio   uint64
    frame  Frame
    result chan writeResult
}

type writeResult struct {
    n   int
    err error
}

type buffersWriter interface {
    WriteBuffers(
        v [][]byte,
    ) (
        n int,
        err error,
    )
}

// Session defines a multiplexed Connection for streams
type Session struct {
    Conn io.ReadWriteCloser

    Config           *Config
    nextStreamID     uint32 // next stream identifier
    nextStreamIDLock sync.Mutex

    bucket       int32         // token bucket
    bucketNotify chan struct{} // used for waiting for tokens

    streams    map[uint32]*Stream // all streams in this session
    streamLock sync.Mutex         // locks streams

    die     chan struct{} // flag session has died
    dieOnce sync.Once

    // socket error handling
    socketReadError      atomic.Value
    socketWriteError     atomic.Value
    chSocketReadError    chan struct{}
    chSocketWriteError   chan struct{}
    socketReadErrorOnce  sync.Once
    socketWriteErrorOnce sync.Once

    // smux protocol errors
    protoError     atomic.Value
    chProtoError   chan struct{}
    protoErrorOnce sync.Once

    chAccepts chan *Stream

    dataReady int32 // flag data has arrived

    goAway int32 // flag id exhausted

    deadline atomic.Value

    shaper chan WriteRequest // a shaper for writing
    writes chan WriteRequest
}

func newSession(
    Config *Config,
    Conn io.ReadWriteCloser,
    client bool,
) *Session {
    s := new(
        Session,
    )
    s.die = make(
        chan struct{},
    )
    s.Conn = Conn
    s.Config = Config
    s.streams = make(
        map[uint32]*Stream,
    )
    s.chAccepts = make(
        chan *Stream,
        defaultAcceptBacklog,
    )
    s.bucket = int32(
        Config.MaxReceiveBuffer,
    )
    s.bucketNotify = make(
        chan struct{},
        1,
    )
    s.shaper = make(
        chan WriteRequest,
    )
    s.writes = make(
        chan WriteRequest,
    )
    s.chSocketReadError = make(
        chan struct{},
    )
    s.chSocketWriteError = make(
        chan struct{},
    )
    s.chProtoError = make(
        chan struct{},
    )

    if client {
        s.nextStreamID = 1
    } else {
        s.nextStreamID = 0
    }

    go s.shaperLoop()
    go s.recvLoop()
    go s.sendLoop()
    if !Config.KeepAliveDisabled {
        go s.keepalive()
    }
    return s
}

// OpenStream is used to create a new stream
func (
    s *Session,
) OpenStream() (
    *Stream,
    error,
) {
    if s.IsClosed() {
        return nil, io.ErrClosedPipe
    }

    // generate stream id
    s.nextStreamIDLock.Lock()
    if s.goAway > 0 {
        s.nextStreamIDLock.Unlock()
        return nil, ErrGoAway
    }

    s.nextStreamID += 2
    Sid := s.nextStreamID
    if Sid == Sid%2 { // stream-id overflows
        s.goAway = 1
        s.nextStreamIDLock.Unlock()
        return nil, ErrGoAway
    }
    s.nextStreamIDLock.Unlock()

    stream := newStream(
        Sid,
        s.Config.MaxFrameSize,
        s,
    )

    if _, err := s.WriteFrame(
        NewFrame(
            byte(s.Config.Version),
            CmdSyn,
            Sid,
        ),
    ); err != nil {
        return nil, err
    }

    s.streamLock.Lock()
    defer s.streamLock.Unlock()
    select {
    case <-s.chSocketReadError:
        return nil, s.socketReadError.Load().(error)
    case <-s.chSocketWriteError:
        return nil, s.socketWriteError.Load().(error)
    case <-s.die:
        return nil, io.ErrClosedPipe
    default:
        s.streams[Sid] = stream
        return stream, nil
    }
}

// Open returns a generic ReadWriteCloser
func (
    s *Session,
) Open() (
    io.ReadWriteCloser,
    error,
) {
    return s.OpenStream()
}

// AcceptStream is used to block until the next available stream
// is ready to be accepted.
func (
    s *Session,
) AcceptStream() (
    *Stream,
    error,
) {
    var deadline <-chan time.Time
    if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
        timer := time.NewTimer(
            time.Until(
                d,
            ),
        )
        defer timer.Stop()
        deadline = timer.C
    }

    select {
    case stream := <-s.chAccepts:
        return stream, nil
    case <-deadline:
        return nil, ErrTimeout
    case <-s.chSocketReadError:
        return nil, s.socketReadError.Load().(error)
    case <-s.chProtoError:
        return nil, s.protoError.Load().(error)
    case <-s.die:
        return nil, io.ErrClosedPipe
    }
}

// Accept Returns a generic ReadWriteCloser instead of smux.Stream
func (
    s *Session,
) Accept() (
    io.ReadWriteCloser,
    error,
) {
    return s.AcceptStream()
}

// Close is used to close the session and all streams.
func (
    s *Session,
) Close() error {
    var once bool
    s.dieOnce.Do(func() {
        close(
            s.die,
        )
        once = true
    })

    if once {
        s.streamLock.Lock()
        for k := range s.streams {
            s.streams[k].sessionClose()
        }
        s.streamLock.Unlock()
        return s.Conn.Close()
    }
    return io.ErrClosedPipe
}

// notifyBucket notifies recvLoop that bucket is available
func (
    s *Session,
) notifyBucket() {
    select {
    case s.bucketNotify <- struct{}{}:
    default:
    }
}

func (
    s *Session,
) notifyReadError(
    err error,
) {
    s.socketReadErrorOnce.Do(func() {
        s.socketReadError.Store(
            err,
        )
        close(
            s.chSocketReadError,
        )
    })
}

func (
    s *Session,
) notifyWriteError(
    err error,
) {
    s.socketWriteErrorOnce.Do(func() {
        s.socketWriteError.Store(
            err,
        )
        close(
            s.chSocketWriteError,
        )
    })
}

func (
    s *Session,
) notifyProtoError(
    err error,
) {
    s.protoErrorOnce.Do(func() {
        s.protoError.Store(
            err,
        )
        close(
            s.chProtoError,
        )
    })
}

// IsClosed does a safe check to see if we have shutdown
func (
    s *Session,
) IsClosed() bool {
    select {
    case <-s.die:
        return true
    default:
        return false
    }
}

// NumStreams returns the number of currently open streams
func (
    s *Session,
) NumStreams() int {
    if s.IsClosed() {
        return 0
    }
    s.streamLock.Lock()
    defer s.streamLock.Unlock()
    return len(
        s.streams,
    )
}

// SetDeadline sets a deadline used by Accept* calls.
// A zero time value disables the deadline.
func (
    s *Session,
) SetDeadline(
    t time.Time,
) error {
    s.deadline.Store(
        t,
    )
    return nil
}

// LocalAddr satisfies net.Conn interface
func (
    s *Session,
) LocalAddr() net.Addr {
    if ts, ok := s.Conn.(interface {
        LocalAddr() net.Addr
    }); ok {
        return ts.LocalAddr()
    }
    return nil
}

// RemoteAddr satisfies net.Conn interface
func (
    s *Session,
) RemoteAddr() net.Addr {
    if ts, ok := s.Conn.(interface {
        RemoteAddr() net.Addr
    }); ok {
        return ts.RemoteAddr()
    }
    return nil
}

// notify the session that a stream has closed
func (
    s *Session,
) streamClosed(
    Sid uint32,
) {
    s.streamLock.Lock()
    // return remaining tokens to the bucket
    if n := s.streams[Sid].recycleTokens(); n > 0 {
        if atomic.AddInt32(
            &s.bucket,
            int32(n),
        ) > 0 {
            s.notifyBucket()
        }
    }
    delete(
        s.streams,
        Sid,
    )
    s.streamLock.Unlock()
}

// returnTokens is called by stream to return token after read
func (
    s *Session,
) returnTokens(
    n int,
) {
    if atomic.AddInt32(
        &s.bucket,
        int32(n),
    ) > 0 {
        s.notifyBucket()
    }
}

// recvLoop keeps on reading from underlying Connection if tokens are available
func (
    s *Session,
) recvLoop() {
    var hdr rawHeader
    var updHdr updHeader

    for {
        for atomic.LoadInt32(
            &s.bucket,
        ) <= 0 && !s.IsClosed() {
            select {
            case <-s.bucketNotify:
            case <-s.die:
                return
            }
        }

        // read header first
        if _, err := io.ReadFull(
            s.Conn,
            hdr[:],
        ); err == nil {
            atomic.StoreInt32(
                &s.dataReady,
                1,
            )
            if hdr.Version() != byte(
                s.Config.Version,
            ) {
                s.notifyProtoError(
                    ErrInvalidProtocol,
                )
                return
            }
            Sid := hdr.StreamID()
            switch hdr.Cmd() {
            case CmdNop:
            case CmdSyn:
                s.streamLock.Lock()
                if _, ok := s.streams[Sid]; !ok {
                    stream := newStream(
                        Sid,
                        s.Config.MaxFrameSize,
                        s,
                    )
                    s.streams[Sid] = stream
                    select {
                    case s.chAccepts <- stream:
                    case <-s.die:
                    }
                }
                s.streamLock.Unlock()
            case CmdFin:
                s.streamLock.Lock()
                if stream, ok := s.streams[Sid]; ok {
                    stream.fin()
                    stream.notifyReadEvent()
                }
                s.streamLock.Unlock()
            case CmdPsh:
                if hdr.Length() > 0 {
                    newbuf := defaultAllocator.Get(
                        int(hdr.Length()),
                    )
                    if written, err := io.ReadFull(
                        s.Conn,
                        newbuf,
                    ); err == nil {
                        s.streamLock.Lock()
                        if stream, ok := s.streams[Sid]; ok {
                            stream.pushBytes(
                                newbuf,
                            )
                            atomic.AddInt32(
                                &s.bucket,
                                -int32(written),
                            )
                            stream.notifyReadEvent()
                        }
                        s.streamLock.Unlock()
                    } else {
                        s.notifyReadError(
                            err,
                        )
                        return
                    }
                }
            case CmdUpd:
                if _, err := io.ReadFull(
                    s.Conn,
                    updHdr[:],
                ); err == nil {
                    s.streamLock.Lock()
                    if stream, ok := s.streams[Sid]; ok {
                        stream.update(
                            updHdr.Consumed(),
                            updHdr.Window(),
                        )
                    }
                    s.streamLock.Unlock()
                } else {
                    s.notifyReadError(
                        err,
                    )
                    return
                }
            default:
                s.notifyProtoError(
                    ErrInvalidProtocol,
                )
                return
            }
        } else {
            s.notifyReadError(
                err,
            )
            return
        }
    }
}

func (
    s *Session,
) keepalive() {
    tickerPing := time.NewTicker(
        s.Config.KeepAliveInterval,
    )
    tickerTimeout := time.NewTicker(
        s.Config.KeepAliveTimeout,
    )
    defer tickerPing.Stop()
    defer tickerTimeout.Stop()
    for {
        select {
        case <-tickerPing.C:
            s.WriteFrameInternal(
                NewFrame(
                    byte(s.Config.Version),
                    CmdNop,
                    0,
                ),
                tickerPing.C,
                0,
            )
            s.notifyBucket() // force a signal to the recvLoop
        case <-tickerTimeout.C:
            if !atomic.CompareAndSwapInt32(
                &s.dataReady,
                1,
                0,
            ) {
                // recvLoop may block while bucket is 0, in this case,
                // session should not be closed.
                if atomic.LoadInt32(
                    &s.bucket,
                ) > 0 {
                    s.Close()
                    return
                }
            }
        case <-s.die:
            return
        }
    }
}

// shaper shapes the sending sequence among streams
func (
    s *Session,
) shaperLoop() {
    var reqs ShaperHeap
    var next WriteRequest
    var chWrite chan WriteRequest

    for {
        if len(
            reqs,
        ) > 0 {
            chWrite = s.writes
            next = heap.Pop(&reqs).(WriteRequest)
        } else {
            chWrite = nil
        }

        select {
        case <-s.die:
            return
        case r := <-s.shaper:
            if chWrite != nil { // next is valid, reshape
                heap.Push(
                    &reqs,
                    next,
                )
            }
            heap.Push(
                &reqs,
                r,
            )
        case chWrite <- next:
        }
    }
}

func (
    s *Session,
) sendLoop() {
    var buf []byte
    var n int
    var err error
    var vec [][]byte // vector for writeBuffers

    bw, ok := s.Conn.(buffersWriter)
    if ok {
        buf = make([]byte, HeaderSize)
        vec = make([][]byte, 2)
    } else {
        buf = make([]byte, (1<<16)+HeaderSize)
    }

    for {
        select {
        case <-s.die:
            return
        case request := <-s.writes:
            buf[0] = request.frame.Ver
            buf[1] = request.frame.Cmd
            binary.LittleEndian.PutUint16(
                buf[2:],
                uint16(
                    len(
                        request.frame.Data,
                    ),
                ),
            )
            binary.LittleEndian.PutUint32(
                buf[4:],
                request.frame.Sid,
            )

            if len(
                vec,
            ) > 0 {
                vec[0] = buf[:HeaderSize]
                vec[1] = request.frame.Data
                n, err = bw.WriteBuffers(
                    vec,
                )
            } else {
                copy(
                    buf[HeaderSize:],
                    request.frame.Data,
                )
                n, err = s.Conn.Write(
                    buf[:HeaderSize+len(request.frame.Data)],
                )
            }

            n -= HeaderSize
            if n < 0 {
                n = 0
            }

            result := writeResult{
                n:   n,
                err: err,
            }

            request.result <- result
            close(
                request.result,
            )

            // store Conn error
            if err != nil {
                s.notifyWriteError(
                    err,
                )
                return
            }
        }
    }
}

// WriteFrame writes the frame to the underlying Connection
// and returns the number of bytes written if successful
func (
    s *Session,
) WriteFrame(
    f Frame,
) (
    n int,
    err error,
) {
    return s.WriteFrameInternal(
        f,
        nil,
        0,
    )
}

// WriteFrameInternal is to support deadline used in keepalive
func (
    s *Session,
) WriteFrameInternal(
    f Frame,
    deadline <-chan time.Time,
    Prio uint64,
) (
    int,
    error,
) {
    req := WriteRequest{
        Prio:  Prio,
        frame: f,
        result: make(
            chan writeResult,
            1,
        ),
    }
    select {
    case s.shaper <- req:
    case <-s.die:
        return 0, io.ErrClosedPipe
    case <-s.chSocketWriteError:
        return 0, s.socketWriteError.Load().(error)
    case <-deadline:
        return 0, ErrTimeout
    }

    select {
    case result := <-req.result:
        return result.n, result.err
    case <-s.die:
        return 0, io.ErrClosedPipe
    case <-s.chSocketWriteError:
        return 0, s.socketWriteError.Load().(error)
    case <-deadline:
        return 0, ErrTimeout
    }
}