MultiDiscrete Space
The MultiDiscreteDist class extends the Gymnasium MultiDiscrete space to create categorical distributions
for multiple discrete variables.
Overview
MultiDiscreteDist allows you to create probability distributions for action spaces with multiple
discrete components, each with its own cardinality.
API Reference
- class prob_spaces.multi_discrete.MultiDiscreteDist(nvec: ~numpy.ndarray[tuple[int, ...], ~numpy.dtype[~numpy.integer[~typing.Any]]] | list[int], dtype: str | type[~numpy.integer[~typing.Any]] = <class 'numpy.int64'>, seed: int | ~numpy.random._generator.Generator | None = None, start: ~numpy.ndarray[tuple[int, ...], ~numpy.dtype[~numpy.integer[~typing.Any]]] | list[int] | None = None)[source]
Bases:
MultiDiscrete- __call__(prob: Tensor, mask: Tensor = None) MaskedCategorical[source]
Applies a transformation to the input probability tensor and optional mask, creating a MaskedCategorical distribution. The method reshapes the input probabilities to match the specified nvec dimensions, applies an optional mask for masking specific probabilities, and combines these with an internal mask. The result is used to create a MaskedCategorical distribution.
- Parameters:
prob (th.Tensor) – A tensor containing probabilities to be reshaped and used in constructing the distribution.
mask (th.Tensor, optional) – An optional boolean tensor for masking specific probabilities before creating the distribution. Defaults to None.
- Returns:
A MaskedCategorical distribution object created with reshaped probabilities and combined masking information.
- Return type:
MaskedCategorical
Key Attributes
nvec: Array of integers representing the number of values for each discrete variablestart: Optional starting indices for each variableinternal_mask: Automatically generated mask to ensure valid actions
Usage Examples
Basic usage:
import numpy as np
import torch as th
from prob_spaces.multi_discrete import MultiDiscreteDist
# Create a multi-discrete space with 3 variables:
# - First variable has 2 possible values (0, 1)
# - Second variable has 3 possible values (0, 1, 2)
# - Third variable has 4 possible values (0, 1, 2, 3)
nvec = np.array([2, 3, 4])
space = MultiDiscreteDist(nvec=nvec)
# Create logits for each variable
# The shape should be (nvec shape) + (max(nvec) + 1)
# In this case: (3, 5)
probs = th.ones((3, 5))
# Create a distribution
dist = space(probs)
# Sample an action
action = dist.sample()
# Compute log probability
log_prob = dist.log_prob(action)
With masking:
import numpy as np
import torch as th
from prob_spaces.multi_discrete import MultiDiscreteDist
nvec = np.array([2, 3, 4])
space = MultiDiscreteDist(nvec=nvec)
probs = th.ones((3, 5))
# Create a mask to disallow certain actions
mask = th.ones((3, 5), dtype=th.bool)
mask[0, 1] = False # Disallow action 1 for first variable
# Create a distribution with the mask
dist = space(probs, mask=mask)