From 2d99b41528715cfc3aafe96316f5d3bfb29bda00 Mon Sep 17 00:00:00 2001 From: listerily Date: Mon, 10 Jul 2023 19:31:58 +0800 Subject: [PATCH] [lang] [ir] Support argpack nesting ghstack-source-id: 7bde32b1715ed831bd853ab3616b52e522705201 Pull Request resolved: https://github.com/taichi-dev/taichi/pull/8273 --- python/taichi/lang/argpack.py | 17 +++++++--- python/taichi/lang/kernel_impl.py | 8 ++--- taichi/codegen/llvm/codegen_llvm.cpp | 17 +++++++--- taichi/codegen/llvm/codegen_llvm.h | 4 +-- taichi/ir/type_factory.cpp | 14 ++++++++ taichi/ir/type_factory.h | 4 +++ taichi/program/argpack.cpp | 32 +++++++++++++++---- taichi/program/argpack.h | 12 ++++--- taichi/program/launch_context_builder.cpp | 5 ++- taichi/program/program.cpp | 2 +- taichi/program/program_impl.h | 4 +-- taichi/python/export_lang.cpp | 4 +++ taichi/runtime/amdgpu/kernel_launcher.cpp | 16 +++++++--- taichi/runtime/cpu/kernel_launcher.cpp | 15 ++++++--- taichi/runtime/cuda/kernel_launcher.cpp | 21 ++++++++---- taichi/runtime/llvm/llvm_runtime_executor.cpp | 4 +-- taichi/runtime/llvm/llvm_runtime_executor.h | 2 +- .../llvm/snode_tree_buffer_manager.cpp | 2 +- .../runtime/program_impls/llvm/llvm_program.h | 4 +-- tests/cpp/aot/llvm/graph_aot_test.cpp | 8 ++--- tests/cpp/aot/llvm/kernel_aot_test.cpp | 6 ++-- tests/python/test_argpack.py | 1 - 22 files changed, 143 insertions(+), 59 deletions(-) diff --git a/python/taichi/lang/argpack.py b/python/taichi/lang/argpack.py index ead309ee9e695..a88b241394f89 100644 --- a/python/taichi/lang/argpack.py +++ b/python/taichi/lang/argpack.py @@ -198,7 +198,7 @@ def _write_to_device(self, needed, provided, v, index): if isinstance(needed, ArgPackType): if not isinstance(v, ArgPack): raise TaichiRuntimeTypeError.get(index, str(needed), str(provided)) - pass # Do nothing + self.__argpack.set_arg_nested_argpack(index, v.__argpack) else: # Note: do not use sth like "needed == f32". That would be slow. if id(needed) in primitive_types.real_type_ids: @@ -293,7 +293,14 @@ def __init__(self, **kwargs): elements.append([dtype.dtype, k]) elif isinstance(dtype, ArgPackType): self.members[k] = dtype - raise TaichiSyntaxError("ArgPack nesting is not supported currently.") + elements.append( + [ + _ti_core.DataType( + _ti_core.get_type_factory_instance().get_struct_type_for_argpack_ptr(dtype.dtype) + ), + k, + ] + ) elif isinstance(dtype, MatrixType): # Convert MatrixType to StructType if dtype.ndim == 1: @@ -314,9 +321,9 @@ def __init__(self, **kwargs): dtype = cook_dtype(dtype) self.members[k] = dtype elements.append([dtype, k]) - if len(elements) == 0: - # Use i32 as a placeholder for empty argpacks - elements.append([primitive_types.i32, k]) + if len(elements) == 0: + # Use i32 as a placeholder for empty argpacks + elements.append([primitive_types.i32, k]) self.dtype = _ti_core.get_type_factory_instance().get_argpack_type(elements) def __call__(self, *args, **kwargs): diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index ef9a40ef72515..5260208dcfbe3 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -863,14 +863,14 @@ def recursive_set_args(needed, provided, v, indices): if not isinstance(v, (float, int, np.floating, np.integer)): raise TaichiRuntimeTypeError.get(indices, needed.to_string(), provided) if in_argpack: - return 0 + return 1 launch_ctx.set_arg_float(indices, float(v)) return 1 if id(needed) in primitive_types.integer_type_ids: if not isinstance(v, (int, np.integer)): raise TaichiRuntimeTypeError.get(indices, needed.to_string(), provided) if in_argpack: - return 0 + return 1 if is_signed(cook_dtype(needed)): launch_ctx.set_arg_int(indices, int(v)) else: @@ -908,12 +908,12 @@ def recursive_set_args(needed, provided, v, indices): return 1 if isinstance(needed, MatrixType): if in_argpack: - return 0 + return 1 set_arg_matrix(indices, v, needed) return 1 if isinstance(needed, StructType): if in_argpack: - return 0 + return 1 if not isinstance(v, needed): raise TaichiRuntimeTypeError(f"Argument {provided} cannot be converted into required type {needed}") needed.set_kernel_struct_args(v, launch_ctx, indices) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 525a2e2ee3be0..9e487302ce928 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -2883,14 +2883,12 @@ void TaskCodeGenLLVM::set_struct_to_buffer( current_element, current_index); } -llvm::Value *TaskCodeGenLLVM::get_argpack_arg(std::vector arg_id, +llvm::Value *TaskCodeGenLLVM::get_argpack_arg(const std::vector &arg_id, int arg_depth, bool create_load) { const std::vector indices_l(arg_id.begin(), arg_id.begin() + arg_depth); const std::vector indices_r(arg_id.begin() + arg_depth, arg_id.end()); - auto indices_data_ptr = indices_l; - indices_data_ptr.push_back(TypeFactory::DATA_PTR_POS_IN_ARGPACK); - auto data_ptr_value = get_struct_arg(indices_data_ptr, true); + llvm::Value *data_ptr_value; auto argpack_iterator = std::find_if(current_callable->argpack_types.begin(), current_callable->argpack_types.end(), @@ -2898,6 +2896,15 @@ llvm::Value *TaskCodeGenLLVM::get_argpack_arg(std::vector arg_id, TI_ASSERT(argpack_iterator != current_callable->argpack_types.end()); const auto *argpack_type = (*argpack_iterator).second; auto *arg_type = argpack_type->get_element_type(indices_r); + if (arg_depth > 1) { + auto key = arg_id; + key.back() = TypeFactory::DATA_PTR_POS_IN_ARGPACK; + data_ptr_value = get_argpack_arg(key, arg_depth - 1, true); + } else { + auto indices_data_ptr = indices_l; + indices_data_ptr.push_back(TypeFactory::DATA_PTR_POS_IN_ARGPACK); + data_ptr_value = get_struct_arg(indices_data_ptr, true); + } std::vector gep_index; gep_index.reserve(indices_r.size()); gep_index.push_back(tlctx->get_constant(0)); @@ -2912,7 +2919,7 @@ llvm::Value *TaskCodeGenLLVM::get_argpack_arg(std::vector arg_id, return builder->CreateLoad(tlctx->get_data_type(arg_type), gep); } -llvm::Value *TaskCodeGenLLVM::get_struct_arg(std::vector index, +llvm::Value *TaskCodeGenLLVM::get_struct_arg(const std::vector &index, bool create_load) { auto *args_ptr = get_args_ptr(current_callable, get_context()); auto *args_type = current_callable->args_type; diff --git a/taichi/codegen/llvm/codegen_llvm.h b/taichi/codegen/llvm/codegen_llvm.h index d07a40fe32f7c..00d025b61d1e6 100644 --- a/taichi/codegen/llvm/codegen_llvm.h +++ b/taichi/codegen/llvm/codegen_llvm.h @@ -87,11 +87,11 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder { llvm::Value *get_arg(int i); - llvm::Value *get_argpack_arg(std::vector index, + llvm::Value *get_argpack_arg(const std::vector &index, int arg_depth, bool create_load); - llvm::Value *get_struct_arg(std::vector index, bool create_load); + llvm::Value *get_struct_arg(const std::vector &index, bool create_load); llvm::Value *get_args_ptr(const Callable *callable, llvm::Value *context); diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index 9949b258d4ac7..8dabf3dcf7f9f 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -69,6 +69,20 @@ const Type *TypeFactory::get_argpack_type( return argpack_types_[key].get(); } +const Type *TypeFactory::get_struct_type_for_argpack_ptr( + DataType dt, + const std::string &layout) { + auto *type_inner = + this->get_struct_type(dt->get_type()->as()->elements(), + layout) + ->as(); + auto *type_pointer = + this->get_pointer_type(const_cast(type_inner), false); + auto *type_outter = + this->get_struct_type({{type_pointer, "data_ptr"}})->as(); + return type_outter; +} + Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) { std::lock_guard _(pointer_mut_); diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index ad8e3a7ba422e..12e3fa35a0cf9 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -30,6 +30,10 @@ class TypeFactory { const std::vector &elements, const std::string &layout = "none"); + const Type *get_struct_type_for_argpack_ptr( + DataType dt, + const std::string &layout = "none"); + const Type *get_ndarray_struct_type(DataType dt, int ndim, bool needs_grad = false); diff --git a/taichi/program/argpack.cpp b/taichi/program/argpack.cpp index 61e86f917b433..7e47b983bb366 100644 --- a/taichi/program/argpack.cpp +++ b/taichi/program/argpack.cpp @@ -45,7 +45,7 @@ DataType ArgPack::get_data_type() const { TypedConstant ArgPack::read(const std::vector &I) const { prog_->synchronize(); size_t offset = dtype->as()->get_element_offset(I); - DataType element_dt = dtype->as()->get_element_type(I); + DataType element_dt = get_element_dt(I); size_t size = data_type_size(element_dt); taichi::lang::Device::AllocParams alloc_params; alloc_params.host_write = false; @@ -73,9 +73,9 @@ TypedConstant ArgPack::read(const std::vector &I) const { return data; } -void ArgPack::write(const std::vector &I, TypedConstant val) { +void ArgPack::write(const std::vector &I, TypedConstant val) const { size_t offset = dtype->as()->get_element_offset(I); - DataType element_dt = dtype->as()->get_element_type(I); + DataType element_dt = get_element_dt(I); size_t size = data_type_size(element_dt); if (element_dt->is_primitive(PrimitiveTypeID::f16)) { uint16_t float16 = fp16_ieee_from_fp32_value(val.val_f32); @@ -105,19 +105,39 @@ void ArgPack::write(const std::vector &I, TypedConstant val) { prog_->synchronize(); } -void ArgPack::set_arg_int(const std::vector &i, int64 val) { +void ArgPack::set_arg_int(const std::vector &i, int64 val) const { DataType element_dt = dtype->as()->get_element_type(i); write(i, TypedConstant(element_dt, val)); } -void ArgPack::set_arg_float(const std::vector &i, float64 val) { +void ArgPack::set_arg_float(const std::vector &i, float64 val) const { DataType element_dt = dtype->as()->get_element_type(i); write(i, TypedConstant(element_dt, val)); } -void ArgPack::set_arg_uint(const std::vector &i, uint64 val) { +void ArgPack::set_arg_uint(const std::vector &i, uint64 val) const { DataType element_dt = dtype->as()->get_element_type(i); write(i, TypedConstant(element_dt, val)); } +void ArgPack::set_arg_nested_argpack(int i, const ArgPack &val) const { + const std::vector indices = {i, TypeFactory::DATA_PTR_POS_IN_ARGPACK}; + DataType element_dt = get_element_dt(indices); + write(indices, + TypedConstant(element_dt, val.get_device_allocation_ptr_as_int())); +} + +void ArgPack::set_arg_nested_argpack_ptr(int i, intptr_t val) const { + const std::vector indices = {i, TypeFactory::DATA_PTR_POS_IN_ARGPACK}; + DataType element_dt = get_element_dt(indices); + write(indices, TypedConstant(element_dt, val)); +} + +DataType ArgPack::get_element_dt(const std::vector &i) const { + auto dt = dtype->as()->get_element_type(i); + if (dt->is()) + return PrimitiveType::u64; + return dt; +} + } // namespace taichi::lang diff --git a/taichi/program/argpack.h b/taichi/program/argpack.h index 382f463832709..feccc8b36a214 100644 --- a/taichi/program/argpack.h +++ b/taichi/program/argpack.h @@ -27,14 +27,18 @@ class TI_DLL_EXPORT ArgPack { std::size_t get_nelement() const; TypedConstant read(const std::vector &I) const; - void write(const std::vector &I, TypedConstant val); - void set_arg_int(const std::vector &i, int64 val); - void set_arg_uint(const std::vector &i, uint64 val); - void set_arg_float(const std::vector &i, float64 val); + void write(const std::vector &I, TypedConstant val) const; + void set_arg_int(const std::vector &i, int64 val) const; + void set_arg_uint(const std::vector &i, uint64 val) const; + void set_arg_float(const std::vector &i, float64 val) const; + void set_arg_nested_argpack(int i, const ArgPack &val) const; + void set_arg_nested_argpack_ptr(int i, intptr_t val) const; ~ArgPack(); private: Program *prog_{nullptr}; + + DataType get_element_dt(const std::vector &i) const; }; } // namespace taichi::lang diff --git a/taichi/program/launch_context_builder.cpp b/taichi/program/launch_context_builder.cpp index 5a9ff9ba38e25..02b1ac47a87d5 100644 --- a/taichi/program/launch_context_builder.cpp +++ b/taichi/program/launch_context_builder.cpp @@ -252,7 +252,10 @@ void LaunchContextBuilder::set_arg_ndarray(const std::vector &arg_id, void LaunchContextBuilder::set_arg_argpack(const std::vector &arg_id, const ArgPack &argpack) { argpack_ptrs[arg_id] = &argpack; - set_argpack_ptr(arg_id, argpack.get_device_allocation_ptr_as_int()); + if (arg_id.size() == 1) { + // Only set ptr to arg buffer if this argpack is not nested + set_argpack_ptr(arg_id, argpack.get_device_allocation_ptr_as_int()); + } // TODO: Consider renaming this method to `set_device_allocation_type` set_array_device_allocation_type(arg_id, DevAllocType::kArgPack); } diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 4bfd1aaab976e..89587a5231ff4 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -465,7 +465,7 @@ intptr_t Program::get_ndarray_data_ptr_as_int(const Ndarray *ndarray) { compile_config().arch == Arch::amdgpu) { // For the LLVM backends, device allocation is a physical pointer. data_ptr = - program_impl_->get_ndarray_alloc_info_ptr(ndarray->ndarray_alloc_); + program_impl_->get_device_alloc_info_ptr(ndarray->ndarray_alloc_); } return reinterpret_cast(data_ptr); diff --git a/taichi/program/program_impl.h b/taichi/program/program_impl.h index 14df0c7fac341..ff481156a62ff 100644 --- a/taichi/program/program_impl.h +++ b/taichi/program/program_impl.h @@ -112,9 +112,9 @@ class ProgramImpl { } // TODO: Move to Runtime Object - virtual uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc) { + virtual uint64_t *get_device_alloc_info_ptr(const DeviceAllocation &alloc) { TI_ERROR( - "get_ndarray_alloc_info_ptr() not implemented on the current backend"); + "get_device_alloc_info_ptr() not implemented on the current backend"); return nullptr; } diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 44b81bf1ad7a0..66cdd20f67364 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -592,6 +592,7 @@ void export_lang(py::module &m) { .def("set_arg_float", &ArgPack::set_arg_float) .def("set_arg_int", &ArgPack::set_arg_int) .def("set_arg_uint", &ArgPack::set_arg_uint) + .def("set_arg_nested_argpack", &ArgPack::set_arg_nested_argpack) .def_readonly("dtype", &ArgPack::dtype); py::enum_(m, "Format") @@ -1251,6 +1252,9 @@ void export_lang(py::module &m) { .def("get_ndarray_struct_type", &TypeFactory::get_ndarray_struct_type, py::arg("dt"), py::arg("ndim"), py::arg("needs_grad"), py::return_value_policy::reference) + .def("get_struct_type_for_argpack_ptr", + &TypeFactory::get_struct_type_for_argpack_ptr, py::arg("dt"), + py::arg("layout") = "none", py::return_value_policy::reference) .def( "get_argpack_type", [&](TypeFactory *factory, diff --git a/taichi/runtime/amdgpu/kernel_launcher.cpp b/taichi/runtime/amdgpu/kernel_launcher.cpp index 00102b490eb86..8d60869bcdbcc 100644 --- a/taichi/runtime/amdgpu/kernel_launcher.cpp +++ b/taichi/runtime/amdgpu/kernel_launcher.cpp @@ -59,7 +59,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, DeviceAllocation devalloc = executor->allocate_memory_on_device( arr_sz, (uint64 *)device_result_buffer); device_ptrs[data_ptr_idx] = - executor->get_ndarray_alloc_info_ptr(devalloc); + executor->get_device_alloc_info_ptr(devalloc); transfers[data_ptr_idx] = {data_ptr, devalloc}; AMDGPUDriver::get_instance().memcpy_host_to_device( @@ -71,7 +71,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, // Ndarray DeviceAllocation *ptr = static_cast(data_ptr); // Unwrapped raw ptr on device - device_ptrs[data_ptr_idx] = executor->get_ndarray_alloc_info_ptr(*ptr); + device_ptrs[data_ptr_idx] = executor->get_device_alloc_info_ptr(*ptr); ctx.set_ndarray_ptrs(key, (uint64)device_ptrs[data_ptr_idx], (uint64)ctx.array_ptrs[grad_ptr_idx]); @@ -82,8 +82,16 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, auto *argpack = ctx.argpack_ptrs[key]; auto argpack_ptr = argpack->get_device_allocation(); device_ptrs[data_ptr_idx] = - executor->get_ndarray_alloc_info_ptr(argpack_ptr); - ctx.set_argpack_ptr(key, (uint64)device_ptrs[data_ptr_idx]); + executor->get_device_alloc_info_ptr(argpack_ptr); + if (key.size() == 1) { + ctx.set_argpack_ptr(key, (uint64)device_ptrs[data_ptr_idx]); + } else { + auto key_parent = key; + key_parent.pop_back(); + auto *argpack_parent = ctx.argpack_ptrs[key_parent]; + argpack_parent->set_arg_nested_argpack_ptr( + key.back(), (uint64)device_ptrs[data_ptr_idx]); + } } } if (transfers.size() > 0) { diff --git a/taichi/runtime/cpu/kernel_launcher.cpp b/taichi/runtime/cpu/kernel_launcher.cpp index abcd409c68e84..063d0aa0fde66 100644 --- a/taichi/runtime/cpu/kernel_launcher.cpp +++ b/taichi/runtime/cpu/kernel_launcher.cpp @@ -34,14 +34,14 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, ctx.array_runtime_sizes[key] > 0) { DeviceAllocation *ptr = static_cast(ctx.array_ptrs[data_ptr_idx]); - uint64 host_ptr = (uint64)executor->get_ndarray_alloc_info_ptr(*ptr); + uint64 host_ptr = (uint64)executor->get_device_alloc_info_ptr(*ptr); ctx.set_array_device_allocation_type( key, LaunchContextBuilder::DevAllocType::kNone); auto grad_ptr = ctx.array_ptrs[grad_ptr_idx]; uint64 host_ptr_grad = grad_ptr == nullptr ? 0 - : (uint64)executor->get_ndarray_alloc_info_ptr( + : (uint64)executor->get_device_alloc_info_ptr( *static_cast(grad_ptr)); ctx.set_ndarray_ptrs(key, host_ptr, host_ptr_grad); } @@ -51,8 +51,15 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, auto *argpack = ctx.argpack_ptrs[key]; auto argpack_ptr = argpack->get_device_allocation(); uint64 host_ptr = - (uint64)executor->get_ndarray_alloc_info_ptr(argpack_ptr); - ctx.set_argpack_ptr(key, host_ptr); + (uint64)executor->get_device_alloc_info_ptr(argpack_ptr); + if (key.size() == 1) { + ctx.set_argpack_ptr(key, host_ptr); + } else { + auto key_parent = key; + key_parent.pop_back(); + auto *argpack_parent = ctx.argpack_ptrs[key_parent]; + argpack_parent->set_arg_nested_argpack_ptr(key.back(), host_ptr); + } } } for (auto task : launcher_ctx.task_funcs) { diff --git a/taichi/runtime/cuda/kernel_launcher.cpp b/taichi/runtime/cuda/kernel_launcher.cpp index c2a70eb6f4d4e..7005dae973ff8 100644 --- a/taichi/runtime/cuda/kernel_launcher.cpp +++ b/taichi/runtime/cuda/kernel_launcher.cpp @@ -83,7 +83,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, DeviceAllocation devalloc = executor->allocate_memory_on_device( arr_sz, (uint64 *)device_result_buffer); device_ptrs[data_ptr_idx] = - executor->get_ndarray_alloc_info_ptr(devalloc); + executor->get_device_alloc_info_ptr(devalloc); transfers[data_ptr_idx] = {data_ptr, devalloc}; CUDADriver::get_instance().memcpy_host_to_device( @@ -93,7 +93,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, executor->allocate_memory_on_device( arr_sz, (uint64 *)device_result_buffer); device_ptrs[grad_ptr_idx] = - executor->get_ndarray_alloc_info_ptr(grad_devalloc); + executor->get_device_alloc_info_ptr(grad_devalloc); transfers[grad_ptr_idx] = {grad_ptr, grad_devalloc}; CUDADriver::get_instance().memcpy_host_to_device( @@ -109,12 +109,11 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, // Ndarray DeviceAllocation *ptr = static_cast(data_ptr); // Unwrapped raw ptr on device - device_ptrs[data_ptr_idx] = executor->get_ndarray_alloc_info_ptr(*ptr); + device_ptrs[data_ptr_idx] = executor->get_device_alloc_info_ptr(*ptr); if (grad_ptr != nullptr) { ptr = static_cast(grad_ptr); - device_ptrs[grad_ptr_idx] = - executor->get_ndarray_alloc_info_ptr(*ptr); + device_ptrs[grad_ptr_idx] = executor->get_device_alloc_info_ptr(*ptr); } else { device_ptrs[grad_ptr_idx] = nullptr; } @@ -128,8 +127,16 @@ void KernelLauncher::launch_llvm_kernel(Handle handle, auto *argpack = ctx.argpack_ptrs[key]; auto argpack_ptr = argpack->get_device_allocation(); device_ptrs[data_ptr_idx] = - executor->get_ndarray_alloc_info_ptr(argpack_ptr); - ctx.set_argpack_ptr(key, (uint64)device_ptrs[data_ptr_idx]); + executor->get_device_alloc_info_ptr(argpack_ptr); + if (key.size() == 1) { + ctx.set_argpack_ptr(key, (uint64)device_ptrs[data_ptr_idx]); + } else { + auto key_parent = key; + key_parent.pop_back(); + auto *argpack_parent = ctx.argpack_ptrs[key_parent]; + argpack_parent->set_arg_nested_argpack_ptr( + key.back(), (uint64)device_ptrs[data_ptr_idx]); + } } } if (transfers.size() > 0) { diff --git a/taichi/runtime/llvm/llvm_runtime_executor.cpp b/taichi/runtime/llvm/llvm_runtime_executor.cpp index d54cc59d18011..180692b11df77 100644 --- a/taichi/runtime/llvm/llvm_runtime_executor.cpp +++ b/taichi/runtime/llvm/llvm_runtime_executor.cpp @@ -500,7 +500,7 @@ void LlvmRuntimeExecutor::deallocate_memory_on_device(DeviceAllocation handle) { void LlvmRuntimeExecutor::fill_ndarray(const DeviceAllocation &alloc, std::size_t size, uint32_t data) { - auto ptr = get_ndarray_alloc_info_ptr(alloc); + auto ptr = get_device_alloc_info_ptr(alloc); if (config_.arch == Arch::cuda) { #if defined(TI_WITH_CUDA) CUDADriver::get_instance().memsetd32((void *)ptr, data, size); @@ -518,7 +518,7 @@ void LlvmRuntimeExecutor::fill_ndarray(const DeviceAllocation &alloc, } } -uint64_t *LlvmRuntimeExecutor::get_ndarray_alloc_info_ptr( +uint64_t *LlvmRuntimeExecutor::get_device_alloc_info_ptr( const DeviceAllocation &alloc) { if (config_.arch == Arch::cuda) { #if defined(TI_WITH_CUDA) diff --git a/taichi/runtime/llvm/llvm_runtime_executor.h b/taichi/runtime/llvm/llvm_runtime_executor.h index ed3b2ed139cb7..35e9dfa142f4a 100644 --- a/taichi/runtime/llvm/llvm_runtime_executor.h +++ b/taichi/runtime/llvm/llvm_runtime_executor.h @@ -55,7 +55,7 @@ class LlvmRuntimeExecutor { void check_runtime_error(uint64 *result_buffer); - uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc); + uint64_t *get_device_alloc_info_ptr(const DeviceAllocation &alloc); const CompileConfig &get_config() const { return config_; diff --git a/taichi/runtime/llvm/snode_tree_buffer_manager.cpp b/taichi/runtime/llvm/snode_tree_buffer_manager.cpp index 07aba49791a3b..e6ec9581d14b0 100644 --- a/taichi/runtime/llvm/snode_tree_buffer_manager.cpp +++ b/taichi/runtime/llvm/snode_tree_buffer_manager.cpp @@ -14,7 +14,7 @@ Ptr SNodeTreeBufferManager::allocate(std::size_t size, uint64 *result_buffer) { auto devalloc = runtime_exec_->allocate_memory_on_device(size, result_buffer); snode_tree_id_to_device_alloc_[snode_tree_id] = devalloc; - return (Ptr)runtime_exec_->get_ndarray_alloc_info_ptr(devalloc); + return (Ptr)runtime_exec_->get_device_alloc_info_ptr(devalloc); } void SNodeTreeBufferManager::destroy(SNodeTree *snode_tree) { diff --git a/taichi/runtime/program_impls/llvm/llvm_program.h b/taichi/runtime/program_impls/llvm/llvm_program.h index 98d760abb6095..7badfed26ef26 100644 --- a/taichi/runtime/program_impls/llvm/llvm_program.h +++ b/taichi/runtime/program_impls/llvm/llvm_program.h @@ -111,8 +111,8 @@ class LlvmProgramImpl : public ProgramImpl { runtime_exec_->finalize(); } - uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc) override { - return runtime_exec_->get_ndarray_alloc_info_ptr(alloc); + uint64_t *get_device_alloc_info_ptr(const DeviceAllocation &alloc) override { + return runtime_exec_->get_device_alloc_info_ptr(alloc); } void fill_ndarray(const DeviceAllocation &alloc, diff --git a/tests/cpp/aot/llvm/graph_aot_test.cpp b/tests/cpp/aot/llvm/graph_aot_test.cpp index 095769d8ce549..54cbc92a33464 100644 --- a/tests/cpp/aot/llvm/graph_aot_test.cpp +++ b/tests/cpp/aot/llvm/graph_aot_test.cpp @@ -76,9 +76,9 @@ TEST(LlvmCGraph, RunGraphCpu) { exec.synchronize(); auto *data_0 = reinterpret_cast( - exec.get_ndarray_alloc_info_ptr(devalloc_arr_0)); + exec.get_device_alloc_info_ptr(devalloc_arr_0)); auto *data_1 = reinterpret_cast( - exec.get_ndarray_alloc_info_ptr(devalloc_arr_1)); + exec.get_device_alloc_info_ptr(devalloc_arr_1)); for (int i = 0; i < ArrLength; i++) { EXPECT_EQ(data_0[i], 3 * i + base0 + base1 + base2); } @@ -146,7 +146,7 @@ TEST(LlvmCGraph, RunGraphCuda) { std::vector cpu_data(ArrLength); auto *data_0 = reinterpret_cast( - exec.get_ndarray_alloc_info_ptr(devalloc_arr_0)); + exec.get_device_alloc_info_ptr(devalloc_arr_0)); CUDADriver::get_instance().memcpy_device_to_host( (void *)cpu_data.data(), (void *)data_0, ArrLength * sizeof(int32_t)); @@ -156,7 +156,7 @@ TEST(LlvmCGraph, RunGraphCuda) { } auto *data_1 = reinterpret_cast( - exec.get_ndarray_alloc_info_ptr(devalloc_arr_1)); + exec.get_device_alloc_info_ptr(devalloc_arr_1)); CUDADriver::get_instance().memcpy_device_to_host( (void *)cpu_data.data(), (void *)data_1, ArrLength * sizeof(int32_t)); diff --git a/tests/cpp/aot/llvm/kernel_aot_test.cpp b/tests/cpp/aot/llvm/kernel_aot_test.cpp index b51dc10f36a34..6c983aed4e6a4 100644 --- a/tests/cpp/aot/llvm/kernel_aot_test.cpp +++ b/tests/cpp/aot/llvm/kernel_aot_test.cpp @@ -58,8 +58,8 @@ TEST(LlvmAotTest, CpuKernel) { } k_run->launch(builder); - auto *data = reinterpret_cast( - exec.get_ndarray_alloc_info_ptr(arr_devalloc)); + auto *data = + reinterpret_cast(exec.get_device_alloc_info_ptr(arr_devalloc)); for (int i = 0; i < kArrLen; ++i) { EXPECT_EQ(data[i], i + vec[0]); } @@ -106,7 +106,7 @@ TEST(LlvmAotTest, CudaKernel) { k_run->launch(builder); auto *data = reinterpret_cast( - exec.get_ndarray_alloc_info_ptr(arr_devalloc)); + exec.get_device_alloc_info_ptr(arr_devalloc)); std::vector cpu_data(kArrLen); CUDADriver::get_instance().memcpy_device_to_host( diff --git a/tests/python/test_argpack.py b/tests/python/test_argpack.py index 534902816baad..1c89f30988863 100644 --- a/tests/python/test_argpack.py +++ b/tests/python/test_argpack.py @@ -41,7 +41,6 @@ def foo(p1: pack_type1, p2: pack_type2) -> ti.f32: assert foo(pack1, pack2) == test_utils.approx(1 * 2.1 + 2.0, rel=1e-3) -@pytest.mark.skip(reason="nested argpacks not supported currently") @test_utils.test() def test_argpack_nested(): arr = ti.ndarray(dtype=ti.math.vec3, shape=(4, 4))