Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement If operator #306

Merged
merged 1 commit into from
Aug 18, 2024
Merged

Implement If operator #306

merged 1 commit into from
Aug 18, 2024

Conversation

robertknight
Copy link
Owner

@robertknight robertknight commented Aug 16, 2024

Implement the If operator. This is the first control flow operator that is being supported in RTen, so most of the changes are the infrastructure needed to support operators which run subgraphs as part of their execution.

TODO:

  • Tests for If operator deserialization
  • Tests for nested subgraphs
  • Rework CaptureEnv::get_input to avoid using a local value node if it is a capture
  • Rework the hack in the model converter that stores a Graph object in an IfAttrsT

Deferred to subsequent PRs

  • Ensure that constant propagation in rten-generate doesn't eliminate inputs that are needed as captured values on subsequent runs. This affects the encoder_hidden_states input in TrOCR (see Add TrOCR example #304).
  • Investigate error when running Silero VAD model that uses the If operator extensively. From an initial glance it looks like the issue is that planning doesn't take into account dependencies an operator may have due to captures in its subgraphs

This was referenced Aug 16, 2024
@robertknight robertknight force-pushed the if-operator branch 8 times, most recently from 82141a2 to 496f0e3 Compare August 18, 2024 15:17
@robertknight robertknight mentioned this pull request Aug 18, 2024
8 tasks
Add infrastructure to support control flow operators which execute subgraphs and
use it to implement the `If` operator.

The implementation of control flow operators follows ONNX by supporting
graphs as operator attributes. These graphs may refer to (_capture_)
values from the parent graph at runtime, like a closure. The `If`
operator has two subgraphs which take no inputs but instead use captures
to access values from the parent graph.

Operator inputs in RTen are represented as a list of graph-local node IDs rather
than names. Therefore an operator's input list cannot directly refer to captured
values by name. Instead when an operator input is encountered that is a capture,
a value node is created in the subgraph to represent it and the ID of that value
node is added to a list of `captures` for the graph.

At runtime when resolving the value of a node it is first looked up by ID
in inputs and locals for the current graph and if that fails, a lookup by name
is performed in the capture environment.

 - Add `If` operator to RTen model schema

 - Support deserializing operators with attribute values that are graphs

 - Extend `Operator` trait with a `run_subgraph` method that is like `run` but
   takes a capture environment as an extra argument and returns a `RunError`
   instead of an `OpError`.

 - Implement `If` operator that uses `run_subgraph`

 - Support running operators with subgraphs in `Graph::run_plan`
@robertknight robertknight marked this pull request as ready for review August 18, 2024 17:20
@robertknight robertknight merged commit 69f9ab2 into main Aug 18, 2024
2 checks passed
@robertknight robertknight deleted the if-operator branch August 18, 2024 17:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant