johnsonjh/jleveldb

View on GitHub
leveldb/batch.go

Summary

Maintainability
A
1 hr
Test Coverage
// Copyright © 2012, Suryandaru Triandana <syndtr@gmail.com>
// Copyright © 2021, Jeffrey H. Johnson <trnsz@pobox.com>
//
// All rights reserved.
//
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

package leveldb

import (
    "encoding/binary"
    "fmt"
    "io"

    "github.com/johnsonjh/jleveldb/leveldb/errors"
    "github.com/johnsonjh/jleveldb/leveldb/memdb"
    "github.com/johnsonjh/jleveldb/leveldb/storage"
)

// ErrBatchCorrupted records reason of batch corruption. This error will be
// wrapped with errors.ErrCorrupted.
type ErrBatchCorrupted struct {
    Reason string
}

func (e *ErrBatchCorrupted) Error() string {
    return fmt.Sprintf("leveldb: batch corrupted: %s", e.Reason)
}

func newErrBatchCorrupted(reason string) error {
    return errors.NewErrCorrupted(storage.FileDesc{}, &ErrBatchCorrupted{reason})
}

const (
    batchHeaderLen = 8 + 4
    batchGrowRec   = 3000
)

// BatchReplay wraps basic batch operations.
type BatchReplay interface {
    Put(key, value []byte)
    Delete(key []byte)
}

type batchIndex struct {
    keyType            keyType
    keyPos, keyLen     int
    valuePos, valueLen int
}

func (index batchIndex) k(data []byte) []byte {
    return data[index.keyPos : index.keyPos+index.keyLen]
}

func (index batchIndex) v(data []byte) []byte {
    if index.valueLen != 0 {
        return data[index.valuePos : index.valuePos+index.valueLen]
    }
    return nil
}

// Batch is a write batch.
type Batch struct {
    data  []byte
    index []batchIndex

    // internalLen is sums of key/value pair length plus 8-bytes internal key.
    internalLen int
}

func (b *Batch) grow(n int) {
    o := len(b.data)
    if cap(b.data)-o < n {
        div := 1
        if len(b.index) > batchGrowRec {
            div = len(b.index) / batchGrowRec
        }
        ndata := make([]byte, o, o+n+o/div)
        copy(ndata, b.data)
        b.data = ndata
    }
}

func (b *Batch) appendRec(kt keyType, key, value []byte) {
    n := 1 + binary.MaxVarintLen32 + len(key)
    if kt == keyTypeVal {
        n += binary.MaxVarintLen32 + len(value)
    }
    b.grow(n)
    index := batchIndex{keyType: kt}
    o := len(b.data)
    data := b.data[:o+n]
    data[o] = byte(kt)
    o++
    o += binary.PutUvarint(data[o:], uint64(len(key)))
    index.keyPos = o
    index.keyLen = len(key)
    o += copy(data[o:], key)
    if kt == keyTypeVal {
        o += binary.PutUvarint(data[o:], uint64(len(value)))
        index.valuePos = o
        index.valueLen = len(value)
        o += copy(data[o:], value)
    }
    b.data = data[:o]
    b.index = append(b.index, index)
    b.internalLen += index.keyLen + index.valueLen + 8
}

// Put appends 'put operation' of the given key/value pair to the batch.
// It is safe to modify the contents of the argument after Put returns but not
// before.
func (b *Batch) Put(key, value []byte) {
    b.appendRec(keyTypeVal, key, value)
}

// Delete appends 'delete operation' of the given key to the batch.
// It is safe to modify the contents of the argument after Delete returns but
// not before.
func (b *Batch) Delete(key []byte) {
    b.appendRec(keyTypeDel, key, nil)
}

// Dump dumps batch contents. The returned slice can be loaded into the
// batch using Load method.
// The returned slice is not its own copy, so the contents should not be
// modified.
func (b *Batch) Dump() []byte {
    return b.data
}

// Load loads given slice into the batch. Previous contents of the batch
// will be discarded.
// The given slice will not be copied and will be used as batch buffer, so
// it is not safe to modify the contents of the slice.
func (b *Batch) Load(data []byte) error {
    return b.decode(data, -1)
}

// Replay replays batch contents.
func (b *Batch) Replay(r BatchReplay) error {
    for _, index := range b.index {
        switch index.keyType {
        case keyTypeVal:
            r.Put(index.k(b.data), index.v(b.data))
        case keyTypeDel:
            r.Delete(index.k(b.data))
        }
    }
    return nil
}

// Len returns number of records in the batch.
func (b *Batch) Len() int {
    return len(b.index)
}

// Reset resets the batch.
func (b *Batch) Reset() {
    b.data = b.data[:0]
    b.index = b.index[:0]
    b.internalLen = 0
}

func (b *Batch) replayInternal(fn func(i int, kt keyType, k, v []byte) error) error {
    for i, index := range b.index {
        if err := fn(i, index.keyType, index.k(b.data), index.v(b.data)); err != nil {
            return err
        }
    }
    return nil
}

