Skip to content

Commit

Permalink
move to all relative imports and add default loss fcts (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
gmgeorg committed Apr 26, 2024
1 parent e425aef commit 1b929cb
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pypsps/datasets/kang_schafer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import pandas as pd
from pypsps.datasets import base
from . import base


class KangSchafer(base.BaseSimulator):
Expand Down
3 changes: 2 additions & 1 deletion pypsps/datasets/lalonde.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
See also here for details
https://rugg2.github.io/Lalonde%20dataset%20-%20Causal%20Inference.html
"""

from typing import Dict
import os
import pandas as pd

from pypsps.datasets import base
from . import base


_BASE_URL = "http://www.nber.org/~rdehejia/data"
Expand Down
2 changes: 1 addition & 1 deletion pypsps/datasets/lunceford_davidian.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
import enum

from pypsps.datasets import base
from . import base


class Association(enum.Enum):
Expand Down
18 changes: 13 additions & 5 deletions pypsps/keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import warnings

import tensorflow as tf
from pypsps import utils
from .. import utils
from . import neglogliks


@tf.keras.utils.register_keras_serializable(package="pypsps")
Expand Down Expand Up @@ -100,8 +101,14 @@ class CausalLoss(tf.keras.losses.Loss):

def __init__(
self,
outcome_loss: OutcomeLoss,
treatment_loss: TreatmentLoss,
outcome_loss: OutcomeLoss = OutcomeLoss(
loss=neglogliks.NegloglikNormal(reduction="none"),
reduction="sum_over_batch_size",
),
treatment_loss: TreatmentLoss = TreatmentLoss(
loss=tf.keras.losses.BinaryCrossentropy(reduction="none"),
reduction="sum_over_batch_size",
),
alpha: float = 1.0,
outcome_loss_weight: float = 1.0,
predictive_states_regularizer: Optional[
Expand All @@ -112,8 +119,9 @@ def __init__(
"""Initializes the causal loss class.
Args:
outcome_loss: instance of an outcome loss
treatment_loss: instance of a treatment loss
outcome_loss: instance of an outcome loss; defaults to a Normal log-likelihood.
treatment_loss: instance of a treatment loss; defaults to binary treatment loss
(ie binary cross entropy).
alpha: penalty parameter for the treatment loss. Defaults to 1.0 so
that total causal loss equals the joint log-likelihood.
outcome_loss_weight: weight of outcome loss; defaults to 1.0.
Expand Down
2 changes: 1 addition & 1 deletion pypsps/keras/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Module for metrics from pypsps predictions."""

import tensorflow as tf
from pypsps import utils
from .. import utils


@tf.keras.utils.register_keras_serializable(package="pypsps")
Expand Down
3 changes: 1 addition & 2 deletions pypsps/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import pypress.keras.layers
import pypress.keras.regularizers

from . import losses, layers, metrics
from pypsps.keras import neglogliks
from . import losses, layers, metrics, neglogliks


tfk = tf.keras
Expand Down

0 comments on commit 1b929cb

Please sign in to comment.