Source code for gym_jiminy.common.utils.spaces

"""Utilities operating over complex `Gym.Space`s associated with arbitrarily
nested data structure of `np.ndarray` and heavily optimized for speed.

They combine static control flow pre-computed for given space and eventually
some pre-allocated values with Just-In-Time (JIT) compiling via Numba when
possible for optimal performance.
"""
import math
from functools import partial
from collections import OrderedDict
from collections.abc import Mapping, MutableMapping, Sequence, MutableSequence
from typing import (
    Any, Dict, Optional, Union, Sequence as SequenceT, Mapping as MappingT,
    Iterable, Tuple, Literal, SupportsFloat, TypeVar, Type, Callable,
    no_type_check, cast)

import numba as nb
import numpy as np
from numpy import typing as npt

import gymnasium as gym

from jiminy_py import tree
from jiminy_py.core import array_copyto  # pylint: disable=no-name-in-module


ValueT = TypeVar('ValueT')
ValueInT = TypeVar('ValueInT')
ValueOutT = TypeVar('ValueOutT')

StructNested = Union[MappingT[str, 'StructNested[ValueT]'],
                     Iterable['StructNested[ValueT]'],
                     ValueT]
FieldNested = StructNested[str]
DataNested = StructNested[np.ndarray]
DataNestedT = TypeVar('DataNestedT', bound=DataNested)

ArrayOrScalar = Union[np.ndarray, SupportsFloat]


@no_type_check
@nb.jit(nopython=True, cache=True, fastmath=True)
def _array_clip(value: np.ndarray,
                low: Optional[ArrayOrScalar],
                high: Optional[ArrayOrScalar]) -> np.ndarray:
    """Element-wise out-of-place clipping of array elements.

    :param value: Array holding values to clip.
    :param low: Optional lower bound.
    :param high: Optional upper bound.
    """
    # Note that in-place clipping is actually slower than out-of-place in
    # Numba when 'fastmath' compilation flag is set.

    # Short circuit if there is neither low or high bounds
    if low is None and high is None:
        return value.copy()

    # Generic case.
    # Note that chaining `np.minimum` with `np.maximum` yields to better
    # performance than `np.clip` when 'fastmath' compilation flag is set.
    if value.ndim:
        if low is not None and high is not None:
            return np.minimum(np.maximum(value, low), high)
        if low is not None:
            return np.maximum(value, low)
        return np.minimum(value, high)

    # Scalar case.
    # Strangely, calling '.item()' on Python scalars is supported by Numba.
    out = value.item()
    if low is not None:
        out = max(out, low.item())
    if high is not None:
        out = min(out, high.item())
    return np.array(out)


@no_type_check
@nb.jit(nopython=True, cache=True, fastmath=True)
def _array_contains(value: np.ndarray,
                    low: Optional[ArrayOrScalar],
                    high: Optional[ArrayOrScalar],
                    tol_abs: float,
                    tol_rel: float) -> bool:
    """Check that all array elements are withing bounds, up to some tolerance
    threshold. If both absolute and relative tolerances are provided, then
    satisfying only one of the two criteria is considered sufficient.

    :param value: Array holding values to check.
    :param low: Optional lower bound.
    :param high: Optional upper bound.
    :param tol_abs: Absolute tolerance.
    :param tol_rel: Relative tolerance. It will be ignored if either the lower
                    or upper is not specified.
    """
    if value.ndim:
        tol_nd = np.full_like(value, tol_abs)
        if low is not None and high is not None and tol_rel > 0.0:
            tol_nd = np.maximum((high - low) * tol_rel, tol_nd)
        # Reversed bound check because 'all' is always true for empty arrays
        if low is not None and not (low - tol_nd <= value).all():
            return False
        if high is not None and not (value <= high + tol_nd).all():
            return False
        return True
    tol_0d = tol_abs
    if low is not None and high is not None and tol_rel > 0.0:
        tol_0d = max((high.item() - low.item()) * tol_rel, tol_0d)
    if low is not None and (low.item() - tol_0d > value.item()):
        return False
    if high is not None and (value.item() > high.item() + tol_0d):
        return False
    return True


[docs] def get_bounds(space: gym.Space ) -> Tuple[Optional[ArrayOrScalar], Optional[ArrayOrScalar]]: """Get the lower and upper bounds of a given 'gym.Space' if any. :param space: `gym.Space` on which to operate. :returns: Lower and upper bounds as a tuple. """ if isinstance(space, gym.spaces.Box): return space.low, space.high if isinstance(space, gym.spaces.Discrete): return space.start, space.n if isinstance(space, gym.spaces.MultiDiscrete): return 0, space.nvec return None, None
[docs] @no_type_check def zeros(space: gym.Space[DataNestedT], dtype: npt.DTypeLike = None, enforce_bounds: bool = True) -> DataNestedT: """Allocate data structure from `gym.Space` and initialize it to zero. :param space: `gym.Space` on which to operate. :param dtype: Can be specified to overwrite original space dtype. Optional: None by default """ # Note that it is not possible to take advantage of `jiminy_py.tree` # because the output type for collections (OrderedDict or Tuple) is not the # same as the input one (gym.Space). value = None if isinstance(space, gym.spaces.Dict): value = OrderedDict() for field, subspace in space.spaces.items(): value[field] = zeros(subspace, dtype=dtype) return value if isinstance(space, gym.spaces.Tuple): value = tuple(zeros(subspace, dtype=dtype) for subspace in space.spaces.values()) elif isinstance(space, gym.spaces.Box): value = np.zeros(space.shape, dtype=dtype or space.dtype) elif isinstance(space, gym.spaces.Discrete): # Note that np.array of 0 dim is returned in order to be mutable value = np.array(0, dtype=dtype or np.int64) elif isinstance(space, gym.spaces.MultiDiscrete): value = np.zeros_like(space.nvec, dtype=dtype or np.int64) elif isinstance(space, gym.spaces.MultiBinary): value = np.zeros(space.n, dtype=dtype or np.int8) if value is not None: if enforce_bounds: value = clip(value, space) return value if not isinstance(space, gym.Space): raise ValueError( "All spaces must derived from `gym.Space`, including tuple and " "dict containers.") raise NotImplementedError( f"Space of type {type(space)} is not supported.")
[docs] def fill(data: DataNested, fill_value: Union[float, int, np.number]) -> None: """Set every element of 'data' from `gym.Space` to scalar 'fill_value'. :param data: Data structure to update. :param fill_value: Value used to fill any scalar from the leaves. """ for value in tree.flatten(data): try: value.fill(fill_value) except AttributeError as e: raise ValueError( "Leaves of 'data' structure must have type `np.ndarray`." ) from e
[docs] def copyto(dst: DataNested, src: DataNested) -> None: """Copy arbitrarily nested data structure of 'np.ndarray' to a given pre-allocated destination. It avoids memory allocation completely, so that memory pointers of 'data' remains unchanged. As direct consequences, it is necessary to preallocate memory beforehand, and it only supports arrays of fixed shape. .. note:: Unlike the function returned by 'build_copyto', only the flattened data structure needs to match, not the original one. This means that the source and/or destination can be flattened already when provided. Beware values must be sorted by keys in case of nested dict. :param dst: Hierarchical data structure to update, possibly flattened. :param value: Hierarchical data to copy, possibly flattened. """ for data, value in zip(*map(tree.flatten, (dst, src))): array_copyto(data, value)
[docs] def copy(data: DataNestedT) -> DataNestedT: """Shallow copy recursively 'data' from `gym.Space`, so that only leaves are still references. :param data: Hierarchical data structure to copy without allocation. """ return cast(DataNestedT, tree.unflatten_as(data, tree.flatten(data)))
[docs] @no_type_check def clip(data: DataNested, space: gym.Space[DataNested]) -> DataNested: """Clip data from `gym.Space` to make sure it is within bounds. .. note: None of the leaves of the returned data structured is sharing memory with the original one, even if clipping had no effect. This alleviate the need of calling 'deepcopy' afterward. :param data: Data to clip. :param space: `gym.Space` on which to operate. """ data_type = type(data) if tree.issubclass_mapping(data_type): return data_type({ field: clip(data[field], subspace) for field, subspace in space.spaces.items()}) if tree.issubclass_sequence(data_type): return data_type([ clip(data[i], subspace) for i, subspace in enumerate(space.spaces)]) return _array_clip(data, *get_bounds(space))
[docs] @no_type_check def contains(data: DataNested, space: gym.Space[DataNested], tol_abs: float = 0.0, tol_rel: float = 0.0) -> bool: """Check if all leaves of a nested data structure are within bounds of their respective `gym.Space`, up to some tolerance threshold. If both absolute and relative tolerances are provided, then satisfying only one of the two criteria is considered sufficient. By design, it is always `True` for all spaces but `gym.spaces.Box`, `gym.spaces.Discrete` and `gym.spaces.MultiDiscrete`. :param data: Data structure to check. :param space: `gym.Space` on which to operate. :param tol_abs: Absolute tolerance. :param tol_rel: Relative tolerance. """ data_type = type(data) if tree.issubclass_mapping(data_type): return all(contains(data[field], subspace, tol_abs, tol_rel) for field, subspace in space.spaces.items()) if tree.issubclass_sequence(data_type): return all(contains(data[i], subspace, tol_abs, tol_rel) for i, subspace in enumerate(space.spaces)) return _array_contains(data, *get_bounds(space), tol_abs, tol_rel)
[docs] @no_type_check def build_reduce(fn: Callable[..., ValueInT], op: Optional[Callable[[ValueOutT, ValueInT], ValueOutT]], dataset: SequenceT[DataNested], space: Optional[gym.Space[DataNested]], arity: Optional[Literal[0, 1]], *args: Any, initializer: Optional[Callable[[], ValueOutT]] = None, forward_bounds: bool = True) -> Callable[..., ValueOutT]: """Generate specialized callable applying transform and reduction on all leaves of given nested space. .. note:: Original ordering of the leaves is preserved. More precisely, both transform and reduction will be applied recursively in keys order. .. warning:: It is assumed without checking that all nested data structures are consistent together and with the space if provided. It holds true both data known at generation-time or runtime. Yet, it is only required for data provided at runtime if any to include the original data structure, so it may contain additional branches which will be ignored. .. warning:: Providing additional data at runtime is supported but impede performance. Arity larger than 1 is not supported because the code path could not be fully specialized, causing dramatic slowdown. .. warning:: There is no built-in 'short-circuit' mechanism, which means that it will go through all leaves systematically unless the reduction operator itself raises an exception. :param fn: Transform applied to every leaves of the nested data structures before performing the actual reduction. This function can perform in-place or out-of-place operations without restriction. `None` is not supported because it would be irrelevant. Note that if tracking the hierarchy during reduction is not necessary, then it would be way more efficient to first flatten the pre-allocated nested data structure once for all, and then perform reduction on this flattened view using the standard 'functools.reduce' method. Still, flattening at runtime using 'flatten' would still much slower than a specialized nested reduction. :param op: Optional reduction operator applied cumulatively on all leaves after transform. See 'functools.reduce' documentation for details. `None` to only apply transform on all leaves without reduction. This is useful when apply in-place transform. :param data: Pre-allocated nested data structure. Optional if the space is provided but hardly relevant. :param space: Container space on which to operate (eg `gym.spaces.Dict` or `gym.spaces.Tuple`). Optional iif the nested data structure is provided. :param arity: Arity of the generated callable. `None` to indicate that it must be determined at runtime, which is slower. :param args: Extra arguments to systematically forward as transform input for all leaves. Note that, as for Python built-ins methods, keywords are not supported for the sake of efficiency. :param initializer: Function used to compute the initial value before starting reduction. Optional if the reduction operator has same input and output types. If `None`, then the value corresponding to the first leaf after transform will be used instead. :param forward_bounds: Whether to forward the lower and upper bounds of the `gym.Space` associated with each leaf as transform input. In this case, they will be added after the data structure provided at runtime but before other extra arguments if any. It is up to the user to make sure all leaves have bounds, otherwise it will raise an exception at generation-time. This argument is ignored if not space is specified. :returns: Fully-specialized reduction callable. """ # pylint: disable=unused-argument def _build_reduce( arity: Literal[0, 1], is_initialized: bool, fn_1: Union[Callable[..., ValueInT], Callable[..., ValueOutT]], field_1: Union[str, int], fn_2: Union[Callable[..., ValueInT], Callable[..., ValueOutT]], field_2: Union[str, int], ) -> Callable[..., ValueOutT]: """Internal method generating a specialized callable performing a single reduction operation on either leaf transform and/or already branch reduction. :param arity: Arity of the generated callable. :param is_initialized: Whether the output has already been initialized at this point. The first reduction is the only one to initialize the output, either by calling the initializer if provided or passing directly the output of first transform call otherwise. :param fn_1: Leaf transform or branch reduction to call last. :param field_1: Pass the value corresponding this key as input argument for nested data structure provided at runtime if any iif callable 'fn_1' is a leaf transform. :param is_out_1: Whether callable 'fn_1' is already a branch reduction. :param fn_2: Leaf transform or branch reduction to call first. :param field_2: Same as 'field_1' for callable 'fn_2'. :param is_out_2: Same as 'is_out_1' for callable 'fn_2'. :returns: Specialized branch reduction callable requiring passing the current reduction output as input if some reduction operator has been specified. """ # Extract extra arguments from functor if necessary to preserve order is_out_1, is_out_2 = fn_1.func is not fn, fn_2.func is not fn if not is_out_1: fn_1, dataset, args_1 = fn_1.func, fn_1.args[:-1], fn_1.args[-1] has_args = bool(args_1) if arity == 0: fn_1 = partial(fn_1, *dataset, *args_1) elif dataset: fn_1 = partial(fn_1, *dataset) if not is_out_2: fn_2, dataset, args_2 = fn_2.func, fn_2.args[:-1], fn_2.args[-1] has_args = bool(args_2) if arity == 0: fn_2 = partial(fn_2, *dataset, *args_2) elif dataset: fn_2 = partial(fn_2, *dataset) # Specialization if no op is specified if op is None: if arity == 0: def _reduce(fn_1, fn_2): fn_2() fn_1() return partial(_reduce, fn_1, fn_2) if is_out_1 and is_out_2: def _reduce(fn_1, fn_2, delayed): fn_2(delayed) fn_1(delayed) return partial(_reduce, fn_1, fn_2) if is_out_1 and not is_out_2: if has_args: # pylint: disable=possibly-used-before-assignment def _reduce(fn_1, fn_2, field_2, args_2, delayed): fn_2(delayed[field_2], *args_2) fn_1(delayed) return partial(_reduce, fn_1, fn_2, field_2, args_2) def _reduce(fn_1, fn_2, field_2, delayed): fn_2(delayed[field_2]) fn_1(delayed) return partial(_reduce, fn_1, fn_2, field_2) if not is_out_1 and is_out_2: if has_args: def _reduce(fn_1, field_1, args_1, fn_2, delayed): fn_2(delayed) fn_1(delayed[field_1], *args_1) return partial(_reduce, fn_1, field_1, args_1, fn_2) def _reduce(fn_1, field_1, fn_2, delayed): fn_2(delayed) fn_1(delayed[field_1]) return partial(_reduce, fn_1, field_1, fn_2) if has_args: def _reduce( fn_1, field_1, args_1, fn_2, field_2, args_2, delayed): fn_2(delayed[field_2], *args_2) fn_1(delayed[field_1], *args_1) return partial( _reduce, fn_1, field_1, args_1, fn_2, field_2, args_2) def _reduce(fn_1, field_1, fn_2, field_2, delayed): fn_2(delayed[field_2]) fn_1(delayed[field_1]) return partial(_reduce, fn_1, field_1, fn_2, field_2) # Specialization if op is specified if arity == 0: if is_initialized: if is_out_1 and is_out_2: def _reduce(fn_1, fn_2, out): return fn_1(fn_2(out)) return partial(_reduce, fn_1, fn_2) if is_out_1 and not is_out_2: def _reduce(op, fn_1, fn_2, out): return fn_1(op(out, fn_2())) elif not is_out_1 and is_out_2: def _reduce(op, fn_1, fn_2, out): return op(fn_2(out), fn_1()) else: def _reduce(op, fn_1, fn_2, out): return op(op(out, fn_2()), fn_1()) return partial(_reduce, op, fn_1, fn_2) if is_out_1 and not is_out_2: def _reduce(fn_1, fn_2, out): return fn_1(fn_2()) return partial(_reduce, fn_1, fn_2) if not is_out_1 and not is_out_2: def _reduce(op, fn_1, fn_2, out): return op(fn_2(), fn_1()) return partial(_reduce, op, fn_1, fn_2) if is_initialized: if is_out_1 and is_out_2: def _reduce(fn_1, fn_2, out, delayed): return fn_1(fn_2(out, delayed), delayed) return partial(_reduce, fn_1, fn_2) if is_out_1 and not is_out_2: def _reduce(op, fn_1, fn_2, field_2, args_2, out, delayed): return fn_1( op(out, fn_2(delayed[field_2], *args_2)), delayed) return partial(_reduce, op, fn_1, fn_2, field_2, args_2) if not is_out_1 and is_out_2: def _reduce(op, fn_1, field_1, args_1, fn_2, out, delayed): return op( fn_2(out, delayed), fn_1(delayed[field_1], *args_1)) return partial(_reduce, op, fn_1, field_1, args_1, fn_2) def _reduce( op, fn_1, field_1, args_1, fn_2, field_2, args_2, out, delayed): return op(op(out, fn_2(delayed[field_2], *args_2)), fn_1(delayed[field_1], *args_1)) return partial( _reduce, op, fn_1, field_1, args_1, fn_2, field_2, args_2) if is_out_1 and not is_out_2: def _reduce(fn_1, fn_2, field_2, args_2, out, delayed): return fn_1(fn_2(delayed[field_2], *args_2), delayed) return partial(_reduce, fn_1, fn_2, field_2, args_2) def _reduce( # type: ignore[no-redef] op, fn_1, field_1, args_1, fn_2, field_2, args_2, out, delayed): return op(fn_2(delayed[field_2], *args_2), fn_1(delayed[field_1], *args_1)) return partial( _reduce, op, fn_1, field_1, args_1, fn_2, field_2, args_2) def _build_forward( arity: Literal[0, 1], parent: Optional[Union[str, int]], is_initialized: bool, post_fn: Union[Callable[..., ValueInT], Callable[..., ValueOutT]], field: Optional[Union[str, int]], ) -> Union[Callable[..., ValueInT], Callable[..., ValueOutT]]: """Internal method generating a specialized callable forwarding the value associated with a given key for nested data structure provided at runtime if any as input argument of some leaf transform or branch reduction callable. The callable is not a reduction at this point, so doing it here since it is the very last moment before main entry-point returns. :param arity: Arity of the generated callable. :param is_initialized: Whether the output has already been initialized. :param parent: Parent key to forward. :param post_fn: Leaf transform or branch reduction. :returns: Specialized key-forwarding callable. """ is_out = post_fn.func is not fn if parent is None and not is_out: # Extract extra arguments from functor to preserve arguments order dataset, args = post_fn.args[:-1], post_fn.args[-1] post_fn = post_fn.func has_args = bool(args) if arity == 0: post_fn = partial(post_fn, *dataset, *args) elif dataset: post_fn = partial(post_fn, *dataset) # Specialization if no op is specified if op is None: if arity == 0: def _forward(post_fn): post_fn() return partial(_forward, post_fn) if has_args: if field is None: def _forward(post_fn, args, delayed): post_fn(delayed, *args) return partial(_forward, post_fn, args) def _forward(post_fn, field, args, delayed): post_fn(delayed[field], *args) return partial(_forward, post_fn, field, args) if field is None: def _forward(post_fn, delayed): post_fn(delayed) return partial(_forward, post_fn) def _forward(post_fn, field, delayed): post_fn(delayed[field]) return partial(_forward, post_fn, field) # Specialization if op is specified if arity == 0: if is_initialized: def _forward(op, post_fn, out): return op(out, post_fn()) return partial(_forward, op, post_fn) def _forward(post_fn, out): return post_fn() return partial(_forward, post_fn) if is_initialized: if field is None: def _forward(op, post_fn, args, out, delayed): return op(out, post_fn(delayed, *args)) return partial(_forward, op, post_fn, args) def _forward(op, post_fn, field, args, out, delayed): return op(out, post_fn(delayed[field], *args)) return partial(_forward, op, post_fn, field, args) if field is None: def _forward(post_fn, args, out, delayed): return post_fn(delayed, *args) return partial(_forward, post_fn, args) def _forward(post_fn, field, args, out, delayed): return post_fn(delayed[field], *args) return partial(_forward, post_fn, field, args) # No key to forward for main entry-point of zero arity if parent is None or arity == 0: return post_fn # Forward key in all other cases if op is None: def _forward(post_fn, field, delayed): return post_fn(delayed[field]) else: def _forward(post_fn, field, out, delayed): return post_fn(out, delayed[field]) return partial(_forward, post_fn, parent) def _build_transform_and_reduce( arity: Literal[0, 1], parent: Optional[Union[str, int]], is_initialized: bool, dataset: SequenceT[DataNested], space: Optional[gym.Space[DataNested]]) -> Optional[ Union[Callable[..., ValueInT], Callable[..., ValueOutT]]]: """Internal method for generating specialized callable applying transform and reduction on all leaves of a nested space recursively. :param arity: Arity of the generated callable. :param parent: Key of parent space mapping to space if any, `None` otherwise. :param is_initialized: Whether the output has already been initialized at this point. See `_build_reduce` for details. :param data: Possibly nested pre-allocated data. :param space: Possibly nested space on which to operate. :returns: Specialized transform if the space is a actually a leaf, otherwise a specialized transform and reduction callable still requiring passing the current reduction output as input if some reduction operator has been specified. `None` if nested data structure if empty. """ # Determine top-level keys if nested data structure keys: Optional[Union[SequenceT[int], SequenceT[str]]] = None space_or_data = space if space_or_data is None and dataset: space_or_data = dataset[0] if isinstance(space_or_data, Mapping): keys = space_or_data.keys() elif isinstance(space_or_data, Sequence): keys = range(len(space_or_data)) else: assert isinstance(space_or_data, (gym.Space, np.ndarray)) # Return specialized transform if leaf if keys is None: post_fn = fn if not dataset else partial(fn, *dataset) post_args = args if forward_bounds and space is not None: post_args = (*get_bounds(space), *post_args) post_fn = partial(post_fn, post_args) if parent is None: post_fn = _build_forward( arity, parent, is_initialized, post_fn, None) return post_fn if not keys: return None # Generate transform and reduce method if branch field_prev, field, out_fn = None, None, None for field in keys: values = [data[field] for data in dataset] subspace = None if space is None else space[field] must_initialize = not is_initialized and len(keys) == 1 post_fn = _build_transform_and_reduce( arity, field, not must_initialize, values, subspace) if post_fn is None: continue if out_fn is None: out_fn = post_fn else: out_fn = _build_reduce( arity, is_initialized, post_fn, field, out_fn, field_prev) is_initialized = True field_prev = field if out_fn is None: return None return _build_forward(arity, parent, is_initialized, out_fn, field) def _dispatch( post_fn_0: Callable[[], ValueOutT], post_fn_1: Callable[[DataNested], ValueOutT], *delayed: Tuple[DataNested]) -> ValueOutT: """Internal method for handling unknown arity at generation-time. :param post_fn_0: Nullary specialized transform and reduce callable. :param post_fn_1: Unary specialized transform and reduce callable. :param delayed: Optional nested data structure any provided at runtime. :returns: Specialized transform and reduce callable of dynamic arity. """ if not delayed: return post_fn_0() return post_fn_1(delayed[0]) def _build_init( arity: Literal[0, 1], post_fn: Callable[..., ValueOutT]) -> Callable[..., ValueOutT]: """Internal method generating a specialized callable initializing the output if a reduction operator and a dedicated initializer has been specified. :param post_fn: Specialized transform and reduce callable. :returns: Specialized transform and reduce callable only taking nested data structures as input. """ if post_fn is None: if initializer is None: return lambda *args, **kwargs: None return initializer if op is None: return post_fn if initializer is None: return partial(post_fn, None) if arity == 0: def _initialize(post_fn, initializer): return post_fn(initializer()) else: def _initialize(post_fn, initializer, delayed): return post_fn(initializer(), delayed) return partial(_initialize, post_fn, initializer) # Check that the combination of input arguments are valid if space is None and not dataset: raise TypeError("At least one dataset or the space must be specified.") if arity not in (0, 1, None): raise TypeError("Arity must be either 0, 1 or `None`.") if isinstance(fn, partial): raise TypeError("Transform function cannot be 'partial' instance.") # Generate transform and reduce callable of various arity if necessary all_fn = [None, None] for i in (0, 1): if arity is not None and i != arity: continue is_initialized = op is not None and initializer is not None all_fn[i] = _build_init(i, _build_transform_and_reduce( i, None, is_initialized, dataset, space)) # Return callable of requested arity if specified, dynamic dispatch if not if arity is None: return partial(_dispatch, *all_fn) return all_fn[arity]
[docs] @no_type_check def build_map(fn: Callable[..., ValueT], data: Optional[DataNested], space: Optional[gym.Space[DataNested]], arity: Optional[Literal[0, 1]], *args: Any, forward_bounds: bool = True ) -> Callable[[], StructNested[ValueT]]: """Generate specialized callable returning applying out-of-place transform to all leaves of given nested space. .. warning:: This method systematically allocates memory to store the resulting nested data structure, which is costly. If pre-allocation is possible, it would more efficient to use `build_reduce` without operator instead. .. warning:: Providing additional data at runtime is supported but impede performance. Arity larger than 1 is not supported because the code path could not be fully specialized, causing dramatic slowdown. .. warning:: It is assumed without check that all nested data structures are consistent together and with the space if provided. It holds true both data known at generation-time or runtime. Yet, it is only required for data provided at runtime if any to include the original data structure, so it may contain additional branches which will be ignored. :param fn: Transform applied to every leaves of the nested data structures. This function is supposed to allocate its own memory while performing some out-of-place operations then return the outcome. :param data: Pre-allocated nested data structure. Optional iif the space is provided. This enables generating specialized random sampling methods for instance. :param space: `gym.spaces.Dict` on which to operate. Optional iif the nested data structure is provided. :param arity: Arity of the generated callable. `None` to indicate that it must be determined at runtime, which is slower. :param args: Extra arguments to systematically forward as transform input for all leaves. Note that, as for Python built-ins methods, keywords are not supported for the sake of efficiency. :param forward_bounds: Whether to forward the lower and upper bounds of the `gym.Space` associated with each leaf as transform input. In this case, they will be added after the data structure provided at runtime but before other extra arguments if any. It is up to the user to make sure all leaves have bounds, otherwise it will raise an exception at generation-time. This argument is ignored if not space is specified. Optional: `True` by default. :returns: Fully-specialized mapping callable. """ def _build_setitem( arity: Literal[0, 1], self_fn: Optional[Callable[..., Dict[str, StructNested[ValueT]]]], value_fn: Callable[..., StructNested[ValueT]], key: Optional[Union[str, int]] ) -> Callable[..., Dict[str, StructNested[ValueT]]]: """Internal method generating a specialized item assignment callable responsible for populating a parent transformed nested data structure with either some child branch already transformed or some leaf to be transformed. This method aims to be composed with itself for recursively creating the whole transformed nested data structure. :param arity: Arity of the generated callable. :param self_fn: Parent branch transform. :param value_fn: Child leaf or branch transform. :param key: Field of the parent transformed nested data structure that must be populated with the output of the child transform. :returns: Specialized item assignment callable. """ # Extract extra arguments from functor if necessary to preserve order is_out, has_args = False, False if isinstance(value_fn, partial): is_out = value_fn.func is not fn if not is_out: dataset, args = value_fn.args[:-1], value_fn.args[-1] value_fn = value_fn.func has_args = bool(args) if arity == 0: value_fn = partial(value_fn, *dataset, *args) elif dataset: value_fn = partial(value_fn, *dataset) is_mapping = isinstance(key, str) if arity == 0: if key is None: return value_fn if is_mapping: def _setitem(self_fn, value_fn, key): self = self_fn() self[key] = value_fn() return self return partial(_setitem, self_fn, value_fn, key) def _setitem(self_fn, value_fn): self = self_fn() self.append(value_fn()) return self return partial(_setitem, self_fn, value_fn) if has_args: if key is None: def _setitem(value_fn, args, delayed): return value_fn(delayed, *args) return partial(_setitem, value_fn, args) if is_mapping: def _setitem(self_fn, value_fn, key, args, delayed): self = self_fn(delayed) self[key] = value_fn(delayed[key], *args) return self return partial(_setitem, self_fn, value_fn, key, args) def _setitem(self_fn, value_fn, key, args, delayed): self = self_fn(delayed) self.append(value_fn(delayed[key], *args)) return self return partial(_setitem, self_fn, value_fn, key, args) if key is None: return value_fn if is_mapping: def _setitem(self_fn, value_fn, key, delayed): self = self_fn(delayed) self[key] = value_fn(delayed[key]) return self return partial(_setitem, self_fn, value_fn, key) def _setitem( # type: ignore[no-redef] self_fn, value_fn, key, delayed): self = self_fn(delayed) self.append(value_fn(delayed[key])) return self return partial(_setitem, self_fn, value_fn, key) def _build_map( arity: Literal[0, 1], parent: Optional[str], data: Optional[DataNested], space: Optional[gym.Space[DataNested]] ) -> Callable[..., Dict[str, StructNested[ValueT]]]: """Internal method for generating specialized callable applying out-of-place transform to all leaves of given nested space. :param arity: Arity of the generated callable. :param parent: Key of parent space mapping to space if any, `None` otherwise. :param data: Possibly nested pre-allocated data. :param space: Possibly nested space on which to operate. :returns: Specialized leaf or branch transform. """ # Determine top-level keys if nested data structure keys: Optional[Union[SequenceT[int], SequenceT[str]]] = None space_or_data = data if data is not None else space if isinstance(space_or_data, Mapping): keys = space_or_data.keys() if isinstance(space_or_data, gym.spaces.Dict): if data is None: container_cls = OrderedDict else: container_cls = type(space_or_data) elif isinstance(space_or_data, MutableMapping): container_cls = type(space_or_data) else: container_cls = dict elif isinstance(space_or_data, Sequence): keys = range(len(space_or_data)) if isinstance(space_or_data, gym.spaces.Tuple): if data is None: container_cls = list else: container_cls = type(space_or_data) elif isinstance(space_or_data, MutableSequence): container_cls = type(space_or_data) else: container_cls = list else: assert isinstance(space_or_data, (gym.Space, np.ndarray)) # Return specialized transform if leaf if keys is None: post_fn = fn if data is None else partial(fn, data) post_args = args if forward_bounds and space is not None: post_args = (*get_bounds(space), *post_args) post_fn = partial(post_fn, post_args) if parent is None: post_fn = _build_setitem(arity, None, post_fn, None) return post_fn # Create new empty container to all transformed values. # FIXME: Immutable containers should be instantiated at the end. def _create(cls: Type[ValueT], *args: Any) -> ValueT: # pylint: disable=unused-argument return cls() out_fn = partial(_create, container_cls) # Apply map recursively while preserving order using monadic operations for field in keys: value = None if data is None else data[field] subspace = None if space is None else space[field] post_fn = _build_map(arity, field, value, subspace) out_fn = _build_setitem(arity, out_fn, post_fn, field) return out_fn def _dispatch( post_fn_0: Callable[[], Dict[str, StructNested[ValueT]]], post_fn_1: Callable[ [Dict[str, DataNested]], Dict[str, StructNested[ValueT]]], *delayed: Tuple[Dict[str, DataNested]] ) -> Dict[str, StructNested[ValueT]]: """Internal method for handling unknown arity at generation-time. :param post_fn_0: Nullary specialized map callable. :param post_fn_1: Unary specialized map callable. :param delayed: Optional nested data structure any provided at runtime. :returns: Specialized map callable of dynamic arity. """ if not delayed: return post_fn_0() return post_fn_1(delayed[0]) # Check that the combination of input arguments are valid if space is None and data is None: raise TypeError("At least data or space must be specified.") if arity not in (0, 1, None): raise TypeError("Arity must be either 0, 1 or `None`.") if isinstance(fn, partial): raise TypeError("Transform function cannot be 'partial' instance.") # Generate transform and reduce callable of various arity if necessary all_fn = [None, None] for i in (0, 1): if arity is not None and i != arity: continue all_fn[i] = _build_map(i, None, data, space) # Return callable of requested arity if specified, dynamic dispatch if not if arity is None: return partial(_dispatch, *all_fn) return all_fn[arity]
[docs] def build_copyto(dst: DataNested) -> Callable[[DataNested], None]: """Generate specialized `copyto` method for a given pre-allocated destination. :param dst: Nested data structure to be updated. """ return build_reduce(array_copyto, None, (dst,), None, 1)
[docs] def build_clip(data: DataNested, space: gym.Space[DataNested]) -> Callable[[], DataNested]: """Generate specialized `clip` method for some pre-allocated nested data structure and corresponding space. :param data: Nested data structure whose leaves must be clipped. :param space: `gym.Space` on which to operate. """ return build_map(_array_clip, data, space, 0)
[docs] def build_contains(data: DataNested, space: gym.Space[DataNested], tol_abs: float = 0.0, tol_rel: float = 0.0) -> Callable[[], bool]: """Generate specialized `contains` method for some pre-allocated nested data structure and corresponding space. :param data: Pre-allocated nested data structure whose leaves must be within bounds if defined and ignored otherwise. :param space: `gym.Space` on which to operate. """ # Define a special exception involved in short-circuit mechanism class ShortCircuitContains(Exception): """Internal exception involved in short-circuit mechanism. """ @nb.jit(nopython=True, cache=True) def _contains_or_raises(value: np.ndarray, low: Optional[ArrayOrScalar], high: Optional[ArrayOrScalar], tol_abs: float, tol_rel: float) -> bool: """Thin wrapper around original `_array_contains` method to raise an exception if the test fails. It enables short-circuit mechanism to abort checking remaining leaves if any. Short-circuit mechanism not only speeds-up scenarios where at least one leaf does not met requirements and also the other scenarios where all tests passes since it is no longer necessary to specify the reduction operator 'operator.and' to keep track of the result. :param value: Array holding values to check. :param low: Lower bound. :param high: Upper bound. :param tol_abs: Absolute tolerance. :param tol_rel: Relative tolerance. """ if not _array_contains(value, low, high, tol_abs, tol_rel): raise ShortCircuitContains("Short-circuit exception.") return True def _exception_handling(out_fn: Callable[[], bool]) -> bool: """Internal method for short-circuit exception handling. :param out_fn: specialized contain callable raising short-circuit exception as soon as one leaf fails the test. :returns: `True` if all leaves are within bounds of their respective space, `False` otherwise. """ try: out_fn() except ShortCircuitContains: return False return True return partial(_exception_handling, build_reduce( _contains_or_raises, None, (data,), space, 0, tol_abs, tol_rel))
[docs] def build_normalize(space: gym.Space[DataNested], dst: DataNested, src: Optional[DataNested] = None, *, is_reversed: bool = False) -> Callable[..., None]: """Generate a normalization or de-normalization method specialized for a given pre-allocated destination. .. note:: The generated method applies element-wise de-normalization to all elements of the leaf spaces having finite bounds. For those that does not, it simply copies the value from 'src' to 'dst'. .. warning:: This method requires all leaf spaces to have type `gym.spaces.Box` with dtype 'np.floating'. :param dst: Nested data structure to updated. :param space: Original (de-normalized) `gym.Space` on which to operate. :param src: Normalized nested data if 'is_reversed' is True, original data (de-normalized) otherwise. `None` to pass it at runtime. Optional: `None` by default. :param is_reversed: True to de-normalize, False to normalize. """ @nb.jit(nopython=True, cache=True) def _array_normalize(dst: np.ndarray, src: np.ndarray, low: np.ndarray, high: np.ndarray, is_reversed: bool) -> None: """Element-wise normalization or de-normalization of array. :param dst: Pre-allocated array into which the result must be stored. :param src: Input array. :param low: Lower bound. :param high: Upper bound. :param is_reversed: True to de-normalize, False to normalize. """ for i, (lo, hi, val) in enumerate(zip(low.flat, high.flat, src.flat)): if not np.isfinite(lo) or not np.isfinite(hi): dst.flat[i] = val elif is_reversed: dst.flat[i] = (lo + hi - val * (lo - hi)) / 2 else: dst.flat[i] = (lo + hi - 2 * val) / (lo - hi) # Make sure that all leaves are `gym.space.Box` with `floating` dtype for subspace in tree.flatten(space): assert isinstance(subspace, gym.spaces.Box) assert np.issubdtype(subspace.dtype, np.floating) dataset = [dst,] if src is not None: dataset.append(src) return build_reduce( _array_normalize, None, dataset, space, 2 - len(dataset), is_reversed)
[docs] def build_flatten(data_nested: DataNested, data_flat: Optional[DataNested] = None, *, is_reversed: Optional[bool] = None ) -> Callable[..., None]: """Generate a flattening or un-flattening method specialized for some pre-allocated nested data. .. note:: Multi-dimensional leaf spaces are supported. Values will be flattened in 1D vectors using 'C' order (row-major). It ignores the actual memory layout the leaves of 'data_nested' and they are not required to have the same dtype as 'data_flat'. :param data_nested: Nested data structure. :param data_flat: Flat array consistent with the nested data structure. Optional iif `is_reversed` is `True`. Optional: `None` by default. :param is_reversed: True to update 'data_flat' (flattening), 'data_nested' otherwise (un-flattening). Optional: True if 'data_flat' is specified, False otherwise. """ # Make sure that the input arguments are valid if is_reversed is None: is_reversed = data_flat is None assert is_reversed or data_flat is not None # Flatten nested data while preserving leaves ordering data_leaves = tree.flatten(data_nested) # Compute slices to split destination in accordance with nested data. # It will be passed to `build_reduce` as an input dataset. It is kind of # hacky since only passing `DataNested` instances is officially supported, # but it is currently the easiest way to keep track of some internal state # and specify leaf-specific constants. flat_slices = [] idx_start = 0 for data in data_leaves: idx_end = idx_start + max(math.prod(data.shape), 1) flat_slices.append((idx_start, idx_end)) idx_start = idx_end @nb.jit(nopython=True, cache=True) def _flatten(data: np.ndarray, flat_slice: Tuple[int, int], data_flat: np.ndarray, is_reversed: bool) -> None: """Synchronize the flatten and un-flatten representation of the data associated with the same leaf space. In practice, it assigns the value of a 1D array slice to some multi- dimensional array, or the other way around. :param data: Multi-dimensional array that will be either updated or copied as a whole depending on 'is_reversed'. :param flat_slice: Start and stop indices of the slice of 'data_flat' to synchronized with 'data'. :param data_flat: 1D array from which to extract that will be either updated or copied depending on 'is_reversed'. :param is_reversed: True to update the multi-dimensional array 'data' by copying the value from slice 'flat_slice' of vector 'data_flat', False for doing the contrary. """ # For some reason, passing a slice as input argument is much slower # in numba than creating it inside the method. if is_reversed: data.ravel()[:] = data_flat[slice(*flat_slice)] else: data_flat[slice(*flat_slice)] = data.ravel() args = (is_reversed,) if data_flat is not None: args = (data_flat, *args) # type: ignore[assignment] out_fn = build_reduce( _flatten, None, (data_leaves, flat_slices), None, 2 - len(args), *args) if data_flat is None: def _repeat(out_fn: Callable[[DataNested], None], n_leaves: int, delayed: DataNested) -> None: """Dispatch flattened data provided at runtime to each transform '_flatten' specialized for all leaves of the original nested space. In practice, it simply repeats the flattened data as many times as the number of leaves of the original nested space before passing them altogether in a tuple as input argument of a function. :param out_fn: Flattening or un-flattening method already specialized for a given pre-allocated nested data. :param n_leaves: Total number of leaves in original nested space. :param delayed: Flattened data provided at runtime. """ out_fn((delayed,) * n_leaves) out_fn = partial(_repeat, out_fn, len(data_leaves)) return out_fn