diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 6ae25d2012d6d..c56da5805ad66 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -33,6 +33,8 @@ from paddle.distribution.uniform import Uniform from paddle.distribution.laplace import Laplace from paddle.distribution.geometric import Geometric +from paddle.distribution.binomial import Binomial +from paddle.distribution.poisson import Poisson __all__ = [ 'Bernoulli', @@ -55,6 +57,8 @@ 'LogNormal', 'Gumbel', 'Geometric', + 'Binomial', + 'Poisson', ] __all__.extend(transform.__all__) diff --git a/python/paddle/distribution/binomial.py b/python/paddle/distribution/binomial.py new file mode 100644 index 0000000000000..9bf5ec41faaad --- /dev/null +++ b/python/paddle/distribution/binomial.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence + +import paddle +from paddle.distribution import distribution + + +class Binomial(distribution.Distribution): + r""" + The Binomial distribution with size `total_count` and `probs` parameters. + + In probability theory and statistics, the binomial distribution is the most basic discrete probability distribution defined on :math:`[0, n] \cap \mathbb{N}`, + which can be viewed as the number of times a potentially unfair coin is tossed to get heads, and the result + of its random variable can be viewed as the sum of a series of independent Bernoulli experiments. + + The probability mass function (pmf) is + + .. math:: + + pmf(x; n, p) = \frac{n!}{x!(n-x)!}p^{x}(1-p)^{n-x} + + In the above equation: + + * :math:`total\_count = n`: is the size, meaning the total number of Bernoulli experiments. + * :math:`probs = p`: is the probability of the event happening in one Bernoulli experiments. + + Args: + total_count(int|Tensor): The size of Binomial distribution which should be greater than 0, meaning the number of independent bernoulli + trials with probability parameter :math:`p`. The data type will be converted to 1-D Tensor with paddle global default dtype if the input + :attr:`probs` is not Tensor, otherwise will be converted to the same as :attr:`probs`. + probs(float|Tensor): The probability of Binomial distribution which should reside in [0, 1], meaning the probability of success + for each individual bernoulli trial. If the input data type is float, it will be converted to a 1-D Tensor with paddle global default dtype. + + Examples: + .. code-block:: python + + >>> import paddle + >>> from paddle.distribution import Binomial + >>> paddle.set_device('cpu') + >>> paddle.seed(100) + >>> rv = Binomial(100, paddle.to_tensor([0.3, 0.6, 0.9])) + + >>> print(rv.sample([2])) + Tensor(shape=[2, 3], dtype=float32, place=Place(cpu), stop_gradient=True, + [[31., 62., 93.], + [29., 54., 91.]]) + + >>> print(rv.mean) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [30.00000191, 60.00000381, 90. ]) + + >>> print(rv.entropy()) + Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, + [2.94053698, 3.00781751, 2.51124287]) + """ + + def __init__(self, total_count, probs): + self.dtype = paddle.get_default_dtype() + self.total_count, self.probs = self._to_tensor(total_count, probs) + + if not self._check_constraint(self.total_count, self.probs): + raise ValueError( + 'Every element of input parameter `total_count` should be grater than or equal to one, and `probs` should be grater than or equal to zero and less than or equal to one.' + ) + if self.total_count.shape == []: + batch_shape = (1,) + else: + batch_shape = self.total_count.shape + super().__init__(batch_shape) + + def _to_tensor(self, total_count, probs): + """Convert the input parameters into Tensors if they were not and broadcast them + + Returns: + Tuple[Tensor, Tensor]: converted total_count and probs. + """ + # convert type + if isinstance(probs, float): + probs = paddle.to_tensor(probs, dtype=self.dtype) + else: + self.dtype = probs.dtype + if isinstance(total_count, int): + total_count = paddle.to_tensor(total_count, dtype=self.dtype) + else: + total_count = paddle.cast(total_count, dtype=self.dtype) + + # broadcast tensor + return paddle.broadcast_tensors([total_count, probs]) + + def _check_constraint(self, total_count, probs): + """Check the constraints for input parameters + + Args: + total_count (Tensor) + probs (Tensor) + + Returns: + bool: pass or not. + """ + total_count_check = (total_count >= 1).all() + probability_check = (probs >= 0).all() * (probs <= 1).all() + return total_count_check and probability_check + + @property + def mean(self): + """Mean of binomial distribuion. + + Returns: + Tensor: mean value. + """ + return self.total_count * self.probs + + @property + def variance(self): + """Variance of binomial distribution. + + Returns: + Tensor: variance value. + """ + return self.total_count * self.probs * (1 - self.probs) + + def sample(self, shape=()): + """Generate binomial samples of the specified shape. The final shape would be ``shape+batch_shape`` . + + Args: + shape (Sequence[int], optional): Prepended shape of the generated samples. + + Returns: + Tensor: Sampled data with shape `sample_shape` + `batch_shape`. The returned data type is the same as `probs`. + """ + if not isinstance(shape, Sequence): + raise TypeError('sample shape must be Sequence object.') + + with paddle.set_grad_enabled(False): + shape = tuple(shape) + batch_shape = tuple(self.batch_shape) + output_shape = tuple(shape + batch_shape) + output_size = paddle.broadcast_to( + self.total_count, shape=output_shape + ) + output_prob = paddle.broadcast_to(self.probs, shape=output_shape) + sample = paddle.binomial( + paddle.cast(output_size, dtype="int32"), output_prob + ) + return paddle.cast(sample, self.dtype) + + def entropy(self): + r"""Shannon entropy in nats. + + The entropy is + + .. math:: + + \mathcal{H}(X) = - \sum_{x \in \Omega} p(x) \log{p(x)} + + In the above equation: + + * :math:`\Omega`: is the support of the distribution. + + Returns: + Tensor: Shannon entropy of binomial distribution. The data type is the same as `probs`. + """ + values = self._enumerate_support() + log_prob = self.log_prob(values) + return -(paddle.exp(log_prob) * log_prob).sum(0) + + def _enumerate_support(self): + """Return the support of binomial distribution [0, 1, ... ,n] + + Returns: + Tensor: the support of binomial distribution + """ + values = paddle.arange( + 1 + paddle.max(self.total_count), dtype=self.dtype + ) + values = values.reshape((-1,) + (1,) * len(self.batch_shape)) + return values + + def log_prob(self, value): + """Log probability density/mass function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: log probability. The data type is the same as `probs`. + """ + value = paddle.cast(value, dtype=self.dtype) + + # combination + log_comb = ( + paddle.lgamma(self.total_count + 1.0) + - paddle.lgamma(self.total_count - value + 1.0) + - paddle.lgamma(value + 1.0) + ) + eps = paddle.finfo(self.probs.dtype).eps + probs = paddle.clip(self.probs, min=eps, max=1 - eps) + # log_p + return paddle.nan_to_num( + ( + log_comb + + value * paddle.log(probs) + + (self.total_count - value) * paddle.log(1 - probs) + ), + neginf=-eps, + ) + + def prob(self, value): + """Probability density/mass function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: probability. The data type is the same as `probs`. + """ + return paddle.exp(self.log_prob(value)) + + def kl_divergence(self, other): + r"""The KL-divergence between two binomial distributions with the same :attr:`total_count`. + + The probability density function (pdf) is + + .. math:: + + KL\_divergence(n_1, p_1, n_2, p_2) = \sum_x p_1(x) \log{\frac{p_1(x)}{p_2(x)}} + + .. math:: + + p_1(x) = \frac{n_1!}{x!(n_1-x)!}p_1^{x}(1-p_1)^{n_1-x} + + .. math:: + + p_2(x) = \frac{n_2!}{x!(n_2-x)!}p_2^{x}(1-p_2)^{n_2-x} + + Args: + other (Binomial): instance of ``Binomial``. + + Returns: + Tensor: kl-divergence between two binomial distributions. The data type is the same as `probs`. + + """ + if not (paddle.equal(self.total_count, other.total_count)).all(): + raise ValueError( + "KL divergence of two binomial distributions should share the same `total_count` and `batch_shape`." + ) + support = self._enumerate_support() + log_prob_1 = self.log_prob(support) + log_prob_2 = other.log_prob(support) + return ( + paddle.multiply( + paddle.exp(log_prob_1), + (paddle.subtract(log_prob_1, log_prob_2)), + ) + ).sum(0) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 27e12a4309c2e..ecec0f425d2d6 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -17,6 +17,7 @@ import paddle from paddle.distribution.bernoulli import Bernoulli from paddle.distribution.beta import Beta +from paddle.distribution.binomial import Binomial from paddle.distribution.categorical import Categorical from paddle.distribution.cauchy import Cauchy from paddle.distribution.continuous_bernoulli import ContinuousBernoulli @@ -28,6 +29,7 @@ from paddle.distribution.lognormal import LogNormal from paddle.distribution.multivariate_normal import MultivariateNormal from paddle.distribution.normal import Normal +from paddle.distribution.poisson import Poisson from paddle.distribution.uniform import Uniform from paddle.framework import in_dynamic_mode @@ -167,6 +169,11 @@ def _kl_beta_beta(p, q): ) +@register_kl(Binomial, Binomial) +def _kl_binomial_binomial(p, q): + return p.kl_divergence(q) + + @register_kl(Dirichlet, Dirichlet) def _kl_dirichlet_dirichlet(p, q): return ( @@ -269,5 +276,10 @@ def _kl_lognormal_lognormal(p, q): return p._base.kl_divergence(q._base) +@register_kl(Poisson, Poisson) +def _kl_poisson_poisson(p, q): + return p.kl_divergence(q) + + def _sum_rightmost(value, n): return value.sum(list(range(-n, 0))) if n > 0 else value diff --git a/python/paddle/distribution/poisson.py b/python/paddle/distribution/poisson.py new file mode 100644 index 0000000000000..4cd50962f085d --- /dev/null +++ b/python/paddle/distribution/poisson.py @@ -0,0 +1,275 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Sequence + +import paddle +from paddle.distribution import distribution + + +class Poisson(distribution.Distribution): + r""" + The Poisson distribution with occurrence rate parameter: `rate`. + + In probability theory and statistics, the Poisson distribution is the most basic discrete probability + distribution defined on the nonnegative integer set, which is used to describe the probability distribution of the number of random + events occurring per unit time. + + The probability mass function (pmf) is + + .. math:: + + pmf(x; \lambda) = \frac{e^{-\lambda} \cdot \lambda^x}{x!} + + In the above equation: + + * :math:`rate = \lambda`: is the mean occurrence rate. + + Args: + rate(int|float|Tensor): The mean occurrence rate of Poisson distribution which should be greater than 0, meaning the expected occurrence + times of an event in a fixed time interval. If the input data type is int or float, the data type of `rate` will be converted to a + 1-D Tensor with paddle global default dtype. + + Examples: + .. code-block:: python + + >>> import paddle + >>> from paddle.distribution import Poisson + >>> paddle.set_device('cpu') + >>> paddle.seed(100) + >>> rv = Poisson(paddle.to_tensor(30.0)) + + >>> print(rv.sample([3])) + Tensor(shape=[3, 1], dtype=float32, place=Place(cpu), stop_gradient=True, + [[35.], + [35.], + [30.]]) + + >>> print(rv.mean) + Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, + 30.) + + >>> print(rv.entropy()) + Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + [3.11671066]) + + >>> rv1 = Poisson(paddle.to_tensor([[30.,40.],[8.,5.]])) + >>> rv2 = Poisson(paddle.to_tensor([[1000.,40.],[7.,10.]])) + >>> print(rv1.kl_divergence(rv2)) + Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True, + [[864.80285645, 0. ], + [0.06825157 , 1.53426421 ]]) + """ + + def __init__(self, rate): + self.dtype = paddle.get_default_dtype() + self.rate = self._to_tensor(rate) + + if not self._check_constraint(self.rate): + raise ValueError( + 'Every element of input parameter `rate` should be nonnegative.' + ) + if self.rate.shape == []: + batch_shape = (1,) + else: + batch_shape = self.rate.shape + super().__init__(batch_shape) + + def _to_tensor(self, rate): + """Convert the input parameters into tensors. + + Returns: + Tensor: converted rate. + """ + # convert type + if isinstance(rate, (float, int)): + rate = paddle.to_tensor([rate], dtype=self.dtype) + else: + self.dtype = rate.dtype + return rate + + def _check_constraint(self, value): + """Check the constraint for input parameters + + Args: + value (Tensor) + + Returns: + bool: pass or not. + """ + return (value >= 0).all() + + @property + def mean(self): + """Mean of poisson distribuion. + + Returns: + Tensor: mean value. + """ + return self.rate + + @property + def variance(self): + """Variance of poisson distribution. + + Returns: + Tensor: variance value. + """ + return self.rate + + def sample(self, shape=()): + """Generate poisson samples of the specified shape. The final shape would be ``shape+batch_shape`` . + + Args: + shape (Sequence[int], optional): Prepended shape of the generated samples. + + Returns: + Tensor: Sampled data with shape `sample_shape` + `batch_shape`. + """ + if not isinstance(shape, Sequence): + raise TypeError('sample shape must be Sequence object.') + + shape = tuple(shape) + batch_shape = tuple(self.batch_shape) + output_shape = tuple(shape + batch_shape) + output_rate = paddle.broadcast_to(self.rate, shape=output_shape) + + with paddle.no_grad(): + return paddle.poisson(output_rate) + + def entropy(self): + r"""Shannon entropy in nats. + + The entropy is + + .. math:: + + \mathcal{H}(X) = - \sum_{x \in \Omega} p(x) \log{p(x)} + + In the above equation: + + * :math:`\Omega`: is the support of the distribution. + + Returns: + Tensor: Shannon entropy of poisson distribution. The data type is the same as `rate`. + """ + values = self._enumerate_bounded_support(self.rate).reshape( + (-1,) + (1,) * len(self.batch_shape) + ) + log_prob = self.log_prob(values) + proposed = -(paddle.exp(log_prob) * log_prob).sum(0) + mask = paddle.cast( + paddle.not_equal( + self.rate, paddle.to_tensor(0.0, dtype=self.dtype) + ), + dtype=self.dtype, + ) + return paddle.multiply(proposed, mask) + + def _enumerate_bounded_support(self, rate): + """Generate a bounded approximation of the support. Approximately view Poisson r.v. as a + Normal r.v. with mu = rate and sigma = sqrt(rate). Then by 30-sigma rule, generate a bounded + approximation of the support. + + Args: + rate (float): rate of one poisson r.v. + + Returns: + Tensor: the bounded approximation of the support + """ + s_max = ( + paddle.sqrt(paddle.max(rate)) + if paddle.greater_equal( + paddle.max(rate), paddle.to_tensor(1.0, dtype=self.dtype) + ) + else paddle.ones_like(rate, dtype=self.dtype) + ) + upper = paddle.max(paddle.cast(rate + 30 * s_max, dtype="int32")) + values = paddle.arange(0, upper, dtype=self.dtype) + return values + + def log_prob(self, value): + """Log probability density/mass function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: log probability. The data type is the same as `rate`. + """ + value = paddle.cast(value, dtype=self.dtype) + if not self._check_constraint(value): + raise ValueError( + 'Every element of input parameter `value` should be nonnegative.' + ) + eps = paddle.finfo(self.rate.dtype).eps + return paddle.nan_to_num( + ( + -self.rate + + value * paddle.log(self.rate) + - paddle.lgamma(value + 1) + ), + neginf=-eps, + ) + + def prob(self, value): + """Probability density/mass function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: probability. The data type is the same as `rate`. + """ + return paddle.exp(self.log_prob(value)) + + def kl_divergence(self, other): + r"""The KL-divergence between two poisson distributions with the same `batch_shape`. + + The probability density function (pdf) is + + .. math:: + + KL\_divergence\lambda_1, \lambda_2) = \sum_x p_1(x) \log{\frac{p_1(x)}{p_2(x)}} + + .. math:: + + p_1(x) = \frac{e^{-\lambda_1} \cdot \lambda_1^x}{x!} + + .. math:: + + p_2(x) = \frac{e^{-\lambda_2} \cdot \lambda_2^x}{x!} + + Args: + other (Poisson): instance of ``Poisson``. + + Returns: + Tensor, kl-divergence between two poisson distributions. The data type is the same as `rate`. + + """ + + if self.batch_shape != other.batch_shape: + raise ValueError( + "KL divergence of two poisson distributions should share the same `batch_shape`." + ) + rate_max = paddle.max(paddle.maximum(self.rate, other.rate)) + support_max = self._enumerate_bounded_support(rate_max) + a_max = paddle.max(support_max) + common_support = paddle.arange(0, a_max, dtype=self.dtype).reshape( + (-1,) + (1,) * len(self.batch_shape) + ) + + log_prob_1 = self.log_prob(common_support) + log_prob_2 = other.log_prob(common_support) + return (paddle.exp(log_prob_1) * (log_prob_1 - log_prob_2)).sum(0) diff --git a/test/distribution/test_distribution_binomial.py b/test/distribution/test_distribution_binomial.py new file mode 100644 index 0000000000000..79d860b5786b6 --- /dev/null +++ b/test/distribution/test_distribution_binomial.py @@ -0,0 +1,208 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +import scipy.stats +from distribution import config + +import paddle +from paddle.distribution.binomial import Binomial + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs'), + [ + ( + 'one-dim', + 1000, + np.array([0.4]).astype('float32'), + ), + ( + 'multi-dim-total_count-probability', + parameterize.xrand((2, 1), min=1, max=100).astype('int32'), + parameterize.xrand((2, 3), dtype='float64', min=0.3, max=1), + ), + ], +) +class TestBinomial(unittest.TestCase): + def setUp(self): + self._dist = Binomial( + total_count=paddle.to_tensor(self.total_count), + probs=paddle.to_tensor(self.probs), + ) + + def test_mean(self): + mean = self._dist.mean + self.assertEqual(mean.numpy().dtype, self.probs.dtype) + np.testing.assert_allclose( + mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + def test_variance(self): + var = self._dist.variance + self.assertEqual(var.numpy().dtype, self.probs.dtype) + np.testing.assert_allclose( + var, + self._np_variance(), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + def test_entropy(self): + entropy = self._dist.entropy() + self.assertEqual(entropy.numpy().dtype, self.probs.dtype) + np.testing.assert_allclose( + entropy, + self._np_entropy(), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + def test_sample(self): + sample_shape = () + samples = self._dist.sample(sample_shape) + self.assertEqual( + tuple(samples.shape), + sample_shape + self._dist.batch_shape + self._dist.event_shape, + ) + + sample_shape = (5000,) + samples = self._dist.sample(sample_shape) + sample_mean = samples.mean(axis=0) + sample_variance = samples.var(axis=0) + + np.testing.assert_allclose( + sample_mean, self._dist.mean, atol=0, rtol=0.20 + ) + np.testing.assert_allclose( + sample_variance, self._dist.variance, atol=0, rtol=0.20 + ) + + def _np_variance(self): + return scipy.stats.binom.var(self.total_count, self.probs) + + def _np_mean(self): + return scipy.stats.binom.mean(self.total_count, self.probs) + + def _np_entropy(self): + return scipy.stats.binom.entropy(self.total_count, self.probs) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'), + [ + ( + 'value-same-shape', + 1000, + np.array([0.12, 0.3, 0.85]).astype('float64'), + np.array([2.0, 55.0, 999.0]).astype('float64'), + ), + ( + 'value-broadcast-shape', + 10, + np.array([[0.3, 0.7], [0.5, 0.5]]), + np.array([[[4.0, 6], [8, 2]], [[2.0, 4], [9, 7]]]), + ), + ], +) +class TestBinomialProbs(unittest.TestCase): + def setUp(self): + self._dist = Binomial( + total_count=self.total_count, + probs=paddle.to_tensor(self.probs), + ) + + def test_prob(self): + np.testing.assert_allclose( + self._dist.prob(paddle.to_tensor(self.value)), + scipy.stats.binom.pmf(self.value, self.total_count, self.probs), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + def test_log_prob(self): + np.testing.assert_allclose( + self._dist.log_prob(paddle.to_tensor(self.value)), + scipy.stats.binom.logpmf(self.value, self.total_count, self.probs), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'n_1', 'p_1', 'n_2', 'p_2'), + [ + ( + 'one-dim-probability', + np.array([3333]), + parameterize.xrand((1,), dtype='float32', min=0, max=1), + np.array([3333]), + parameterize.xrand((1,), dtype='float32', min=0, max=1), + ), + ( + 'multi-dim-probability', + np.array([25, 25, 25]), + parameterize.xrand((2, 3), dtype='float64', min=0, max=1), + np.array([25, 25, 25]), + parameterize.xrand((2, 3), dtype='float64', min=0, max=1), + ), + ], +) +class TestBinomialKL(unittest.TestCase): + def setUp(self): + self._dist1 = Binomial( + total_count=paddle.to_tensor(self.n_1), + probs=paddle.to_tensor(self.p_1), + ) + self._dist2 = Binomial( + total_count=paddle.to_tensor(self.n_2), + probs=paddle.to_tensor(self.p_2), + ) + + def test_kl_divergence(self): + kl0 = self._dist1.kl_divergence(self._dist2) + kl1 = self.kl_divergence(self._dist1, self._dist2) + + self.assertEqual(tuple(kl0.shape), self.p_1.shape) + self.assertEqual(tuple(kl1.shape), self.p_1.shape) + np.testing.assert_allclose( + kl0, + kl1, + rtol=config.RTOL.get(str(self.p_1.dtype)), + atol=config.ATOL.get(str(self.p_1.dtype)), + ) + + def kl_divergence(self, dist1, dist2): + support = np.arange(1 + self.n_1.max(), dtype=self.p_1.dtype) + support = support.reshape((-1,) + (1,) * len(self.p_1.shape)) + log_prob_1 = scipy.stats.binom.logpmf( + support, dist1.total_count, dist1.probs + ) + log_prob_2 = scipy.stats.binom.logpmf( + support, dist2.total_count, dist2.probs + ) + return (np.exp(log_prob_1) * (log_prob_1 - log_prob_2)).sum(0) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/distribution/test_distribution_binomial_static.py b/test/distribution/test_distribution_binomial_static.py new file mode 100644 index 0000000000000..360dc411c4554 --- /dev/null +++ b/test/distribution/test_distribution_binomial_static.py @@ -0,0 +1,245 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +import scipy.stats +from distribution import config + +import paddle +from paddle.distribution.binomial import Binomial + +paddle.enable_static() + + +paddle.enable_static() + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs'), + [ + ( + 'one-dim', + np.array([1000]), + parameterize.xrand((1,), dtype='float32', min=0, max=1), + ), + ( + 'multi-dim', + np.array([100]), + parameterize.xrand((1, 3), dtype='float64', min=0, max=1), + ), + ], +) +class TestBinomial(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + probs = paddle.static.data( + 'probs', self.probs.shape, self.probs.dtype + ) + total_count = paddle.static.data( + 'total_count', self.total_count.shape, self.total_count.dtype + ) + dist = Binomial(total_count, probs) + mean = dist.mean + var = dist.variance + entropy = dist.entropy() + large_samples = dist.sample(shape=(1000,)) + fetch_list = [mean, var, entropy, large_samples] + feed = { + 'probs': self.probs, + 'total_count': self.total_count, + } + + executor.run(startup_program) + [ + self.mean, + self.var, + self.entropy, + self.large_samples, + ] = executor.run(main_program, feed=feed, fetch_list=fetch_list) + + def test_mean(self): + self.assertEqual(str(self.mean.dtype).split('.')[-1], self.probs.dtype) + np.testing.assert_allclose( + self.mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + def test_variance(self): + self.assertEqual(str(self.var.dtype).split('.')[-1], self.probs.dtype) + np.testing.assert_allclose( + self.var, + self._np_variance(), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + def test_entropy(self): + self.assertEqual( + str(self.entropy.dtype).split('.')[-1], self.probs.dtype + ) + np.testing.assert_allclose( + self.entropy, + self._np_entropy(), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + def test_sample(self): + self.assertEqual( + str(self.large_samples.dtype).split('.')[-1], self.probs.dtype + ) + sample_mean = self.large_samples.mean(axis=0) + sample_variance = self.large_samples.var(axis=0) + np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.20) + np.testing.assert_allclose(sample_variance, self.var, atol=0, rtol=0.20) + + def _np_variance(self): + return scipy.stats.binom.var(self.total_count, self.probs) + + def _np_mean(self): + return scipy.stats.binom.mean(self.total_count, self.probs) + + def _np_entropy(self): + return scipy.stats.binom.entropy(self.total_count, self.probs) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'total_count', 'probs', 'value'), + [ + ( + 'value-same-shape', + np.array([10]).astype('int64'), + np.array([0.2, 0.3, 0.5]).astype('float64'), + np.array([2.0, 3.0, 5.0]).astype('float64'), + ), + ( + 'value-broadcast-shape', + np.array([10]), + np.array([[0.3, 0.7], [0.5, 0.5]]), + np.array([[[4.0, 6.0], [8.0, 2.0]], [[2.0, 4.0], [9.0, 7.0]]]), + ), + ], +) +class TestBinomialProbs(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + total_count = paddle.static.data( + 'total_count', self.total_count.shape, self.total_count.dtype + ) + probs = paddle.static.data( + 'probs', self.probs.shape, self.probs.dtype + ) + value = paddle.static.data( + 'value', self.value.shape, self.value.dtype + ) + dist = Binomial(total_count, probs) + pmf = dist.prob(value) + feed = { + 'total_count': self.total_count, + 'probs': self.probs, + 'value': self.value, + } + fetch_list = [pmf] + + executor.run(startup_program) + [self.pmf] = executor.run( + main_program, feed=feed, fetch_list=fetch_list + ) + + def test_prob(self): + np.testing.assert_allclose( + self.pmf, + scipy.stats.binom.pmf(self.value, self.total_count, self.probs), + rtol=config.RTOL.get(str(self.probs.dtype)), + atol=config.ATOL.get(str(self.probs.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'n_1', 'p_1', 'n_2', 'p_2'), + [ + ( + 'multi-dim-probability', + np.array([32]), + parameterize.xrand((1, 2), dtype='float64', min=0, max=1), + np.array([32]), + parameterize.xrand((1, 2), dtype='float64', min=0, max=1), + ), + ], +) +class TestBinomialKL(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + n_1 = paddle.static.data('n_1', self.n_1.shape, self.n_1.dtype) + p_1 = paddle.static.data('p_1', self.p_1.shape, self.p_1.dtype) + n_2 = paddle.static.data('n_2', self.n_2.shape, self.n_2.dtype) + p_2 = paddle.static.data('p_2', self.p_2.shape, self.p_2.dtype) + dist1 = Binomial(n_1, p_1) + dist2 = Binomial(n_2, p_2) + kl_dist1_dist2 = dist1.kl_divergence(dist2) + feed = { + 'n_1': self.n_1, + 'p_1': self.p_1, + 'n_2': self.n_2, + 'p_2': self.p_2, + } + fetch_list = [kl_dist1_dist2] + + executor.run(startup_program) + [self.kl_dist1_dist2] = executor.run( + main_program, feed=feed, fetch_list=fetch_list + ) + + def test_kl_divergence(self): + kl0 = self.kl_dist1_dist2 + kl1 = self.kl_divergence_scipy() + + self.assertEqual(tuple(kl0.shape), self.p_1.shape) + self.assertEqual(tuple(kl1.shape), self.p_1.shape) + np.testing.assert_allclose( + kl0, + kl1, + rtol=config.RTOL.get(str(self.p_1.dtype)), + atol=config.ATOL.get(str(self.p_1.dtype)), + ) + + def kl_divergence_scipy(self): + support = np.arange(1 + self.n_1.max(), dtype=self.p_1.dtype) + support = support.reshape((-1,) + (1,) * len(self.p_1.shape)) + log_prob_1 = scipy.stats.binom.logpmf(support, self.n_1, self.p_1) + log_prob_2 = scipy.stats.binom.logpmf(support, self.n_2, self.p_2) + return (np.exp(log_prob_1) * (log_prob_1 - log_prob_2)).sum(0) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/distribution/test_distribution_poisson.py b/test/distribution/test_distribution_poisson.py new file mode 100644 index 0000000000000..039b4ee040d24 --- /dev/null +++ b/test/distribution/test_distribution_poisson.py @@ -0,0 +1,193 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +import scipy.stats +from distribution import config + +import paddle +from paddle.distribution import Poisson + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate'), + [ + ('one-dim', np.array([100.0]).astype('float64')), + # bondary case and extreme case (`scipy.stats.poisson.entropy` cannot converge for very extreme cases such as rate=10000.0) + ('multi-dim', np.array([0.0, 1000.0]).astype('float32')), + ], +) +class TestPoisson(unittest.TestCase): + def setUp(self): + self._dist = Poisson(rate=paddle.to_tensor(self.rate)) + + def test_mean(self): + mean = self._dist.mean + self.assertEqual(mean.numpy().dtype, self.rate.dtype) + np.testing.assert_allclose( + mean, + scipy.stats.poisson.mean(self.rate), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_variance(self): + var = self._dist.variance + self.assertEqual(var.numpy().dtype, self.rate.dtype) + np.testing.assert_allclose( + var, + scipy.stats.poisson.var(self.rate), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_entropy(self): + entropy = self._dist.entropy() + self.assertEqual(entropy.numpy().dtype, self.rate.dtype) + np.testing.assert_allclose( + entropy, + scipy.stats.poisson.entropy(self.rate), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_sample(self): + sample_shape = () + samples = self._dist.sample(sample_shape) + self.assertEqual(samples.numpy().dtype, self.rate.dtype) + self.assertEqual( + tuple(samples.shape), + sample_shape + self._dist.batch_shape + self._dist.event_shape, + ) + + sample_shape = (5000,) + samples = self._dist.sample(sample_shape) + sample_mean = samples.mean(axis=0) + sample_variance = samples.var(axis=0) + + np.testing.assert_allclose( + sample_mean, self._dist.mean, atol=0, rtol=0.20 + ) + np.testing.assert_allclose( + sample_variance, self._dist.variance, atol=0, rtol=0.20 + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate', 'value'), + [ + ( + 'value-same-shape', + np.array(1000).astype('float32'), + np.array(1100).astype('float32'), + ), + ( + 'value-broadcast-shape', + np.array(10).astype('float64'), + np.array([2.0, 3.0, 5.0, 10.0, 20.0]).astype('float64'), + ), + ], +) +class TestPoissonProbs(unittest.TestCase): + def setUp(self): + self._dist = Poisson(rate=paddle.to_tensor(self.rate)) + + def test_prob(self): + np.testing.assert_allclose( + self._dist.prob(paddle.to_tensor(self.value)), + scipy.stats.poisson.pmf(self.value, self.rate), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_log_prob(self): + np.testing.assert_allclose( + self._dist.log_prob(paddle.to_tensor(self.value)), + scipy.stats.poisson.logpmf(self.value, self.rate), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate_1', 'rate_2'), + [ + ( + 'one-dim', + parameterize.xrand((1,), min=1, max=20) + .astype('int32') + .astype('float64'), + parameterize.xrand((1,), min=1, max=20) + .astype('int32') + .astype('float64'), + ), + ( + 'multi-dim', + parameterize.xrand((5, 3), min=1, max=20) + .astype('int32') + .astype('float32'), + parameterize.xrand((5, 3), min=1, max=20) + .astype('int32') + .astype('float32'), + ), + ], +) +class TestPoissonKL(unittest.TestCase): + def setUp(self): + self._dist1 = Poisson(rate=paddle.to_tensor(self.rate_1)) + self._dist2 = Poisson(rate=paddle.to_tensor(self.rate_2)) + + def test_kl_divergence(self): + kl0 = self._dist1.kl_divergence(self._dist2) + kl1 = self.kl_divergence_scipy() + + self.assertEqual(tuple(kl0.shape), self._dist1.batch_shape) + self.assertEqual(tuple(kl1.shape), self._dist1.batch_shape) + np.testing.assert_allclose( + kl0, + kl1, + rtol=config.RTOL.get(str(self.rate_1.dtype)), + atol=config.ATOL.get(str(self.rate_1.dtype)), + ) + + def kl_divergence_scipy(self): + rate_max = np.max(np.maximum(self.rate_1, self.rate_2)) + rate_min = np.min(np.minimum(self.rate_1, self.rate_2)) + support_max = self.enumerate_bounded_support(rate_max) + support_min = self.enumerate_bounded_support(rate_min) + a_min = np.min(support_min) + a_max = np.max(support_max) + common_support = np.arange( + a_min, a_max, dtype=self.rate_1.dtype + ).reshape((-1,) + (1,) * len(self.rate_1.shape)) + log_prob_1 = scipy.stats.poisson.logpmf(common_support, self.rate_1) + log_prob_2 = scipy.stats.poisson.logpmf(common_support, self.rate_2) + return (np.exp(log_prob_1) * (log_prob_1 - log_prob_2)).sum(0) + + def enumerate_bounded_support(self, rate): + s = np.sqrt(rate) + upper = int(rate + 30 * s) + lower = int(np.clip(rate - 30 * s, a_min=0, a_max=rate)) + values = np.arange(lower, upper, dtype=self.rate_1.dtype) + return values + + +if __name__ == '__main__': + unittest.main() diff --git a/test/distribution/test_distribution_poisson_static.py b/test/distribution/test_distribution_poisson_static.py new file mode 100644 index 0000000000000..199217aecde16 --- /dev/null +++ b/test/distribution/test_distribution_poisson_static.py @@ -0,0 +1,234 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import parameterize +import scipy.stats +from distribution import config + +import paddle +from paddle.distribution import Poisson + +paddle.enable_static() + + +paddle.enable_static() + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate'), + [ + ('one-dim', np.array([1000.0]).astype('float32')), + ( + 'multi-dim', + parameterize.xrand((2,), min=1, max=20) + .astype('int32') + .astype('float64'), + ), + ], +) +class TestPoisson(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + rate = paddle.static.data('rate', self.rate.shape, self.rate.dtype) + dist = Poisson(rate) + mean = dist.mean + var = dist.variance + entropy = dist.entropy() + mini_samples = dist.sample(shape=()) + large_samples = dist.sample(shape=(1000,)) + fetch_list = [mean, var, entropy, mini_samples, large_samples] + feed = {'rate': self.rate} + + executor.run(startup_program) + [ + self.mean, + self.var, + self.entropy, + self.mini_samples, + self.large_samples, + ] = executor.run(main_program, feed=feed, fetch_list=fetch_list) + + def test_mean(self): + self.assertEqual(str(self.mean.dtype).split('.')[-1], self.rate.dtype) + np.testing.assert_allclose( + self.mean, + self._np_mean(), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_variance(self): + self.assertEqual(str(self.var.dtype).split('.')[-1], self.rate.dtype) + np.testing.assert_allclose( + self.var, + self._np_variance(), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_entropy(self): + self.assertEqual( + str(self.entropy.dtype).split('.')[-1], self.rate.dtype + ) + np.testing.assert_allclose( + self.entropy, + self._np_entropy(), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + def test_sample(self): + self.assertEqual( + str(self.mini_samples.dtype).split('.')[-1], self.rate.dtype + ) + sample_mean = self.large_samples.mean(axis=0) + sample_variance = self.large_samples.var(axis=0) + np.testing.assert_allclose(sample_mean, self.mean, atol=0, rtol=0.20) + np.testing.assert_allclose(sample_variance, self.var, atol=0, rtol=0.20) + + def _np_variance(self): + return self.rate + + def _np_mean(self): + return self.rate + + def _np_entropy(self): + return scipy.stats.poisson.entropy(self.rate) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate', 'value'), + [ + ( + 'value-same-shape', + np.array(1000).astype('float32'), + np.array(1100).astype('float32'), + ), + ( + 'value-broadcast-shape', + np.array(10).astype('float64'), + np.array([2.0, 3.0]).astype('float64'), + ), + ], +) +class TestPoissonProbs(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + rate = paddle.static.data('rate', self.rate.shape, self.rate.dtype) + value = paddle.static.data( + 'value', self.value.shape, self.value.dtype + ) + dist = Poisson(rate) + pmf = dist.prob(value) + feed = {'rate': self.rate, 'value': self.value} + fetch_list = [pmf] + + executor.run(startup_program) + [self.pmf] = executor.run( + main_program, feed=feed, fetch_list=fetch_list + ) + + def test_prob(self): + np.testing.assert_allclose( + self.pmf, + scipy.stats.poisson.pmf(self.value, self.rate), + rtol=config.RTOL.get(str(self.rate.dtype)), + atol=config.ATOL.get(str(self.rate.dtype)), + ) + + +@parameterize.place(config.DEVICES) +@parameterize.parameterize_cls( + (parameterize.TEST_CASE_NAME, 'rate_1', 'rate_2'), + [ + ( + 'multi-dim', + parameterize.xrand((2, 3), min=1, max=20) + .astype('int32') + .astype('float32'), + parameterize.xrand((2, 3), min=1, max=20) + .astype('int32') + .astype('float32'), + ), + ], +) +class TestPoissonKL(unittest.TestCase): + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + + with paddle.static.program_guard(main_program, startup_program): + rate_1 = paddle.static.data('rate_1', self.rate_1.shape) + rate_2 = paddle.static.data('rate_2', self.rate_2.shape) + dist1 = Poisson(rate_1) + dist2 = Poisson(rate_2) + kl_dist1_dist2 = dist1.kl_divergence(dist2) + feed = {'rate_1': self.rate_1, 'rate_2': self.rate_2} + fetch_list = [kl_dist1_dist2] + + executor.run(startup_program) + [self.kl_dist1_dist2] = executor.run( + main_program, feed=feed, fetch_list=fetch_list + ) + + def test_kl_divergence(self): + kl0 = self.kl_dist1_dist2 + kl1 = self.kl_divergence_scipy() + + self.assertEqual(tuple(kl0.shape), self.rate_1.shape) + self.assertEqual(tuple(kl1.shape), self.rate_1.shape) + np.testing.assert_allclose( + kl0, + kl1, + rtol=config.RTOL.get(str(self.rate_1.dtype)), + atol=config.ATOL.get(str(self.rate_1.dtype)), + ) + + def kl_divergence_scipy(self): + rate_max = np.max(np.maximum(self.rate_1, self.rate_2)) + rate_min = np.min(np.minimum(self.rate_1, self.rate_2)) + support_max = self.enumerate_bounded_support(rate_max) + support_min = self.enumerate_bounded_support(rate_min) + a_min = np.min(support_min) + a_max = np.max(support_max) + common_support = np.arange( + a_min, a_max, dtype=self.rate_1.dtype + ).reshape((-1,) + (1,) * len(self.rate_1.shape)) + log_prob_1 = scipy.stats.poisson.logpmf(common_support, self.rate_1) + log_prob_2 = scipy.stats.poisson.logpmf(common_support, self.rate_2) + return (np.exp(log_prob_1) * (log_prob_1 - log_prob_2)).sum(0) + + def enumerate_bounded_support(self, rate): + s = np.sqrt(rate) + upper = int(rate + 30 * s) + lower = int(np.clip(rate - 30 * s, a_min=0, a_max=rate)) + values = np.arange(lower, upper, dtype=self.rate_1.dtype) + return values + + +if __name__ == '__main__': + unittest.main()