johnsonjh/gfsmux

View on GitHub
stream.go

Summary

Maintainability
D
2 days
Test Coverage
package gfsmux

import (
    "encoding/binary"
    "io"
    "net"
    "sync"
    "sync/atomic"
    "time"
)

// Stream implements net.Conn
type Stream struct {
    id   uint32
    sess *Session

    // slice heads kept for recycle
    buffers [][]byte
    heads   [][]byte

    bufferLock sync.Mutex
    frameSize  int

    // notify a read event
    chReadEvent chan struct{}

    // flag the stream has closed
    die     chan struct{}
    dieOnce sync.Once

    // FIN command
    chFinEvent   chan struct{}
    finEventOnce sync.Once

    // deadlines
    readDeadline  atomic.Value
    writeDeadline atomic.Value

    // per stream sliding window control
    numRead    uint32 // number of consumed bytes
    numWritten uint32 // count num of bytes written
    incr       uint32 // counting for sending

    // UPD (Update) command
    peerConsumed uint32        // num of bytes remote peer has consumed
    peerWindow   uint32        // peer window, init to 256KB, updated by peer
    chUpdate     chan struct{} // notify remote data consumed and window update
}

// newStream initiates a Stream struct
func newStream(
    id uint32,
    frameSize int,
    sess *Session,
) *Stream {
    s := new(
        Stream,
    )
    s.id = id
    s.chReadEvent = make(
        chan struct{},
        1,
    )
    s.chUpdate = make(
        chan struct{},
        1,
    )
    s.frameSize = frameSize
    s.sess = sess
    s.die = make(
        chan struct{},
    )
    s.chFinEvent = make(
        chan struct{},
    )
    s.peerWindow = initialPeerWindow // set to initial window size
    return s
}

// ID returns the unique stream ID.
func (
    s *Stream,
) ID() uint32 {
    return s.id
}

// Read implements net.Conn
func (
    s *Stream,
) Read(
    b []byte,
) (
    n int,
    err error,
) {
    for {
        n, err = s.tryRead(
            b,
        )
        if err == ErrWouldBlock {
            if ew := s.waitRead(); ew != nil {
                return 0, ew
            }
        } else {
            return n, err
        }
    }
}

// tryRead is the nonblocking version of Read
func (
    s *Stream,
) tryRead(
    b []byte,
) (
    n int,
    err error,
) {
    if s.sess.Config.Version == 2 {
        return s.tryReadv2(
            b,
        )
    }

    if len(
        b,
    ) == 0 {
        return 0, nil
    }

    s.bufferLock.Lock()
    if len(
        s.buffers,
    ) > 0 {
        n = copy(
            b,
            s.buffers[0],
        )
        s.buffers[0] = s.buffers[0][n:]
        if len(
            s.buffers[0],
        ) == 0 {
            s.buffers[0] = nil
            s.buffers = s.buffers[1:]
            defaultAllocator.Put(
                s.heads[0],
            )
            s.heads = s.heads[1:]
        }
    }
    s.bufferLock.Unlock()

    if n > 0 {
        s.sess.returnTokens(
            n,
        )
        return n, nil
    }

    select {
    case <-s.die:
        return 0, io.EOF
    default:
        return 0, ErrWouldBlock
    }
}

func (
    s *Stream,
) tryReadv2(
    b []byte,
) (
    n int,
    err error,
) {
    if len(
        b,
    ) == 0 {
        return 0, nil
    }

    var notifyConsumed uint32
    s.bufferLock.Lock()
    if len(
        s.buffers,
    ) > 0 {
        n = copy(
            b,
            s.buffers[0],
        )
        s.buffers[0] = s.buffers[0][n:]
        if len(
            s.buffers[0],
        ) == 0 {
            s.buffers[0] = nil
            s.buffers = s.buffers[1:]
            // full recycle
            defaultAllocator.Put(
                s.heads[0],
            )
            s.heads = s.heads[1:]
        }
    }

    // ideally, if more than half of buffer has consumed, send read ack to
    // peer based on round-trip time of ACK, continuous flowing data won't
    // slow down because of waiting for ACK, as long as the consumer keeps on
    // reading data s.numRead == n ... also notify window at the first read
    s.numRead += uint32(n)
    s.incr += uint32(n)
    if s.incr >= uint32(s.sess.Config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
        notifyConsumed = s.numRead
        s.incr = 0
    }
    s.bufferLock.Unlock()

    if n > 0 {
        s.sess.returnTokens(
            n,
        )
        if notifyConsumed > 0 {
            err := s.sendWindowUpdate(
                notifyConsumed,
            )
            return n, err
        }
        return n, nil
    }

    select {
    case <-s.die:
        return 0, io.EOF
    default:
        return 0, ErrWouldBlock
    }
}

// WriteTo implements io.WriteTo
func (
    s *Stream,
) WriteTo(
    w io.Writer,
) (
    n int64,
    err error,
) {
    if s.sess.Config.Version == 2 {
        return s.writeTov2(
            w,
        )
    }

    for {
        var buf []byte
        s.bufferLock.Lock()
        if len(
            s.buffers,
        ) > 0 {
            buf = s.buffers[0]
            s.buffers = s.buffers[1:]
            s.heads = s.heads[1:]
        }
        s.bufferLock.Unlock()

        if buf != nil {
            nw, ew := w.Write(
                buf,
            )
            s.sess.returnTokens(
                len(
                    buf,
                ),
            )
            defaultAllocator.Put(
                buf,
            )
            if nw > 0 {
                n += int64(nw)
            }

            if ew != nil {
                return n, ew
            }
        } else if ew := s.waitRead(); ew != nil {
            return n, ew
        }
    }
}

func (
    s *Stream,
) writeTov2(
    w io.Writer,
) (
    n int64,
    err error,
) {
    for {
        var notifyConsumed uint32
        var buf []byte
        s.bufferLock.Lock()
        if len(
            s.buffers,
        ) > 0 {
            buf = s.buffers[0]
            s.buffers = s.buffers[1:]
            s.heads = s.heads[1:]
        }
        s.numRead += uint32(
            len(
                buf,
            ),
        )
        s.incr += uint32(
            len(
                buf,
            ),
        )
        if s.incr >= uint32(s.sess.Config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) {
            notifyConsumed = s.numRead
            s.incr = 0
        }
        s.bufferLock.Unlock()

        if buf != nil {
            nw, ew := w.Write(
                buf,
            )
            s.sess.returnTokens(
                len(
                    buf,
                ),
            )
            defaultAllocator.Put(
                buf,
            )
            if nw > 0 {
                n += int64(nw)
            }

            if ew != nil {
                return n, ew
            }

            if notifyConsumed > 0 {
                if err := s.sendWindowUpdate(
                    notifyConsumed,
                ); err != nil {
                    return n, err
                }
            }
        } else if ew := s.waitRead(); ew != nil {
            return n, ew
        }
    }
}

func (
    s *Stream,
) sendWindowUpdate(
    consumed uint32,
) error {
    var timer *time.Timer
    var deadline <-chan time.Time
    if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
        timer = time.NewTimer(
            time.Until(
                d,
            ),
        )
        defer timer.Stop()
        deadline = timer.C
    }

    frame := NewFrame(
        byte(s.sess.Config.Version),
        CmdUpd,
        s.id,
    )
    var hdr updHeader
    binary.LittleEndian.PutUint32(
        hdr[:],
        consumed,
    )
    binary.LittleEndian.PutUint32(
        hdr[4:],
        uint32(s.sess.Config.MaxStreamBuffer),
    )
    frame.Data = hdr[:]
    _, err := s.sess.WriteFrameInternal(
        frame,
        deadline,
        0,
    )
    return err
}

func (
    s *Stream,
) waitRead() error {
    var timer *time.Timer
    var deadline <-chan time.Time
    if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
        timer = time.NewTimer(
            time.Until(
                d,
            ),
        )
        defer timer.Stop()
        deadline = timer.C
    }

    select {
    case <-s.chReadEvent:
        return nil
    case <-s.chFinEvent:
        // BUG(xtaci): Fix for https://github.com/xtaci/smux/issues/82 // XXX
        s.bufferLock.Lock()
        defer s.bufferLock.Unlock()
        if len(
            s.buffers,
        ) > 0 {
            return nil
        }
        return io.EOF
    case <-s.sess.chSocketReadError:
        return s.sess.socketReadError.Load().(error)
    case <-s.sess.chProtoError:
        return s.sess.protoError.Load().(error)
    case <-deadline:
        return ErrTimeout
    case <-s.die:
        return io.ErrClosedPipe
    }
}

