diff --git a/nets/models/feedforward.py b/nets/models/feedforward.py index 9b5b881..6dcefe8 100644 --- a/nets/models/feedforward.py +++ b/nets/models/feedforward.py @@ -108,7 +108,7 @@ def __init__( for key, ins, outs, drop in zip( hidden_keys, - (in_features,) + hidden_features[:-1], + (in_features, *tuple(hidden_features[:-1])), hidden_features, dropout_probs[:-1], strict=True,