datahuborg/datahub

View on GitHub
src/core/test/test_query_rewriter.py

Summary

Maintainability
B
5 hrs
Test Coverage
from core.db.query_rewriter import SQLQueryRewriter

from django.db.models import signals
from django.test import TestCase

import factory
import sqlparse
from mock import patch


class QueryRewriter(TestCase):
    """Tests all the query rewriter operations in query_rewriter.py."""

    @factory.django.mute_signals(signals.pre_save)
    def setUp(self):
        self.repo_base = "test_repobase"
        self.user = "test_user"
        self.query_rewriter = SQLQueryRewriter(self.repo_base, self.user)

        self.mock_connection = self.create_patch(
            'core.db.manager.DataHubConnection')

    def create_patch(self, name):
        # helper method for creating patches
        patcher = patch(name)
        thing = patcher.start()
        self.addCleanup(patcher.stop)
        return thing

    def test_extract_table_info(self):
        valid_table_token = "repo.table"
        expected_result = ["repo", "table", None]
        self.assertEqual(
            self.query_rewriter.extract_table_info(valid_table_token),
            expected_result)

        valid_table_token = "repobase.repo.table"
        expected_result = ["repo", "table", "repobase"]
        self.assertEqual(
            self.query_rewriter.extract_table_info(valid_table_token),
            expected_result)

        invalid_table_token = "testtable"
        exception_raised = False
        try:
            self.query_rewriter.extract_table_info(invalid_table_token)
        except Exception:
            exception_raised = True
        self.assertEquals(exception_raised, True)

        invalid_table_token = "table1.table2.table3.table4"
        exception_raised = False
        try:
            self.query_rewriter.extract_table_info(invalid_table_token)
        except Exception:
            exception_raised = True
        self.assertEquals(exception_raised, True)

    def test_extract_table_token(self):
        query = "SELECT * from repo1.table1 as tbl1"
        token = sqlparse.parse(query)[0].tokens[6]
        expected_result = [(["repo1", "table1", None], "as tbl1")]
        self.assertEqual(
            self.query_rewriter.extract_table_token(token), expected_result)

        query = ("SELECT * from repo1.table1 as tbl1, repo2.table2 as tbl2 "
                 "where ... ")
        token = sqlparse.parse(query)[0].tokens[6]
        expected_result = [(["repo1", "table1", None], "as tbl1"),
                           (["repo2", "table2", None], "as tbl2 where ...")]
        self.assertEqual(
            self.query_rewriter.extract_table_token(token), expected_result)

        query = "SELECT * from repo1.table1 tbl1, repo2.table2 tbl2 where ... "
        token = sqlparse.parse(query)[0].tokens[6]
        expected_result = [(["repo1", "table1", None], "tbl1"),
                           (["repo2", "table2", None], "tbl2")]
        self.assertEqual(
            self.query_rewriter.extract_table_token(token), expected_result)

    def test_extract_table_string(self):
        valid_table_string = "repo.table"
        expected_result = (["repo", "table", None], '')
        self.assertEqual(
            self.query_rewriter.extract_table_string(valid_table_string),
            expected_result)

        valid_table_string = "repo.table test"
        expected_result = (["repo", "table", None], 'test')
        self.assertEqual(
            self.query_rewriter.extract_table_string(valid_table_string),
            expected_result)

        valid_table_string = "repo.table as test"
        expected_result = (["repo", "table", None], 'as test')
        self.assertEqual(
            self.query_rewriter.extract_table_string(valid_table_string),
            expected_result)

        valid_table_string = "repobase.repo.table test "
        expected_result = (["repo", "table", "repobase"], 'test')
        self.assertEqual(
            self.query_rewriter.extract_table_string(valid_table_string),
            expected_result)

        valid_table_string = "repobase.repo.table as test "
        expected_result = (["repo", "table", "repobase"], 'as test')
        self.assertEqual(
            self.query_rewriter.extract_table_string(valid_table_string),
            expected_result)

        invalid_table_string = "invalidtable"
        exception_raised = False
        try:
            self.query_rewriter.extract_table_string(invalid_table_string)
        except Exception:
            exception_raised = True
        self.assertEquals(exception_raised, True)

    def test_contains_subquery(self):
        query = ("select * from (select * from repo.table where "
                 "repo.table.test = 'True')")
        subquery_token = sqlparse.parse(query)[0].tokens[6]
        no_subquery_token = sqlparse.parse(query)[0].tokens[0]
        self.assertEqual(
            self.query_rewriter.contains_subquery(subquery_token), True)
        self.assertEqual(
            self.query_rewriter.contains_subquery(no_subquery_token), False)

    def test_extract_subquery(self):
        query = ("select * from (select * from repo.table where "
                 "repo.table.test='True')")
        subquery_token = sqlparse.parse(query)[0].tokens[6]
        expected_result = ('(', ('select * from repo.table where '
                                 'repo.table.test=\'True\''), ')')
        self.assertEqual(
            self.query_rewriter.extract_subquery(subquery_token),
            expected_result)

    def test_process_subquery(self):
        query = "select * from (select * from repo.table)"
        subquery_token = sqlparse.parse(query)[0].tokens[6]
        mock_table_policies = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.find_table_policies')
        mock_table_policies.return_value = ["tester='Alice"]
        expected_result = ("(select * from (SELECT * FROM repo.table WHERE "
                           "tester='Alice) AS repotable)")
        self.assertEqual(
            self.query_rewriter.process_subquery(subquery_token),
            expected_result)

    def test_apply_row_level_security_base(self):
        mock_find_table_policies = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.find_table_policies')
        mock_find_table_policies.return_value = ["tester='Alice'"]

        query = "select * from repo.table"
        expected_result = ("select * from (SELECT * FROM repo.table WHERE "
                           "tester='Alice') AS repotable")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = "select * from hola.orders limit 3"
        expected_result = ("select * from (SELECT * FROM hola.orders WHERE "
                           "tester='Alice') AS holaorders limit 3")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        mock_find_table_policies = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.find_table_policies')
        mock_find_table_policies.return_value = ["tester='Alice'",
                                                 "tester='Bob'"]

        query = ("select * from hola.orders o, hola.customer t where "
                 "o.customerid=t.customerid order by customer")
        expected_result = ("select * from (SELECT * FROM hola.orders "
                           "WHERE tester='Alice' OR tester='Bob') o, "
                           "(SELECT * FROM hola.customer WHERE tester='Alice' "
                           "OR tester='Bob') t where o.customerid=t.customerid"
                           " order by customer")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = ("select * from test.orders right join test.customer "
                 "on test.orders.customerid=test.customer.customerid")
        expected_result = ("select * from (SELECT * FROM test.orders WHERE "
                           "tester='Alice' OR tester='Bob') AS testorders "
                           "right join (SELECT * FROM test.customer WHERE "
                           "tester='Alice' OR tester='Bob') AS testcustomer "
                           "on testorders.customerid=testcustomer.customerid")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = ("select * from test.orders right join test.customer on "
                 "test.orders.customerid=test.customer.customerid")
        expected_result = ("select * from (SELECT * FROM test.orders WHERE "
                           "tester='Alice' OR tester='Bob') AS testorders "
                           "right join (SELECT * FROM test.customer WHERE "
                           "tester='Alice' OR tester='Bob') AS testcustomer "
                           "on testorders.customerid=testcustomer.customerid")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = ("select * from test.orders right join test.customer on "
                 "test.orders.customerid=test.customer.customerid")
        expected_result = ("select * from (SELECT * FROM test.orders WHERE "
                           "tester='Alice' OR tester='Bob') AS testorders "
                           "right join (SELECT * FROM test.customer WHERE "
                           "tester='Alice' OR tester='Bob') AS testcustomer "
                           "on testorders.customerid=testcustomer.customerid")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = "select * from test.orders where test.orders.customerid='1'"
        expected_result = ("select * from (SELECT * FROM test.orders WHERE "
                           "tester='Alice' OR tester='Bob') AS testorders "
                           "where testorders.customerid='1'")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = ("select count(*), visible from hola.grades_file "
                 "group by visible")
        expected_result = ("select count(*), visible from (SELECT * FROM "
                           "hola.grades_file WHERE tester='Alice' OR "
                           "tester='Bob') AS holagrades_file group by visible")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = ("select * from (select * from "
                 "(select * from hola.orders) as i) as o")
        expected_result = ("select * from (select * from (select * from "
                           "(SELECT * FROM hola.orders WHERE tester='Alice' "
                           "OR tester='Bob') AS holaorders) as i) as o")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = ("select * from hola.orders where customerid = "
                 "(select customerid from hola.customer where customerid='3')")
        expected_result = ("select * from (SELECT * FROM hola.orders WHERE "
                           "tester='Alice' OR tester='Bob') AS holaorders "
                           "where customerid = (select customerid from "
                           "(SELECT * FROM hola.customer WHERE tester='Alice' "
                           "OR tester='Bob') AS holacustomer where "
                           "customerid='3')")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

        query = ("select * from hola.orders as t, hola.orders_2, "
                 "hola.customer where t.customerid=hola.orders_2.customerid "
                 "and hola.orders_2.customerid=hola.customer.customerid")
        expected_result = ("select * from (SELECT * FROM hola.orders WHERE "
                           "tester='Alice' OR tester='Bob') as t, "
                           "(SELECT * FROM hola.orders_2 WHERE tester='Alice' "
                           "OR tester='Bob') AS holaorders_2, "
                           "(SELECT * FROM hola.customer WHERE tester='Alice' "
                           "OR tester='Bob') AS holacustomer where "
                           "t.customerid=holaorders_2.customerid and "
                           "holaorders_2.customerid=holacustomer.customerid")
        self.assertEqual(
            self.query_rewriter.apply_row_level_security_base(query),
            expected_result)

    def test_apply_row_level_security_update(self):
        mock_find_table_policies = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.find_table_policies')
        mock_find_table_policies.return_value = ["count > 10"]
        query = ("update hola.grades_file set firstname='Alice' "
                 "where lastname='Abby'")
        expected_result = ("update hola.grades_file set firstname='Alice' "
                           "where lastname='Abby' AND count > 10")
        self.assertEquals(
            self.query_rewriter.apply_row_level_security_update(query),
            expected_result)

        mock_find_table_policies = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.find_table_policies')
        mock_find_table_policies.return_value = []
        query = ("update hola.grades_file set firstname='Alice' "
                 "where lastname='Abby'")
        expected_result = ("update hola.grades_file set firstname='Alice' "
                           "where lastname='Abby'")
        self.assertEquals(
            self.query_rewriter.apply_row_level_security_update(query),
            expected_result)

    def test_apply_row_level_security_insert(self):
        query = "insert into repo.table values (a,b,c)"
        mock_find_table_policies = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.find_table_policies')
        mock_find_table_policies.return_value = ["INSERT='True'"]
        expected_result = "insert into repo.table values (a,b,c)"
        self.assertEquals(
            self.query_rewriter.apply_row_level_security_insert(query),
            expected_result)

        mock_find_table_policies = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.find_table_policies')
        mock_find_table_policies.return_value = ["INSERT='False'"]

        exception_raised = False
        try:
            self.query_rewriter.apply_row_level_security_insert(query)
        except Exception:
            exception_raised = True
        self.assertEquals(exception_raised, True)

        mock_find_table_policies = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.find_table_policies')
        mock_find_table_policies.return_value = []
        self.assertEquals(
            self.query_rewriter.apply_row_level_security_insert(query),
            expected_result)

    def test_apply_row_level_security(self):
        mock_apply_rls_base = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.'
            'apply_row_level_security_base')
        mock_apply_rls_base.return_value = "RLS for select called"

        mock_apply_rls_insert = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.'
            'apply_row_level_security_insert')
        mock_apply_rls_insert.return_value = "RLS for insert called"

        mock_apply_rls_update = self.create_patch(
            'core.db.query_rewriter.SQLQueryRewriter.'
            'apply_row_level_security_update')
        mock_apply_rls_update.return_value = "RLS for update called"

        select_query = "select * from repo.table"
        self.assertEquals(
            self.query_rewriter.apply_row_level_security(select_query),
            "RLS for select called")

        insert_query = "insert into repo.table values (a,b,c)"
        self.assertEquals(
            self.query_rewriter.apply_row_level_security(insert_query),
            "RLS for insert called")

        update_query = ("update repo.table set firstname='Alice' "
                        "where lastname='Abby'")
        self.assertEquals(
            self.query_rewriter.apply_row_level_security(update_query),
            "RLS for update called")