// Write implements net.Conn
//
// Behavior when multiple concurrent goroutines write is not deterministic,
// so the frames will interleave in random ways.
func (
    s *Stream,
) Write(
    b []byte,
) (
    n int,
    err error,
) {
    if s.sess.Config.Version == 2 {
        return s.writeV2(
            b,
        )
    }

    var deadline <-chan time.Time
    if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
        timer := time.NewTimer(
            time.Until(
                d,
            ),
        )
        defer timer.Stop()
        deadline = timer.C
    }

    // check if stream has closed
    select {
    case <-s.die:
        return 0, io.ErrClosedPipe
    default:
    }

    // frame split and transmit
    sent := 0
    frame := NewFrame(
        byte(s.sess.Config.Version),
        CmdPsh,
        s.id,
    )
    bts := b
    for len(
        bts,
    ) > 0 {
        sz := len(
            bts,
        )
        if sz > s.frameSize {
            sz = s.frameSize
        }
        frame.Data = bts[:sz]
        bts = bts[sz:]
        n, err := s.sess.WriteFrameInternal(
            frame,
            deadline,
            uint64(s.numWritten),
        )
        s.numWritten++
        sent += n
        if err != nil {
            return sent, err
        }
    }

    return sent, nil
}

func (
    s *Stream,
) writeV2(
    b []byte,
) (
    n int,
    err error,
) {
    // check for empty input
    if len(b) == 0 {
        return 0, nil
    }

    // check if stream has closed
    select {
    case <-s.die:
        return 0, io.ErrClosedPipe
    default:
    }

    // create write deadline timer
    var deadline <-chan time.Time
    if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
        timer := time.NewTimer(
            time.Until(
                d,
            ),
        )
        defer timer.Stop()
        deadline = timer.C
    }

    // frame split and transmit process
    sent := 0
    frame := NewFrame(
        byte(s.sess.Config.Version),
        CmdPsh,
        s.id,
    )

    for {
        // per stream sliding window control
        // [.... [consumed... numWritten] ... win... ]
        // [.... [consumed...................+rmtwnd]]
        var bts []byte
        // note:
        // even if uint32 overflow, this math still works:
        // eg1: uint32(0) - uint32(math.MaxUint32) = 1
        // eg2: int32(uint32(0) - uint32(1)) = -1
        // security check for misbehavior
        inflight := int32(atomic.LoadUint32(
            &s.numWritten,
        ) - atomic.LoadUint32(
            &s.peerConsumed,
        ))
        if inflight < 0 {
            return 0, ErrConsumed
        }

        win := int32(atomic.LoadUint32(
            &s.peerWindow,
        )) - inflight
        if win > 0 {
            if win > int32(len(
                b,
            )) {
                bts = b
                b = nil
            } else {
                bts = b[:win]
                b = b[win:]
            }

            for len(
                bts,
            ) > 0 {
                sz := len(
                    bts,
                )
                if sz > s.frameSize {
                    sz = s.frameSize
                }
                frame.Data = bts[:sz]
                bts = bts[sz:]
                n, err := s.sess.WriteFrameInternal(
                    frame,
                    deadline,
                    uint64(atomic.LoadUint32(
                        &s.numWritten,
                    ),
                    ),
                )
                atomic.AddUint32(
                    &s.numWritten,
                    uint32(sz),
                )
                sent += n
                if err != nil {
                    return sent, err
                }
            }
        }

        // if there is any data remaining to be sent
        // wait until stream closes, window changes or deadline reached
        // this blocking behavior will inform upper layer to do flow control
        if len(
            b,
        ) > 0 {
            select {
            case <-s.chFinEvent: // if fin arrived, future window update is impossible
                return 0, io.EOF
            case <-s.die:
                return sent, io.ErrClosedPipe
            case <-deadline:
                return sent, ErrTimeout
            case <-s.sess.chSocketWriteError:
                return sent, s.sess.socketWriteError.Load().(error)
            case <-s.chUpdate:
                continue
            }
        } else {
            return sent, nil
        }
    }
}

