18F/rdbms-subsetter

View on GitHub
rdbms_subsetter/subsetter.py

Summary

Maintainability
F
5 days
Test Coverage
"""
Generate a random sample of rows from a relational database that preserves
referential integrity - so long as constraints are defined, all parent rows
will exist for child rows.

Good for creating test/development databases from production.  It's slow,
but how often do you need to generate a test/development database?

Usage::

    rdbms-subsetter <source SQLAlchemy connection string> <destination connection string> <fraction of rows to use>

Example::

    rdbms-subsetter postgresql://:@/bigdb postgresql://:@/littledb 0.05

Valid SQLAlchemy connection strings are described
`here <docs.sqlalchemy.org/en/latest/core/engines.html#database-urls#database-urls>`_.

``rdbms-subsetter`` promises that each child row will have whatever parent rows are
required by its foreign keys.  It will also *try* to include most child rows belonging
to each parent row (up to the supplied ``--children`` parameter, default 3 each), but it
can't make any promises.  (Demanding all children can lead to infinite propagation in
thoroughly interlinked databases, as every child record demands new parent records,
which demand new child records, which demand new parent records...
so increase ``--children`` with caution.)

When row numbers in your tables vary wildly (tens to billions, for example),
consider using the ``-l`` flag, which reduces row counts by a logarithmic formula.  If ``f`` is
the fraction specified, and ``-l`` is set, and the original table has ``n`` rows,
then each new table's row target will be::

    math.pow(10, math.log10(n)*f)

A fraction of ``0.5`` seems to produce good results, converting 10 rows to 3,
1,000,000 to 1,000, and 1,000,000,000 to 31,622.

Rows are selected randomly, but for tables with a single primary key column, you
can force rdbms-subsetter to include specific rows (and their dependencies) with
``force=<tablename>:<primary key value>``.  The immediate children of these rows
are also exempted from the ``--children`` limit.

rdbms-subsetter only performs the inserts; it's your responsibility to set
up the target database first, with its foreign key constraints.  The easiest
way to do this is with your RDBMS's dump utility.  For example, for PostgreSQL,

::

    pg_dump --schema-only -f schemadump.sql bigdb
    createdb littledb
    psql -f schemadump.sql littledb

You can pull rows from a non-default schema by passing ``--schema=<name>``.
Currently the target database must contain the corresponding tables in its own
schema of the same name (moving between schemas of different names is not yet
supported).

Case-specific table names will probably create bad results in rdbms-subsetter,
and in the rest of your life, for that matter.  Don't do it.
"""
import argparse
import fnmatch
import functools
import json
import logging
import math
import random
import types
from collections import OrderedDict, deque

import sqlalchemy as sa
from blinker import signal
from sqlalchemy.engine.reflection import Inspector

from dialects.postgres import fix_postgres_array_of_enum

# Python2 has a totally different definition for ``input``; overriding it here
try:
    input = raw_input
except NameError:
    pass

__version__ = '0.2.6.2'

SIGNAL_ROW_ADDED = 'row_added'


def _find_n_rows(self, estimate=False):
    self.n_rows = 0
    if estimate:
        try:
            if self.db.engine.driver in ('psycopg2', 'pg8000', ):
                schema = (self.schema + '.') if self.schema else ''
                qry = """SELECT reltuples FROM pg_class
                         WHERE oid = lower('%s%s')::regclass""" % (
                    schema, self.name.lower())
            elif 'oracle' in self.db.engine.driver:
                qry = """SELECT num_rows FROM all_tables
                         WHERE LOWER(table_name)='%s'""" % self.name.lower()
            else:
                raise NotImplementedError(
                    "No approximation known for driver %s" %
                    self.db.engine.driver)
            self.n_rows = self.db.conn.execute(qry).fetchone()[0]
        except Exception as e:
            logging.debug("failed to get approximate rowcount for %s\n%s" %
                          (self.name, str(e)))
    if not self.n_rows:
        self.n_rows = self.db.conn.execute(self.count()).fetchone()[0]


