datahuborg/datahub

View on GitHub
src/core/db/backend/pg.py

Summary

Maintainability
F
2 wks
Test Coverage
import re
import os
import errno
import shutil
import hashlib
from collections import namedtuple
from uuid import uuid4
import psycopg2
import core.db.query_rewriter
from psycopg2.extensions import AsIs
from psycopg2.pool import ThreadedConnectionPool
from psycopg2 import errorcodes
from core.db.licensemanager import LicenseManager

from core.db.errors import PermissionDenied
from config import settings

"""
DataHub internal APIs for postgres repo_base
"""
HOST = settings.DATABASES['default']['HOST']
PORT = 5432

if settings.DATABASES['default']['PORT'] != '':
    try:
        PORT = int(settings.DATABASES['default']['PORT'])
    except:
        pass

# Maintain a separate db connection pool for each (user, password, database)
# tuple.
connection_pools = {}
PoolKey = namedtuple('PoolKey', 'user, password, repo_base')


def _pool_for_credentials(user, password, repo_base, create_if_missing=True):
    pool_key = PoolKey(user, password, repo_base)
    # Create a new pool if one doesn't exist or if the existing one has been
    # closed. Normally a pool should only be closed during testing, to force
    # all hanging connections to a database to be closed.
    if pool_key not in connection_pools or connection_pools[pool_key].closed:
        if create_if_missing is False:
            return None
        # Maintains at least 1 connection.
        # Raises "PoolError: connection pool exausted" if a thread tries
        # holding onto than 10 connections to a single database.
        connection_pools[pool_key] = ThreadedConnectionPool(
            0,
            10,
            user=user,
            password=password,
            host=HOST,
            port=PORT,
            database=repo_base)
    return connection_pools[pool_key]


def _close_all_connections(repo_base):
    for key, pool in connection_pools.iteritems():
        if repo_base == key.repo_base and not pool.closed:
            pool.closeall()


def _convert_pg_exception(e):
    # Convert some psycopg2 errors into exceptions meaningful to
    # Django.
    if (e.pgcode == errorcodes.INSUFFICIENT_PRIVILEGE):
        raise PermissionDenied()
    if (e.pgcode == errorcodes.INVALID_PARAMETER_VALUE or
            e.pgcode == errorcodes.UNDEFINED_OBJECT):
        raise ValueError("Invalid parameter in query.")
    if e.pgcode == errorcodes.INVALID_SCHEMA_NAME:
        error = ('Repo not found. '
                 'You must specify a repo in your query. '
                 'i.e. select * from REPO_NAME.TABLE_NAME. ')
        raise LookupError(error)
    if e.pgcode == errorcodes.UNDEFINED_TABLE:
        raise LookupError("Table or view not found.")
    if e.pgcode == errorcodes.DUPLICATE_SCHEMA:
        raise ValueError("A repo with that name already exists.")
    if e.pgcode == errorcodes.DUPLICATE_TABLE:
        raise ValueError("A table with that name already exists.")
    raise e