// Close implements net.Conn
func (
    s *Stream,
) Close() error {
    var once bool
    var err error
    s.dieOnce.Do(func() {
        close(
            s.die,
        )
        once = true
    })

    if once {
        _, err = s.sess.WriteFrame(
            NewFrame(
                byte(s.sess.Config.Version),
                CmdFin,
                s.id,
            ),
        )
        s.sess.streamClosed(
            s.id,
        )
        return err
    }
    return io.ErrClosedPipe
}

// GetDieCh returns a readonly chan which can be readable
// when the stream is to be closed.
func (
    s *Stream,
) GetDieCh() <-chan struct{} {
    return s.die
}

// SetReadDeadline sets the read deadline as defined by
// net.Conn.SetReadDeadline.
// A zero time value disables the deadline.
func (
    s *Stream,
) SetReadDeadline(
    t time.Time,
) error {
    s.readDeadline.Store(
        t,
    )
    s.notifyReadEvent()
    return nil
}

// SetWriteDeadline sets the write deadline as defined by
// net.Conn.SetWriteDeadline.
// A zero time value disables the deadline.
func (
    s *Stream,
) SetWriteDeadline(
    t time.Time,
) error {
    s.writeDeadline.Store(
        t,
    )
    return nil
}

// SetDeadline sets both read and write deadlines as defined by
// net.Conn.SetDeadline.
// A zero time value disables the deadlines.
func (
    s *Stream,
) SetDeadline(
    t time.Time,
) error {
    if err := s.SetReadDeadline(
        t,
    ); err != nil {
        return err
    }
    if err := s.SetWriteDeadline(
        t,
    ); err != nil {
        return err
    }
    return nil
}

// session closes
func (
    s *Stream,
) sessionClose() {
    s.dieOnce.Do(func() {
        close(
            s.die,
        )
    })
}

// LocalAddr satisfies net.Conn interface
func (
    s *Stream,
) LocalAddr() net.Addr {
    if ts, ok := s.sess.Conn.(interface {
        LocalAddr() net.Addr
    }); ok {
        return ts.LocalAddr()
    }
    return nil
}

// RemoteAddr satisfies net.Conn interface
func (s *Stream) RemoteAddr() net.Addr {
    if ts, ok := s.sess.Conn.(interface {
        RemoteAddr() net.Addr
    }); ok {
        return ts.RemoteAddr()
    }
    return nil
}

// pushBytes append buf to buffers
func (
    s *Stream,
) pushBytes(
    buf []byte,
) (
    written int,
    err error,
) {
    s.bufferLock.Lock()
    s.buffers = append(
        s.buffers,
        buf,
    )
    s.heads = append(
        s.heads,
        buf,
    )
    s.bufferLock.Unlock()
    return
}

// recycleTokens transform remaining bytes to tokens(will truncate buffer)
func (
    s *Stream,
) recycleTokens() (
    n int,
) {
    s.bufferLock.Lock()
    for k := range s.buffers {
        n += len(
            s.buffers[k],
        )
        defaultAllocator.Put(
            s.heads[k],
        )
    }
    s.buffers = nil
    s.heads = nil
    s.bufferLock.Unlock()
    return
}

// notify read event
func (
    s *Stream,
) notifyReadEvent() {
    select {
    case s.chReadEvent <- struct{}{}:
    default:
    }
}

// update command
func (
    s *Stream,
) update(
    consumed,
    window uint32,
) {
    atomic.StoreUint32(
        &s.peerConsumed,
        consumed,
    )
    atomic.StoreUint32(
        &s.peerWindow,
        window,
    )
    select {
    case s.chUpdate <- struct{}{}:
    default:
    }
}

// mark this stream has been closed in protocol
func (
    s *Stream,
) fin() {
    s.finEventOnce.Do(func() {
        close(
            s.chFinEvent,
        )
    })
}