Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] historic scalers do not properly handle transfer learning #1035

Open
jmoralez opened this issue Jun 12, 2024 · 0 comments
Open

[core] historic scalers do not properly handle transfer learning #1035

jmoralez opened this issue Jun 12, 2024 · 0 comments
Assignees
Labels

Comments

@jmoralez
Copy link
Member

jmoralez commented Jun 12, 2024

What happened + What you expected to happen

When performing transfer learning if we have historic scalers we should refit them to the new data, such that when applying the scaling it has the expected statistics. Currently we reuse the scalers fitted on the original data, which produce different statistics than the expected.

Versions / Dependencies

1.7.2

Reproduction script

import logging
import os

from neuralforecast import NeuralForecast
from neuralforecast.models import LSTM
from utilsforecast.data import generate_series
from utilsforecast.losses import rmse

logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
os.environ['NIXTLA_ID_AS_COL'] = '1'

# data
series = generate_series(10, min_length=200, max_length=500)
h = 7
valid = series.groupby('unique_id', observed=True).tail(h)
train = series.drop(valid.index)
train2 = train.copy()
train2['y'] += 100
valid2 = valid.copy()
valid2['y'] += 100

# training
nf = NeuralForecast(
    models=[LSTM(input_size=2 * h, h=h, scaler_type=None, max_steps=50, val_check_steps=1, enable_progress_bar=False)],
    freq='D',
    local_scaler_type='standard',
)
nf.fit(train)

# predictions, these should be the same
preds = nf.predict()
preds2 = nf.predict(df=train2)

# comparison
def evaluate(preds, valid):
    return rmse(preds.merge(valid), models=['LSTM'])['LSTM'].mean()

evaluate(preds, valid), evaluate(preds2, valid2)
# (0.7740806216300222, 99.65305363444284)

Issue Severity

High: It blocks me from completing my task.

@jmoralez jmoralez added the bug label Jun 12, 2024
@jmoralez jmoralez self-assigned this Jun 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant