Skip to content

Commit

Permalink
[Zero-Dim] Support broadcast_tensors input 0D and distribution API ou…
Browse files Browse the repository at this point in the history
…tput 0D
  • Loading branch information
zhwesky2010 committed Mar 21, 2023
1 parent 0bb7c00 commit b60e325
Show file tree
Hide file tree
Showing 16 changed files with 254 additions and 66 deletions.
6 changes: 0 additions & 6 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,12 +769,6 @@ void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
target_rank = std::max(target_rank, input_ddim.size());
}

PADDLE_ENFORCE_GT(target_rank,
0,
errors::InvalidArgument("BroadcastTensorsOp requires at "
"least one input tensor to have "
"rank greater than zero"));

std::vector<int64_t> target_dims(target_rank, 0);
// 2. Output dim(axis=x) = max(Inputs dim(axis=x))
for (int index = 0; index < target_rank; index++) {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/funcs/eigen/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ struct EigenBroadcast<Eigen::DefaultDevice, T, Rank> {
OutType out,
InType in,
const Array& bcast) {
// Eigen::TensorMap.broadcast not support 0D
out.device(dev) = in.broadcast(bcast);
}

Expand Down
25 changes: 16 additions & 9 deletions paddle/phi/kernels/impl/broadcast_tensors_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/broadcast_tensors_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
Expand Down Expand Up @@ -48,9 +49,9 @@ void ApplyBroadcast(const Context& ctx,
// expanded dims: "new_input_dims_vec"
Eigen::DSizes<Eigen::DenseIndex, OutRank> bcast_dims;
std::vector<int64_t> new_input_dims_vec(out_rank);
for (int j = 0; j < out_rank; j++) {
int out_axis = out_rank - j - 1;
int in_axis = in_rank - j - 1;
for (int i = 0; i < out_rank; i++) {
int in_axis = in_rank - i - 1;
int out_axis = out_rank - i - 1;

bcast_dims[out_axis] = output_dims[out_axis];
new_input_dims_vec[out_axis] = 1;
Expand Down Expand Up @@ -101,12 +102,18 @@ void BroadcastTensorsKernel(const Context& ctx,
for (size_t i = 0; i < num_ins; i++) {
int out_rank = out_tensors[i]->dims().size();
switch (out_rank) {
SWITCH_OUT_RANK_CASE(1)
SWITCH_OUT_RANK_CASE(2)
SWITCH_OUT_RANK_CASE(3)
SWITCH_OUT_RANK_CASE(4)
SWITCH_OUT_RANK_CASE(5)
SWITCH_OUT_RANK_CASE(6)
case 0: {
const DenseTensor* src = in_tensors[i];
DenseTensor* dst = out_tensors[i];
phi::Copy(ctx, *src, src->place(), false, dst);
break;
}
SWITCH_OUT_RANK_CASE(1)
SWITCH_OUT_RANK_CASE(2)
SWITCH_OUT_RANK_CASE(3)
SWITCH_OUT_RANK_CASE(4)
SWITCH_OUT_RANK_CASE(5)
SWITCH_OUT_RANK_CASE(6)
default: {
PADDLE_THROW(phi::errors::InvalidArgument(
"Target tensor rank out of range"
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/distribution/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ class Beta(exponential_family.ExponentialFamily):
# scale input
beta = paddle.distribution.Beta(alpha=0.5, beta=0.5)
print(beta.mean)
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.50000000])
print(beta.variance)
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.12500000])
print(beta.entropy())
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.12500000])
# tensor input with broadcast
Expand All @@ -84,10 +84,10 @@ class Beta(exponential_family.ExponentialFamily):

def __init__(self, alpha, beta):
if isinstance(alpha, numbers.Real):
alpha = paddle.full(shape=[1], fill_value=alpha)
alpha = paddle.full(shape=[], fill_value=alpha)

if isinstance(beta, numbers.Real):
beta = paddle.full(shape=[1], fill_value=beta)
beta = paddle.full(shape=[], fill_value=beta)

self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta])

Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distribution/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ class Dirichlet(exponential_family.ExponentialFamily):
dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))
print(dirichlet.entropy())
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [-1.24434423])
print(dirichlet.prob(paddle.to_tensor([.3, .5, .6])))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [10.80000114])
"""
Expand Down
14 changes: 9 additions & 5 deletions python/paddle/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def _extend_shape(self, sample_shape):
Returns:
Tensor: generated sample data shape
"""
return sample_shape + self._batch_shape + self._event_shape
return (
tuple(sample_shape)
+ tuple(self._batch_shape)
+ tuple(self._event_shape)
)

