datacoves/dbt-coves

View on GitHub
dbt_coves/tasks/data_sync/sql_database/helpers.py

Summary

Maintainability
A
3 hrs
Test Coverage
"""SQL database source helpers"""

import operator
from typing import Any, Callable, Iterator, List, Optional, Union

import dlt
from dlt.common.configuration.specs import BaseConfiguration, configspec
from dlt.common.typing import TDataItem
from dlt.sources.credentials import ConnectionStringCredentials
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine

from .schema_types import SelectAny, Table, table_to_columns


class TableLoader:
    def __init__(
        self,
        engine: Engine,
        table: Table,
        chunk_size: int = 1000,
        incremental: Optional[dlt.sources.incremental[Any]] = None,
    ) -> None:
        self.engine = engine
        self.table = table
        self.chunk_size = chunk_size
        self.incremental = incremental
        if incremental:
            try:
                self.cursor_column = table.c[incremental.cursor_path]
            except KeyError as e:
                raise KeyError(
                    f"Cursor column '{incremental.cursor_path}' does not exist in table '{table.name}'"
                ) from e
            self.last_value = incremental.last_value
            self.end_value = incremental.end_value
            self.row_order = getattr(self.incremental, "row_order", None)
        else:
            self.cursor_column = None
            self.last_value = None
            self.end_value = None
            self.row_order = None

    def make_query(self) -> SelectAny:
        table = self.table
        query = table.select()
        if not self.incremental:
            return query
        last_value_func = self.incremental.last_value_func

        # generate where
        if last_value_func is max:  # Query ordered and filtered according to last_value function
            filter_op = operator.ge
            filter_op_end = operator.lt
        elif last_value_func is min:
            filter_op = operator.le
            filter_op_end = operator.gt
        else:  # Custom last_value, load everything and let incremental handle filtering
            return query

        if self.last_value is not None:
            query = query.where(filter_op(self.cursor_column, self.last_value))
            if self.end_value is not None:
                query = query.where(filter_op_end(self.cursor_column, self.end_value))

        # generate order by from declared row order
        order_by = None
        if self.row_order == "asc":
            order_by = self.cursor_column.asc()
        elif self.row_order == "desc":
            order_by = self.cursor_column.desc()
        if order_by is not None:
            query = query.order_by(order_by)

        return query

    def load_rows(self) -> Iterator[List[TDataItem]]:
        query = self.make_query()
        with self.engine.connect() as conn:
            result = conn.execution_options(yield_per=self.chunk_size).execute(query)
            for partition in result.partitions(size=self.chunk_size):
                yield [dict(row._mapping) for row in partition]


def table_rows(
    engine: Engine,
    table: Table,
    chunk_size: int,
    incremental: Optional[dlt.sources.incremental[Any]] = None,
    detect_precision_hints: bool = False,
    defer_table_reflect: bool = False,
    table_adapter_callback: Callable[[Table], None] = None,
) -> Iterator[TDataItem]:
    if defer_table_reflect:
        table = Table(table.name, table.metadata, autoload_with=engine, extend_existing=True)
        if table_adapter_callback:
            table_adapter_callback(table)
        # set the primary_key in the incremental
        if incremental and incremental.primary_key is None:
            primary_key = get_primary_key(table)
            if primary_key is not None:
                incremental.primary_key = primary_key
        # yield empty record to set hints
        yield dlt.mark.with_hints(
            [],
            dlt.mark.make_hints(
                primary_key=get_primary_key(table),
                columns=table_to_columns(table) if detect_precision_hints else None,
            ),
        )

    loader = TableLoader(engine, table, incremental=incremental, chunk_size=chunk_size)
    yield from loader.load_rows()


def engine_from_credentials(credentials: Union[ConnectionStringCredentials, Engine, str]) -> Engine:
    if isinstance(credentials, Engine):
        return credentials
    if isinstance(credentials, ConnectionStringCredentials):
        credentials = credentials.to_native_representation()
    return create_engine(credentials)


def get_primary_key(table: Table) -> List[str]:
    """Create primary key or return None if no key defined"""
    primary_key = [c.name for c in table.primary_key]
    return primary_key if len(primary_key) > 0 else None


@configspec
class SqlDatabaseTableConfiguration(BaseConfiguration):
    incremental: Optional[dlt.sources.incremental] = None  # type: ignore[type-arg]


@configspec
class SqlTableResourceConfiguration(BaseConfiguration):
    credentials: ConnectionStringCredentials
    table: str
    incremental: Optional[dlt.sources.incremental] = None  # type: ignore[type-arg]
    schema: Optional[str]


__source_name__ = "sql_database"