diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index e22df441fd860..8340024595139 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -769,12 +769,6 @@ void BroadcastTensorsInferMeta(const std::vector& x, target_rank = std::max(target_rank, input_ddim.size()); } - PADDLE_ENFORCE_GT(target_rank, - 0, - errors::InvalidArgument("BroadcastTensorsOp requires at " - "least one input tensor to have " - "rank greater than zero")); - std::vector target_dims(target_rank, 0); // 2. Output dim(axis=x) = max(Inputs dim(axis=x)) for (int index = 0; index < target_rank; index++) { diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cc b/paddle/phi/kernels/funcs/eigen/broadcast.cc index c806cdeaad60b..e0b074548c91d 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cc +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cc @@ -37,6 +37,7 @@ struct EigenBroadcast { OutType out, InType in, const Array& bcast) { + // Eigen::TensorMap.broadcast not support 0D out.device(dev) = in.broadcast(bcast); } diff --git a/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h b/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h index d0b7825d15ed3..c61b10d5a2199 100644 --- a/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h +++ b/paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h @@ -18,6 +18,7 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/broadcast_tensors_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" @@ -48,9 +49,9 @@ void ApplyBroadcast(const Context& ctx, // expanded dims: "new_input_dims_vec" Eigen::DSizes bcast_dims; std::vector new_input_dims_vec(out_rank); - for (int j = 0; j < out_rank; j++) { - int out_axis = out_rank - j - 1; - int in_axis = in_rank - j - 1; + for (int i = 0; i < out_rank; i++) { + int in_axis = in_rank - i - 1; + int out_axis = out_rank - i - 1; bcast_dims[out_axis] = output_dims[out_axis]; new_input_dims_vec[out_axis] = 1; @@ -101,12 +102,18 @@ void BroadcastTensorsKernel(const Context& ctx, for (size_t i = 0; i < num_ins; i++) { int out_rank = out_tensors[i]->dims().size(); switch (out_rank) { - SWITCH_OUT_RANK_CASE(1) - SWITCH_OUT_RANK_CASE(2) - SWITCH_OUT_RANK_CASE(3) - SWITCH_OUT_RANK_CASE(4) - SWITCH_OUT_RANK_CASE(5) - SWITCH_OUT_RANK_CASE(6) + case 0: { + const DenseTensor* src = in_tensors[i]; + DenseTensor* dst = out_tensors[i]; + phi::Copy(ctx, *src, src->place(), false, dst); + break; + } + SWITCH_OUT_RANK_CASE(1) + SWITCH_OUT_RANK_CASE(2) + SWITCH_OUT_RANK_CASE(3) + SWITCH_OUT_RANK_CASE(4) + SWITCH_OUT_RANK_CASE(5) + SWITCH_OUT_RANK_CASE(6) default: { PADDLE_THROW(phi::errors::InvalidArgument( "Target tensor rank out of range" diff --git a/python/paddle/distribution/beta.py b/python/paddle/distribution/beta.py index 7c3e94056f430..07cbf9155c701 100644 --- a/python/paddle/distribution/beta.py +++ b/python/paddle/distribution/beta.py @@ -60,13 +60,13 @@ class Beta(exponential_family.ExponentialFamily): # scale input beta = paddle.distribution.Beta(alpha=0.5, beta=0.5) print(beta.mean) - # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # [0.50000000]) print(beta.variance) - # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # [0.12500000]) print(beta.entropy()) - # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # [0.12500000]) # tensor input with broadcast @@ -84,10 +84,10 @@ class Beta(exponential_family.ExponentialFamily): def __init__(self, alpha, beta): if isinstance(alpha, numbers.Real): - alpha = paddle.full(shape=[1], fill_value=alpha) + alpha = paddle.full(shape=[], fill_value=alpha) if isinstance(beta, numbers.Real): - beta = paddle.full(shape=[1], fill_value=beta) + beta = paddle.full(shape=[], fill_value=beta) self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta]) diff --git a/python/paddle/distribution/dirichlet.py b/python/paddle/distribution/dirichlet.py index a726675249e12..ffa94ee37cdfc 100644 --- a/python/paddle/distribution/dirichlet.py +++ b/python/paddle/distribution/dirichlet.py @@ -62,10 +62,10 @@ class Dirichlet(exponential_family.ExponentialFamily): dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.])) print(dirichlet.entropy()) - # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # [-1.24434423]) print(dirichlet.prob(paddle.to_tensor([.3, .5, .6]))) - # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # [10.80000114]) """ diff --git a/python/paddle/distribution/distribution.py b/python/paddle/distribution/distribution.py index 299bd7348b03a..ae532f45f0aaa 100644 --- a/python/paddle/distribution/distribution.py +++ b/python/paddle/distribution/distribution.py @@ -134,7 +134,11 @@ def _extend_shape(self, sample_shape): Returns: Tensor: generated sample data shape """ - return sample_shape + self._batch_shape + self._event_shape + return ( + tuple(sample_shape) + + tuple(self._batch_shape) + + tuple(self._event_shape) + ) def _validate_args(self, *args): """ @@ -173,11 +177,11 @@ def _to_tensor(self, *args): tmp = 0.0 for arg in args: - if isinstance(arg, float): - arg = [arg] - if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)): + if not isinstance( + arg, (float, list, tuple, np.ndarray, tensor.Variable) + ): raise TypeError( - "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".format( + "Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {}".format( type(arg) ) ) diff --git a/python/paddle/distribution/gumbel.py b/python/paddle/distribution/gumbel.py index 067f87ca8375a..c02d017f29e52 100644 --- a/python/paddle/distribution/gumbel.py +++ b/python/paddle/distribution/gumbel.py @@ -61,7 +61,7 @@ class Gumbel(TransformedDistribution): dist.cdf(value) # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.54523915]) dist.entropy() - # Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567]) + # Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567]) dist.rsample([2]) # Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[0.80463481], [0.91893655]]) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 010f781d041cc..ac3b94d4ebd66 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -56,7 +56,7 @@ def kl_divergence(p, q): q = paddle.distribution.Beta(alpha=0.3, beta=0.7) print(paddle.distribution.kl_divergence(p, q)) - # Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True, # [0.21193528]) """ diff --git a/python/paddle/distribution/laplace.py b/python/paddle/distribution/laplace.py index 5c047aebfd9b2..2d9d1ba5c5370 100644 --- a/python/paddle/distribution/laplace.py +++ b/python/paddle/distribution/laplace.py @@ -48,7 +48,7 @@ class Laplace(distribution.Distribution): m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) m.sample() # Laplace distributed with loc=0, scale=1 - # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, # [3.68546247]) """ @@ -209,7 +209,7 @@ def entropy(self): m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) m.entropy() - # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, # [1.69314718]) """ return 1 + paddle.log(2 * self.scale) @@ -304,14 +304,10 @@ def sample(self, shape=()): m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) m.sample() # Laplace distributed with loc=0, scale=1 - # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, # [3.68546247]) """ - if not isinstance(shape, tuple): - raise TypeError( - f'Expected shape should be tuple[int], but got {type(shape)}' - ) - + shape = shape if isinstance(shape, tuple) else tuple(shape) with paddle.no_grad(): return self.rsample(shape) @@ -336,22 +332,16 @@ def rsample(self, shape): """ eps = self._get_eps() - shape = self._extend_shape(shape) or (1,) + shape = self._extend_shape(shape) uniform = paddle.uniform( shape=shape, min=float(np.nextafter(-1, 1)) + eps / 2, max=1.0 - eps / 2, dtype=self.loc.dtype, ) - - if len(self.scale.shape) == 0 and len(self.loc.shape) == 0: - loc, scale, uniform = paddle.broadcast_tensors( - [self.loc, self.scale, uniform] - ) - else: - loc, scale = self.loc, self.scale - - return loc - scale * uniform.sign() * paddle.log1p(-uniform.abs()) + return self.loc - self.scale * uniform.sign() * paddle.log1p( + -uniform.abs() + ) def _get_eps(self): """ @@ -410,7 +400,7 @@ def kl_divergence(self, other): m1 = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0])) m2 = paddle.distribution.Laplace(paddle.to_tensor([1.0]), paddle.to_tensor([0.5])) m1.kl_divergence(m2) - # Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, + # Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, # [1.04261160]) """ diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index f437be3ab8ac5..c69a8a6cf9113 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -72,13 +72,13 @@ class LogNormal(TransformedDistribution): sample = lognormal_a.sample((2, )) # a random tensor created by lognormal distribution with shape: [2, 1] entropy = lognormal_a.entropy() - # [1.4189385] with shape: [1] + # [1.4189385] with shape: [] lp = lognormal_a.log_prob(value_tensor) # [-0.72069150] with shape: [1] p = lognormal_a.probs(value_tensor) # [0.48641577] with shape: [1] kl = lognormal_a.kl_divergence(lognormal_b) - # [0.34939718] with shape: [1] + # [0.34939718] with shape: [] """ def __init__(self, loc, scale): diff --git a/python/paddle/distribution/multinomial.py b/python/paddle/distribution/multinomial.py index 585cb2152d948..c57511da5c672 100644 --- a/python/paddle/distribution/multinomial.py +++ b/python/paddle/distribution/multinomial.py @@ -166,7 +166,7 @@ def entropy(self): Tensor: entropy value """ n = paddle.full( - shape=[1], fill_value=self.total_count, dtype=self.probs.dtype + shape=[], fill_value=self.total_count, dtype=self.probs.dtype ) support = paddle.arange( self.total_count + 1, dtype=self.probs.dtype diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 22ebab0ed4b62..a3ac0cf5c3df5 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -77,13 +77,13 @@ class Normal(distribution.Distribution): sample = normal_a.sample([2]) # a random tensor created by normal distribution with shape: [2, 1] entropy = normal_a.entropy() - # [1.4189385] with shape: [1] + # [1.4189385] with shape: [] lp = normal_a.log_prob(value_tensor) # [-1.2389386] with shape: [1] p = normal_a.probs(value_tensor) # [0.28969154] with shape: [1] kl = normal_a.kl_divergence(normal_b) - # [0.34939718] with shape: [1] + # [0.34939718] with shape: [] """ def __init__(self, loc, scale, name=None): @@ -101,7 +101,6 @@ def __init__(self, loc, scale, name=None): 'Normal', ) - self.batch_size_unknown = False self.all_arg_is_float = False self.name = name if name is not None else 'Normal' self.dtype = 'float32' @@ -112,7 +111,6 @@ def __init__(self, loc, scale, name=None): scale = float(scale) if self._validate_args(loc, scale): - self.batch_size_unknown = True self.loc = loc self.scale = scale self.dtype = convert_dtype(loc.dtype) @@ -174,8 +172,7 @@ def sample(self, shape=(), seed=0): shape = list(shape) batch_shape = list((self.loc + self.scale).shape) name = self.name + '_sample' - - if self.batch_size_unknown: + if -1 in batch_shape: output_shape = shape + batch_shape zero_tmp = tensor.fill_constant_batch_size_like( self.loc + self.scale, batch_shape + shape, self.dtype, 0.0 @@ -236,9 +233,12 @@ def entropy(self): """ name = self.name + '_entropy' batch_shape = list((self.loc + self.scale).shape) - zero_tmp = tensor.fill_constant_batch_size_like( - self.loc + self.scale, batch_shape, self.dtype, 0.0 - ) + if -1 in batch_shape: + zero_tmp = tensor.fill_constant_batch_size_like( + self.loc + self.scale, batch_shape, self.dtype, 0.0 + ) + else: + zero_tmp = paddle.full(batch_shape, 0.0, self.dtype) return paddle.add( 0.5 + zero_tmp, 0.5 * math.log(2 * math.pi) + paddle.log((self.scale + zero_tmp)), diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index 40cd4b6627b38..85575b3c61a15 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -368,7 +368,7 @@ class AbsTransform(Transform): # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, # [1., 0., 1.]) - print(abs.inverse(paddle.to_tensor(1.))) + print(abs.inverse(paddle.to_tensor([1.]))) # (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, # [-1.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, # [1.])) @@ -380,7 +380,7 @@ class AbsTransform(Transform): # 0.)) #Special case handling of 0. - print(abs.inverse(paddle.to_tensor(0.))) + print(abs.inverse(paddle.to_tensor([0.]))) # (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, # [0.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, # [0.])) diff --git a/python/paddle/distribution/uniform.py b/python/paddle/distribution/uniform.py index 3c6bd83e9a6f1..824d1c7549693 100644 --- a/python/paddle/distribution/uniform.py +++ b/python/paddle/distribution/uniform.py @@ -84,7 +84,7 @@ class Uniform(distribution.Distribution): sample = uniform.sample([2]) # a random tensor created by uniform distribution with shape: [2, 1] entropy = uniform.entropy() - # [0.6931472] with shape: [1] + # [0.6931472] with shape: [] lp = uniform.log_prob(value_tensor) # [-0.6931472] with shape: [1] p = uniform.probs(value_tensor) @@ -117,7 +117,6 @@ def __init__(self, low, high, name=None): high = float(high) if self._validate_args(low, high): - self.batch_size_unknown = True self.low = low self.high = high self.dtype = convert_dtype(low.dtype) @@ -159,7 +158,7 @@ def sample(self, shape, seed=0): name = self.name + '_sample' batch_shape = list((self.low + self.high).shape) - if self.batch_size_unknown: + if -1 in batch_shape: output_shape = shape + batch_shape zero_tmp = tensor.fill_constant_batch_size_like( self.low + self.high, batch_shape + shape, self.dtype, 0.0 diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py index f7d7060fb9d5a..57264b5f8972a 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py @@ -69,7 +69,7 @@ def test_sample(self): def test_rsample(self): shape = [5, 10, 8] - expected_shape = (5, 10, 8, 1) + expected_shape = (5, 10, 8) data = self._t.rsample(shape) self.assertEqual(tuple(data.shape), expected_shape) self.assertEqual(data.dtype, self.base.loc.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py index 233d69d1fd045..e14638b86ea6e 100644 --- a/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py +++ b/python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py @@ -768,6 +768,44 @@ def test_broadcast_to(self): self.assertEqual(out3.grad.shape, [3, 3]) np.testing.assert_allclose(out3.grad, 1.0) + def test_broadcast_tensors(self): + # 1) x is 0D, y is 0D + x1 = paddle.full([], 2.0) + x1.stop_gradient = False + x2 = paddle.full([], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + # backward has bug now + # out1.backward() + + self.assertEqual(out1.shape, []) + self.assertEqual(out2.shape, []) + # self.assertEqual(x1.grad.shape, []) + + # 2) x is ND , y is 0D + x1 = paddle.full([2, 3], 2.0) + x1.stop_gradient = False + x2 = paddle.full([], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + # out1.backward() + + self.assertEqual(out1.shape, [2, 3]) + self.assertEqual(out2.shape, [2, 3]) + # self.assertEqual(x1.grad.shape, [2, 3]) + + # 3) x is 0D , y is ND + x1 = paddle.full([], 2.0) + x1.stop_gradient = False + x2 = paddle.full([2, 3], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + # out1.backward() + + self.assertEqual(out1.shape, [2, 3]) + self.assertEqual(out2.shape, [2, 3]) + # self.assertEqual(x1.grad.shape, [2, 3]) + def test_broadcast_shape(self): x = [] y = [3, 5] @@ -3542,6 +3580,37 @@ def _test_shape(self): self.assertEqual(res[0].shape, (0)) np.testing.assert_array_equal(res[0], np.array([])) + def test_broadcast_tensors(self): + # 1) x is 0D, y is 0D + x1 = paddle.full([], 2.0) + x1.stop_gradient = False + x2 = paddle.full([], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + + self.assertEqual(out1.shape, ()) + self.assertEqual(out2.shape, ()) + + # 2) x is ND , y is 0D + x1 = paddle.full([2, 3], 2.0) + x1.stop_gradient = False + x2 = paddle.full([], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + + self.assertEqual(out1.shape, (2, 3)) + self.assertEqual(out2.shape, (2, 3)) + + # 3) x is 0D , y is ND + x1 = paddle.full([], 2.0) + x1.stop_gradient = False + x2 = paddle.full([2, 3], 2.0) + x2.stop_gradient = False + out1, out2 = paddle.broadcast_tensors([x1, x2]) + + self.assertEqual(out1.shape, (2, 3)) + self.assertEqual(out2.shape, (2, 3)) + # Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest. class TestNoBackwardAPI(unittest.TestCase): @@ -4116,5 +4185,129 @@ def test_static(self): paddle.disable_static() +class TestDistribution(unittest.TestCase): + def setUp(self): + self.x = paddle.full([], 2.0) + + def test_Categorical(self): + logits = paddle.rand([6]) + d = paddle.distribution.Categorical(logits) + self.assertEqual(d.sample([]).shape, []) + self.assertEqual(d.probs(paddle.full([], 2, dtype='int64')).shape, []) + self.assertEqual( + d.log_prob(paddle.full([], 2, dtype='int64')).shape, [] + ) + # because use paddle.sum + # self.assertEqual(d.entropy().shape, []) + + def test_Normal(self): + normal = paddle.distribution.Normal(0.0, 3.0) + self.assertEqual(normal.sample([]).shape, []) + self.assertEqual(normal.rsample([]).shape, []) + self.assertEqual(normal.mean.shape, []) + self.assertEqual(normal.variance.shape, []) + self.assertEqual(normal.probs(self.x).shape, []) + self.assertEqual(normal.log_prob(self.x).shape, []) + self.assertEqual(normal.entropy().shape, []) + + normal = paddle.distribution.Normal( + paddle.full([], 0.0), paddle.full([], 3.0) + ) + self.assertEqual(normal.sample([]).shape, []) + self.assertEqual(normal.rsample([]).shape, []) + self.assertEqual(normal.mean.shape, []) + self.assertEqual(normal.variance.shape, []) + self.assertEqual(normal.probs(self.x).shape, []) + self.assertEqual(normal.log_prob(self.x).shape, []) + self.assertEqual(normal.entropy().shape, []) + + def test_Uniform(self): + uniform = paddle.distribution.Uniform(0.0, 1.0) + self.assertEqual(uniform.sample([]).shape, []) + self.assertEqual(uniform.probs(self.x).shape, []) + self.assertEqual(uniform.log_prob(self.x).shape, []) + self.assertEqual(uniform.entropy().shape, []) + + uniform = paddle.distribution.Uniform( + paddle.full([], 0.0), paddle.full([], 1.0) + ) + self.assertEqual(uniform.sample([]).shape, []) + self.assertEqual(uniform.probs(self.x).shape, []) + self.assertEqual(uniform.log_prob(self.x).shape, []) + self.assertEqual(uniform.entropy().shape, []) + + def test_Beta(self): + beta = paddle.distribution.Beta(alpha=0.5, beta=0.5) + self.assertEqual(beta.sample([]).shape, []) + self.assertEqual(beta.mean.shape, []) + self.assertEqual(beta.variance.shape, []) + # because use paddle.sum + # self.assertEqual(beta.prob(self.x).shape, []) + # self.assertEqual(beta.log_prob(self.x).shape, []) + # self.assertEqual(beta.entropy().shape, []) + + def test_kl_divergence(self): + p = paddle.distribution.Beta(alpha=0.5, beta=0.5) + q = paddle.distribution.Beta(alpha=0.2, beta=1.0) + kl = paddle.distribution.kl_divergence(p, q) + self.assertEqual(kl.shape, []) + + def test_TransformedDistribution(self): + d = paddle.distribution.TransformedDistribution( + paddle.distribution.Normal(0.0, 1.0), + [ + paddle.distribution.AffineTransform( + paddle.full([], 1.0), paddle.full([], 2.0) + ) + ], + ) + self.assertEqual(d.sample([]).shape, []) + self.assertEqual(d.rsample([]).shape, []) + self.assertEqual(d.prob(self.x).shape, []) + self.assertEqual(d.log_prob(self.x).shape, []) + + def test_Laplace(self): + d = paddle.distribution.Laplace(0.0, 1.0) + self.assertEqual(d.sample([]).shape, []) + self.assertEqual(d.rsample([]).shape, []) + self.assertEqual(d.mean.shape, []) + self.assertEqual(d.stddev.shape, []) + self.assertEqual(d.variance.shape, []) + self.assertEqual(d.prob(self.x).shape, []) + self.assertEqual(d.log_prob(self.x).shape, []) + self.assertEqual(d.cdf(self.x).shape, []) + self.assertEqual(d.icdf(self.x).shape, []) + self.assertEqual(d.entropy().shape, []) + + def test_LogNormal(self): + d = paddle.distribution.LogNormal(0.0, 1.0) + self.assertEqual(d.sample([]).shape, []) + self.assertEqual(d.mean.shape, []) + self.assertEqual(d.variance.shape, []) + self.assertEqual(d.entropy().shape, []) + self.assertEqual(d.probs(self.x).shape, []) + + def test_Gumbel(self): + d = paddle.distribution.Gumbel(0.0, 1.0) + self.assertEqual(d.sample([]).shape, []) + self.assertEqual(d.rsample([]).shape, []) + self.assertEqual(d.mean.shape, []) + self.assertEqual(d.variance.shape, []) + self.assertEqual(d.stddev.shape, []) + self.assertEqual(d.prob(self.x).shape, []) + self.assertEqual(d.log_prob(self.x).shape, []) + self.assertEqual(d.cdf(self.x).shape, []) + self.assertEqual(d.entropy().shape, []) + + def test_Multinomial(self): + d = paddle.distribution.Multinomial( + 10, paddle.to_tensor([0.2, 0.3, 0.5]) + ) + # because use paddle.sum + # self.assertEqual(d.prob(self.x).shape, []) + # self.assertEqual(d.log_prob(self.x).shape, []) + # self.assertEqual(d.entropy().shape, []) + + if __name__ == "__main__": unittest.main()