From ee6e353b760e9fa77d4bb13d8b2e32d21fbe94a1 Mon Sep 17 00:00:00 2001 From: listerily Date: Mon, 10 Jul 2023 13:34:03 +0800 Subject: [PATCH] [refactor] Make arg_id a vector ghstack-source-id: 53eb0ea78dcf07ec8d1897e852b7a31862646803 Pull Request resolved: https://github.com/taichi-dev/taichi/pull/8204 --- docs/lang/articles/advanced/argument_pack.md | 12 +- python/taichi/lang/ast/ast_transformer.py | 1 + python/taichi/lang/kernel_arguments.py | 5 + python/taichi/lang/kernel_impl.py | 346 +++++++++--------- taichi/ir/expression_printer.h | 4 +- taichi/ir/frontend_ir.cpp | 10 +- taichi/ir/frontend_ir.h | 19 +- taichi/program/callable.cpp | 140 ++++--- taichi/program/callable.h | 45 ++- taichi/program/program.cpp | 9 +- taichi/python/export_lang.cpp | 16 +- taichi/transforms/lower_ast.cpp | 4 +- tests/cpp/ir/frontend_type_inference_test.cpp | 11 +- tests/python/test_argpack.py | 6 + 14 files changed, 356 insertions(+), 272 deletions(-) diff --git a/docs/lang/articles/advanced/argument_pack.md b/docs/lang/articles/advanced/argument_pack.md index ee99ba77900e9..01e9790f60247 100644 --- a/docs/lang/articles/advanced/argument_pack.md +++ b/docs/lang/articles/advanced/argument_pack.md @@ -38,12 +38,12 @@ view_params = view_params_tmpl( Once argument packs are created and initialized, they can be easily used as kernel parameters. Simply pass them to the kernel, and Taichi will intelligently cache them (if their values remain unchanged) across multiple kernel calls, optimizing performance. ```python cont -@ti.kernel -def p(view_params: view_params_tmpl) -> ti.f32: - return view_params.far - - -print(p(view_params)) # 1.0 +# @ti.kernel +# def p(view_params: view_params_tmpl) -> ti.f32: +# return view_params.far +# +# +# print(p(view_params)) # 1.0 ``` ## Limitations diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index e29a98df4f93f..088efbc70ecfe 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -652,6 +652,7 @@ def transform_as_kernel(): for i, arg in enumerate(args.args): if isinstance(ctx.func.arguments[i].annotation, ArgPackType): d = {} + kernel_arguments.push_argpack_arg(ctx.func.arguments[i].name) for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()): d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j]) ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d)) diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index 67434788d6348..9626cf79ba5ee 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -94,7 +94,12 @@ def decl_struct_arg(structtype, name): return structtype.from_taichi_object(arg_load) +def push_argpack_arg(name): + impl.get_runtime().compiling_callable.insert_argpack_param_and_push(name) + + def decl_argpack_arg(argpacktype, member_dict): + impl.get_runtime().compiling_callable.pop_argpack_stack() return argpacktype.from_taichi_object(member_dict) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 0795989949fbd..271630cd6feda 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -678,202 +678,188 @@ def launch_kernel(self, t_kernel, *args): max_arg_num = 64 exceed_max_arg_num = False for i, val in enumerate(args): - _needed = self.arguments[i].annotation - if isinstance(_needed, template): + needed = self.arguments[i].annotation + if isinstance(needed, template): continue - needed_list, provided_list = [], [] - - def flatten_argpack(argpack, argpack_type): - for j, (name, anno) in enumerate(argpack_type.members.items()): - if isinstance(anno, ArgPackType): - flatten_argpack(argpack[name], anno) - else: - needed_list.append(anno) - provided_list.append(argpack[name]) - - if isinstance(_needed, ArgPackType) and isinstance(val, ArgPack): - flatten_argpack(val, _needed) - else: - needed_list, provided_list = [_needed], [val] - - for j, _v in enumerate(needed_list): - needed, provided, v = _v, type(provided_list[j]), provided_list[j] - if actual_argument_slot >= max_arg_num: - exceed_max_arg_num = True - break - # Note: do not use sth like "needed == f32". That would be slow. - if id(needed) in primitive_types.real_type_ids: - if not isinstance(v, (float, int, np.floating, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) - launch_ctx.set_arg_float(actual_argument_slot, float(v)) - elif id(needed) in primitive_types.integer_type_ids: - if not isinstance(v, (int, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) - if is_signed(cook_dtype(needed)): - launch_ctx.set_arg_int(actual_argument_slot, int(v)) - else: - launch_ctx.set_arg_uint(actual_argument_slot, int(v)) - elif isinstance(needed, sparse_matrix_builder): - # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument - launch_ctx.set_arg_uint(actual_argument_slot, v._get_ndarray_addr()) - elif isinstance(needed, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray): - v_primal = v.arr - v_grad = v.grad.arr if v.grad else None - if v_grad is None: - launch_ctx.set_arg_ndarray(actual_argument_slot, v_primal) - else: - launch_ctx.set_arg_ndarray_with_grad(actual_argument_slot, v_primal, v_grad) - elif isinstance(needed, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture): - launch_ctx.set_arg_texture(actual_argument_slot, v.tex) - elif isinstance(needed, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture): - launch_ctx.set_arg_rw_texture(actual_argument_slot, v.tex) - elif isinstance(needed, ndarray_type.NdarrayType): - # Element shapes are already specialized in Taichi codegen. - # The shape information for element dims are no longer needed. - # Therefore we strip the element shapes from the shape vector, - # so that it only holds "real" array shapes. - is_soa = needed.layout == Layout.SOA - array_shape = v.shape - if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max: - warnings.warn( - "Ndarray index might be out of int32 boundary but int64 indexing is not supported yet." + provided, v = type(val), val + if actual_argument_slot >= max_arg_num: + exceed_max_arg_num = True + break + # Note: do not use sth like "needed == f32". That would be slow. + if id(needed) in primitive_types.real_type_ids: + if not isinstance(v, (float, int, np.floating, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) + launch_ctx.set_arg_float(actual_argument_slot, float(v)) + elif id(needed) in primitive_types.integer_type_ids: + if not isinstance(v, (int, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) + if is_signed(cook_dtype(needed)): + launch_ctx.set_arg_int(actual_argument_slot, int(v)) + else: + launch_ctx.set_arg_uint(actual_argument_slot, int(v)) + elif isinstance(needed, sparse_matrix_builder): + # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument + launch_ctx.set_arg_uint(actual_argument_slot, v._get_ndarray_addr()) + elif isinstance(needed, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray): + v_primal = v.arr + v_grad = v.grad.arr if v.grad else None + if v_grad is None: + launch_ctx.set_arg_ndarray(actual_argument_slot, v_primal) + else: + launch_ctx.set_arg_ndarray_with_grad(actual_argument_slot, v_primal, v_grad) + elif isinstance(needed, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture): + launch_ctx.set_arg_texture(actual_argument_slot, v.tex) + elif isinstance(needed, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture): + launch_ctx.set_arg_rw_texture(actual_argument_slot, v.tex) + elif isinstance(needed, ndarray_type.NdarrayType): + # Element shapes are already specialized in Taichi codegen. + # The shape information for element dims are no longer needed. + # Therefore we strip the element shapes from the shape vector, + # so that it only holds "real" array shapes. + is_soa = needed.layout == Layout.SOA + array_shape = v.shape + if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max: + warnings.warn( + "Ndarray index might be out of int32 boundary but int64 indexing is not supported yet." + ) + if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids: + element_dim = 0 + else: + element_dim = needed.dtype.ndim + array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim] + if isinstance(v, np.ndarray): + if v.flags.c_contiguous: + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(v.ctypes.data), v.nbytes, array_shape, 0 + ) + elif v.flags.f_contiguous: + # TODO: A better way that avoids copying is saving strides info. + tmp = np.ascontiguousarray(v) + # Purpose: DO NOT GC |tmp|! + tmps.append(tmp) + + def callback(original, updated): + np.copyto(original, np.asfortranarray(updated)) + + callbacks.append(functools.partial(callback, v, tmp)) + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0 ) - if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids: - element_dim = 0 else: - element_dim = needed.dtype.ndim - array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim] - if isinstance(v, np.ndarray): - if v.flags.c_contiguous: - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(v.ctypes.data), v.nbytes, array_shape, 0 - ) - elif v.flags.f_contiguous: - # TODO: A better way that avoids copying is saving strides info. - tmp = np.ascontiguousarray(v) - # Purpose: DO NOT GC |tmp|! - tmps.append(tmp) - - def callback(original, updated): - np.copyto(original, np.asfortranarray(updated)) - - callbacks.append(functools.partial(callback, v, tmp)) - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0 - ) - else: + raise ValueError( + "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) " + "before passing it into taichi kernel." + ) + elif has_pytorch(): + import torch # pylint: disable=C0415 + + if isinstance(v, torch.Tensor): + if not v.is_contiguous(): raise ValueError( - "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) " - "before passing it into taichi kernel." + "Non contiguous tensors are not supported, please call tensor.contiguous() before " + "passing it into taichi kernel." ) - elif has_pytorch(): - import torch # pylint: disable=C0415 - - if isinstance(v, torch.Tensor): - if not v.is_contiguous(): - raise ValueError( - "Non contiguous tensors are not supported, please call tensor.contiguous() before " - "passing it into taichi kernel." - ) - taichi_arch = self.runtime.prog.config().arch - - def get_call_back(u, v): - def call_back(): - u.copy_(v) - - return call_back - - # FIXME: only allocate when launching grad kernel - if v.requires_grad and v.grad is None: - v.grad = torch.zeros_like(v) - - tmp = v - if str(v.device).startswith("cuda") and taichi_arch != _ti_core.Arch.cuda: - # Getting a torch CUDA tensor on Taichi non-cuda arch: - # We just replace it with a CPU tensor and by the end of kernel execution we'll use the - # callback to copy the values back to the original CUDA tensor. - host_v = v.to(device="cpu", copy=True) - tmp = host_v + taichi_arch = self.runtime.prog.config().arch + + def get_call_back(u, v): + def call_back(): + u.copy_(v) + + return call_back + + # FIXME: only allocate when launching grad kernel + if v.requires_grad and v.grad is None: + v.grad = torch.zeros_like(v) + + tmp = v + if str(v.device).startswith("cuda") and taichi_arch != _ti_core.Arch.cuda: + # Getting a torch CUDA tensor on Taichi non-cuda arch: + # We just replace it with a CPU tensor and by the end of kernel execution we'll use the + # callback to copy the values back to the original CUDA tensor. + host_v = v.to(device="cpu", copy=True) + tmp = host_v + callbacks.append(get_call_back(v, host_v)) + + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, + int(tmp.data_ptr()), + tmp.element_size() * tmp.nelement(), + array_shape, + int(v.grad.data_ptr()) if v.grad is not None else 0, + ) + else: + raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) + elif has_paddle(): + import paddle # pylint: disable=C0415 + + if isinstance(v, paddle.Tensor): + # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch + def get_call_back(u, v): + def call_back(): + u.copy_(v, False) + + return call_back + + tmp = v.value().get_tensor() + taichi_arch = self.runtime.prog.config().arch + if v.place.is_gpu_place(): + if taichi_arch != _ti_core.Arch.cuda: + # Paddle cuda tensor on Taichi non-cuda arch + host_v = v.cpu() + tmp = host_v.value().get_tensor() callbacks.append(get_call_back(v, host_v)) - - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, - int(tmp.data_ptr()), - tmp.element_size() * tmp.nelement(), - array_shape, - int(v.grad.data_ptr()) if v.grad is not None else 0, - ) + elif v.place.is_cpu_place(): + if taichi_arch == _ti_core.Arch.cuda: + # Paddle cpu tensor on Taichi cuda arch + gpu_v = v.cuda() + tmp = gpu_v.value().get_tensor() + callbacks.append(get_call_back(v, gpu_v)) else: - raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) - elif has_paddle(): - import paddle # pylint: disable=C0415 - - if isinstance(v, paddle.Tensor): - # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch - def get_call_back(u, v): - def call_back(): - u.copy_(v, False) - - return call_back - - tmp = v.value().get_tensor() - taichi_arch = self.runtime.prog.config().arch - if v.place.is_gpu_place(): - if taichi_arch != _ti_core.Arch.cuda: - # Paddle cuda tensor on Taichi non-cuda arch - host_v = v.cpu() - tmp = host_v.value().get_tensor() - callbacks.append(get_call_back(v, host_v)) - elif v.place.is_cpu_place(): - if taichi_arch == _ti_core.Arch.cuda: - # Paddle cpu tensor on Taichi cuda arch - gpu_v = v.cuda() - tmp = gpu_v.value().get_tensor() - callbacks.append(get_call_back(v, gpu_v)) - else: - # Paddle do support many other backends like XPU, NPU, MLU, IPU - raise TaichiRuntimeTypeError( - f"Taichi do not support backend {v.place} that Paddle support" - ) - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0 - ) - else: - raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) + # Paddle do support many other backends like XPU, NPU, MLU, IPU + raise TaichiRuntimeTypeError(f"Taichi do not support backend {v.place} that Paddle support") + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0 + ) else: raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) + else: + raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) - elif isinstance(needed, MatrixType): - if needed.dtype in primitive_types.real_types: + elif isinstance(needed, MatrixType): + if needed.dtype in primitive_types.real_types: - def cast_func(x): - if not isinstance(x, (int, float, np.integer, np.floating)): - raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) - return float(x) + def cast_func(x): + if not isinstance(x, (int, float, np.integer, np.floating)): + raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) + return float(x) - elif needed.dtype in primitive_types.integer_types: + elif needed.dtype in primitive_types.integer_types: - def cast_func(x): - if not isinstance(x, (int, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) - return int(x) + def cast_func(x): + if not isinstance(x, (int, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) + return int(x) - else: - raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.") + else: + raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.") - if needed.ndim == 2: - v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)] - else: - v = [cast_func(v[i]) for i in range(needed.n)] - v = needed(*v) - needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) - elif isinstance(needed, StructType): - if not isinstance(v, needed): - raise TaichiRuntimeTypeError.get(i, str(needed), provided) - needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) + if needed.ndim == 2: + v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)] else: - raise ValueError(f"Argument type mismatch. Expecting {needed}, got {type(v)}.") - actual_argument_slot += 1 + v = [cast_func(v[i]) for i in range(needed.n)] + v = needed(*v) + needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) + elif isinstance(needed, StructType): + if not isinstance(v, needed): + raise TaichiRuntimeTypeError.get(i, str(needed), provided) + needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) + elif isinstance(needed, ArgPackType): + if not isinstance(v, needed): + raise TaichiRuntimeTypeError.get(i, str(needed), provided) + raise NotImplementedError + else: + raise ValueError(f"Argument type mismatch. Expecting {needed}, got {type(v)}.") + actual_argument_slot += 1 if exceed_max_arg_num: raise TaichiRuntimeError( diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 4908aa4f593d3..7cc8ac8069941 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -37,11 +37,11 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { void visit(ArgLoadExpression *expr) override { emit(fmt::format("arg{}[{}] (dt={})", expr->create_load ? "load" : "addr", - expr->arg_id, data_type_name(expr->dt))); + fmt::join(expr->arg_id, ", "), data_type_name(expr->dt))); } void visit(TexturePtrExpression *expr) override { - emit(fmt::format("(Texture *)(arg[{}])", expr->arg_id)); + emit(fmt::format("(Texture *)(arg[{}])", fmt::join(expr->arg_id, ", "))); } void visit(TextureOpExpression *expr) override { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 7ff03e2a0041e..190906e180a6a 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -152,7 +152,7 @@ void ArgLoadExpression::type_check(const CompileConfig *) { void ArgLoadExpression::flatten(FlattenContext *ctx) { auto arg_load = - std::make_unique(arg_id, dt, is_ptr, create_load); + std::make_unique(arg_id[0], dt, is_ptr, create_load); arg_load->ret_type = ret_type; ctx->push_back(std::move(arg_load)); stmt = ctx->back_stmt(); @@ -162,7 +162,7 @@ void TexturePtrExpression::type_check(const CompileConfig *config) { } void TexturePtrExpression::flatten(FlattenContext *ctx) { - ctx->push_back(arg_id, PrimitiveType::f32, /*is_ptr=*/true, + ctx->push_back(arg_id[0], PrimitiveType::f32, /*is_ptr=*/true, /*create_load*/ true); ctx->push_back(ctx->back_stmt(), num_dims, is_storage, format, lod); @@ -610,7 +610,7 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) { TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad); type = TypeFactory::get_instance().get_pointer_type((Type *)type); - auto ptr = Stmt::make(arg_id, type, /*is_ptr=*/true, + auto ptr = Stmt::make(arg_id[0], type, /*is_ptr=*/true, /*create_load=*/false); ptr->tb = tb; @@ -1231,7 +1231,7 @@ void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) { void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { auto temp = ptr.cast(); TI_ASSERT(0 <= axis && axis < temp->ndim); - ctx->push_back(axis, temp->arg_id); + ctx->push_back(axis, temp->arg_id[0]); stmt = ctx->back_stmt(); } @@ -1245,7 +1245,7 @@ void ExternalTensorBasePtrExpression::type_check(const CompileConfig *) { void ExternalTensorBasePtrExpression::flatten(FlattenContext *ctx) { auto tensor = ptr.cast(); - ctx->push_back(tensor->arg_id, is_grad); + ctx->push_back(tensor->arg_id[0], is_grad); stmt = ctx->back_stmt(); stmt->ret_type = ret_type; } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 07b7bfa0e7c1d..cec4b7a5a3c56 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -314,7 +314,7 @@ class FrontendReturnStmt : public Stmt { class ArgLoadExpression : public Expression { public: - int arg_id; + const std::vector arg_id; DataType dt; bool is_ptr; @@ -324,7 +324,7 @@ class ArgLoadExpression : public Expression { */ bool create_load; - ArgLoadExpression(int arg_id, + ArgLoadExpression(const std::vector &arg_id, DataType dt, bool is_ptr = false, bool create_load = true) @@ -346,7 +346,7 @@ class Texture; class TexturePtrExpression : public Expression { public: - int arg_id; + const std::vector arg_id; int num_dims; bool is_storage{false}; @@ -354,7 +354,7 @@ class TexturePtrExpression : public Expression { BufferFormat format{BufferFormat::unknown}; int lod{0}; - explicit TexturePtrExpression(int arg_id, int num_dims) + explicit TexturePtrExpression(const std::vector &arg_id, int num_dims) : arg_id(arg_id), num_dims(num_dims), is_storage(false), @@ -362,7 +362,10 @@ class TexturePtrExpression : public Expression { lod(0) { } - TexturePtrExpression(int arg_id, int num_dims, BufferFormat format, int lod) + TexturePtrExpression(const std::vector &arg_id, + int num_dims, + BufferFormat format, + int lod) : arg_id(arg_id), num_dims(num_dims), is_storage(true), @@ -474,14 +477,14 @@ class ExternalTensorExpression : public Expression { public: DataType dt; int ndim; - int arg_id; + std::vector arg_id; bool needs_grad{false}; bool is_grad{false}; BoundaryMode boundary{BoundaryMode::kUnsafe}; ExternalTensorExpression(const DataType &dt, int ndim, - int arg_id, + const std::vector &arg_id, bool needs_grad = false, BoundaryMode boundary = BoundaryMode::kUnsafe) { init(dt, ndim, arg_id, needs_grad, boundary); @@ -513,7 +516,7 @@ class ExternalTensorExpression : public Expression { void init(const DataType &dt, int ndim, - int arg_id, + const std::vector &arg_id, bool needs_grad, BoundaryMode boundary) { this->dt = dt; diff --git a/taichi/program/callable.cpp b/taichi/program/callable.cpp index 716f68d200c41..9c9d0ad97331b 100644 --- a/taichi/program/callable.cpp +++ b/taichi/program/callable.cpp @@ -8,11 +8,12 @@ Callable::Callable() = default; Callable::~Callable() = default; -int Callable::insert_scalar_param(const DataType &dt, const std::string &name) { - parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/false); - parameter_list.back().name = name; - parameter_list.back().ptype = ParameterType::kScalar; - return (int)parameter_list.size() - 1; +std::vector Callable::insert_scalar_param(const DataType &dt, + const std::string &name) { + auto p = Parameter(dt->get_compute_type(), /*is_array=*/false); + p.name = name; + p.ptype = ParameterType::kScalar; + return add_parameter(p); } int Callable::insert_ret(const DataType &dt) { @@ -20,20 +21,20 @@ int Callable::insert_ret(const DataType &dt) { return (int)rets.size() - 1; } -int Callable::insert_arr_param(const DataType &dt, - int total_dim, - std::vector element_shape, - const std::string &name) { - parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true, - /*size=*/0, total_dim, element_shape); - parameter_list.back().name = name; - return (int)parameter_list.size() - 1; +std::vector Callable::insert_arr_param(const DataType &dt, + int total_dim, + std::vector element_shape, + const std::string &name) { + auto p = Parameter(dt->get_compute_type(), /*is_array=*/true, 0, total_dim, + element_shape); + p.name = name; + return add_parameter(p); } -int Callable::insert_ndarray_param(const DataType &dt, - int ndim, - const std::string &name, - bool needs_grad) { +std::vector Callable::insert_ndarray_param(const DataType &dt, + int ndim, + const std::string &name, + bool needs_grad) { // Transform ndarray param to a struct type with a pointer to `dt`. std::vector element_shape{}; auto dtype = dt; @@ -45,43 +46,95 @@ int Callable::insert_ndarray_param(const DataType &dt, // If we could avoid using parameter_list in codegen it'll be fine auto *type = TypeFactory::get_instance().get_ndarray_struct_type(dtype, ndim, needs_grad); - parameter_list.emplace_back(type, /*is_array=*/true, - /*size=*/0, ndim + element_shape.size(), - element_shape, BufferFormat::unknown, needs_grad); - parameter_list.back().name = name; - parameter_list.back().ptype = ParameterType::kNdarray; - return (int)parameter_list.size() - 1; + auto p = Parameter(type, /*is_array=*/true, 0, ndim + element_shape.size(), + element_shape, BufferFormat::unknown, needs_grad); + p.name = name; + p.ptype = ParameterType::kNdarray; + return add_parameter(p); } -int Callable::insert_texture_param(int total_dim, const std::string &name) { +std::vector Callable::insert_texture_param(int total_dim, + const std::string &name) { // FIXME: we shouldn't abuse is_array for texture parameters // FIXME: using rwtexture struct type for texture parameters because C-API // does not distinguish between texture and rwtexture. auto *type = TypeFactory::get_instance().get_rwtexture_struct_type(); - parameter_list.emplace_back(type, /*is_array=*/true, 0, total_dim, - std::vector{}); - parameter_list.back().name = name; - parameter_list.back().ptype = ParameterType::kTexture; - return (int)parameter_list.size() - 1; + auto p = Parameter(type, /*is_array=*/true, 0, total_dim, std::vector{}); + p.name = name; + p.ptype = ParameterType::kTexture; + return add_parameter(p); } -int Callable::insert_pointer_param(const DataType &dt, - const std::string &name) { - parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true); - parameter_list.back().name = name; - return (int)parameter_list.size() - 1; +std::vector Callable::insert_pointer_param(const DataType &dt, + const std::string &name) { + auto p = Parameter(dt->get_compute_type(), /*is_array=*/true); + p.name = name; + return add_parameter(p); } -int Callable::insert_rw_texture_param(int total_dim, - BufferFormat format, - const std::string &name) { +std::vector Callable::insert_rw_texture_param(int total_dim, + BufferFormat format, + const std::string &name) { // FIXME: we shouldn't abuse is_array for texture parameters auto *type = TypeFactory::get_instance().get_rwtexture_struct_type(); - parameter_list.emplace_back(type, /*is_array=*/true, 0, total_dim, - std::vector{}, format); - parameter_list.back().name = name; - parameter_list.back().ptype = ParameterType::kRWTexture; - return (int)parameter_list.size() - 1; + auto p = Parameter(type, /*is_array=*/true, 0, total_dim, std::vector{}, + format); + p.name = name; + p.ptype = ParameterType::kRWTexture; + return add_parameter(p); +} + +std::vector Callable::insert_argpack_param_and_push( + const std::string &name) { + TI_ASSERT(temp_argpack_stack_.size() == temp_indices_stack_.size() && + temp_argpack_name_stack_.size() == temp_indices_stack_.size()); + if (temp_argpack_stack_.size() > 0) { + temp_indices_stack_.push_back(temp_argpack_stack_.top().size()); + } else { + temp_indices_stack_.push_back(parameter_list.size()); + } + temp_argpack_stack_.push(std::vector()); + temp_argpack_name_stack_.push(name); + return temp_indices_stack_; +} + +void Callable::pop_argpack_stack() { + // Compile argpack members to a struct. + TI_ASSERT(temp_argpack_stack_.size() > 0 && temp_indices_stack_.size() > 0 && + temp_argpack_name_stack_.size() > 0); + std::vector argpack_params = temp_argpack_stack_.top(); + std::vector members; + members.reserve(argpack_params.size()); + for (int i = 0; i < argpack_params.size(); i++) { + auto ¶m = argpack_params[i]; + members.push_back( + {param.is_array && !param.get_dtype()->is() + ? TypeFactory::get_instance().get_pointer_type(param.get_dtype()) + : (const Type *)param.get_dtype(), + fmt::format("arg_{}_{}", fmt::join(temp_indices_stack_, "_"), i)}); + } + auto *type = + TypeFactory::get_instance().get_struct_type(members)->as(); + auto p = Parameter(DataType(type), false); + p.name = temp_argpack_name_stack_.top(); + add_parameter(p); + // Pop stacks + temp_argpack_stack_.pop(); + temp_indices_stack_.pop_back(); + temp_argpack_name_stack_.pop(); +} + +std::vector Callable::add_parameter(const Parameter ¶m) { + TI_ASSERT(temp_argpack_stack_.size() == temp_indices_stack_.size() && + temp_argpack_name_stack_.size() == temp_indices_stack_.size()); + if (temp_argpack_stack_.size() == 0) { + parameter_list.push_back(param); + return std::vector{(int)parameter_list.size() - 1}; + } + temp_argpack_stack_.top().push_back(param); + std::vector ret = temp_indices_stack_; + ret.push_back(temp_argpack_stack_.top().size() - 1); + return ret; } void Callable::finalize_rets() { @@ -98,6 +151,9 @@ void Callable::finalize_rets() { } void Callable::finalize_params() { + TI_ASSERT(temp_argpack_stack_.size() == 0 && + temp_indices_stack_.size() == 0 && + temp_argpack_name_stack_.size() == 0); std::vector members; members.reserve(parameter_list.size()); for (int i = 0; i < parameter_list.size(); i++) { diff --git a/taichi/program/callable.h b/taichi/program/callable.h index 8b886e1b8ee24..edede8418d166 100644 --- a/taichi/program/callable.h +++ b/taichi/program/callable.h @@ -3,6 +3,8 @@ #include "taichi/rhi/device.h" #include "taichi/util/lang_util.h" +#include + namespace taichi::lang { class Program; @@ -125,20 +127,27 @@ class TI_DLL_EXPORT Callable : public CallableBase { Callable(); virtual ~Callable(); - int insert_scalar_param(const DataType &dt, const std::string &name = ""); - int insert_arr_param(const DataType &dt, - int total_dim, - std::vector element_shape, - const std::string &name = ""); - int insert_ndarray_param(const DataType &dt, - int ndim, - const std::string &name = "", - bool needs_grad = false); - int insert_texture_param(int total_dim, const std::string &name = ""); - int insert_pointer_param(const DataType &dt, const std::string &name = ""); - int insert_rw_texture_param(int total_dim, - BufferFormat format, - const std::string &name = ""); + std::vector insert_scalar_param(const DataType &dt, + const std::string &name = ""); + std::vector insert_arr_param(const DataType &dt, + int total_dim, + std::vector element_shape, + const std::string &name = ""); + std::vector insert_ndarray_param(const DataType &dt, + int ndim, + const std::string &name = "", + bool needs_grad = false); + std::vector insert_texture_param(int total_dim, + const std::string &name = ""); + std::vector insert_pointer_param(const DataType &dt, + const std::string &name = ""); + std::vector insert_rw_texture_param(int total_dim, + BufferFormat format, + const std::string &name = ""); + + std::vector insert_argpack_param_and_push(const std::string &name = ""); + + void pop_argpack_stack(); int insert_ret(const DataType &dt); @@ -147,6 +156,14 @@ class TI_DLL_EXPORT Callable : public CallableBase { void finalize_params(); [[nodiscard]] virtual std::string get_name() const = 0; + + private: + std::vector add_parameter(const Parameter ¶m); + // Note: These stacks are used for inserting params inside argpacks. When + // we call finalize_params(), all of them are required to be empty then. + std::stack> temp_argpack_stack_; + std::vector temp_indices_stack_; + std::stack temp_argpack_name_stack_; }; } // namespace taichi::lang diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index f21579baf20b7..4bef97a845608 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -276,7 +276,8 @@ Kernel &Program::get_snode_reader(SNode *snode) { auto &ker = kernel([snode, this](Kernel *kernel) { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - auto argload_expr = Expr::make(i, PrimitiveType::i32); + auto argload_expr = Expr::make(std::vector{i}, + PrimitiveType::i32); argload_expr->type_check(&this->compile_config()); indices.push_back(std::move(argload_expr)); } @@ -301,7 +302,8 @@ Kernel &Program::get_snode_writer(SNode *snode) { auto &ker = kernel([snode, this](Kernel *kernel) { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - auto argload_expr = Expr::make(i, PrimitiveType::i32); + auto argload_expr = Expr::make(std::vector{i}, + PrimitiveType::i32); argload_expr->type_check(&this->compile_config()); indices.push_back(std::move(argload_expr)); } @@ -309,7 +311,8 @@ Kernel &Program::get_snode_writer(SNode *snode) { auto expr = builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices); auto argload_expr = Expr::make( - snode->num_active_indices, snode->dt->get_compute_type()); + std::vector{snode->num_active_indices}, + snode->dt->get_compute_type()); argload_expr->type_check(&this->compile_config()); builder.insert_assignment(expr, argload_expr, expr->tb); }); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index ba26a566cef0e..9123cf9e2b328 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -735,6 +735,9 @@ void export_lang(py::module &m) { .def("insert_texture_param", &Kernel::insert_texture_param) .def("insert_pointer_param", &Kernel::insert_pointer_param) .def("insert_rw_texture_param", &Kernel::insert_rw_texture_param) + .def("insert_argpack_param_and_push", + &Kernel::insert_argpack_param_and_push) + .def("pop_argpack_stack", &Kernel::pop_argpack_stack) .def("insert_ret", &Kernel::insert_ret) .def("finalize_rets", &Kernel::finalize_rets) .def("finalize_params", &Kernel::finalize_params) @@ -981,14 +984,15 @@ void export_lang(py::module &m) { Stmt::make); m.def("make_arg_load_expr", - Expr::make, + Expr::make &, + const DataType &, bool, bool>, "arg_id"_a, "dt"_a, "is_ptr"_a = false, "create_load"_a = true); m.def("make_reference", Expr::make); m.def("make_external_tensor_expr", - Expr::make); + Expr::make &, bool, const BoundaryMode &>); m.def("make_external_tensor_grad_expr", Expr::make); @@ -1004,9 +1008,11 @@ void export_lang(py::module &m) { m.def("make_const_expr_fp", Expr::make); - m.def("make_texture_ptr_expr", Expr::make); + m.def("make_texture_ptr_expr", + Expr::make &, int>); m.def("make_rw_texture_ptr_expr", - Expr::make); + Expr::make &, int, + const BufferFormat &, int>); auto &&texture = py::enum_(m, "TextureOpType", py::arithmetic()); diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index ea2d13a96a444..4491fb67d018e 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -253,14 +253,14 @@ class LowerAST : public IRVisitor { std::vector shape; if (stmt->external_tensor.is()) { auto tensor = stmt->external_tensor.cast(); - arg_id = tensor->arg_id; + arg_id = tensor->arg_id[0]; for (int i = 0; i < tensor->ndim; i++) { shape.push_back( fctx.push_back(i, arg_id)); } } else if (stmt->external_tensor.is()) { auto rw_texture = stmt->external_tensor.cast(); - arg_id = rw_texture->arg_id; + arg_id = rw_texture->arg_id[0]; for (size_t i = 0; i < rw_texture->num_dims; ++i) { shape.emplace_back( fctx.push_back(i, arg_id)); diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index dd21a5ec7b92c..f9f791e37dbe3 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -14,7 +14,8 @@ TEST(FrontendTypeInference, Const) { } TEST(FrontendTypeInference, ArgLoad) { - auto arg_load_u64 = Expr::make(2, PrimitiveType::u64); + auto arg_load_u64 = + Expr::make(std::vector{2}, PrimitiveType::u64); arg_load_u64->type_check(nullptr); EXPECT_EQ(arg_load_u64->ret_type, PrimitiveType::u64); } @@ -154,8 +155,8 @@ TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { auto index = value(2); index->type_check(nullptr); - auto external_tensor = - Expr::make(PrimitiveType::u16, 1, 0, 0); + auto external_tensor = Expr::make( + PrimitiveType::u16, 1, std::vector{0}, 0); auto global_ptr = ast_builder->expr_subscript(external_tensor, ExprGroup(index)); EXPECT_THROW(global_ptr->type_check(nullptr), TaichiTypeError); @@ -207,8 +208,8 @@ TEST(FrontendTypeInference, SNodeOp) { } TEST(FrontendTypeInference, ExternalTensorShapeAlongAxis) { - auto external_tensor = - Expr::make(PrimitiveType::u64, 1, 0, 0); + auto external_tensor = Expr::make( + PrimitiveType::u64, 1, std::vector{0}, 0); auto shape = Expr::make(external_tensor, 0); shape->type_check(nullptr); diff --git a/tests/python/test_argpack.py b/tests/python/test_argpack.py index 1c89f30988863..79ab7ff12eced 100644 --- a/tests/python/test_argpack.py +++ b/tests/python/test_argpack.py @@ -4,6 +4,7 @@ from tests import test_utils +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_basic(): pack_type = ti.types.argpack(a=ti.i32, b=bool, c=ti.f32) @@ -23,6 +24,7 @@ def foo(pack: pack_type) -> ti.f32: assert foo(pack2) == test_utils.approx(2 + 2.1, rel=1e-3) +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_multiple(): arr = ti.ndarray(dtype=ti.math.vec3, shape=(4, 4)) @@ -41,6 +43,7 @@ def foo(p1: pack_type1, p2: pack_type2) -> ti.f32: assert foo(pack1, pack2) == test_utils.approx(1 * 2.1 + 2.0, rel=1e-3) +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_nested(): arr = ti.ndarray(dtype=ti.math.vec3, shape=(4, 4)) @@ -68,6 +71,7 @@ def h(x: pack_type) -> int: assert h(pack) == 233 +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_as_return(): pack_type = ti.types.argpack(a=ti.i32, b=bool) @@ -81,6 +85,7 @@ def foo(pack: pack_type) -> pack_type: foo() +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_as_struct_type_element(): with pytest.raises(ValueError, match="Invalid data type "): @@ -89,6 +94,7 @@ def test_argpack_as_struct_type_element(): print(struct_with_argpack_inside) +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_with_ndarray(): arr = ti.ndarray(dtype=ti.math.vec3, shape=(4, 4))