func (b *Batch) append(p *Batch) {
    ob := len(b.data)
    oi := len(b.index)
    b.data = append(b.data, p.data...)
    b.index = append(b.index, p.index...)
    b.internalLen += p.internalLen

    // Updating index offset.
    if ob != 0 {
        for ; oi < len(b.index); oi++ {
            index := &b.index[oi]
            index.keyPos += ob
            if index.valueLen != 0 {
                index.valuePos += ob
            }
        }
    }
}

func (b *Batch) decode(data []byte, expectedLen int) error {
    b.data = data
    b.index = b.index[:0]
    b.internalLen = 0
    err := decodeBatch(data, func(i int, index batchIndex) error {
        b.index = append(b.index, index)
        b.internalLen += index.keyLen + index.valueLen + 8
        return nil
    })
    if err != nil {
        return err
    }
    if expectedLen >= 0 && len(b.index) != expectedLen {
        return newErrBatchCorrupted(fmt.Sprintf("invalid records length: %d vs %d", expectedLen, len(b.index)))
    }
    return nil
}

func (b *Batch) putMem(seq uint64, mdb *memdb.DB) error {
    var ik []byte
    for i, index := range b.index {
        ik = makeInternalKey(ik, index.k(b.data), seq+uint64(i), index.keyType)
        if err := mdb.Put(ik, index.v(b.data)); err != nil {
            return err
        }
    }
    return nil
}

func newBatch() interface{} {
    return &Batch{}
}

// MakeBatch returns empty batch with preallocated buffer.
func MakeBatch(n int) *Batch {
    return &Batch{data: make([]byte, 0, n)}
}

func decodeBatch(data []byte, fn func(i int, index batchIndex) error) error {
    var index batchIndex
    for i, o := 0, 0; o < len(data); i++ {
        // Key type.
        index.keyType = keyType(data[o])
        if index.keyType > keyTypeVal {
            return newErrBatchCorrupted(fmt.Sprintf("bad record: invalid type %#x", uint(index.keyType)))
        }
        o++

        // Key.
        x, n := binary.Uvarint(data[o:])
        o += n
        if n <= 0 || o+int(x) > len(data) {
            return newErrBatchCorrupted("bad record: invalid key length")
        }
        index.keyPos = o
        index.keyLen = int(x)
        o += index.keyLen

        // Value.
        if index.keyType == keyTypeVal {
            x, n = binary.Uvarint(data[o:])
            o += n
            if n <= 0 || o+int(x) > len(data) {
                return newErrBatchCorrupted("bad record: invalid value length")
            }
            index.valuePos = o
            index.valueLen = int(x)
            o += index.valueLen
        } else {
            index.valuePos = 0
            index.valueLen = 0
        }

        if err := fn(i, index); err != nil {
            return err
        }
    }
    return nil
}

func decodeBatchToMem(data []byte, expectSeq uint64, mdb *memdb.DB) (seq uint64, batchLen int, err error) {
    seq, batchLen, err = decodeBatchHeader(data)
    if err != nil {
        return 0, 0, err
    }
    if seq < expectSeq {
        return 0, 0, newErrBatchCorrupted("invalid sequence number")
    }
    data = data[batchHeaderLen:]
    var ik []byte
    var decodedLen int
    err = decodeBatch(data, func(i int, index batchIndex) error {
        if i >= batchLen {
            return newErrBatchCorrupted("invalid records length")
        }
        ik = makeInternalKey(ik, index.k(data), seq+uint64(i), index.keyType)
        if err := mdb.Put(ik, index.v(data)); err != nil {
            return err
        }
        decodedLen++
        return nil
    })
    if err == nil && decodedLen != batchLen {
        err = newErrBatchCorrupted(fmt.Sprintf("invalid records length: %d vs %d", batchLen, decodedLen))
    }
    return
}

func encodeBatchHeader(dst []byte, seq uint64, batchLen int) []byte {
    dst = ensureBuffer(dst, batchHeaderLen)
    binary.LittleEndian.PutUint64(dst, seq)
    binary.LittleEndian.PutUint32(dst[8:], uint32(batchLen))
    return dst
}

func decodeBatchHeader(data []byte) (seq uint64, batchLen int, err error) {
    if len(data) < batchHeaderLen {
        return 0, 0, newErrBatchCorrupted("too short")
    }

    seq = binary.LittleEndian.Uint64(data)
    batchLen = int(binary.LittleEndian.Uint32(data[8:]))
    if batchLen < 0 {
        return 0, 0, newErrBatchCorrupted("invalid records length")
    }
    return
}

func batchesLen(batches []*Batch) int {
    batchLen := 0
    for _, batch := range batches {
        batchLen += batch.Len()
    }
    return batchLen
}

func writeBatchesWithHeader(wr io.Writer, batches []*Batch, seq uint64) error {
    if _, err := wr.Write(encodeBatchHeader(nil, seq, batchesLen(batches))); err != nil {
        return err
    }
    for _, batch := range batches {
        if _, err := wr.Write(batch.data); err != nil {
            return err
        }
    }
    return nil
}