Skip to content

Commit

Permalink
Replaced assertion in BN forward folding tests from direct tensors co…
Browse files Browse the repository at this point in the history
…mparison to norm MSE to prevent occasional failures (#762)

---------

Co-authored-by: Ofir Gordon <Ofir.Gordon@altair-semi.com>
  • Loading branch information
ofirgo and Ofir Gordon committed Aug 7, 2023
1 parent 524ed1f commit 89f8812
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 11 deletions.
11 changes: 9 additions & 2 deletions tests/common_tests/helpers/tensors_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,17 @@ def norm_similarity(a, b):


def normalized_mse(a, b, norm_factor=None):
batch_size = a.shape[0]

a = np.reshape(a, [batch_size, -1])
b = np.reshape(b, [batch_size, -1])

if norm_factor is None:
norm_factor = np.square(np.abs(a)).mean()
norm_factor = np.square(np.abs(a)).mean(axis=-1)
norm_factor = np.reshape(norm_factor, [batch_size, 1])

lsb_error = (np.abs(a - b)**2 / norm_factor)
return np.mean(lsb_error), np.std(lsb_error), np.max(lsb_error), np.min(lsb_error)
return np.mean(lsb_error, axis=-1), np.std(lsb_error, axis=-1), np.max(lsb_error, axis=-1), np.min(lsb_error, axis=-1)


def tensor_norm(a):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tests.keras_tests.tpc_keras import get_16bit_tpc
from tests.keras_tests.feature_networks_tests.base_keras_feature_test import BaseKerasFeatureNetworkTest
import numpy as np
from tests.common_tests.helpers.tensors_compare import cosine_similarity
from tests.common_tests.helpers.tensors_compare import cosine_similarity, normalized_mse
from tests.keras_tests.utils import get_layers_from_model_by_type

keras = tf.keras
Expand Down Expand Up @@ -229,7 +229,7 @@ class BNForwardFoldingTest(BaseKerasFeatureNetworkTest):
test that the BN isn't folded
"""
def __init__(self, unit_test, test_layer, conversion_applied, add_bn=False, is_dwconv=False):
super().__init__(unit_test=unit_test, experimental_exporter=True)
super().__init__(unit_test=unit_test, experimental_exporter=True, val_batch_size=2)
self.test_layer = test_layer
self.conversion_applied = conversion_applied
self.add_bn = add_bn
Expand Down Expand Up @@ -268,8 +268,17 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info=
sum([isinstance(l, tf.keras.layers.DepthwiseConv2D) for l in quantized_model.layers]))
else:
is_bn_in_model = any([isinstance(l, tf.keras.layers.BatchNormalization) for l in quantized_model.layers])

self.unit_test.assertTrue(self.conversion_applied is not is_bn_in_model)

# Checking on multiple inputs to reduce probability for numeric error that will randomly fail the test
self.unit_test.assertEqual(input_x[0].shape[0], 2, "Expecting batch of size 2 for BN folding test.")

out_float = float_model(input_x)
out_quant = quantized_model(input_x)
self.unit_test.assertTrue(np.isclose(out_quant, out_float, rtol=1e-4).all())

norm_mse, _, max_error, _ = normalized_mse(out_float.numpy(), out_quant.numpy())

self.unit_test.assertTrue(np.isclose(norm_mse[0], 0, atol=1e-5) or np.isclose(norm_mse[1], 0, atol=1e-5))
self.unit_test.assertTrue(np.isclose(max_error[0], 0, atol=1e-4) or np.isclose(max_error[1], 0, atol=1e-4))

5 changes: 3 additions & 2 deletions tests/pytorch_tests/model_tests/base_pytorch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ def __init__(self,
unit_test,
float_reconstruction_error=1e-7,
convert_to_fx=True,
experimental_exporter=True):
experimental_exporter=True,
val_batch_size=1):

super().__init__(unit_test)
super().__init__(unit_test, val_batch_size=val_batch_size)
self.float_reconstruction_error = float_reconstruction_error
self.convert_to_fx = convert_to_fx
self.experimental_exporter = experimental_exporter
Expand Down
17 changes: 13 additions & 4 deletions tests/pytorch_tests/model_tests/feature_models/bn_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
from model_compression_toolkit.core.pytorch.utils import set_model, to_torch_tensor, \
torch_tensor_to_numpy
from tests.common_tests.helpers.tensors_compare import normalized_mse
from tests.pytorch_tests.model_tests.base_pytorch_test import BasePytorchTest

"""
Expand Down Expand Up @@ -103,7 +104,7 @@ class BNForwardFoldingNetTest(BasePytorchTest):
test that the BN isn't folded
"""
def __init__(self, unit_test, test_layer, fold_applied=True, add_bn=False, is_dw=False):
super().__init__(unit_test, float_reconstruction_error=1e-6)
super().__init__(unit_test, float_reconstruction_error=1e-6, val_batch_size=2)
self.test_layer = test_layer
self.fold_applied = fold_applied
self.add_bn = add_bn
Expand All @@ -122,8 +123,6 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
set_model(float_model)
quant_model = quantized_models['no_quantization']
set_model(quant_model)
out_float = torch_tensor_to_numpy(float_model(*input_x))
out_quant = torch_tensor_to_numpy(quant_model(*input_x))

if self.is_dw:
is_bn_in_model = (sum([type(module) is torch.nn.Conv2d for name, module in float_model.named_modules()]) ==
Expand All @@ -132,4 +131,14 @@ def compare(self, quantized_models, float_model, input_x=None, quantization_info
is_bn_in_model = torch.nn.BatchNorm2d in [type(module) for name, module in quant_model.named_modules()]

self.unit_test.assertTrue(self.fold_applied is not is_bn_in_model)
self.unit_test.assertTrue(np.isclose(out_quant, out_float, atol=1e-6, rtol=1e-4).all())

# Checking on multiple inputs to reduce probability for numeric error that will randomly fail the test
self.unit_test.assertEqual(input_x[0].shape[0], 2, "Expecting batch of size 2 for BN folding test.")

out_float = torch_tensor_to_numpy(float_model(*input_x))
out_quant = torch_tensor_to_numpy(quant_model(*input_x))

norm_mse, _, max_error, _ = normalized_mse(out_float, out_quant)

self.unit_test.assertTrue(np.isclose(norm_mse[0], 0, atol=1e-5) or np.isclose(norm_mse[1], 0, atol=1e-5))
self.unit_test.assertTrue(np.isclose(max_error[0], 0, atol=1e-4) or np.isclose(max_error[1], 0, atol=1e-4))

0 comments on commit 89f8812

Please sign in to comment.