class PGBackend:

    def __init__(self, user, password, host=HOST, port=PORT, repo_base=None):
        self.user = user
        self.password = password
        self.host = host
        self.port = port
        self.repo_base = repo_base
        self.connection = None

        # row level security is enabled unless the user is a superuser
        self.row_level_security = bool(
            user != settings.DATABASES['default']['USER'])

        # We only need a query rewriter if RLS is enabled
        if self.row_level_security:
            self.query_rewriter = core.db.query_rewriter.SQLQueryRewriter(
                self.repo_base, self.user)

        self.__open_connection__()

    def __del__(self):
        self.close_connection()

    def __open_connection__(self):
        pool = _pool_for_credentials(self.user, self.password, self.repo_base)
        self.connection = pool.getconn()
        self.connection.set_isolation_level(
            psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)

    def change_repo_base(self, repo_base):
        self.close_connection()
        self.repo_base = repo_base
        self.__open_connection__()

    def close_connection(self):
        pool = _pool_for_credentials(self.user, self.password, self.repo_base,
                                     create_if_missing=False)
        if self.connection and pool and not pool.closed:
            pool.putconn(self.connection, close=True)
            self.connection = None

    def _check_for_injections(self, noun):
        """
        Raises ValueError if the proposed noun is invalid.

        Valid nouns contain only alphanumeric characters and underscores, and
        must not begin or end with an underscore.
        """
        invalid_noun_msg = (
            "Usernames and repo names may only contain "
            "alphanumeric characters and underscores, must begin with a "
            "letter, and must not begin or end with an underscore."
        )

        regex = r'^(?![\_\d])[\w\_]+(?<![\_])$'
        valid_pattern = re.compile(regex)
        matches = valid_pattern.match(noun)

        if matches is None:
            raise ValueError(invalid_noun_msg)

    def _validate_table_name(self, noun):
        """
        Raises ValueError if the proposed table name is invalid.

        Valid table names contain only alphanumeric characters and underscores.
        """
        invalid_noun_msg = (
            "Table names may only contain "
            "alphanumeric characters and underscores, must begin with a "
            "letter, and must not begin or end with an underscore."
        )

        regex = r'^(?![\d])[\w\_]+(?<![\_])$'
        valid_pattern = re.compile(regex)
        matches = valid_pattern.match(noun)

        if matches is None:
            raise ValueError(invalid_noun_msg)

    def create_repo(self, repo):
        """Creates a postgres schema for the user."""
        self._check_for_injections(repo)

        query = 'CREATE SCHEMA IF NOT EXISTS %s AUTHORIZATION %s'
        params = (AsIs(repo), AsIs(self.user))

        res = self.execute_sql(query, params)
        return res['status']

    def list_repos(self):
        query = ('SELECT schema_name AS repo_name '
                 'FROM information_schema.schemata '
                 'WHERE schema_owner != %s')
        params = (settings.DATABASES['default']['USER'],)
        res = self.execute_sql(query, params)
        return [t[0] for t in res['tuples']]

    def rename_repo(self, repo, new_name):
        self._check_for_injections(repo)
        self._check_for_injections(new_name)

        query = 'ALTER SCHEMA %s RENAME TO %s'
        params = (AsIs(repo), AsIs(new_name))
        res = self.execute_sql(query, params)
        return res['status']

    def delete_repo(self, repo, force=False):
        """Deletes a repo and the folder the user's repo files are in."""
        self._check_for_injections(repo)

        # drop the schema
        query = 'DROP SCHEMA %s %s'
        params = (AsIs(repo), AsIs('CASCADE') if force is True else AsIs(''))
        res = self.execute_sql(query, params)
        return res['status']

    def add_collaborator(
            self, repo, collaborator, db_privileges=[], license_id=None):
        # check that all repo names, usernames, and privileges passed aren't
        # sql injections
        self._check_for_injections(repo)
        self._check_for_injections(collaborator)
        for privilege in db_privileges:
            self._check_for_injections(privilege)

        if license_id:
            query = (
                    'BEGIN;'
                    'GRANT USAGE ON SCHEMA %s TO %s;'
                    'COMMIT;')
            privileges_str = ', '.join(db_privileges)
            params = [repo, collaborator, privileges_str, repo,
                      collaborator, repo, privileges_str, collaborator]
            params = tuple(map(lambda x: AsIs(x), params))
            res = self.execute_sql(query, params)
            return res['status']
        else:
            query = ('BEGIN;'
                     'GRANT USAGE ON SCHEMA %s TO %s;'
                     'GRANT %s ON ALL TABLES IN SCHEMA %s TO %s;'
                     'ALTER DEFAULT PRIVILEGES IN SCHEMA %s '
                     'GRANT %s ON TABLES TO %s;'
                     'COMMIT;'
                     )

            privileges_str = ', '.join(db_privileges)
            params = [repo, collaborator, privileges_str, repo,
                      collaborator, repo, privileges_str, collaborator]
            params = tuple(map(lambda x: AsIs(x), params))
            res = self.execute_sql(query, params)
            return res['status']

    def add_collaborator_to_license_view(
            self, repo, collaborator, view, db_privileges=[]):
        # check that all repo names, usernames, and privileges passed aren't
        # sql injections
        self._check_for_injections(repo)
        self._check_for_injections(collaborator)
        for privilege in db_privileges:
            self._check_for_injections(privilege)

        query = ('BEGIN;'
                 'GRANT %s ON %s.%s TO %s;'
                 'ALTER DEFAULT PRIVILEGES IN SCHEMA %s '
                 'GRANT %s ON %s.%s TO %s;'
                 'COMMIT;'
                 )

        privileges_str = ', '.join(db_privileges)
        params = [
                privileges_str, repo, view,
                collaborator, repo, privileges_str,
                repo, view, collaborator]

        params = tuple(map(lambda x: AsIs(x), params))
        res = self.execute_sql(query, params)
        return res['status']

    def delete_collaborator(self, repo, collaborator):
        self._check_for_injections(repo)
        self._check_for_injections(collaborator)

        query = ('BEGIN;'
                 'REVOKE ALL ON ALL TABLES IN SCHEMA %s FROM %s CASCADE;'
                 'REVOKE ALL ON SCHEMA %s FROM %s CASCADE;'
                 'ALTER DEFAULT PRIVILEGES IN SCHEMA %s '
                 'REVOKE ALL ON TABLES FROM %s;'
                 'COMMIT;'
                 )
        params = [repo, collaborator, repo, collaborator, repo, collaborator]
        params = tuple(map(lambda x: AsIs(x), params))

        res = self.execute_sql(query, params)
        return res['status']

    def create_license_view(self, repo_base, repo,
                            table, view_sql, license_id):
        view_name = table.lower() + "_license_view_"+str(license_id)
        res = self.create_view(repo, view_name, view_sql)

        return res

    def delete_license_view(self, repo_base, repo, license_view):
        res = self.delete_view(repo, license_view)

        return res

    def get_view_sql(self, repo_base, repo, table, view_params, license_id):
        # create view based on license
        license = LicenseManager.find_license_by_id(license_id)

        pii_def = license.pii_def

        if license.pii_removed:
            # remove columns
            query = ('SELECT column_name FROM information_schema.columns '
                     'WHERE table_schema = %s'
                     'AND table_name = %s'
                     )
            params = (repo, table)
            res = self.execute_sql(query, params)

            columns = [t[0] for t in res['tuples']]

            all_columns = set(columns)
            removed_columns = set(view_params['removed-columns'])
            columns_to_show = list(all_columns - removed_columns)
            # if columns_to_show < 1:
            #     #error
            #     pass

            query = 'SELECT {} FROM {}.{}'

            columns_query = ""
            for i in range(len(columns_to_show)):
                columns_query += columns_to_show[i]
                if i < len(columns_to_show) - 1:
                    columns_query += ","

            query = query.format(columns_query, repo, table)

            return query

    def create_table(self, repo, table, params):
        # check for injections
        self._check_for_injections(repo)
        self._validate_table_name(table)
        param_values = []
        for obj in params:
            param_values += obj.values()
        for value in param_values:
            self._check_for_injections(value)

        query = ('CREATE TABLE %s.%s (%s)')

        table_params = ''
        for obj in params:
            table_params += obj['column_name']
            table_params += ' '
            table_params += obj['data_type']
            table_params += ', '

        table_params = table_params[:-2]

        params = (AsIs(repo), AsIs(table), AsIs(table_params))
        res = self.execute_sql(query, params)

        return res['status']

    def list_tables(self, repo):
        self._check_for_injections(repo)

        all_repos = self.list_repos()

        if repo not in all_repos:
            raise LookupError('Invalid repository name: %s' % (repo))

        query = ('SELECT table_name FROM information_schema.tables '
                 'WHERE table_schema = %s AND table_type = \'BASE TABLE\';'
                 )
        params = (repo,)
        res = self.execute_sql(query, params)
        return [t[0] for t in res['tuples']]

    def describe_table(self, repo, table, detail=False):
        query = ("SELECT %s "
                 "FROM information_schema.columns "
                 "WHERE table_schema = %s and table_name = %s;")
        params = None
        if detail:
            params = (AsIs('*'), repo, table)
        else:
            params = (AsIs('column_name, data_type'), repo, table)

        res = self.execute_sql(query, params)
        return res['tuples']

    def list_table_permissions(self, repo, table):
        query = ("select privilege_type from "
                 "information_schema.role_table_grants where table_schema=%s "
                 "and table_name=%s and grantee=%s")
        params = (repo, table, self.user)
        res = self.execute_sql(query, params)
        return res['tuples']

    def create_view(self, repo, view, sql):
        self._check_for_injections(repo)
        self._validate_table_name(view)
        query = ('CREATE VIEW %s.%s AS (%s)')

        params = (
                AsIs(repo), AsIs(view),
                AsIs(sql))

        res = self.execute_sql(query, params)

        return res['status']

    def list_views(self, repo):
        self._check_for_injections(repo)

        all_repos = self.list_repos()
        if repo not in all_repos:
            raise LookupError('Invalid repository name: %s' % (repo))

        query = ('SELECT table_name FROM information_schema.tables '
                 'WHERE table_schema = %s '
                 'AND table_type = \'VIEW\';')

        params = (repo,)

        res = self.execute_sql(query, params)

        return [t[0] for t in res['tuples']]

    def delete_view(self, repo, view, force=False):
        self._check_for_injections(repo)
        self._validate_table_name(view)

        force_param = 'RESTRICT'
        if force:
            force_param = 'CASCADE'

        query = ('DROP VIEW %s.%s.%s %s')
        params = (AsIs(self.repo_base), AsIs(repo), AsIs(view),
                  AsIs(force_param))

        res = self.execute_sql(query, params)
        return res['status']

    def describe_view(self, repo, view, detail=False):
        query = ("SELECT %s "
                 "FROM information_schema.columns "
                 "WHERE table_schema = %s and table_name = %s;")

        params = None
        if detail:
            params = (AsIs('*'), repo, view)
        else:
            params = (AsIs('column_name, data_type'), repo, view)

        res = self.execute_sql(query, params)

        return res['tuples']

    def delete_table(self, repo, table, force=False):
        self._check_for_injections(repo)
        self._validate_table_name(table)

        force_param = 'RESTRICT'
        if force:
            force_param = 'CASCADE'

        query = ('DROP TABLE %s.%s.%s %s')
        params = (AsIs(self.repo_base), AsIs(repo), AsIs(table),
                  AsIs(force_param))

        res = self.execute_sql(query, params)
        return res['status']

    def clone_table(self, repo, table, new_table):
        self._validate_table_name(table)
        self._validate_table_name(new_table)

        query = 'CREATE TABLE %s.%s AS SELECT * FROM %s.%s'
        params = (AsIs(repo), AsIs(new_table), AsIs(repo), AsIs(table))

        res = self.execute_sql(query, params)
        return res['status']

    def get_schema(self, repo, table):
        self._check_for_injections(repo)
        self._validate_table_name(table)

        query = ('SELECT column_name, data_type '
                 'FROM information_schema.columns '
                 'WHERE table_name = %s '
                 'AND table_schema = %s;'
                 )

        params = (table, repo)
        res = self.execute_sql(query, params)

        if res['row_count'] < 1:
            raise NameError("Invalid reference: '%s.%s'.\n" % (repo, table))

        # return will look like [('id', 'integer'), ('words', 'text')]
        return res['tuples']

    def explain_query(self, query):
        """
        returns the number of rows, the cost (in time) to execute,
        and the width (bytes) of rows outputted
        """
        # if it's a select query, return a different set of defaults
        select_query = bool((query.split()[0]).lower() == 'select')

        if not select_query:
            response = {'num_rows': 1, 'time_cost': 0, 'byte_width': 0}
            return response

        query = 'EXPLAIN %s' % (query)
        res = self.execute_sql(query)

        num_rows = re.match(r'.*rows=(\d+).*', res['tuples'][0][0]).group(1)
        byte_width = re.match(r'.*width=(\d+).*', res['tuples'][0][0]).group(1)

        time_cost_re = re.match(
            r'.*cost=(\d+.\d+)..(\d+.\d+)*', res['tuples'][0][0])
        time_cost = (float(time_cost_re.group(1)),
                     float(time_cost_re.group(2)))

        response = {'num_rows': int(num_rows),
                    'time_cost': time_cost,
                    'byte_width': int(byte_width)
                    }
        return response

    def limit_and_offset_select_query(self, query, limit, offset):
        query = query.strip().rstrip(';')

        # is it a select query?
        select_query = False
        if (query.split()[0]).lower() == 'select':
            select_query = True

        # return select query
        if select_query:
            query = ('select * from ( %s ) '
                     'as BXCQWVPEMWVKFBEBNKZSRPYBSB '
                     'LIMIT %s OFFSET %s;'
                     % (query, limit, offset))

        return {'select_query': select_query, 'query': query}

    def select_table_query(self, repo_base, repo, table):
        dh_table_name = '%s.%s.%s' % (repo_base, repo, table)
        query = 'SELECT * FROM %s;' % (dh_table_name)
        return query

    def execute_sql(self, query, params=None):
        result = {
            'status': False,
            'row_count': 0,
            'tuples': [],
            'fields': []
        }

        query = query.strip()
        cur = self.connection.cursor()

        try:
            sql_query = cur.mogrify(query, params)
            if self.row_level_security:
                sql_query = self.query_rewriter.apply_row_level_security(
                    sql_query)
            cur.execute(sql_query)
        except psycopg2.Error as e:
            # Convert some psycopg2 errors into exceptions meaningful to
            # Django.
            _convert_pg_exception(e)

        # if cur.execute() failed, this will print it.
        try:
            result['tuples'] = cur.fetchall()
        except psycopg2.ProgrammingError:
            # print "possible psycopg2.ProgrammingError in pg.execute_sql: "
            # print(e)
            pass

        result['status'] = True
        result['row_count'] = cur.rowcount
        if cur.description:
            result['fields'] = [
                {'name': col[0], 'type': col[1]} for col in cur.description]

        cur.close()
        return result

    def user_exists(self, username):
        query = "SELECT 1 FROM pg_roles WHERE rolname=%s"
        params = (username,)
        result = self.execute_sql(query, params)
        return (result['row_count'] > 0)

    def database_exists(self, db_name):
        query = "SELECT 1 FROM pg_database WHERE datname=%s"
        params = (db_name,)
        result = self.execute_sql(query, params)
        return (result['row_count'] > 0)

    def create_user(self, username, password, create_db=True):
        self._check_for_injections(username)

        query = ('CREATE ROLE %s WITH LOGIN '
                 'NOCREATEDB NOCREATEROLE NOCREATEUSER PASSWORD %s')
        params = (AsIs(username), password)
        self.execute_sql(query, params)

        # Don't do this in the case of the public user.
        if username != settings.PUBLIC_ROLE:
            query = ('GRANT %s to %s')
            params = (AsIs(settings.PUBLIC_ROLE), AsIs(username))
            self.execute_sql(query, params)

        if create_db:
            return self.create_user_database(username)

    def create_user_database(self, username):
        # lines need to be executed seperately because
        # "CREATE DATABASE cannot be executed from a
        # function or multi-command string"
        self._check_for_injections(username)

        query = 'CREATE DATABASE %s; '
        params = (AsIs(username),)
        self.execute_sql(query, params)

        query = 'ALTER DATABASE %s OWNER TO %s; '
        params = (AsIs(username), AsIs(username))
        return self.execute_sql(query, params)

    def remove_user(self, username):
        self._check_for_injections(username)

        query = 'DROP ROLE %s;'
        params = (AsIs(username),)
        return self.execute_sql(query, params)

    def drop_owned_by(self, username):
        self._check_for_injections(username)
        query = 'DROP OWNED BY %s CASCADE;' % (username)
        params = (AsIs(username), )
        return self.execute_sql(query, params)

    def list_all_users(self):
        query = 'SELECT usename FROM pg_catalog.pg_user WHERE usename != %s'
        params = (self.user,)
        res = self.execute_sql(query, params)
        user_tuples = res['tuples']

        all_users_list = []
        for user_tuple in user_tuples:
            all_users_list.append(user_tuple[0])

        return all_users_list

    def list_all_databases(self):
        query = ('SELECT datname FROM pg_database where datname NOT IN '
                 ' (%s, \'template1\', \'template0\', '
                 ' \'datahub\', \'test_datahub\', \'postgres\');'
                 )
        params = (self.user, )
        res = self.execute_sql(query, params)
        db_tuples = res['tuples']

        all_db_list = []
        for db_tuple in db_tuples:
            all_db_list.append(db_tuple[0])

        return all_db_list

    def remove_database(self, database, revoke_collaborators=True):
        self._check_for_injections(database)

        # remove collaborator access to the database
        if revoke_collaborators:
            all_users = self.list_all_users()

            for user in all_users:
                query = "REVOKE ALL ON DATABASE %s FROM %s;"
                params = (AsIs(database), AsIs(user))
                self.execute_sql(query, params)

        # Make sure to close all extant connections to this database or the
        # drop will fail.
        _close_all_connections(database)

        # drop database
        query = 'DROP DATABASE %s;'
        params = (AsIs(database),)
        try:
            return self.execute_sql(query, params)
        except psycopg2.ProgrammingError as e:
            print(e)
            print('this probably happened because the postgres role'
                  'exists, but a database of the same name does not.')

    def change_password(self, username, password):
        self._check_for_injections(username)
        query = 'ALTER ROLE %s WITH PASSWORD %s;'
        params = (AsIs(username), password)
        return self.execute_sql(query, params)

    def list_collaborators(self, repo):
        query = 'SELECT unnest(nspacl) FROM pg_namespace WHERE nspname=%s;'
        params = (repo, )
        res = self.execute_sql(query, params)

        # postgres privileges
        # r -- SELECT ("read")
        # w -- UPDATE ("write")
        # a -- INSERT ("append")
        # d -- DELETE
        # D -- TRUNCATE
        # x -- REFERENCES
        # t -- TRIGGER
        # X -- EXECUTE
        # U -- USAGE
        # C -- CREATE
        # c -- CONNECT
        # T -- TEMPORARY
        # arwdDxt -- ALL PRIVILEGES (for tables, varies for other objects)
        # * -- grant option for preceding privilege
        # /yyyy -- role that granted this privilege

        collaborators = []
        for row in res['tuples']:
            # for reference, rows look like this:
            # ('username=UC/repo_base',)

            collab_obj = {}
            username = row[0].split('=')[0].strip()
            permissions = row[0].split('=')[1].split('/')[0]

            collab_obj['username'] = username
            collab_obj['db_permissions'] = permissions

            collaborators.append(collab_obj)

        return collaborators

    def has_base_privilege(self, login, privilege):
        """
        returns True or False for whether the user has privileges for the
        repo_base (database)
        """
        query = 'SELECT has_database_privilege(%s, %s);'
        params = (login, privilege)
        res = self.execute_sql(query, params)
        return res['tuples'][0][0]

    def has_repo_db_privilege(self, login, repo, privilege):
        """
        returns True or False for whether the use has privileges for the
        repo (schema)
        """
        query = 'SELECT has_schema_privilege(%s, %s, %s);'
        params = (login, repo, privilege)
        res = self.execute_sql(query, params)
        return res['tuples'][0][0]

    def has_table_privilege(self, login, table, privilege):
        query = 'SELECT has_table_privilege(%s, %s, %s);'
        params = (login, table, privilege)
        res = self.execute_sql(query, params)
        return res['tuples'][0][0]

    def has_column_privilege(self, login, table, column, privilege):
        query = 'SELECT has_column_privilege(%s, %s, %s, %s);'
        params = (login, table, column, privilege)
        res = self.execute_sql(query, params)
        return res['tuples'][0][0]

    def export_table(self, table_name, file_path, file_format='CSV',
                     delimiter=',', header=True):
        words = table_name.split('.')
        for word in words[:-1]:
            self._check_for_injections(word)
        self._validate_table_name(words[-1])

        self._check_for_injections(file_format)

        query = 'SELECT * FROM %s' % table_name
        self.export_query(
            query,
            file_path,
            file_format=file_format,
            delimiter=delimiter,
            header=header)

    def export_view(self, view_name, file_path, file_format='CSV',
                    delimiter=',', header=True):
        words = view_name.split('.')
        for word in words[:-1]:
            self._check_for_injections(word)
        self._validate_table_name(words[-1])

        self._check_for_injections(file_format)

        query = 'SELECT * FROM %s' % view_name
        self.export_query(
            query,
            file_path,
            file_format=file_format,
            delimiter=delimiter,
            header=header)

    def export_query(self, query, file_path, file_format='CSV',
                     delimiter=',', header=True):
        """
        Runs a query as the current user and saves the result to a file.

        query can be a sql query or table reference.
        """
        header_option = 'HEADER' if header else ''
        query = query.split(';')[0].strip()

        self._check_for_injections(file_format)
        self._check_for_injections(header_option)

        meta_query = 'COPY (%s) TO STDOUT WITH %s %s DELIMITER %s;'
        params = (AsIs(query), AsIs(file_format),
                  AsIs(header_option), delimiter)

        cur = self.connection.cursor()
        query = cur.mogrify(meta_query, params)

        # Store pending exports in a temporary location so they're aren't
        # discoverable while being exported.
        tmp_path = '/tmp/user_exports/{0}-{1}'.format(
            uuid4().hex, hashlib.sha256(query).hexdigest())
        try:
            os.makedirs('/tmp/user_exports')
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise e

        try:
            with open(tmp_path, 'w') as f:
                cur.copy_expert(query, f)
        except psycopg2.Error as e:
            # Delete the temporary files of failed exports.
            os.remove(tmp_path)
            _convert_pg_exception(e)
        finally:
            cur.close()
        # Move successful exports into the user's data folder.
        # os.rename() would fail here if /tmp and /user_data are stored on
        # different filesystems, so use shutil.move() instead.
        shutil.move(tmp_path, file_path)

    def import_file(self, table_name, file_path, file_format='CSV',
                    delimiter=',', header=True, encoding='ISO-8859-1',
                    quote_character='"'):

        header_option = 'HEADER' if header else ''

        words = table_name.split('.')
        for word in words[:-1]:
            self._check_for_injections(word)
        self._validate_table_name(words[-1])
        self._check_for_injections(file_format)

        query = 'COPY %s FROM %s WITH %s %s DELIMITER %s ENCODING %s QUOTE %s;'
        params = (AsIs(table_name), file_path, AsIs(file_format),
                  AsIs(header_option), delimiter, encoding, quote_character)
        try:
            self.execute_sql(query, params)
        except Exception as e:
            self.execute_sql('DROP TABLE IF EXISTS %s', (AsIs(table_name),))
            raise ImportError(e)

            # Try importing using dbtruck. Was never enabled by anant.
            # RogerTangos 2015-12-09
            # return self.import_file_w_dbtruck(table_name, file_path)

    def import_rows(self, repo, table, rows, delimiter=',', header=False):
        # if there was a header, remove it
        if header:
            rows = rows[1:len(rows)]

        query = 'INSERT INTO %s.%s values '
        params = [AsIs(repo), AsIs(table)]

        # prepare query
        all_row_array = []
        for row in rows:

            # split the string into an array
            row = row.split(delimiter)

            # add the objects to params
            params += row

            # turn every item in the array into a %s
            # and make a string out of it
            row_array = ['%s' for c in row]
            row_array = ', '.join(row_array)
            row_array = '(' + row_array + ')'

            all_row_array.append(row_array)

        all_row_string = ', '.join(all_row_array)

        # finalize the query and params
        query += all_row_string
        params = tuple(params)

        res = self.execute_sql(query, params)
        return res['status']

    def import_file_w_dbtruck(self, table_name, file_path):
        # dbtruck is not tested for safety. At all. It's currently disabled
        # in the project RogerTangos 2015-12-09
        from dbtruck.dbtruck import import_datafiles
        # from dbtruck.util import get_logger
        from dbtruck.exporters.pg import PGMethods

        dbsettings = {
            'dbname': self.repo_base,
            'hostname': self.host,
            'username': self.user,
            'password': self.password,
            'port': self.port,
        }

        create_new = True
        errfile = None

        return import_datafiles([file_path], create_new, table_name, errfile,
                                PGMethods, **dbsettings)
    # Methods for Licenses

    def create_license_schema(self):
        public_role = settings.PUBLIC_ROLE
        schema = settings.LICENSE_SCHEMA
        self._check_for_injections(public_role)
        self._check_for_injections(schema)
        query = 'CREATE SCHEMA IF NOT EXISTS %s AUTHORIZATION %s'
        params = (AsIs(schema), AsIs(public_role))

        return self.execute_sql(query, params)

    def create_license_table(self):
        schema = settings.LICENSE_SCHEMA
        table = settings.LICENSE_TABLE
        public_role = settings.PUBLIC_ROLE

        self._check_for_injections(schema)
        self._validate_table_name(table)
        self._check_for_injections(public_role)

        query = ('CREATE TABLE IF NOT EXISTS %s.%s'
                 '(license_id serial primary key,'
                 'license_name VARCHAR(40),'
                 'pii_def VARCHAR(100) NOT NULL,'
                 'pii_removed boolean NOT NULL,'
                 'pii_anonymized boolean NOT NULL);')
        params = (AsIs(schema), AsIs(table))
        self.execute_sql(query, params)

        # grant the public role access to the table
        query = ('GRANT ALL ON %s.%s to %s;')
        params = (AsIs(schema), AsIs(table), AsIs(public_role))

        return self.execute_sql(query, params)

    def create_license_link_table(self):
        schema = settings.LICENSE_LINK_SCHEMA
        table = settings.LICENSE_LINK_TABLE
        public_role = settings.PUBLIC_ROLE

        self._check_for_injections(schema)
        self._validate_table_name(table)
        self._check_for_injections(public_role)

        query = ('CREATE TABLE IF NOT EXISTS %s.%s '
                 '(license_link_id serial primary key,'
                 'repo_base VARCHAR(40) NOT NULL,'
                 'repo VARCHAR(40) NOT NULL,'
                 'license_id integer NOT NULL);')
        params = (AsIs(schema), AsIs(table))
        self.execute_sql(query, params)

        query = ('GRANT ALL ON %s.%s to %s;')
        params = (AsIs(schema), AsIs(table), AsIs(public_role))

        return self.execute_sql(query, params)

    def create_license(
            self, license_name, pii_def, pii_anonymized, pii_removed):
        '''
        Creates a new license
        '''
        query = (
                'INSERT INTO dh_public.license '
                '(license_name, pii_def, pii_anonymized, pii_removed) '
                'values (%s, %s, %s, %s)')
        params = (license_name, pii_def, pii_anonymized, pii_removed)

        res = self.execute_sql(query, params)

        return res['status']

    def create_license_link(self, repo_base, repo, license_id):
        '''
        Creates a new license
        '''

        # check if link already exists
        query = ('SELECT license_link_id, repo_base, repo, license_id '
                 'FROM %s.%s where '
                 'repo_base = %s and repo = %s and license_id = %s;')
        params = (
                AsIs(settings.LICENSE_SCHEMA),
                AsIs(settings.LICENSE_LINK_TABLE),
                repo_base, repo, license_id)

        res = self.execute_sql(query, params)

        if res['tuples']:
            return res['status']

        query = (
                'INSERT INTO dh_public.license_link '
                '(repo_base, repo, license_id) '
                'values (%s, %s, %s)')
        params = (repo_base, repo, license_id)

        res = self.execute_sql(query, params)

        return res['status']

    def find_license_links(self, license_id):
        '''
        finds all license_links associated with a given license_id
        '''
        query = ('SELECT license_link_id, repo_base, repo, license_id '
                 'FROM %s.%s ;')
        params = (
                AsIs(settings.LICENSE_SCHEMA),
                AsIs(settings.LICENSE_LINK_TABLE))
        res = self.execute_sql(query, params)

        if not res['tuples']:
            return []

        return res['tuples']

    def find_license_links_by_repo(self, repo_base, repo):
        query = ('SELECT license_link_id, repo_base, repo, license_id '
                 'FROM %s.%s where repo_base = %s and repo = %s;')
        params = (
                AsIs(settings.LICENSE_SCHEMA),
                AsIs(settings.LICENSE_LINK_TABLE),
                repo_base, repo)
        res = self.execute_sql(query, params)

        if not res['tuples']:
            return []

        return res['tuples']

    def find_licenses(self):
        '''
        find all licenses
        '''
        query = (
                'SELECT license_id, license_name, pii_def, '
                'pii_anonymized, pii_removed FROM %s.%s;')
        params = (AsIs(settings.LICENSE_SCHEMA), AsIs(settings.LICENSE_TABLE))

        res = self.execute_sql(query, params)
        return res['tuples']

    def find_license_by_id(self, license_id):
        query = (
                'SELECT license_id, license_name, pii_def, '
                'pii_anonymized, pii_removed '
                'FROM %s.%s where license_id= %s;')
        params = (
                AsIs(settings.LICENSE_SCHEMA),
                AsIs(settings.LICENSE_TABLE),
                license_id)

        res = self.execute_sql(query, params)

        # return None if the list is empty
        if not res['tuples']:
            return None

        # else, return the policy
        return res['tuples'][0]

    # Below methods can only be called from the RLSSecurityManager #

    def create_security_policy_schema(self):
        public_role = settings.PUBLIC_ROLE
        schema = settings.POLICY_SCHEMA
        self._check_for_injections(public_role)
        self._check_for_injections(schema)

        query = 'CREATE SCHEMA IF NOT EXISTS %s AUTHORIZATION %s'
        params = (AsIs(schema), AsIs(public_role))
        return self.execute_sql(query, params)

    def create_security_policy_table(self):
        schema = settings.POLICY_SCHEMA
        table = settings.POLICY_TABLE
        public_role = settings.PUBLIC_ROLE

        self._check_for_injections(schema)
        self._validate_table_name(table)
        self._check_for_injections(public_role)

        query = ('CREATE TABLE IF NOT EXISTS %s.%s'
                 '('
                 'policy_id serial primary key,'
                 'policy VARCHAR(80) NOT NULL,'
                 'policy_type VARCHAR(80) NOT NULL,'
                 'grantee VARCHAR(80) NOT NULL,'
                 'grantor VARCHAR(80) NOT NULL,'
                 'table_name VARCHAR(80) NOT NULL,'
                 'repo VARCHAR(80) NOT NULL,'
                 'repo_base VARCHAR(80) NOT NULL'
                 ');')
        params = (AsIs(schema), AsIs(table))
        self.execute_sql(query, params)

        # create indexes for faster seraching
        query = ('create index grantee_index on '
                 'dh_public.policy using hash(grantee); '

                 'create index grantor_index on '
                 'dh_public.policy using hash(grantor); '

                 'create index table_name_index on '
                 'dh_public.policy using hash(table_name); '

                 'create index repo_index on '
                 'dh_public.policy using hash(repo); '

                 'create index repo_base_index on '
                 'dh_public.policy using hash(repo_base);')

        # postgres 9.4 doesn't support IF NOT EXISTS when creating indexes
        # so it's possible for tests to attempt to create duplicate indexes
        # This catches that exception
        try:
            self.execute_sql(query)
        except:
            pass

        # grant the public role access to the table
        query = ('GRANT ALL ON %s.%s to %s;')
        params = (AsIs(schema), AsIs(table), AsIs(public_role))
        return self.execute_sql(query, params)

    def create_security_policy(self, policy, policy_type, grantee, grantor,
                               repo_base, repo, table):
        '''
        Creates a new security policy in the policy table if the policy
        does not yet exist.
        '''

        # disallow semicolons in policy. This helps prevent the policy creator
        # from shooting themself in the foot with an attempted sql injection.
        # Note that we don't actually _need_ to do this. The parameters are all
        # escaped in RLS methods executed by the superuser, so there's not a
        # really a risk of a user acquiring root access.
        if ';' in policy:
            raise ValueError("\';'s are disallowed in the policy field")

        query = ('INSERT INTO dh_public.policy (policy, policy_type, grantee, '
                 'grantor, table_name, repo, repo_base) values '
                 '(%s, %s, %s, %s, %s, %s, %s)')
        params = (policy, policy_type, grantee, grantor, table, repo,
                  repo_base)

        res = self.execute_sql(query, params)

        return res['status']

    def find_all_security_policies(self, username):
        params = (username, username)

        query = ('SELECT policy_id, policy, policy_type, grantee, grantor '
                 'FROM dh_public.policy WHERE grantee = %s or '
                 'grantor = %s')

        res = self.execute_sql(query, params)
        return res['tuples']

    def find_security_policies(self, repo_base, repo=None, table=None,
                               policy_id=None, policy=None, policy_type=None,
                               grantee=None, grantor=None):
        '''
        Returns a list of all security polices that match the inputs specied
        by the user.
        '''
        query = ('SELECT policy_id, policy, policy_type, grantee, grantor '
                 'repo_base, repo, table_name '
                 'FROM %s.%s WHERE ')
        params = [AsIs(settings.POLICY_SCHEMA), AsIs(settings.POLICY_TABLE)]
        conditions = []

        # append mandatory passed-in conditions
        conditions.append('repo_base = %s')
        params.append(repo_base)

        # append optional conditions
        if repo:
            conditions.append('repo = %s')
            params.append(repo)
        if table:
            conditions.append('table_name = %s')
            params.append(table)
        if policy_id:
            conditions.append('policy_id = %s')
            params.append(policy_id)
        if policy:
            conditions.append('policy = %s')
            params.append(policy)
        if policy_type:
            conditions.append('policy_type = %s')
            params.append(policy_type)
        if grantee:
            conditions.append('grantee = %s')
            params.append(grantee)
        if grantor:
            conditions.append('grantor = %s')
            params.append(grantor)

        conditions = " and ".join(conditions)
        params = tuple(params)
        query += conditions

        res = self.execute_sql(query, params)

        return res['tuples']

    def find_security_policy_by_id(self, policy_id):
        '''
        Returns the security policy that has a policy_id matching the input
        specified by the user.
        '''
        query = ('SELECT policy_id, policy, policy_type, grantee, grantor, '
                 'repo_base, repo, table_name '
                 'FROM dh_public.policy WHERE policy_id = %s')
        params = (policy_id,)
        res = self.execute_sql(query, params)

        # return None if the list is empty
        if not res['tuples']:
            return None

        # else, return the policy
        return res['tuples'][0]

    def update_security_policy(self, policy_id, new_policy, new_policy_type,
                               new_grantee):
        '''
        Updates an existing security policy based on the inputs specified
        by the user.
        '''
        query = ('UPDATE dh_public.policy '
                 'SET policy = %s, policy_type = %s, '
                 'grantee = %s '
                 'WHERE policy_id = %s')
        params = (new_policy, new_policy_type, new_grantee, policy_id)

        res = self.execute_sql(query, params)
        return res['status']

    def remove_security_policy(self, policy_id):
        '''
        Removes the security policy from the policy table with a policy_id
        matching the one specified.
        '''
        query = 'DELETE FROM dh_public.policy WHERE policy_id = %s'
        params = (policy_id,)
        res = self.execute_sql(query, params)
        return res['status']

    def can_user_access_rls_table(self,
                                  username,
                                  permissions=['SELECT', 'UPDATE', 'INSERT']):
        '''
        Returns True if the has been granted specified type(s) of access to
        select/update/insert into the RLS policy table. Else, returns false.

        This must be executed from a connection to the
        settings.POLICY_DB database. Otherwise, it will check the wrong
        database, and (most likely) return fFalse
        '''
        query = ("SELECT exists("
                 "SELECT * FROM %s.%s where grantee=lower(%s) and (")

        conditions = ["lower(policy_type)=lower(%s)"] * len(permissions)
        conditions = " or ".join(conditions)
        query += conditions + "))"

        params = (AsIs(settings.POLICY_SCHEMA),
                  AsIs(settings.POLICY_TABLE),
                  username) + tuple(permissions)

        res = self.execute_sql(query, params)
        return res['tuples'][0][0]