ejhigson/dyPolyChord

View on GitHub
dyPolyChord/output_processing.py

Summary

Maintainability
A
0 mins
Test Coverage
#!/usr/bin/env python
"""
Functions for loading and processing dyPolyChord dynamic nested sampling runs.
"""
import os
import warnings
import numpy as np
import nestcheck.ns_run_utils
import nestcheck.data_processing
import nestcheck.io_utils as iou


def settings_root(likelihood_name, prior_name, ndim, **kwargs):
    """
    Returns a standard string containing information about settings.

    Parameters
    ----------
    likelihood_name: str
    prior_name: str
    ndim: int
    prior_scale: float or int
    dynamic_goal: float, int or None
    nlive_const: int
    nrepeats: int
    nint: int, optional
    init_step: int, optional

    Returns
    -------
    root: str
    """
    prior_scale = kwargs.pop('prior_scale')
    dynamic_goal = kwargs.pop('dynamic_goal')
    nlive_const = kwargs.pop('nlive_const')
    nrepeats = kwargs.pop('nrepeats')
    ninit = kwargs.pop('ninit', None)
    init_step = kwargs.pop('init_step', None)
    if kwargs:
        raise TypeError('Unexpected **kwargs: {0}'.format(kwargs))
    root = '{}_{}_{}_dg{}'.format(likelihood_name, prior_name, prior_scale,
                                  dynamic_goal)
    if dynamic_goal is not None:
        assert ninit is not None
        root += '_{}init'.format(ninit)
        if dynamic_goal != 0:
            assert init_step is not None
            root += '_{}is'.format(init_step)
    root += '_{}d_{}nlive_{}nrepeats'.format(ndim, nlive_const, nrepeats)
    return root.replace('.', '_')


def process_dypolychord_run(file_root, base_dir, **kwargs):
    """
    Load the output files of a dynamic run and process them to the nestcheck
    format.

    Parameters
    ----------
    file_root: str
    base_dir: str
    dynamic_goal: float
    dup_assert: bool, optional
        Whether to throw an AssertionError if there are duplicate point
        loglikelihood values.
    dup_warn: bool, optional
        Whether to give a UserWarning if there are duplicate point
        loglikelihood values.

    Returns
    -------
    run: dict
        Nested sampling run in nestcheck format (see
        http://nestcheck.readthedocs.io/en/latest/api.html for more
        information).
    """
    dynamic_goal = kwargs.pop('dynamic_goal')
    dup_assert = kwargs.pop('dup_assert', False)
    dup_warn = kwargs.pop('dup_warn', False)
    if kwargs:
        raise TypeError('Unexpected **kwargs: {0}'.format(kwargs))
    init = nestcheck.data_processing.process_polychord_run(
        file_root + '_init', base_dir, dup_assert=dup_assert,
        dup_warn=dup_warn)
    assert np.all(init['thread_min_max'][:, 0] == -np.inf), (
        'Initial run contains threads not starting at -inf.\n'
        'thread_min_max=' + str(init['thread_min_max']))
    dyn = nestcheck.data_processing.process_polychord_run(
        file_root + '_dyn', base_dir, dup_assert=dup_assert,
        dup_warn=dup_warn)
    dyn_info = iou.pickle_load(os.path.join(
        base_dir, file_root + '_dyn_info'))
    if dynamic_goal == 0:
        # If dynamic_goal == 0 then nlive should only decrease, so check all
        # threads start by sampling
        assert np.all(dyn['thread_min_max'][:, 0] == -np.inf), (
            str(dyn['thread_min_max']))
    # Get info to run
    run_output = {'file_root': file_root,
                  'base_dir': base_dir}
    if 'resume_ndead' not in dyn_info:
        # The dynamic run was not resumed part way through the initial run:
        # hence there are no samples repeated in both runs' files and we can
        # simply combine dyn and init using standard nestcheck functions.
        run = nestcheck.ns_run_utils.combine_ns_runs([init, dyn])
        try:
            run_output['nlike'] = (
                init['output']['nlike'] + dyn['output']['nlike'])
        except KeyError:
            pass # protect from error reading nlike from .stats file
    else:
        # The dynamic run started by resuming the initial run part way
        # through, and we need to remove duplicate points from the combined run
        run = combine_resumed_dyn_run(init, dyn, dyn_info['resume_ndead'])
        try:
            run_output['nlike'] = (
                init['output']['nlike'] + dyn['output']['nlike']
                - dyn_info['resume_nlike'])
        except KeyError:
            pass # protect from error reading nlike from .stats file
    run['output'] = run_output
    # check the nested sampling run has the expected properties
    nestcheck.ns_run_utils.check_ns_run(
        run, dup_assert=dup_assert, dup_warn=dup_warn)
    return run


def combine_resumed_dyn_run(init, dyn, resume_ndead):
    """
    Merge initial run and dynamic run which was resumed from it, including
    removing duplicate points present in both runs.

    Parameters
    ----------
    init: dict
        Initial exploratory run in nestcheck format (see
        http://nestcheck.readthedocs.io/en/latest/api.html for more
        information).
    dyn: dict
        Dynamic run in nestcheck format.
    resume_ndead: int
        The number of dead points present when dyn was resumed from init.

    Returns
    -------
    run: dict
        Combined run in nestcheck format.
    """
    assert np.array_equal(
        init['logl'][:resume_ndead], dyn['logl'][:resume_ndead]), (
            'The first {0} points should be the same'.format(resume_ndead))
    init['theta'] = init['theta'][resume_ndead:, :]
    for key in ['nlive_array', 'logl', 'thread_labels']:
        init[key] = init[key][resume_ndead:]
    # We also need to remove the points that were live when the resume file was
    # written, as these show up as samples in both dyn and init
    live_inds = []
    empty_thread_inds = []
    for i, th_lab in enumerate(np.unique(init['thread_labels'])):
        th_inds = np.where(init['thread_labels'] == th_lab)[0]
        live_logl = init['logl'][th_inds[0]]
        init['thread_min_max'][i, 0] = live_logl
        if np.where(dyn['logl'] == live_logl)[0].shape[0] > 0:
            live_inds.append(th_inds[0])
            if th_inds.shape[0] == 1:
                empty_thread_inds.append(i)
        else:
            warnings.warn(
                ('Expected live point at resume should be present in dynamic '
                 'run. If there are no further errors, this warning can be '
                 'ignored.\nlogl={}, th_lab={}, inds={}, init samples (after '
                 'removing first resume_ndead)={}, unique threads in init={}, '
                 'dyn samples={}, resume_ndead={}.').format(
                     str(live_logl), th_lab,
                     np.where(dyn['logl'] == live_logl),
                     init['logl'].shape[0],
                     np.unique(init['thread_labels']).shape[0],
                     dyn['logl'].shape[0], resume_ndead), UserWarning)
    # Remove the live points at resume from init
    init['theta'] = np.delete(init['theta'], live_inds, axis=0)
    for key in ['nlive_array', 'logl', 'thread_labels']:
        init[key] = np.delete(init[key], live_inds)
    # Deal with the case that one of the threads is now empty
    if empty_thread_inds:
        # remove any empty threads from logl_min_max
        init['thread_min_max'] = np.delete(
            init['thread_min_max'], empty_thread_inds, axis=0)
        # Now we need to reorder the thread labels to avoid gaps
        thread_labels_new = np.full(init['thread_labels'].shape, np.nan)
        for i, th_lab in enumerate(np.unique(init['thread_labels'])):
            inds = np.where(init['thread_labels'] == th_lab)[0]
            thread_labels_new[inds] = i
            # Check the newly relabelled thread label matches thread_min_max
            assert init['thread_min_max'][i, 0] <= init['logl'][inds[0]]
            assert init['thread_min_max'][i, 1] == init['logl'][inds[-1]]
        assert np.all(~np.isnan(thread_labels_new))
        init['thread_labels'] = thread_labels_new.astype(int)
    # Add the init threads to dyn with new labels that continue on from the dyn
    # labels
    init['thread_labels'] += dyn['thread_min_max'].shape[0]
    run = nestcheck.ns_run_utils.combine_threads(
        nestcheck.ns_run_utils.get_run_threads(dyn) +
        nestcheck.ns_run_utils.get_run_threads(init),
        assert_birth_point=False)
    return run


def clean_extra_output(root_name):
    """Clean the additional output files made by dyPolyChord, leaving only
    output files for the combined run in PolyChord format.

    Parameters
    ----------
    root_name: str
        File root. Equivalent to os.path.join(base_dir, file_root).
    """
    os.remove(root_name + '_dyn_info.pkl')
    for extra in ['init', 'dyn']:
        os.remove(root_name + '_{0}.stats'.format(extra))
        os.remove(root_name + '_{0}_dead-birth.txt'.format(extra))
        os.remove(root_name + '_{0}_dead.txt'.format(extra))
        # tidy up remaining .resume files (if the function has reach this
        # point, both the initial and dynamic runs have finished so we
        # shouldn't need to resume)
        try:
            os.remove(root_name + '_{0}.resume'.format(extra))
        except OSError:
            pass