Skip to content

Commit

Permalink
[Bug] Fix Erroneous handling of ndarray in real function in CFG (#8245)
Browse files Browse the repository at this point in the history
Issue: #

`arg_id` doesn't make sense when two statements come from different
callables. Therefore we cannot decide whether a `ExternalPtrStmt` in the
real function points to the same ndarray as another `ExternalPtrStmt` in
the kernel by checking the `arg_id` of them.

Instead, we can assume that all ndarrays passed into a real function can
be modified. When a ndarray is passed into a real function, a
`ExternalTensorBasePtrStmt` is inserted into the argument of the real
function, which will appear in the store destinations of the
`FuncCallStmt`. Then, we can add support for it in the alias analysis.
We also don't include the `ExternalPtrStmt`s in the real functions in
the store destination sets as they don't make any sense.

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at ba828dc</samp>

Add support for real functions writing to external arrays or tensors.
Update alias analysis and store destination collection to handle
external pointers. Add a test case for the new feature.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at ba828dc</samp>

* Add a special case for alias analysis of external tensor base pointers
([link](https://github.com/taichi-dev/taichi/pull/8245/files?diff=unified&w=0#diff-de599d158682f1f0209f39aa58631d6df8d8ded1eacb00d7d6c5200ef7391793R17-R38))
* Skip external pointers and matrix pointers from external pointers when
gathering store destinations of a function
([link](https://github.com/taichi-dev/taichi/pull/8245/files?diff=unified&w=0#diff-0bfbe49ff08844a76d5d2e1c5b81c2cf813be4a9089422b997bc380ec9a68eadL65-R75))
* Add a test case for real functions writing to external arrays or
tensors in `test_ndarray.py`
([link](https://github.com/taichi-dev/taichi/pull/8245/files?diff=unified&w=0#diff-ca3c8d1edb25b6a7f4affbb79b2e3e74f73b3757e5d465258ce42ea9eb09fbc0R1132-R1150))
  • Loading branch information
lin-hitonami committed Jun 30, 2023
1 parent b53733b commit 9c5fb98
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
24 changes: 24 additions & 0 deletions taichi/analysis/alias_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,30 @@ AliasResult alias_analysis(Stmt *var1, Stmt *var2) {
if (!var1 || !var2)
return AliasResult::different;

if (var1->is<ExternalTensorBasePtrStmt>() ||
var2->is<ExternalTensorBasePtrStmt>()) {
auto *base = var1->cast<ExternalTensorBasePtrStmt>();
Stmt *other = var2;
if (!base) {
base = var2->cast<ExternalTensorBasePtrStmt>();
other = var1;
}
auto *external_ptr = other->cast<ExternalPtrStmt>();
if (!external_ptr) {
if (auto *matrix_ptr = other->cast<MatrixPtrStmt>()) {
external_ptr = matrix_ptr->origin->cast<ExternalPtrStmt>();
}
if (!external_ptr)
return AliasResult::different;
}
if (base->is_grad != external_ptr->is_grad)
return AliasResult::different;
if (base->arg_id == external_ptr->base_ptr->as<ArgLoadStmt>()->arg_id) {
return AliasResult::uncertain;
}
return AliasResult::different;
}

// TODO: further optimize with offset inside MatrixPtrStmt
// If at least one of var1 and var2 is local, they will be treated here.
auto retrieve_local = [&](Stmt *var) {
Expand Down
12 changes: 11 additions & 1 deletion taichi/analysis/gather_func_store_dests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,17 @@ class GatherFuncStoreDests : public BasicStmtVisitor {
return;
}
auto result = irpass::analysis::get_store_destination(stmt);
results_.insert(result.begin(), result.end());
for (const auto &dest : result) {
if (dest->is<ExternalPtrStmt>()) {
continue;
}
if (auto matrix_ptr = dest->cast<MatrixPtrStmt>()) {
if (matrix_ptr->origin->is<ExternalPtrStmt>()) {
continue;
}
}
results_.insert(dest);
}
}

void visit(FuncCallStmt *stmt) override {
Expand Down
19 changes: 19 additions & 0 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,3 +1129,22 @@ def test(x: ti.types.ndarray(ndim=1)) -> vec3:
x = ti.Vector.ndarray(3, ti.f32, shape=(1))
x[0] = vec3(1, 2, 3)
assert (test(x) == vec3(1, 2, 3)).all()


@test_utils.test(arch=[ti.cpu, ti.cuda])
def test_real_func_write_ndarray_cfg():
@ti.experimental.real_func
def bar(a: ti.types.ndarray(ndim=1)):
a[0] = vec3(1)

@ti.kernel
def foo(
a: ti.types.ndarray(ndim=1),
):
a[0] = vec3(3)
bar(a)
a[0] = vec3(3)

a = ti.Vector.ndarray(3, float, shape=(2,))
foo(a)
assert (a[0] == vec3(3)).all()

0 comments on commit 9c5fb98

Please sign in to comment.