"""A module for Figure of Merits"""
from abc import ABC, abstractmethod
import numpy as np
from MDMC.common.decorators import repr_decorator
from MDMC.trajectory_analysis.observables.obs import Observable
[docs]
@repr_decorator('weight', 'exp_obs', 'MD_obs', 'rescale_factor', 'auto_scale')
class ObservablePair:
    """
    Contains a pair of observables for calculating the FoM
    Checks the validity of observables
    Parameters
    ----------
    exp_obs : Observable
        An ``Observable`` with ``Observable.origin == 'experiment'``
    MD_obs : Observable
        An ``Observable`` with ``Observable.origin == 'MD'``
    weight : float
        The relative weight of this pair on a total FoM
    rescale_factor: float, optional
        Factor applied to ``exp_obs`` when calculating the FoM to ensure it is
        on the same scale as ``MD_obs``. Default is `1.`.
    auto_scale: bool, optional
        If `True`, ``rescale_factor`` is set automatically to minimise the FoM
        for each step of the refinement, overriding a user specified value if
        set. Note that this process is purely statistical and does not account
        for physical effects that might impact the scaling. Default is `False`.
    """
    def __init__(self, exp_obs: Observable, MD_obs: Observable, weight: float,
                 rescale_factor: float = 1., auto_scale: bool = False):
        self.exp_obs = exp_obs
        self.MD_obs = MD_obs
        self.weight = weight
        self.rescale_factor = rescale_factor
        self.auto_scale = auto_scale
    @property
    def exp_obs(self) -> Observable:
        """
        Get or set the experimental ``Observable``
        Setting the ``Observable`` checks its validity
        Returns
        -------
        Observable
            The experimental ``Observable``
        """
        return self._exp_obs
    @exp_obs.setter
    def exp_obs(self, exp_obs: Observable) -> None:
        self.validate_obs(exp_obs, 'experiment')
        self._exp_obs = exp_obs
    @property
    def MD_obs(self) -> Observable:
        """
        Get or set the MD ``Observable``
        Setting the ``Observable`` checks its validity
        Returns
        -------
        Observable
            The MD ``Observable``
        """
        return self._MD_obs
    @MD_obs.setter
    def MD_obs(self, MD_obs: Observable) -> None:
        self.validate_obs(MD_obs, 'MD')
        self._MD_obs = MD_obs
    @property
    def weight(self) -> float:
        """
        Get or set the relative weight of this pair on a total FoM
        Returns
        -------
        float
            The relative weight
        Raises
        ------
        TypeError
            If ``weight`` is set with a non-numeric
        """
        return self._weight
    @weight.setter
    def weight(self, weight: float) -> None:
        try:
            weight = float(weight)
        except ValueError as error:
            raise TypeError('weight must be a float') from error
        self.validate_weight(weight)
        self._weight = weight
    @property
    def n_averages(self) -> 'dict[str, int]':
        """
        The number of separate, complete dependent variable calculations we
        have been able to perform for the ``Observable``
        Returns
        -------
        dict
            Each key represents a dependent variable, and the value is the
            number of times we have calculated it
        """
        n_averages = {}
        for key, value in self.MD_obs.dependent_variables.items():
            n_averages[key] = len(value)
        return n_averages
[docs]
    def validate_obs(self, obs: Observable, origin: str) -> None:
        """
        Performs checks to test the validity of an ``Observable``
        Tests that the ``Observable.origin`` is as expected. If the
        ``ObservablePair`` has another ``Observable`` (i.e. the other
        ``origin``), then this tests that the ``independent_variables`` are
        identical, the ``dependent_variables`` have the same shape,
        the ``errors`` have the same shape, and that the ``Observable`` objects
        are of the same type.
        Parameters
        ----------
        obs : Observable
            The ``Observable`` to validate
        origin : str
            The ``Observable.origin`` (``'experiment'`` or ``'MD'``)
        Raises
        ------
        AssertionError
            If the ``Observable.origin`` is not the same as the ``origin``
            Parameter
        AssertionError
            If ``Observable`` does not have identical ``independent_variables``
            to any ``Observable`` of the other ``Observable.origin`` that
            already exists in the ``ObservablePair``
        AssertionError
            If ``Observable`` does not have identical ``dependent_variables`` to
            any ``Observable`` of the other ``Observable.origin`` that already
            exists in the ``ObservablePair``
        AssertionError
            If ``Observable`` does not have identical ``errors`` to any
            ``Observable`` of the other ``Observable.origin`` that already
            exists in the ``ObservablePair``
        AssertionError
            If ``Observable`` does not have identical type to any ``Observable``
            of the other ``Observable.origin`` that already exists in the
            ``ObservablePair``
        """
        # Check origin is correct
        assert obs.origin == origin, ('The observable does not have the correct'
                                      ' origin')
        try:
            other_obs = self.exp_obs if obs.origin == 'MD' else self.MD_obs
        except AttributeError:
            other_obs = None
        # Check independent variables are identical, check dependent variables
        # have the same shapes, check errors have the same shapes, check
        # observables have the same type
        if other_obs:
            indep_e_mess = 'Independent variables must be identical'
            assert (obs.independent_variables.keys() ==
                    other_obs.independent_variables.keys()), indep_e_mess
            for k in obs.independent_variables:
                assert np.all(obs.independent_variables[k] ==
                              other_obs.independent_variables[k]), indep_e_mess
            # Try/except deals with empty observable case (no dependent
            # variables and errors)
            try:
                dep_e_mess = 'Dependent variables must have the same shape'
                assert (obs.dependent_variables.keys() ==
                        other_obs.dependent_variables.keys()), dep_e_mess
                for k in obs.dependent_variables:
                    assert (np.shape(obs.dependent_variables[k]) ==
                            np.shape(other_obs.dependent_variables[k])), \
                        
