Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/reproducibility RNN #118

Merged
merged 15 commits into from
Jul 3, 2020
Merged

Fix/reproducibility RNN #118

merged 15 commits into from
Jul 3, 2020

Conversation

guillaumeraille
Copy link
Contributor

Fixes #DARTS-123.

Summary

Adds possibility to specify a random_state at model creation on RNN model use the same API as sklearn for easy usage across the whole DARTS library.

Other Information

@guillaumeraille guillaumeraille changed the base branch from master to develop June 30, 2020 13:52
"""

kwargs['output_length'] = output_length
kwargs['input_size'] = input_size
kwargs['output_size'] = output_size

# TODO : make it a util function? -> reusable in other torch models that needs fixed seed...
# set the random seed
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be a util function used in every torch model that need fixed seed. What do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably more part of the superclass

Copy link
Contributor

@TheMP TheMP Jul 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I think it could be one of kwargs in ForecastingModel and set there if possible - just need to make sure fixing the seed in one class will not leak outside the scope of current instance and affect all of the other ones.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it is enough to set it in TorchForecastingModel (at least for now)

@pennfranc
Copy link
Contributor

This was definitely missing, thanks! Just one thing: Do you think it would be possible to add this functionality to the superclass TorchForecastingModel instead of RNNModel? That way the TCN will automatically inherit this too, as well as all future torch-based models. Sorry I should have raised this idea this morning already

@guillaumeraille
Copy link
Contributor Author

This was definitely missing, thanks! Just one thing: Do you think it would be possible to add this functionality to the superclass TorchForecastingModel instead of RNNModel? That way the TCN will automatically inherit this too, as well as all future torch-based models. Sorry I should have raised this idea this morning already

I was thinking about that but then it is supposed to be there only for models that have some randomness (probably most of the torch implemented model will) but that's why I proposed to implement it as a util function used only on a selection of torch model. If you think they will all need it probably then yes we should move it. What do you think ?

@pennfranc
Copy link
Contributor

This was definitely missing, thanks! Just one thing: Do you think it would be possible to add this functionality to the superclass TorchForecastingModel instead of RNNModel? That way the TCN will automatically inherit this too, as well as all future torch-based models. Sorry I should have raised this idea this morning already

I was thinking about that but then it is supposed to be there only for models that have some randomness (probably most of the torch implemented model will) but that's why I proposed to implement it as a util function used only on a selection of torch model. If you think they will all need it probably then yes we should move it. What do you think ?

Hmm yeah I see what you mean. To be honest I'm not sure what's best. Any ideas @hrzn ?

@guillaumeraille
Copy link
Contributor Author

This was definitely missing, thanks! Just one thing: Do you think it would be possible to add this functionality to the superclass TorchForecastingModel instead of RNNModel? That way the TCN will automatically inherit this too, as well as all future torch-based models. Sorry I should have raised this idea this morning already

I was thinking about that but then it is supposed to be there only for models that have some randomness (probably most of the torch implemented model will) but that's why I proposed to implement it as a util function used only on a selection of torch model. If you think they will all need it probably then yes we should move it. What do you think ?

Hmm yeah I see what you mean. To be honest I'm not sure what's best. Any ideas @hrzn ?

After some thoughts I think adding it to the superclass is better as it will cover most of the use cases. If a inherited model is deterministic it will still work, and if you really want to enforce that you can't specify a random_state you can add a check in the inherited model .__init__

@@ -145,12 +149,22 @@ def __init__(self,
Sizes of hidden layers connecting the last hidden layer of the RNN module to the output, if any.
dropout
Fraction of neurons afected by Dropout.
random_state
Control the randomness of the weights initialization. Check this
`link <https://scikit-learn.org/stable/glossary.html#term-random-state>`_ for more details.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please correct me if I am wrong but I think that random_state from sklearn affects only function that is passed to. But here I see that torch seed will be set by random_state for all torch related pseudorandom number generation.

Copy link
Contributor Author

@guillaumeraille guillaumeraille Jun 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you are absolutely right I didn't find a cleaner way to avoid side effect do you know any ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, It would need to be added also to fit function as well, I think we use shuffle=true in there, but it look possible with fork_rng. What do you think about just using manual_seed before using the model rather than in model itself? I feel like it would be much simpler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for manual_seed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good, It would need to be added also to fit function as well, I think we use shuffle=true in there, but it look possible with fork_rng. What do you think about just using manual_seed before using the model rather than in model itself? I feel like it would be much simpler.

it would be simpler indeed however we would not have a unified API meaning the models from sklearn would need to be passed a random_state while torch would need torch.manual_seed before usage for reproducibility

@guillaumeraille
Copy link
Contributor Author

guillaumeraille commented Jul 1, 2020

In order to make sure there is no side effect and provide the end user with the same API as sklearn to fix a random state, I propose the following. Let me know what you think @Kostiiii, @TheMP, @pennfranc, @hrzn (might be too much ?)

# parent = TorchForecastingModel
class Parent:
    def __init__(self, random_state = None):
        if not hasattr(self, "_random_instance"):
            self._random_instance = np.random.RandomState(random_state) # a random_instance will be associated with the model and used in each function that require randomness
# children = a specific model (i.e. RNNModel, GRU ...)
# the darts developer add a decorator @random_method to each method that will use random number generator (rgn)
class Children(Parent):
    @random_method
    def __init__(self, **kwargs):
        print("create some model with random initial weights: {}".format(torch.randn(5)))
        super().__init__(**kwargs)
    
    @random_method
    def fit(self):
        print("train model with randomized batches {}".format(torch.randn(5)))
# in darts.utils.torch
MAX_TORCH_SEED_VALUE = (1 << 63) - 1

def random_method(decorated):
    def decorator(self, *args, **kwargs):
        if hasattr(self, "_random_instance"):
            # if parent class has been initialized already, should have a random instance -> use it
            with torch.random.fork_rng():
                torch.random.manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
                decorated(self, *args, **kwargs)
        elif "random_state" in kwargs.keys():
            # if parent class has not been initialized but a random_state was provided as argument -> use it
            self._random_instance = np.random.RandomState(kwargs["random_state"])
            with torch.random.fork_rng():
                torch.random.manual_seed(self._random_instance.randint(0, high=MAX_TORCH_SEED_VALUE))
                decorated(self, *args, **kwargs)
        else:
            # else parent class has not been initialized and no random_state provided -> default randomness (not reproducible)
            decorated(self, *args, **kwargs)
    return decorator

Usage for a darts user:

children = Children(...params, random_state=42)
children.fit(some_data)
children.predict(...)

[EDIT]
Actually it also work without the Parent class code and it is probably better as it can be applied and generalise to any methods using torch.

@hrzn
Copy link
Contributor

hrzn commented Jul 1, 2020

The decorator approach looks quite neat @guillaumeraille, I think you can go for it.

@guillaumeraille guillaumeraille merged commit 7af942f into develop Jul 3, 2020
@LeoTafti LeoTafti deleted the fix/reproducabilityRNN branch October 15, 2020 08:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants