repository.go

Summary

Maintainability
D
1 day
Test Coverage
A
100%
package rel

import (
    "context"
    "errors"
    "reflect"
    "runtime"
    "strings"
)

// Repository for interacting with database.
type Repository interface {
    // Adapter used in this repository.
    Adapter(ctx context.Context) Adapter

    // Instrumentation defines callback to be used as instrumenter.
    Instrumentation(instrumenter Instrumenter)

    // Ping database.
    Ping(ctx context.Context) error

    // Iterate through a collection of entities from database in batches.
    // This function returns iterator that can be used to loop all entities.
    // Limit, Offset and Sort query is automatically ignored.
    Iterate(ctx context.Context, query Query, option ...IteratorOption) Iterator

    // Aggregate over the given field.
    // Supported aggregate: count, sum, avg, max, min.
    // Any select, group, offset, limit and sort query will be ignored automatically.
    // If complex aggregation is needed, consider using All instead.
    Aggregate(ctx context.Context, query Query, aggregate string, field string) (int, error)

    // MustAggregate over the given field.
    // Supported aggregate: count, sum, avg, max, min.
    // Any select, group, offset, limit and sort query will be ignored automatically.
    // If complex aggregation is needed, consider using All instead.
    // It'll panic if any error occurred.
    MustAggregate(ctx context.Context, query Query, aggregate string, field string) int

    // Count entities that match the query.
    Count(ctx context.Context, collection string, queriers ...Querier) (int, error)

    // MustCount entities that match the query.
    // It'll panic if any error occurred.
    MustCount(ctx context.Context, collection string, queriers ...Querier) int

    // Find a entity that match the query.
    // If no result found, it'll return not found error.
    Find(ctx context.Context, entity any, queriers ...Querier) error

    // MustFind a entity that match the query.
    // If no result found, it'll panic.
    MustFind(ctx context.Context, entity any, queriers ...Querier)

    // FindAll entities that match the query.
    FindAll(ctx context.Context, entities any, queriers ...Querier) error

    // MustFindAll entities that match the query.
    // It'll panic if any error occurred.
    MustFindAll(ctx context.Context, entities any, queriers ...Querier)

    // FindAndCountAll entities that match the query.
    // This is a convenient method that combines FindAll and Count. It's useful when dealing with queries related to pagination.
    // Limit and Offset property will be ignored when performing count query.
    FindAndCountAll(ctx context.Context, entities any, queriers ...Querier) (int, error)

    // MustFindAndCountAll entities that match the query.
    // This is a convenient method that combines FindAll and Count. It's useful when dealing with queries related to pagination.
    // Limit and Offset property will be ignored when performing count query.
    // It'll panic if any error occurred.
    MustFindAndCountAll(ctx context.Context, entities any, queriers ...Querier) int

    // Insert a entity to database.
    Insert(ctx context.Context, entity any, mutators ...Mutator) error

    // MustInsert an entity to database.
    // It'll panic if any error occurred.
    MustInsert(ctx context.Context, entity any, mutators ...Mutator)

    // InsertAll entities.
    // Does not supports application cascade insert.
    InsertAll(ctx context.Context, entities any, mutators ...Mutator) error

    // MustInsertAll entities.
    // It'll panic if any error occurred.
    // Does not supports application cascade insert.
    MustInsertAll(ctx context.Context, entities any, mutators ...Mutator)

    // Update a entity in database.
    // It'll panic if any error occurred.
    Update(ctx context.Context, entity any, mutators ...Mutator) error

    // MustUpdate a entity in database.
    // It'll panic if any error occurred.
    MustUpdate(ctx context.Context, entity any, mutators ...Mutator)

    // UpdateAny entities tha match the query.
    // Returns number of updated entities and error.
    UpdateAny(ctx context.Context, query Query, mutates ...Mutate) (int, error)

    // MustUpdateAny entities that match the query.
    // It'll panic if any error occurred.
    // Returns number of updated entities.
    MustUpdateAny(ctx context.Context, query Query, mutates ...Mutate) int

    // Delete a entity.
    Delete(ctx context.Context, entity any, mutators ...Mutator) error

    // MustDelete a entity.
    // It'll panic if any error occurred.
    MustDelete(ctx context.Context, entity any, mutators ...Mutator)

    // DeleteAll entities.
    // Does not supports application cascade delete.
    DeleteAll(ctx context.Context, entities any) error

    // MustDeleteAll entities.
    // It'll panic if any error occurred.
    // Does not supports application cascade delete.
    MustDeleteAll(ctx context.Context, entities any)

    // DeleteAny entities that match the query.
    // Returns number of deleted entities and error.
    DeleteAny(ctx context.Context, query Query) (int, error)

    // MustDeleteAny entities that match the query.
    // It'll panic if any error occurred.
    // Returns number of updated entities.
    MustDeleteAny(ctx context.Context, query Query) int

    // Preload association with given query.
    // This function can accepts either a struct or a slice of structs.
    // If association is already loaded, this will do nothing.
    // To force preloading even though association is already loaeded, add `Reload(true)` as query.
    Preload(ctx context.Context, entities any, field string, queriers ...Querier) error

    // MustPreload association with given query.
    // This function can accept either a struct or a slice of structs.
    // It'll panic if any error occurred.
    MustPreload(ctx context.Context, entities any, field string, queriers ...Querier)

    // Exec raw statement.
    // Returns last inserted id, rows affected and error.
    Exec(ctx context.Context, statement string, args ...any) (int, int, error)

    // MustExec raw statement.
    // Returns last inserted id, rows affected and error.
    MustExec(ctx context.Context, statement string, args ...any) (int, int)

    // Transaction performs transaction with given function argument.
    // Transaction scope/connection is automatically passed using context.
    Transaction(ctx context.Context, fn func(ctx context.Context) error) error
}

type repository struct {
    rootAdapter  Adapter
    instrumenter Instrumenter
}

func (r repository) Adapter(ctx context.Context) Adapter {
    return fetchContext(ctx, r.rootAdapter).adapter
}

func (r *repository) Instrumentation(instrumenter Instrumenter) {
    r.instrumenter = instrumenter
    r.rootAdapter.Instrumentation(instrumenter)
}

func (r *repository) Ping(ctx context.Context) error {
    return r.rootAdapter.Ping(ctx)
}

func (r repository) Iterate(ctx context.Context, query Query, options ...IteratorOption) Iterator {
    var (
        cw = fetchContext(ctx, r.rootAdapter)
    )

    return newIterator(cw.ctx, cw.adapter, query, options)
}

func (r repository) Aggregate(ctx context.Context, query Query, aggregate string, field string) (int, error) {
    finish := r.instrumenter.Observe(ctx, "rel-aggregate", "aggregating entities")
    defer finish(nil)

    var (
        cw = fetchContext(ctx, r.rootAdapter)
    )

    return r.aggregate(cw, query, aggregate, field)
}

func (r repository) aggregate(cw contextWrapper, query Query, aggregate string, field string) (int, error) {
    query.GroupQuery = GroupQuery{}
    query.LimitQuery = 0
    query.OffsetQuery = 0
    query.SortQuery = nil

    return cw.adapter.Aggregate(cw.ctx, query, aggregate, field)
}

func (r repository) MustAggregate(ctx context.Context, query Query, aggregate string, field string) int {
    result, err := r.Aggregate(ctx, query, aggregate, field)
    must(err)
    return result
}

func (r repository) Count(ctx context.Context, collection string, queriers ...Querier) (int, error) {
    finish := r.instrumenter.Observe(ctx, "rel-count", "aggregating entities")
    defer finish(nil)

    var (
        cw = fetchContext(ctx, r.rootAdapter)
    )

    return r.aggregate(cw, Build(collection, queriers...), "count", "*")
}

func (r repository) MustCount(ctx context.Context, collection string, queriers ...Querier) int {
    count, err := r.Count(ctx, collection, queriers...)
    must(err)
    return count
}

func (r repository) Find(ctx context.Context, entity any, queriers ...Querier) error {
    finish := r.instrumenter.Observe(ctx, "rel-find", "finding a entity")
    defer finish(nil)

    var (
        cw    = fetchContext(ctx, r.rootAdapter)
        doc   = NewDocument(entity)
        query = Build(doc.Table(), queriers...).Populate(doc.Meta())
    )

    return r.find(cw, doc, query)
}

func (r repository) MustFind(ctx context.Context, entity any, queriers ...Querier) {
    must(r.Find(ctx, entity, queriers...))
}

func (r repository) find(cw contextWrapper, doc *Document, query Query) error {
    query = r.withDefaultScope(doc.meta, query, true)
    cur, err := cw.adapter.Query(cw.ctx, query.Limit(1))
    if err != nil {
        return err
    }

    finish := r.instrumenter.Observe(cw.ctx, "rel-scan-one", "scanning a entity")
    if err := scanOne(cur, doc); err != nil {
        finish(err)
        return err
    }
    finish(nil)

    for i := range query.PreloadQuery {
        if err := r.preload(cw, doc, query.PreloadQuery[i], nil); err != nil {
            return err
        }
    }

    return nil
}

func (r repository) FindAll(ctx context.Context, entities any, queriers ...Querier) error {
    finish := r.instrumenter.Observe(ctx, "rel-find-all", "finding all entities")
    defer finish(nil)

    var (
        cw    = fetchContext(ctx, r.rootAdapter)
        col   = NewCollection(entities)
        query = Build(col.Table(), queriers...).Populate(col.Meta())
    )

    col.Reset()

    return r.findAll(cw, col, query)
}

func (r repository) MustFindAll(ctx context.Context, entities any, queriers ...Querier) {
    must(r.FindAll(ctx, entities, queriers...))
}

func (r repository) findAll(cw contextWrapper, col *Collection, query Query) error {
    query = r.withDefaultScope(col.meta, query, true)
    cur, err := cw.adapter.Query(cw.ctx, query)
    if err != nil {
        return err
    }

    finish := r.instrumenter.Observe(cw.ctx, "rel-scan-all", "scanning all entities")
    if err := scanAll(cur, col); err != nil {
        finish(err)
        return err
    }
    finish(nil)

    for i := range query.PreloadQuery {
        if err := r.preload(cw, col, query.PreloadQuery[i], nil); err != nil {
            return err
        }
    }

    return nil
}

func (r repository) FindAndCountAll(ctx context.Context, entities any, queriers ...Querier) (int, error) {
    finish := r.instrumenter.Observe(ctx, "rel-find-and-count-all", "finding all entities")
    defer finish(nil)

    var (
        cw    = fetchContext(ctx, r.rootAdapter)
        col   = NewCollection(entities)
        query = Build(col.Table(), queriers...).Populate(col.Meta())
    )

    col.Reset()

    if err := r.findAll(cw, col, query); err != nil {
        return 0, err
    }

    return r.aggregate(cw, r.withDefaultScope(col.meta, query, false), "count", "*")
}

func (r repository) MustFindAndCountAll(ctx context.Context, entities any, queriers ...Querier) int {
    count, err := r.FindAndCountAll(ctx, entities, queriers...)
    must(err)

    return count
}

func (r repository) Insert(ctx context.Context, entity any, mutators ...Mutator) error {
    finish := r.instrumenter.Observe(ctx, "rel-insert", "inserting a entity")
    defer finish(nil)

    if entity == nil {
        return nil
    }

    var (
        cw       = fetchContext(ctx, r.rootAdapter)
        doc      = NewDocument(entity)
        mutation = Apply(doc, mutators...)
    )

    if !mutation.IsAssocEmpty() && mutation.Cascade == true {
        return r.transaction(cw, func(cw contextWrapper) error {
            return r.insert(cw, doc, mutation)
        })
    }

    return r.insert(cw, doc, mutation)
}

func (r repository) insert(cw contextWrapper, doc *Document, mutation Mutation) error {
    var (
        pField   string
        pFields  = doc.PrimaryFields()
        queriers = Build(doc.Table())
    )

    if mutation.Cascade {
        if err := r.saveBelongsTo(cw, doc, &mutation); err != nil {
            return err
        }
    }

    if len(pFields) == 1 {
        pField = pFields[0]
    }

    pValue, err := cw.adapter.Insert(cw.ctx, queriers, pField, mutation.Mutates, mutation.OnConflict)
    if err != nil {
        return mutation.ErrorFunc.transform(err)
    }

    // update primary value
    if pField != "" {
        doc.SetValue(pField, pValue)
    }

    if mutation.Cascade {
        if err := r.saveHasOne(cw, doc, &mutation); err != nil {
            return err
        }

        if err := r.saveHasMany(cw, doc, &mutation, true); err != nil {
            return err
        }
    }

    return nil
}

func (r repository) MustInsert(ctx context.Context, entity any, mutators ...Mutator) {
    must(r.Insert(ctx, entity, mutators...))
}

func (r repository) InsertAll(ctx context.Context, entities any, mutators ...Mutator) error {
    finish := r.instrumenter.Observe(ctx, "rel-insert-all", "inserting multiple entities")
    defer finish(nil)

    if entities == nil {
        return nil
    }

    var (
        cw   = fetchContext(ctx, r.rootAdapter)
        col  = NewCollection(entities)
        muts = make([]Mutation, col.Len())
    )

    for i := range muts {
        doc := col.Get(i)
        if i == 0 {
            // only need to apply options from first one
            muts[i] = Apply(doc, mutators...)
        } else {
            muts[i] = Apply(doc)
        }
    }

    return r.insertAll(cw, col, muts)
}

func (r repository) MustInsertAll(ctx context.Context, entities any, mutators ...Mutator) {
    must(r.InsertAll(ctx, entities, mutators...))
}

