dallinger/config.py
from __future__ import unicode_literals
import io
import json
import logging
import os
import sys
from collections import deque
from contextlib import contextmanager
from pathlib import Path
import six
from setuptools.dist import strtobool
from six.moves import configparser
logger = logging.getLogger(__file__)
marker = object()
LOCAL_CONFIG = "config.txt"
SENSITIVE_KEY_NAMES = ("access_id", "access_key", "password", "secret", "token")
def is_valid_json(value):
json.loads(value)
default_keys = (
# These are the keys allowed in a dallinger experiment config.txt file.
("activate_recruiter_on_start", bool, []),
("ad_group", six.text_type, []),
("approve_requirement", int, []),
("assign_qualifications", bool, []),
("auto_recruit", bool, []),
("aws_access_key_id", six.text_type, ["AWS_ACCESS_KEY_ID"], True),
(
"aws_region",
six.text_type,
["AWS_REGION", "AWS_DEFAULT_REGION", "aws_default_region"],
),
("aws_secret_access_key", six.text_type, ["AWS_SECRET_ACCESS_KEY"], True),
("base_payment", float, []),
("base_port", int, []),
("browser_exclude_rule", six.text_type, []),
("clock_on", bool, []),
("contact_email_on_error", six.text_type, []),
("chrome-path", six.text_type, []),
("dallinger_develop_directory", six.text_type, []),
("dallinger_email_address", six.text_type, []),
("dashboard_password", six.text_type, [], True),
("dashboard_user", six.text_type, [], True),
("database_size", six.text_type, []),
("database_url", six.text_type, [], True),
("debug_recruiter", six.text_type, []),
("description", six.text_type, []),
("disable_when_duration_exceeded", bool, []),
("duration", float, []),
("dyno_type", six.text_type, []),
("dyno_type_web", six.text_type, []),
("dyno_type_worker", six.text_type, []),
("enable_global_experiment_registry", bool, []),
("EXPERIMENT_CLASS_NAME", six.text_type, []),
("group_name", six.text_type, []),
("heroku_app_id_root", six.text_type, []),
("heroku_auth_token", six.text_type, [], True),
("heroku_python_version", six.text_type, []),
("heroku_team", six.text_type, ["team"]),
("host", six.text_type, []),
("id", six.text_type, []),
("infrastructure_debug_details", six.text_type, [], False),
("keywords", six.text_type, []),
("language", six.text_type, []),
("lifetime", int, []),
("lock_table_when_creating_participant", bool, []),
("logfile", six.text_type, []),
("loglevel", int, []),
("mode", six.text_type, []),
("mturk_qualification_blocklist", six.text_type, ["qualification_blacklist"]),
("mturk_qualification_requirements", six.text_type, [], False, [is_valid_json]),
("num_dynos_web", int, []),
("num_dynos_worker", int, []),
("organization_name", six.text_type, []),
("port", int, ["PORT"]),
("prolific_api_token", six.text_type, ["PROLIFIC_RESEARCHER_API_TOKEN"], True),
("prolific_api_version", six.text_type, []),
("prolific_estimated_completion_minutes", int, []),
("prolific_maximum_allowed_minutes", int, []),
("prolific_recruitment_config", six.text_type, [], False, [is_valid_json]),
("protected_routes", six.text_type, [], False, [is_valid_json]),
("recruiter", six.text_type, []),
("recruiters", six.text_type, []),
("redis_size", six.text_type, []),
("replay", bool, []),
("sentry", bool, []),
("smtp_host", six.text_type, []),
("smtp_username", six.text_type, []),
("smtp_password", six.text_type, ["dallinger_email_password"], True),
("threads", six.text_type, []),
("title", six.text_type, []),
("question_max_length", int, []),
("us_only", bool, []),
("webdriver_type", six.text_type, []),
("webdriver_url", six.text_type, []),
("whimsical", bool, []),
("worker_multiplier", float, []),
("docker_image_base_name", six.text_type, [], ""),
("docker_image_name", six.text_type, [], ""),
("docker_volumes", six.text_type, [], ""),
)
class Configuration(object):
SUPPORTED_TYPES = {six.binary_type, six.text_type, int, float, bool}
_experiment_params_loaded = False
_module_params_loaded = False
def __init__(self):
self._reset()
def set(self, key, value):
return self.extend({key: value})
def clear(self):
self.data = deque()
self.ready = False
def _reset(self, register_defaults=False):
self.clear()
self.types = {}
self.synonyms = {}
self.validators = {}
self.sensitive = set()
self._experiment_params_loaded = False
self._module_params_loaded = False
if register_defaults:
for registration in default_keys:
self.register(*registration)
def extend(self, mapping, cast_types=False, strict=False):
normalized_mapping = {}
for key, value in mapping.items():
key = self.synonyms.get(key, key)
test_deprecation(key)
if key not in self.types:
# This key hasn't been registered, we ignore it
if strict:
raise_invalid_key_error(key)
continue
expected_type = self.types.get(key)
if cast_types:
if isinstance(value, six.text_type) and value.startswith("file:"):
# Load this value from a file
_, filename = value.split(":", 1)
with io.open(filename, "rt", encoding="utf-8") as source_file:
value = source_file.read()
try:
if expected_type == bool:
value = strtobool(value)
value = expected_type(value)
except ValueError:
pass
if not isinstance(value, expected_type):
raise TypeError(
"Got {value} for {key}, expected {expected_type}".format(
value=repr(value), key=key, expected_type=expected_type
)
)
for validator in self.validators.get(key, []):
try:
validator(value)
except ValueError as e:
# Annotate the exception with more info
e.dallinger_config_key = key
e.dallinger_config_value = value
raise e
normalized_mapping[key] = value
self.data.extendleft([normalized_mapping])
@contextmanager
def override(self, *args, **kwargs):
self.extend(*args, **kwargs)
yield self
self.data.popleft()
changeable_params = ["auto_recruit"]
def get(self, key, default=marker):
# For now this is limited to "auto_recruit", but in the future it can be extended
# to other parameters as well
if key == "auto_recruit":
from dallinger.db import redis_conn
auto_recruit = redis_conn.get("auto_recruit")
if auto_recruit is not None:
return bool(int(auto_recruit))
if not self.ready:
raise RuntimeError("Config not loaded")
for layer in self.data:
try:
value = layer[key]
if isinstance(value, six.text_type):
value = value.strip()
return value
except KeyError:
continue
if default is marker:
raise KeyError(
f"The following config parameter was not set: {key}. Consider setting it in "
"config.txt or in ~/.dallingerconfig."
)
return default
def __getitem__(self, key):
return self.get(key)
def __setitem__(self, key, value):
return self.extend({key: value})
def __getattr__(self, key):
try:
return self.get(key)
except KeyError:
raise AttributeError
def as_dict(self, include_sensitive=False):
d = {}
for key in self.types:
if key not in self.sensitive or include_sensitive:
try:
d[key] = self.get(key)
except KeyError:
pass
return d
def is_sensitive(self, key):
if key in self.sensitive:
return True
# Also, does a sensitive string appear within the key?
return any(s for s in SENSITIVE_KEY_NAMES if s in key)
def register(self, key, type_, synonyms=None, sensitive=False, validators=None):
if synonyms is None:
synonyms = set()
if key in self.types:
raise KeyError("Config key {} is already registered".format(key))
if type_ not in self.SUPPORTED_TYPES:
raise TypeError("{type} is not a supported type".format(type=type_))
self.types[key] = type_
for synonym in synonyms:
self.synonyms[synonym] = key
if validators:
self.validators[key] = validators
if sensitive:
self.sensitive.add(key)
def load_from_file(self, filename):
parser = configparser.ConfigParser()
parser.read(filename)
data = {}
for section in parser.sections():
data.update(dict(parser.items(section)))
self.extend(data, cast_types=True, strict=True)
def write(self, filter_sensitive=False, directory=None):
parser = configparser.ConfigParser()
parser.add_section("Parameters")
for layer in reversed(self.data):
for k, v in layer.items():
if filter_sensitive and self.is_sensitive(k):
continue
parser.set("Parameters", k, six.text_type(v))
directory = directory or os.getcwd()
destination = os.path.join(directory, LOCAL_CONFIG)
with open(destination, "w") as fp:
parser.write(fp)
def load_from_environment(self):
self.extend(os.environ, cast_types=True)
def load_defaults(self):
"""Load default configuration values"""
# Apply extra parameters before loading the configs
if experiment_available():
# In practice, experiment_available should only return False in tests
self.register_extra_parameters()
global_config_name = ".dallingerconfig"
global_config = os.path.expanduser(os.path.join("~/", global_config_name))
defaults_folder = os.path.join(os.path.dirname(__file__), "default_configs")
local_defaults_file = os.path.join(defaults_folder, "local_config_defaults.txt")
global_defaults_file = os.path.join(
defaults_folder, "global_config_defaults.txt"
)
# Load the configuration, with local parameters overriding global ones.
for config_file in [global_defaults_file, local_defaults_file, global_config]:
self.load_from_file(config_file)
if experiment_available():
self.load_experiment_config_defaults()
def load(self):
self.load_defaults()
localConfig = os.path.join(os.getcwd(), LOCAL_CONFIG)
if os.path.exists(localConfig):
self.load_from_file(localConfig)
self.load_from_environment()
self.ready = True
def register_extra_parameters(self):
initialize_experiment_package(os.getcwd())
extra_parameters = None
# Import and instantiate the experiment class if available
# This will run any experiment specific parameter registrations
from dallinger.experiment import load
exp_klass = load()
exp_params = getattr(exp_klass, "extra_parameters", None)
if exp_params is not None and not self._experiment_params_loaded:
exp_params()
self._experiment_params_loaded = True
try:
from dallinger_experiment.experiment import extra_parameters
except ImportError:
try:
from dallinger_experiment.dallinger_experiment import extra_parameters
except ImportError:
try:
from dallinger_experiment import extra_parameters
except ImportError:
pass
if extra_parameters is not None and not self._module_params_loaded:
extra_parameters()
self._module_params_loaded = True
def load_experiment_config_defaults(self):
from dallinger.experiment import load
exp_klass = load()
self.extend(exp_klass.config_defaults(), strict=True)
config = None
def get_config():
global config
if config is None:
if experiment_available():
from dallinger.experiment import load
exp_klass = load()
config_class = exp_klass.config_class()
else:
config_class = Configuration
config = config_class()
for registration in default_keys:
config.register(*registration)
return config
def initialize_experiment_package(path):
"""Make the specified directory importable as the `dallinger_experiment` package."""
# Create __init__.py if it doesn't exist (needed for Python 2)
init_py = os.path.join(path, "__init__.py")
if not os.path.exists(init_py):
open(init_py, "a").close()
# Retain already set experiment module
if sys.modules.get("dallinger_experiment") is not None:
return
dirname = os.path.dirname(path)
basename = os.path.basename(path)
sys.path.insert(0, dirname)
package = __import__(basename)
if Path(path) not in [Path(p) for p in package.__path__]:
raise Exception(
"Package was not imported from the requested path! ({} not in {})".format(
path, package.__path__
)
)
sys.modules["dallinger_experiment"] = package
package.__package__ = "dallinger_experiment"
package.__name__ = "dallinger_experiment"
sys.path.pop(0)
def experiment_available():
return Path("experiment.py").exists()
def raise_invalid_key_error(key):
error_text = "{} is not a valid configuration key".format(key)
if key == "prolific_reward_cents":
error_text = (
"The 'prolific_reward_cents' config variable has been removed. "
+ "Use 'base_payment' instead to set base compensation for participants. "
+ "Note that base_payment is written in terms of the base unit for the currency, "
+ "not in cents. So, if your prolific_reward_cents was originally set to 50, "
+ "then you should set your base_payment to 0.5."
)
raise KeyError(error_text)
def test_deprecation(key):
if key == "prolific_maximum_allowed_minutes":
import warnings
warnings.simplefilter("always", DeprecationWarning)
warnings.warn(
"The 'prolific_maximum_allowed_minutes' config variable has no effect "
+ "as it is currently ignored by the Prolific API.",
DeprecationWarning,
)