"""A module for the Parameter and Parameters classes
Parameter defines the name and value of each force field parameter, and whether
it is fixed, has constraints or is tied.
Parameters inherits from lists and implements a number of methods for filterting
a sequence of Parameter objects.
"""
from __future__ import annotations
import ast
import logging
import operator
import re
import warnings
import weakref
from collections.abc import Iterable
from itertools import chain, count
from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Union
import numpy as np
from MDMC.common.decorators import repr_decorator, unit_decorator, unit_decorator_getter
if TYPE_CHECKING:
from MDMC.MD.interactions import Interaction
[docs]
@repr_decorator('ID', 'type', 'value', 'unit', 'fixed', 'constraints',
'interactions_name', 'functions_name', 'tied')
class Parameter:
"""
A force field parameter which can be fixed or constrained within limits
The value of a parameter cannot be set if ``fixed==True``.
Parameters
----------
value : float
The value of the parameter.
name : str
The name of the parameter.
fixed : bool
Whether or not the value can be changed.
constraints : tuple
The closed range of the ``Parameter.value``, (lower, upper).
``constraints`` must have the same units as ``value``.
**settings
``unit`` (`str`)
The unit. If this is not provided then the unit will be taken from
the object passed as ``value``.
"""
# each Parameter has a unique ID, so they can be distinguished
_ID_generator = count(start=1, step=1)
def __init__(self, value, name, fixed=False, constraints=None, **settings):
self.ID = self._generate_ID()
self.name = name + f" (#{self.ID})"
self.type = name
self.unit = settings.get('unit', getattr(value, 'unit', None))
self.constraints = constraints
self.value = value
self.fixed = fixed
self.interactions_name = None
self.functions_name = None
self._interactions = []
self._tie = None
self._tie_parameter = None
@property
def value(self) -> float:
"""
Get or set the value of the ``Parameter``
The value will not be changed if it is ``fixed`` or ``tied``, or if it
is set outside the bounds of ``constraints``
Returns
-------
float
The value of the ``Parameter``, including if the ``Parameter`` is
``tied``
Warns
--------
warnings.warn
If the ``Parameter`` is ``fixed``.
warnings.warn
If the ``Parameter`` is ``tied``.
"""
if self.tied:
return self.tie
return self._value
@value.setter
@unit_decorator(unit=None)
def value(self, value: float) -> None:
if hasattr(self, 'fixed') and self.fixed:
warnings.warn("Unable to change fixed parameter")
elif self.tied:
warnings.warn("Unable to change tied parameter")
else:
if self.constraints is not None:
self.validate_value(value, self.constraints)
self._value = value
@property
@unit_decorator_getter(unit=None)
def constraints(self) -> tuple:
"""
Get or set the constraint of the ``Parameter``
Returns
-------
tuple
The closed range of the ``Parameter.value``
Raises
------
ValueError
If the constraint tuple is not ``(lower, upper)``.
"""
return self._constraints
@constraints.setter
def constraints(self, constraints: tuple) -> None:
# Checks if constraints are a 2 element tuple of floats, that the
# zeroeth element is less than or equal to the first, and that
# self.value is within them, if it exists
if constraints is not None:
if constraints[0] > constraints[1]:
raise ValueError("Constaints must be (lower, upper)")
if hasattr(self, 'value'):
self.validate_value(self.value, constraints)
self._constraints = constraints
@property
def interactions(self) -> list:
"""
Get or append to the parent ``Interaction`` objects for this
``Parameter``
Returns
-------
list
All parent ``Interaction`` objects
Raises
------
ValueError
If an added interaction name is not consistent with existing
interaction names
ValueError
If an added ``Interaction`` has a function name not consistent with
the function names of an existing ``Interaction``
"""
return [interaction() for interaction in self._interactions]
@interactions.setter
def interactions(self, interaction: 'Interaction') -> None:
# Test if interaction is of the same type as any interactions already
# stored
if self.interactions_name:
if interaction.name != self.interactions_name:
raise ValueError('Added interaction name is not consistent with'
' existing interaction names')
if interaction.function_name != self.functions_name:
raise ValueError('Added function name is not consistent with'
' existing function names')
else:
self.interactions_name = interaction.name
self.functions_name = interaction.function_name
self._interactions.append(weakref.ref(interaction))
@property
def tie(self) -> Union[float, None]:
"""
Get the ``value`` of a the ``Parameter`` that this ``Parameter`` is tied
to
Returns
-------
float
The ``value`` of the ``tied`` ``Parameter``
"""
# pylint: disable=eval-used
# eval use is generally bad
# but the safe alternative (ast.literal_eval) creates malformed nodes
if self._tie is None:
return None
return eval(compile(self._tie, '', 'eval'))
@property
def tied(self) -> bool:
"""
Get whether this ``Parameter`` is tied
Returns
-------
bool
`True` if this ``Parameter`` is tied to another ``Parameter``, else
`False`
"""
return bool(hasattr(self, 'tie') and self.tie is not None)
[docs]
def set_tie(self, parameter: Parameter, expr: str) -> None:
"""
This ``ties`` the ``Parameter.value`` to the ``value`` of another
``Parameter``
Parameters
---------
parameter : Parameter
The ``Parameter`` to tie to
expr : str
A mathematical expression
Examples
--------
To set the ``Parameter.value`` to ``p1.value * 2``::
>>> Parameter.set_tie(p1, "* 2")
"""
self._tie_parameter = weakref.ref(parameter)
self._tie = ast.parse(
'self._tie_parameter().value' + expr, mode='eval')
@classmethod
def _generate_ID(cls) -> int:
"""Generates a unique ID for the Parameter that has just been created."""
return next(cls._ID_generator)
def __str__(self) -> str:
condition = ('Fixed ' if self.fixed else 'Tied ' if self.tied else
'Constrained ' if self.constraints is not None else '')
function = self.functions_name + ' ' if self.functions_name else ''
return '{0}{_value} {1}{name}'.format(condition, function,
**self.__dict__)
def __getitem__(self, key):
return self.__getattribute__(key)
def __setitem__(self, key, value):
self.__setattr__(key, value)
[docs]
@staticmethod
def validate_value(value: float, constraints: tuple) -> None:
"""
Validates the ``Parameter.value`` by testing if it is within the
``constraints``
Parameters
----------
values : float
The value of the ``Parameter``
constraints: tuple
A 2-tuple of the lower and upper constraints respectively.
Raises
------
ValueError
If the ``value`` is not within the ``constraints``
"""
if value < constraints[0] or value > constraints[1]:
raise ValueError(f"Value must be within constraints, \
value is: {value}, constraints are: {constraints}")
# comparison operator so parameters are always in the same order on MMC refinement headings
def __lt__(self, other):
return self.name < other.name
[docs]
class Parameters(dict):
"""
A `dict-like` object where every element is a ``Parameter`` indexed by name,
which contains a number of helper methods for filtering.
Although ``Parameters`` is a `dict`, it should be treated like a `list` when writing to it;
i.e. initialise it using a `list` and use `append` to add to it. These parameters can
then be accessed by their name as a key.
In short; Parameters writes like a list and reads like a dict.
Parameters
----------
init_parameters: ``Parameter`` or `list` of ``Parameter``s, optional, default None
The initial ``Parameter`` objects that the ``Parameters`` object contains.
Attributes
----------
array: np.ndarray
An alphabetically-sorted numpy array of the ``Parameter``s stored in this object.
"""
def __init__(self, init_parameters: Optional[Union[Parameter, 'list[Parameter]']] = None):
super().__init__()
if init_parameters is not None:
init_parameters = self._check_input(init_parameters)
self.append(init_parameters)
def __setitem__(self, key: str, value: Parameter) -> NoReturn:
# disable this method to ensure parameter keys are always the parameter name
raise TypeError("Parameters should be added to using Parameters.append(parameter), "
"with a parameter or list of parameters as your argument.")
def __getitem__(self, key: str) -> Union[Parameter, 'list[Parameter]']:
try:
return super().__getitem__(key)
except KeyError as error:
# see if the key passed was a parameter name with no ID, and catch the error
# by getting the first parameter with that name
r = re.compile(rf"{key} \(#[0-9]+\)")
matching_parameters = list(filter(r.match, list(self.keys())))
if matching_parameters:
if len(matching_parameters) > 1:
warnings.warn("Calling a parameter name with no ID returns a "
"list of all parameters with that name; "
"this may create inconsistent behaviour!")
#pylint: disable=super-with-arguments
# for some reason when we run it without arguments,
# it complains in jupyter notebooks
# see http://thomas-cokelaer.info/blog/2011/09/382/
return sorted([super(Parameters, self).__getitem__(p)
for p in matching_parameters],
key=lambda p: p.ID)
return super().__getitem__(matching_parameters[0])
raise KeyError from error
[docs]
def append(self, parameters: Union['list[Parameter]', Parameter]) -> None:
"""
Appends a ``Parameter`` or list of ``Parameter``s to the dict,
with the parameter name as its key.
Parameters
----------
parameters: ``Parameter`` or `list` of ``Parameter``s
The parameter(s) to be added to the dict.
"""
parameters = self._check_input(parameters)
for parameter in parameters:
super().__setitem__(parameter.name, parameter)
@property
def as_array(self) -> np.ndarray:
"""
The parameters in the object as a sorted numpy array.
Returns
-------
np.ndarray
An alphabetically-sorted array of parameter values in the object.
"""
return np.array(sorted(list(self.values()), key=lambda p: p.name))
[docs]
def filter(self, predicate: Callable[[Parameter], bool]) -> Parameters:
"""
Filters using a predicate
Parameters
----------
predicate : function
A function that returns a `bool` which takes a ``Parameter`` as an
argument.
Returns
-------
Parameters
The ``Parameter`` objects which meet the condition of the predicate
"""
return Parameters(list(filter(predicate, list(self.values()))))
[docs]
def filter_name(self, name: str) -> Parameters:
"""
Filters by ``name``
Parameters
----------
name : str
The ``name`` of the ``Parameter`` objects to return.
Returns
-------
Parameters
The ``Parameter`` objects with ``name``
"""
return self.filter(lambda p: name in p.name)
[docs]
def filter_value(self, comparison: str, value: float) -> Parameters:
"""
Filters by ``value``
Parameters
----------
comparison : str
A `str` representing a comparison operator, ``'>'``, ``'<'``,
``'>='``, ``'<='``, ``'=='``, ``'!='``.
value : float
A `float` with which ``Parameter`` values are compared, using the
``comparison`` operator.
Returns
-------
Parameters
The ``Parameter`` objects which return a `True` when their values
are compared with ``value`` using the ``comparison`` operator
"""
ops = {'>': operator.gt,
'<': operator.lt,
'>=': operator.ge,
'<=': operator.le,
'==': operator.eq,
'!=': operator.ne}
return self.filter(lambda p: ops[comparison](p.value, value))
[docs]
def filter_interaction(self, interaction_name: str) -> Parameters:
"""
Filters based on the name of the ``Interaction`` of each ``Parameter``
Parameters
----------
interaction_name : str
The name of the ``Interaction`` of ``Parameter`` objects to return,
for example ``'Bond'``.
Returns
-------
Parameters
The ``Parameter`` objects which have an ``Interaction`` with the
specified ``interaction_name``
"""
return self.filter(lambda p: p.interactions_name == interaction_name)
[docs]
def filter_function(self, function_name: str) -> Parameters:
"""
Filters based on the name of the ``InteractionFunction`` of each
``Parameter``
Parameters
----------
function_name : str
The name of the ``InteractionFunction`` of ``Parameter`` objects to
return, for example ``'LennardJones'`` or ``'HarmonicPotential'``.
Returns
-------
Parameters
The ``Parameter`` objects which have a ``function`` with the
specified ``function_name``
"""
return self.filter(lambda p: p.functions_name == function_name)
[docs]
def filter_atom_attribute(self, attribute: str, value: Union[str, float]) -> Parameters:
"""
Filters based on the attribute of ``Atom`` objects which have each
``Parameter`` applied to them
Parameters
----------
attribute : str
An attribute of an ``Atom``. Attributes to match to must be either
`float` or str.
value : str, float
The value of the ``Atom`` ``attribute``.
Returns
-------
Parameters
The ``Parameter`` objects which are applied to an ``Atom`` object
which has the specified ``value`` of the specified ``attribute``
"""
def flatten(iterable):
for element in iterable:
if isinstance(element, Iterable):
yield from flatten(element)
else:
yield element
return self.filter(lambda p:
value in [getattr(atom, attribute)
for interaction in p.interactions
for atom in flatten(interaction.atoms)])
[docs]
def filter_structure(self, structure_name: str) -> Parameters:
"""
Filters based on the name of the ``Structure`` to which each
``Parameter`` applies
Parameters
----------
structure_name : str
The name of a ``Structure``.
Returns
-------
Parameters
The ``Parameter`` objects which are applied to a ``Structure``
which has the specified ``structure_name``
"""
def check_structure_name(parameter):
"""
Checks the name of all structures
Returns
-------
list
A `list` of `str` with the names of ``Structure`` objects
"""
# Recursively add structure.name to structure_names set until the
# structure is the top level structure
structure_names = set()
def add_name(structure):
structure_names.add(structure.name)
if structure.top_level_structure == structure:
return
add_name(structure.parent)
for inter in parameter.interactions:
for atom in chain.from_iterable(inter.atoms):
add_name(atom)
return structure_name in structure_names
return self.filter(check_structure_name)
[docs]
def log_parameters(self) -> None:
"""Logs all Parameters by ID"""
LOGGER = logging.getLogger(__name__)
msg = "List of all parameters with ID: \n"
for parameter in self.values():
msg += f"{parameter.repr()}"
LOGGER.info(msg)
print("Details on which Parameter corresponds to each ID have been written to the log.")
@staticmethod
def _check_input(x: Any) -> 'list[Parameter]':
"""
Ensures that input to a Parameters object is in the correct form.
Raises an error if the input is not either a Parameter
(in which case it is turned into a list, so it can be fed into an iteration loop)
or a list of Parameters.
Parameters
----------
x: Any
The object to be sanitised.
Returns
-------
list[Parameter]
Returns x if x is a list of Parameters, or [x] if x is a Parameter
(so it can be iterated over)
Raises
------
TypeError
If the object is not a Parameter or list of Parameters (including
if it is a list that contains a non-Parameter object)
"""
if isinstance(x, Parameter):
return [x]
if isinstance(x, list) and all(isinstance(i, Parameter) for i in x):
return x
raise TypeError("Input into a Parameters object must be either a Parameter "
"or a list of Parameters.")