def _validate_args(self, *args):
"""
Expand Down Expand Up @@ -173,11 +177,11 @@ def _to_tensor(self, *args):
tmp = 0.0

for arg in args:
if isinstance(arg, float):
arg = [arg]
if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
if not isinstance(
arg, (float, list, tuple, np.ndarray, tensor.Variable)
):
raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".format(
"Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {}".format(
type(arg)
)
)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distribution/gumbel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Gumbel(TransformedDistribution):
dist.cdf(value)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.54523915])
dist.entropy()
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567])
# Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567])
dist.rsample([2])
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[0.80463481], [0.91893655]])
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distribution/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def kl_divergence(p, q):
q = paddle.distribution.Beta(alpha=0.3, beta=0.7)
print(paddle.distribution.kl_divergence(p, q))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [0.21193528])
"""
Expand Down
28 changes: 9 additions & 19 deletions python/paddle/distribution/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Laplace(distribution.Distribution):
m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0]))
m.sample() # Laplace distributed with loc=0, scale=1
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# [3.68546247])
"""
Expand Down Expand Up @@ -209,7 +209,7 @@ def entropy(self):
m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0]))
m.entropy()
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.69314718])
"""
return 1 + paddle.log(2 * self.scale)
Expand Down Expand Up @@ -304,14 +304,10 @@ def sample(self, shape=()):
m = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0]))
m.sample() # Laplace distributed with loc=0, scale=1
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# [3.68546247])
"""
if not isinstance(shape, tuple):
raise TypeError(
f'Expected shape should be tuple[int], but got {type(shape)}'
)

shape = shape if isinstance(shape, tuple) else tuple(shape)
with paddle.no_grad():
return self.rsample(shape)

Expand All @@ -336,22 +332,16 @@ def rsample(self, shape):
"""

eps = self._get_eps()
shape = self._extend_shape(shape) or (1,)
shape = self._extend_shape(shape)
uniform = paddle.uniform(
shape=shape,
min=float(np.nextafter(-1, 1)) + eps / 2,
max=1.0 - eps / 2,
dtype=self.loc.dtype,
)

if len(self.scale.shape) == 0 and len(self.loc.shape) == 0:
loc, scale, uniform = paddle.broadcast_tensors(
[self.loc, self.scale, uniform]
)
else:
loc, scale = self.loc, self.scale

return loc - scale * uniform.sign() * paddle.log1p(-uniform.abs())
return self.loc - self.scale * uniform.sign() * paddle.log1p(
-uniform.abs()
)

