Skip to content

Commit

Permalink
[GPU] Update range shape infer with epsilon
Browse files Browse the repository at this point in the history
  • Loading branch information
kelvinchoi-intel committed Sep 24, 2024
1 parent 1e9768b commit 610a225
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/core/shape_inference/include/range_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,13 @@ std::vector<TRShape> range_shape_infer(const Node* op,
span = stop - start;
}

double strided = ceil(fabs(span) / fabs(step));
uint32_t strided = static_cast<uint32_t>(ceil(fabs(span) / fabs(step)));
const double epsilon = 1e-06;
if (!output_is_integral && (strided - 1) * step >= span - epsilon) {
strided -= 1;
}

output_shapes[0] = TRShape{static_cast<uint32_t>(strided)};
output_shapes[0] = TRShape{strided};
} else {
output_shapes[0] = ov::PartialShape::dynamic(1);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ TEST_P(range_si_test, shape_infer) {

INSTANTIATE_TEST_SUITE_P(smoke, range_si_test,
testing::ValuesIn(std::vector<range_si_test_params>{
{ov::PartialShape{}, ov::PartialShape{39}, data_types::f32, {0.0f, 1.0f, 0.025641024f}},
{ov::PartialShape{}, ov::PartialShape{7}, data_types::i32, {2, 23, 3}},
{ov::PartialShape{}, ov::PartialShape{7}, data_types::i8, {2, 23, 3}},
{ov::PartialShape{}, ov::PartialShape{7}, data_types::u8, {2, 23, 3}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ INSTANTIATE_TEST_SUITE_P(range_gpu_test,
smoke_range_test,
testing::ValuesIn(
range_test_param_generator()
.simple_params(float_types, 0, 1.0f, 0.025641024f)
.simple_params(general_types, 2, 23, 3)
.simple_params(general_types, 1, 21, 2)
.simple_params(float_types, 1, 2.5f, 0.5f)
Expand Down

0 comments on commit 610a225

Please sign in to comment.