// TODO: support assocs
func (r repository) insertAll(cw contextWrapper, col *Collection, mutation []Mutation) error {
    if len(mutation) == 0 {
        return nil
    }

    var (
        pField      string
        pFields     = col.PrimaryFields()
        queriers    = Build(col.Table())
        onConflict  = mutation[0].OnConflict
        fields      = make([]string, 0, len(mutation[0].Mutates))
        fieldMap    = make(map[string]struct{}, len(mutation[0].Mutates))
        bulkMutates = make([]map[string]Mutate, len(mutation))
    )

    // TODO: baypassable if it's predictable.
    for i := range mutation {
        for field := range mutation[i].Mutates {
            if _, exist := fieldMap[field]; !exist {
                fieldMap[field] = struct{}{}
                fields = append(fields, field)
            }
        }
        bulkMutates[i] = mutation[i].Mutates
    }

    if len(pFields) == 1 {
        pField = pFields[0]
    }

    ids, err := cw.adapter.InsertAll(cw.ctx, queriers, pField, fields, bulkMutates, onConflict)
    if err != nil {
        return mutation[0].ErrorFunc.transform(err)
    }

    // apply ids
    if pField != "" {
        for i, id := range ids {
            col.Get(i).SetValue(pField, id)
        }
    }

    return nil
}

func (r repository) Update(ctx context.Context, entity any, mutators ...Mutator) error {
    finish := r.instrumenter.Observe(ctx, "rel-update", "updating a entity")
    defer finish(nil)

    if entity == nil {
        return nil
    }

    var (
        cw       = fetchContext(ctx, r.rootAdapter)
        doc      = NewDocument(entity)
        filter   = filterDocument(doc)
        mutation = Apply(doc, mutators...)
    )

    if !mutation.IsAssocEmpty() && mutation.Cascade == true {
        return r.transaction(cw, func(cw contextWrapper) error {
            return r.update(cw, doc, mutation, filter)
        })
    }

    return r.update(cw, doc, mutation, filter)
}

func (r repository) lockVersion(doc Document, unscoped Unscoped) (int, bool) {
    if unscoped {
        return 0, false
    }
    if doc.Flag(HasVersioning) {
        versionRaw, _ := doc.Value("lock_version")
        version, _ := versionRaw.(int)
        return version, true
    }
    return 0, false
}

func (r repository) update(cw contextWrapper, doc *Document, mutation Mutation, filter FilterQuery) error {
    if mutation.Cascade {
        if err := r.saveBelongsTo(cw, doc, &mutation); err != nil {
            return err
        }
    }

    if !mutation.IsMutatesEmpty() {
        if err := r.applyMutates(cw, doc, mutation, filter); err != nil {
            return err
        }
    }

    if mutation.Cascade {
        if err := r.saveHasOne(cw, doc, &mutation); err != nil {
            return err
        }

        if err := r.saveHasMany(cw, doc, &mutation, false); err != nil {
            return err
        }
    }

    return nil
}

func (r repository) applyMutates(cw contextWrapper, doc *Document, mutation Mutation, filter FilterQuery) (dbErr error) {
    var (
        baseQueries = []Querier{filter, mutation.Unscoped, mutation.Cascade}
        queries     = baseQueries
    )

    if version, ok := r.lockVersion(*doc, mutation.Unscoped); ok {
        Set("lock_version", version+1).Apply(doc, &mutation)
        queries = append(queries, lockVersion(version))
        defer func() {
            if dbErr != nil {
                doc.SetValue("lock_version", version)
            }
        }()
    }

    var (
        pField string
        query  = r.withDefaultScope(doc.meta, Build(doc.Table(), queries...).Populate(doc.Meta()), false)
    )

    if len(doc.meta.primaryField) == 1 {
        pField = doc.PrimaryField()
    }

    if updatedCount, err := cw.adapter.Update(cw.ctx, query, pField, mutation.Mutates); err != nil {
        return mutation.ErrorFunc.transform(err)
    } else if updatedCount == 0 {
        return NotFoundError{}
    }

    if mutation.Reload {
        baseQuery := r.withDefaultScope(doc.meta, Build(doc.Table(), baseQueries...).Populate(doc.Meta()), false)
        if err := r.find(cw, doc, baseQuery.UsePrimary()); err != nil {
            return err
        }
    }

    return nil
}

func (r repository) MustUpdate(ctx context.Context, entity any, mutators ...Mutator) {
    must(r.Update(ctx, entity, mutators...))
}

// TODO: support deletion
func (r repository) saveBelongsTo(cw contextWrapper, doc *Document, mutation *Mutation) error {
    for _, field := range doc.BelongsTo() {
        var (
            assoc              = doc.Association(field)
            assocMuts, changed = mutation.Assoc[field]
        )

        if !assoc.Autosave() || !changed || len(assocMuts.Mutations) == 0 {
            continue
        }

        var (
            assocDoc, loaded = assoc.Document()
            assocMut         = assocMuts.Mutations[0]
        )

        if loaded {
            filter, err := filterBelongsTo(assoc)
            if err != nil {
                return err
            }

            if err := r.update(cw, assocDoc, assocMut, filter); err != nil {
                return err
            }
        } else {
            if err := r.insert(cw, assocDoc, assocMut); err != nil {
                return err
            }

            var (
                rField = assoc.ReferenceField()
                fValue = assoc.ForeignValue()
            )

            mutation.Add(Set(rField, fValue))
            doc.SetValue(rField, fValue)
        }
    }

    return nil
}

// TODO: suppprt deletion
func (r repository) saveHasOne(cw contextWrapper, doc *Document, mutation *Mutation) error {
    for _, field := range doc.HasOne() {
        var (
            assoc              = doc.Association(field)
            assocMuts, changed = mutation.Assoc[field]
        )

        if !assoc.Autosave() || !changed || len(assocMuts.Mutations) == 0 {
            continue
        }

        var (
            assocDoc, loaded = assoc.Document()
            assocMut         = assocMuts.Mutations[0]
        )

        if loaded && (assoc.ForeignField() == "" || !isZero(assoc.ForeignValue())) {
            filter, err := filterHasOne(assoc, assocDoc)
            if err != nil {
                return err
            }

            if err := r.update(cw, assocDoc, assocMut, filter); err != nil {
                return err
            }
        } else {
            var (
                fField = assoc.ForeignField()
                rValue = assoc.ReferenceValue()
            )

            assocMut.Add(Set(fField, rValue))
            assocDoc.SetValue(fField, rValue)

            if err := r.insert(cw, assocDoc, assocMut); err != nil {
                return err
            }
        }
    }

    return nil
}

// saveHasMany expects has many mutation to be ordered the same as the recrods in collection.
func (r repository) saveHasMany(cw contextWrapper, doc *Document, mutation *Mutation, insertion bool) error {
    for _, field := range doc.HasMany() {
        var (
            assoc              = doc.Association(field)
            assocMuts, changed = mutation.Assoc[field]
        )

        if !assoc.Autosave() || !changed {
            continue
        }

        var (
            col, _     = assoc.Collection()
            table      = col.Table()
            fField     = assoc.ForeignField()
            rValue     = assoc.ReferenceValue()
            muts       = assocMuts.Mutations
            deletedIDs = assocMuts.DeletedIDs
        )

        // this shouldn't happen unless there's bug in the mutator.
        if len(muts) != col.Len() {
            panic("rel: invalid mutator")
        }

        if !insertion {
            var (
                filter = Eq(fField, rValue)
            )

            if deletedIDs == nil {
                // if it's nil, then clear old association (used by structset).
                if _, err := r.deleteAny(cw, col.meta.flag, Build(table, filter).Populate(col.Meta())); err != nil {
                    return err
                }
            } else if len(deletedIDs) > 0 {
                filter = filter.AndIn(col.PrimaryField(), deletedIDs...)
                if _, err := r.deleteAny(cw, col.meta.flag, Build(table, filter).Populate(col.Meta())); err != nil {
                    return err
                }
            }
        }

        // update and filter for bulk insertion.
        updateCount := 0
        for i := range muts {
            var (
                assocDoc = col.Get(i)
            )

            // When deleted IDs is nil, it's assumed that association will be replaced.
            // hence any update request is ignored here.
            var fValue, _ = assocDoc.Value(fField)
            if deletedIDs != nil && !isZero(assocDoc.PrimaryValue()) && !isZero(fValue) {
                var (
                    filter = filterDocument(assocDoc).AndEq(fField, rValue)
                )

                if rValue != fValue {
                    return ConstraintError{
                        Key:  fField,
                        Type: ForeignKeyConstraint,
                        Err:  errors.New("rel: inconsistent has many ref and fk"),
                    }
                }

                if updateCount < i {
                    col.Swap(updateCount, i)
                    muts[i], muts[updateCount] = muts[updateCount], muts[i]
                }

                if err := r.update(cw, assocDoc, muts[updateCount], filter); err != nil {
                    return err
                }

                updateCount++
            } else {
                muts[i].Add(Set(fField, rValue))
                assocDoc.SetValue(fField, rValue)
            }
        }

        if len(muts)-updateCount > 0 {
            var (
                insertMuts = muts
                insertCol  = col
            )

            if updateCount > 0 {
                insertMuts = muts[updateCount:]
                insertCol = col.Slice(updateCount, len(muts))
            }

            if err := r.insertAll(cw, insertCol, insertMuts); err != nil {
                return err
            }
        }

    }

    return nil
}

func (r repository) UpdateAny(ctx context.Context, query Query, mutates ...Mutate) (int, error) {
    finish := r.instrumenter.Observe(ctx, "rel-update-any", "updating multiple entities")
    defer finish(nil)

    var (
        err          error
        updatedCount int
        cw           = fetchContext(ctx, r.rootAdapter)
        muts         = make(map[string]Mutate, len(mutates))
    )

    for _, mut := range mutates {
        muts[mut.Field] = mut
    }

    if len(muts) > 0 {
        updatedCount, err = cw.adapter.Update(cw.ctx, query, "", muts)
    }

    return updatedCount, err
}

func (r repository) MustUpdateAny(ctx context.Context, query Query, mutates ...Mutate) int {
    updatedCount, err := r.UpdateAny(ctx, query, mutates...)
    must(err)
    return updatedCount
}

func (r repository) Delete(ctx context.Context, entity any, mutators ...Mutator) error {
    finish := r.instrumenter.Observe(ctx, "rel-delete", "deleting a entity")
    defer finish(nil)

    var (
        cw       = fetchContext(ctx, r.rootAdapter)
        doc      = NewDocument(entity)
        mutation = applyMutators(nil, false, false, mutators...)
    )

    if mutation.Cascade {
        return r.transaction(cw, func(cw contextWrapper) error {
            return r.delete(cw, doc, filterDocument(doc), mutation)
        })
    }

    return r.delete(cw, doc, filterDocument(doc), mutation)
}

func (r repository) delete(cw contextWrapper, doc *Document, filter FilterQuery, mutation Mutation) error {
    var filters []Querier = []Querier{filter, mutation.Unscoped}

    if version, ok := r.lockVersion(*doc, mutation.Unscoped); ok {
        filters = append(filters, lockVersion(version))
    }

    var (
        table = doc.Table()
        query = Build(table, filters...).Populate(doc.Meta())
    )

    if mutation.Cascade {
        if err := r.deleteHasOne(cw, doc, true); err != nil {
            return err
        }

        if err := r.deleteHasMany(cw, doc); err != nil {
            return err
        }
    }

    deletedCount, err := r.deleteAny(cw, doc.meta.flag, query)
    if err == nil && deletedCount == 0 {
        err = NotFoundError{}
    }

    if err == nil && mutation.Cascade {
        if err := r.deleteBelongsTo(cw, doc, true); err != nil {
            return err
        }
    }

    return err
}

func (r repository) deleteBelongsTo(cw contextWrapper, doc *Document, cascade Cascade) error {
    for _, field := range doc.BelongsTo() {
        var (
            assoc = doc.Association(field)
        )

        if !assoc.Autosave() {
            continue
        }

        if assocDoc, loaded := assoc.Document(); loaded {
            filter, err := filterBelongsTo(assoc)
            if err != nil {
                return err
            }

            if err := r.delete(cw, assocDoc, filter, Mutation{Cascade: cascade}); err != nil {
                return err
            }
        }
    }

    return nil
}

func (r repository) deleteHasOne(cw contextWrapper, doc *Document, cascade Cascade) error {
    for _, field := range doc.HasOne() {
        var (
            assoc = doc.Association(field)
        )

        if !assoc.Autosave() {
            continue
        }

        if assocDoc, loaded := assoc.Document(); loaded {
            filter, err := filterHasOne(assoc, assocDoc)
            if err != nil {
                return err
            }

            if err := r.delete(cw, assocDoc, filter, Mutation{Cascade: cascade}); err != nil {
                return err
            }
        }
    }

    return nil
}

func (r repository) deleteHasMany(cw contextWrapper, doc *Document) error {
    for _, field := range doc.HasMany() {
        var (
            assoc = doc.Association(field)
        )

        if !assoc.Autosave() {
            continue
        }

        if col, loaded := assoc.Collection(); loaded && col.Len() != 0 {
            var (
                table  = col.Table()
                fField = assoc.ForeignField()
                rValue = assoc.ReferenceValue()
                filter = Eq(fField, rValue).And(filterCollection(col))
            )

            if _, err := r.deleteAny(cw, col.meta.flag, Build(table, filter).Populate(doc.Meta())); err != nil {
                return err
            }
        }
    }

    return nil
}

func (r repository) MustDelete(ctx context.Context, entity any, mutators ...Mutator) {
    must(r.Delete(ctx, entity, mutators...))
}

func (r repository) DeleteAll(ctx context.Context, entities any) error {
    finish := r.instrumenter.Observe(ctx, "rel-delete-all", "deleting entities")
    defer finish(nil)

    var (
        cw  = fetchContext(ctx, r.rootAdapter)
        col = NewCollection(entities)
    )

    if col.Len() == 0 {
        return nil
    }

    var (
        query  = Build(col.Table(), filterCollection(col)).Populate(col.Meta())
        _, err = r.deleteAny(cw, col.meta.flag, query)
    )

    return err
}

