Source code for prob_spaces.multi_discrete

from typing import Any

import numpy as np
import torch as th
from gymnasium import spaces
from numpy.typing import NDArray

from prob_spaces.dists.categorical import CategoricalDist, MaskedCategorical


[docs] class MultiDiscreteDist(spaces.MultiDiscrete): def __init__( self, nvec: NDArray[np.integer[Any]] | list[int], dtype: str | type[np.integer[Any]] = np.int64, seed: int | np.random.Generator | None = None, start: NDArray[np.integer[Any]] | list[int] | None = None, ): super().__init__(nvec, dtype, seed, start) self.internal_mask = self._internal_mask() @property def prob_last_dim(self) -> int: return int(np.max(self.nvec)) + 1 def _internal_mask(self) -> NDArray[np.bool_]: prob_last_dim = self.prob_last_dim shape = (*self.nvec.shape, self.prob_last_dim) mask = np.zeros(shape=shape, dtype=np.bool) max_arrange = np.arange(start=0, stop=prob_last_dim) all_actions = np.zeros_like(mask, dtype=self.nvec.dtype) all_actions[..., :] = max_arrange diffs = np.abs(self.nvec) c_diffs = np.broadcast_to(diffs[..., np.newaxis], shape) mask[c_diffs > all_actions] = True return mask
[docs] def __call__(self, prob: th.Tensor, mask: th.Tensor = None) -> MaskedCategorical: """ Applies a transformation to the input probability tensor and optional mask, creating a `MaskedCategorical` distribution. The method reshapes the input probabilities to match the specified `nvec` dimensions, applies an optional mask for masking specific probabilities, and combines these with an internal mask. The result is used to create a `MaskedCategorical` distribution. :param prob: A tensor containing probabilities to be reshaped and used in constructing the distribution. :type prob: th.Tensor :param mask: An optional boolean tensor for masking specific probabilities before creating the distribution. Defaults to None. :type mask: th.Tensor, optional :return: A `MaskedCategorical` distribution object created with reshaped probabilities and combined masking information. :rtype: MaskedCategorical """ probs = prob.reshape(*self.nvec.shape, self.prob_last_dim) start = self.start mask = mask if mask is not None else th.ones_like(probs, dtype=th.bool, device=probs.device) mask = th.logical_and(mask, th.tensor(self.internal_mask, dtype=th.bool, device=probs.device)) dist = CategoricalDist(probs, mask=mask, start=start) return dist
[docs] @classmethod def from_space(cls, space: spaces.MultiDiscrete) -> "MultiDiscreteDist": """Convert a gymnasium space to a MultiDiscreteDist.""" return cls(nvec=space.nvec, dtype=space.dtype, start=space.start) # type: ignore