Skip to content

Commit

Permalink
[refactor] Make arg_id a vector
Browse files Browse the repository at this point in the history
ghstack-source-id: 5d0d0a2d28f4b99cad3197d0fd7843f7f88ddda7
Pull Request resolved: #8204
  • Loading branch information
listerily committed Jun 20, 2023
1 parent c7bf8b0 commit 929cb89
Show file tree
Hide file tree
Showing 13 changed files with 357 additions and 272 deletions.
1 change: 1 addition & 0 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,7 @@ def transform_as_kernel():
for i, arg in enumerate(args.args):
if isinstance(ctx.func.arguments[i].annotation, ArgPackType):
d = {}
kernel_arguments.push_argpack_arg(ctx.func.arguments[i].name)
for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()):
d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j])
ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d))
Expand Down
19 changes: 12 additions & 7 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def decl_scalar_arg(dtype, name):
arg_id = impl.get_runtime().compiling_callable.insert_pointer_param(dtype, name)
else:
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(dtype, name)
return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref))
return Expr(_ti_core.make_arg_load_expr((arg_id,), dtype, is_ref))


def get_type_for_kernel_args(dtype, name):
Expand All @@ -83,18 +83,23 @@ def get_type_for_kernel_args(dtype, name):
def decl_matrix_arg(matrixtype, name):
arg_type = get_type_for_kernel_args(matrixtype, name)
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False))
arg_load = Expr(_ti_core.make_arg_load_expr((arg_id,), arg_type, create_load=False))
return matrixtype.from_taichi_object(arg_load)


def decl_struct_arg(structtype, name):
arg_type = get_type_for_kernel_args(structtype, name)
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False))
arg_load = Expr(_ti_core.make_arg_load_expr((arg_id,), arg_type, create_load=False))
return structtype.from_taichi_object(arg_load)


def push_argpack_arg(name):
impl.get_runtime().compiling_callable.insert_argpack_param_and_push(name)


def decl_argpack_arg(argpacktype, member_dict):
impl.get_runtime().compiling_callable.pop_argpack_stack()
return argpacktype.from_taichi_object(member_dict)


Expand All @@ -103,25 +108,25 @@ def decl_sparse_matrix(dtype, name):
ptr_type = cook_dtype(u64)
# Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer
arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(ptr_type, name)
return SparseMatrixProxy(_ti_core.make_arg_load_expr(arg_id, ptr_type, False), value_type)
return SparseMatrixProxy(_ti_core.make_arg_load_expr((arg_id,), ptr_type, False), value_type)


def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary):
arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary))
return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, (arg_id,), needs_grad, boundary))


def decl_texture_arg(num_dimensions, name):
# FIXME: texture_arg doesn't have element_shape so better separate them
arg_id = impl.get_runtime().compiling_callable.insert_texture_param(num_dimensions, name)
return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions), num_dimensions)
return TextureSampler(_ti_core.make_texture_ptr_expr((arg_id,), num_dimensions), num_dimensions)


def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name):
# FIXME: texture_arg doesn't have element_shape so better separate them
arg_id = impl.get_runtime().compiling_callable.insert_rw_texture_param(num_dimensions, buffer_format, name)
return RWTextureAccessor(
_ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, buffer_format, lod),
_ti_core.make_rw_texture_ptr_expr((arg_id,), num_dimensions, buffer_format, lod),
num_dimensions,
)

Expand Down
346 changes: 166 additions & 180 deletions python/taichi/lang/kernel_impl.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {

void visit(ArgLoadExpression *expr) override {
emit(fmt::format("arg{}[{}] (dt={})", expr->create_load ? "load" : "addr",
expr->arg_id, data_type_name(expr->dt)));
fmt::join(expr->arg_id, ", "), data_type_name(expr->dt)));
}

void visit(TexturePtrExpression *expr) override {
emit(fmt::format("(Texture *)(arg[{}])", expr->arg_id));
emit(fmt::format("(Texture *)(arg[{}])", fmt::join(expr->arg_id, ", ")));
}

void visit(TextureOpExpression *expr) override {
Expand Down
8 changes: 4 additions & 4 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void ArgLoadExpression::type_check(const CompileConfig *) {

void ArgLoadExpression::flatten(FlattenContext *ctx) {
auto arg_load =
std::make_unique<ArgLoadStmt>(arg_id, dt, is_ptr, create_load);
std::make_unique<ArgLoadStmt>(arg_id[0], dt, is_ptr, create_load);
arg_load->ret_type = ret_type;
ctx->push_back(std::move(arg_load));
stmt = ctx->back_stmt();
Expand All @@ -162,7 +162,7 @@ void TexturePtrExpression::type_check(const CompileConfig *config) {
}

void TexturePtrExpression::flatten(FlattenContext *ctx) {
ctx->push_back<ArgLoadStmt>(arg_id, PrimitiveType::f32, /*is_ptr=*/true,
ctx->push_back<ArgLoadStmt>(arg_id[0], PrimitiveType::f32, /*is_ptr=*/true,
/*create_load*/ true);
ctx->push_back<TexturePtrStmt>(ctx->back_stmt(), num_dims, is_storage, format,
lod);
Expand Down Expand Up @@ -609,7 +609,7 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) {
auto type =
TypeFactory::get_instance().get_ndarray_struct_type(dt, ndim, needs_grad);

auto ptr = Stmt::make<ArgLoadStmt>(arg_id, type, /*is_ptr=*/true,
auto ptr = Stmt::make<ArgLoadStmt>(arg_id[0], type, /*is_ptr=*/true,
/*create_load=*/false);

ptr->tb = tb;
Expand Down Expand Up @@ -1230,7 +1230,7 @@ void ExternalTensorShapeAlongAxisExpression::type_check(const CompileConfig *) {
void ExternalTensorShapeAlongAxisExpression::flatten(FlattenContext *ctx) {
auto temp = ptr.cast<ExternalTensorExpression>();
TI_ASSERT(0 <= axis && axis < temp->ndim);
ctx->push_back<ExternalTensorShapeAlongAxisStmt>(axis, temp->arg_id);
ctx->push_back<ExternalTensorShapeAlongAxisStmt>(axis, temp->arg_id[0]);
stmt = ctx->back_stmt();
}

Expand Down
19 changes: 11 additions & 8 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class FrontendReturnStmt : public Stmt {

class ArgLoadExpression : public Expression {
public:
int arg_id;
const std::vector<int> arg_id;
DataType dt;
bool is_ptr;

Expand All @@ -324,7 +324,7 @@ class ArgLoadExpression : public Expression {
*/
bool create_load;

ArgLoadExpression(int arg_id,
ArgLoadExpression(const std::vector<int> &arg_id,
DataType dt,
bool is_ptr = false,
bool create_load = true)
Expand All @@ -346,23 +346,26 @@ class Texture;

class TexturePtrExpression : public Expression {
public:
int arg_id;
const std::vector<int> arg_id;
int num_dims;
bool is_storage{false};

// Optional, for storage textures
BufferFormat format{BufferFormat::unknown};
int lod{0};

explicit TexturePtrExpression(int arg_id, int num_dims)
explicit TexturePtrExpression(const std::vector<int> &arg_id, int num_dims)
: arg_id(arg_id),
num_dims(num_dims),
is_storage(false),
format(BufferFormat::rgba8),
lod(0) {
}

TexturePtrExpression(int arg_id, int num_dims, BufferFormat format, int lod)
TexturePtrExpression(const std::vector<int> &arg_id,
int num_dims,
BufferFormat format,
int lod)
: arg_id(arg_id),
num_dims(num_dims),
is_storage(true),
Expand Down Expand Up @@ -474,14 +477,14 @@ class ExternalTensorExpression : public Expression {
public:
DataType dt;
int ndim;
int arg_id;
std::vector<int> arg_id;
bool needs_grad{false};
bool is_grad{false};
BoundaryMode boundary{BoundaryMode::kUnsafe};

ExternalTensorExpression(const DataType &dt,
int ndim,
int arg_id,
const std::vector<int> &arg_id,
bool needs_grad = false,
BoundaryMode boundary = BoundaryMode::kUnsafe) {
init(dt, ndim, arg_id, needs_grad, boundary);
Expand Down Expand Up @@ -512,7 +515,7 @@ class ExternalTensorExpression : public Expression {

void init(const DataType &dt,
int ndim,
int arg_id,
const std::vector<int> &arg_id,
bool needs_grad,
BoundaryMode boundary) {
this->dt = dt;
Expand Down
141 changes: 99 additions & 42 deletions taichi/program/callable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,33 @@ Callable::Callable() = default;

Callable::~Callable() = default;

int Callable::insert_scalar_param(const DataType &dt, const std::string &name) {
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/false);
parameter_list.back().name = name;
parameter_list.back().ptype = ParameterType::kScalar;
return (int)parameter_list.size() - 1;
std::vector<int> Callable::insert_scalar_param(const DataType &dt,
const std::string &name) {
auto p = Parameter(dt->get_compute_type(), /*is_array=*/false);
p.name = name;
p.ptype = ParameterType::kScalar;
return add_parameter(p);
}

int Callable::insert_ret(const DataType &dt) {
rets.emplace_back(dt->get_compute_type());
return (int)rets.size() - 1;
}

int Callable::insert_arr_param(const DataType &dt,
int total_dim,
std::vector<int> element_shape,
const std::string &name) {
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true,
/*size=*/0, total_dim, element_shape);
parameter_list.back().name = name;
return (int)parameter_list.size() - 1;
std::vector<int> Callable::insert_arr_param(const DataType &dt,
int total_dim,
std::vector<int> element_shape,
const std::string &name) {
auto p = Parameter(dt->get_compute_type(), /*is_array=*/true, 0, total_dim,
element_shape);
p.name = name;
return add_parameter(p);
}

int Callable::insert_ndarray_param(const DataType &dt,
int ndim,
const std::string &name,
bool needs_grad) {
std::vector<int> Callable::insert_ndarray_param(const DataType &dt,
int ndim,
const std::string &name,
bool needs_grad) {
// Transform ndarray param to a struct type with a pointer to `dt`.
std::vector<int> element_shape{};
auto dtype = dt;
Expand All @@ -45,43 +46,96 @@ int Callable::insert_ndarray_param(const DataType &dt,
// If we could avoid using parameter_list in codegen it'll be fine
auto *type = TypeFactory::get_instance().get_ndarray_struct_type(dtype, ndim,
needs_grad);
parameter_list.emplace_back(type, /*is_array=*/true,
/*size=*/0, ndim + element_shape.size(),
element_shape, BufferFormat::unknown, needs_grad);
parameter_list.back().name = name;
parameter_list.back().ptype = ParameterType::kNdarray;
return (int)parameter_list.size() - 1;
auto p = Parameter(type, /*is_array=*/true, 0, ndim + element_shape.size(),
element_shape, BufferFormat::unknown, needs_grad);
p.name = name;
p.ptype = ParameterType::kNdarray;
return add_parameter(p);
}

int Callable::insert_texture_param(int total_dim, const std::string &name) {
std::vector<int> Callable::insert_texture_param(int total_dim,
const std::string &name) {
// FIXME: we shouldn't abuse is_array for texture parameters
// FIXME: using rwtexture struct type for texture parameters because C-API
// does not distinguish between texture and rwtexture.
auto *type = TypeFactory::get_instance().get_rwtexture_struct_type();
parameter_list.emplace_back(type, /*is_array=*/true, 0, total_dim,
std::vector<int>{});
parameter_list.back().name = name;
parameter_list.back().ptype = ParameterType::kTexture;
return (int)parameter_list.size() - 1;
auto p = Parameter(type, /*is_array=*/true, 0, total_dim, std::vector<int>{});
p.name = name;
p.ptype = ParameterType::kTexture;
return add_parameter(p);
}

int Callable::insert_pointer_param(const DataType &dt,
const std::string &name) {
parameter_list.emplace_back(dt->get_compute_type(), /*is_array=*/true);
parameter_list.back().name = name;
return (int)parameter_list.size() - 1;
std::vector<int> Callable::insert_pointer_param(const DataType &dt,
const std::string &name) {
auto p = Parameter(dt->get_compute_type(), /*is_array=*/true);
p.name = name;
return add_parameter(p);
}

int Callable::insert_rw_texture_param(int total_dim,
BufferFormat format,
const std::string &name) {
std::vector<int> Callable::insert_rw_texture_param(int total_dim,
BufferFormat format,
const std::string &name) {
// FIXME: we shouldn't abuse is_array for texture parameters
auto *type = TypeFactory::get_instance().get_rwtexture_struct_type();
parameter_list.emplace_back(type, /*is_array=*/true, 0, total_dim,
std::vector<int>{}, format);
parameter_list.back().name = name;
parameter_list.back().ptype = ParameterType::kRWTexture;
return (int)parameter_list.size() - 1;
auto p = Parameter(type, /*is_array=*/true, 0, total_dim, std::vector<int>{},
format);
p.name = name;
p.ptype = ParameterType::kRWTexture;
return add_parameter(p);
}

std::vector<int> Callable::insert_argpack_param_and_push(
const std::string &name) {
TI_ASSERT(temp_argpack_stack_.size() == temp_indices_stack_.size() &&
temp_argpack_name_stack_.size() == temp_indices_stack_.size());
if (temp_argpack_stack_.size() > 0) {
temp_indices_stack_.push_back(temp_argpack_stack_.top().size());
} else {
temp_indices_stack_.push_back(parameter_list.size());
}
temp_argpack_stack_.push(std::vector<Parameter>());
temp_argpack_name_stack_.push(name);
return temp_indices_stack_;
}

void Callable::pop_argpack_stack() {
// Compile argpack members to a struct.
TI_ASSERT(temp_argpack_stack_.size() > 0 && temp_indices_stack_.size() > 0 &&
temp_argpack_name_stack_.size() > 0);
std::vector<Parameter> argpack_params = temp_argpack_stack_.top();
std::vector<StructMember> members;
members.reserve(argpack_params.size());
for (int i = 0; i < argpack_params.size(); i++) {
auto &param = argpack_params[i];
members.push_back(
{param.is_array && !param.get_dtype()->is<StructType>()
? TypeFactory::get_instance().get_pointer_type(param.get_dtype())
: (const Type *)param.get_dtype(),
fmt::format("arg_{}_{}", fmt::join(temp_indices_stack_, "_"), i)});
}
auto *type =
TypeFactory::get_instance().get_struct_type(members)->as<StructType>();
auto p = Parameter(DataType(type), false);
p.name = temp_argpack_name_stack_.top();
add_parameter(p);
// Pop stacks
temp_argpack_stack_.pop();
temp_indices_stack_.pop_back();
temp_argpack_name_stack_.pop();
}

std::vector<int> Callable::add_parameter(const Parameter &param) {
TI_ASSERT(temp_argpack_stack_.size() == temp_indices_stack_.size() &&
temp_argpack_name_stack_.size() == temp_indices_stack_.size());
if (temp_argpack_stack_.size() == 0) {
parameter_list.push_back(param);
return std::vector<int>{(int)parameter_list.size() - 1};
}
TI_ASSERT(temp_argpack_stack_.size() > 0);
temp_argpack_stack_.top().push_back(param);
std::vector<int> ret = temp_indices_stack_;
ret.push_back(temp_argpack_stack_.top().size() - 1);
return ret;
}

void Callable::finalize_rets() {
Expand All @@ -98,6 +152,9 @@ void Callable::finalize_rets() {
}

void Callable::finalize_params() {
TI_ASSERT(temp_argpack_stack_.size() == 0 &&
temp_indices_stack_.size() == 0 &&
temp_argpack_name_stack_.size() == 0);
std::vector<StructMember> members;
members.reserve(parameter_list.size());
for (int i = 0; i < parameter_list.size(); i++) {
Expand Down
Loading

0 comments on commit 929cb89

Please sign in to comment.