diff --git a/dali/benchmark/operator_bench.h b/dali/benchmark/operator_bench.h index c584e3667e1..ef1996cc103 100644 --- a/dali/benchmark/operator_bench.h +++ b/dali/benchmark/operator_bench.h @@ -88,7 +88,8 @@ class OperatorBench : public DALIBenchmark { void RunGPU(benchmark::State &st, const OpSpec &op_spec, int batch_size = 128, TensorListShape<> shape = uniform_list_shape(128, {1080, 1920, 3}), TensorLayout layout = "HWC", - bool fill_in_data = false) { + bool fill_in_data = false, + int64_t sync_each_n = -1) { assert(layout.size() == shape.size()); auto op_ptr = InstantiateOperator(op_spec); @@ -117,29 +118,36 @@ class OperatorBench : public DALIBenchmark { Setup>(op_ptr, op_spec, ws, batch_size); op_ptr->Run(ws); CUDA_CALL(cudaStreamSynchronize(0)); + + int64_t batches = 0; for (auto _ : st) { op_ptr->Run(ws); - CUDA_CALL(cudaStreamSynchronize(0)); - - int num_batches = st.iterations() + 1; - st.counters["FPS"] = benchmark::Counter(batch_size * num_batches, - benchmark::Counter::kIsRate); + batches++; + if (sync_each_n > 0 && batches % sync_each_n == 0) { + CUDA_CALL(cudaStreamSynchronize(0)); + } } + + st.ResumeTiming(); + CUDA_CALL(cudaStreamSynchronize(0)); + st.PauseTiming(); + st.counters["FPS"] = benchmark::Counter(batch_size * st.iterations(), + benchmark::Counter::kIsRate); } template void RunGPU(benchmark::State &st, const OpSpec &op_spec, int batch_size = 128, TensorShape<> shape = {1080, 1920, 3}, TensorLayout layout = "HWC", - bool fill_in_data = false) { + bool fill_in_data = false, int64_t sync_each_n = -1) { RunGPU(st, op_spec, batch_size, - uniform_list_shape(batch_size, shape), layout, fill_in_data); + uniform_list_shape(batch_size, shape), layout, fill_in_data, sync_each_n); } template void RunGPU(benchmark::State& st, const OpSpec &op_spec, int batch_size = 128, int H = 1080, int W = 1920, int C = 3, - bool fill_in_data = false) { - RunGPU(st, op_spec, batch_size, {H, W, C}, "HWC", fill_in_data); + bool fill_in_data = false, int64_t sync_each_n = -1) { + RunGPU(st, op_spec, batch_size, {H, W, C}, "HWC", fill_in_data, sync_each_n); } };