-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support common array interfaces in python for zero-copy data sharing #5057
Comments
Thanks for writing this up! My initial thought was triggered by a question that someone asked if they can use Taichi's GPU data with numpy withouth incurring a D2H copying. If we can interoperate with either numba or JAX, we can provide users with GPU numpy for free. Not sure whether numba or JAX is more similar to numpy, though :-) |
@k-ye Yea sharing CPU data with numpy (or GPU data with numba) is possible since we have physical pointer address https://github.com/taichi-dev/taichi/blob/master/taichi/program/program.cpp#L567 for cpu and cuda backend. Implementing the above array interfaces (or dlpack) should do the work. |
I wonder how the lifetime ownership problem is resolved in this case? |
@k-ye I believe the ownership is shared in this case as when you do |
I see, wasn't aware that torch used CPython layer API.. |
I just wanted to ask if there is any update on this issue? I didn't find any information on dlpack in the taichi docs. In particular Im interested in using taichi with jax. Jax has already implemented dlpack support. Do you maybe plan to implement support for jax arrays via taichi.ndarrays or dlpack? |
Is there any updates on this? |
By common array interfaces, I mean
Note an alternative for this is https://dmlc.github.io/dlpack/latest/python_spec.html#syntax-for-data-interchange-with-dlpack and we already have an issue for it. #4534 Tbh I haven't explored the pros and cons of these two interfaces myself but this is something to consider before implementing. Reference: https://data-apis.org/array-api/latest/design_topics/data_interchange.html
Implementation wise it shouldn't be too hard, one reference can be https://github.com/pytorch/pytorch/pull/11984/files.
I believe jax and pytorch doesn't support importing from cuda_array_interface yet. In other words, if you create a torch tensor/jax devicearray then use their cuda_array_interface in taichi that's totally fine. But if you create a taichi ndarray and want to use its cuda_array_interface in torch/jax I believe it's not yet supported. Numba does support that tho. google/jax#1100
Also once we support these interfaces in taichi we should use them when we use numpy/torch/paddle tensors as external arrays for taichi kernels to clean things up.
cc: @k-ye
The text was updated successfully, but these errors were encountered: