Skip to content

Commit

Permalink
working aside from sort indices
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanebert committed Dec 10, 2023
1 parent 9d45ca3 commit 4153291
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 87 deletions.
5 changes: 5 additions & 0 deletions examples/editor/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const progressIndicator = document.getElementById("progress-indicator") as HTMLP
const scene = new SPLAT.Scene();
const camera = scene.findObjectOfType(SPLAT.Camera) as SPLAT.Camera;
camera.data.setSize(canvas.clientWidth, canvas.clientHeight);

const gridPass = new GridPass(camera, canvas);
const axisPass = new AxisPass(camera);
const shaderPasses = [gridPass, axisPass];
Expand All @@ -22,6 +23,10 @@ let mode = "";
async function main() {
const url = "https://huggingface.co/datasets/dylanebert/3dgs/resolve/main/bonsai/bonsai-7k-mini.splat";
const splat = await SPLAT.Loader.LoadAsync(url, scene, (progress) => (progressIndicator.value = progress * 100));
const splat2 = await SPLAT.Loader.LoadAsync(url, scene, (progress) => (progressIndicator.value = progress * 100));
splat2.position = new SPLAT.Vector3(2, 0, 0);
splat2.rotation = SPLAT.Quaternion.FromEuler(new SPLAT.Vector3(0, Math.PI / 2, 0));
splat2.scale = new SPLAT.Vector3(0.5, 0.5, 0.5);
progressDialog.close();
renderer.backgroundColor = new SPLAT.Color32(64, 64, 64, 255);

Expand Down
47 changes: 38 additions & 9 deletions src/renderers/WebGLRenderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import { Camera } from "../cameras/Camera";
import { ObjectAddedEvent, ObjectChangedEvent } from "../events/Events";
import { Splat } from "../splats/Splat";
import { Color32 } from "../math/Color32";
import { Matrix4 } from "../math/Matrix4";

export class WebGLRenderer {
private _domElement: HTMLCanvasElement;
Expand Down Expand Up @@ -56,18 +55,21 @@ export class WebGLRenderer {
let vertexShader: WebGLShader;
let fragmentShader: WebGLShader;
let texture: WebGLTexture;
let transformsTexture: WebGLTexture;

let u_projection: WebGLUniformLocation;
let u_viewport: WebGLUniformLocation;
let u_focal: WebGLUniformLocation;
let u_view: WebGLUniformLocation;
let u_texture: WebGLUniformLocation;
let u_transform: WebGLUniformLocation;
let u_transforms: WebGLUniformLocation;

let positionAttribute: number;
let indexAttribute: number;
let transformIndexAttribute: number;

let indexBuffer: WebGLBuffer;
let transformIndexBuffer: WebGLBuffer;
let vertexBuffer: WebGLBuffer;
let centerBuffer: WebGLBuffer;
let colorBuffer: WebGLBuffer;
Expand Down Expand Up @@ -147,8 +149,12 @@ export class WebGLRenderer {
u_view = gl.getUniformLocation(this._program, "view") as WebGLUniformLocation;
gl.uniformMatrix4fv(u_view, false, activeCamera.data.viewMatrix.buffer);

u_transform = gl.getUniformLocation(this._program, "transform") as WebGLUniformLocation;
gl.uniformMatrix4fv(u_transform, false, new Matrix4().buffer);
transformsTexture = gl.createTexture() as WebGLTexture;
gl.activeTexture(gl.TEXTURE1);
gl.bindTexture(gl.TEXTURE_2D, transformsTexture);

u_transforms = gl.getUniformLocation(this._program, "u_transforms") as WebGLUniformLocation;
gl.uniform1i(u_transforms, 1);

const triangleVertices = new Float32Array([-2, -2, 2, -2, 2, 2, -2, 2]);
vertexBuffer = gl.createBuffer() as WebGLBuffer;
Expand All @@ -160,6 +166,7 @@ export class WebGLRenderer {
gl.vertexAttribPointer(positionAttribute, 2, gl.FLOAT, false, 0, 0);

texture = gl.createTexture() as WebGLTexture;
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, texture);

u_texture = gl.getUniformLocation(this._program, "u_texture") as WebGLUniformLocation;
Expand All @@ -172,6 +179,14 @@ export class WebGLRenderer {
gl.vertexAttribIPointer(indexAttribute, 1, gl.INT, 0, 0);
gl.vertexAttribDivisor(indexAttribute, 1);

transformIndexBuffer = gl.createBuffer() as WebGLBuffer;
transformIndexAttribute = gl.getAttribLocation(this._program, "transformIndex");
gl.enableVertexAttribArray(transformIndexAttribute);
gl.bindBuffer(gl.ARRAY_BUFFER, transformIndexBuffer);
gl.vertexAttribIPointer(transformIndexAttribute, 1, gl.INT, 0, 0);
gl.vertexAttribDivisor(transformIndexAttribute, 1);

gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, texture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);
Expand All @@ -189,9 +204,6 @@ export class WebGLRenderer {
renderData.buffer,
);

gl.bindTexture(gl.TEXTURE_2D, texture);
gl.activeTexture(gl.TEXTURE0);

for (const shaderPass of shaderPasses) {
shaderPass.init(this);
}
Expand Down Expand Up @@ -283,6 +295,7 @@ export class WebGLRenderer {

if (renderData.dirty) {
renderData.rebuild();
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, texture);
gl.texImage2D(
gl.TEXTURE_2D,
Expand Down Expand Up @@ -311,10 +324,26 @@ export class WebGLRenderer {
gl.uniformMatrix4fv(u_view, false, activeCamera.data.viewMatrix.buffer);
gl.uniform2fv(u_viewport, new Float32Array([canvas.width, canvas.height]));
gl.uniform2fv(u_focal, new Float32Array([activeCamera.data.fx, activeCamera.data.fy]));
const matrix = renderData.transforms.slice(0, 16);
gl.uniformMatrix4fv(u_transform, false, matrix);
gl.bindBuffer(gl.ARRAY_BUFFER, transformIndexBuffer);
gl.bufferData(gl.ARRAY_BUFFER, renderData.transformIndices, gl.STATIC_DRAW);
gl.activeTexture(gl.TEXTURE1);
gl.bindTexture(gl.TEXTURE_2D, transformsTexture);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
gl.texImage2D(
gl.TEXTURE_2D,
0,
gl.RGBA32F,
renderData.transformsWidth,
renderData.transformsHeight,
0,
gl.RGBA,
gl.FLOAT,
renderData.transforms,
);
gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer);
gl.vertexAttribPointer(positionAttribute, 2, gl.FLOAT, false, 0, 0);
gl.activeTexture(gl.TEXTURE0);
gl.bindTexture(gl.TEXTURE_2D, texture);
gl.drawArraysInstanced(gl.TRIANGLE_FAN, 0, 4, renderData.vertexCount);
}
Expand Down
4 changes: 2 additions & 2 deletions src/renderers/webgl/passes/FadeInPass.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ class FadeInPass implements ShaderPass {
active = true;
activeRenderer = renderer;

u_useDepthFade = renderer.gl.getUniformLocation(renderer.program, "u_useDepthFade") as WebGLUniformLocation;
u_useDepthFade = renderer.gl.getUniformLocation(renderer.program, "useDepthFade") as WebGLUniformLocation;
activeRenderer.gl.uniform1i(u_useDepthFade, 1);

u_depthFade = renderer.gl.getUniformLocation(renderer.program, "u_depthFade") as WebGLUniformLocation;
u_depthFade = renderer.gl.getUniformLocation(renderer.program, "depthFade") as WebGLUniformLocation;
activeRenderer.gl.uniform1f(u_depthFade, value);
};

Expand Down
25 changes: 17 additions & 8 deletions src/renderers/webgl/shaders/vertex.glsl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,30 @@ uniform mat4 projection, view;
uniform vec2 focal;
uniform vec2 viewport;
uniform mat4 transform;
uniform sampler2D u_transforms;
uniform bool u_useDepthFade;
uniform float u_depthFade;
uniform bool useDepthFade;
uniform float depthFade;
in vec2 position;
in int index;
in int transformIndex;
out vec4 vColor;
out vec2 vPosition;
void main () {
int matrixRow = transformIndex * 4;
mat4 transformationMatrix = mat4(
texelFetch(u_transforms, ivec2(0, matrixRow), 0),
texelFetch(u_transforms, ivec2(0, matrixRow + 1), 0),
texelFetch(u_transforms, ivec2(0, matrixRow + 2), 0),
texelFetch(u_transforms, ivec2(0, matrixRow + 3), 0)
);
uvec4 cen = texelFetch(u_texture, ivec2((uint(index) & 0x3ffu) << 1, uint(index) >> 10), 0);
vec4 cam = view * transform * vec4(uintBitsToFloat(cen.xyz), 1);
vec4 worldPosition = transformationMatrix * vec4(uintBitsToFloat(cen.xyz), 1);
vec4 cam = view * worldPosition;
vec4 pos2d = projection * cam;
float clip = 1.2 * pos2d.w;
Expand All @@ -33,7 +43,6 @@ void main () {
return;
}
uvec4 cov = texelFetch(u_texture, ivec2(((uint(index) & 0x3ffu) << 1) | 1u, uint(index) >> 10), 0);
vec2 u1 = unpackHalf2x16(cov.x), u2 = unpackHalf2x16(cov.y), u3 = unpackHalf2x16(cov.z);
mat3 Vrk = mat3(u1.x, u1.y, u2.x, u1.y, u2.y, u3.x, u2.x, u3.x, u3.y);
Expand All @@ -44,7 +53,7 @@ void main () {
0., 0., 0.
);
mat3 T = transpose(mat3(view)) * J;
mat3 T = transpose(mat3(view * transformationMatrix)) * J;
mat3 cov2d = transpose(T) * Vrk * T;
float mid = (cov2d[0][0] + cov2d[1][1]) / 2.0;
Expand All @@ -61,13 +70,13 @@ void main () {
float scalingFactor = 1.0;
if(u_useDepthFade) {
if(useDepthFade) {
float depthNorm = (pos2d.z / pos2d.w + 1.0) / 2.0;
float near = 0.1; float far = 100.0;
float normalizedDepth = (2.0 * near) / (far + near - depthNorm * (far - near));
float start = max(normalizedDepth - 0.1, 0.0);
float end = min(normalizedDepth + 0.1, 1.0);
scalingFactor = clamp((u_depthFade - start) / (end - start), 0.0, 1.0);
scalingFactor = clamp((depthFade - start) / (end - start), 0.0, 1.0);
}
vec2 vCenter = vec2(pos2d) / pos2d.w;
Expand Down
49 changes: 34 additions & 15 deletions src/renderers/webgl/utils/RenderData.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@ import { Quaternion } from "../../../math/Quaternion";
import { Vector3 } from "../../../math/Vector3";

class RenderData {
private _indices: Map<Splat, number> = new Map<Splat, number>();
private _offsets: Map<Splat, number>;
private _objectIndices: Map<Splat, number>;
private _transformIndices: Uint32Array;
private _transforms: Float32Array;
private _transformsWidth: number;
private _transformsHeight: number;
private _buffer: Uint32Array;
private _width: number;
private _height: number;
Expand Down Expand Up @@ -55,18 +59,21 @@ class RenderData {
};

let vertexCount = 0;
this._objectIndices = new Map<Splat, number>();
this._offsets = new Map<Splat, number>();
for (const object of scene.objects) {
if (object instanceof Splat) {
this._indices.set(object, vertexCount);
this._objectIndices.set(object, this._objectIndices.size);
this._offsets.set(object, vertexCount);
vertexCount += object.data.vertexCount;
this._dirty.add(object);
}
}

const transformsTextureWidth = 4;
const transformsTextureHeight = this._indices.size * 4;
this._transforms = new Float32Array(transformsTextureWidth * transformsTextureHeight);
for (const [splat, index] of this._indices) {
this._transformsWidth = 1;
this._transformsHeight = this._offsets.size * 4;
this._transforms = new Float32Array(this._transformsWidth * this._transformsHeight * 4);
for (const [splat, index] of this._objectIndices) {
const transform = splat.matrix.buffer;
for (let j = 0; j < 16; j++) {
this._transforms[16 * index + j] = transform[j];
Expand All @@ -78,20 +85,24 @@ class RenderData {
this._height = Math.ceil((2 * this.vertexCount) / this.width);
this._buffer = new Uint32Array(this.width * this.height * 4);
this._positions = new Float32Array(this.vertexCount * 3);
this._transformIndices = new Uint32Array(this.vertexCount);

const data_f = new Float32Array(this._buffer.buffer);
const data_c = new Uint8Array(this._buffer.buffer);

const build = (splat: Splat) => {
const data = splat.data;
const positions = splat.worldPositions;
const rotations = splat.worldRotations;
const scales = splat.worldScales;
const positions = data.positions;
const rotations = data.rotations;
const scales = data.scales;
const colors = data.colors;
for (let i = 0; i < splat.data.vertexCount; i++) {
const offset = this._indices.get(splat) as number;
const offset = this._offsets.get(splat) as number;
const index = offset + i;

const objectIndex = this._objectIndices.get(splat) as number;
this._transformIndices[index] = objectIndex;

data_f[8 * index + 0] = positions[3 * i + 0];
data_f[8 * index + 1] = positions[3 * i + 1];
data_f[8 * index + 2] = positions[3 * i + 2];
Expand Down Expand Up @@ -134,7 +145,7 @@ class RenderData {
};

this.markDirty = (splat: Splat) => {
const index = this._indices.get(splat) as number;
const index = this._objectIndices.get(splat) as number;
const transform = splat.matrix.buffer;
for (let j = 0; j < 16; j++) {
this._transforms[16 * index + j] = transform[j];
Expand All @@ -156,14 +167,22 @@ class RenderData {
this.rebuild();
}

get indices() {
return this._indices;
}

get transforms() {
return this._transforms;
}

get transformsWidth() {
return this._transformsWidth;
}

get transformsHeight() {
return this._transformsHeight;
}

get transformIndices() {
return this._transformIndices;
}

get buffer() {
return this._buffer;
}
Expand Down
5 changes: 4 additions & 1 deletion src/renderers/webgl/utils/SortWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ async function initWasm() {
wasmModule = await loadWasm();
}

let sortData: { positions: Float32Array; vertexCount: number };
let sortData: {
positions: Float32Array;
vertexCount: number;
};
let allocatedVertexCount: number;
let viewProj: Matrix4;
let sortRunning = false;
Expand Down
52 changes: 0 additions & 52 deletions src/splats/Splat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ import { SplatData } from "./SplatData";
import { Object3D } from "../core/Object3D";
import { Vector3 } from "../math/Vector3";
import { Quaternion } from "../math/Quaternion";
import { Matrix3 } from "../math/Matrix3";

class Splat extends Object3D {
private _data: SplatData;
Expand Down Expand Up @@ -36,57 +35,6 @@ class Splat extends Object3D {
get dirty() {
return this._dirty;
}

get worldPositions() {
const result = new Float32Array(this.data.vertexCount * 3);
const R = Matrix3.RotationFromQuaternion(this.rotation).buffer;
for (let i = 0; i < this.data.vertexCount; i++) {
const x = this.data.positions[3 * i + 0];
const y = this.data.positions[3 * i + 1];
const z = this.data.positions[3 * i + 2];

result[3 * i + 0] = R[0] * x + R[1] * y + R[2] * z;
result[3 * i + 1] = R[3] * x + R[4] * y + R[5] * z;
result[3 * i + 2] = R[6] * x + R[7] * y + R[8] * z;

result[3 * i + 0] = result[3 * i + 0] * this.scale.x + this.position.x;
result[3 * i + 1] = result[3 * i + 1] * this.scale.y + this.position.y;
result[3 * i + 2] = result[3 * i + 2] * this.scale.z + this.position.z;
}

return result;
}

get worldRotations() {
const result = new Float32Array(this.data.vertexCount * 4);
for (let i = 0; i < this.data.vertexCount; i++) {
const currentRotation = new Quaternion(
this.data.rotations[4 * i + 1],
this.data.rotations[4 * i + 2],
this.data.rotations[4 * i + 3],
this.data.rotations[4 * i + 0],
);

const newRot = this.rotation.multiply(currentRotation);
result[4 * i + 1] = newRot.x;
result[4 * i + 2] = newRot.y;
result[4 * i + 3] = newRot.z;
result[4 * i + 0] = newRot.w;
}

return result;
}

get worldScales() {
const result = new Float32Array(this.data.vertexCount * 3);
for (let i = 0; i < this.data.vertexCount; i++) {
result[3 * i + 0] = this.scale.x * this.data.scales[3 * i + 0];
result[3 * i + 1] = this.scale.y * this.data.scales[3 * i + 1];
result[3 * i + 2] = this.scale.z * this.data.scales[3 * i + 2];
}

return result;
}
}

export { Splat };

0 comments on commit 4153291

Please sign in to comment.