"""This module promotes reward components and termination conditions as
first-class objects. Those building blocks that can be plugged onto an existing
pipeline by composition to keep everything modular, from the task definition to
the low-level observers and controllers.
This modular approach allows for standardization of usual metrics. Overall, it
greatly reduces code duplication and bugs.
"""
from abc import ABC, abstractmethod
from enum import IntEnum
from typing import Tuple, Sequence, Callable, Union, Optional, Generic, TypeVar
import numpy as np
from .interfaces import InfoType, InterfaceJiminyEnv
from .quantities import QuantityCreator
ValueT = TypeVar('ValueT')
Number = Union[float, int, bool, complex]
ArrayOrScalar = Union[np.ndarray, np.number, Number]
ArrayLikeOrScalar = Union[ArrayOrScalar, Sequence[Union[Number, np.number]]]
[docs]
class AbstractReward(ABC):
"""Abstract class from which all reward component must derived.
This goal of the agent is to maximize the expectation of the cumulative sum
of discounted reward over complete episodes. This holds true no matter if
its sign is always negative (aka. reward), always positive (aka. cost) or
indefinite (aka. objective).
Defining cost is allowed by not recommended. Although it encourages the
agent to achieve the task at hands as quickly as possible if success is the
only termination condition, it has the side-effect to give the opportunity
to the agent to maximize the return by killing itself whenever this is an
option, which is rarely the desired behavior. No restriction is enforced as
it may be limiting in some relevant cases, so it is up to the user to make
sure that its design makes sense overall.
"""
def __init__(self,
env: InterfaceJiminyEnv,
name: str) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the reward.
"""
self.env = env
self._name = name
@property
def name(self) -> str:
"""Name uniquely identifying a given reward component.
This name will be used as key for storing reward-specific monitoring
and debugging information in 'info' if key does not already exists,
otherwise it will raise an exception.
"""
return self._name
@property
def is_terminal(self) -> Optional[bool]:
"""Whether the reward is terminal, non-terminal, or indefinite.
A reward is said to be "terminal" if only evaluated for the terminal
state of the MDP, "non-terminal" if evaluated for all states except the
terminal one, or indefinite if systematically evaluated no matter what.
All rewards are supposed to be indefinite unless stated otherwise by
overloading this method. The responsibility of evaluating the reward
only when necessary is delegated to `compute`. This allows for complex
evaluation logics beyond terminal or non-terminal without restriction.
.. note::
Truncation is not consider the same as termination. The reward to
not be evaluated in such a case, which means that it will never be
for such episodes.
"""
return None
@property
@abstractmethod
def is_normalized(self) -> bool:
"""Whether the reward is guaranteed to be normalized, ie it is in range
[0.0, 1.0].
"""
[docs]
@abstractmethod
def compute(self, terminated: bool, info: InfoType) -> Optional[float]:
"""Compute the reward.
.. note::
Return value can be set to `None` to indicate that evaluation was
skipped for some reason, and therefore the reward must not be taken
into account when computing the total reward. This is useful when
the reward is undefined or simply inappropriate in the current
state of the environment.
.. warning::
It is the responsibility of the practitioner overloading this
method to honor flags 'is_terminated' (if not indefinite) and
'is_normalized'. Failing this, an exception will be raised.
:param terminated: Whether the episode has reached a terminal state of
the MDP at the current step.
:param info: Dictionary of extra information for monitoring. It will be
updated in-place for storing current value of the reward
in 'info' if it was truly evaluated.
:returns: Scalar value if the reward was evaluated, `None` otherwise.
"""
def __call__(self, terminated: bool, info: InfoType) -> float:
"""Return the reward associated with the current environment step.
For the corresponding MDP to be stationary, the computation of the
reward is supposed to involve only the transition from previous to
current state of the environment under the ongoing action.
.. note::
This method is a lightweight wrapper around `compute` to skip
evaluation depending on whether the current state and the reward
are terminal. If the reward was truly evaluated, then 'info' is
updated to store either custom debugging information if any or its
value otherwise. If the reward is not evaluated, then 'info' is
left as-is and 0.0 is returned.
.. warning::
This method is not meant to be overloaded.
:param terminated: Whether the episode has reached a terminal state of
the MDP at the current step.
:param info: Dictionary of extra information for monitoring. It will be
updated in-place for storing current value of the reward
in 'info' if it was truly evaluated.
"""
# Evaluate the reward and store extra information
reward_info: InfoType = {}
value = self.compute(terminated, reward_info)
# Early return if None, which means that the reward was not evaluated
if value is None:
return 0.0
# Make sure that terminal flag is honored
if self.is_terminal is not None and self.is_terminal ^ terminated:
raise ValueError("Flag 'is_terminal' not honored.")
# Make sure that the reward is scalar
assert np.ndim(value) == 0
# Make sure that the reward is normalized
if self.is_normalized and (value < 0.0 or value > 1.0):
raise ValueError(
"Reward not normalized in range [0.0, 1.0] as it ought to be.")
# Store its value as info
if self.name in info.keys():
raise KeyError(
f"Key '{self.name}' already reserved in 'info'. Impossible to "
"store value of reward component.")
if reward_info:
info[self.name] = reward_info
else:
info[self.name] = value
# Returning the reward
return value
[docs]
class QuantityReward(AbstractReward, Generic[ValueT]):
"""Convenience class making it easy to derive reward components from
generic quantities.
All this class does is applying some user-specified post-processing to the
value of a given multi-variate quantity to return a floating-point scalar
value, eventually normalized between 0.0 and 1.0 if desired.
"""
def __init__(self,
env: InterfaceJiminyEnv,
name: str,
quantity: QuantityCreator[ValueT],
transform_fn: Optional[Callable[[ValueT], float]],
is_normalized: bool,
is_terminal: Optional[bool]) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the reward. This name will be used as key
for storing current value of the reward in 'info', and to
add the underlying quantity to the set of already managed
quantities by the environment. As a result, it must be
unique otherwise an exception will be raised.
:param quantity: Tuple gathering the class of the underlying quantity
to use as reward after some post-processing, plus any
keyword-arguments of its constructor except 'env',
and 'parent'.
:param transform_fn: Transform function responsible for aggregating a
multi-variate quantity as floating-point scalar
value to maximize. Typical examples are `np.min`,
`np.max`, `lambda x: np.linalg.norm(x, order=N)`.
This function is also responsible for rescaling
the transformed quantity in range [0.0, 1.0] if
the reward is advertised as normalized. The Radial
Basis Function (RBF) kernel is the most common
choice to derive a reward to maximize from errors
based on distance metrics (See
`radial_basis_function` for details.). `None` to
skip transform entirely if not necessary.
:param is_normalized: Whether the reward is guaranteed to be normalized
after applying transform function `transform_fn`.
:param is_terminal: Whether the reward is terminal, non-terminal or
indefinite. A terminal reward will be evaluated at
most once, at the end of each episode for which a
termination condition has been triggered. On the
contrary, a non-terminal reward will be evaluated
systematically except at the end of the episode.
Finally, a indefinite reward will be evaluated
systematically. The value 0.0 is returned and no
'info' will be stored when reward evaluation is
skipped.
"""
# Backup user argument(s)
self._transform_fn = transform_fn
self._is_normalized = is_normalized
self._is_terminal = is_terminal
# Call base implementation
super().__init__(env, name)
# Add quantity to the set of quantities managed by the environment
self.env.quantities[self.name] = quantity
# Keep track of the underlying quantity
self.data = self.env.quantities.registry[self.name]
def __del__(self) -> None:
try:
del self.env.quantities[self.name]
except Exception: # pylint: disable=broad-except
# This method must not fail under any circumstances
pass
@property
def is_terminal(self) -> Optional[bool]:
return self._is_terminal
@property
def is_normalized(self) -> bool:
return self._is_normalized
[docs]
def compute(self, terminated: bool, info: InfoType) -> Optional[float]:
"""Compute the reward if necessary depending on whether the reward and
state are terminal. If so, then first evaluate the underlying quantity,
next apply post-processing if requested.
.. warning::
This method is not meant to be overloaded.
:returns: Scalar value if the reward was evaluated, `None` otherwise.
"""
# Early return depending on whether the reward and state are terminal
if self.is_terminal is not None and self.is_terminal ^ terminated:
return None
# Evaluate raw quantity
value = self.data.get()
# Early return if quantity is None
if value is None:
return None
# Apply some post-processing if requested
if self._transform_fn is not None:
value = self._transform_fn(value)
# Return the reward
return value
QuantityReward.name.__doc__ = \
"""Name uniquely identifying every reward.
It will be used as key not only for storing reward-specific monitoring
and debugging information in 'info', but also for adding the underlying
quantity to the ones already managed by the environment.
"""
[docs]
class MixtureReward(AbstractReward):
"""Base class for aggregating multiple independent reward components as a
single one.
"""
components: Tuple[AbstractReward, ...]
"""List of all the reward components that must be aggregated together.
"""
def __init__(self,
env: InterfaceJiminyEnv,
name: str,
components: Sequence[AbstractReward],
reduce_fn: Callable[
[Tuple[Optional[float], ...]], Optional[float]],
is_normalized: bool) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the total reward.
:param components: Sequence of reward components to aggregate.
:param reduce_fn: Transform function responsible for aggregating all
the reward components that were evaluated. Typical
examples are cumulative product and weighted sum.
:param is_normalized: Whether the reward is guaranteed to be normalized
after applying reduction function `reduce_fn`.
"""
# Make sure that at least one reward component has been specified
if not components:
raise ValueError(
"At least one reward component must be specified.")
# Make sure that all reward components share the same environment
for reward in components:
if env is not reward.env:
raise ValueError(
"All reward components must share the same environment.")
# Backup some user argument(s)
self.components = tuple(components)
self._reduce_fn = reduce_fn
self._is_normalized = is_normalized
# Call base implementation
super().__init__(env, name)
# Determine whether the reward mixture is terminal
is_terminal = {reward.is_terminal for reward in self.components}
self._is_terminal: Optional[bool] = None
if len(is_terminal) == 1:
self._is_terminal = next(iter(is_terminal))
@property
def is_terminal(self) -> Optional[bool]:
"""Whether the reward is terminal, ie only evaluated at the end of an
episode if a termination condition has been triggered.
The cumulative reward is considered terminal if and only if all its
individual reward components are terminal.
"""
return self._is_terminal
@property
def is_normalized(self) -> bool:
return self._is_normalized
[docs]
def compute(self, terminated: bool, info: InfoType) -> Optional[float]:
"""Evaluate each individual reward component for the current state of
the environment, then aggregate them in one.
"""
# Early return depending on whether the reward and state are terminal
if self.is_terminal is not None and self.is_terminal ^ terminated:
return None
# Compute all reward components
values = []
for reward in self.components:
# Evaluate reward
reward_info: InfoType = {}
value: Optional[float] = reward(terminated, reward_info)
# Clear reward value if the reward was never truly evaluated
if not reward_info:
value = None
# Append reward value and information
info.update(reward_info)
values.append(value)
# Aggregate all reward components in one
reward_total = self._reduce_fn(tuple(values))
return reward_total
[docs]
class EpisodeState(IntEnum):
"""Specify the current state of the ongoing episode.
"""
CONTINUED = 0
"""No termination condition has been triggered this step.
"""
TERMINATED = 1
"""The terminal state has been reached.
"""
TRUNCATED = 2
"""A truncation condition has been triggered.
"""
[docs]
class AbstractTerminationCondition(ABC):
"""Abstract class from which all termination conditions must derived.
Request the ongoing episode to stop immediately as soon as a termination
condition is triggered.
There are two cases: truncating the episode or reaching the terminal state.
In the former case, the agent is instructed to stop collecting samples from
the ongoing episode and move to the next one, without considering this as a
failure. As such, the reward-to-go that has not been observed will be
estimated via a value function estimator. This is already what happens
when collecting sample batches in the infinite horizon RL framework, except
that the episode is not resumed to collect the rest of the episode in the
following sample batched. In the case of a termination condition, the agent
is just as much instructed to move to the next episode, but also to
consider that it was an actual failure. This means that, unlike truncation
conditions, the reward-to-go is known to be exactly zero. This is usually
dramatic for the agent in the perspective of an infinite horizon reward,
even more as the maximum discounted reward grows larger as the discount
factor gets closer to one. As a result, the agent will avoid at all cost
triggering terminal conditions, to the point of becoming risk averse by
taking extra security margins lowering the average reward if necessary.
"""
def __init__(self,
env: InterfaceJiminyEnv,
name: str,
grace_period: float = 0.0,
*,
is_truncation: bool = False,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the termination condition. This name will
be used as key for storing the current episode state from
the perspective of this specific condition in 'info', and
to add the underlying quantity to the set of already
managed quantities by the environment. As a result, it
must be unique otherwise an exception will be raised.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_truncation: Whether the episode should be considered
terminated or truncated whenever the termination
condition is triggered.
Optional: False by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
self.env = env
self._name = name
self.grace_period = grace_period
self.is_truncation = is_truncation
self.is_training_only = is_training_only
@property
def name(self) -> str:
"""Name uniquely identifying a given termination condition.
This name will be used as key for storing termination
condition-specific monitoring information in 'info' if key does not
already exists, otherwise it will raise an exception.
"""
return self._name
[docs]
@abstractmethod
def compute(self, info: InfoType) -> bool:
"""Evaluate the termination condition at hands.
:param info: Dictionary of extra information for monitoring. It will be
updated in-place for storing terminated and truncated
flags in 'info' as a tri-states `EpisodeState` value.
"""
def __call__(self, info: InfoType) -> Tuple[bool, bool]:
"""Return whether the termination condition has been triggered.
For the corresponding MDP to be stationary, the condition to trigger
termination is supposed to involve only the transition from previous to
current state of the environment under the ongoing action.
.. note::
This method is a lightweight wrapper around `compute` to return two
boolean flags 'terminated', 'truncated' complying with Gym API.
'info' will be updated to store either custom debug information if
any, a tri-states episode state `EpisodeState` otherwise.
.. warning::
This method is not meant to be overloaded.
:param info: Dictionary of extra information for monitoring. It will be
updated in-place for storing terminated and truncated
flags in 'info' as a tri-states `EpisodeState` value.
:returns: terminated and truncated flags.
"""
# Skip termination condition in eval mode or during grace period
termination_info: InfoType = {}
if (self.is_training_only and not self.env.is_training) or (
self.env.stepper_state.t < self.grace_period):
# Always continue
is_terminated, is_truncated = False, False
else:
# Evaluate the reward and store extra information
is_done = self.compute(termination_info)
is_terminated = is_done and not self.is_truncation
is_truncated = is_done and self.is_truncation
# Store episode state as info
if self.name in info.keys():
raise KeyError(
f"Key '{self.name}' already reserved in 'info'. Impossible to "
"store value of termination condition.")
if termination_info:
info[self.name] = termination_info
else:
if is_terminated:
episode_state = EpisodeState.TERMINATED
elif is_truncated:
episode_state = EpisodeState.TRUNCATED
else:
episode_state = EpisodeState.CONTINUED
info[self.name] = episode_state
# Returning terminated and truncated flags
return is_terminated, is_truncated
[docs]
class QuantityTermination(AbstractTerminationCondition, Generic[ValueT]):
"""Convenience class making it easy to derive termination conditions from
generic quantities.
All this class does is checking that, all elements of a given quantity are
within bounds. If so, then the episode continues, otherwise it is either
truncated or terminated according to 'is_truncation' constructor argument.
This only applies after the end of a grace period. Before that, the
episode continues no matter what.
"""
def __init__(self,
env: InterfaceJiminyEnv,
name: str,
quantity: QuantityCreator[Optional[ArrayOrScalar]],
low: Optional[ArrayLikeOrScalar],
high: Optional[ArrayLikeOrScalar],
grace_period: float = 0.0,
*,
is_truncation: bool = False,
is_training_only: bool = False) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param name: Desired name of the termination condition. This name will
be used as key for storing the current episode state from
the perspective of this specific condition in 'info', and
to add the underlying quantity to the set of already
managed quantities by the environment. As a result, it
must be unique otherwise an exception will be raised.
:param quantity: Tuple gathering the class of the underlying quantity
to use as termination condition, plus any
keyword-arguments of its constructor except 'env',
and 'parent'.
:param low: Lower bound below which termination is triggered.
:param high: Upper bound above which termination is triggered.
:param grace_period: Grace period effective only at the very beginning
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param is_truncation: Whether the episode should be considered
terminated or truncated whenever the termination
condition is triggered.
Optional: False by default.
:param is_training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Optional: False by default.
"""
# Backup user argument(s)
self.low = low
self.high = high
# Call base implementation
super().__init__(
env,
name,
grace_period,
is_truncation=is_truncation,
is_training_only=is_training_only)
# Add quantity to the set of quantities managed by the environment
self.env.quantities[self.name] = quantity
# Keep track of the underlying quantity
self.data = self.env.quantities.registry[self.name]
def __del__(self) -> None:
try:
del self.env.quantities[self.name]
except Exception: # pylint: disable=broad-except
# This method must not fail under any circumstances
pass
[docs]
def compute(self, info: InfoType) -> bool:
"""Evaluate the termination condition.
The underlying quantity is first evaluated. The episode continues if
all the elements of its value are within bounds, otherwise the episode
is either truncated or terminated according to 'is_truncation'.
.. warning::
This method is not meant to be overloaded.
"""
# Evaluate the quantity
value = self.data.get()
# Check if the quantity is out-of-bounds bound.
# Note that it may be `None` if the quantity is ill-defined for the
# current simulation state, which triggers termination unconditionally.
is_done = value is None
is_done |= self.low is not None and bool(np.any(self.low > value))
is_done |= self.high is not None and bool(np.any(value > self.high))
return is_done
QuantityTermination.name.__doc__ = \
"""Name uniquely identifying every termination condition.
It will be used as key not only for storing termination condition-specific
monitoring and debugging information in 'info', but also for adding the
underlying quantity to the ones already managed by the environment.
"""