Source code for akro.discrete
"""A space representing a selection between a finite number of items."""
import numpy as np
from akro.space import Space
[docs]class Discrete(Space):
"""{0,1,...,n-1}."""
def __init__(self, n):
self._n = n
@property
def n(self):
"""Return the number of elements in the Discrete space."""
return self._n
[docs] def sample(self):
"""Uniformly randomly sample a random element of this space."""
return np.random.randint(self.n)
[docs] def contains(self, x):
"""Return boolean specifying if x is a valid member of this space."""
x = np.asarray(x)
return x.shape == () and x.dtype.kind == 'i' and x >= 0 and x < self.n
def __repr__(self):
"""Compute a representation of the space."""
return "Discrete(%d)" % self.n
def __eq__(self, other):
"""Compare two Discrete Spaces for equality."""
if not isinstance(other, Discrete):
return False
return self.n == other.n
[docs] def flatten(self, x):
"""
Return a flattened observation x.
Returns:
x (flattened)
"""
ret = np.zeros(self.n)
ret[x] = 1
return ret
[docs] def unflatten(self, x):
"""
Return an unflattened observation x.
Returns:
x (unflattened)
"""
return np.nonzero(x)[0][0]
[docs] def flatten_n(self, x):
"""
Return flattened observations xs.
Returns:
xs (flattened)
"""
ret = np.zeros((len(x), self.n))
ret[np.arange(len(x)), x] = 1
return ret
[docs] def unflatten_n(self, x):
"""
Return unflattened observations xs.
Returns:
xs (unflattened)
"""
return np.nonzero(x)[1]
@property
def flat_dim(self):
"""Return the length of the flattened vector of the space."""
return self.n
[docs] def weighted_sample(self, weights):
"""Compute a weighted sample of the elements in the Discrete Space."""
# An array of the weights, cumulatively summed.
cs = np.cumsum(weights)
# Find the index of the first weight over a random value.
idx = sum(cs < np.random.rand())
return min(idx, self.n - 1)
@property
def default_value(self):
"""Return the default value of the spaceself.
This is always just 0.
"""
return 0
def __hash__(self):
"""Hash the Discrete Space."""
return hash(self.n)
[docs] def new_tensor_variable(self, name, extra_dims):
"""
Create a tensor variable given the name and extra dimensions.
:param name: name of the variable
:param extra_dims: extra dimensions in the front
:return: the created tensor variable
"""
raise NotImplementedError