Source code for prob_spaces.discrete

"""Module for probability distributions over Discrete spaces."""

import torch as th
from gymnasium import spaces

from prob_spaces.dists.categorical import CategoricalDist, MaskedCategorical


[docs] class DiscreteDist(spaces.Discrete): """Probability distribution for Discrete spaces."""
[docs] def __call__(self, prob: th.Tensor, mask: th.Tensor = None) -> MaskedCategorical: """Compute and return a masked categorical distribution. Compute a masked categorical distribution based on the given probability tensor and an optional mask. The distribution incorporates specific probabilities and constraints defined by the provided input. :param prob: A tensor representing the probabilities for each category. :param mask: A tensor specifying a mask to limit the valid categories. Defaults to a tensor of ones if not provided. :return: A MaskedCategorical distribution constructed with given probabilities, mask, and starting values. Returns ------- MaskedCategorical A MaskedCategorical distribution constructed with given probabilities, mask, and starting values. """ probs = prob.reshape(self.n) # type: ignore start = self.start mask = mask if mask is not None else th.ones_like(probs, dtype=th.bool, device=probs.device) dist = CategoricalDist(probs, mask=mask, start=start) return dist
[docs] @classmethod def from_space(cls, space: spaces.Discrete) -> "DiscreteDist": """Create a DiscreteDist from a gymnasium Discrete space. Returns ------- DiscreteDist An instance of DiscreteDist created from the given gymnasium Discrete space. """ return cls( n=space.n, start=space.start, # dtype=space.dtype, )