"""Module for probability distributions over Box spaces."""
from typing import Any, Sequence, SupportsFloat, Type
import numpy as np
import torch as th
from gymnasium import spaces
from numpy.typing import NDArray
from torch.distributions import TransformedDistribution
from torch.distributions.transforms import AffineTransform, SigmoidTransform
[docs]
class BoxDist(spaces.Box):
"""Probability distribution for Box spaces."""
def __init__(
self,
low: SupportsFloat | NDArray[Any],
high: SupportsFloat | NDArray[Any],
shape: Sequence[int] | None = None,
dtype: type[np.floating[Any]] | type[np.integer[Any]] = np.float32,
seed: int | np.random.Generator | None = None,
dist: None | Type[th.distributions.Distribution] = None,
):
"""Initialize BoxDist with bounds, shape, dtype, and base distribution."""
super().__init__(low, high, shape, dtype, seed)
self.base_dist = dist or th.distributions.Normal
def transforms(self, device: th.device) -> list:
"""Return list of transforms to map base distribution to Box bounds.
Returns
-------
list
List of transforms to map the base distribution to Box bounds.
"""
t_low = th.tensor(self.low, device=device)
t_high = th.tensor(self.high, device=device)
range_value = t_high - t_low
offset = t_low
transforms: list = []
if self.base_dist != th.distributions.Beta:
transforms.append(SigmoidTransform())
transforms.append(AffineTransform(loc=offset, scale=range_value, event_dim=1))
return transforms
[docs]
def __call__(self, loc: th.Tensor, scale: th.Tensor) -> th.distributions.Distribution:
"""Generate a transformed probability distribution.
Construct a base distribution, apply a sequence of transformations, and return the resulting
transformed distribution.
:param loc: A tensor specifying the location parameters for the base distribution.
:param scale: A tensor specifying the scale parameters for the base distribution.
:return: A transformed distribution object derived from the specified base distribution and
transformations.
Returns
-------
th.distributions.Distribution
A transformed distribution object derived from the specified base distribution and transformations.
"""
dist = self.base_dist(loc, scale, validate_args=True) # type: ignore
transforms = self.transforms(loc.device)
transformed_dist = TransformedDistribution(dist, transforms, validate_args=True)
return transformed_dist
@classmethod
def from_space(cls, space: spaces.Box) -> "BoxDist":
"""Create a BoxDist from a gymnasium Box space.
Returns
-------
BoxDist
An instance of BoxDist created from the given gymnasium Box space.
"""
low = space.low
high = space.high
dtype = space.dtype
shape = space.shape
return cls(low=low, high=high, shape=shape, dtype=dtype) # type: ignore