Skip to content

Commit

Permalink
Add negloglik calculations for general distributions (#5)
Browse files Browse the repository at this point in the history
* update losses

* update negloligk class to generalize to any distributions

* fix negloglik loss

* update poetry file
  • Loading branch information
gmgeorg committed Apr 25, 2024
1 parent 9cab675 commit 54d66df
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 93 deletions.
117 changes: 116 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ python = "^3.10"
numpy = ">=1.11.0"
pandas = ">=1.0.0"
tensorflow = ">=2.11.0"
tensorflow_probability = ">=0.18.0"
tf-keras = ">=2.14.1"
tqdm = ">=4.62"
tensorflow-addons = ">=0.15.0"
pypress = { git = "https://github.com/gmgeorg/pypress.git", rev = "v0.0.2" }
Expand Down
85 changes: 26 additions & 59 deletions pypsps/keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,9 @@
import warnings

import tensorflow as tf
import math
from pypsps import utils


"""
import tensorflow_probability as tfp
distrs = []
for i in range(3):
distrs.append(tfp.distributions.Normal(
np.float32(y_pred_j[i, 0]), np.float32(y_pred_j[i, 1]), validate_args=False, allow_nan_stats=True, name='Normal'
))
def negloglik(y, rv_y):
print(y, rv_y.parameters)
return -rv_y.log_prob(y)
# Compare to tfp loglik
[negloglik(y_true_j[i], d) for i, d in enumerate(distrs)]
"""


def _negloglik(y: tf.Tensor, mu, sigma) -> tf.Tensor:
"""Computes negative log-likelihood of data y ~ Normal(mu, sigma)."""
negloglik_element = tf.math.log(2.0 * math.pi) / 2.0 + tf.math.log(sigma)
negloglik_element += 0.5 * tf.square((y - mu) / sigma)
return negloglik_element


@tf.keras.utils.register_keras_serializable(package="pypsps")
class NegloglikNormal(tf.keras.losses.Loss):
"""Computes the negative log-likelihood of y ~ N(mu, sigma^2)."""

def call(self, y_true, y_pred):
"""Implements the loss function call."""
y_pred_mu = y_pred[:, 0]
y_pred_scale = y_pred[:, 1]

losses = _negloglik(y_true, y_pred_mu, y_pred_scale)
if self.reduction == tf.keras.losses.Reduction.NONE:
return losses
if self.reduction == tf.keras.losses.Reduction.SUM:
return tf.reduce_sum(losses)
if self.reduction in (tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,):
return tf.reduce_mean(losses)
raise NotImplementedError("reduction='%s' is not implemented", self.reduction)


def negloglik_normal_each(y_true, y_pred):
"""Compute negative log-likelihood for y ~ Normal(mu, sigma^2)."""
y_pred_mu = y_pred[:, 0]
y_pred_scale = y_pred[:, 1]

return _negloglik(y_true, y_pred_mu, y_pred_scale)


def negloglik_normal(y_true, y_pred):
return tf.reduce_sum(negloglik_normal_each(y_true, y_pred))


@tf.keras.utils.register_keras_serializable(package="pypsps")
class OutcomeLoss(tf.keras.losses.Loss):
"""Computes outcome loss for a pypsps model with multi-output predictions.
Expand Down Expand Up @@ -139,7 +82,21 @@ def call(self, y_true, y_pred):

@tf.keras.utils.register_keras_serializable(package="pypsps")
class CausalLoss(tf.keras.losses.Loss):
"""PSPS causal loss is the sum of outcome loss + treatment loss."""
"""PSPS causal loss is the sum of outcome loss + treatment loss.
Causal loss from PSPS is based on the joint distribution P(outcome, treatment | features)
which decomposes into
Pr(Y, T | X) = Pr(Y | T, X) * Pr(T | X)
which in log-likelihood terms is
loglik(Y, T; X) = loglik(Y; T, X) + alpha * loglik(T; X)
where alpha = 1 (by default). See Eq (10) in
https://proceedings.mlr.press/v177/kelly22a/kelly22a.pdf
for details (in paper lambda == alpha).
"""

def __init__(
self,
Expand All @@ -152,6 +109,17 @@ def __init__(
] = None,
**kwargs
):
"""Initializes the causal loss class.
Args:
outcome_loss: instance of an outcome loss
treatment_loss: instance of a treatment loss
alpha: penalty parameter for the treatment loss. Defaults to 1.0 so
that total causal loss equals the joint log-likelihood.
outcome_loss_weight: weight of outcome loss; defaults to 1.0.
predictive_states_regularizer: optional; user can define a predictive
state regularizer.
"""
super().__init__(**kwargs)
assert isinstance(outcome_loss, OutcomeLoss)
assert isinstance(treatment_loss, TreatmentLoss)
Expand Down Expand Up @@ -184,7 +152,6 @@ def call(self, y_true, y_pred):
loss_outcome = self._outcome_loss(y_true, y_pred)
loss_treatment = self._treatment_loss(y_true, y_pred)

# print(loss_treatment, loss_outcome)
total_loss = (
self._outcome_loss_weight * loss_outcome + self._alpha * loss_treatment
)
Expand Down
9 changes: 5 additions & 4 deletions pypsps/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import pypress
import pypress.keras.layers
import pypress.keras.regularizers
from . import losses
from . import layers
from . import metrics

from . import losses, layers, metrics
from pypsps.keras import neglogliks


tfk = tf.keras
Expand All @@ -35,7 +35,8 @@ def _build_binary_continuous_causal_loss(
) -> losses.CausalLoss:
"""Builds an example of binary treatment & continuous outcome causal loss."""
psps_outcome_loss = losses.OutcomeLoss(
loss=losses.NegloglikNormal(reduction="none"), reduction="sum_over_batch_size"
loss=neglogliks.NegloglikNormal(reduction="none"),
reduction="sum_over_batch_size",
)
psps_treat_loss = losses.TreatmentLoss(
loss=tf.keras.losses.BinaryCrossentropy(reduction="none"),
Expand Down
71 changes: 71 additions & 0 deletions pypsps/keras/neglogliks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Module to implement distributions and log-likelihood loss fcts."""

import tensorflow as tf
import math
import tensorflow_probability as tfp
import numpy as np

tfd = tfp.distributions


@tf.keras.utils.register_keras_serializable(package="pypsps")
class NegloglikLoss(tf.keras.losses.Loss):
"""Computes the negative log-likelihood of y ~ Distribution.
This is a general purpose class for any (!) tfd.Distribution.
"""

def __init__(self, distribution_constructor: tfd.Distribution, **kwargs):
self._distribution_constructor = distribution_constructor
super().__init__(**kwargs)

def call(self, y_true, y_pred):
"""Implements the loss function call."""
if isinstance(y_pred, np.ndarray):
n_params = y_pred.shape[1]
else:
n_params = y_pred.get_shape().as_list()[1]

y_pred_cols = [tf.squeeze(c) for c in tf.split(y_pred, n_params, axis=1)]
distr = self._distribution_constructor(*y_pred_cols)
losses = -distr.log_prob(y_true)

if self.reduction == tf.keras.losses.Reduction.NONE:
return losses
if self.reduction == tf.keras.losses.Reduction.SUM:
return tf.reduce_sum(losses)
if self.reduction in (
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
):
return tf.reduce_mean(losses)
raise NotImplementedError("reduction='%s' is not implemented", self.reduction)


def _negloglik(y: tf.Tensor, mu: tf.Tensor, sigma: tf.Tensor) -> tf.Tensor:
"""Computes negative log-likelihood of data y ~ Normal(mu, sigma)."""
negloglik_element = tf.math.log(2.0 * math.pi) / 2.0 + tf.math.log(sigma)
negloglik_element += 0.5 * tf.square((y - mu) / sigma)
return negloglik_element


@tf.keras.utils.register_keras_serializable(package="pypsps")
class NegloglikNormal(tf.keras.losses.Loss):
"""Computes the negative log-likelihood of y ~ N(mu, sigma^2)."""

def call(self, y_true, y_pred):
"""Implements the loss function call."""
y_pred_mu = y_pred[:, 0]
y_pred_scale = y_pred[:, 1]

losses = _negloglik(y_true, y_pred_mu, y_pred_scale)
if self.reduction == tf.keras.losses.Reduction.NONE:
return losses
if self.reduction == tf.keras.losses.Reduction.SUM:
return tf.reduce_sum(losses)
if self.reduction in (
tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE,
tf.keras.losses.Reduction.AUTO,
):
return tf.reduce_mean(losses)
raise NotImplementedError("reduction='%s' is not implemented", self.reduction)
Loading

0 comments on commit 54d66df

Please sign in to comment.