Skip to content

Commit

Permalink
[lang] [ir] Support argpack nesting
Browse files Browse the repository at this point in the history
ghstack-source-id: 7bde32b1715ed831bd853ab3616b52e522705201
Pull Request resolved: #8273
  • Loading branch information
listerily authored and Taichi Gardener committed Jul 13, 2023
1 parent 22a32e3 commit 2d99b41
Show file tree
Hide file tree
Showing 22 changed files with 143 additions and 59 deletions.
17 changes: 12 additions & 5 deletions python/taichi/lang/argpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def _write_to_device(self, needed, provided, v, index):
if isinstance(needed, ArgPackType):
if not isinstance(v, ArgPack):
raise TaichiRuntimeTypeError.get(index, str(needed), str(provided))
pass # Do nothing
self.__argpack.set_arg_nested_argpack(index, v.__argpack)
else:
# Note: do not use sth like "needed == f32". That would be slow.
if id(needed) in primitive_types.real_type_ids:
Expand Down Expand Up @@ -293,7 +293,14 @@ def __init__(self, **kwargs):
elements.append([dtype.dtype, k])
elif isinstance(dtype, ArgPackType):
self.members[k] = dtype
raise TaichiSyntaxError("ArgPack nesting is not supported currently.")
elements.append(
[
_ti_core.DataType(
_ti_core.get_type_factory_instance().get_struct_type_for_argpack_ptr(dtype.dtype)
),
k,
]
)
elif isinstance(dtype, MatrixType):
# Convert MatrixType to StructType
if dtype.ndim == 1:
Expand All @@ -314,9 +321,9 @@ def __init__(self, **kwargs):
dtype = cook_dtype(dtype)
self.members[k] = dtype
elements.append([dtype, k])
if len(elements) == 0:
# Use i32 as a placeholder for empty argpacks
elements.append([primitive_types.i32, k])
if len(elements) == 0:
# Use i32 as a placeholder for empty argpacks
elements.append([primitive_types.i32, k])
self.dtype = _ti_core.get_type_factory_instance().get_argpack_type(elements)

def __call__(self, *args, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,14 +863,14 @@ def recursive_set_args(needed, provided, v, indices):
if not isinstance(v, (float, int, np.floating, np.integer)):
raise TaichiRuntimeTypeError.get(indices, needed.to_string(), provided)
if in_argpack:
return 0
return 1
launch_ctx.set_arg_float(indices, float(v))
return 1
if id(needed) in primitive_types.integer_type_ids:
if not isinstance(v, (int, np.integer)):
raise TaichiRuntimeTypeError.get(indices, needed.to_string(), provided)
if in_argpack:
return 0
return 1
if is_signed(cook_dtype(needed)):
launch_ctx.set_arg_int(indices, int(v))
else:
Expand Down Expand Up @@ -908,12 +908,12 @@ def recursive_set_args(needed, provided, v, indices):
return 1
if isinstance(needed, MatrixType):
if in_argpack:
return 0
return 1
set_arg_matrix(indices, v, needed)
return 1
if isinstance(needed, StructType):
if in_argpack:
return 0
return 1
if not isinstance(v, needed):
raise TaichiRuntimeTypeError(f"Argument {provided} cannot be converted into required type {needed}")
needed.set_kernel_struct_args(v, launch_ctx, indices)
Expand Down
17 changes: 12 additions & 5 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2883,21 +2883,28 @@ void TaskCodeGenLLVM::set_struct_to_buffer(
current_element, current_index);
}

llvm::Value *TaskCodeGenLLVM::get_argpack_arg(std::vector<int> arg_id,
llvm::Value *TaskCodeGenLLVM::get_argpack_arg(const std::vector<int> &arg_id,
int arg_depth,
bool create_load) {
const std::vector<int> indices_l(arg_id.begin(), arg_id.begin() + arg_depth);
const std::vector<int> indices_r(arg_id.begin() + arg_depth, arg_id.end());
auto indices_data_ptr = indices_l;
indices_data_ptr.push_back(TypeFactory::DATA_PTR_POS_IN_ARGPACK);
auto data_ptr_value = get_struct_arg(indices_data_ptr, true);
llvm::Value *data_ptr_value;
auto argpack_iterator =
std::find_if(current_callable->argpack_types.begin(),
current_callable->argpack_types.end(),
[&](const auto &kv) { return kv.first == indices_l; });
TI_ASSERT(argpack_iterator != current_callable->argpack_types.end());
const auto *argpack_type = (*argpack_iterator).second;
auto *arg_type = argpack_type->get_element_type(indices_r);
if (arg_depth > 1) {
auto key = arg_id;
key.back() = TypeFactory::DATA_PTR_POS_IN_ARGPACK;
data_ptr_value = get_argpack_arg(key, arg_depth - 1, true);
} else {
auto indices_data_ptr = indices_l;
indices_data_ptr.push_back(TypeFactory::DATA_PTR_POS_IN_ARGPACK);
data_ptr_value = get_struct_arg(indices_data_ptr, true);
}
std::vector<llvm::Value *> gep_index;
gep_index.reserve(indices_r.size());
gep_index.push_back(tlctx->get_constant(0));
Expand All @@ -2912,7 +2919,7 @@ llvm::Value *TaskCodeGenLLVM::get_argpack_arg(std::vector<int> arg_id,
return builder->CreateLoad(tlctx->get_data_type(arg_type), gep);
}

llvm::Value *TaskCodeGenLLVM::get_struct_arg(std::vector<int> index,
llvm::Value *TaskCodeGenLLVM::get_struct_arg(const std::vector<int> &index,
bool create_load) {
auto *args_ptr = get_args_ptr(current_callable, get_context());
auto *args_type = current_callable->args_type;
Expand Down
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

llvm::Value *get_arg(int i);

llvm::Value *get_argpack_arg(std::vector<int> index,
llvm::Value *get_argpack_arg(const std::vector<int> &index,
int arg_depth,
bool create_load);

llvm::Value *get_struct_arg(std::vector<int> index, bool create_load);
llvm::Value *get_struct_arg(const std::vector<int> &index, bool create_load);

llvm::Value *get_args_ptr(const Callable *callable, llvm::Value *context);

Expand Down
14 changes: 14 additions & 0 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ const Type *TypeFactory::get_argpack_type(
return argpack_types_[key].get();
}

const Type *TypeFactory::get_struct_type_for_argpack_ptr(
DataType dt,
const std::string &layout) {
auto *type_inner =
this->get_struct_type(dt->get_type()->as<ArgPackType>()->elements(),
layout)
->as<StructType>();
auto *type_pointer =
this->get_pointer_type(const_cast<StructType *>(type_inner), false);
auto *type_outter =
this->get_struct_type({{type_pointer, "data_ptr"}})->as<StructType>();
return type_outter;
}

Type *TypeFactory::get_pointer_type(Type *element, bool is_bit_pointer) {
std::lock_guard<std::mutex> _(pointer_mut_);

Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class TypeFactory {
const std::vector<AbstractDictionaryMember> &elements,
const std::string &layout = "none");

const Type *get_struct_type_for_argpack_ptr(
DataType dt,
const std::string &layout = "none");

const Type *get_ndarray_struct_type(DataType dt,
int ndim,
bool needs_grad = false);
Expand Down
32 changes: 26 additions & 6 deletions taichi/program/argpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ DataType ArgPack::get_data_type() const {
TypedConstant ArgPack::read(const std::vector<int> &I) const {
prog_->synchronize();
size_t offset = dtype->as<ArgPackType>()->get_element_offset(I);
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(I);
DataType element_dt = get_element_dt(I);
size_t size = data_type_size(element_dt);
taichi::lang::Device::AllocParams alloc_params;
alloc_params.host_write = false;
Expand Down Expand Up @@ -73,9 +73,9 @@ TypedConstant ArgPack::read(const std::vector<int> &I) const {
return data;
}

void ArgPack::write(const std::vector<int> &I, TypedConstant val) {
void ArgPack::write(const std::vector<int> &I, TypedConstant val) const {
size_t offset = dtype->as<ArgPackType>()->get_element_offset(I);
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(I);
DataType element_dt = get_element_dt(I);
size_t size = data_type_size(element_dt);
if (element_dt->is_primitive(PrimitiveTypeID::f16)) {
uint16_t float16 = fp16_ieee_from_fp32_value(val.val_f32);
Expand Down Expand Up @@ -105,19 +105,39 @@ void ArgPack::write(const std::vector<int> &I, TypedConstant val) {
prog_->synchronize();
}

void ArgPack::set_arg_int(const std::vector<int> &i, int64 val) {
void ArgPack::set_arg_int(const std::vector<int> &i, int64 val) const {
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(i);
write(i, TypedConstant(element_dt, val));
}

void ArgPack::set_arg_float(const std::vector<int> &i, float64 val) {
void ArgPack::set_arg_float(const std::vector<int> &i, float64 val) const {
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(i);
write(i, TypedConstant(element_dt, val));
}

void ArgPack::set_arg_uint(const std::vector<int> &i, uint64 val) {
void ArgPack::set_arg_uint(const std::vector<int> &i, uint64 val) const {
DataType element_dt = dtype->as<ArgPackType>()->get_element_type(i);
write(i, TypedConstant(element_dt, val));
}

void ArgPack::set_arg_nested_argpack(int i, const ArgPack &val) const {
const std::vector<int> indices = {i, TypeFactory::DATA_PTR_POS_IN_ARGPACK};
DataType element_dt = get_element_dt(indices);
write(indices,
TypedConstant(element_dt, val.get_device_allocation_ptr_as_int()));
}

void ArgPack::set_arg_nested_argpack_ptr(int i, intptr_t val) const {
const std::vector<int> indices = {i, TypeFactory::DATA_PTR_POS_IN_ARGPACK};
DataType element_dt = get_element_dt(indices);
write(indices, TypedConstant(element_dt, val));
}

DataType ArgPack::get_element_dt(const std::vector<int> &i) const {
auto dt = dtype->as<ArgPackType>()->get_element_type(i);
if (dt->is<PointerType>())
return PrimitiveType::u64;
return dt;
}

} // namespace taichi::lang
12 changes: 8 additions & 4 deletions taichi/program/argpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,18 @@ class TI_DLL_EXPORT ArgPack {
std::size_t get_nelement() const;

TypedConstant read(const std::vector<int> &I) const;
void write(const std::vector<int> &I, TypedConstant val);
void set_arg_int(const std::vector<int> &i, int64 val);
void set_arg_uint(const std::vector<int> &i, uint64 val);
void set_arg_float(const std::vector<int> &i, float64 val);
void write(const std::vector<int> &I, TypedConstant val) const;
void set_arg_int(const std::vector<int> &i, int64 val) const;
void set_arg_uint(const std::vector<int> &i, uint64 val) const;
void set_arg_float(const std::vector<int> &i, float64 val) const;
void set_arg_nested_argpack(int i, const ArgPack &val) const;
void set_arg_nested_argpack_ptr(int i, intptr_t val) const;

~ArgPack();

private:
Program *prog_{nullptr};

DataType get_element_dt(const std::vector<int> &i) const;
};
} // namespace taichi::lang
5 changes: 4 additions & 1 deletion taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,10 @@ void LaunchContextBuilder::set_arg_ndarray(const std::vector<int> &arg_id,
void LaunchContextBuilder::set_arg_argpack(const std::vector<int> &arg_id,
const ArgPack &argpack) {
argpack_ptrs[arg_id] = &argpack;
set_argpack_ptr(arg_id, argpack.get_device_allocation_ptr_as_int());
if (arg_id.size() == 1) {
// Only set ptr to arg buffer if this argpack is not nested
set_argpack_ptr(arg_id, argpack.get_device_allocation_ptr_as_int());
}
// TODO: Consider renaming this method to `set_device_allocation_type`
set_array_device_allocation_type(arg_id, DevAllocType::kArgPack);
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ intptr_t Program::get_ndarray_data_ptr_as_int(const Ndarray *ndarray) {
compile_config().arch == Arch::amdgpu) {
// For the LLVM backends, device allocation is a physical pointer.
data_ptr =
program_impl_->get_ndarray_alloc_info_ptr(ndarray->ndarray_alloc_);
program_impl_->get_device_alloc_info_ptr(ndarray->ndarray_alloc_);
}

return reinterpret_cast<intptr_t>(data_ptr);
Expand Down
4 changes: 2 additions & 2 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ class ProgramImpl {
}

// TODO: Move to Runtime Object
virtual uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc) {
virtual uint64_t *get_device_alloc_info_ptr(const DeviceAllocation &alloc) {
TI_ERROR(
"get_ndarray_alloc_info_ptr() not implemented on the current backend");
"get_device_alloc_info_ptr() not implemented on the current backend");
return nullptr;
}

Expand Down
4 changes: 4 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,7 @@ void export_lang(py::module &m) {
.def("set_arg_float", &ArgPack::set_arg_float)
.def("set_arg_int", &ArgPack::set_arg_int)
.def("set_arg_uint", &ArgPack::set_arg_uint)
.def("set_arg_nested_argpack", &ArgPack::set_arg_nested_argpack)
.def_readonly("dtype", &ArgPack::dtype);

py::enum_<BufferFormat>(m, "Format")
Expand Down Expand Up @@ -1251,6 +1252,9 @@ void export_lang(py::module &m) {
.def("get_ndarray_struct_type", &TypeFactory::get_ndarray_struct_type,
py::arg("dt"), py::arg("ndim"), py::arg("needs_grad"),
py::return_value_policy::reference)
.def("get_struct_type_for_argpack_ptr",
&TypeFactory::get_struct_type_for_argpack_ptr, py::arg("dt"),
py::arg("layout") = "none", py::return_value_policy::reference)
.def(
"get_argpack_type",
[&](TypeFactory *factory,
Expand Down
16 changes: 12 additions & 4 deletions taichi/runtime/amdgpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
DeviceAllocation devalloc = executor->allocate_memory_on_device(
arr_sz, (uint64 *)device_result_buffer);
device_ptrs[data_ptr_idx] =
executor->get_ndarray_alloc_info_ptr(devalloc);
executor->get_device_alloc_info_ptr(devalloc);
transfers[data_ptr_idx] = {data_ptr, devalloc};

AMDGPUDriver::get_instance().memcpy_host_to_device(
Expand All @@ -71,7 +71,7 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
// Ndarray
DeviceAllocation *ptr = static_cast<DeviceAllocation *>(data_ptr);
// Unwrapped raw ptr on device
device_ptrs[data_ptr_idx] = executor->get_ndarray_alloc_info_ptr(*ptr);
device_ptrs[data_ptr_idx] = executor->get_device_alloc_info_ptr(*ptr);

ctx.set_ndarray_ptrs(key, (uint64)device_ptrs[data_ptr_idx],
(uint64)ctx.array_ptrs[grad_ptr_idx]);
Expand All @@ -82,8 +82,16 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
auto *argpack = ctx.argpack_ptrs[key];
auto argpack_ptr = argpack->get_device_allocation();
device_ptrs[data_ptr_idx] =
executor->get_ndarray_alloc_info_ptr(argpack_ptr);
ctx.set_argpack_ptr(key, (uint64)device_ptrs[data_ptr_idx]);
executor->get_device_alloc_info_ptr(argpack_ptr);
if (key.size() == 1) {
ctx.set_argpack_ptr(key, (uint64)device_ptrs[data_ptr_idx]);
} else {
auto key_parent = key;
key_parent.pop_back();
auto *argpack_parent = ctx.argpack_ptrs[key_parent];
argpack_parent->set_arg_nested_argpack_ptr(
key.back(), (uint64)device_ptrs[data_ptr_idx]);
}
}
}
if (transfers.size() > 0) {
Expand Down
15 changes: 11 additions & 4 deletions taichi/runtime/cpu/kernel_launcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
ctx.array_runtime_sizes[key] > 0) {
DeviceAllocation *ptr =
static_cast<DeviceAllocation *>(ctx.array_ptrs[data_ptr_idx]);
uint64 host_ptr = (uint64)executor->get_ndarray_alloc_info_ptr(*ptr);
uint64 host_ptr = (uint64)executor->get_device_alloc_info_ptr(*ptr);
ctx.set_array_device_allocation_type(
key, LaunchContextBuilder::DevAllocType::kNone);

auto grad_ptr = ctx.array_ptrs[grad_ptr_idx];
uint64 host_ptr_grad =
grad_ptr == nullptr ? 0
: (uint64)executor->get_ndarray_alloc_info_ptr(
: (uint64)executor->get_device_alloc_info_ptr(
*static_cast<DeviceAllocation *>(grad_ptr));
ctx.set_ndarray_ptrs(key, host_ptr, host_ptr_grad);
}
Expand All @@ -51,8 +51,15 @@ void KernelLauncher::launch_llvm_kernel(Handle handle,
auto *argpack = ctx.argpack_ptrs[key];
auto argpack_ptr = argpack->get_device_allocation();
uint64 host_ptr =
(uint64)executor->get_ndarray_alloc_info_ptr(argpack_ptr);
ctx.set_argpack_ptr(key, host_ptr);
(uint64)executor->get_device_alloc_info_ptr(argpack_ptr);
if (key.size() == 1) {
ctx.set_argpack_ptr(key, host_ptr);
} else {
auto key_parent = key;
key_parent.pop_back();
auto *argpack_parent = ctx.argpack_ptrs[key_parent];
argpack_parent->set_arg_nested_argpack_ptr(key.back(), host_ptr);
}
}
}
for (auto task : launcher_ctx.task_funcs) {
Expand Down
Loading

0 comments on commit 2d99b41

Please sign in to comment.