diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 6e5ab0f69eda93..7d3514e57a29e1 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -649,6 +649,7 @@ def transform_as_kernel(): for i, arg in enumerate(args.args): if isinstance(ctx.func.arguments[i].annotation, ArgPackType): d = {} + kernel_arguments.push_argpack_arg(ctx.func.arguments[i].name) for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()): d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j]) ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d)) diff --git a/python/taichi/lang/kernel_arguments.py b/python/taichi/lang/kernel_arguments.py index 94f237622a7255..8afef16afe05ae 100644 --- a/python/taichi/lang/kernel_arguments.py +++ b/python/taichi/lang/kernel_arguments.py @@ -56,7 +56,7 @@ def decl_scalar_arg(dtype, name): arg_id = impl.get_runtime().compiling_callable.insert_pointer_param(dtype, name) else: arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(dtype, name) - return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref)) + return Expr(_ti_core.make_arg_load_expr((arg_id,), dtype, is_ref)) def get_type_for_kernel_args(dtype, name): @@ -83,18 +83,23 @@ def get_type_for_kernel_args(dtype, name): def decl_matrix_arg(matrixtype, name): arg_type = get_type_for_kernel_args(matrixtype, name) arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name) - arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False)) + arg_load = Expr(_ti_core.make_arg_load_expr((arg_id,), arg_type, create_load=False)) return matrixtype.from_taichi_object(arg_load) def decl_struct_arg(structtype, name): arg_type = get_type_for_kernel_args(structtype, name) arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name) - arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False)) + arg_load = Expr(_ti_core.make_arg_load_expr((arg_id,), arg_type, create_load=False)) return structtype.from_taichi_object(arg_load) +def push_argpack_arg(name): + impl.get_runtime().compiling_callable.insert_argpack_param_and_push(name) + + def decl_argpack_arg(argpacktype, member_dict): + impl.get_runtime().compiling_callable.pop_argpack_stack() return argpacktype.from_taichi_object(member_dict) @@ -103,25 +108,25 @@ def decl_sparse_matrix(dtype, name): ptr_type = cook_dtype(u64) # Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(ptr_type, name) - return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type) + return SparseMatrixProxy(_ti_core.make_arg_load_expr((arg_id,), ptr_type, False), value_type) def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary): arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad) - return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary)) + return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, (arg_id,), needs_grad, boundary)) def decl_texture_arg(num_dimensions, name): # FIXME: texture_arg doesn't have element_shape so better separate them arg_id = impl.get_runtime().compiling_callable.insert_texture_param(num_dimensions, name) - return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions), num_dimensions) + return TextureSampler(_ti_core.make_texture_ptr_expr((arg_id,), num_dimensions), num_dimensions) def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name): # FIXME: texture_arg doesn't have element_shape so better separate them arg_id = impl.get_runtime().compiling_callable.insert_rw_texture_param(num_dimensions, buffer_format, name) return RWTextureAccessor( - _ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, buffer_format, lod), + _ti_core.make_rw_texture_ptr_expr((arg_id,), num_dimensions, buffer_format, lod), num_dimensions, ) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index c6fe985cb34ee2..fc5e75df7df70f 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -633,202 +633,188 @@ def launch_kernel(self, t_kernel, *args): max_arg_num = 64 exceed_max_arg_num = False for i, val in enumerate(args): - _needed = self.arguments[i].annotation - if isinstance(_needed, template): + needed = self.arguments[i].annotation + if isinstance(needed, template): continue - needed_list, provided_list = [], [] - - def flatten_argpack(argpack, argpack_type): - for j, (name, anno) in enumerate(argpack_type.members.items()): - if isinstance(anno, ArgPackType): - flatten_argpack(argpack[name], anno) - else: - needed_list.append(anno) - provided_list.append(argpack[name]) - - if isinstance(_needed, ArgPackType) and isinstance(val, ArgPack): - flatten_argpack(val, _needed) - else: - needed_list, provided_list = [_needed], [val] - - for j, _v in enumerate(needed_list): - needed, provided, v = _v, type(provided_list[j]), provided_list[j] - if actual_argument_slot >= max_arg_num: - exceed_max_arg_num = True - break - # Note: do not use sth like "needed == f32". That would be slow. - if id(needed) in primitive_types.real_type_ids: - if not isinstance(v, (float, int, np.floating, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) - launch_ctx.set_arg_float(actual_argument_slot, float(v)) - elif id(needed) in primitive_types.integer_type_ids: - if not isinstance(v, (int, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) - if is_signed(cook_dtype(needed)): - launch_ctx.set_arg_int(actual_argument_slot, int(v)) - else: - launch_ctx.set_arg_uint(actual_argument_slot, int(v)) - elif isinstance(needed, sparse_matrix_builder): - # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument - launch_ctx.set_arg_uint(actual_argument_slot, v._get_ndarray_addr()) - elif isinstance(needed, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray): - v_primal = v.arr - v_grad = v.grad.arr if v.grad else None - if v_grad is None: - launch_ctx.set_arg_ndarray(actual_argument_slot, v_primal) - else: - launch_ctx.set_arg_ndarray_with_grad(actual_argument_slot, v_primal, v_grad) - elif isinstance(needed, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture): - launch_ctx.set_arg_texture(actual_argument_slot, v.tex) - elif isinstance(needed, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture): - launch_ctx.set_arg_rw_texture(actual_argument_slot, v.tex) - elif isinstance(needed, ndarray_type.NdarrayType): - # Element shapes are already specialized in Taichi codegen. - # The shape information for element dims are no longer needed. - # Therefore we strip the element shapes from the shape vector, - # so that it only holds "real" array shapes. - is_soa = needed.layout == Layout.SOA - array_shape = v.shape - if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max: - warnings.warn( - "Ndarray index might be out of int32 boundary but int64 indexing is not supported yet." + provided, v = type(val), val + if actual_argument_slot >= max_arg_num: + exceed_max_arg_num = True + break + # Note: do not use sth like "needed == f32". That would be slow. + if id(needed) in primitive_types.real_type_ids: + if not isinstance(v, (float, int, np.floating, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) + launch_ctx.set_arg_float(actual_argument_slot, float(v)) + elif id(needed) in primitive_types.integer_type_ids: + if not isinstance(v, (int, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.to_string(), provided) + if is_signed(cook_dtype(needed)): + launch_ctx.set_arg_int(actual_argument_slot, int(v)) + else: + launch_ctx.set_arg_uint(actual_argument_slot, int(v)) + elif isinstance(needed, sparse_matrix_builder): + # Pass only the base pointer of the ti.types.sparse_matrix_builder() argument + launch_ctx.set_arg_uint(actual_argument_slot, v._get_ndarray_addr()) + elif isinstance(needed, ndarray_type.NdarrayType) and isinstance(v, taichi.lang._ndarray.Ndarray): + v_primal = v.arr + v_grad = v.grad.arr if v.grad else None + if v_grad is None: + launch_ctx.set_arg_ndarray(actual_argument_slot, v_primal) + else: + launch_ctx.set_arg_ndarray_with_grad(actual_argument_slot, v_primal, v_grad) + elif isinstance(needed, texture_type.TextureType) and isinstance(v, taichi.lang._texture.Texture): + launch_ctx.set_arg_texture(actual_argument_slot, v.tex) + elif isinstance(needed, texture_type.RWTextureType) and isinstance(v, taichi.lang._texture.Texture): + launch_ctx.set_arg_rw_texture(actual_argument_slot, v.tex) + elif isinstance(needed, ndarray_type.NdarrayType): + # Element shapes are already specialized in Taichi codegen. + # The shape information for element dims are no longer needed. + # Therefore we strip the element shapes from the shape vector, + # so that it only holds "real" array shapes. + is_soa = needed.layout == Layout.SOA + array_shape = v.shape + if functools.reduce(operator.mul, array_shape, 1) > np.iinfo(np.int32).max: + warnings.warn( + "Ndarray index might be out of int32 boundary but int64 indexing is not supported yet." + ) + if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids: + element_dim = 0 + else: + element_dim = needed.dtype.ndim + array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim] + if isinstance(v, np.ndarray): + if v.flags.c_contiguous: + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(v.ctypes.data), v.nbytes, array_shape, 0 + ) + elif v.flags.f_contiguous: + # TODO: A better way that avoids copying is saving strides info. + tmp = np.ascontiguousarray(v) + # Purpose: DO NOT GC |tmp|! + tmps.append(tmp) + + def callback(original, updated): + np.copyto(original, np.asfortranarray(updated)) + + callbacks.append(functools.partial(callback, v, tmp)) + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0 ) - if needed.dtype is None or id(needed.dtype) in primitive_types.type_ids: - element_dim = 0 else: - element_dim = needed.dtype.ndim - array_shape = v.shape[element_dim:] if is_soa else v.shape[:-element_dim] - if isinstance(v, np.ndarray): - if v.flags.c_contiguous: - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(v.ctypes.data), v.nbytes, array_shape, 0 - ) - elif v.flags.f_contiguous: - # TODO: A better way that avoids copying is saving strides info. - tmp = np.ascontiguousarray(v) - # Purpose: DO NOT GC |tmp|! - tmps.append(tmp) - - def callback(original, updated): - np.copyto(original, np.asfortranarray(updated)) - - callbacks.append(functools.partial(callback, v, tmp)) - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(tmp.ctypes.data), tmp.nbytes, array_shape, 0 - ) - else: + raise ValueError( + "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) " + "before passing it into taichi kernel." + ) + elif has_pytorch(): + import torch # pylint: disable=C0415 + + if isinstance(v, torch.Tensor): + if not v.is_contiguous(): raise ValueError( - "Non contiguous numpy arrays are not supported, please call np.ascontiguousarray(arr) " - "before passing it into taichi kernel." + "Non contiguous tensors are not supported, please call tensor.contiguous() before " + "passing it into taichi kernel." ) - elif has_pytorch(): - import torch # pylint: disable=C0415 - - if isinstance(v, torch.Tensor): - if not v.is_contiguous(): - raise ValueError( - "Non contiguous tensors are not supported, please call tensor.contiguous() before " - "passing it into taichi kernel." - ) - taichi_arch = self.runtime.prog.config().arch - - def get_call_back(u, v): - def call_back(): - u.copy_(v) - - return call_back - - # FIXME: only allocate when launching grad kernel - if v.requires_grad and v.grad is None: - v.grad = torch.zeros_like(v) - - tmp = v - if str(v.device).startswith("cuda") and taichi_arch != _ti_core.Arch.cuda: - # Getting a torch CUDA tensor on Taichi non-cuda arch: - # We just replace it with a CPU tensor and by the end of kernel execution we'll use the - # callback to copy the values back to the original CUDA tensor. - host_v = v.to(device="cpu", copy=True) - tmp = host_v + taichi_arch = self.runtime.prog.config().arch + + def get_call_back(u, v): + def call_back(): + u.copy_(v) + + return call_back + + # FIXME: only allocate when launching grad kernel + if v.requires_grad and v.grad is None: + v.grad = torch.zeros_like(v) + + tmp = v + if str(v.device).startswith("cuda") and taichi_arch != _ti_core.Arch.cuda: + # Getting a torch CUDA tensor on Taichi non-cuda arch: + # We just replace it with a CPU tensor and by the end of kernel execution we'll use the + # callback to copy the values back to the original CUDA tensor. + host_v = v.to(device="cpu", copy=True) + tmp = host_v + callbacks.append(get_call_back(v, host_v)) + + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, + int(tmp.data_ptr()), + tmp.element_size() * tmp.nelement(), + array_shape, + int(v.grad.data_ptr()) if v.grad is not None else 0, + ) + else: + raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) + elif has_paddle(): + import paddle # pylint: disable=C0415 + + if isinstance(v, paddle.Tensor): + # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch + def get_call_back(u, v): + def call_back(): + u.copy_(v, False) + + return call_back + + tmp = v.value().get_tensor() + taichi_arch = self.runtime.prog.config().arch + if v.place.is_gpu_place(): + if taichi_arch != _ti_core.Arch.cuda: + # Paddle cuda tensor on Taichi non-cuda arch + host_v = v.cpu() + tmp = host_v.value().get_tensor() callbacks.append(get_call_back(v, host_v)) - - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, - int(tmp.data_ptr()), - tmp.element_size() * tmp.nelement(), - array_shape, - int(v.grad.data_ptr()) if v.grad is not None else 0, - ) + elif v.place.is_cpu_place(): + if taichi_arch == _ti_core.Arch.cuda: + # Paddle cpu tensor on Taichi cuda arch + gpu_v = v.cuda() + tmp = gpu_v.value().get_tensor() + callbacks.append(get_call_back(v, gpu_v)) else: - raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) - elif has_paddle(): - import paddle # pylint: disable=C0415 - - if isinstance(v, paddle.Tensor): - # For now, paddle.fluid.core.Tensor._ptr() is only available on develop branch - def get_call_back(u, v): - def call_back(): - u.copy_(v, False) - - return call_back - - tmp = v.value().get_tensor() - taichi_arch = self.runtime.prog.config().arch - if v.place.is_gpu_place(): - if taichi_arch != _ti_core.Arch.cuda: - # Paddle cuda tensor on Taichi non-cuda arch - host_v = v.cpu() - tmp = host_v.value().get_tensor() - callbacks.append(get_call_back(v, host_v)) - elif v.place.is_cpu_place(): - if taichi_arch == _ti_core.Arch.cuda: - # Paddle cpu tensor on Taichi cuda arch - gpu_v = v.cuda() - tmp = gpu_v.value().get_tensor() - callbacks.append(get_call_back(v, gpu_v)) - else: - # Paddle do support many other backends like XPU, NPU, MLU, IPU - raise TaichiRuntimeTypeError( - f"Taichi do not support backend {v.place} that Paddle support" - ) - launch_ctx.set_arg_external_array_with_shape( - actual_argument_slot, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0 - ) - else: - raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) + # Paddle do support many other backends like XPU, NPU, MLU, IPU + raise TaichiRuntimeTypeError(f"Taichi do not support backend {v.place} that Paddle support") + launch_ctx.set_arg_external_array_with_shape( + actual_argument_slot, int(tmp._ptr()), v.element_size() * v.size, array_shape, 0 + ) else: raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) + else: + raise TaichiRuntimeTypeError.get(i, needed.to_string(), v) - elif isinstance(needed, MatrixType): - if needed.dtype in primitive_types.real_types: + elif isinstance(needed, MatrixType): + if needed.dtype in primitive_types.real_types: - def cast_func(x): - if not isinstance(x, (int, float, np.integer, np.floating)): - raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) - return float(x) + def cast_func(x): + if not isinstance(x, (int, float, np.integer, np.floating)): + raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) + return float(x) - elif needed.dtype in primitive_types.integer_types: + elif needed.dtype in primitive_types.integer_types: - def cast_func(x): - if not isinstance(x, (int, np.integer)): - raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) - return int(x) + def cast_func(x): + if not isinstance(x, (int, np.integer)): + raise TaichiRuntimeTypeError.get(i, needed.dtype.to_string(), type(x)) + return int(x) - else: - raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.") + else: + raise ValueError(f"Matrix dtype {needed.dtype} is not integer type or real type.") - if needed.ndim == 2: - v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)] - else: - v = [cast_func(v[i]) for i in range(needed.n)] - v = needed(*v) - needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) - elif isinstance(needed, StructType): - if not isinstance(v, needed): - raise TaichiRuntimeTypeError.get(i, str(needed), provided) - needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) + if needed.ndim == 2: + v = [cast_func(v[i, j]) for i in range(needed.n) for j in range(needed.m)] else: - raise ValueError(f"Argument type mismatch. Expecting {needed}, got {type(v)}.") - actual_argument_slot += 1 + v = [cast_func(v[i]) for i in range(needed.n)] + v = needed(*v) + needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) + elif isinstance(needed, StructType): + if not isinstance(v, needed): + raise TaichiRuntimeTypeError.get(i, str(needed), provided) + needed.set_kernel_struct_args(v, launch_ctx, (actual_argument_slot,)) + elif isinstance(needed, ArgPackType): + if not isinstance(v, needed): + raise TaichiRuntimeTypeError.get(i, str(needed), provided) + raise NotImplementedError + else: + raise ValueError(f"Argument type mismatch. Expecting {needed}, got {type(v)}.") + actual_argument_slot += 1 if exceed_max_arg_num: raise TaichiRuntimeError( diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 81f0100a771d6d..f5501db3cad90c 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -37,11 +37,11 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { void visit(ArgLoadExpression *expr) override { emit(fmt::format("arg{}[{}] (dt={})", expr->create_load ? "load" : "addr", - expr->arg_id, data_type_name(expr->dt))); + fmt::join(expr->arg_id, ", "), data_type_name(expr->dt))); } void visit(TexturePtrExpression *expr) override { - emit(fmt::format("(Texture *)(arg[{}])", expr->arg_id)); + emit(fmt::format("(Texture *)(arg[{}])", fmt::join(expr->arg_id, ", "))); } void visit(TextureOpExpression *expr) override { diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index fc3b9bfbe859f3..a1aed4db9d0d6d 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -152,7 +152,7 @@ void ArgLoadExpression::type_check(const CompileConfig *) { void ArgLoadExpression::flatten(FlattenContext *ctx) { auto arg_load = - std::make_unique(arg_id, dt, is_ptr, create_load); + std::make_unique(arg_id[0], dt, is_ptr, create_load); arg_load->ret_type = ret_type; ctx->push_back(std::move(arg_load)); stmt = ctx->back_stmt(); @@ -162,7 +162,7 @@ void TexturePtrExpression::type_check(const CompileConfig *config) { } void TexturePtrExpression::flatten(FlattenContext *ctx) { - ctx->push_back(arg_id, PrimitiveType::f32, /*is_ptr=*/true, + ctx->push_back(arg_id[0], PrimitiveType::f32, /*is_ptr=*/true, /*create_load*/ true); ctx->push_back(ctx->back_stmt(), num_dims, is_storage, format, lod); @@ -609,7 +609,7 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) { auto type = TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad); - auto ptr = Stmt::make(arg_id, type, /*is_ptr=*/true, + auto ptr = Stmt::make(arg_id[0], type, /*is_ptr=*/true, /*create_load=*/false); ptr->tb = tb; @@ -1230,7 +1230,7 @@ void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) { void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) { auto temp = ptr.cast(); TI_ASSERT(0 <= axis && axis < temp->ndim); - ctx->push_back(axis, temp->arg_id); + ctx->push_back(axis, temp->arg_id[0]); stmt = ctx->back_stmt(); } diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index f123c64c087c88..44834e2452150a 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -314,7 +314,7 @@ class FrontendReturnStmt : public Stmt { class ArgLoadExpression : public Expression { public: - int arg_id; + const std::vector arg_id; DataType dt; bool is_ptr; @@ -324,7 +324,7 @@ class ArgLoadExpression : public Expression { */ bool create_load; - ArgLoadExpression(int arg_id, + ArgLoadExpression(const std::vector &arg_id, DataType dt, bool is_ptr = false, bool create_load = true) @@ -346,7 +346,7 @@ class Texture; class TexturePtrExpression : public Expression { public: - int arg_id; + const std::vector arg_id; int num_dims; bool is_storage{false}; @@ -354,7 +354,7 @@ class TexturePtrExpression : public Expression { BufferFormat format{BufferFormat::unknown}; int lod{0}; - explicit TexturePtrExpression(int arg_id, int num_dims) + explicit TexturePtrExpression(const std::vector &arg_id, int num_dims) : arg_id(arg_id), num_dims(num_dims), is_storage(false), @@ -362,7 +362,10 @@ class TexturePtrExpression : public Expression { lod(0) { } - TexturePtrExpression(int arg_id, int num_dims, BufferFormat format, int lod) + TexturePtrExpression(const std::vector &arg_id, + int num_dims, + BufferFormat format, + int lod) : arg_id(arg_id), num_dims(num_dims), is_storage(true), @@ -474,14 +477,14 @@ class ExternalTensorExpression : public Expression { public: DataType dt; int ndim; - int arg_id; + std::vector arg_id; bool needs_grad{false}; bool is_grad{false}; BoundaryMode boundary{BoundaryMode::kUnsafe}; ExternalTensorExpression(const DataType &dt, int ndim, - int arg_id, + const std::vector &arg_id, bool needs_grad = false, BoundaryMode boundary = BoundaryMode::kUnsafe) { init(dt, ndim, arg_id, needs_grad, boundary); @@ -512,7 +515,7 @@ class ExternalTensorExpression : public Expression { void init(const DataType &dt, int ndim, - int arg_id, + const std::vector &arg_id, bool needs_grad, BoundaryMode boundary) { this->dt = dt; diff --git a/taichi/program/callable.cpp b/taichi/program/callable.cpp index 716f68d200c415..44c4f5121031b8 100644 --- a/taichi/program/callable.cpp +++ b/taichi/program/callable.cpp @@ -8,11 +8,12 @@ Callable::Callable() = default; Callable::~Callable() = default; -int Callable::insert_scalar_param(const DataType &dt, const std::string &name) { - parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/false); - parameter_list.back().name = name; - parameter_list.back().ptype = ParameterType::kScalar; - return (int)parameter_list.size() - 1; +std::vector Callable::insert_scalar_param(const DataType &dt, + const std::string &name) { + auto p = Parameter(dt->get_compute_type(), /*is_array=*/false); + p.name = name; + p.ptype = ParameterType::kScalar; + return add_parameter(p); } int Callable::insert_ret(const DataType &dt) { @@ -20,20 +21,20 @@ int Callable::insert_ret(const DataType &dt) { return (int)rets.size() - 1; } -int Callable::insert_arr_param(const DataType &dt, - int total_dim, - std::vector element_shape, - const std::string &name) { - parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true, - /*size=*/0, total_dim, element_shape); - parameter_list.back().name = name; - return (int)parameter_list.size() - 1; +std::vector Callable::insert_arr_param(const DataType &dt, + int total_dim, + std::vector element_shape, + const std::string &name) { + auto p = Parameter(dt->get_compute_type(), /*is_array=*/true, 0, total_dim, + element_shape); + p.name = name; + return add_parameter(p); } -int Callable::insert_ndarray_param(const DataType &dt, - int ndim, - const std::string &name, - bool needs_grad) { +std::vector Callable::insert_ndarray_param(const DataType &dt, + int ndim, + const std::string &name, + bool needs_grad) { // Transform ndarray param to a struct type with a pointer to `dt`. std::vector element_shape{}; auto dtype = dt; @@ -45,43 +46,96 @@ int Callable::insert_ndarray_param(const DataType &dt, // If we could avoid using parameter_list in codegen it'll be fine auto *type = TypeFactory::get_instance().get_ndarray_struct_type(dtype, ndim, needs_grad); - parameter_list.emplace_back(type, /*is_array=*/true, - /*size=*/0, ndim + element_shape.size(), - element_shape, BufferFormat::unknown, needs_grad); - parameter_list.back().name = name; - parameter_list.back().ptype = ParameterType::kNdarray; - return (int)parameter_list.size() - 1; + auto p = Parameter(type, /*is_array=*/true, 0, ndim + element_shape.size(), + element_shape, BufferFormat::unknown, needs_grad); + p.name = name; + p.ptype = ParameterType::kNdarray; + return add_parameter(p); } -int Callable::insert_texture_param(int total_dim, const std::string &name) { +std::vector Callable::insert_texture_param(int total_dim, + const std::string &name) { // FIXME: we shouldn't abuse is_array for texture parameters // FIXME: using rwtexture struct type for texture parameters because C-API // does not distinguish between texture and rwtexture. auto *type = TypeFactory::get_instance().get_rwtexture_struct_type(); - parameter_list.emplace_back(type, /*is_array=*/true, 0, total_dim, - std::vector{}); - parameter_list.back().name = name; - parameter_list.back().ptype = ParameterType::kTexture; - return (int)parameter_list.size() - 1; + auto p = Parameter(type, /*is_array=*/true, 0, total_dim, std::vector{}); + p.name = name; + p.ptype = ParameterType::kTexture; + return add_parameter(p); } -int Callable::insert_pointer_param(const DataType &dt, - const std::string &name) { - parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true); - parameter_list.back().name = name; - return (int)parameter_list.size() - 1; +std::vector Callable::insert_pointer_param(const DataType &dt, + const std::string &name) { + auto p = Parameter(dt->get_compute_type(), /*is_array=*/true); + p.name = name; + return add_parameter(p); } -int Callable::insert_rw_texture_param(int total_dim, - BufferFormat format, - const std::string &name) { +std::vector Callable::insert_rw_texture_param(int total_dim, + BufferFormat format, + const std::string &name) { // FIXME: we shouldn't abuse is_array for texture parameters auto *type = TypeFactory::get_instance().get_rwtexture_struct_type(); - parameter_list.emplace_back(type, /*is_array=*/true, 0, total_dim, - std::vector{}, format); - parameter_list.back().name = name; - parameter_list.back().ptype = ParameterType::kRWTexture; - return (int)parameter_list.size() - 1; + auto p = Parameter(type, /*is_array=*/true, 0, total_dim, std::vector{}, + format); + p.name = name; + p.ptype = ParameterType::kRWTexture; + return add_parameter(p); +} + +std::vector Callable::insert_argpack_param_and_push( + const std::string &name) { + TI_ASSERT(temp_argpack_stack_.size() == temp_indices_stack_.size() && + temp_argpack_name_stack_.size() == temp_indices_stack_.size()); + if (temp_argpack_stack_.size() > 0) { + temp_indices_stack_.push_back(temp_argpack_stack_.top().size()); + } else { + temp_indices_stack_.push_back(parameter_list.size()); + } + temp_argpack_stack_.push(std::vector()); + temp_argpack_name_stack_.push(name); + return temp_indices_stack_; +} + +void Callable::pop_argpack_stack() { + // Compile argpack members to a struct. + TI_ASSERT(temp_argpack_stack_.size() > 0 && temp_indices_stack_.size() > 0 && + temp_argpack_name_stack_.size() > 0); + std::vector argpack_params = temp_argpack_stack_.top(); + std::vector members; + members.reserve(argpack_params.size()); + for (int i = 0; i < argpack_params.size(); i++) { + auto ¶m = argpack_params[i]; + members.push_back( + {param.is_array && !param.get_dtype()->is() + ? TypeFactory::get_instance().get_pointer_type(param.get_dtype()) + : (const Type *)param.get_dtype(), + fmt::format("arg_{}_{}", fmt::join(temp_indices_stack_, "_"), i)}); + } + auto *type = + TypeFactory::get_instance().get_struct_type(members)->as(); + auto p = Parameter(DataType(type), false); + p.name = temp_argpack_name_stack_.top(); + add_parameter(p); + // Pop stacks + temp_argpack_stack_.pop(); + temp_indices_stack_.pop_back(); + temp_argpack_name_stack_.pop(); +} + +std::vector Callable::add_parameter(const Parameter ¶m) { + TI_ASSERT(temp_argpack_stack_.size() == temp_indices_stack_.size() && + temp_argpack_name_stack_.size() == temp_indices_stack_.size()); + if (temp_argpack_stack_.size() == 0) { + parameter_list.push_back(param); + return std::vector{(int)parameter_list.size() - 1}; + } + TI_ASSERT(temp_argpack_stack_.size() > 0); + temp_argpack_stack_.top().push_back(param); + std::vector ret = temp_indices_stack_; + ret.push_back(temp_argpack_stack_.top().size() - 1); + return ret; } void Callable::finalize_rets() { @@ -98,6 +152,9 @@ void Callable::finalize_rets() { } void Callable::finalize_params() { + TI_ASSERT(temp_argpack_stack_.size() == 0 && + temp_indices_stack_.size() == 0 && + temp_argpack_name_stack_.size() == 0); std::vector members; members.reserve(parameter_list.size()); for (int i = 0; i < parameter_list.size(); i++) { diff --git a/taichi/program/callable.h b/taichi/program/callable.h index 47d1598b15eba3..93ff472f791db4 100644 --- a/taichi/program/callable.h +++ b/taichi/program/callable.h @@ -3,6 +3,8 @@ #include "taichi/rhi/device.h" #include "taichi/util/lang_util.h" +#include + namespace taichi::lang { class Program; @@ -124,20 +126,27 @@ class TI_DLL_EXPORT Callable : public CallableBase { Callable(); virtual ~Callable(); - int insert_scalar_param(const DataType &dt, const std::string &name = ""); - int insert_arr_param(const DataType &dt, - int total_dim, - std::vector element_shape, - const std::string &name = ""); - int insert_ndarray_param(const DataType &dt, - int ndim, - const std::string &name = "", - bool needs_grad = false); - int insert_texture_param(int total_dim, const std::string &name = ""); - int insert_pointer_param(const DataType &dt, const std::string &name = ""); - int insert_rw_texture_param(int total_dim, - BufferFormat format, - const std::string &name = ""); + std::vector insert_scalar_param(const DataType &dt, + const std::string &name = ""); + std::vector insert_arr_param(const DataType &dt, + int total_dim, + std::vector element_shape, + const std::string &name = ""); + std::vector insert_ndarray_param(const DataType &dt, + int ndim, + const std::string &name = "", + bool needs_grad = false); + std::vector insert_texture_param(int total_dim, + const std::string &name = ""); + std::vector insert_pointer_param(const DataType &dt, + const std::string &name = ""); + std::vector insert_rw_texture_param(int total_dim, + BufferFormat format, + const std::string &name = ""); + + std::vector insert_argpack_param_and_push(const std::string &name = ""); + + void pop_argpack_stack(); int insert_ret(const DataType &dt); @@ -146,6 +155,14 @@ class TI_DLL_EXPORT Callable : public CallableBase { void finalize_params(); [[nodiscard]] virtual std::string get_name() const = 0; + + private: + std::vector add_parameter(const Parameter ¶m); + // Note: These stacks are used for inserting params inside argpacks. When + // we call finalize_params(), all of them are required to be empty then. + std::stack> temp_argpack_stack_; + std::vector temp_indices_stack_; + std::stack temp_argpack_name_stack_; }; } // namespace taichi::lang diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index f21579baf20b70..4bef97a8456087 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -276,7 +276,8 @@ Kernel &Program::get_snode_reader(SNode *snode) { auto &ker = kernel([snode, this](Kernel *kernel) { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - auto argload_expr = Expr::make(i, PrimitiveType::i32); + auto argload_expr = Expr::make(std::vector{i}, + PrimitiveType::i32); argload_expr->type_check(&this->compile_config()); indices.push_back(std::move(argload_expr)); } @@ -301,7 +302,8 @@ Kernel &Program::get_snode_writer(SNode *snode) { auto &ker = kernel([snode, this](Kernel *kernel) { ExprGroup indices; for (int i = 0; i < snode->num_active_indices; i++) { - auto argload_expr = Expr::make(i, PrimitiveType::i32); + auto argload_expr = Expr::make(std::vector{i}, + PrimitiveType::i32); argload_expr->type_check(&this->compile_config()); indices.push_back(std::move(argload_expr)); } @@ -309,7 +311,8 @@ Kernel &Program::get_snode_writer(SNode *snode) { auto expr = builder.expr_subscript(Expr(snode_to_fields_.at(snode)), indices); auto argload_expr = Expr::make( - snode->num_active_indices, snode->dt->get_compute_type()); + std::vector{snode->num_active_indices}, + snode->dt->get_compute_type()); argload_expr->type_check(&this->compile_config()); builder.insert_assignment(expr, argload_expr, expr->tb); }); diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 922b99589dda49..5abcf9f755e822 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -684,6 +684,9 @@ void export_lang(py::module &m) { .def("insert_texture_param", &Kernel::insert_texture_param) .def("insert_pointer_param", &Kernel::insert_pointer_param) .def("insert_rw_texture_param", &Kernel::insert_rw_texture_param) + .def("insert_argpack_param_and_push", + &Kernel::insert_argpack_param_and_push) + .def("pop_argpack_stack", &Kernel::pop_argpack_stack) .def("insert_ret", &Kernel::insert_ret) .def("finalize_rets", &Kernel::finalize_rets) .def("finalize_params", &Kernel::finalize_params) @@ -929,14 +932,15 @@ void export_lang(py::module &m) { Stmt::make); m.def("make_arg_load_expr", - Expr::make, + Expr::make &, + const DataType &, bool, bool>, "arg_id"_a, "dt"_a, "is_ptr"_a = false, "create_load"_a = true); m.def("make_reference", Expr::make); m.def("make_external_tensor_expr", - Expr::make); + Expr::make &, bool, const BoundaryMode &>); m.def("make_external_tensor_grad_expr", Expr::make); @@ -952,9 +956,11 @@ void export_lang(py::module &m) { m.def("make_const_expr_fp", Expr::make); - m.def("make_texture_ptr_expr", Expr::make); + m.def("make_texture_ptr_expr", + Expr::make &, int>); m.def("make_rw_texture_ptr_expr", - Expr::make); + Expr::make &, int, + const BufferFormat &, int>); auto &&texture = py::enum_(m, "TextureOpType", py::arithmetic()); diff --git a/taichi/transforms/lower_ast.cpp b/taichi/transforms/lower_ast.cpp index d3fab5c7663d80..88964cf200f660 100644 --- a/taichi/transforms/lower_ast.cpp +++ b/taichi/transforms/lower_ast.cpp @@ -253,14 +253,14 @@ class LowerAST : public IRVisitor { std::vector shape; if (stmt->external_tensor.is()) { auto tensor = stmt->external_tensor.cast(); - arg_id = tensor->arg_id; + arg_id = tensor->arg_id[0]; for (int i = 0; i < tensor->ndim; i++) { shape.push_back( fctx.push_back(i, arg_id)); } } else if (stmt->external_tensor.is()) { auto rw_texture = stmt->external_tensor.cast(); - arg_id = rw_texture->arg_id; + arg_id = rw_texture->arg_id[0]; for (size_t i = 0; i < rw_texture->num_dims; ++i) { shape.emplace_back( fctx.push_back(i, arg_id)); diff --git a/tests/cpp/ir/frontend_type_inference_test.cpp b/tests/cpp/ir/frontend_type_inference_test.cpp index dd21a5ec7b92ce..f9f791e37dbe3f 100644 --- a/tests/cpp/ir/frontend_type_inference_test.cpp +++ b/tests/cpp/ir/frontend_type_inference_test.cpp @@ -14,7 +14,8 @@ TEST(FrontendTypeInference, Const) { } TEST(FrontendTypeInference, ArgLoad) { - auto arg_load_u64 = Expr::make(2, PrimitiveType::u64); + auto arg_load_u64 = + Expr::make(std::vector{2}, PrimitiveType::u64); arg_load_u64->type_check(nullptr); EXPECT_EQ(arg_load_u64->ret_type, PrimitiveType::u64); } @@ -154,8 +155,8 @@ TEST(FrontendTypeInference, GlobalPtr_ExternalTensor) { auto index = value(2); index->type_check(nullptr); - auto external_tensor = - Expr::make(PrimitiveType::u16, 1, 0, 0); + auto external_tensor = Expr::make( + PrimitiveType::u16, 1, std::vector{0}, 0); auto global_ptr = ast_builder->expr_subscript(external_tensor, ExprGroup(index)); EXPECT_THROW(global_ptr->type_check(nullptr), TaichiTypeError); @@ -207,8 +208,8 @@ TEST(FrontendTypeInference, SNodeOp) { } TEST(FrontendTypeInference, ExternalTensorShapeAlongAxis) { - auto external_tensor = - Expr::make(PrimitiveType::u64, 1, 0, 0); + auto external_tensor = Expr::make( + PrimitiveType::u64, 1, std::vector{0}, 0); auto shape = Expr::make(external_tensor, 0); shape->type_check(nullptr); diff --git a/tests/python/test_argpack.py b/tests/python/test_argpack.py index 1c89f309888638..79ab7ff12eced0 100644 --- a/tests/python/test_argpack.py +++ b/tests/python/test_argpack.py @@ -4,6 +4,7 @@ from tests import test_utils +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_basic(): pack_type = ti.types.argpack(a=ti.i32, b=bool, c=ti.f32) @@ -23,6 +24,7 @@ def foo(pack: pack_type) -> ti.f32: assert foo(pack2) == test_utils.approx(2 + 2.1, rel=1e-3) +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_multiple(): arr = ti.ndarray(dtype=ti.math.vec3, shape=(4, 4)) @@ -41,6 +43,7 @@ def foo(p1: pack_type1, p2: pack_type2) -> ti.f32: assert foo(pack1, pack2) == test_utils.approx(1 * 2.1 + 2.0, rel=1e-3) +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_nested(): arr = ti.ndarray(dtype=ti.math.vec3, shape=(4, 4)) @@ -68,6 +71,7 @@ def h(x: pack_type) -> int: assert h(pack) == 233 +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_as_return(): pack_type = ti.types.argpack(a=ti.i32, b=bool) @@ -81,6 +85,7 @@ def foo(pack: pack_type) -> pack_type: foo() +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_as_struct_type_element(): with pytest.raises(ValueError, match="Invalid data type "): @@ -89,6 +94,7 @@ def test_argpack_as_struct_type_element(): print(struct_with_argpack_inside) +@pytest.mark.skip(reason="Temporarily disabled argpack functionalities") @test_utils.test() def test_argpack_with_ndarray(): arr = ti.ndarray(dtype=ti.math.vec3, shape=(4, 4))