diff --git a/nets/models/feedforward.py b/nets/models/feedforward.py index 6dcefe8..3d261d0 100644 --- a/nets/models/feedforward.py +++ b/nets/models/feedforward.py @@ -64,7 +64,7 @@ def __init__( class MLP(eqx.Module): """Multi-layer perceptron.""" - layers: tuple[Linear, ...] + layers: tuple[enn.Linear, ...] dropouts: tuple[enn.Dropout, ...] activation: Callable final_activation: Callable diff --git a/nets/simulators/online_sgd.py b/nets/simulators/online_sgd.py index ccd9ef9..efc496d 100644 --- a/nets/simulators/online_sgd.py +++ b/nets/simulators/online_sgd.py @@ -234,7 +234,7 @@ def simulate( exemplar_noise_scale: float, # Sampler params. sampler_cls: type[samplers.EpochSampler], # TODO(eringrant): Use `SingletonSampler`. -) -> tuple[pd.DataFrame, ...]: +) -> tuple[eqx.Module, pd.DataFrame]: """Simulate in-context learning of classification tasks.""" logging.info(f"Using JAX backend: {jax.default_backend()}\n") diff --git a/py.typed b/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/setup.cfg b/setup.cfg index acde8df..1e390aa 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,5 +30,8 @@ install_requires = equinox>=0.11.0 optax>=0.1.5 +[options.package_data] +nets = py.typed + [options.packages.find] include = nets*