def _random_row_func(self):
    dialect = self.bind.engine.dialect.name
    if 'mysql' in dialect or 'mssql' in dialect:
        return sa.sql.func.rand()
    elif 'oracle' in dialect:
        return sa.sql.func.dbms_random.value()
    else:
        return sa.sql.func.random()


def _random_row_gen_fn(self):
    """
    Random sample of *approximate* size n
    """
    if self.n_rows:
        while True:
            n = self.target.n_rows_desired
            if self.n_rows > 1000:
                fraction = n / float(self.n_rows)
                qry = sa.sql.select([self, ]).where(self.random_row_func() <
                                                    fraction)
                results = self.db.conn.execute(qry).fetchall()
                # we may stop wanting rows at any point, so shuffle them so as not to
                # skew the sample toward those near the beginning
                random.shuffle(results)
                for row in results:
                    yield row
            else:
                qry = sa.sql.select([self, ]).order_by(self.random_row_func(
                )).limit(n)
                for row in self.db.conn.execute(qry):
                    yield row


def _next_row(self):
    if self.target.required:
        return self.target.required.popleft()
    elif self.target.requested:
        return self.target.requested.popleft()
    else:
        try:
            return (next(self.random_rows), False)  # not prioritized
        except StopIteration:
            return None


def _filtered_by(self, **kw):
    slct = sa.sql.select([self, ])
    slct = slct.where(sa.sql.and_((self.c[k] == v) for (k, v) in kw.items()))
    return slct


def _pk_val(self, row):
    if self.pk:
        return row[self.pk[0]]
    else:
        return None


def _by_pk(self, pk):
    pk_name = self.db.inspector.get_primary_keys(self.name, self.schema)[0]
    slct = self.filtered_by(**{pk_name: pk})
    return self.db.conn.execute(slct).fetchone()


def _completeness_score(self):
    """Scores how close a target table is to being filled enough to quit"""
    table = (self.schema if self.schema else "") + self.name
    fetch_all = self.fetch_all
    requested = len(self.requested)
    required = len(self.required)
    n_rows = float(self.n_rows)
    n_rows_desired = float(self.n_rows_desired)
    logging.debug("%s.fetch_all      = %s", table, fetch_all)
    logging.debug("%s.requested      = %d", table, requested)
    logging.debug("%s.required       = %d", table, required)
    logging.debug("%s.n_rows         = %d", table, n_rows)
    logging.debug("%s.n_rows_desired = %d", table, n_rows_desired)
    if fetch_all:
        if n_rows < n_rows_desired:
            return 1 + (n_rows or 1) - (n_rows_desired or 1)
    result = 0 - (requested / (n_rows or 1)) - required
    if not self.required:  # anything in `required` queue disqualifies
        result += (n_rows / (n_rows_desired or 1))**0.33
    return result


def _table_matches_any_pattern(schema, table, patterns):
    """Test if the table `<schema>.<table>` matches any of the provided patterns.

    Will attempt to match both `schema.table` and just `table` against each pattern.

    Params:
        - schema.      Name of the schema the table belongs to.
        - table.       Name of the table.
        - patterns.    The patterns to try.
    """
    qual_name = '{}.{}'.format(schema, table)
    return any(fnmatch.fnmatch(qual_name, each) or fnmatch.fnmatch(table, each)
               for each in patterns)


def _import_modules(import_list):
    for module_name in import_list:
        __import__(module_name)


class Db(object):
    def __init__(self, sqla_conn, args, schemas=[None]):
        self.args = args
        self.sqla_conn = sqla_conn
        self.schemas = schemas
        self.engine = sa.create_engine(sqla_conn)
        self.inspector = Inspector(bind=self.engine)
        self.conn = self.engine.connect()
        self.tables = OrderedDict()

        for schema in self.schemas:
            meta = sa.MetaData(
                bind=self.engine)  # excised schema=schema to prevent errors
            meta.reflect(schema=schema)
            for tbl in meta.sorted_tables:
                if args.tables and not _table_matches_any_pattern(
                        tbl.schema, tbl.name, self.args.tables):
                    continue
                if _table_matches_any_pattern(tbl.schema, tbl.name,
                                              self.args.exclude_tables):
                    continue
                tbl.db = self

                if self.engine.name == 'postgresql':
                    fix_postgres_array_of_enum(self.conn, tbl)

                # TODO: Replace all these monkeypatches with an instance assigment
                tbl.find_n_rows = types.MethodType(_find_n_rows, tbl)
                tbl.random_row_func = types.MethodType(_random_row_func, tbl)
                tbl.fks = self.inspector.get_foreign_keys(tbl.name,
                                                          schema=tbl.schema)
                tbl.pk = self.inspector.get_primary_keys(tbl.name,
                                                         schema=tbl.schema)
                if not tbl.pk:
                    tbl.pk = [
                        d['name']
                        for d in self.inspector.get_columns(tbl.name,
                                                            schema=tbl.schema)
                    ]
                tbl.filtered_by = types.MethodType(_filtered_by, tbl)
                tbl.by_pk = types.MethodType(_by_pk, tbl)
                tbl.pk_val = types.MethodType(_pk_val, tbl)
                tbl.child_fks = []
                estimate_rows = not _table_matches_any_pattern(
                    tbl.schema, tbl.name, self.args.full_tables)
                tbl.find_n_rows(estimate=estimate_rows)
                self.tables[(tbl.schema, tbl.name)] = tbl
        all_constraints = args.config.get('constraints', {})
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            qualified = "{}.{}".format(tbl_schema, tbl_name)
            if qualified in all_constraints:
                constraints = all_constraints[qualified]
            else:
                constraints = all_constraints.get(tbl_name, [])
            tbl.constraints = constraints
            for fk in (tbl.fks + constraints):
                fk['constrained_schema'] = tbl_schema
                fk['constrained_table'] = tbl_name  # TODO: check against constrained_table
                self.tables[(fk['referred_schema'], fk['referred_table']
                             )].child_fks.append(fk)

    def __repr__(self):
        return "Db('%s')" % self.sqla_conn

    def assign_target(self, target_db):
        for ((tbl_schema, tbl_name), tbl) in self.tables.items():
            tbl._random_row_gen_fn = types.MethodType(_random_row_gen_fn, tbl)
            tbl.random_rows = tbl._random_row_gen_fn()
            tbl.next_row = types.MethodType(_next_row, tbl)
            target = target_db.tables[(tbl_schema, tbl_name)]
            target.requested = deque()
            target.required = deque()
            target.pending = dict()
            target.done = set()
            target.fetch_all = False
            if _table_matches_any_pattern(tbl.schema, tbl.name,
                                          self.args.full_tables):
                target.n_rows_desired = tbl.n_rows
                target.fetch_all = True
            else:
                if tbl.n_rows:
                    if self.args.logarithmic:
                        target.n_rows_desired = int(math.pow(10, math.log10(
                            tbl.n_rows) * self.args.fraction)) or 1
                    else:
                        target.n_rows_desired = int(tbl.n_rows *
                                                    self.args.fraction) or 1
                else:
                    target.n_rows_desired = 0
            target.source = tbl
            tbl.target = target
            target.completeness_score = types.MethodType(_completeness_score,
                                                         target)
            logging.debug("assigned methods to %s" % target.name)

    def confirm(self):
        message = []
        for (tbl_schema, tbl_name) in sorted(self.tables, key=lambda t: t[1]):
            tbl = self.tables[(tbl_schema, tbl_name)]
            message.append("Create %d rows from %d in %s.%s" %
                           (tbl.target.n_rows_desired, tbl.n_rows,
                            tbl_schema or '', tbl_name))
        print("\n".join(sorted(message)))
        if self.args.yes:
            return True
        response = input("Proceed? (Y/n) ").strip().lower()
        return (not response) or (response[0] == 'y')

    def create_row_in(self, source_row, target_db, target, prioritized=False):
        logging.debug('create_row_in %s:%s ' %
                      (target.name, target.pk_val(source_row)))

        pks = hashable((source_row[key] for key in target.pk))
        row_exists = pks in target.pending or pks in target.done
        logging.debug("Row exists? %s" % str(row_exists))
        if row_exists and not prioritized:
            return

        if not row_exists:
            # make sure that all required rows are in parent table(s)
            for fk in target.fks:
                target_parent = target_db.tables[(fk['referred_schema'], fk[
                    'referred_table'])]
                slct = sa.sql.select([target_parent, ])
                any_non_null_key_columns = False
                for (parent_col, child_col) in zip(fk['referred_columns'],
                                                   fk['constrained_columns']):
                    slct = slct.where(target_parent.c[parent_col] ==
                                      source_row[child_col])
                    if source_row[child_col] is not None:
                        any_non_null_key_columns = True
                        break
                if any_non_null_key_columns:
                    target_parent_row = target_db.conn.execute(slct).first()
                    if not target_parent_row:
                        source_parent_row = self.conn.execute(slct).first()
                        self.create_row_in(source_parent_row, target_db,
                                           target_parent)

            # make sure that all referenced rows are in referenced table(s)
            for constraint in target.constraints:
                target_referred = target_db.tables[(constraint[
                    'referred_schema'], constraint['referred_table'])]
                slct = sa.sql.select([target_referred, ])
                any_non_null_key_columns = False
                for (referred_col, constrained_col) in zip(
                        constraint['referred_columns'],
                        constraint['constrained_columns']):
                    slct = slct.where(target_referred.c[referred_col] ==
                                      source_row[constrained_col])
                    if source_row[constrained_col] is not None:
                        any_non_null_key_columns = True
                        break
                if any_non_null_key_columns:
                    target_referred_row = target_db.conn.execute(slct).first()
                    if not target_referred_row:
                        source_referred_row = self.conn.execute(slct).first()
                        # because constraints aren't enforced like real FKs, the referred row isn't guaranteed to exist
                        if source_referred_row:
                            self.create_row_in(source_referred_row, target_db,
                                               target_referred)

            pks = hashable((source_row[key] for key in target.pk))
            target.n_rows += 1

            if self.args.buffer == 0:
                target_db.insert_one(target, pks, source_row)
            else:
                target.pending[pks] = source_row
            signal(SIGNAL_ROW_ADDED).send(self,
                                          source_row=source_row,
                                          target_db=target_db,
                                          target_table=target,
                                          prioritized=prioritized)

        for child_fk in target.child_fks:
            child = self.tables[(child_fk['constrained_schema'], child_fk[
                'constrained_table'])]
            slct = sa.sql.select([child])
            for (child_col, this_col) in zip(child_fk['constrained_columns'],
                                             child_fk['referred_columns']):
                slct = slct.where(child.c[child_col] == source_row[this_col])
            if not prioritized:
                slct = slct.limit(self.args.children)
            for (n, desired_row) in enumerate(self.conn.execute(slct)):
                if prioritized:
                    child.target.required.append((desired_row, prioritized))
                elif (n == 0):
                    child.target.requested.appendleft((desired_row, prioritized
                                                       ))
                else:
                    child.target.requested.append((desired_row, prioritized))

    @property
    def pending(self):
        return functools.reduce(
            lambda count, table: count + len(table.pending),
            self.tables.values(), 0)

    def insert_one(self, table, pk, values):
        self.conn.execute(table.insert(), values)
        table.done.add(pk)

    def flush(self):
        for table in self.tables.values():
            if not table.pending:
                continue
            self.conn.execute(table.insert(), list(table.pending.values()))
            table.done = table.done.union(table.pending.keys())
            table.pending = dict()

    def create_subset_in(self, target_db):

        for (tbl_name, pks) in self.args.force_rows.items():
            if '.' in tbl_name:
                (tbl_schema, tbl_name) = tbl_name.split('.', 1)
            else:
                tbl_schema = None
            source = self.tables[(tbl_schema, tbl_name)]
            for pk in pks:
                source_row = source.by_pk(pk)
                if source_row:
                    self.create_row_in(source_row,
                                       target_db,
                                       source.target,
                                       prioritized=True)
                else:
                    logging.warn("requested %s:%s not found in source db,"
                                 "could not create" % (source.name, pk))

        while True:
            targets = sorted(target_db.tables.values(),
                             key=lambda t: t.completeness_score())
            try:
                target = targets.pop(0)
                while not target.source.n_rows:
                    target = targets.pop(0)
            except IndexError:  # pop failure, no more tables
                break
            logging.debug("total n_rows in target: %d" %
                          sum((t.n_rows for t in target_db.tables.values())))
            logging.debug("target tables with 0 n_rows: %s" % ", ".join(
                t.name for t in target_db.tables.values() if not t.n_rows))
            logging.info("lowest completeness score (in %s) at %f" %
                         (target.name, target.completeness_score()))
            if target.completeness_score() > 0.97:
                break
            (source_row, prioritized) = target.source.next_row()
            self.create_row_in(source_row,
                               target_db,
                               target,
                               prioritized=prioritized)

            if target_db.pending > self.args.buffer > 0:
                target_db.flush()

        if self.args.buffer > 0:
            target_db.flush()


def update_sequences(source, target, schemas, tables, exclude_tables):
    """Set database sequence values to match the source db

       Needed to avoid subsequent unique key violations after DB build.
       Currently only implemented for postgresql -> postgresql."""

    if source.engine.name != 'postgresql' or target.engine.name != 'postgresql':
        return
    qry = """SELECT 'SELECT last_value FROM ' || n.nspname ||
                     '.' || s.relname || ';' AS qry,
                    n.nspname || '.' || s.relname AS qual_name,
                    n.nspname AS schema, t.relname AS table
             FROM pg_class s
             JOIN pg_depend d ON (d.objid=s.oid AND d.classid='pg_class'::regclass AND d.refclassid='pg_class'::regclass)
             JOIN pg_class t ON (t.oid=d.refobjid)
             JOIN pg_namespace n ON (n.oid=t.relnamespace)
             WHERE s.relkind='S' AND d.deptype='a'"""

    for (qry, qual_name, schema, table) in list(source.conn.execute(qry)):
        if schema not in schemas:
            continue
        if tables and not _table_matches_any_pattern(schema, table, tables):
            continue
        if _table_matches_any_pattern(schema, table, exclude_tables):
            continue
        (lastval, ) = source.conn.execute(qry).first()
        nextval = int(lastval) + 1
        updater = "ALTER SEQUENCE %s RESTART WITH %s;" % (qual_name, nextval)
        target.conn.execute(updater)
    target.conn.execute('commit')


