Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: use load_model() function #231

Merged
merged 4 commits into from
Jan 24, 2022
Merged

Conversation

YuriCat
Copy link
Contributor

@YuriCat YuriCat commented Jan 9, 2022

load_model() is more natural than get_model().

Especially for ONNX models, the model can be loaded with just the model_path.
Therefore, it makes more sense to have the model_path as the first argument, as in build_agent(), but I'm not sure.

@ikki407
Copy link
Member

ikki407 commented Jan 13, 2022

IMHO, it's natural to pass the model first and then the weights (load_model(model, model_path)).
How about using a model as an argument?

E.g.

model = env.net()
model = load_model(model_path, model)
# or
model = load_model(env.net(), model_path)

@YuriCat
Copy link
Contributor Author

YuriCat commented Jan 13, 2022

I was worried that unnecessary torch would be imported when we loaded ONNX models.
However, as long as the neural net is defined in the environment file, torch is already imported when an environment instance is created.

@ikki407
Copy link
Member

ikki407 commented Jan 13, 2022

Ah, I got it!!

@ikki407
Copy link
Member

ikki407 commented Jan 13, 2022

So, if the definition of the neural network is separated from the environment module, load_model() can return an ONNX model without importing PyTorch when the extension is .onnx.

@YuriCat
Copy link
Contributor Author

YuriCat commented Jan 13, 2022

That's right. This torch lazy import will be written by whoever needs it, and I think your idea is cool.

@@ -277,10 +277,9 @@ def network_match_acception(n, env_args, num_agents, port):
return agents_list


def get_model(env, model_path):
def load_model(model_path, model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@ikki407
Copy link
Member

ikki407 commented Jan 24, 2022

LGTM

@ikki407 ikki407 merged commit 57a1967 into DeNA:develop Jan 24, 2022
@YuriCat YuriCat deleted the feature/load_model branch February 11, 2022 17:56
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants