Source code for calliope.core.attrdict

"""
Copyright (C) 2013-2018 Calliope contributors listed in AUTHORS.
Licensed under the Apache 2.0 License (see LICENSE file).

attrdict.py
~~~~~~~~~~~

Implements the AttrDict class (a subclass of regular dict)
used for managing model configuration.

"""

import io

import numpy as np
import ruamel.yaml as ruamel_yaml

from calliope.core.util.tools import relative_path
from calliope.core.util.logging import logger


class __Missing(object):
    def __nonzero__(self):
        return False


_MISSING = __Missing()


def _yaml_load(src):
    """Load YAML from a file object or path with useful parser errors"""
    if not isinstance(src, str):
        try:
            src_name = src.name
        except AttributeError:
            src_name = '<yaml stringio>'
        # Force-load file streams as that allows the parser to print
        # much more context when it encounters an error
        src = src.read()
    else:
        src_name = '<yaml string>'
    try:
        result = ruamel_yaml.safe_load(src)
        if not isinstance(result, dict):
            raise ValueError('Could not parse {} as YAML'.format(src_name))
        return result
    except ruamel_yaml.YAMLError:
        logger.error(
            'Parser error when reading YAML '
            'from {}.'.format(src_name)
        )
        raise


[docs]class AttrDict(dict): """ A subclass of ``dict`` with key access by attributes:: d = AttrDict({'a': 1, 'b': 2}) d.a == 1 # True Includes a range of additional methods to read and write to YAML, and to deal with nested keys. """ __getattr__ = dict.__getitem__ __setattr__ = dict.__setitem__ def __init__(self, source_dict=None): super().__init__() if source_dict is not None: if isinstance(source_dict, dict): self.init_from_dict(source_dict) else: raise ValueError('Must pass a dict to AttrDict')
[docs] def copy(self): """Override copy method so that it returns an AttrDict""" return AttrDict(self.as_dict().copy())
[docs] def init_from_dict(self, d): """ Initialize a new AttrDict from the given dict. Handles any nested dicts by turning them into AttrDicts too:: d = AttrDict({'a': 1, 'b': {'x': 1, 'y': 2}}) d.b.x == 1 # True """ for k, v in d.items(): # First, keys must be strings, not ints if isinstance(k, int): k = str(k) # Now, assign to the key, handling nested AttrDicts properly if isinstance(v, dict): self.set_key(k, AttrDict(v)) elif isinstance(v, list): # Modifying the list in-place so that if it is a modified # list subclass, e.g. CommentedSeq, it is not killed for i in range(len(v)): if isinstance(v[i], dict): v[i] = AttrDict(v[i]) self.set_key(k, v) else: self.set_key(k, v)
@classmethod def _resolve_imports(cls, loaded, resolve_imports, base_path=None): if isinstance(resolve_imports, bool) and resolve_imports is True and 'import' in loaded: loaded_dict = loaded elif isinstance(resolve_imports, str) and resolve_imports + '.import' in loaded.keys_nested(): loaded_dict = loaded.get_key(resolve_imports) else: # Return right away if no importing to be done return loaded # If we end up here, we have something to import imports = loaded_dict.get_key('import') if not isinstance(imports, list): raise ValueError('`import` must be a list.') for k in imports: if base_path: path = relative_path(base_path, k) else: path = k imported = cls.from_yaml(path) # loaded is added to imported (i.e. it takes precedence) imported.union(loaded_dict) loaded_dict = imported # 'import' key itself is no longer needed loaded_dict.del_key('import') if isinstance(resolve_imports, str): loaded.set_key(resolve_imports, loaded_dict) else: loaded = loaded_dict return loaded
[docs] @classmethod def from_yaml(cls, f, resolve_imports=True): """ Returns an AttrDict initialized from the given path or file object ``f``, which must point to a YAML file. If ``resolve_imports`` is True, top-level ``import:`` statements are resolved recursively, else they are treated like any other key. If ``resolve_imports`` is a string, such as ``foobar``, import statements underneath that key are resolved, i.e. ``foobar.import:``. When resolving import statements, anything defined locally overrides definitions in the imported file. """ if isinstance(f, str): with open(f, 'r') as src: loaded = cls(_yaml_load(src)) else: loaded = cls(_yaml_load(f)) loaded = cls._resolve_imports(loaded, resolve_imports, base_path=f) return loaded
[docs] @classmethod def from_yaml_string(cls, string, resolve_imports=True): """ Returns an AttrDict initialized from the given string, which must be valid YAML. """ loaded = cls(_yaml_load(string)) loaded = cls._resolve_imports(loaded, resolve_imports) return loaded
[docs] def set_key(self, key, value): """ Set the given ``key`` to the given ``value``. Handles nested keys, e.g.:: d = AttrDict() d.set_key('foo.bar', 1) d.foo.bar == 1 # True """ if isinstance(value, dict) and not isinstance(value, AttrDict): value = AttrDict(value) if '.' in key: key, remainder = key.split('.', 1) try: self[key].set_key(remainder, value) except KeyError: self[key] = AttrDict() self[key].set_key(remainder, value) except AttributeError: if self[key] is None: # If the value is None, we replace it self[key] = AttrDict() self[key].set_key(remainder, value) # Else there is probably something there, and we don't just # want to overwrite so stop and warn the user else: raise KeyError('Cannot set nested key on non-dict key.') else: self[key] = value
[docs] def get_key(self, key, default=_MISSING): """ Looks up the given ``key``. Like set_key(), deals with nested keys. If default is anything but ``_MISSING``, the given default is returned if the key does not exist. """ if '.' in key: # Nested key of form "foo.bar" key, remainder = key.split('.', 1) if default != _MISSING: try: value = self[key].get_key(remainder, default) except KeyError: # subdict exists, but doesn't contain key return default except AttributeError: # key points to non-dict thing, so no get_key attribute return default else: value = self[key].get_key(remainder) else: # Single, non-nested key of form "foo" if default != _MISSING: return self.get(key, default) else: return self[key] return value
[docs] def del_key(self, key): """Delete the given key. Properly deals with nested keys.""" if '.' in key: key, remainder = key.split('.', 1) try: del self[key][remainder] except KeyError: self[key].del_key(remainder) # If we removed the last subkey, delete the parent key too if len(self[key].keys()) == 0: del self[key] else: del self[key]
[docs] def as_dict(self, flat=False): """ Return the AttrDict as a pure dict (with nested dicts if necessary). """ if flat: return self.as_dict_flat() else: return self.as_dict_nested()
def as_dict_nested(self): d = {} for k, v in self.items(): if isinstance(v, AttrDict): d[k] = v.as_dict() elif isinstance(v, list): d[k] = [ i if not isinstance(i, AttrDict) else i.as_dict() for i in v] else: d[k] = v return d def as_dict_flat(self): d = {} keys = self.keys_nested() for k in keys: d[k] = self.get_key(k) return d
[docs] def to_yaml(self, path=None): """ Saves the AttrDict to the ``path`` as a YAML file, or returns a YAML string if ``path`` is None. """ result = self.copy() yaml_ = ruamel_yaml.YAML() yaml_.indent = 2 yaml_.block_seq_indent = 0 # Numpy objects should be converted to regular Python objects, # so that they are properly displayed in the resulting YAML output for k in result.keys_nested(): # Convert numpy numbers to regular python ones v = result.get_key(k) if isinstance(v, np.floating): result.set_key(k, float(v)) elif isinstance(v, np.integer): result.set_key(k, int(v)) # Lists are turned into seqs so that they are formatted nicely elif isinstance(v, list): result.set_key(k, yaml_.seq(v)) result = result.as_dict() if path is not None: with open(path, 'w') as f: yaml_.dump(result, f) else: stream = io.StringIO() yaml_.dump(result, stream) return stream.getvalue()
[docs] def keys_nested(self, subkeys_as='list'): """ Returns all keys in the AttrDict, sorted, including the keys of nested subdicts (which may be either regular dicts or AttrDicts). If ``subkeys_as='list'`` (default), then a list of all keys is returned, in the form ``['a', 'b.b1', 'b.b2']``. If ``subkeys_as='dict'``, a list containing keys and dicts of subkeys is returned, in the form ``['a', {'b': ['b1', 'b2']}]``. """ keys = [] for k, v in sorted(self.items()): # Check if dict instance (which AttrDict is too), # and for non-emptyness of the dict if isinstance(v, dict) and v: if subkeys_as == 'list': keys.extend([ k + '.' + kk for kk in v.keys_nested() ]) elif subkeys_as == 'dict': keys.append({k: v.keys_nested( subkeys_as=subkeys_as) }) else: keys.append(k) return keys
[docs] def union( self, other, allow_override=False, allow_replacement=False, allow_subdict_override_with_none=False): """ Merges the AttrDict in-place with the passed ``other`` AttrDict. Keys in ``other`` take precedence, and nested keys are properly handled. If ``allow_override`` is False, a KeyError is raised if other tries to redefine an already defined key. If ``allow_replacement``, allow "_REPLACE_" key to replace an entire sub-dict. If ``allow_subdict_override_with_none`` is False (default), a key of the form ``this.that: None`` in other will be ignored if subdicts exist in self like ``this.that.foo: 1``, rather than wiping them. """ self_keys = self.keys_nested() other_keys = other.keys_nested() if allow_replacement: WIPE_KEY = '_REPLACE_' override_keys = [k for k in other_keys if WIPE_KEY not in k] wipe_keys = [k.split('.' + WIPE_KEY)[0] for k in other_keys if WIPE_KEY in k] else: override_keys = other_keys wipe_keys = [] for k in override_keys: if not allow_override and k in self_keys: raise KeyError('Key defined twice: {}'.format(k)) else: other_value = other.get_key(k) # If other value is None, and would overwrite an entire subdict, # we skip it if not (other_value is None and isinstance(self.get_key(k, None), AttrDict)): self.set_key(k, other_value) for k in wipe_keys: self.set_key(k, other.get_key(k + '.' + WIPE_KEY))