Skip to content

Commit

Permalink
belated updates
Browse files Browse the repository at this point in the history
  • Loading branch information
eringrant committed May 30, 2024
1 parent 1c36bd3 commit 529903d
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion nets/models/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
class MLP(eqx.Module):
"""Multi-layer perceptron."""

layers: tuple[Linear, ...]
layers: tuple[enn.Linear, ...]
dropouts: tuple[enn.Dropout, ...]
activation: Callable
final_activation: Callable
Expand Down
2 changes: 1 addition & 1 deletion nets/simulators/online_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def simulate(
exemplar_noise_scale: float,
# Sampler params.
sampler_cls: type[samplers.EpochSampler], # TODO(eringrant): Use `SingletonSampler`.
) -> tuple[pd.DataFrame, ...]:
) -> tuple[eqx.Module, pd.DataFrame]:
"""Simulate in-context learning of classification tasks."""
logging.info(f"Using JAX backend: {jax.default_backend()}\n")

Expand Down
Empty file added py.typed
Empty file.
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,8 @@ install_requires =
equinox>=0.11.0
optax>=0.1.5

[options.package_data]
nets = py.typed

[options.packages.find]
include = nets*

0 comments on commit 529903d

Please sign in to comment.