Skip to content

Commit

Permalink
[lang] Instantiate a runtime ArgPack object when a python ArgPack is …
Browse files Browse the repository at this point in the history
…created

ghstack-source-id: d1dc8a846ff1202b84fc47c270d0e88e7dabeca3
Pull Request resolved: #8241
  • Loading branch information
listerily committed Jun 28, 2023
1 parent 3e71d16 commit 3b943e7
Show file tree
Hide file tree
Showing 13 changed files with 275 additions and 11 deletions.
25 changes: 15 additions & 10 deletions python/taichi/lang/argpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class ArgPack:
Args:
annotations (Dict[str, Union[Dict, Matrix, Struct]]): \
The keys and types for `ArgPack` members.
dtype (ArgPackType): \
The ArgPackType class of this ArgPack object.
entries (Dict[str, Union[Dict, Matrix, Struct]]): \
The keys and corresponding values for `ArgPack` members.
Expand All @@ -40,7 +42,7 @@ class ArgPack:

_instance_count = 0

def __init__(self, annotations, *args, **kwargs):
def __init__(self, annotations, dtype, *args, **kwargs):
# converts dicts to argument packs
if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
self.__entries = args[0]
Expand All @@ -56,7 +58,12 @@ def __init__(self, annotations, *args, **kwargs):
for k, v in self.__entries.items():
self.__entries[k] = v if in_python_scope() else impl.expr_init(v)
self._register_members()
self.__dtype = None
self.__dtype = dtype
self.__argpack = impl.get_runtime().prog.create_argpack(self.__dtype)

def __del__(self):
if impl is not None and impl.get_runtime() is not None and impl.get_runtime().prog is not None:
impl.get_runtime().prog.delete_argpack(self.__argpack)

