MultiDiscrete Space

The MultiDiscreteDist class extends the Gymnasium MultiDiscrete space to create categorical distributions for multiple discrete variables.

Overview

MultiDiscreteDist allows you to create probability distributions for action spaces with multiple discrete components, each with its own cardinality.

API Reference

class prob_spaces.multi_discrete.MultiDiscreteDist(nvec: ~numpy.ndarray[tuple[int, ...], ~numpy.dtype[~numpy.integer[~typing.Any]]] | list[int], dtype: str | type[~numpy.integer[~typing.Any]] = <class 'numpy.int64'>, seed: int | ~numpy.random._generator.Generator | None = None, start: ~numpy.ndarray[tuple[int, ...], ~numpy.dtype[~numpy.integer[~typing.Any]]] | list[int] | None = None)[source]

Bases: MultiDiscrete

__call__(prob: Tensor, mask: Tensor = None) MaskedCategorical[source]

Applies a transformation to the input probability tensor and optional mask, creating a MaskedCategorical distribution. The method reshapes the input probabilities to match the specified nvec dimensions, applies an optional mask for masking specific probabilities, and combines these with an internal mask. The result is used to create a MaskedCategorical distribution.

Parameters:
  • prob (th.Tensor) – A tensor containing probabilities to be reshaped and used in constructing the distribution.

  • mask (th.Tensor, optional) – An optional boolean tensor for masking specific probabilities before creating the distribution. Defaults to None.

Returns:

A MaskedCategorical distribution object created with reshaped probabilities and combined masking information.

Return type:

MaskedCategorical

Key Attributes

  • nvec: Array of integers representing the number of values for each discrete variable

  • start: Optional starting indices for each variable

  • internal_mask: Automatically generated mask to ensure valid actions

Usage Examples

Basic usage:

import numpy as np
import torch as th
from prob_spaces.multi_discrete import MultiDiscreteDist

# Create a multi-discrete space with 3 variables:
# - First variable has 2 possible values (0, 1)
# - Second variable has 3 possible values (0, 1, 2)
# - Third variable has 4 possible values (0, 1, 2, 3)
nvec = np.array([2, 3, 4])
space = MultiDiscreteDist(nvec=nvec)

# Create logits for each variable
# The shape should be (nvec shape) + (max(nvec) + 1)
# In this case: (3, 5)
probs = th.ones((3, 5))

# Create a distribution
dist = space(probs)

# Sample an action
action = dist.sample()

# Compute log probability
log_prob = dist.log_prob(action)

With masking:

import numpy as np
import torch as th
from prob_spaces.multi_discrete import MultiDiscreteDist

nvec = np.array([2, 3, 4])
space = MultiDiscreteDist(nvec=nvec)

probs = th.ones((3, 5))

# Create a mask to disallow certain actions
mask = th.ones((3, 5), dtype=th.bool)
mask[0, 1] = False  # Disallow action 1 for first variable

# Create a distribution with the mask
dist = space(probs, mask=mask)