Source code for akro.discrete
"""A space representing a selection between a finite number of items."""
import gym.spaces
import numpy as np
from akro import tf, theano
from akro.requires import requires_tf, requires_theano
from akro.space import Space
[docs]class Discrete(gym.spaces.Discrete, Space):
"""{0,1,...,n-1}."""
[docs] def flatten(self, x):
"""Return a flattened observation x.
Args:
x (:obj:`Iterable`): The object to flatten.
Returns:
np.ndarray: An array of x collapsed into one dimension.
"""
ret = np.zeros(self.n)
ret[x] = 1
return ret
[docs] def unflatten(self, x):
"""Return an unflattened observation x.
Args:
x (:obj:`Iterable`): The object to unflatten.
Returns:
np.ndarray: An array of x in the shape of self.shape.
"""
return np.nonzero(x)[0][0]
[docs] def flatten_n(self, xs):
"""Return flattened observations xs.
Args:
xs (:obj:`Iterable`): The object to reshape and flatten
Returns:
np.ndarray: An array of xs in a shape inferred by the size of
its first element.
"""
ret = np.zeros((len(xs), self.n))
ret[np.arange(len(xs)), xs] = 1
return ret
[docs] def unflatten_n(self, xs):
"""Return unflattened observations xs.
Args:
xs (:obj:`Iterable`): The object to reshape and unflatten
Returns:
np.ndarray: An array of xs in a shape inferred by the size of
its first element and self.shape.
"""
return np.nonzero(xs)[1]
@property
def flat_dim(self):
"""Return the length of the flattened vector of the space."""
return self.n
[docs] def weighted_sample(self, weights):
"""Compute a weighted sample of the elements in the Discrete Space.
Args:
weights (:obj:`list`): Values to use in the sample.
Returns:
int or np.ndarray: A random sample of n based on
probabilities in weights.
"""
assert len(weights) == self.n
weights = np.asarray(weights)
return np.random.choice(self.n, p=weights / weights.sum())
[docs] def concat(self, other):
"""Concatenate with another space of the same type.
Args:
other (Space): A space to be concatenated with this space.
Returns:
Space: A concatenated space.
"""
raise NotImplementedError
def __hash__(self):
"""Hash the Discrete Space.
Returns:
int: A hash of the value n.
"""
return hash(self.n)
[docs] @requires_tf
def to_tf_placeholder(self, name, batch_dims):
"""Create a tensor placeholder from the Space object.
Args:
name (str): name of the variable
batch_dims (:obj:`list`): batch dimensions to add to the
shape of the object.
Returns:
tf.Tensor: Tensor object with the same properties as
the Discrete obj where the shape is modified by batch_dims.
"""
return tf.compat.v1.placeholder(dtype=self.dtype,
shape=[None] * batch_dims +
[self.flat_dim],
name=name)
[docs] @requires_theano
def to_theano_tensor(self, name, batch_dims):
"""Create a theano tensor from the Space object.
Args:
name (str): name of the variable
batch_dims (:obj:`list`): batch dimensions to add to the
shape of the object.
Returns:
theano.tensor.TensorVariable: Tensor object with the
same properties as the Discrete obj where the shape is
modified by batch_dims..
"""
return theano.tensor.TensorType(self.dtype,
(False, ) * (batch_dims + 1))(name)