Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue unpickling a wrapped keras model into a Shiny app #1479

Open
sarahxj opened this issue Jun 24, 2024 · 1 comment
Open

Issue unpickling a wrapped keras model into a Shiny app #1479

sarahxj opened this issue Jun 24, 2024 · 1 comment

Comments

@sarahxj
Copy link

sarahxj commented Jun 24, 2024

I am building a Shiny app to host a pretrained model so that the user can run and view the outputs of the model without dealing with the code. The model itself is a keras model built in a function named model(), then wrapped in a scikeras KerasClassifier wrapper that is further wrapped in a hiclass LocalClassifierPerParentNode wrapper, then trained and pickled. Code mockup of how the model is built:

import tensorflow as tf
from scikeras.wrappers import KerasClassifier
from hiclass import LocalClassifierPerParentNode
import pickle

def model():
.............
return model

scikeras_model = KerasClassifier(model=model, args)
scikeras_model_lcppn = LocalClassifierPerParentNode(scikeras_model)

scikeras_model_lcppn_res = scikeras_model_lcppn.fit(x_train, y_train)

pickle.dump(scikeras_model_lcppn_res, open('filename.sav', 'wb'))

Outside of a Shiny app, the model can be successfully pickled and unpickled as long as the model() function is imported where it is unpickled. I have created a model.py file that contains the model() function definition and import it from this file. I have built a Shiny app and attempted to load the model into the app, but I get the following error:

AttributeError: Can't get attribute 'model' on <module '__main__' (built-in)>

This suggests that the script is not reading the model() import correctly, even though it is explicitly imported (i.e. from model import model) in the preceding line. I have tried defining the model() function explicitly within the app code instead of importing it and I receive the same error.

I have also tried pickling just the keras model, unpickling the keras model in the app, then wrapping it and training it there. This avoids the AttributeError above because model() no longer has to be called, but it requires training the model in the app, which is against the point of the app.

I have tested this in both Shiny Express and Shiny Core and experience the same issue in both. Since I only experience this error within a Shiny app and the model unpickles successfully outside of a Shiny app, I believe this is a Shiny-specific issue.

@gadenbuie
Copy link
Collaborator

Hi @sarahxj and thanks for the question. Do you think your issue could be generalized to a problem with importing the module module or is it specifically about the pickled model? I ask because it's much easier to provide help when we have a complete, runnable example.

It would probably be easiest for both of us if you can create a reproducible example that doesn't require scikeras or a pickled model. Or you if you could create a small model example, that could work too. You can use https://shinylive.io/py as a place to develop and share the example if the example is small enough.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants