Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support common array interfaces in python for zero-copy data sharing #5057

Open
ailzhang opened this issue May 30, 2022 · 7 comments
Open

Support common array interfaces in python for zero-copy data sharing #5057

ailzhang opened this issue May 30, 2022 · 7 comments
Labels
feature request Suggest an idea on this project

Comments

@ailzhang
Copy link
Contributor

ailzhang commented May 30, 2022

By common array interfaces, I mean

Note an alternative for this is https://dmlc.github.io/dlpack/latest/python_spec.html#syntax-for-data-interchange-with-dlpack and we already have an issue for it. #4534 Tbh I haven't explored the pros and cons of these two interfaces myself but this is something to consider before implementing. Reference: https://data-apis.org/array-api/latest/design_topics/data_interchange.html

Implementation wise it shouldn't be too hard, one reference can be https://github.com/pytorch/pytorch/pull/11984/files.

I believe jax and pytorch doesn't support importing from cuda_array_interface yet. In other words, if you create a torch tensor/jax devicearray then use their cuda_array_interface in taichi that's totally fine. But if you create a taichi ndarray and want to use its cuda_array_interface in torch/jax I believe it's not yet supported. Numba does support that tho. google/jax#1100

Also once we support these interfaces in taichi we should use them when we use numpy/torch/paddle tensors as external arrays for taichi kernels to clean things up.

cc: @k-ye

@ailzhang ailzhang added the feature request Suggest an idea on this project label May 30, 2022
@k-ye
Copy link
Member

k-ye commented May 30, 2022

Thanks for writing this up! My initial thought was triggered by a question that someone asked if they can use Taichi's GPU data with numpy withouth incurring a D2H copying. If we can interoperate with either numba or JAX, we can provide users with GPU numpy for free. Not sure whether numba or JAX is more similar to numpy, though :-)

@ailzhang
Copy link
Contributor Author

@k-ye Yea sharing CPU data with numpy (or GPU data with numba) is possible since we have physical pointer address https://github.com/taichi-dev/taichi/blob/master/taichi/program/program.cpp#L567 for cpu and cuda backend. Implementing the above array interfaces (or dlpack) should do the work.
Side note: currently our to_numpy is by default a deep-copy, note torch's to_numpy shares the underlying storage if src and target device are both cpu. We can probably add a to_numpy(..., copy=False) so that user can control whether it shares the underlying storage or not. And the implementation shouldn't be hard.

@k-ye
Copy link
Member

k-ye commented May 30, 2022

, note torch's to_numpy shares the underlying storage if src and target device are both cpu.

I wonder how the lifetime ownership problem is resolved in this case?

@ailzhang
Copy link
Contributor Author

@k-ye I believe the ownership is shared in this case as when you do numpy() it actually creates a new python tensor from the storage in pytorch and numpy steals its reference and set it as base. https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/tensor_numpy.cpp#L160-L165 In other words creating a numpy array out of existing storage from a tensor should increase the reference of the tensor.

@k-ye
Copy link
Member

k-ye commented May 30, 2022

I see, wasn't aware that torch used CPython layer API..

@salykova
Copy link

salykova commented May 9, 2023

Hi @ailzhang @k-ye

I just wanted to ask if there is any update on this issue? I didn't find any information on dlpack in the taichi docs. In particular Im interested in using taichi with jax. Jax has already implemented dlpack support. Do you maybe plan to implement support for jax arrays via taichi.ndarrays or dlpack?

@mehdiataei
Copy link

Is there any updates on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request Suggest an idea on this project
Projects
Status: Backlog
Development

No branches or pull requests

4 participants