Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert RandomZoom to backend-agnostic and improve affine_transform #574

Merged
merged 18 commits into from
Jul 26, 2023

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Jul 21, 2023

Related to keras-team/keras#18442

EDITED:
please see #574 (comment)

=====OLD=====
Expected output in correctness tests with torch backend has changed. The root cause might come from the inconsistency of the interpolation.

# zoom -0.5
# output of tf, numpy and jax
[[ 6.  7.  7.  8.  8.]
 [11. 12. 12. 13. 13.]
 [11. 12. 12. 13. 13.]
 [16. 17. 17. 18. 18.]
 [16. 17. 17. 18. 18.]]
# output of torch
tensor([[ 6.,  6.,  7.,  7.,  8.],
        [ 6.,  6.,  7.,  7.,  8.],
        [11., 11., 12., 12., 13.],
        [11., 11., 12., 12., 13.],
        [16., 16., 17., 17., 18.]])

# zoom 0.5, 0.8
# output of tf, numpy and jax
[[ 0.  0.  0.  0.  0.]
 [ 0.  5.  7.  9.  0.]
 [ 0. 10. 12. 14.  0.]
 [ 0. 20. 22. 24.  0.]
 [ 0.  0.  0.  0.  0.]]
# output of torch
tensor([[ 0.,  0.,  0.,  0.,  0.],
        [ 0.,  6.,  7.,  9.,  0.],
        [ 0., 11., 12., 14.,  0.],
        [ 0., 21., 22., 24.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]])

@fchollet
Copy link
Member

The root cause might come from the inconsistency of the interpolation.

Is it because of the behavior of affine_transform? Can we fix the issue there? Or otherwise pick an interpolation that matches (we have the choice between several interpolations)?

@james77777778
Copy link
Contributor Author

Is it because of the behavior of affine_transform? Can we fix the issue there?

I believe so. I can dig deeper into it to identify the source of the differences.
However, I lack the confidence to fix it as tnn.grid_sample is implemented in torch's backend, making it challenging to comprehend.

Or otherwise pick an interpolation that matches (we have the choice between several interpolations)?

The correctness tests failed both in nearest and bilinear interpolation

# bilinear (JAX, TF, NumPy)
# zoom -0.5
[[ 6.   6.5  7.   7.5  8. ]
 [ 8.5  9.   9.5 10.  10.5]
 [11.  11.5 12.  12.5 13. ]
 [13.5 14.  14.5 15.  15.5]
 [16.  16.5 17.  17.5 18. ]]

# zoom 0.5, 0.8
[[ 0.5999999   0.20000005  2.          3.7999997   3.4       ]
 [ 3.1         2.7         4.5         6.2999997   5.9       ]
 [10.6        10.2        12.         13.799999   13.4       ]
 [18.1        17.7        19.5        21.3        20.9       ]
 [20.6        20.2        22.         23.8        23.4       ]]

# bilinear (Torch)
# zoom -0.5
tensor([[ 4.5000,  5.0000,  5.5000,  6.0000,  6.5000],
        [ 7.0000,  7.5000,  8.0000,  8.5000,  9.0000],
        [ 9.5000, 10.0000, 10.5000, 11.0000, 11.5000],
        [12.0000, 12.5000, 13.0000, 13.5000, 14.0000],
        [14.5000, 15.0000, 15.5000, 16.0000, 16.5000]])
# zoom 0.5, 0.8
tensor([[ 0.2000,  0.6000,  2.4000,  4.0000,  3.0000],
        [ 3.9500,  4.3500,  6.1500,  7.7500,  6.7500],
        [11.4500, 11.8500, 13.6500, 15.2500, 14.2500],
        [18.9500, 19.3500, 21.1500, 22.7500, 21.7500],
        [18.9500, 19.3500, 21.1500, 22.7500, 21.7500]])

@fchollet
Copy link
Member

So the root cause is that interpolation is implemented differently in torch (specifically only torch)? This is a bit surprising because there is an exact pixel-level algorithm that corresponds to each interpolation mode.

Is the difference visually noticeable on real images? The magnitude of the differences in test arrays looks quite significant.

If we conclude that the difference is small, we can roll with it and make a note in the docs that torch has numerical differences. If the difference is significant, we should explore alternative solutions.

@james77777778
Copy link
Contributor Author

there is an exact pixel-level algorithm that corresponds to each interpolation mode.

Actually, when I implementing affine_transform, I have noticed that the differences is significant. As a result, the CI is set to assertRaises when the backend is torch.

if backend.backend() == "torch":
# TODO: cannot pass with torch backend
with self.assertRaises(AssertionError):
self.assertAllClose(ref_out, out, atol=0.3)
else:
self.assertAllClose(ref_out, out, atol=0.3)

Is the difference visually noticeable on real images? The magnitude of the differences in test arrays looks quite significant.

There is the comparison of the image with the size (600, 512). I think the difference is not visually noticeable.

  • TF (same as JAX and NumPy)
    tf_affine

  • Torch
    torch_affine

Using the code below with KERAS_BACKEND=tensorflow/torch respectively:

import matplotlib.cbook as cbook
import matplotlib.pyplot as plt

from keras_core.layers import RandomZoom

