dvc/config.py
"""DVC config objects."""
import os
import re
from contextlib import contextmanager
from functools import partial
from typing import TYPE_CHECKING, Optional
from funcy import compact, memoize, re_find
from dvc.exceptions import DvcException, NotDvcRepoError
from dvc.log import logger
from .utils.objects import cached_property
if TYPE_CHECKING:
from dvc.fs import FileSystem
from dvc.types import DictStrAny
logger = logger.getChild(__name__)
class ConfigError(DvcException):
"""DVC config exception."""
def __init__(self, msg):
super().__init__(f"config file error: {msg}")
class RemoteConfigError(ConfigError):
pass
class NoRemoteError(RemoteConfigError):
pass
class RemoteNotFoundError(RemoteConfigError):
pass
class MachineConfigError(ConfigError):
pass
class NoMachineError(MachineConfigError):
pass
class MachineNotFoundError(MachineConfigError):
pass
@memoize
def get_compiled_schema():
from voluptuous import Schema
from .config_schema import SCHEMA
return Schema(SCHEMA)
def to_bool(value):
from .config_schema import Bool
return Bool(value)
class Config(dict):
"""Class that manages configuration files for a DVC repo.
Args:
dvc_dir (str): optional path to `.dvc` directory, that is used to
access repo-specific configs like .dvc/config and
.dvc/config.local.
validate (bool): optional flag to tell dvc if it should validate the
config or just load it as is. 'True' by default.
Raises:
ConfigError: thrown if config has an invalid format.
"""
SYSTEM_LEVELS = ("system", "global")
REPO_LEVELS = ("repo", "local")
# In the order they shadow each other
LEVELS = SYSTEM_LEVELS + REPO_LEVELS
CONFIG = "config"
CONFIG_LOCAL = "config.local"
def __init__(
self,
dvc_dir: Optional[str] = None,
local_dvc_dir: Optional[str] = None,
validate: bool = True,
fs: Optional["FileSystem"] = None,
config: Optional["DictStrAny"] = None,
remote: Optional[str] = None,
remote_config: Optional["DictStrAny"] = None,
):
from dvc.fs import LocalFileSystem
dvc_dir = os.fspath(dvc_dir) if dvc_dir else None
self.dvc_dir = dvc_dir
self.wfs = LocalFileSystem()
self.fs = fs or self.wfs
if dvc_dir:
self.dvc_dir = self.fs.abspath(dvc_dir)
self.local_dvc_dir = local_dvc_dir
if not fs and not local_dvc_dir:
self.local_dvc_dir = dvc_dir
self.load(
validate=validate, config=config, remote=remote, remote_config=remote_config
)
@classmethod
def from_cwd(cls, fs: Optional["FileSystem"] = None, **kwargs):
from dvc.repo import Repo
try:
dvc_dir = Repo.find_dvc_dir(fs=fs)
except NotDvcRepoError:
dvc_dir = None
return cls(dvc_dir=dvc_dir, fs=fs, **kwargs)
@classmethod
def get_dir(cls, level):
from dvc.dirs import global_config_dir, system_config_dir
assert level in ("global", "system")
if level == "global":
return global_config_dir()
if level == "system":
return system_config_dir()
@cached_property
def files(self) -> dict[str, str]:
files = {
level: os.path.join(self.get_dir(level), self.CONFIG)
for level in ("system", "global")
}
if self.dvc_dir is not None:
files["repo"] = self.fs.join(self.dvc_dir, self.CONFIG)
if self.local_dvc_dir is not None:
files["local"] = self.wfs.join(self.local_dvc_dir, self.CONFIG_LOCAL)
return files
@staticmethod
def init(dvc_dir):
"""Initializes dvc config.
Args:
dvc_dir (str): path to .dvc directory.
Returns:
dvc.config.Config: config object.
"""
config_file = os.path.join(dvc_dir, Config.CONFIG)
with open(config_file, "w+", encoding="utf-8"):
return Config(dvc_dir)
def merge(self, config):
merge(self, config)
def load(
self,
validate: bool = True,
config: Optional["DictStrAny"] = None,
remote: Optional[str] = None,
remote_config: Optional["DictStrAny"] = None,
):
"""Loads config from all the config files.
Raises:
ConfigError: thrown if config has an invalid format.
"""
conf = self.load_config_to_level()
if config is not None:
merge(conf, config)
if validate:
conf = self.validate(conf)
self.clear()
if remote:
conf["core"]["remote"] = remote
if remote_config:
remote = remote or conf["core"].get("remote")
if not remote:
raise ValueError("Missing remote name")
merge(conf, {"remote": {remote: remote_config}})
self.update(conf)
def _get_fs(self, level):
# NOTE: this might be a Gitfs, which doesn't see things outside of
# the repo.
return self.fs if level == "repo" else self.wfs
@staticmethod
def load_file(path, fs=None) -> dict:
from configobj import ConfigObj, ConfigObjError
from dvc.fs import localfs
fs = fs or localfs
with fs.open(path) as fobj:
try:
conf_obj = ConfigObj(fobj)
except UnicodeDecodeError as exc:
raise ConfigError(str(exc)) from exc
except ConfigObjError as exc:
raise ConfigError(str(exc)) from exc
return _parse_named(_lower_keys(conf_obj.dict()))
def _load_config(self, level):
filename = self.files[level]
fs = self._get_fs(level)
try:
return self.load_file(filename, fs=fs)
except FileNotFoundError:
return {}
def _save_config(self, level, conf_dict):
from configobj import ConfigObj
filename = self.files[level]
fs = self._get_fs(level)
logger.debug("Writing '%s'.", filename)
fs.makedirs(os.path.dirname(filename))
config = ConfigObj(_pack_named(conf_dict))
with fs.open(filename, "wb") as fobj:
config.write(fobj)
config.filename = filename
def load_one(self, level):
conf = self._load_config(level)
conf = self._load_paths(conf, self.files[level])
# Auto-verify sections
for key in get_compiled_schema().schema:
conf.setdefault(key, {})
return conf
@staticmethod
def _resolve(conf_dir, path):
from .config_schema import ExpPath, RelPath
if re.match(r"\w+://", path):
return path
if os.path.isabs(path):
return path
# on windows convert slashes to backslashes
# to have path compatible with abs_conf_dir
if os.path.sep == "\\" and "/" in path:
path = path.replace("/", "\\")
expanded = os.path.expanduser(path)
if os.path.isabs(expanded):
return ExpPath(expanded, path)
return RelPath(os.path.abspath(os.path.join(conf_dir, path)))
@classmethod
def _load_paths(cls, conf, filename):
conf_dir = os.path.abspath(os.path.dirname(filename))
resolve = partial(cls._resolve, conf_dir)
return Config._map_dirs(conf, resolve)
@staticmethod
def _to_relpath(conf_dir, path):
from dvc.fs import localfs
from dvc.utils import relpath
from .config_schema import ExpPath, RelPath
if re.match(r"\w+://", path):
return path
if isinstance(path, ExpPath):
return path.def_path
if os.path.expanduser(path) != path:
return localfs.as_posix(path)
if isinstance(path, RelPath) or not os.path.isabs(path):
path = relpath(path, conf_dir)
return localfs.as_posix(path)
return path
@staticmethod
def _save_paths(conf, filename):
conf_dir = os.path.dirname(filename)
rel = partial(Config._to_relpath, conf_dir)
return Config._map_dirs(conf, rel)
@staticmethod
def _map_dirs(conf, func):
from voluptuous import ALLOW_EXTRA, Schema
dirs_schema = {
"cache": {"dir": func},
"remote": {
str: {
"url": func,
"gdrive_user_credentials_file": func,
"gdrive_service_account_json_file_path": func,
"credentialpath": func,
"keyfile": func,
"cert_path": func,
"key_path": func,
}
},
"machine": {
str: {
"startup_script": func,
"setup_script": func,
}
},
}
return Schema(dirs_schema, extra=ALLOW_EXTRA)(conf)
def load_config_to_level(self, level=None):
merged_conf: dict = {}
for merge_level in self.LEVELS:
if merge_level == level:
break
if merge_level in self.files:
merge(merged_conf, self.load_one(merge_level))
return merged_conf
def read(self, level=None):
# NOTE: we read from a merged config by default, same as git config
if level is None:
return self.load_config_to_level()
return self.load_one(level)
@contextmanager
def edit(self, level=None, validate=True):
# NOTE: we write to repo config by default, same as git config
level = level or "repo"
if self.dvc_dir is None and level in self.REPO_LEVELS:
raise ConfigError("Not inside a DVC repo")
conf = self.load_one(level)
yield conf
conf = self._save_paths(conf, self.files[level])
merged_conf = self.load_config_to_level(level)
merge(merged_conf, conf)
if validate:
self.validate(merged_conf)
self._save_config(level, conf)
self.load(validate=validate)
@staticmethod
def validate(data):
from voluptuous import Invalid
try:
return get_compiled_schema()(data)
except Invalid as exc:
raise ConfigError(str(exc)) from None
def _parse_named(conf):
result: dict[str, dict] = {"remote": {}, "machine": {}, "db": {}}
for section, val in conf.items():
match = re_find(r'^\s*(remote|machine|db)\s*"(.*)"\s*$', section)
if match:
key, name = match
result[key][name] = val
else:
result[section] = val
return result
def _pack_named(conf):
# Drop empty sections
result = compact(conf)
# Transform remote.name -> 'remote "name"'
for key in ("remote", "machine", "db"):
for name, val in conf[key].items():
result[f'{key} "{name}"'] = val
result.pop(key, None)
return result
def merge(into, update):
"""Merges second dict into first recursively"""
for key, val in update.items():
if isinstance(into.get(key), dict) and isinstance(val, dict):
merge(into[key], val)
else:
into[key] = val
def _lower_keys(data):
return {
k.lower(): _lower_keys(v) if isinstance(v, dict) else v for k, v in data.items()
}