BondGraphTools/BondGraphTools

View on GitHub
BondGraphTools/fileio.py

Summary

Maintainability
B
6 hrs
Test Coverage
A
94%
"""The file save/load interface and file format data model.

This module provides the basic IO functionality such as saving and loading
to file.
"""
# todo:
#    As the file format matures, we probably want to move to a object
#    oriented loader that tracks the (as yet undefined) schema.


import logging
import pathlib
import yaml

from .compound import BondGraph
from .actions import connect, new, expose
from .exceptions import InvalidComponentException

logger = logging.getLogger(__name__)

FILE_VERSION = "0.1"


def save(model, filename):
    """Save the model to file.

    Args:
        model: The model to be saved
        filename: The file to save to

    """
    model_directory = _build_model_directory(model)

    models = {}
    templates = {}

    for uri, sub_model in model_directory.items():
        models.update({
            uri: _build_model_data(sub_model, templates)
        })
    data = {
        "version": FILE_VERSION,
        "root": model.name,
        "models": models
    }

    with open(filename, 'w') as filestream:
        yaml.dump(data, filestream, default_flow_style=False)


def _build_model_directory(model):

    try:
        if not model.parent:
            uri = "/"
        else:
            _, uri = model.uri.split(":")

        directory = {uri: model}

        for c in model.components:
            directory.update(_build_model_directory(c))
        return directory
    except AttributeError:
        return {}


def _build_model_data(model, templates):
    components = []
    out = {}
    for c in model.components:
        if isinstance(c, BondGraph):
            _, uri = c.uri.split(":")
            components.append(
                f"{c.name} {uri}"
            )
        else:
            components.append(
                _build_component_string(c)
            )
    out.update({"components": components})

    netlist = []
    for tail, head in model.bonds:
        netlist.append(f"{tail} {head}")

    if netlist:
        out.update({"netlist": netlist})

    ports = []
    for port in model.ports:
        (c1, e), (c2, f) = model._port_map[port]  # noqa
        if c1 == c2:
            ports.append(f"{c1.name} {port.name}")
        else:
            raise NotImplementedError
    if ports:
        out.update({"ports": ports})

    return out


def _build_component_string(component):

    out_str = f"{component.name} {component.template}"
    logger.debug("Trying to serialise: %s", out_str)
    try:
        for param, value in component.params.items():
            logger.debug("Param: %s, %s", param, value)
            if isinstance(value, (int, float)):
                out_str += f" {param}={value}"
            elif isinstance(value, dict):
                try:
                    v = value["value"]
                    if isinstance(v, (float, int)):
                        out_str += f" {param}={v}"
                except KeyError as ex:
                    logger.debug("Skipping: %s ", str(ex))
                    pass

    except AttributeError:
        pass
    logger.debug("Saving component string: %s", out_str)
    return out_str


def load(file_name, model=None, as_name=None):
    """Load a model from file.

    Args:
        file_name (str or Path): The file to load.

    Returns:
        An instance of `BondGraph`

    Raises:
        `NotImplementedError` for incorrect file version.

    """
    if isinstance(file_name, pathlib.Path):
        file_name = str(file_name)

    with open(file_name, 'r') as f:
        data = yaml.load(f, Loader=yaml.SafeLoader)

    version = str(data['version'])
    if version == FILE_VERSION:
        return _builder(data, model, as_name)
    else:
        raise NotImplementedError


def _builder(data, model=None, as_name=None):
    if not model:
        root = "/"
    else:
        root = model
    models = data['models']

    def _build(model_name, template_name):
        model_data = models[template_name]
        netlist = model_data["netlist"]

        logger.debug("%s: Trying to build", model_name)

        model = new(name=model_name)

        for comp_string in model_data["components"]:
            logger.debug("%s: building", comp_string)

            try:
                comp = _base_component_from_str(comp_string)
            except (ValueError, KeyError):
                name, sub_model = comp_string.split(" ")
                comp = _build(name, sub_model)

            model.add(comp)
            logger.debug("%s components complete", model_name)

        _wire(model, netlist)

        try:
            io_ports = model_data["ports"]
            _expose(model, io_ports)
        except KeyError:
            logger.debug("No ports on model ")

        return model

    def _wire(model, netlist):
        logger.debug("%s: trying to wire", model.name)

        def get_port(port_string):

            tokens = iter(port_string.split('.'))

            c = next(tokens)
            try:
                comp, = (comp for comp in model.components if comp.name == c)
            except ValueError:
                raise InvalidComponentException(
                    f"Could not find component {c} in model {model.name}")
            try:
                t2 = next(tokens)
            except StopIteration:
                return comp
            try:
                _ = next(tokens)
            except StopIteration:

                logger.debug("Tyring to get port %s, %s", str(comp), str(t2))
                try:
                    t2 = int(t2)
                except ValueError:
                    pass
                port = comp.get_port(t2)
                logger.debug("Got %s", str(port))
                return port

            else:
                raise NotImplementedError

        for bond_string in netlist:
            logger.debug("%s: bond %s", model.name, bond_string)
            tail_str, head_str = bond_string.split()
            tail = get_port(tail_str)
            head = get_port(head_str)
            connect(tail, head)

    def _expose(model, _io_ports):
        for port_string in _io_ports:
            component_name, port_label = port_string.split(" ")
            comp, = {
                c for c in model.components if c.name == component_name
            }
            expose(comp, port_label)

    def _parse_build_args(in_args):

        if not in_args:
            return [], {}

        arg = in_args[0]
        args, kwargs = _parse_build_args(in_args[1:])

        try:
            k, v = arg.split("=")
        except ValueError:
            k = None
            v = arg

        if v.isnumeric():
            try:
                v = int(v)
            except ValueError:
                v = float(v)
        if not k:
            args.append(v)
        else:
            kwargs.update({k: v})

        return args, kwargs

    def _base_component_from_str(string):
        label, tempate, *build_args = string.split()
        args, kwargs = _parse_build_args(build_args)
        library, component = tempate.split("/")

        comp = new(name=label,
                   library=library,
                   component=component,
                   value=args)

        for k, v in kwargs.items():
            comp.set_param(k, v)
        return comp

    out = _build(root, root)

    if as_name:
        out.name = as_name
    elif not model:
        out.name = data["root"]

    return out