Skip to content

Commit

Permalink
[Prim][PIR] add index_sample op forward prim (#61825)
Browse files Browse the repository at this point in the history
* add index_sample decomp

* index_sample support dynamic shape

* update code

* update code
  • Loading branch information
kevincheng2 committed Feb 22, 2024
1 parent 957b1dd commit 60902c7
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"gelu",
"hardswish",
"group_norm",
"index_sample",
"index_select",
"instance_norm",
"layer_norm",
Expand Down Expand Up @@ -61,6 +62,7 @@
"gelu",
"hardswish",
"group_norm",
"index_sample",
"index_select",
"instance_norm",
"layer_norm",
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'InterpolateInferMeta',
'DeformableConvInferMeta',
'MatrixNMSInferMeta',
'IndexSampleInferMeta',
}

_PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE = {'FrobeniusNormOp'}
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_eager_prim_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ Tensor full<Tensor>(const IntArray& shape,
}
}

template <>
Tensor arange_with_tensor<Tensor>(const Tensor& start,
const Tensor& end,
const Tensor& step,
DataType dtype,
Place place) {
VLOG(4) << "Eager Prim API arange_ad_func call";
return ::arange_ad_func(start, end, step, dtype, place);
}

} // namespace backend
} // namespace primitive
} // namespace paddle
7 changes: 7 additions & 0 deletions paddle/fluid/primitive/backend/manual/manual_prim_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ Tensor reshape_with_tensor(const Tensor& x, const Tensor& shape);
template <typename T>
Tensor expand_with_tensor(const Tensor& x, const Tensor& shape);

template <typename T>
Tensor arange_with_tensor(const Tensor& start,
const Tensor& end,
const Tensor& step,
DataType dtype = DataType::FLOAT64,
Place place = CPUPlace());

} // namespace backend
} // namespace primitive
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ Tensor expand_with_tensor<LazyTensor>(const Tensor& x, const Tensor& shape) {
return out;
}

template <>
Tensor arange_with_tensor<LazyTensor>(const Tensor& start,
const Tensor& end,
const Tensor& step,
DataType dtype,
Place place) {
pir::Value start_val =
std::static_pointer_cast<LazyTensor>(start.impl())->value();
pir::Value end_val =
std::static_pointer_cast<LazyTensor>(end.impl())->value();
pir::Value step_val =
std::static_pointer_cast<LazyTensor>(step.impl())->value();
auto op_res =
paddle::dialect::arange(start_val, end_val, step_val, dtype, place);
Tensor out(std::make_shared<LazyTensor>(op_res));
return out;
}

} // namespace backend
} // namespace primitive
} // namespace paddle
25 changes: 25 additions & 0 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,31 @@ Tensor embedding_decomp(const Tensor& x,
return res;
}

template <typename T>
Tensor index_sample_decomp(const Tensor& x, const Tensor& index) {
std::vector<int64_t> tmp_shape{-1, 1};
auto index_dim = get_slice<T>(shape<T>(index), 0);
auto start =
backend::full_with_tensor<T>(shape<T>(index_dim), 0, index_dim.dtype());
auto step =
backend::full_with_tensor<T>(shape<T>(index_dim), 1, index_dim.dtype());
auto arange_tmp = reshape<T>(
backend::arange_with_tensor<T>(start, index_dim, step, index.dtype()),
tmp_shape);

auto index_res = reshape<T>(
backend::expand_with_tensor<T>(arange_tmp, shape<T>(index)), tmp_shape);
auto index_ = reshape<T>(index, tmp_shape);
auto concat_res = concat<T>({index_res, index_}, 1);
auto res = backend::reshape<T>(gather_nd<T>(x, concat_res), shape<T>(index));

if (res.dtype() != x.dtype()) {
return cast<T>(res, x.dtype());
} else {
return res;
}
}

} // namespace details

} // namespace primitive
Expand Down
10 changes: 8 additions & 2 deletions test/legacy_test/test_index_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
class TestIndexSampleOp(OpTest):
def setUp(self):
self.op_type = "index_sample"
self.prim_op_type = "comp"
self.python_api = paddle.index_sample
self.public_python_api = paddle.index_sample
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
if self.x_type == np.complex64 or self.x_type == np.complex128:
Expand All @@ -47,7 +49,7 @@ def setUp(self):
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output(check_pir=True)
self.check_output(check_pir=True, check_prim_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
Expand Down Expand Up @@ -158,7 +160,9 @@ def config(self):
class TestIndexSampleBF16Op(OpTest):
def setUp(self):
self.op_type = "index_sample"
self.prim_op_type = "comp"
self.python_api = paddle.index_sample
self.public_python_api = paddle.index_sample
self.config()
xnp = np.random.random(self.x_shape).astype(self.x_type)
indexnp = np.random.randint(
Expand All @@ -177,7 +181,9 @@ def setUp(self):
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place, check_pir=True)
self.check_output_with_place(
self.place, check_pir=True, check_prim_pir=True
)

def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out', check_pir=True)
Expand Down
71 changes: 71 additions & 0 deletions test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def tile_net2(x):
return y


def index_sample_net(x, index):
return paddle.index_sample(x, index)


class TestPrimOne(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
Expand Down Expand Up @@ -198,5 +202,72 @@ def setUp(self):
self.enable_cinn = False


class TestPrimTwo(unittest.TestCase):
def setUp(self):
np.random.seed(2023)
self.shape_x = [300, 4096]
self.shape_y = [300, 2048]
self.dtype_x = "float32"
self.dtype_y = int
self.init_x_shape = [None, 4096]
self.init_y_shape = [None, 2048]
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
self.y = np.random.random(self.shape_y).astype(self.dtype_y)
self.net = index_sample_net
self.necessary_ops = "pd_op.index_sample"
self.enable_cinn = False

def base_net(self, flag=None):
x = paddle.to_tensor(self.x)
y = paddle.to_tensor(self.y)
if flag == "prim":
core._set_prim_all_enabled(True)
fn = apply_to_static(
self.net,
use_cinn=self.enable_cinn,
input_spec=[
InputSpec(shape=self.init_x_shape, dtype=self.dtype_x),
InputSpec(shape=self.init_y_shape, dtype=self.dtype_y),
],
)
fn.eval()
else:
fn = self.net
res = fn(x, y)

if flag == "prim":
ops = [
op.name()
for op in fn.program_cache.last()[-1][-1]
.infer_program.program.global_block()
.ops
]
assert self.necessary_ops not in ops
core._set_prim_all_enabled(False)
return res

def test_prim_all_dynamic(self):
res_ref = self.base_net()
res = self.base_net("prim")
for ref, actual in zip(res_ref, res):
np.testing.assert_allclose(ref, actual, rtol=1e-6)


class TestPrimTwoIndexSample(TestPrimTwo):
def setUp(self):
np.random.seed(2023)
self.shape_x = [300, 4096]
self.shape_y = [300, 2048]
self.dtype_x = "float32"
self.dtype_y = int
self.init_x_shape = [None, 4096]
self.init_y_shape = [300, 2048]
self.x = np.random.random(self.shape_x).astype(self.dtype_x)
self.y = np.random.random(self.shape_y).astype(self.dtype_y)
self.net = index_sample_net
self.necessary_ops = "pd_op.index_sample"
self.enable_cinn = False


if __name__ == "__main__":
unittest.main()

0 comments on commit 60902c7

Please sign in to comment.