def fraction(n):
    n = float(n)
    if 0 <= n <= 1:
        return n
    raise argparse.ArgumentError(
        'Fraction must be greater than 0 and no greater than 1')


all_loglevels = "CRITICAL, FATAL, ERROR, DEBUG, INFO, WARN, WARNING"


def loglevel(raw):
    try:
        return int(raw)
    except ValueError:
        upper = raw.upper()
        if upper in all_loglevels:
            return getattr(logging, upper)
        raise NotImplementedError('log level "%s" not one of %s' %
                                  (raw, all_loglevels))


argparser = argparse.ArgumentParser(
    description='Generate consistent subset of a database')
argparser.add_argument('source',
                       help='SQLAlchemy connection string for data origin',
                       type=str)
argparser.add_argument(
    'dest',
    help='SQLAlchemy connection string for data destination',
    type=str)
argparser.add_argument(
    'fraction',
    help='Proportion of rows to create in dest (0.0 to 1.0)',
    type=fraction)
argparser.add_argument(
    '-l',
    '--logarithmic',
    help='Cut row numbers logarithmically; try 0.5 for fraction',
    action='store_true')
argparser.add_argument(
    '-b',
    '--buffer',
    help=
    'Number of records to store in buffer before flush; use 0 for no buffer',
    type=int,
    default=1000)
