Skip to content

Commit

Permalink
flat shading fix
Browse files Browse the repository at this point in the history
Summary:
Make flat shading differentiable again

Currently test fails with P130944403 which looks weird.

Reviewed By: nikhilaravi

Differential Revision: D21567106

fbshipit-source-id: 65995b64739e08397b3d021b65625e3c377cd1a5
  • Loading branch information
gkioxari authored and facebook-github-bot committed May 14, 2020
1 parent 728179e commit a0e14ca
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions pytorch3d/renderer/mesh/shading.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,21 @@ def flat_shading(meshes, fragments, lights, cameras, materials, texels) -> torch
face_normals = meshes.faces_normals_packed() # (V, 3)
faces_verts = verts[faces]
face_coords = faces_verts.mean(dim=-2) # (F, 3, XYZ) mean xyz across verts
pixel_coords = face_coords[fragments.pix_to_face]
pixel_normals = face_normals[fragments.pix_to_face]

# Replace empty pixels in pix_to_face with 0 in order to interpolate.
mask = fragments.pix_to_face == -1
pix_to_face = fragments.pix_to_face.clone()
pix_to_face[mask] = 0

N, H, W, K = pix_to_face.shape
idx = pix_to_face.view(N * H * W * K, 1).expand(N * H * W * K, 3)

# gather pixel coords
pixel_coords = face_coords.gather(0, idx).view(N, H, W, K, 3)
pixel_coords[mask] = 0.0
# gather pixel normals
pixel_normals = face_normals.gather(0, idx).view(N, H, W, K, 3)
pixel_normals[mask] = 0.0

# Calculate the illumination at each face
ambient, diffuse, specular = _apply_lighting(
Expand Down

0 comments on commit a0e14ca

Please sign in to comment.