@property
def keys(self):
Expand Down Expand Up @@ -181,7 +188,7 @@ class _IntermediateArgPack(ArgPack):
entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members.
"""

def __init__(self, annotations, *args, **kwargs):
def __init__(self, annotations, dtype, *args, **kwargs):
# converts dicts to argument packs
if len(args) == 1 and kwargs == {} and isinstance(args[0], dict):
self._ArgPack__entries = args[0]
Expand All @@ -195,7 +202,8 @@ def __init__(self, annotations, *args, **kwargs):
raise TaichiSyntaxError("ArgPack annotations keys not equals to entries keys.")
self._ArgPack__annotations = annotations
self._register_members()
self._ArgPack__dtype = None
self._ArgPack__dtype = dtype
self._ArgPack__argpack = impl.get_runtime().prog.create_argpack(dtype)


class ArgPackType(CompoundType):
Expand Down Expand Up @@ -263,10 +271,8 @@ def __call__(self, *args, **kwargs):

d[name] = data

entries = ArgPack(self.members, d)
entries._ArgPack__dtype = self.dtype
entries = ArgPack(self.members, self.dtype, d)
pack = self.cast(entries)
pack._ArgPack__dtype = self.dtype
return pack

def __instancecheck__(self, instance):
Expand Down Expand Up @@ -308,8 +314,7 @@ def cast(self, pack):
entries[k] = int(v) if dtype in primitive_types.integer_types else float(v)
else:
entries[k] = ops.cast(pack._ArgPack__entries[k], dtype)
pack = ArgPack(self.members, entries)
pack._ArgPack__dtype = self.dtype
pack = ArgPack(self.members, self.dtype, entries)
return pack

def from_taichi_object(self, arg_load_dict: dict):
Expand All @@ -318,7 +323,7 @@ def from_taichi_object(self, arg_load_dict: dict):
for index, pair in enumerate(items):
name, dtype = pair
d[name] = arg_load_dict[name]
pack = _IntermediateArgPack(self.members, d)
pack = _IntermediateArgPack(self.members, self.dtype, d)
pack._ArgPack__dtype = self.dtype
return pack

Expand Down
49 changes: 49 additions & 0 deletions taichi/program/argpack.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include <numeric>

#include "taichi/program/argpack.h"
#include "taichi/program/program.h"

#ifdef TI_WITH_LLVM
#include "taichi/runtime/llvm/llvm_context.h"
#include "taichi/runtime/program_impls/llvm/llvm_program.h"
#endif

namespace taichi::lang {

ArgPack::ArgPack(Program *prog,
const DataType type)
: prog_(prog) {

auto* old_type = type->get_type()->as<ArgPackType>();
auto [argpack_type, alloc_size] = prog->get_argpack_type_with_data_layout(
old_type, old_type->get_layout());
dtype = DataType(argpack_type);
argpack_alloc_ = prog->allocate_memory_on_device(alloc_size,
prog->result_buffer);
}

ArgPack::~ArgPack() {
if (prog_) {
argpack_alloc_.device->dealloc_memory(argpack_alloc_);
}
}

intptr_t ArgPack::get_device_allocation_ptr_as_int() const {
// taichi's own argpack's ptr points to its |DeviceAllocation| on the
// specified device.
return reinterpret_cast<intptr_t>(&argpack_alloc_);
}

DeviceAllocation ArgPack::get_device_allocation() const {
return argpack_alloc_;
}

std::size_t ArgPack::get_nelement() const {
return dtype->as<ArgPackType>()->elements().size();
}

DataType ArgPack::get_data_type() const {
return dtype;
}

} // namespace taichi::lang
35 changes: 35 additions & 0 deletions taichi/program/argpack.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <cstdint>
#include <vector>

#include "taichi/inc/constants.h"
#include "taichi/ir/type_utils.h"
#include "taichi/rhi/device.h"

namespace taichi::lang {

class Program;

class TI_DLL_EXPORT ArgPack {
public:
/* Constructs a ArgPack managed by Program.
* Memory allocation and deallocation is handled by Program.
*/
explicit ArgPack(Program *prog,
const DataType type);

DeviceAllocation argpack_alloc_{kDeviceNullAllocation};
DataType dtype;

DataType get_data_type() const;
intptr_t get_device_allocation_ptr_as_int() const;
DeviceAllocation get_device_allocation() const;
std::size_t get_nelement() const;

~ArgPack();

private:
Program *prog_{nullptr};
};
}
25 changes: 25 additions & 0 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,13 @@ Ndarray *Program::create_ndarray(const DataType type,
return arr_ptr;
}

ArgPack *Program::create_argpack(const DataType type) {
auto pack = std::make_unique<ArgPack>(this, type);
auto pack_ptr = pack.get();
argpacks_.insert({pack_ptr, std::move(pack)});
return pack_ptr;
}

void Program::delete_ndarray(Ndarray *ndarray) {
// [Note] Ndarray memory deallocation
// Ndarray's memory allocation is managed by Taichi and Python can control
Expand All @@ -416,6 +423,24 @@ void Program::delete_ndarray(Ndarray *ndarray) {
}
}

void Program::delete_argpack(ArgPack *argpack) {
// [Note] Argpack memory deallocation
// Argpack's memory allocation is managed by Taichi and Python can control
// this via Taichi indirectly. For example, when an argpack is GC-ed in
// Python, it signals Taichi to free its memory allocation. But Taichi will
// make sure **no pending kernels to be executed needs the argpack** before it
// actually frees the memory. When `ti.reset()` is called, all argpack
// allocated in this program should be gone and no longer valid in Python.
// This isn't the best implementation, argpacks should be managed by taichi
// runtime instead of this giant program and it should be freed when:
// - Python GC signals taichi that it's no longer useful
// - All kernels using it are executed.
if (argpacks_.count(argpack) &&
!program_impl_->used_in_kernel(argpack->argpack_alloc_.alloc_id)) {
argpacks_.erase(argpack);
}
}

Texture *Program::create_texture(BufferFormat buffer_format,
const std::vector<int> &shape) {
if (shape.size() == 1) {
Expand Down
14 changes: 13 additions & 1 deletion taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "taichi/ir/type_factory.h"
#include "taichi/ir/snode.h"
#include "taichi/util/lang_util.h"
#include "taichi/program/argpack.h"
#include "taichi/program/program_impl.h"
#include "taichi/program/callable.h"
#include "taichi/program/function.h"
Expand Down Expand Up @@ -255,6 +256,8 @@ class TI_DLL_EXPORT Program {
ExternalArrayLayout layout = ExternalArrayLayout::kNull,
bool zero_fill = false);

ArgPack *create_argpack(const DataType type);

std::string get_kernel_return_data_layout() {
return program_impl_->get_kernel_return_data_layout();
};
Expand All @@ -269,8 +272,16 @@ class TI_DLL_EXPORT Program {
return program_impl_->get_struct_type_with_data_layout(old_ty, layout);
}

std::pair<const ArgPackType *, size_t> get_argpack_type_with_data_layout(
const ArgPackType *old_ty,
const std::string &layout) {
return program_impl_->get_argpack_type_with_data_layout(old_ty, layout);
}

void delete_ndarray(Ndarray *ndarray);

void delete_argpack(ArgPack *argpack);

Texture *create_texture(BufferFormat buffer_format,
const std::vector<int> &shape);

Expand Down Expand Up @@ -335,8 +346,9 @@ class TI_DLL_EXPORT Program {
static std::atomic<int> num_instances_;
bool finalized_{false};

// TODO: Move ndarrays_ and textures_ to be managed by runtime
// TODO: Move ndarrays_, argpacks_ and textures_ to be managed by runtime
std::unordered_map<void *, std::unique_ptr<Ndarray>> ndarrays_;
std::unordered_map<void *, std::unique_ptr<ArgPack>> argpacks_;
std::vector<std::unique_ptr<Texture>> textures_;
};

Expand Down
6 changes: 6 additions & 0 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,12 @@ class ProgramImpl {
return {old_ty, 0};
}

virtual std::pair<const ArgPackType *, size_t>
get_argpack_type_with_data_layout(const ArgPackType *old_ty,
const std::string &layout) {
return {old_ty, 0};
}

KernelCompilationManager &get_kernel_compilation_manager();

KernelLauncher &get_kernel_launcher();
Expand Down
6 changes: 6 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,12 @@ void export_lang(py::module &m) {
py::arg("layout") = ExternalArrayLayout::kNull,
py::arg("zero_fill") = false, py::return_value_policy::reference)
.def("delete_ndarray", &Program::delete_ndarray)
.def("create_argpack",
[](Program *program, const DataType &dt) -> ArgPack * {
return program->create_argpack(dt);
},
py::arg("dt"), py::return_value_policy::reference)
.def("delete_argpack", &Program::delete_argpack)
.def(
"create_texture",
[&](Program *program, BufferFormat fmt, const std::vector<int> &shape)
Expand Down
74 changes: 74 additions & 0 deletions taichi/runtime/gfx/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -856,5 +856,79 @@ GfxRuntime::get_struct_type_with_data_layout_impl(
bytes, align};
}

std::pair<const lang::ArgPackType *, size_t>
GfxRuntime::get_argpack_type_with_data_layout(const lang::ArgPackType *old_ty,
const std::string &layout) {
auto [new_ty, size, align] =
get_argpack_type_with_data_layout_impl(old_ty, layout);
return {new_ty, size};
}

std::tuple<const lang::ArgPackType *, size_t, size_t>
GfxRuntime::get_argpack_type_with_data_layout_impl(
const lang::ArgPackType *old_ty,
const std::string &layout) {
TI_TRACE("get_argpack_type_with_data_layout: {}", layout);
TI_ASSERT(layout.size() == 2);
auto is_430 = layout[0] == '4';
auto has_buffer_ptr = layout[1] == 'b';
auto members = old_ty->elements();
size_t bytes = 0;
size_t align = 0;
for (int i = 0; i < members.size(); i++) {
auto &member = members[i];
size_t member_align;
size_t member_size;
if (auto struct_type = member.type->cast<lang::StructType>()) {
auto [new_ty, size, member_align_] =
get_struct_type_with_data_layout_impl(struct_type, layout);
members[i].type = new_ty;
member_align = member_align_;
member_size = size;
} else if (auto tensor_type = member.type->cast<lang::TensorType>()) {
size_t element_size = data_type_size_gfx(tensor_type->get_element_type());
size_t num_elements = tensor_type->get_num_elements();
if (!is_430) {
if (num_elements == 2) {
member_align = element_size * 2;
} else {
member_align = element_size * 4;
}
member_size = member_align;
} else {
member_align = element_size;
member_size = tensor_type->get_num_elements() * element_size;
}
} else if (auto pointer_type = member.type->cast<PointerType>()) {
if (has_buffer_ptr) {
member_size = sizeof(uint64_t);
member_align = member_size;
} else {
// Use u32 as placeholder
member_size = sizeof(uint32_t);
member_align = member_size;
}
} else {
TI_ASSERT(member.type->is<PrimitiveType>());
member_size = data_type_size_gfx(member.type);
member_align = member_size;
}
bytes = align_up(bytes, member_align);
members[i].offset = bytes;
bytes += member_size;
align = std::max(align, member_align);
}

if (!is_430) {
align = align_up(align, sizeof(float) * 4);
bytes = align_up(bytes, 4 * sizeof(float));
}
TI_TRACE(" total_bytes={}", bytes);
return {TypeFactory::get_instance()
.get_argpack_type(members, layout)
->as<lang::ArgPackType>(),
bytes, align};
}

} // namespace gfx
} // namespace taichi::lang
8 changes: 8 additions & 0 deletions taichi/runtime/gfx/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ class TI_DLL_EXPORT GfxRuntime {
get_struct_type_with_data_layout_impl(const lang::StructType *old_ty,
const std::string &layout);

static std::pair<const lang::ArgPackType *, size_t>
get_argpack_type_with_data_layout(const lang::ArgPackType *old_ty,
const std::string &layout);

static std::tuple<const lang::ArgPackType *, size_t, size_t>
get_argpack_type_with_data_layout_impl(const lang::ArgPackType *old_ty,
const std::string &layout);

private:
friend class taichi::lang::gfx::SNodeTreeManager;

Expand Down
Loading

0 comments on commit 3b943e7

Please sign in to comment.