Skip to content

Commit

Permalink
Fix CI issues
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jul 19, 2024
1 parent 085d48d commit 96088f6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion lampe/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from .amnre import *
from .bnre import *
from .cnre import *
from .dcp import *
from .fmpe import *
from .mcmc import *
from .npe import *
from .nre import *
from .dcp import *
28 changes: 16 additions & 12 deletions lampe/inference/dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@
compiler installed. For more information please refer to the `installation section <https://github.com/teddykoker/torchsort?tab=readme-ov-file#install>`_
in the official repository.
Please keep in mind, that compiler's CUDA version has to match that of pytorch!
Please keep in mind, that compiler's CUDA version has to match that of PyTorch!
For certain Python + CUDA + pytorch versions combinations there are
`pre-build wheels <https://github.com/teddykoker/torchsort?tab=readme-ov-file#pre-built-wheels>`_
available.
Below is an example of how to install `torchsort` in a `conda environment <https://conda.io/projects/conda/en/latest/user-guide/getting-started.html>`_
with Python 3.11 + CUDA 12.1.1 + pytorch 2.0.0:
Below is an example of how to install `torchsort` in a `conda environment
<https://conda.io/projects/conda/en/latest/user-guide/getting-started.html>`_ with
Python 3.11 + CUDA 12.1.1 + PyTorch 2.0.0:
.. code-block:: bash
Expand All @@ -64,14 +64,18 @@
'DCPNPELoss',
]

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsort

from torch import Tensor, Size
from typing import *
from torch import Size, Tensor
from torch.distributions import Distribution
from typing import Tuple

try:
import torchsort
except ImportError:
pass


class STEhardtanh(torch.autograd.Function):
Expand All @@ -93,7 +97,7 @@ class DCPNRELoss(nn.Module):
.. math::
l & = \frac{1}{2N} \sum_{i = 1}^N
\ell(d_\phi(\theta_i, x_i)) + \ell(1 - d_\phi(\theta_{i+1}, x_i)) \\
& + \lambda 1/M \sum_{j=1}^M (\text{ECP}(1 - \alpha_j) - (1 - \alpha_j))^2
& + \lambda \frac{1}{M} \sum_{j=1}^M (\text{ECP}(1 - \alpha_j) - (1 - \alpha_j))^2
where :math:`\ell(p) = -\log p` is the negative log-likelihood and
:math:`\text{ECP}(1 - \alpha_j)` is the Expected Coverage Probability at
Expand All @@ -116,8 +120,8 @@ class DCPNRELoss(nn.Module):
def __init__(
self,
estimator: nn.Module,
prior: torch.distributions.Distribution,
proposal: torch.distributions.Distribution = None,
prior: Distribution,
proposal: Distribution = None,
lmbda: float = 5.0,
n_samples: int = 16,
calibration: bool = False,
Expand Down Expand Up @@ -203,7 +207,7 @@ class DCPNPELoss(nn.Module):
.. math::
l & = \frac{1}{N} \sum_{i = 1}^N -\log p_\phi(\theta_i | x_i) \\
& + \lambda 1/M \sum_{j=1}^M (\text{ECP}(1 - \alpha_j) - (1 - \alpha_j))^2
& + \lambda \frac{1}{M} \sum_{j=1}^M (\text{ECP}(1 - \alpha_j) - (1 - \alpha_j))^2
where :math:`\ell(p) = -\log p` is the negative log-likelihood and
:math:`\text{ECP}(1 - \alpha_j)` is the Expected Coverage Probability at
Expand Down
2 changes: 0 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

def test_NRE():
estimator = NRE(3, 5)
prior = torch.distributions.MultivariateNormal(torch.zeros(3), torch.eye(3))

# Non-batched
theta, x = randn(3), randn(5)
Expand Down Expand Up @@ -128,7 +127,6 @@ def test_NPE():

def test_NPELoss():
estimator = NPE(3, 5)
prior = torch.distributions.MultivariateNormal(torch.zeros(3), torch.eye(3))

losses = [
NPELoss(estimator),
Expand Down

0 comments on commit 96088f6

Please sign in to comment.