src/tests/catwalk_tests/test_protected_groups_generators.py
from datetime import datetime, date
import testing.postgresql
from sqlalchemy.engine import create_engine
from unittest.mock import MagicMock
from triage.component.catwalk.protected_groups_generators import ProtectedGroupsGenerator
def create_demographics_table(db_engine, data):
db_engine.execute(
"""drop table if exists demographics;
create table demographics (person_id int, event_date date, race text, sex text, age_bucket int)
"""
)
for event in data:
db_engine.execute(
"insert into demographics values (%s, %s, %s, %s, %s)", event
)
def create_cohort_table(db_engine, data):
db_engine.execute(
"create table cohort_abcdef (entity_id int, as_of_date timestamp)"
)
for event in data:
db_engine.execute(
"insert into cohort_abcdef values (%s, %s)", event
)
def default_demographics():
return [
(1, datetime(2015, 12, 30), 'aa', 'male', 1),
(1, datetime(2016, 2, 1), 'aa', 'male', 1),
(1, datetime(2016, 3, 1), 'aa', 'female', 1),
(2, datetime(2015, 12, 30), 'wh', 'male', 3),
(2, datetime(2016, 3, 1), 'wh', 'male', 3),
(3, datetime(2015, 12, 30), 'aa', 'male', 1),
(3, datetime(2016, 3, 1), 'aa', 'male', 1),
(5, datetime(2016, 2, 1), 'wh', 'female', 2),
(5, datetime(2016, 3, 1), 'wh', 'female', 2),
]
def default_cohort():
return [
(1, datetime(2016, 1, 1)),
(1, datetime(2016, 3, 1)),
(1, datetime(2016, 4, 1)),
(2, datetime(2016, 1, 1)),
(2, datetime(2016, 3, 1)),
(2, datetime(2016, 4, 1)),
(3, datetime(2016, 1, 1)),
(3, datetime(2016, 3, 1)),
(3, datetime(2016, 4, 1)),
(4, datetime(2016, 1, 1)),
(4, datetime(2016, 3, 1)),
(4, datetime(2016, 4, 1)),
(5, datetime(2016, 1, 1)),
(5, datetime(2016, 3, 1)),
(5, datetime(2016, 4, 1)),
]
def assert_data(table_generator):
expected_output = [
(1, date(2016, 1, 1), 'aa', 'male', '1', 'abcdef'),
(1, date(2016, 3, 1), 'aa', 'male', '1', 'abcdef'),
(1, date(2016, 4, 1), 'aa', 'female', '1', 'abcdef'),
(2, date(2016, 1, 1), 'wh', 'male', '3', 'abcdef'),
(2, date(2016, 3, 1), 'wh', 'male', '3', 'abcdef'),
(2, date(2016, 4, 1), 'wh', 'male', '3', 'abcdef'),
(3, date(2016, 1, 1), 'aa', 'male', '1', 'abcdef'),
(3, date(2016, 3, 1), 'aa', 'male', '1', 'abcdef'),
(3, date(2016, 4, 1), 'aa', 'male', '1', 'abcdef'),
(4, date(2016, 1, 1), None, None, None, 'abcdef'),
(4, date(2016, 3, 1), None, None, None, 'abcdef'),
(4, date(2016, 4, 1), None, None, None, 'abcdef'),
(5, date(2016, 1, 1), None, None, None, 'abcdef'),
(5, date(2016, 3, 1), 'wh', 'female', '2', 'abcdef'),
(5, date(2016, 4, 1), 'wh', 'female', '2', 'abcdef'),
]
results = list(
table_generator.db_engine.execute(
f"""
select entity_id, as_of_date, race, sex, age_bucket, cohort_hash
from {table_generator.protected_groups_table_name}
order by entity_id, as_of_date
"""
)
)
assert results == expected_output
def test_protected_groups_generator_replace():
demographics_data = default_demographics()
cohort_data = default_cohort()
with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
create_demographics_table(engine, demographics_data)
create_cohort_table(engine, cohort_data)
table_generator = ProtectedGroupsGenerator(
from_obj="demographics",
attribute_columns=['race', 'sex', 'age_bucket'],
entity_id_column="person_id",
knowledge_date_column="event_date",
db_engine=engine,
protected_groups_table_name="protected_groups_abcdef",
replace=True
)
as_of_dates = [
datetime(2016, 1, 1),
datetime(2016, 3, 1),
datetime(2016, 4, 1),
]
table_generator.generate_all_dates(
as_of_dates,
cohort_table_name='cohort_abcdef',
cohort_hash='abcdef'
)
assert_data(table_generator)
table_generator.generate_all_dates(
as_of_dates,
cohort_table_name='cohort_abcdef',
cohort_hash='abcdef'
)
assert_data(table_generator)
def test_protected_groups_generator_noreplace():
demographics_data = default_demographics()
cohort_data = default_cohort()
with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
create_demographics_table(engine, demographics_data)
create_cohort_table(engine, cohort_data)
table_generator = ProtectedGroupsGenerator(
from_obj="demographics",
attribute_columns=['race', 'sex', 'age_bucket'],
entity_id_column="person_id",
knowledge_date_column="event_date",
db_engine=engine,
protected_groups_table_name="protected_groups_abcdef",
replace=False
)
as_of_dates = [
datetime(2016, 1, 1),
datetime(2016, 3, 1),
datetime(2016, 4, 1),
]
table_generator.generate_all_dates(
as_of_dates,
cohort_table_name='cohort_abcdef',
cohort_hash='abcdef'
)
assert_data(table_generator)
table_generator.generate = MagicMock()
table_generator.generate_all_dates(
as_of_dates,
cohort_table_name='cohort_abcdef',
cohort_hash='abcdef'
)
table_generator.generate.assert_not_called()
assert_data(table_generator)
def test_as_dataframe():
attribute_columns = ['race', 'sex', 'age_bucket']
demographics_data = default_demographics()
cohort_data = default_cohort()
with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
create_demographics_table(engine, demographics_data)
create_cohort_table(engine, cohort_data)
table_generator = ProtectedGroupsGenerator(
from_obj="demographics",
attribute_columns=attribute_columns,
entity_id_column="person_id",
knowledge_date_column="event_date",
db_engine=engine,
protected_groups_table_name="protected_groups_abcdef",
replace=True
)
as_of_dates = [
datetime(2016, 1, 1),
datetime(2016, 3, 1),
datetime(2016, 4, 1),
]
table_generator.generate_all_dates(
as_of_dates,
cohort_table_name='cohort_abcdef',
cohort_hash='abcdef'
)
protected_df = table_generator.as_dataframe(
as_of_dates,
cohort_hash='abcdef'
)
assert(protected_df.shape == (15, 3))
assert(set(attribute_columns).issubset(protected_df.columns))
for attr_col in attribute_columns:
assert(protected_df[attr_col].dtype == 'object')