with cbook.get_sample_data("grace_hopper.jpg") as image_file:
    image = plt.imread(image_file)

layer = RandomZoom(
    height_factor=(-0.5, -0.5),
    width_factor=(-0.5, -0.5),
    fill_mode="constant",
    interpolation="nearest",
)
output1 = layer(image)

layer = RandomZoom(
    height_factor=(0.5, 0.5),
    width_factor=(0.8, 0.8),
    fill_mode="constant",
    interpolation="nearest",
)
output2 = layer(image)

fig, ax_dict = plt.subplot_mosaic([["A", "B", "C"]], figsize=(12, 4))
ax_dict["A"].set_title("Original")
ax_dict["A"].imshow(image)
ax_dict["B"].set_title("Zoom In")
ax_dict["B"].imshow(output1 / 255.0)
ax_dict["C"].set_title("Zoom Out")
ax_dict["C"].imshow(output2 / 255.0)
fig.tight_layout(h_pad=0.1, w_pad=0.1)
plt.savefig("affine.png")

If we conclude that the difference is small, we can roll with it and make a note in the docs that torch has numerical differences. If the difference is significant, we should explore alternative solutions.

It seems that:

  • the difference is quite big in a small number array
  • the difference is not noticeable in a high res image

@fchollet
Copy link
Member

Thanks for the detailed info! I think we should go with option 1:

we can roll with it and make a note in the docs that torch has numerical differences.

We should make sure to include the note in the affine_transform docstring and the RandomZoom docstring. For RandomZoom it will largely not matter since it's an augmentation layer, but folks using affine_transform are definitely going to need to know about it.

@james77777778
Copy link
Contributor Author

james77777778 commented Jul 24, 2023

Thanks for the detailed info! I think we should go with option 1:

We should make sure to include the note in the affine_transform docstring and the RandomZoom docstring. For RandomZoom it will largely not matter since it's an augmentation layer, but folks using affine_transform are definitely going to need to know about it.

I'm currently working on a more consistent interpolation method in affine_transform between torch and the other backends.
Please wait for a while.

@fchollet
Copy link
Member

Sounds good!

@james77777778
Copy link
Contributor Author

james77777778 commented Jul 25, 2023

@fchollet

After refactoring the tnn.affine_grid and examining the details of tnn.grid_sample, I can summarize the following table:

fill_mode scipy(numpy) tensorflow torch jax
constant
nearest
wrap X
mirror X ✓ (called reflection)
reflect X

This table can be verified by the updated keras_core/ops/image_test.py

Some notes:

  • The numerical result of tensorflow with fill_mode=wrap is not consistent with scipy
  • We must use align_corner=True in tnn.grid_sample to have the same result of scipy's map_coordinates
  • The padding_mode=reflection in tnn.grid_sample is actually fill_mode=mirror in other backends (it took me a lot of time to realize lol)

Now, I have a question:
The preprocessing layers such as RandomTranslation, RandomZoom and RandomRotation has a default fill_mode=reflect. This value works well for most of the backends but not for torch.

I think there are two options:

  1. choose another default value like constant (scipy's default) or nearest
  2. As the visual difference is not noticeable in preprocessing layers, we can override the fill_mode to mirror when using torch

What should I do?

interpolation=interpolation.upper(),
fill_mode=fill_mode.upper(),
coordinates = _compute_affine_transform_coordinates(x, transform)
ref_out = scipy.ndimage.map_coordinates(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I replaced tf.raw_ops.ImageProjectiveTransformV3 with scipy.ndimage.map_coordinates because scipy has more fill_mode options to test all backends.

@fchollet
Copy link
Member

fchollet commented Jul 25, 2023

I think there are two options:
choose another default value like constant (scipy's default) or nearest
As the visual difference is not noticeable in preprocessing layers, we can override the fill_mode to mirror when using torch

Thanks for the analysis! I think we can make backend.torch.image.affine_transform redirect "reflect" to "mirror", and mention this in the docstrings.

We cannot pick nearest or constant as the default in RandomZoom/etc., because that would diminish the benefits of data augmentation. It's important that "new" areas in the augmented images be filled with content that is in-distribution and that minimizes visual discontinuity. Otherwise you train on information-free pixels (at best) or OOD pixels (at worst).

@james77777778 james77777778 changed the title Convert RandomZoom to the backend-agnostic implementation Convert RandomZoom to backend-agnostic and improve affine_transform Jul 26, 2023
@james77777778
Copy link
Contributor Author

I think we can make backend.torch.image.affine_transform redirect "reflect" to "mirror", and mention this in the docstrings.

I have updated the PR with the solution you mentioned. All tests passed.

We cannot pick nearest or constant as the default in RandomZoom/etc., because that would diminish the benefits of data augmentation. It's important that "new" areas in the augmented images be filled with content that is in-distribution and that minimizes visual discontinuity. Otherwise you train on information-free pixels (at best) or OOD pixels (at worst).

Thanks for the clarification! Now I know the default reflect plays an important role in data augmentation.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work -- thank you for the great contribution!

@fchollet fchollet merged commit b6b4376 into keras-team:main Jul 26, 2023
6 checks passed
@james77777778 james77777778 deleted the update-randomzoom branch July 26, 2023 03:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants