Source code for gym_jiminy.common.quantities.transform

"""Generic quantities that may be relevant for any kind of robot, regardless
its topology (multiple or single branch, fixed or floating base...) and the
application (locomotion, grasping...).
"""
import sys
import warnings
from operator import sub
from dataclasses import dataclass
from types import EllipsisType
from typing import (
    Any, Optional, Sequence, Tuple, TypeVar, Union, Generic, ClassVar,
    Callable, Literal, List, overload, cast)

import numpy as np

from jiminy_py import tree
from jiminy_py.core import (  # pylint: disable=no-name-in-module
    array_copyto, multi_array_copyto)

from ..bases import InterfaceJiminyEnv, InterfaceQuantity, QuantityCreator
from ..bases.compositions import ArrayOrScalar
from ..utils import DataNested, build_reduce


ValueT = TypeVar('ValueT')
OtherValueT = TypeVar('OtherValueT')
YetAnotherValueT = TypeVar('YetAnotherValueT')


[docs] @dataclass(unsafe_hash=True) class StackedQuantity( InterfaceQuantity[OtherValueT], Generic[ValueT, OtherValueT]): """Keep track of a given quantity over time by automatically stacking its value once per environment step since last reset. .. note:: A new entry is added to the stack right before evaluating the reward and termination conditions. Internal simulation steps, observer and controller updates are ignored. """ quantity: InterfaceQuantity[ValueT] """Base quantity whose value must be stacked over time since last reset. """ max_stack: int """Maximum number of values that keep in memory before starting to discard the oldest one (FIFO). `sys.maxsize` if unlimited. """ as_array: bool """Whether to return data as a tuple or a contiguous N-dimensional array whose last dimension gathers the value of individual timesteps. """ is_wrapping: bool """Whether to wrap the stack around (i.e. starting filling data back from the start when full) when full instead of shifting data to the left. """ allow_update_graph: ClassVar[bool] = False """Disable dynamic computation graph update. """ @overload def __init__(self: "StackedQuantity[ValueT, List[ValueT]]", env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantity: QuantityCreator[ValueT], *, max_stack: int, is_wrapping: bool, as_array: Literal[False]) -> None: ... @overload def __init__(self: "StackedQuantity[Union[np.ndarray, float], np.ndarray]", env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantity: QuantityCreator[Union[np.ndarray, float]], *, max_stack: int, is_wrapping: bool, as_array: Literal[True]) -> None: ... def __init__(self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantity: QuantityCreator[Any], *, max_stack: int = sys.maxsize, is_wrapping: bool = False, as_array: bool = False) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantity: Tuple gathering the class of the quantity whose values must be stacked, plus all its constructor keyword- arguments except environment 'env' and 'parent'. :param max_stack: Maximum number of values that keep in memory before starting to discard the oldest one (FIFO). Optional: The maxium sequence length by default, ie `sys.maxsize` (2^63 - 1). :param is_wrapping: Whether to wrap the stack around (i.e. starting filling data back from the start when full) when full instead of shifting data to the left. Note that wrapping around is much faster for large stack but does not preserve temporal ordering. Optional: False by default. :param as_array: Whether to return data as a list or a contiguous N-dimensional array whose last dimension gathers the value of individual timesteps. """ # Make sure that the input arguments are valid if max_stack > 10000 and (as_array and not is_wrapping): warnings.warn( "Very large stack length is strongly discourages for " "`as_array=True` and `is_wrapping=False`.") # Backup user arguments self.max_stack = max_stack self.is_wrapping = is_wrapping self.as_array = as_array # Call base implementation super().__init__(env, parent, requirements=dict(quantity=quantity), auto_refresh=True) # Define specialized flattening operators for efficiency self._use_deepcopy = False self._dst_flat: List[np.ndarray] = [] self._src_flat: List[np.ndarray] = [] self._flatten_dst: Callable[[DataNested], None] = lambda data: None self._flatten_src: Callable[[DataNested], None] = lambda data: None # Allocate stack buffer. # Note that using a plain old list is more efficient than dequeue in # practice. Although front deletion is very fast compared to list, # casting deque to tuple or list is very slow, which ultimately # prevail. The matter gets worst as the maximum length gets longer. self._value_list: List[ValueT] = [] # Continuous memory to store the whole stack if requested. # Note that it will be allocated lazily since the dimension of the # quantity is not known in advance. self._data = np.array([]) # Define proxy to number of steps of current episode for fast access self._num_steps = np.array(-1) # Keep track of the last time the quantity has been stacked self._num_steps_prev = -1
[docs] def initialize(self) -> None: # Call base implementation super().initialize() # Refresh proxy self._num_steps = self.env.num_steps # Clear stack buffer self._value_list.clear() # Get current value of base quantity value = self.quantity.get() # Try to define specialized operators based on value. # This would succeeded if and only if all leaves are `np.ndarray`. # Do not try again if deepcopy mode was previously enabled. if not self.as_array and not self._use_deepcopy: try: self._flatten_dst = build_reduce(fn=self._dst_flat.append, op=None, dataset=(), space=value, arity=1, forward_bounds=False) self._flatten_src = build_reduce(fn=self._src_flat.append, op=None, dataset=(), space=value, arity=1, forward_bounds=False) except AssertionError: # Falling back to generic deepcopy self._use_deepcopy = True # Initialize buffers if necessary if self.as_array: # Make sure that the value of the quantity is supported if not isinstance(value, (int, float, np.ndarray, np.number)): raise ValueError( "'as_array=True' is only supported by quantities " "returning N-dimensional arrays as value.") _value = np.asarray(value) # Allocate contiguous memory if necessary self._data = np.zeros( (*_value.shape, self.max_stack), order='F', dtype=_value.dtype) # Reset step counter self._num_steps_prev = -1
[docs] def refresh(self) -> OtherValueT: # Check if there is anything to do must_refresh = True num_steps = self._num_steps.item() if self.env.is_simulation_running: # Early return if the stack if already up to date if num_steps == self._num_steps_prev: must_refresh = False # Make sure that no steps are missing in the stack elif num_steps != self._num_steps_prev + 1: raise RuntimeError( "Previous step missing in the stack. Please reset the " "environment after adding this quantity.") else: must_refresh = False # Extract contiguous slice of (future) available data if necessary if self.as_array: data = self._data num_stack = num_steps + 1 if num_stack < self.max_stack: data = self._data[..., :num_stack] # Append current value of the quantity to the history buffer or update # aggregated continuous array directly if necessary. if must_refresh: # Get the current value of the quantity value = self.quantity.get() # Get current index if wrapping around if self.is_wrapping: index = num_steps % self.max_stack # Append value to the history or aggregate data directly is_stack_full = num_steps >= self.max_stack if self.as_array: if self.is_wrapping: array_copyto(data[..., index], value) else: # Shift all available data one timestep to the left. # Operate on (future) available data only for efficiency. if is_stack_full: array_copyto(data[..., :-1], data[..., 1:]) # Update most recent value in stack with the current one array_copyto(data[..., -1], value) else: # Remove oldest value in the stack if full update_buffer = is_stack_full and not self._use_deepcopy update_in_place = update_buffer and self.is_wrapping if update_in_place: buffer = self._value_list[index] elif update_buffer: buffer = self._value_list.pop(0) # Copy of the current value, while avoiding memory allocation # if possible for efficiency. Note that data must be # "deep-copied" to make sure it does not get altered afterward. if update_buffer: # pylint: disable=used-before-assignment try: self._dst_flat.clear() self._flatten_dst(buffer) self._src_flat.clear() self._flatten_src(value) # type: ignore[arg-type] multi_array_copyto(self._dst_flat, self._src_flat) except AssertionError: # The value of the quantity has changed its memory # layout wrt initialization. Enabling generic deepcopy # fallback from now on. buffer = tree.deepcopy(value) self._use_deepcopy = True else: buffer = tree.deepcopy(value) # Add copied value to the stack if necessary if not update_in_place: if is_stack_full and self.is_wrapping: self._value_list.insert(0, buffer) else: self._value_list.append(buffer) # Increment step counter self._num_steps_prev += 1 # Return aggregate data if requested if self.as_array: return cast(OtherValueT, data) # Return the whole stack as a list to preserve the integrity of the # underlying container and make the API robust to internal changes. return cast(OtherValueT, tuple(self._value_list))
[docs] @dataclass(unsafe_hash=True) class MaskedQuantity(InterfaceQuantity[np.ndarray]): """Extract a pre-defined set of elements from a given quantity whose value is a N-dimensional array along an axis. Elements will be extract by copy unless the indices of the elements to extract to be written equivalently by a slice (ie they are evenly spaced), and the array can be flattened while preserving memory contiguity if 'axis' is `None`, which means that the result will be different between C- and F- contiguous arrays. """ quantity: InterfaceQuantity[np.ndarray] """Base quantity whose elements must be extracted. """ indices: Tuple[Union[int, EllipsisType], ...] """Indices of the elements to extract. """ axis: Optional[int] """Axis over which to extract elements. `None` to consider flattened array. """ def __init__(self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantity: QuantityCreator[np.ndarray], keys: Union[Sequence[Union[int, EllipsisType]], Sequence[bool]], *, axis: Optional[int] = 0) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantity: Tuple gathering the class of the quantity whose values must be extracted, plus any keyword-arguments of its constructor except 'env' and 'parent'. :param keys: Sequence of indices or boolean mask that will be used to extract elements from the quantity along one axis. Ellipsis can be specified to automatically extract any indices in between surrounding indices or at both ends. Ellipsis on the right end is only supported for indices with constant stride. :param axis: Axis over which to extract elements. `None` to consider flattened array. Optional: First axis by default. """ # Convert boolean mask to indices if any(isinstance(e, (bool, np.bool_)) for e in keys): if not all(isinstance(e, (bool, np.bool_)) for e in keys): raise ValueError( "Interleave boolean mask with ellipsis is not supported.") keys = tuple(np.flatnonzero(keys)) # type: ignore[arg-type] # Convert keys to tuple while removing consecutive ellipsis if any keys = tuple( e if e is Ellipsis else int(e) for e, _next in zip(keys, (*keys[1:], object())) if e is not Ellipsis or _next != e) # Replace intermediary ellipsis by indices if possible. # Note that it is important to do this substitution BEFORE storing # indices as attribute, otherwise masked quantities whose keys are # different be actually corresponds to identicial indices would be # identified as different as recomputed, e.g. (1, 2, 3) vs (..., 3). if any(e is Ellipsis for e in keys): for i in range(len(keys))[1:-1][::-1]: if keys[i] is Ellipsis: indices = range( keys[i - 1], keys[i + 1]) # type: ignore[arg-type] keys = (*keys[:(i - 1)], *indices, *keys[(i + 1):]) if len(keys) > 1 and keys[0] is Ellipsis: assert isinstance(keys[1], int) keys = (*range(0, keys[1]), *keys[1:]) # Make sure that at least one index must be extracted if not keys: raise ValueError( "No indices to extract from quantity. Data would be empty.") # Make sure that at least one index must be extracted if keys == (Ellipsis,): raise ValueError( "Specifying `keys=(...,)` is not allowed as it has no effect.") # Check if indices or ellipsis has been provided if not all((e is Ellipsis) or isinstance(e, int) for e in keys): raise ValueError( "Argument 'keys' invalid. It must either be a boolean mask, " "or a sequence of indices and ellipsis.") # Backup user arguments self.indices = keys self.axis = axis # Check if the indices are evenly spaced stride: Optional[int] = None keys_heads, key_tail = cast(Tuple[int, ...], keys[:-1]), keys[-1] if len(keys) == 1: stride = 1 elif all(e >= 0 for e in keys if e is not Ellipsis): if key_tail is Ellipsis: spaces = np.array((*np.diff(keys_heads), 1)) else: spaces = np.diff((*keys_heads, key_tail)) try: (stride,) = np.unique(spaces) except ValueError as e: if key_tail is Ellipsis: raise ValueError( "Ellipsis on the right end is only supported for " "sequence of indices with constant stride.") from e # Convert indices to slices if possible self._slices: Tuple[Union[slice, EllipsisType], ...] = () if stride is not None: slice_ = slice(keys[0], None if key_tail is Ellipsis else key_tail + 1, stride) if axis is None: self._slices = (slice_,) elif axis >= 0: self._slices = (*((slice(None),) * axis), slice_) else: self._slices = ( Ellipsis, slice_, *((slice(None),) * (- axis - 1))) # Call base implementation super().__init__(env, parent, requirements=dict(quantity=quantity), auto_refresh=False)
[docs] def refresh(self) -> np.ndarray: # Get current value of base quantity value = self.quantity.get() # Extract elements from quantity if not self._slices: # Note that `take` is faster than classical advanced indexing via # `operator[]` (`__getitem__`) because the latter is more generic. # Notably, `operator[]` supports boolean mask but `take` does not. return value.take( self.indices, self.axis) # type: ignore[arg-type] if self.axis is None: # `ravel` must be used instead of `flat` to get a view that can # be sliced without copy. return value.ravel(order="K")[self._slices] return value[self._slices]
[docs] @dataclass(unsafe_hash=True) class ConcatenatedQuantity(InterfaceQuantity[np.ndarray]): """Concatenate a set of quantities whose value are N-dimensional arrays along a given axis. All the quantities must have the same shape, except for the dimension corresponding to concatenation axis. .. note:: For efficiency and convenience, built-in scalars and 0-D arrays are treated as 1D arrays. For instance, multiple floats can be concatenated as a vector. """ quantities: Tuple[InterfaceQuantity[np.ndarray], ...] """Base quantities whose values must be concatenated. """ axis: int """Axis over which to concatenate values. """ def __init__(self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantities: Sequence[QuantityCreator[np.ndarray]], *, axis: int = 0) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantities: Sequence of tuples, each of which gathering the class of the quantity whose values must be extracted, plus any keyword-arguments of its constructor except 'env' and 'parent'. :param axis: Axis over which to concatenate values. Optional: First axis by default. """ # Backup user arguments self.axis = axis # Call base implementation super().__init__(env, parent, requirements={ str(i): quantity for i, quantity in enumerate(quantities) }, auto_refresh=False) # Define proxies for fast access if len(quantities) < 2: raise ValueError( "Specifying less than 2 quantities is not allowed.") self.quantities = tuple(self.requirements.values()) # Continuous memory to store the result # Note that it will be allocated lazily since the dimension of the # quantity is not known in advance. self._data = np.array([]) # Store slices of data associated with each individual quantity self._data_slices: List[np.ndarray] = []
[docs] def initialize(self) -> None: # Call base implementation super().initialize() # Get current value of all the quantities # Dealing with special case where value is a float, as it would impede # performance to force allocating a 1D array before concatenation. values = tuple( np.atleast_1d(quantity.get()) for quantity in self.quantities) # Allocate contiguous memory self._data = np.concatenate(values, axis=self.axis) # Compute slices of data self._data_slices.clear() idx_start = 0 for data in values: idx_end = idx_start + data.shape[self.axis] self._data_slices.append(self._data[ (*((slice(None),) * self.axis), slice(idx_start, idx_end))]) idx_start = idx_end
[docs] def refresh(self) -> np.ndarray: # Refresh the contiguous buffer multi_array_copyto(self._data_slices, [quantity.get() for quantity in self.quantities]) return self._data
[docs] @dataclass(unsafe_hash=True) class UnaryOpQuantity(InterfaceQuantity[ValueT], Generic[ValueT, OtherValueT]): """Apply a given unary operator to a quantity. This quantity is useful to translate quantities from world frame to local odometry frame. It may also be used to convert multi-variate quantities as scalar, typically by computing the L^p-norm. """ quantity: InterfaceQuantity[OtherValueT] """Quantity that will be forwarded to the unary operator. """ op: Callable[[OtherValueT], ValueT] """Callable taking any value of the quantity as input argument. """ def __init__(self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantity: QuantityCreator[OtherValueT], op: Callable[[OtherValueT], ValueT]) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantity: Tuple gathering the class of the quantity whose value must be passed as argument of the unary operator, plus any keyword-arguments of its constructor except 'env' and 'parent'. :param op: Any callable taking any value of the quantity as input argument. For example `partial(np.linalg.norm, ord=2)` to compute the difference. """ # Backup some user argument(s) self.op = op # Call base implementation super().__init__( env, parent, requirements=dict(quantity=quantity), auto_refresh=False)
[docs] def refresh(self) -> ValueT: return self.op(self.quantity.get())
[docs] @dataclass(unsafe_hash=True) class BinaryOpQuantity(InterfaceQuantity[ValueT], Generic[ValueT, OtherValueT, YetAnotherValueT]): """Apply a given binary operator between two quantities. This quantity is mainly useful for computing the error between the value of a given quantity evaluated at the current simulation state and the state of at the current simulation time for the reference trajectory being selected. """ quantity_left: InterfaceQuantity[OtherValueT] """Left-hand side quantity that will be forwarded to the binary operator. """ quantity_right: InterfaceQuantity[YetAnotherValueT] """Right-hand side quantity that will be forwarded to the binary operator. """ op: Callable[[OtherValueT, YetAnotherValueT], ValueT] """Callable taking left- and right-hand side quantities as input argument. """ def __init__(self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantity_left: QuantityCreator[OtherValueT], quantity_right: QuantityCreator[YetAnotherValueT], op: Callable[[OtherValueT, YetAnotherValueT], ValueT] ) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantity_left: Tuple gathering the class of the quantity that must be passed to left-hand side of the binary operator, plus all its constructor keyword- arguments except environment 'env' and parent 'parent. :param quantity_right: Quantity that must be passed to right-hand side of the binary operator as a tuple (class, keyword-arguments). See `quantity_left` argument for details. :param op: Any callable taking the right- and left-hand side quantities as input argument. For example `operator.sub` to compute the difference. """ # Backup some user argument(s) self.op = op # Call base implementation super().__init__( env, parent, requirements=dict( quantity_left=quantity_left, quantity_right=quantity_right), auto_refresh=False)
[docs] def refresh(self) -> ValueT: return self.op(self.quantity_left.get(), self.quantity_right.get())
[docs] @dataclass(unsafe_hash=True) class MultiAryOpQuantity(InterfaceQuantity[ValueT]): """Apply a given n-ary operator to the values of a given set of quantities. """ quantities: Tuple[InterfaceQuantity[Any], ...] """Sequence of quantities that will be forwarded to the n-ary operator in this exact order. """ op: Callable[[Sequence[Any]], ValueT] """Callable taking the packed sequence of values for all the specified quantities as input argument. """ def __init__(self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantities: Sequence[QuantityCreator[Any]], op: Callable[[Sequence[Any]], ValueT]) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantities: Ordered sequence of n pairs, each gathering the class of a quantity whose value must be passed as argument of the n-ary operator, plus any keyword-arguments of its constructor except 'env' and 'parent'. :param op: Any callable taking the packed sequence of values for all the quantities as input argument, in the exact order they were originally specified. """ # Backup some user argument(s) self.op = op # Call base implementation super().__init__( env, parent, requirements={ f"quantity_{i}": quantity for i, quantity in enumerate(quantities)}, auto_refresh=False) # Keep track of the instantiated quantities for identity check self.quantities = tuple(self.requirements.values())
[docs] def refresh(self) -> ValueT: return self.op([quantity.get() for quantity in self.quantities])
[docs] @dataclass(unsafe_hash=True) class DeltaQuantity(InterfaceQuantity[ArrayOrScalar]): """Variation of a given quantity over the whole span of a horizon. If `bounds_only=False`, then the differences of the value of the quantity between successive timesteps is accumulated over a variable-length history bounded by 'max_stack', which is basically a sliding window. The total variation over this horizon is defined as the sum of all the successive differences stored in the history. If `bounds_only=True`, then the value of the quantity is accumulated over a variable-length history bounded by 'max_stack'. The total variation is simply computed as the difference between most recent and oldest values stored in the history. """ quantity: InterfaceQuantity[ Union[np.ndarray, Sequence[ArrayOrScalar]]] """Quantity from which to compute the total variation over the history. """ op: Callable[[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar] """Any callable taking as input argument the current and some previous value of the quantity in that exact order, and returning the signed difference between them. """ max_stack: int """Time horizon over which to compute the variation. """ bounds_only: bool """Whether to compute the total variation as the difference between the most recent and oldest value stored in the history, or the sum of differences between successive timesteps. """ def __init__( self, env: InterfaceJiminyEnv, parent: Optional[InterfaceQuantity], quantity: QuantityCreator[ArrayOrScalar], horizon: Optional[float], *, op: Callable[[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar] = sub, bounds_only: bool = True) -> None: """ :param env: Base or wrapped jiminy environment. :param parent: Higher-level quantity from which this quantity is a requirement if any, `None` otherwise. :param quantity: Tuple gathering the class of the quantity from which to compute the variation, plus any keyword-arguments of its constructor except 'env' and 'parent'. :param horizon: Horizon over which values of the quantity will be stacked before computing the drift. `None` to consider only two successive timesteps. :param op: Any callable taking as input argument the current and some previous value of the quantity in that exact order, and returning the signed difference between them. Typically, the substraction operation is appropriate for position in Euclidean space, but not for orientation as it is important to count turns. Optional: `sub` by default. :param bounds_only: Whether to compute the total variation as the difference between the most recent and oldest value stored in the history, or the sum of differences between successive timesteps. Optional: True by default. """ # Convert horizon in stack length, assuming constant env timestep if horizon is None: max_stack = 2 else: max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1 # Backup some of the user-arguments self.op = op self.max_stack = max_stack self.bounds_only = bounds_only # Define the appropriate quantity quantity_stack: QuantityCreator if bounds_only: quantity_stack = (StackedQuantity, dict( quantity=quantity, max_stack=max_stack, is_wrapping=False, as_array=False)) else: quantity_stack = (StackedQuantity, dict( quantity=(DeltaQuantity, dict( quantity=quantity, horizon=None, bounds_only=True, op=op)), max_stack=(max_stack - 1), is_wrapping=True, as_array=True)) # Call base implementation super().__init__( env, parent, requirements=dict( quantity_stack=quantity_stack), auto_refresh=False) # Keep try of the underlying quantity for equality check if bounds_only: self.quantity = self.quantity_stack.quantity else: self.quantity = self.quantity_stack.quantity.quantity
[docs] def refresh(self) -> ArrayOrScalar: quantity_stack = self.quantity_stack.get() if self.bounds_only: return self.op(quantity_stack[-1], quantity_stack[0]) return quantity_stack.sum(axis=-1)