argparser.add_argument('--loglevel',
                       type=loglevel,
                       help='log level (%s)' % all_loglevels,
                       default='INFO')
argparser.add_argument(
    '-f',
    '--force',
    help='<table name>:<primary_key_val> to force into dest',
    type=str.lower,
    action='append')
argparser.add_argument(
    '-c',
    '--children',
    help='Max number of child rows to attempt to pull for each parent row',
    type=int,
    default=3)
argparser.add_argument('--schema',
                       help='Non-default schema to include',
                       type=str,
                       action='append',
                       default=[])
argparser.add_argument('--config',
                       help='Path to configuration .json file',
                       type=argparse.FileType('r'))
argparser.add_argument('--table',
                       '-t',
                       dest='tables',
                       help='Include the named table(s) only',
                       type=str,
                       action='append',
                       default=[])
argparser.add_argument(
    '--exclude-table',
    '-T',
    dest='exclude_tables',
    help=
    'Tables to exclude. When both -t and -T are given, the behavior is to include just the tables that match at least one -t switch but no -T switches.',
    type=str,
    action='append',
    default=[])
argparser.add_argument('--full-table',
                       '-F',
                       dest='full_tables',
                       help='Tables to include every row of',
                       type=str,
                       action='append',
                       default=[])
argparser.add_argument(
    '--import',
    '-i',
    dest='import_list',
    help='Dotted module name to import; e.g. custom.signalhandler',
    type=str,
    action='append',
    default=[])
argparser.add_argument('-y',
                       '--yes',
                       help='Proceed without stopping for confirmation',
                       action='store_true')

log_format = "%(asctime)s %(levelname)-5s %(message)s"


def merge_config_args(args):
    args.tables.extend(args.config.get("tables", []))
    args.schema.extend(args.config.get("schemas", []))
    args.full_tables.extend(args.config.get("full_tables", []))

def generate():
    args = argparser.parse_args()
    _import_modules(args.import_list)
    args.force_rows = {}
    for force_row in (args.force or []):
        (table_name, pk) = force_row.split(':')
        if table_name not in args.force_rows:
            args.force_rows[table_name] = []
        args.force_rows[table_name].append(pk)
    logging.getLogger().setLevel(args.loglevel)
    logging.basicConfig(format=log_format)

    args.config = json.load(args.config) if args.config else {}
    merge_config_args(args)
    schemas = args.schema + [None, ]
    source = Db(args.source, args, schemas)
    target = Db(args.dest, args, schemas)
    if set(source.tables.keys()) != set(target.tables.keys()):
        raise Exception('Source and target databases have different tables')
    source.assign_target(target)
    if source.confirm():
        source.create_subset_in(target)
    update_sequences(source, target, schemas, args.tables, args.exclude_tables)


def hashable(raw):
    """If `raw` contains nested lists, convert them to tuples

    >>> hashable(('a', 'b', 'c'))
    ('a', 'b', 'c')
    >>> hashable(('a', ['b', 'c'], 'd'))
    ('a', ('b', 'c'), 'd')
    """

    result = tuple(hashable(itm) if isinstance(itm, list) else itm
                   for itm in raw)
    return result


if __name__ == '__main__':
    generate()