datahuborg/datahub

View on GitHub
src/core/db/query_rewriter.py

Summary

Maintainability
F
3 days
Test Coverage
import sqlparse

from inventory.models import Collaborator
from core.db.rlsmanager import RowLevelSecurityManager
from config import settings


class SQLQueryRewriter:

    def __init__(self, repo_base, user):
        self.repo_base = repo_base
        self.user = user

    def extract_table_info(self, table_string):
        '''
        Takes in a string and parses it for the repo and table name.
        Tables are typically in the form of repo_name.table_name, so
        in this function, we check if the string is of the right form. If so,
        we return a list of [repo_name, table_name, repo_base].
        Otherwise, we return None.

        Valid table infos can be of form repo.table or repo_base.repo.table
        '''
        table_info = table_string.rstrip().split('.')
        if len(table_info) == 2:
            return [table_info[0], table_info[1], None]
        if len(table_info) == 3:
            return [table_info[1], table_info[2], table_info[0]]

        raise Exception('Error parsing %s: missing schema name' % table_string)

    def extract_table_string(self, table):
        '''
        Takes in a string and parses it for the table information. First,
        we take the table input and splits it by spaces to separate the table
        information from aliasing information. For example, if the table is of
        the form:

            (1) repo_name.table_name
            (2) repo_name.table_name AS alias_name
            (3) repo_name.table.name alias_name

        this function will return:

            (1) ([repo_name,table_name], '')
            (2) ([repo_name,table_name], 'AS alias_name')
            (3) ([repo_name,table_name], 'alias_name')

        If the table input is of the wrong form where the first phrase does
        not contain table information, then this function will return None.
        '''
        table_input = table.rstrip().split(' ')
        if table_input[0] != '':
            table_info = self.extract_table_info(table_input[0])
            if table_info is None:
                return None
            alias_info = " ".join(table_input[1:])
            return (table_info, alias_info)
        return None

    def extract_table_token(self, token):
        '''
        Takes in a token and returns a list of table information for each of
        the tables in the token. There may be multiple tables in the token
        because SQLParse parses all text after the FROM token and before the
        next SQL key word in the query as the table name. For example, if
        we have a query like:

            "SELECT * from repo1.table1 as tbl1, repo2.table2 as tbl2 where..."

        then "repo1.table1 as tbl1, repo2.table2 as tbl2" will fall into one
        token. This method will return a list of table information for all
        tables in a token.
        '''
        table_list = []
        token_string = unicode(token)
        tables = token_string.split(',')
        for table in tables:
            table_info = self.extract_table_string(table.rstrip().lstrip())
            if table_info is not None:
                table_list.append((table_info[0], table_info[1]))
        return table_list

    def contains_subquery(self, token):
        '''
        Takes in a token and checks whether the token contains a subquery
        inside it. Return True if so, False otherwise.
        '''
        if not token.is_group():
            return False
        if "select" not in unicode(token).lower():
            return False
        return True

    def extract_subquery(self, token):
        '''
        Takes in a token that contains a subquery and returns a tuple of the
        form (string_before_subquery, subquery_string, string_after_subqery).
        All subqueries are nested in between parantheses, so we are just
        separating the subquery from the other parts that come before and
        after the query.
        '''
        subquery_start = unicode(token).find('(')
        subquery_end = unicode(token).rfind(')')
        return (unicode(token)[:subquery_start + 1],
                unicode(token)[subquery_start + 1:subquery_end],
                unicode(token)[subquery_end:])

    def process_subquery(self, token):
        '''
        Takes in a token and processes the subquery that it contains. First,
        we call extract_subquery to extract the subquery from the
        string that comes before and after it. Then, we apply row level
        security to the extracted subquery, and merge the result with the other
        string components.
        '''
        result = ''
        subquery = self.extract_subquery(token)
        result = subquery[0] + '%s' + subquery[2]
        processed_subquery = self.apply_row_level_security(
            subquery[1].rstrip().lstrip())
        return result % processed_subquery

    def apply_row_level_security(self, query):
        token = unicode(sqlparse.parse(query)[0].tokens[0]).lower()
        if token == "insert":
            return self.apply_row_level_security_insert(query)
        elif token == "update":
            return self.apply_row_level_security_update(query)
        elif token == "explain" or token == "select":
            return self.apply_row_level_security_base(query)
        return query

    def apply_row_level_security_insert(self, query):
        '''
        Takes in an insert SQL query and applies security policies related to
        the insert access type to it. Currently, we only support one type
        of insert permission -- which is that the user making the insert call
        has permission to insert into the specified table.

        # Insert into repo.table values (...)
        # Insert into repo.table values (select * from ....)
        '''
        # Find the table of interest, and check if any meta user insert
        # policies are defined on the table (user='username'). If so,
        # return the query as entered, as the user has insert permissions. If
        # not, raise an exception stating user does not have insert
        # permissions.

        tokens = sqlparse.parse(query)[0].tokens
        prev_token = None
        result = ''

        table = None
        for token in tokens:
            if self.contains_subquery(token):
                result += self.process_subquery(token)
                prev_token = token
                continue

            if (prev_token is None or unicode(token) == " " or
                    unicode(prev_token).lower() != "into"):
                result += unicode(token)
                if unicode(token) != " ":
                    prev_token = token
                continue

            table = self.extract_table_string(unicode(token))
            result += unicode(token)
            if unicode(token) != " ":
                prev_token = token

        if table is not None:
            policy = self.find_table_policies(
                table[0][1], table[0][0], "insert", table[0][2])

            if policy == [] or policy[0] == "INSERT='True'":
                return result

        raise Exception('User does not have insert access on %s' % table[0][1])

    def apply_row_level_security_update(self, query):
        '''
        Takes in an update SQL query and applies security policies related to
        the update access type to it.
        '''
        tokens = sqlparse.parse(query.replace(";", ''))[0].tokens
        prev_token = None
        result = ''

        table = None
        for token in tokens:
            if self.contains_subquery(token):
                result += self.process_subquery(token)
                continue

            if (prev_token is None or unicode(token) == " " or
                    unicode(prev_token).lower() != "update"):
                result += unicode(token)
                if unicode(token) != " ":
                    prev_token = token
                continue

            table = self.extract_table_info(unicode(token))
            result += unicode(token)
            if unicode(token) != " ":
                prev_token = token

        if table is not None:
            policies = self.find_table_policies(
                table[1], table[0], "update", table[2])
            for policy in policies:
                result += (' AND %s' % policy)

        result = result.replace("USERNAME", "'" + self.user + "'")
        return result

    def is_postgres_catalog(self, token):
        token_name = unicode(token)
        if token_name[:3] == 'pg_':
            return True
        return False

    def need_query_rewrite(self, prev_token):
        prev_token = unicode(prev_token).lower()
        joins = ["inner join", "left join", "right join", "join"]
        if prev_token == "from" or prev_token in joins:
            return True
        return False

    def apply_row_level_security_base(self, query):
        '''
        Takes in a SQL query and applies row level security to it. All table
        references in the query are replaced with a subquery that only extracts
        rows from the table for which the user is allowed to see.
        '''
        tokens = sqlparse.parse(query)[0].tokens
        replace_list = []
        prev_token = None
        result = ''

        for token in tokens:
            if self.contains_subquery(token):
                result += self.process_subquery(token)
                prev_token = token
                continue

            if self.is_postgres_catalog(token):
                result += unicode(token)
                prev_token = token
                continue

            if (prev_token is None or unicode(token) == " " or
                    not self.need_query_rewrite(prev_token)):
                result += unicode(token)
                if unicode(token) != " ":
                    prev_token = token
                continue

            table_information = self.extract_table_token(token)
            for table in table_information:
                if table[0][2] is not None:
                    query = '(SELECT * FROM %s.%s.%s' % (
                        table[0][2], table[0][0], table[0][1])
                else:
                    query = '(SELECT * FROM %s.%s' % (table[0][0], table[0][1])

                policies = self.find_table_policies(table[0][1],
                                                    table[0][0],
                                                    "select",
                                                    table[0][2])
                if policies:
                    query += ' WHERE '
                    for policy in policies:
                        query += policy + " OR "
                    query = query[:-4]
                query += ")"

                # Here we are handling table aliasing. In the case where the
                # default query does not use an alias, we need to auto-create
                # an alias for the table (since we are constructing a subquery
                # from the table name to apply row level security). We then
                # need to replace all later instances of the original table
                # with the alias.
                if table[1] != "":
                    query += " %s" % table[1]
                else:
                    original_table_name = table[0][0] + "." + table[0][1]
                    alias_name = table[0][0] + table[0][1]
                    query += " AS %s" % (alias_name)
                    replace_list.append((original_table_name,
                                         alias_name,
                                         len(result) + len(query)))

                result += query
                if len(table_information) > 1:
                    result += ", "

            if len(table_information) > 1:
                result = result[:-2]

            prev_token = token

        for alias in replace_list:
            result = result[0:alias[2]] + result[alias[2]:].replace(
                alias[0] + ".", alias[1] + ".")
            result = result[0:alias[2]] + result[alias[2]:].replace(
                alias[0] + " ", alias[1] + " ")

        result = result.replace("USERNAME", "'" + self.user + "'")
        return result

    def find_table_policies(self, table, repo, policytype, repo_base):
        '''
        Look up policies associated with the table and repo and returns a
        list of all the policies defined for the user.
        '''
        if repo_base is None:
            repo_base = self.repo_base

        # policies that are meant to apply to specific users
        user_policies = RowLevelSecurityManager.find_security_policies(
            repo_base=repo_base,
            repo=repo,
            table=table,
            policy_type=policytype,
            grantee=self.user,
            safe=False)

        # policies that are meant to apply to all users
        all_policies = RowLevelSecurityManager.find_security_policies(
            repo_base=repo_base,
            repo=repo,
            table=table,
            policy_type=policytype,
            grantee=settings.RLS_ALL,
            safe=False)

        # People collaborating on this repo
        collaborators = Collaborator.objects.filter(repo_base=repo_base,
                                                    repo_name=repo)

        # If the user is not explicitly granted access, also load the
        # public_policies
        public_policies = []
        if self.user not in collaborators:
            public_policies = RowLevelSecurityManager.find_security_policies(
                repo_base=repo_base,
                repo=repo,
                table=table,
                policy_type=policytype,
                grantee=settings.RLS_PUBLIC,
                safe=False)

        security_policies = user_policies + all_policies + public_policies

        result = []
        for policy_tuple in security_policies:
            result.append(policy_tuple.policy)

        return result