Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
NKNaN committed Dec 19, 2023
1 parent 70d5abf commit 3fc4ad6
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/paddle/distribution/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import paddle
from paddle.distribution.bernoulli import Bernoulli
from paddle.distribution.beta import Beta
from paddle.distribution.binomial import Binomial
from paddle.distribution.categorical import Categorical
from paddle.distribution.cauchy import Cauchy
from paddle.distribution.continuous_bernoulli import ContinuousBernoulli
Expand All @@ -28,6 +29,7 @@
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.multivariate_normal import MultivariateNormal
from paddle.distribution.normal import Normal
from paddle.distribution.poisson import Poisson
from paddle.distribution.uniform import Uniform
from paddle.framework import in_dynamic_mode

Expand Down Expand Up @@ -167,6 +169,11 @@ def _kl_beta_beta(p, q):
)


@register_kl(Binomial, Binomial)
def _kl_binomial_binomial(p, q):
return p.kl_divergence(q)


@register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(p, q):
return (
Expand Down Expand Up @@ -269,5 +276,10 @@ def _kl_lognormal_lognormal(p, q):
return p._base.kl_divergence(q._base)


@register_kl(Poisson, Poisson)
def _kl_poisson_poisson(p, q):
return p.kl_divergence(q)


def _sum_rightmost(value, n):
return value.sum(list(range(-n, 0))) if n > 0 else value

0 comments on commit 3fc4ad6

Please sign in to comment.