Skip to content

Releases: patrick-kidger/quax

Quax v0.0.3

02 Feb 11:11
809bcb0
Compare
Choose a tag to compare

This release is an attempt to start bringing some stability to the wild west that is Quax! :D

Documentation

We now have some shiny new documentation available at https://docs.kidger.site/quax ! Go check it out, including the tutorials for writing your own custom rules using Quax.

Design

There have been a few outstanding design questions with Quax, that this release aims to resolve. This is really the highlight of this release.

Pass Values across quaxify boundaries

Where we used to be a bit laid-back about this, we are now quite careful: the pattern is always that you should create your custom value, and that you should then pass it across a quaxify boundary. For example, this is okay:

import quax
import quax.examples.lora as lora

x = lora.LoraArray(...)
quax.quaxify(some_function)(x)

but this is not:

@quax.quaxify
def some_function(...):
    x = lora.LoraArray(...)

Unless you're about to pass that x into a nested quaxify, that is.

Previously, if a quaxified argument encountered another Value during runtime, then it would automatically quaxify that value for us. The problem with this was that this the autoquaxification only happens on primitive binds, not before running any other traced code. For example, jnp.where(..., MyValue(), ...) would not work, as this argument must be an array, and MyArray() hasn't yet interact with a quaxified argument and been wrapped in a Quax tracer.

If the previous paragraph is all technical gobblydegook to you, then the summary is that this removes a major "gotcha".

Values no longer support __add__ etc.

In line with the previous change: there is no longer a reason to want to write something like MyValue() + 1, as this is an operation without being wrapped in a quax.quaxfied ! So all such dunder methods have been removed for safety.

No more DenseArrayValue

This is just a case of simplifying the API a bit. Instead of writing rules that look like:

@quax.register(...)
def _(x: quax.DenseArrayValue, y: SomeValue):
    ...

you shoud write:

from jaxtyping import ArrayLike

@quax.register(...)
def _(x: ArrayLike, y: SomeValue):
    ...

In particular this means you no longer need to carefully think whether a normal JAX array has been wrapped into a DenseArrayValue -- instead, just use them like normal.

Previously, we had to think separately about being "in normal JAX code" (with ArrayLikes and quax.Values), and "currently writing a Quax rule" (with DenseArrayValues and quax.Values). This unifies these two things, so as to simplify the mental reasoning a bit. It also means we can remove the quax.quaxify(..., unwrap_builtin_values=...) argument we had to use to toggle between these two regimes.

Disabled dynamic tracing

(As per the discussion in #2.)

Quax will no longer perform dynamic tracing -- that is, it will only run on a primtive bind if one of the arguments to that bind are downstream of an input passed to a quax.quaxify. It will no longer run on all primitive binds that happen to occur inside the quax.quaxify-wrapped function.

This removes a spooky-action-at-a-distance. Previously, it was possible to quax.register a primitive acting only on ArrayLikes, and thus change the behaviour of that primitive even in normal non-Quax usage. For example, our very own zero library did this, as a way to create symbolic zeros from operations like "broadcast(0, shape)". This caused all kinds of havoc, with random Zeros showing up in places that weren't expected, and which the author of some other type did not know to expect!

Features

  • All examples are now built-in to the library: check out quax.examples.{lora, zeros, named, ...} !
  • Added quax.quaxify(..., filter_spec): for the advanced user who only wants to quaxify a few arguments, when working with nested quax.quaxifys. (Which is realistically probably just me at this point!) It is now possible to easily specify which arguments should be quaxified. This is just a nice to have; you could previously work around this with tree un/flattening, or capturing values via closure.
  • Better debugging by giving names to overloaded primitive rules. (Thanks @nstarman! #1)

New Contributors

Full Changelog: v0.0.2...v0.0.3

quax v0.0.2

13 Dec 05:42
Compare
Choose a tag to compare

Autogenerated release notes as follows:

Full Changelog: https://github.com/patrick-kidger/quax/commits/v0.0.2