Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some kind of offset in TexturesUV Map, leads to faulty texture on mesh #1847

Closed
jonasmrc opened this issue Aug 1, 2024 · 7 comments
Closed

Comments

@jonasmrc
Copy link

jonasmrc commented Aug 1, 2024

Hello everyone,

I have some unexpected behavior with my code. I trained a NERF PermutoSDF on a Dataset. Afterwards there is a script within PermutoSDF to extract a mesh out of the NERF. However this mesh does not have any textures... So I would like to train those textures with pytorch3d. To do so I create a UV Map in blender for the mesh I got out of PermotoSDF. I use the Smart UV Project function of blender.

Then I train the resulting mesh with the following code:

import torch
from torch.utils.data import DataLoader
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import (
    FoVPerspectiveCameras,
    BlendParams,
    PointLights,
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
)


#Custom created Dataset with NERF images and their poses and intrinsics
dataset = Datasetloader.Permuto2PyTorch3D_Dataset(path_to_dataset)
train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

raster_settings = RasterizationSettings(
            image_size=img_width, #rendered images will be size of (img_width, img_width)
            blur_radius=0.0, 
            faces_per_pixel=10,
            bin_size=0,
            perspective_correct=False,
    )

#load Mesh
mesh = load_objs_as_meshes([path_to_mesh], device=device)
texture = mesh.textures.maps_padded() #get TextureUV Map
texture.requires_grad = True

#optimize TextureUV Map only
optimizer = torch.optim.Adam([texture], lr=0.01)
l1 = torch.nn.L1Loss()

for i in range(epochs):
    for img, R, T in train_dataloader:
        img = img.to(device)
        optimizer.zero_grad()
        cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=torch.degrees(img_fov_from_dataset))
        lights = PointLights(device=device, location=loc)

        renderer = MeshRenderer(
                    rasterizer=MeshRasterizer(
                        cameras=cameras, 
                        raster_settings=raster_settings
                    ),
                    shader=SoftPhongShader(
                        device=device, 
                        cameras=cameras,
                        lights=lights,
                        blend_params=BlendParams(background_color=(0.0, 0.0, 0.0)) #black Background
                    ),
                )

        meshes = mesh.extend(img.shape[0])
        images = renderer(meshes) #size (batch, img_width, img_width, 4 (RGBA))

        #center crop images to size (img_height, img_width)
        top = (img_width - img_height)/2
        bottom = (img_width + img_height)/2
        images = images[:, int(top):int(bottom), :, :] #size (batch, img_height, img_width, 4 (RGBA))

        loss_rgb = l1(images[..., :3], img[..., :3]) #ignore alpha channel
        loss_rgb.backward()
        optimizer.step()

        #clip TextureMap values to stay in range [0, 1]
        with torch.no_grad():
            texture[texture > 1] = 1
            texture[texture < 0] = 0

Now I get some unexpected behavior as result...
It looks like that there is some kind of offset within the TextureUV Map
uvmap
I expect all triangles to be offset a little bit to the left (e.g. see red circle in image). But there are also some triangles which looks correct, so the offset does not apply to all triangles. This leads to unwanted results with the texture of my mesh:
texture
As one can see there are a lot of grey (grey was the initial color of my texture map) textures in my mesh which supposed to be green.
I don't think it is an issue with my mesh structure, because when I try to paint some grey spots to green manually, there is no problem. Furthermore texture mapping within pytorch3d seems to be right too, because the (incorrectly) resulting texture map is mapped correctly on my mesh.

The thing which is wrong is the resulting texture map and I checked for the gradients at the Texture UV Map and these gradient seems to be offset too. So it looks like that there is something wrong with gradient calculations. I visualized the gradients with this code

...
loss_rgb.backward()
plt.imsave("grad.png", torch.abs(texture.grad[0, ..., :3] / texture.grad[0, ..., :3].max()).cpu().numpy())
optimizer.step()

gradients

I tried to optimize the TextureVertex instead of TextureUV as shown in your tutorial too. This worked fine and there wasn't any grey spots after training. But TextureVertex does not offer enough resolution, so I would prefer to optimize a higher resolution TexturesUV Map instead.

If any more information or explanation is necessary, please let me know.
Thank you for your help!

@bottler
Copy link
Contributor

bottler commented Aug 1, 2024

In the last image, have you plotted every triangle in the mesh on the uv map in grey, or only some of them? You should not have gradients in unused portions of the map. I wonder if you are using the right conventions for u and v? Can you plot the TexturesUV with one of the functions in https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/vis/texture_vis.py ?

@jonasmrc
Copy link
Author

jonasmrc commented Aug 2, 2024

I've plotted every triangle of my mesh (I just zoomed in so it is easier to see) and yes it is very odd that there are gradients in unused portions of the map. I am not an expert with meshes but I think I should be using the right conventions for u and v. I created the UV Map with blender Smart UV Project and export my mesh immediately afterwards. So as long as there isn't something wrong with blender UV, my mesh should be fine too.

I attached my potted TexturesUV (bigger than 10MB so there is a link to it https://uni-bonn.sciebo.de/s/6ca9gDnqUKgZy4K)
You might want to zoom in so it is easier to see all the details. The offset is mostly noticeable at the right or at the top. The middle part of the UV-Map seems to be ok.

I also attach the graddient variant with plottet UV points
uv

@bottler
Copy link
Contributor

bottler commented Aug 2, 2024

I think there's no guarantee that blender's conventions for u and v are the same as PyTorch3D's. This should be checked. The two images look the same way round, and you can see the red dots are at least roughly lined up with the colored area. I don't know what's happening here.

@jonasmrc
Copy link
Author

So in this comment I will describe how to reproduce my issue with your tutorial. I recreated the uv-map of the cow mesh, which was used in your tutorial. I intentionally choosed some unfortunately parameter for blender uv smart project. As result I confessedly got not the very best possible uv-map. However I would expect that the gradient descent based method should still be able to deal with this uv map. I attached this resulting mesh here
cow.zip. Now you can use this code to reproduce the texture map.

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from tqdm import tqdm

# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import (
    FoVPerspectiveCameras,
    BlendParams,
    PointLights,
    RasterizationSettings, 
    MeshRenderer, 
    MeshRasterizer,  
    SoftPhongShader,
    look_at_view_transform
)

class Cow_Dataset(Dataset):
    def __init__(self, imgs, R, T):
        self.imgs = imgs
        self.R = R
        self.T = T

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        return self.imgs[idx], self.R[idx], self.T[idx]

if __name__ == "__main__":
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")

    # Set paths
    DATA_DIR = "./data"
    obj_filename = os.path.join(DATA_DIR, "cow_mesh/cow.obj")

    # Load obj file
    mesh = load_objs_as_meshes([obj_filename], device=device)

    # We scale normalize and center the target mesh to fit in a sphere of radius 1 
    # centered at (0,0,0). (scale, center) will be used to bring the predicted mesh 
    # to its original center and scale.  Note that normalizing the target mesh, 
    # speeds up the optimization but is not necessary!
    verts = mesh.verts_packed()
    N = verts.shape[0]
    center = verts.mean(0)
    scale = max((verts - center).abs().max(0)[0])
    mesh.offset_verts_(-center)
    mesh.scale_verts_((1.0 / float(scale)))




    num_views = 20

    # Get a batch of viewing angles. 
    azim = np.array([0, 90, 180, 270])
    elev = np.array([-80, -40, 0, 40, 80])
    azim, elev = np.meshgrid(azim, elev)
    azim = torch.Tensor(azim).flatten()
    elev = torch.Tensor(elev).flatten()

    # Place a point light in front of the object. As mentioned above, the front of 
    # the cow is facing the -z direction. 
    lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
    
    # Initialize an OpenGL perspective camera that represents a batch of different 
    # viewing angles. All the cameras helper methods support mixed type inputs and 
    # broadcasting. So we can view the camera from the a distance of dist=2.7, and 
    # then specify elevation and azimuth angles for each viewpoint as tensors. 
    R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
    cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
    # lights = PointLights(device=device, location=T)

    # Define the settings for rasterization and shading. Here we set the output 
    # image to be of size 128X128. As we are rendering images for visualization 
    # purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to 
    # rasterize_meshes.py for explanations of these parameters.  We also leave 
    # bin_size and max_faces_per_bin to their default values of None, which sets 
    # their values using heuristics and ensures that the faster coarse-to-fine 
    # rasterization method is used.  Refer to docs/notes/renderer.md for an 
    # explanation of the difference between naive and coarse-to-fine rasterization. 
    raster_settings = RasterizationSettings(
        image_size=1024, 
        blur_radius=0.0, 
        faces_per_pixel=1, 
    )

    # Create a Phong renderer by composing a rasterizer and a shader. The textured 
    # Phong shader will interpolate the texture uv coordinates for each vertex, 
    # sample from a texture image and apply the Phong lighting model
    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras, 
            raster_settings=raster_settings
        ),
        shader=SoftPhongShader(
            device=device, 
            cameras=cameras,
            lights=lights,
            blend_params=BlendParams(background_color=(0.0, 0.0, 0.0))
        )
    )

    # Create a batch of meshes by repeating the cow mesh and associated textures. 
    # Meshes has a useful `extend` method which allows us do this very easily. 
    # This also extends the textures. 
    meshes = mesh.extend(num_views)

    # Render the cow mesh from each viewing angle
    target_images = renderer(meshes, cameras=cameras, lights=lights)

    # Our multi-view cow dataset will be represented by these 2 lists of tensors,
    # each of length num_views.
    target_rgb = [target_images[i, ..., :3] for i in range(num_views)]



    obj_filename = os.path.join(DATA_DIR, "cow.obj")
    mesh = load_objs_as_meshes([obj_filename], device=device)

    # We scale normalize and center the target mesh to fit in a sphere of radius 1 
    # centered at (0,0,0). (scale, center) will be used to bring the predicted mesh 
    # to its original center and scale.  Note that normalizing the target mesh, 
    # speeds up the optimization but is not necessary!
    verts = mesh.verts_packed()
    N = verts.shape[0]
    center = verts.mean(0)
    scale = max((verts - center).abs().max(0)[0])
    mesh.offset_verts_(-center)
    mesh.scale_verts_((1.0 / float(scale)))


    dataset = Cow_Dataset(target_rgb, R, T)
    train_dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

    raster_settings = RasterizationSettings(
            image_size=1024,
            blur_radius=0.0, 
            faces_per_pixel=1,
            bin_size=0,
            perspective_correct=False,
    )

    # Initial texture is red
    texture = mesh.textures.maps_padded()
    texture[:] = 0
    texture[..., 0] = 1
    texture.requires_grad = True

    optimizer = torch.optim.Adam([texture], lr=0.01)
    l1 = torch.nn.L1Loss()


    for i in (pbar := tqdm(range(50), unit="epochs")):

        for img, R, T in train_dataloader:
            img = img.to(device)
            optimizer.zero_grad()

            cameras = FoVPerspectiveCameras(device=device, R=R, T=T)

            
            renderer = MeshRenderer(
                rasterizer=MeshRasterizer(
                    cameras=cameras, 
                    raster_settings=raster_settings
                ),
                shader=SoftPhongShader(
                    device=device, 
                    cameras=cameras,
                    lights=lights,
                    blend_params=BlendParams(background_color=(0.0, 0.0, 0.0))
                ),
            )

            meshes = mesh.extend(img.shape[0])
            images = renderer(meshes)

            
            loss_rgb = l1(images[..., :3], img[..., :3])
           
            # save texturemap
            plt.imsave("tex_map.png", mesh.textures.maps_padded()[0].cpu().detach().numpy())
            # visualize current prediction
            plt.imsave("pred.png", images[0, ..., :3].cpu().detach().numpy())

            # Optimization step
            loss_rgb.backward()
            optimizer.step()

            with torch.no_grad():
                texture[texture > 1] = 1
                texture[texture < 0] = 0

            with torch.no_grad():
                pbar.set_description(f"loss={loss_rgb.item():.5f}, min={torch.min(mesh.textures.maps_padded()):.3f}, max={torch.max(mesh.textures.maps_padded()):.3f}")

just unzip the cow mesh into the data dear folder from your tutorial and run this code. Afterwards you can replace tex_map.png of the cow mesh with the tex_map.png saved by this code. Then you can use blender, MeshLab e.g. to show the mesh.

grafik
As you can see in between the triangles there are some red lines (red was the initial color of texture so it is easier to see the issue).

Now we have a look at the resulting texture-uv-map. When I zoomed into the left part of the texture-map, it seems to have a little offset to the right.
grafik0

When we look at the right part, it seems to have a little offset to the left
grafik1.

At the top of the texture map there is an offset down
grafik2

And when we concentrate on the bottom there is an offset up
grafik3

I am not quite sure but maybe there is some minimal wrong scale within the pytorch3d rendering? Or maybe calculations are not precise enough e.g. some integer division but it should be float or maybe float16 but it should be float32?

I hope you are able to reproduce my issue and if you need some further information please let me know.

@bottler
Copy link
Contributor

bottler commented Aug 23, 2024

Without looking in detail at your code, I have an idea: Can you check what do you are doing in terms of align_corners conventions? They could be an inconsistency. Firstly, can you just try changing the value when you create the TexturesUV?

@jonasmrc
Copy link
Author

I used the default value for align_corners. So this value should be true. I'll change it to false now and have a look what happens

@jonasmrc
Copy link
Author

Here is the solution!

  1. Set align_corners in TexturesUV to False
    texture = mesh.textures.maps_padded()
    texture[:] = 0.5
    texture.requires_grad = True
    **texture.align_corners = False**

and
2. Set perspective_correct in RasterizationSettings from False to None or True

raster_settings = RasterizationSettings(
            image_size=1024,
            blur_radius=0.0, 
            faces_per_pixel=1,
            bin_size=None,
            **perspective_correct=None**
    )

This fixes my problem. @bottler thank you for your help :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants