Skip to content

Commit

Permalink
[refactor] Make arg_id a vector
Browse files Browse the repository at this point in the history
ghstack-source-id: 53eb0ea78dcf07ec8d1897e852b7a31862646803
Pull Request resolved: #8204
  • Loading branch information
listerily authored and Taichi Gardener committed Jul 11, 2023
1 parent fee507e commit ee6e353
Show file tree
Hide file tree
Showing 14 changed files with 356 additions and 272 deletions.
12 changes: 6 additions & 6 deletions docs/lang/articles/advanced/argument_pack.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ view_params = view_params_tmpl(
Once argument packs are created and initialized, they can be easily used as kernel parameters. Simply pass them to the kernel, and Taichi will intelligently cache them (if their values remain unchanged) across multiple kernel calls, optimizing performance.

```python cont
@ti.kernel
def p(view_params: view_params_tmpl) -> ti.f32:
return view_params.far


print(p(view_params)) # 1.0
# @ti.kernel
# def p(view_params: view_params_tmpl) -> ti.f32:
# return view_params.far
#
#
# print(p(view_params)) # 1.0
```

## Limitations
Expand Down
1 change: 1 addition & 0 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def transform_as_kernel():
for i, arg in enumerate(args.args):
if isinstance(ctx.func.arguments[i].annotation, ArgPackType):
d = {}
kernel_arguments.push_argpack_arg(ctx.func.arguments[i].name)
for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()):
d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j])
ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d))
Expand Down
5 changes: 5 additions & 0 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ def decl_struct_arg(structtype, name):
return structtype.from_taichi_object(arg_load)


def push_argpack_arg(name):
impl.get_runtime().compiling_callable.insert_argpack_param_and_push(name)


def decl_argpack_arg(argpacktype, member_dict):
impl.get_runtime().compiling_callable.pop_argpack_stack()
return argpacktype.from_taichi_object(member_dict)


Expand Down
346 changes: 166 additions & 180 deletions python/taichi/lang/kernel_impl.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {

void visit(ArgLoadExpression *expr) override {
emit(fmt::format("arg{}[{}] (dt={})", expr->create_load ? "load" : "addr",
expr->arg_id, data_type_name(expr->dt)));
fmt::join(expr->arg_id, ", "), data_type_name(expr->dt)));
}

void visit(TexturePtrExpression *expr) override {
emit(fmt::format("(Texture *)(arg[{}])", expr->arg_id));
emit(fmt::format("(Texture *)(arg[{}])", fmt::join(expr->arg_id, ", ")));
}

void visit(TextureOpExpression *expr) override {
Expand Down
10 changes: 5 additions & 5 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void ArgLoadExpression::type_check(const CompileConfig *) {

void ArgLoadExpression::flatten(FlattenContext *ctx) {
auto arg_load =
std::make_unique<ArgLoadStmt>(arg_id, dt, is_ptr, create_load);
std::make_unique<ArgLoadStmt>(arg_id[0], dt, is_ptr, create_load);
arg_load->ret_type = ret_type;
ctx->push_back(std::move(arg_load));
stmt = ctx->back_stmt();
Expand All @@ -162,7 +162,7 @@ void TexturePtrExpression::type_check(const CompileConfig *config) {
}

void TexturePtrExpression::flatten(FlattenContext *ctx) {
ctx->push_back<ArgLoadStmt>(arg_id, PrimitiveType::f32, /*is_ptr=*/true,
ctx->push_back<ArgLoadStmt>(arg_id[0], PrimitiveType::f32, /*is_ptr=*/true,
/*create_load*/ true);
ctx->push_back<TexturePtrStmt>(ctx->back_stmt(), num_dims, is_storage, format,
lod);
Expand Down Expand Up @@ -610,7 +610,7 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) {
TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad);
type = TypeFactory::get_instance().get_pointer_type((Type *)type);

auto ptr = Stmt::make<ArgLoadStmt>(arg_id, type, /*is_ptr=*/true,
auto ptr = Stmt::make<ArgLoadStmt>(arg_id[0], type, /*is_ptr=*/true,
/*create_load=*/false);

ptr->tb = tb;
Expand Down Expand Up @@ -1231,7 +1231,7 @@ void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) {
void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) {
auto temp = ptr.cast<ExternalTensorExpression>();
TI_ASSERT(0 <= axis && axis < temp->ndim);
ctx->push_back<ExternalTensorShapeAlongAxisStmt>(axis, temp->arg_id);
ctx->push_back<ExternalTensorShapeAlongAxisStmt>(axis, temp->arg_id[0]);
stmt = ctx->back_stmt();
}

Expand All @@ -1245,7 +1245,7 @@ void ExternalTensorBasePtrExpression::type_check(const CompileConfig *) {

void ExternalTensorBasePtrExpression::flatten(FlattenContext *ctx) {
auto tensor = ptr.cast<ExternalTensorExpression>();
ctx->push_back<ExternalTensorBasePtrStmt>(tensor->arg_id, is_grad);
ctx->push_back<ExternalTensorBasePtrStmt>(tensor->arg_id[0], is_grad);
stmt = ctx->back_stmt();
stmt->ret_type = ret_type;
}
Expand Down
19 changes: 11 additions & 8 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class FrontendReturnStmt : public Stmt {

class ArgLoadExpression : public Expression {
public:
int arg_id;
const std::vector<int> arg_id;
DataType dt;
bool is_ptr;

Expand All @@ -324,7 +324,7 @@ class ArgLoadExpression : public Expression {
*/
bool create_load;

ArgLoadExpression(int arg_id,
ArgLoadExpression(const std::vector<int> &arg_id,
DataType dt,
bool is_ptr = false,
bool create_load = true)
Expand All @@ -346,23 +346,26 @@ class Texture;

class TexturePtrExpression : public Expression {
public:
int arg_id;
const std::vector<int> arg_id;
int num_dims;
bool is_storage{false};

// Optional, for storage textures
BufferFormat format{BufferFormat::unknown};
int lod{0};

explicit TexturePtrExpression(int arg_id, int num_dims)
explicit TexturePtrExpression(const std::vector<int> &arg_id, int num_dims)
: arg_id(arg_id),
num_dims(num_dims),
is_storage(false),
format(BufferFormat::rgba8),
lod(0) {
}

TexturePtrExpression(int arg_id, int num_dims, BufferFormat format, int lod)
TexturePtrExpression(const std::vector<int> &arg_id,
int num_dims,
BufferFormat format,
int lod)
: arg_id(arg_id),
num_dims(num_dims),
is_storage(true),
Expand Down Expand Up @@ -474,14 +477,14 @@ class ExternalTensorExpression : public Expression {
public:
DataType dt;
int ndim;
int arg_id;
std::vector<int> arg_id;
bool needs_grad{false};
bool is_grad{false};
BoundaryMode boundary{BoundaryMode::kUnsafe};

ExternalTensorExpression(const DataType &dt,
int ndim,
int arg_id,
const std::vector<int> &arg_id,
bool needs_grad = false,
BoundaryMode boundary = BoundaryMode::kUnsafe) {
init(dt, ndim, arg_id, needs_grad, boundary);
Expand Down Expand Up @@ -513,7 +516,7 @@ class ExternalTensorExpression : public Expression {

void init(const DataType &dt,
int ndim,
int arg_id,
const std::vector<int> &arg_id,
bool needs_grad,
BoundaryMode boundary) {
this->dt = dt;
Expand Down
140 changes: 98 additions & 42 deletions taichi/program/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,33 @@ Callable::Callable() = default;

Callable::~Callable() = default;

int Callable::insert_scalar_param(const DataType &dt, const std::string &name) {
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/false);
parameter_list.back().name = name;
parameter_list.back().ptype = ParameterType::kScalar;
return (int)parameter_list.size() - 1;
std::vector<int> Callable::insert_scalar_param(const DataType &dt,
const std::string &name) {
auto p = Parameter(dt->get_compute_type(), /*is_array=*/false);
p.name = name;
p.ptype = ParameterType::kScalar;
return add_parameter(p);
}

int Callable::insert_ret(const DataType &dt) {
rets.emplace_back(dt->get_compute_type());
return (int)rets.size() - 1;
}

int Callable::insert_arr_param(const DataType &dt,
int total_dim,
std::vector<int> element_shape,
const std::string &name) {
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true,
/*size=*/0, total_dim, element_shape);
parameter_list.back().name = name;
return (int)parameter_list.size() - 1;
std::vector<int> Callable::insert_arr_param(const DataType &dt,
int total_dim,
std::vector<int> element_shape,
const std::string &name) {
auto p = Parameter(dt->get_compute_type(), /*is_array=*/true, 0, total_dim,
element_shape);
p.name = name;
return add_parameter(p);
}

int Callable::insert_ndarray_param(const DataType &dt,
int ndim,
const std::string &name,
bool needs_grad) {
std::vector<int> Callable::insert_ndarray_param(const DataType &dt,
int ndim,
const std::string &name,
bool needs_grad) {
// Transform ndarray param to a struct type with a pointer to `dt`.
std::vector<int> element_shape{};
auto dtype = dt;
Expand All @@ -45,43 +46,95 @@ int Callable::insert_ndarray_param(const DataType &dt,
// If we could avoid using parameter_list in codegen it'll be fine
auto *type = TypeFactory::get_instance().get_ndarray_struct_type(dtype, ndim,
needs_grad);
parameter_list.emplace_back(type, /*is_array=*/true,
/*size=*/0, ndim + element_shape.size(),
element_shape, BufferFormat::unknown, needs_grad);
parameter_list.back().name = name;
parameter_list.back().ptype = ParameterType::kNdarray;
return (int)parameter_list.size() - 1;
auto p = Parameter(type, /*is_array=*/true, 0, ndim + element_shape.size(),
element_shape, BufferFormat::unknown, needs_grad);
p.name = name;
p.ptype = ParameterType::kNdarray;
return add_parameter(p);
}

int Callable::insert_texture_param(int total_dim, const std::string &name) {
std::vector<int> Callable::insert_texture_param(int total_dim,
const std::string &name) {
// FIXME: we shouldn't abuse is_array for texture parameters
// FIXME: using rwtexture struct type for texture parameters because C-API
// does not distinguish between texture and rwtexture.
auto *type = TypeFactory::get_instance().get_rwtexture_struct_type();
parameter_list.emplace_back(type, /*is_array=*/true, 0, total_dim,
std::vector<int>{});
parameter_list.back().name = name;
parameter_list.back().ptype = ParameterType::kTexture;
return (int)parameter_list.size() - 1;
auto p = Parameter(type, /*is_array=*/true, 0, total_dim, std::vector<int>{});
p.name = name;
p.ptype = ParameterType::kTexture;
return add_parameter(p);
}

int Callable::insert_pointer_param(const DataType &dt,
const std::string &name) {
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true);
parameter_list.back().name = name;
return (int)parameter_list.size() - 1;
std::vector<int> Callable::insert_pointer_param(const DataType &dt,
const std::string &name) {
auto p = Parameter(dt->get_compute_type(), /*is_array=*/true);
p.name = name;
return add_parameter(p);
}

int Callable::insert_rw_texture_param(int total_dim,
BufferFormat format,
const std::string &name) {
std::vector<int> Callable::insert_rw_texture_param(int total_dim,
BufferFormat format,
const std::string &name) {
// FIXME: we shouldn't abuse is_array for texture parameters
auto *type = TypeFactory::get_instance().get_rwtexture_struct_type();
parameter_list.emplace_back(type, /*is_array=*/true, 0, total_dim,
std::vector<int>{}, format);
parameter_list.back().name = name;
parameter_list.back().ptype = ParameterType::kRWTexture;
return (int)parameter_list.size() - 1;
auto p = Parameter(type, /*is_array=*/true, 0, total_dim, std::vector<int>{},
format);
p.name = name;
p.ptype = ParameterType::kRWTexture;
return add_parameter(p);
}

std::vector<int> Callable::insert_argpack_param_and_push(
const std::string &name) {
TI_ASSERT(temp_argpack_stack_.size() == temp_indices_stack_.size() &&
temp_argpack_name_stack_.size() == temp_indices_stack_.size());
if (temp_argpack_stack_.size() > 0) {
temp_indices_stack_.push_back(temp_argpack_stack_.top().size());
} else {
temp_indices_stack_.push_back(parameter_list.size());
}
temp_argpack_stack_.push(std::vector<Parameter>());
temp_argpack_name_stack_.push(name);
return temp_indices_stack_;
}

void Callable::pop_argpack_stack() {
// Compile argpack members to a struct.
TI_ASSERT(temp_argpack_stack_.size() > 0 && temp_indices_stack_.size() > 0 &&
temp_argpack_name_stack_.size() > 0);
std::vector<Parameter> argpack_params = temp_argpack_stack_.top();
std::vector<StructMember> members;
members.reserve(argpack_params.size());
for (int i = 0; i < argpack_params.size(); i++) {
auto &param = argpack_params[i];
members.push_back(
{param.is_array && !param.get_dtype()->is<StructType>()
? TypeFactory::get_instance().get_pointer_type(param.get_dtype())
: (const Type *)param.get_dtype(),
fmt::format("arg_{}_{}", fmt::join(temp_indices_stack_, "_"), i)});
}
auto *type =
TypeFactory::get_instance().get_struct_type(members)->as<StructType>();
auto p = Parameter(DataType(type), false);
p.name = temp_argpack_name_stack_.top();
add_parameter(p);
// Pop stacks
temp_argpack_stack_.pop();
temp_indices_stack_.pop_back();
temp_argpack_name_stack_.pop();
}

std::vector<int> Callable::add_parameter(const Parameter &param) {
TI_ASSERT(temp_argpack_stack_.size() == temp_indices_stack_.size() &&
temp_argpack_name_stack_.size() == temp_indices_stack_.size());
if (temp_argpack_stack_.size() == 0) {
parameter_list.push_back(param);
return std::vector<int>{(int)parameter_list.size() - 1};
}
temp_argpack_stack_.top().push_back(param);
std::vector<int> ret = temp_indices_stack_;
ret.push_back(temp_argpack_stack_.top().size() - 1);
return ret;
}

void Callable::finalize_rets() {
Expand All @@ -98,6 +151,9 @@ void Callable::finalize_rets() {
}

void Callable::finalize_params() {
TI_ASSERT(temp_argpack_stack_.size() == 0 &&
temp_indices_stack_.size() == 0 &&
temp_argpack_name_stack_.size() == 0);
std::vector<StructMember> members;
members.reserve(parameter_list.size());
for (int i = 0; i < parameter_list.size(); i++) {
Expand Down
Loading

0 comments on commit ee6e353

Please sign in to comment.