func (r repository) MustDeleteAll(ctx context.Context, entities any) {
    must(r.DeleteAll(ctx, entities))
}

func (r repository) DeleteAny(ctx context.Context, query Query) (int, error) {
    finish := r.instrumenter.Observe(ctx, "rel-delete-any", "deleting multiple entities")
    defer finish(nil)

    var (
        cw = fetchContext(ctx, r.rootAdapter)
    )

    return r.deleteAny(cw, Invalid, query)
}

func (r repository) MustDeleteAny(ctx context.Context, query Query) int {
    deletedCount, err := r.DeleteAny(ctx, query)
    must(err)
    return deletedCount
}

func (r repository) deleteAny(cw contextWrapper, flag DocumentFlag, query Query) (int, error) {
    hasDeletedAt := flag.Is(HasDeletedAt)
    hasDeleted := flag.Is(HasDeleted)
    mutates := make(map[string]Mutate, 1)
    if hasDeletedAt {
        mutates["deleted_at"] = Set("deleted_at", Now())
    }
    if hasDeleted {
        mutates["deleted"] = Set("deleted", true)
        if flag.Is(HasUpdatedAt) && !hasDeletedAt {
            mutates["updated_at"] = Set("updated_at", Now())
        }
    }
    if hasDeletedAt || hasDeleted {
        if flag.Is(HasVersioning) {
            mutates["lock_version"] = Inc("lock_version")
        }
        return cw.adapter.Update(cw.ctx, query, "", mutates)
    }

    return cw.adapter.Delete(cw.ctx, query)
}

func (r repository) Preload(ctx context.Context, entities any, field string, queriers ...Querier) error {
    finish := r.instrumenter.Observe(ctx, "rel-preload", "preloading associations")
    defer finish(nil)

    var (
        sl slice
        cw = fetchContext(ctx, r.rootAdapter)
        rt = reflect.TypeOf(entities)
    )

    if rt.Kind() != reflect.Ptr {
        panic("rel: entity parameter must be a pointer.")
    }

    rt = rt.Elem()
    if rt.Kind() == reflect.Slice {
        sl = NewCollection(entities)
    } else {
        sl = NewDocument(entities)
    }

    return r.preload(cw, sl, field, queriers)
}

func (r repository) preload(cw contextWrapper, entities slice, field string, queriers []Querier) error {
    var (
        path                                             = strings.Split(field, ".")
        targets, table, keyField, keyType, ddata, loaded = r.mapPreloadTargets(entities, path)
        ids                                              = r.targetIDs(targets)
        inClauseLength                                   = 999
    )

    // Create separate queries if the amount of ids is more than inClauseLength.
    for {
        if len(ids) == 0 {
            break
        }

        // necessary check to avoid slicing beyond
        // slice capacity
        if len(ids) < inClauseLength {
            inClauseLength = len(ids)
        }

        idsChunk := ids[0:inClauseLength]
        ids = ids[inClauseLength:]

        query := Build(table, append(queriers, In(keyField, idsChunk...))...).Populate(entities.Meta())
        if len(targets) == 0 || loaded && !bool(query.ReloadQuery) {
            return nil
        }

        var (
            cur, err = cw.adapter.Query(cw.ctx, r.withDefaultScope(ddata, query, false))
        )

        if err != nil {
            return err
        }

        scanFinish := r.instrumenter.Observe(cw.ctx, "rel-scan-multi", "scanning all entities to multiple targets")
        // Note: Calling scanMulti multiple times with the same targets works
        // only if the cursor of each execution only contains a new set of keys.
        // That is here the case as each select is with a unique set of ids.
        err = scanMulti(cur, keyField, keyType, targets)
        scanFinish(err)
        if err != nil {
            return err
        }
    }

    return nil
}

func (r repository) MustPreload(ctx context.Context, entities any, field string, queriers ...Querier) {
    must(r.Preload(ctx, entities, field, queriers...))
}

func (r repository) mapPreloadTargets(sl slice, path []string) (map[any][]slice, string, string, reflect.Type, DocumentMeta, bool) {
    type frame struct {
        index int
        doc   *Document
    }

    var (
        table     string
        keyField  string
        keyType   reflect.Type
        meta      DocumentMeta
        loaded    = true
        mapTarget = make(map[any][]slice)
        stack     = make([]frame, sl.Len())
    )

    // init stack
    for i := 0; i < len(stack); i++ {
        stack[i] = frame{index: 0, doc: sl.Get(i)}
    }

    for len(stack) > 0 {
        var (
            n      = len(stack) - 1
            top    = stack[n]
            assocs = top.doc.Association(path[top.index])
        )

        stack = stack[:n]

        if top.index == len(path)-1 {
            var (
                target       slice
                targetLoaded bool
                ref          = assocs.ReferenceValue()
            )

            if ref == nil {
                continue
            }

            if assocs.Type() == HasMany {
                target, targetLoaded = assocs.Collection()
            } else {
                target, targetLoaded = assocs.LazyDocument()
            }

            target.Reset()
            mapTarget[ref] = append(mapTarget[ref], target)
            loaded = loaded && targetLoaded

            if table == "" {
                table = target.Table()
                keyField = assocs.ForeignField()
                keyType = reflect.TypeOf(ref)

                if doc, ok := target.(*Document); ok {
                    meta = doc.meta
                }

                if col, ok := target.(*Collection); ok {
                    meta = col.meta
                }
            }
        } else {
            if assocs.Type() == HasMany {
                var (
                    col, loaded = assocs.Collection()
                )

                if !loaded {
                    continue
                }

                stack = append(stack, make([]frame, col.Len())...)
                for i := 0; i < col.Len(); i++ {
                    stack[n+i] = frame{
                        index: top.index + 1,
                        doc:   col.Get(i),
                    }
                }
            } else {
                if doc, loaded := assocs.LazyDocument(); loaded {
                    stack = append(stack, frame{
                        index: top.index + 1,
                        doc:   doc,
                    })
                }
            }
        }

    }

    return mapTarget, table, keyField, keyType, meta, loaded
}

func (r repository) targetIDs(targets map[any][]slice) []any {
    var (
        ids = make([]any, len(targets))
        i   = 0
    )

    for key := range targets {
        ids[i] = key
        i++
    }

    return ids
}

func (r repository) withDefaultScope(meta DocumentMeta, query Query, preload bool) Query {
    if query.UnscopedQuery {
        return query
    }

    if meta.flag.Is(HasDeleted) {
        query = query.Where(Eq("deleted", false))
    } else if meta.flag.Is(HasDeletedAt) {
        query = query.Where(Nil("deleted_at"))
    }

    if preload && bool(query.CascadeQuery) {
        // Clone meta.preload to avoid data race
        //
        // The implementation for cloning a slice is the same as `slices.Clone` in go 1.23
        // https://cs.opensource.google/go/go/+/refs/tags/go1.23.3:src/slices/slices.go;l=350
        metaPreload := append(meta.preload[:0:0], meta.preload...)

        query.PreloadQuery = append(metaPreload, query.PreloadQuery...)
    }

    return query
}

// Exec raw statement.
// Returns last inserted id, rows affected and error.
func (r repository) Exec(ctx context.Context, stmt string, args ...any) (int, int, error) {
    lastInsertedId, rowsAffected, err := r.Adapter(ctx).Exec(ctx, stmt, args)
    return int(lastInsertedId), int(rowsAffected), err
}

// MustExec raw statement.
// Returns last inserted id, rows affected and error.
func (r repository) MustExec(ctx context.Context, stmt string, args ...any) (int, int) {
    lastInsertedId, rowsAffected, err := r.Exec(ctx, stmt, args...)
    must(err)
    return lastInsertedId, rowsAffected
}

func (r repository) Transaction(ctx context.Context, fn func(ctx context.Context) error) error {
    finish := r.instrumenter.Observe(ctx, "rel-transaction", "transaction")
    defer finish(nil)

    var (
        cw = fetchContext(ctx, r.rootAdapter)
    )

    return r.transaction(cw, func(cw contextWrapper) error {
        return fn(cw.ctx)
    })
}

func (r repository) transaction(cw contextWrapper, fn func(cw contextWrapper) error) error {
    adp, err := cw.adapter.Begin(cw.ctx)
    if err != nil {
        return err
    }

    // wrap trx adapter to new context.
    cw = wrapContext(cw.ctx, adp)

    func() {
        defer func() {
            if p := recover(); p != nil {
                _ = cw.adapter.Rollback(cw.ctx)

                switch e := p.(type) {
                case runtime.Error:
                    panic(e)
                case error:
                    err = e
                default:
                    panic(e)
                }
            } else if err != nil {
                _ = cw.adapter.Rollback(cw.ctx)
            } else {
                err = cw.adapter.Commit(cw.ctx)
            }
        }()

        err = fn(cw)
    }()

    return err
}

// New create new repo using adapter.
func New(adapter Adapter) Repository {
    repo := &repository{
        rootAdapter:  adapter,
        instrumenter: DefaultLogger,
    }

    repo.Instrumentation(DefaultLogger)

    return repo
}