Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] [ir] Support argpack nesting #8273

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions python/taichi/lang/argpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _write_to_device(self, needed, provided, v, index):
if isinstance(needed, ArgPackType):
if not isinstance(v, ArgPack):
raise TaichiRuntimeTypeError.get(index, str(needed), str(provided))
pass # Do nothing
self.__argpack.set_arg_nested_argpack(index, v.__argpack)
else:
# Note: do not use sth like "needed == f32". That would be slow.
if id(needed) in primitive_types.real_type_ids:
Expand Down Expand Up @@ -293,7 +293,14 @@ def __init__(self, **kwargs):
elements.append([dtype.dtype, k])
elif isinstance(dtype, ArgPackType):
self.members[k] = dtype
raise TaichiSyntaxError("ArgPack nesting is not supported currently.")
elements.append(
[
_ti_core.DataType(
_ti_core.get_type_factory_instance().get_struct_type_for_argpack_ptr(dtype.dtype)
),
k,
]
)
elif isinstance(dtype, MatrixType):
# Convert MatrixType to StructType
if dtype.ndim == 1:
Expand All @@ -314,9 +321,9 @@ def __init__(self, **kwargs):
dtype = cook_dtype(dtype)
self.members[k] = dtype
elements.append([dtype, k])
if len(elements) == 0:
# Use i32 as a placeholder for empty argpacks
elements.append([primitive_types.i32, k])
if len(elements) == 0:
# Use i32 as a placeholder for empty argpacks
elements.append([primitive_types.i32, k])
self.dtype = _ti_core.get_type_factory_instance().get_argpack_type(elements)

def __call__(self, *args, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,14 +863,14 @@ def recursive_set_args(needed, provided, v, indices):
if not isinstance(v, (float, int, np.floating, np.integer)):
raise TaichiRuntimeTypeError.get(indices, needed.to_string(), provided)
if in_argpack:
return 0
return 1
launch_ctx.set_arg_float(indices, float(v))
return 1
if id(needed) in primitive_types.integer_type_ids:
if not isinstance(v, (int, np.integer)):
raise TaichiRuntimeTypeError.get(indices, needed.to_string(), provided)
if in_argpack:
return 0
return 1
if is_signed(cook_dtype(needed)):
launch_ctx.set_arg_int(indices, int(v))
else:
Expand Down Expand Up @@ -908,12 +908,12 @@ def recursive_set_args(needed, provided, v, indices):
return 1
if isinstance(needed, MatrixType):
if in_argpack:
return 0
return 1
set_arg_matrix(indices, v, needed)
return 1
if isinstance(needed, StructType):
if in_argpack:
return 0
return 1
if not isinstance(v, needed):
raise TaichiRuntimeTypeError(f"Argument {provided} cannot be converted into required type {needed}")
needed.set_kernel_struct_args(v, launch_ctx, indices)
Expand Down
17 changes: 12 additions & 5 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2883,21 +2883,28 @@ void TaskCodeGenLLVM::set_struct_to_buffer(
current_element, current_index);
}

llvm::Value *TaskCodeGenLLVM::get_argpack_arg(std::vector<int> arg_id,
llvm::Value *TaskCodeGenLLVM::get_argpack_arg(const std::vector<int> &arg_id,
int arg_depth,
bool create_load) {
const std::vector<int> indices_l(arg_id.begin(), arg_id.begin() + arg_depth);
const std::vector<int> indices_r(arg_id.begin() + arg_depth, arg_id.end());
auto indices_data_ptr = indices_l;
indices_data_ptr.push_back(TypeFactory::DATA_PTR_POS_IN_ARGPACK);
auto data_ptr_value = get_struct_arg(indices_data_ptr, true);
llvm::Value *data_ptr_value;
auto argpack_iterator =
std::find_if(current_callable->argpack_types.begin(),
current_callable->argpack_types.end(),
[&](const auto &kv) { return kv.first == indices_l; });
TI_ASSERT(argpack_iterator != current_callable->argpack_types.end());
const auto *argpack_type = (*argpack_iterator).second;
auto *arg_type = argpack_type->get_element_type(indices_r);
if (arg_depth > 1) {
auto key = arg_id;
key.back() = TypeFactory::DATA_PTR_POS_IN_ARGPACK;
data_ptr_value = get_argpack_arg(key, arg_depth - 1, true);
} else {
auto indices_data_ptr = indices_l;
indices_data_ptr.push_back(TypeFactory::DATA_PTR_POS_IN_ARGPACK);
data_ptr_value = get_struct_arg(indices_data_ptr, true);
}
std::vector<llvm::Value *> gep_index;
gep_index.reserve(indices_r.size());
gep_index.push_back(tlctx->get_constant(0));
Expand All @@ -2912,7 +2919,7 @@ llvm::Value *TaskCodeGenLLVM::get_argpack_arg(std::vector<int> arg_id,
return builder->CreateLoad(tlctx->get_data_type(arg_type), gep);
}

llvm::Value *TaskCodeGenLLVM::get_struct_arg(std::vector<int> index,
llvm::Value *TaskCodeGenLLVM::get_struct_arg(const std::vector<int> &index,
bool create_load) {
auto *args_ptr = get_args_ptr(current_callable, get_context());
auto *args_type = current_callable->args_type;
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *get_arg(int i);

llvm::Value *get_argpack_arg(std::vector<int> index,
llvm::Value *get_argpack_arg(const std::vector<int> &index,
int arg_depth,
bool create_load);

llvm::Value *get_struct_arg(std::vector<int> index, bool create_load);
llvm::Value *get_struct_arg(const std::vector<int> &index, bool create_load);

llvm::Value *get_args_ptr(const Callable *callable, llvm::Value *context);

Expand Down
14 changes: 14 additions & 0 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ const Type *TypeFactory::get_argpack_type(
return argpack_types_[key].get();
}

const Type *TypeFactory::get_struct_type_for_argpack_ptr(
DataType dt,
const std::string &layout) {
auto *type_inner =
this->get_struct_type(dt->get_type()->as<ArgPackType>()->elements(),
layout)
->as<StructType>();
auto *type_pointer =
this->get_pointer_type(const_cast<StructType *>(type_inner), false);
auto *type_outter =
this->get_struct_type({{type_pointer, "data_ptr"}})->as<StructType>();
return type_outter;
}

Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
std::lock_guard<std::mutex> _(pointer_mut_);

Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class TypeFactory {
const std::vector<AbstractDictionaryMember> &elements,
const std::string &layout = "none");

const Type *get_struct_type_for_argpack_ptr(
DataType dt,
const std::string &layout = "none");

const Type *get_ndarray_struct_type(DataType dt,
int ndim,
bool needs_grad = false);
Expand Down
32 changes: 26 additions & 6 deletions taichi/program/argpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ DataType ArgPack::get_data_type() const {
TypedConstant ArgPack::read(const std::vector<int> &I) const {
prog_->synchronize();
size_t offset = dtype->as<ArgPackType>()->get_element_offset(I);
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(I);
DataType element_dt = get_element_dt(I);
size_t size = data_type_size(element_dt);
taichi::lang::Device::AllocParams alloc_params;
alloc_params.host_write = false;
Expand Down Expand Up @@ -73,9 +73,9 @@ TypedConstant ArgPack::read(const std::vector<int> &I) const {
return data;
}

void ArgPack::write(const std::vector<int> &I, TypedConstant val) {
void ArgPack::write(const std::vector<int> &I, TypedConstant val) const {
size_t offset = dtype->as<ArgPackType>()->get_element_offset(I);
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(I);
DataType element_dt = get_element_dt(I);
size_t size = data_type_size(element_dt);
if (element_dt->is_primitive(PrimitiveTypeID::f16)) {
uint16_t float16 = fp16_ieee_from_fp32_value(val.val_f32);
Expand Down Expand Up @@ -105,19 +105,39 @@ void ArgPack::write(const std::vector<int> &I, TypedConstant val) {
prog_->synchronize();
}

void ArgPack::set_arg_int(const std::vector<int> &i, int64 val) {
void ArgPack::set_arg_int(const std::vector<int> &i, int64 val) const {
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(i);
write(i, TypedConstant(element_dt, val));
}

void ArgPack::set_arg_float(const std::vector<int> &i, float64 val) {
void ArgPack::set_arg_float(const std::vector<int> &i, float64 val) const {
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(i);
write(i, TypedConstant(element_dt, val));
}

void ArgPack::set_arg_uint(const std::vector<int> &i, uint64 val) {
void ArgPack::set_arg_uint(const std::vector<int> &i, uint64 val) const {
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(i);
write(i, TypedConstant(element_dt, val));
}

void ArgPack::set_arg_nested_argpack(int i, const ArgPack &val) const {
const std::vector<int> indices = {i, TypeFactory::DATA_PTR_POS_IN_ARGPACK};
DataType element_dt = get_element_dt(indices);
write(indices,
TypedConstant(element_dt, val.get_device_allocation_ptr_as_int()));
}

void ArgPack::set_arg_nested_argpack_ptr(int i, intptr_t val) const {
const std::vector<int> indices = {i, TypeFactory::DATA_PTR_POS_IN_ARGPACK};
DataType element_dt = get_element_dt(indices);
write(indices, TypedConstant(element_dt, val));
}

DataType ArgPack::get_element_dt(const std::vector<int> &i) const {
auto dt = dtype->as<ArgPackType>()->get_element_type(i);
if (dt->is<PointerType>())
return PrimitiveType::u64;
return dt;
}

} // namespace taichi::lang
12 changes: 8 additions & 4 deletions taichi/program/argpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ class TI_DLL_EXPORT ArgPack {
std::size_t get_nelement() const;

TypedConstant read(const std::vector<int> &I) const;
void write(const std::vector<int> &I, TypedConstant val);
void set_arg_int(const std::vector<int> &i, int64 val);
void set_arg_uint(const std::vector<int> &i, uint64 val);
void set_arg_float(const std::vector<int> &i, float64 val);
void write(const std::vector<int> &I, TypedConstant val) const;
void set_arg_int(const std::vector<int> &i, int64 val) const;
void set_arg_uint(const std::vector<int> &i, uint64 val) const;
void set_arg_float(const std::vector<int> &i, float64 val) const;
void set_arg_nested_argpack(int i, const ArgPack &val) const;
void set_arg_nested_argpack_ptr(int i, intptr_t val) const;

~ArgPack();

private:
Program *prog_{nullptr};

DataType get_element_dt(const std::vector<int> &i) const;
};
} // namespace taichi::lang
5 changes: 4 additions & 1 deletion taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ void LaunchContextBuilder::set_arg_ndarray(const std::vector<int> &arg_id,
void LaunchContextBuilder::set_arg_argpack(const std::vector<int> &arg_id,
const ArgPack &argpack) {
argpack_ptrs[arg_id] = &argpack;
set_argpack_ptr(arg_id, argpack.get_device_allocation_ptr_as_int());
if (arg_id.size() == 1) {
// Only set ptr to arg buffer if this argpack is not nested
set_argpack_ptr(arg_id, argpack.get_device_allocation_ptr_as_int());
}
// TODO: Consider renaming this method to `set_device_allocation_type`
set_array_device_allocation_type(arg_id, DevAllocType::kArgPack);
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ intptr_t Program::get_ndarray_data_ptr_as_int(const Ndarray *ndarray) {
compile_config().arch == Arch::amdgpu) {
// For the LLVM backends, device allocation is a physical pointer.
data_ptr =
program_impl_->get_ndarray_alloc_info_ptr(ndarray->ndarray_alloc_);
program_impl_->get_device_alloc_info_ptr(ndarray->ndarray_alloc_);
}

return reinterpret_cast<intptr_t>(data_ptr);
Expand Down
4 changes: 2 additions & 2 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ class ProgramImpl {
}

// TODO: Move to Runtime Object
virtual uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc) {
virtual uint64_t *get_device_alloc_info_ptr(const DeviceAllocation &alloc) {
TI_ERROR(
"get_ndarray_alloc_info_ptr() not implemented on the current backend");
"get_device_alloc_info_ptr() not implemented on the current backend");
return nullptr;
}

Expand Down
4 changes: 4 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ void export_lang(py::module &m) {
.def("set_arg_float", &ArgPack::set_arg_float)
.def("set_arg_int", &ArgPack::set_arg_int)
.def("set_arg_uint", &ArgPack::set_arg_uint)
.def("set_arg_nested_argpack", &ArgPack::set_arg_nested_argpack)
.def_readonly("dtype", &ArgPack::dtype);

py::enum_<BufferFormat>(m, "Format")
Expand Down Expand Up @@ -1251,6 +1252,9 @@ void export_lang(py::module &m) {
.def("get_ndarray_struct_type", &TypeFactory::get_ndarray_struct_type,
py::arg("dt"), py::arg("ndim"), py::arg("needs_grad"),
py::return_value_policy::reference)
.def("get_struct_type_for_argpack_ptr",
&TypeFactory::get_struct_type_for_argpack_ptr, py::arg("dt"),
py::arg("layout") = "none", py::return_value_policy::reference)
.def(
"get_argpack_type",
[&](TypeFactory *factory,
Expand Down
16 changes: 12 additions & 4 deletions taichi/runtime/amdgpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
DeviceAllocation devalloc = executor->allocate_memory_on_device(
arr_sz, (uint64 *)device_result_buffer);
device_ptrs[data_ptr_idx] =
executor->get_ndarray_alloc_info_ptr(devalloc);
executor->get_device_alloc_info_ptr(devalloc);
transfers[data_ptr_idx] = {data_ptr, devalloc};

AMDGPUDriver::get_instance().memcpy_host_to_device(
Expand All @@ -71,7 +71,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
// Ndarray
DeviceAllocation *ptr = static_cast<DeviceAllocation *>(data_ptr);
// Unwrapped raw ptr on device
device_ptrs[data_ptr_idx] = executor->get_ndarray_alloc_info_ptr(*ptr);
device_ptrs[data_ptr_idx] = executor->get_device_alloc_info_ptr(*ptr);

ctx.set_ndarray_ptrs(key, (uint64)device_ptrs[data_ptr_idx],
(uint64)ctx.array_ptrs[grad_ptr_idx]);
Expand All @@ -82,8 +82,16 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
auto *argpack = ctx.argpack_ptrs[key];
auto argpack_ptr = argpack->get_device_allocation();
device_ptrs[data_ptr_idx] =
executor->get_ndarray_alloc_info_ptr(argpack_ptr);
ctx.set_argpack_ptr(key, (uint64)device_ptrs[data_ptr_idx]);
executor->get_device_alloc_info_ptr(argpack_ptr);
if (key.size() == 1) {
ctx.set_argpack_ptr(key, (uint64)device_ptrs[data_ptr_idx]);
} else {
auto key_parent = key;
key_parent.pop_back();
auto *argpack_parent = ctx.argpack_ptrs[key_parent];
argpack_parent->set_arg_nested_argpack_ptr(
key.back(), (uint64)device_ptrs[data_ptr_idx]);
}
}
}
if (transfers.size() > 0) {
Expand Down
15 changes: 11 additions & 4 deletions taichi/runtime/cpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
ctx.array_runtime_sizes[key] > 0) {
DeviceAllocation *ptr =
static_cast<DeviceAllocation *>(ctx.array_ptrs[data_ptr_idx]);
uint64 host_ptr = (uint64)executor->get_ndarray_alloc_info_ptr(*ptr);
uint64 host_ptr = (uint64)executor->get_device_alloc_info_ptr(*ptr);
ctx.set_array_device_allocation_type(
key, LaunchContextBuilder::DevAllocType::kNone);

auto grad_ptr = ctx.array_ptrs[grad_ptr_idx];
uint64 host_ptr_grad =
grad_ptr == nullptr ? 0
: (uint64)executor->get_ndarray_alloc_info_ptr(
: (uint64)executor->get_device_alloc_info_ptr(
*static_cast<DeviceAllocation *>(grad_ptr));
ctx.set_ndarray_ptrs(key, host_ptr, host_ptr_grad);
}
Expand All @@ -51,8 +51,15 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
auto *argpack = ctx.argpack_ptrs[key];
auto argpack_ptr = argpack->get_device_allocation();
uint64 host_ptr =
(uint64)executor->get_ndarray_alloc_info_ptr(argpack_ptr);
ctx.set_argpack_ptr(key, host_ptr);
(uint64)executor->get_device_alloc_info_ptr(argpack_ptr);
if (key.size() == 1) {
ctx.set_argpack_ptr(key, host_ptr);
} else {
auto key_parent = key;
key_parent.pop_back();
auto *argpack_parent = ctx.argpack_ptrs[key_parent];
argpack_parent->set_arg_nested_argpack_ptr(key.back(), host_ptr);
}
}
}
for (auto task : launcher_ctx.task_funcs) {
Expand Down
Loading
Loading