dep_e_mess
                err_e_mess = 'Errors must have the same shape'
                assert obs.errors.keys() == other_obs.errors.keys(), err_e_mess
                for k in obs.errors:
                    assert (np.shape(obs.errors[k]) ==
                            np.shape(other_obs.errors[k])), err_e_mess
            except AttributeError:
                pass
            assert isinstance(obs, type(other_obs)), ('Observables are not of'
                                                      ' the same type') 
[docs]
    @staticmethod
    def validate_weight(weight: float) -> None:
        """
        Performs checks to test the validity of the ``weight``
        Parameters
        ----------
        weight : float
            The ``weight`` to be validated
        Raises
        ------
        AssertionError
            If the ``weight`` is not positive or is infinite
        """
        assert weight > 0. and weight != float('inf'), ('Weight must be a'
                                                        ' finite positive'
                                                        ' float') 
[docs]
    def check_types(self) -> None:
        """
        Checks that ``Observable`` objects are of the same type
        """
        raise NotImplementedError 
[docs]
    def check_indep_var(self) -> None:
        """
        Checks that ``Observable`` objects have the same
        ``independent_variables`` and that are finite
        """
        raise NotImplementedError 
[docs]
    def check_dep_var(self) -> None:
        """
        Checks that ``Observable`` objects have the same ``dependent_variables``
        and that are finite
        """
        raise NotImplementedError 
[docs]
    def check_errors(self) -> None:
        """
        Checks that an ``Observable`` has errors on the ``dependent_variable``
        and that these are `float` and not `NaN`
        """
        raise NotImplementedError 
[docs]
    def check_origin(self, origin: str) -> None:
        """
        Checks that the ``Observable.origin`` is correct
        Parameters
        -------
        origin : str
            A string consisting of either ``'experiment'`` or  ``'MD'``
        """
        raise NotImplementedError 
[docs]
    def calculate_difference(self) -> np.ndarray:
        """
        Assumes a single dependent variable for each ``Observable``
        Returns
        -------
        numpy.ndarray
            An array with the same dimensions as the ``dependent_variables`` of
            the ``exp_obs`` and ``MD_obs``. The array contains the difference
            between the ``dependent_variables`` taking the ``rescale_factor``
            into account.
        """
        diff = (np.array(*self.exp_obs.dependent_variables.values())
                * self.rescale_factor
                - np.array(*self.MD_obs.dependent_variables.values()))
        return diff 
[docs]
    def calculate_errors(self) -> np.ndarray:
        """
        Assumes a single dependent variable error for each ``Observable``
        Returns
        -------
        numpy.ndarray
            An array with the same dimensions as the ``errors`` of the
            ``exp_obs`` and ``MD_obs``. The array contains the combination of
            the ``errors`` in quadrature, taking the ``rescale_factor`` into
            account.
        """
        errors = (self.calculate_exp_errors() ** 2
                  + np.array(*self.MD_obs.errors.values()) ** 2) ** 0.5
        return errors 
[docs]
    def calculate_exp_errors(self) -> np.ndarray:
        """
        Assumes a single dependent variable error for each ``Observable``.
        Calculates only the experimental errors.
        Returns
        -------
        numpy.ndarray
            An array with the same dimensions as the ``errors`` of the
            ``exp_obs``, taking the ``rescale_factor`` into account.
        """
        return np.array(*self.exp_obs.errors.values()) * self.rescale_factor