Skip to content

Commit

Permalink
fix codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoshe committed Oct 19, 2023
1 parent 40ea529 commit 0ef396b
Show file tree
Hide file tree
Showing 2 changed files with 469 additions and 287 deletions.
88 changes: 49 additions & 39 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3734,13 +3734,15 @@ def cdist(
)


def histogramdd(sample, bins=10, range=None, density=False, weights=None, name=None):
def histogramdd(
sample, bins=10, range=None, density=False, weights=None, name=None
):
r"""
Computes a multi-dimensional histogram of the values in a tensor.
Interprets the elements of an input tensor whose innermost dimension has size N as a collection of N-dimensional points. Maps each of the points into a set of N-dimensional bins and returns the number of points (or total weight) in each bin.
input must be a tensor with at least 2 dimensions. If input has shape (M, N), each of its M rows defines a point in N-dimensional space. If input has three or more dimensions, all but the last dimension are flattened.
Each dimension is independently associated with its own strictly increasing sequence of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D tensors. Alternatively, bin edges may be constructed automatically by passing a sequence of integers specifying the number of equal-width bins in each dimension.
Each dimension is independently associated with its own strictly increasing sequence of bin edges. Bin edges may be specified explicitly by passing a sequence of 1D tensors. Alternatively, bin edges may be constructed automatically by passing a sequence of integers specifying the number of equal-width bins in each dimension.
Args:
sample (Tensor): The input tensor.
bins (Tensor[], int[], or int): If Tensor[], defines the sequences of bin edges. If int[], defines the number of equal-width bins in each dimension. If int, defines the number of equal-width bins for all dimensions.
Expand Down Expand Up @@ -3782,10 +3784,11 @@ def histogramdd(sample, bins=10, range=None, density=False, weights=None, name=N
"""

def __check_sample(sample):
assert len(sample.shape) >= 2, (
"input sample must be a tensor with at least 2 dimensions."
)
assert (
len(sample.shape) >= 2
), "input sample must be a tensor with at least 2 dimensions."
check_variable_and_dtype(
sample,
'sample',
Expand All @@ -3795,6 +3798,7 @@ def __check_sample(sample):
],
'histogramdd',
)

def __check_bins(bins, sample): # when Tensor[], check dtype
for bins_tensor in bins:
bins_tensor = paddle.to_tensor(bins_tensor)
Expand All @@ -3807,12 +3811,13 @@ def __check_bins(bins, sample): # when Tensor[], check dtype
],
'histogramdd',
)
assert bins_tensor.dtype == sample.dtype, (
"When bins is Tensor[], the dtype of bins must be the same as sample.\n"
)
assert (
bins_tensor.dtype == sample.dtype
), "When bins is Tensor[], the dtype of bins must be the same as sample.\n"

def __check_weights(sample, weights):
if weights is None: return
if weights is None:
return
sample_shape, weights_shape = sample.shape, weights.shape
assert len(sample_shape) == len(weights_shape) + 1, (
"if weight tensor is provided,"
Expand All @@ -3832,17 +3837,20 @@ def __check_weights(sample, weights):
],
'histogramdd',
)
assert weights.dtype == sample.dtype, (
"The dtype of weights must be the same as sample.\n"
)
assert (
weights.dtype == sample.dtype
), "The dtype of weights must be the same as sample.\n"

def __check_range(D, range):
if range is None: return
if range is None:
return
check_type(range, 'range', (list, tuple), 'histogramdd')
assert D * 2 == len(range), (
"The length of range list must be %d\n" % (D * 2)
assert D * 2 == len(range), "The length of range list must be %d\n" % (
D * 2
)

check_type(density, 'density', bool, 'histogramdd')

__check_sample(sample)
# weights
__check_weights(sample, weights)
Expand All @@ -3867,19 +3875,17 @@ def __check_range(D, range):
range[:, 0] = minv
range[:, 1] = maxv
else:
range = paddle.static.setitem(
range, (slice(None), 0), minv
)
range = paddle.static.setitem(
range, (slice(None), 1), maxv
)
else: range = paddle.to_tensor(range, dtype=paddle.float32).reshape([D, 2])
range = paddle.static.setitem(range, (slice(None), 0), minv)
range = paddle.static.setitem(range, (slice(None), 1), maxv)
else:
range = paddle.to_tensor(range, dtype=paddle.float32).reshape([D, 2])
# bins to edges
edges = []
hist_shape = []
dedges = []
if isinstance(bins, (int, list)): # int or int[]
if isinstance(bins, int): bins = [bins] * D
if isinstance(bins, (int, list)): # int or int[]
if isinstance(bins, int):
bins = [bins] * D
assert len(bins) == D, (
"The length of bins must be %d when bins is a list.\n" % D
)
Expand All @@ -3891,26 +3897,28 @@ def __check_range(D, range):
e = paddle.linspace(r[0], r[1], bins[idx] + 1, 'float32')
edges.append(e)
dedges.append(e.diff())
elif isinstance(bins, tuple): # tuple with D tensors for each innermost dimension
elif isinstance(
bins, tuple
): # tuple with D tensors for each innermost dimension
__check_bins(bins, sample)
for bin in bins:
bin = paddle.to_tensor(bin)
edges.append(bin)
dedges.append(bin.diff())
else:
raise ValueError(
"Input bins must be Tensor[], int[], or int."
)
raise ValueError("Input bins must be Tensor[], int[], or int.")
hist_shape = [edge.shape[0] + 1 for edge in edges]
index_list = []
# edges shape: [D, linspaced]
# index_list shape: [D, N]
for idx, edge in enumerate(edges):
edge = paddle.to_tensor(edge)
index_list.append(paddle.searchsorted(edge, reshaped_input[:, idx], right=True))
index_list.append(
paddle.searchsorted(edge, reshaped_input[:, idx], right=True)
)
index_list = paddle.to_tensor(index_list)
for i in _range(D):
on_edge = (reshaped_input[:, i] == edges[i][-1])
on_edge = reshaped_input[:, i] == edges[i][-1]
if paddle.in_dynamic_mode():
index_list[i][on_edge] -= 1
else:
Expand All @@ -3920,21 +3928,23 @@ def __check_range(D, range):
index_list = tuple(index_list)
lut = paddle.arange(paddle.to_tensor(hist_shape).prod()).reshape(hist_shape)
flattened_index = lut[index_list]
hist = paddle.bincount(flattened_index, reshaped_weights, minlength=paddle.to_tensor(hist_shape).prod())
hist = paddle.bincount(
flattened_index,
reshaped_weights,
minlength=paddle.to_tensor(hist_shape).prod(),
)
hist = hist.reshape(hist_shape)
hist = hist.astype('float32')

core = D*(slice(1, -1),)
core = D * (slice(1, -1),)
hist = hist[core]

if density:
s = hist.sum()
for i in _range(D):
shape = D*[1]
shape = D * [1]
shape[i] = hist_shape[i] - 2
hist = hist / dedges[i].reshape(shape)
hist /= s

return (hist, edges)


return (hist, edges)
Loading

0 comments on commit 0ef396b

Please sign in to comment.