Skip to content

Commit

Permalink
Port more things from NVIDIA#5302
Browse files Browse the repository at this point in the history
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
  • Loading branch information
klecki committed Mar 4, 2024
1 parent cc23c51 commit 5f4282d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 40 deletions.
13 changes: 11 additions & 2 deletions dali/test/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,7 +13,16 @@
# limitations under the License.

if (BUILD_TEST)
# get all the test srcs
# Get all the test srcs, make it a part of gtest binary
file(GLOB tmp *.cc *.cu *.h)
adjust_source_file_language_property("${tmp}")
set(DALI_TEST_SRCS ${DALI_TEST_SRCS} ${tmp} PARENT_SCOPE)

# Additionally build the operators as loadable library, so it can be imported as plugin in Python
set(lib_name "testoperatorplugin")
add_library(${lib_name} SHARED ${tmp})
target_link_libraries(${lib_name} PRIVATE ${CUDART_LIB})
target_link_libraries(${lib_name} PUBLIC dali)

set_target_properties(${lib_name} PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${TEST_BINARY_DIR})
endif()
58 changes: 20 additions & 38 deletions dali/test/python/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import nvidia.dali as dali
import nvidia.dali.types as types
from nvidia.dali.backend_impl import TensorListGPU, TensorGPU, TensorListCPU
from nvidia.dali import plugin_manager

import functools
import inspect
Expand Down Expand Up @@ -187,42 +188,6 @@ def get_absdiff(left, right):
return _get_absdiff(left, right)


def dump_as_core_artifacts(image_info, lhs, rhs, iter=None, sample_idx=None):
import_numpy()
import_pil()

from pathlib import Path

path = (
"/opt/dali"
if os.path.exists("/opt/dali") and os.access("/opt/dali", os.W_OK)
else os.getcwd()
)
Path(f"{path}/core_artifacts").mkdir(parents=True, exist_ok=True)

image_info = image_info.replace("/", "_")
image_info = image_info.replace(" ", "_")
if iter is not None:
image_info = image_info + f"_iter{iter}"
if sample_idx is not None:
image_info = image_info + f"_sample_idx{sample_idx}"

try:
save_image(lhs, f"{path}/core_artifacts/{image_info}.lhs.png")
save_image(rhs, f"{path}/core_artifacts/{image_info}.rhs.png")
except Exception as e:
print(f"Tried to save images but got an error: {e}")

try:
# save arrays on artifact folder
import numpy as np

np.save(f"{path}/core_artifacts/{image_info}.lhs.npy", lhs)
np.save(f"{path}/core_artifacts/{image_info}.rhs.npy", rhs)
except Exception as e:
print(f"Tried to save arrays but got an error: {e}")


# If the `max_allowed_error` is not None, it's checked instead of comparing mean error with `eps`.
def check_batch(
batch1,
Expand Down Expand Up @@ -328,8 +293,15 @@ def _verify_batch_size(batch):
error_msg += f"\nLHS data source: {batch1[i].source_info()}"
if hasattr(batch2[i], "source_info"):
error_msg += f"\nRHS data source: {batch2[i].source_info()}"

dump_as_core_artifacts(batch1[i].source_info(), left, right, sample_idx=i)
try:
save_image(left, "err_1.png")
save_image(right, "err_2.png")
except: # noqa:722
print("Batch at {} can't be saved as an image".format(i))
print("left: \n", left)
print("right: \n", right)
np.save("err_1.npy", left)
np.save("err_2.npy", right)
assert False, error_msg


Expand Down Expand Up @@ -972,3 +944,13 @@ def tested_ops(self):
return set(_tested_ops)

return SignOff()


def load_test_operator_plugin():
"""Load plugin containing the test operators from: `dali/test/operators`."""
test_bin_dir = os.path.dirname(dali.__file__) + "/test"
try:
plugin_manager.load_library(test_bin_dir + "/libtestoperatorplugin.so")
except RuntimeError:
# in conda "libtestoperatorplugin" lands inside lib/ dir
plugin_manager.load_library("libtestoperatorplugin.so")

0 comments on commit 5f4282d

Please sign in to comment.