def _get_eps(self):
"""
Expand Down Expand Up @@ -410,7 +400,7 @@ def kl_divergence(self, other):
m1 = paddle.distribution.Laplace(paddle.to_tensor([0.0]), paddle.to_tensor([1.0]))
m2 = paddle.distribution.Laplace(paddle.to_tensor([1.0]), paddle.to_tensor([0.5]))
m1.kl_divergence(m2)
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.04261160])
"""

Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distribution/lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ class LogNormal(TransformedDistribution):
sample = lognormal_a.sample((2, ))
# a random tensor created by lognormal distribution with shape: [2, 1]
entropy = lognormal_a.entropy()
# [1.4189385] with shape: [1]
# [1.4189385] with shape: []
lp = lognormal_a.log_prob(value_tensor)
# [-0.72069150] with shape: [1]
p = lognormal_a.probs(value_tensor)
# [0.48641577] with shape: [1]
kl = lognormal_a.kl_divergence(lognormal_b)
# [0.34939718] with shape: [1]
# [0.34939718] with shape: []
"""

def __init__(self, loc, scale):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distribution/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def entropy(self):
Tensor: entropy value
"""
n = paddle.full(
shape=[1], fill_value=self.total_count, dtype=self.probs.dtype
shape=[], fill_value=self.total_count, dtype=self.probs.dtype
)
support = paddle.arange(
self.total_count + 1, dtype=self.probs.dtype
Expand Down
18 changes: 9 additions & 9 deletions python/paddle/distribution/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ class Normal(distribution.Distribution):
sample = normal_a.sample([2])
# a random tensor created by normal distribution with shape: [2, 1]
entropy = normal_a.entropy()
# [1.4189385] with shape: [1]
# [1.4189385] with shape: []
lp = normal_a.log_prob(value_tensor)
# [-1.2389386] with shape: [1]
p = normal_a.probs(value_tensor)
# [0.28969154] with shape: [1]
kl = normal_a.kl_divergence(normal_b)
# [0.34939718] with shape: [1]
# [0.34939718] with shape: []
"""

def __init__(self, loc, scale, name=None):
Expand All @@ -101,7 +101,6 @@ def __init__(self, loc, scale, name=None):
'Normal',
)

self.batch_size_unknown = False
self.all_arg_is_float = False
self.name = name if name is not None else 'Normal'
self.dtype = 'float32'
Expand All @@ -112,7 +111,6 @@ def __init__(self, loc, scale, name=None):
scale = float(scale)

if self._validate_args(loc, scale):
self.batch_size_unknown = True
self.loc = loc
self.scale = scale
self.dtype = convert_dtype(loc.dtype)
Expand Down Expand Up @@ -174,8 +172,7 @@ def sample(self, shape=(), seed=0):
shape = list(shape)
batch_shape = list((self.loc + self.scale).shape)
name = self.name + '_sample'

if self.batch_size_unknown:
if -1 in batch_shape:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.dtype, 0.0
Expand Down Expand Up @@ -236,9 +233,12 @@ def entropy(self):
"""
name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.dtype, 0.0
)
if -1 in batch_shape:
zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape, self.dtype, 0.0
)
else:
zero_tmp = paddle.full(batch_shape, 0.0, self.dtype)
return paddle.add(
0.5 + zero_tmp,
0.5 * math.log(2 * math.pi) + paddle.log((self.scale + zero_tmp)),
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/distribution/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ class AbsTransform(Transform):
# Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 0., 1.])
print(abs.inverse(paddle.to_tensor(1.)))
print(abs.inverse(paddle.to_tensor([1.])))
# (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-1.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1.]))
Expand All @@ -380,7 +380,7 @@ class AbsTransform(Transform):
# 0.))
#Special case handling of 0.
print(abs.inverse(paddle.to_tensor(0.)))
print(abs.inverse(paddle.to_tensor([0.])))
# (Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.]), Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [0.]))
Expand Down
5 changes: 2 additions & 3 deletions python/paddle/distribution/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Uniform(distribution.Distribution):
sample = uniform.sample([2])
# a random tensor created by uniform distribution with shape: [2, 1]
entropy = uniform.entropy()
# [0.6931472] with shape: [1]
# [0.6931472] with shape: []
lp = uniform.log_prob(value_tensor)
# [-0.6931472] with shape: [1]
p = uniform.probs(value_tensor)
Expand Down Expand Up @@ -117,7 +117,6 @@ def __init__(self, low, high, name=None):
high = float(high)

if self._validate_args(low, high):
self.batch_size_unknown = True
self.low = low
self.high = high
self.dtype = convert_dtype(low.dtype)
Expand Down Expand Up @@ -159,7 +158,7 @@ def sample(self, shape, seed=0):

name = self.name + '_sample'
batch_shape = list((self.low + self.high).shape)
if self.batch_size_unknown:
if -1 in batch_shape:
output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.dtype, 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_sample(self):

def test_rsample(self):
shape = [5, 10, 8]
expected_shape = (5, 10, 8, 1)
expected_shape = (5, 10, 8)
data = self._t.rsample(shape)
self.assertEqual(tuple(data.shape), expected_shape)
self.assertEqual(data.dtype, self.base.loc.dtype)
Expand Down
Loading

0 comments on commit b60e325

Please sign in to comment.