Skip to content

Commit

Permalink
removed unused chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanebert committed Mar 22, 2024
1 parent c17d5b1 commit 3047d7b
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 111 deletions.
86 changes: 11 additions & 75 deletions examples/4d/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const scene = new SPLAT.Scene();
const camera = new SPLAT.Camera();
const controls = new SPLAT.OrbitControls(camera, canvas);

async function loadSplatV(url: string) {
async function loadSplatv(url: string) {
const req = await fetch(url, {
mode: "cors",
credentials: "omit",
Expand All @@ -19,93 +19,30 @@ async function loadSplatV(url: string) {
throw new Error(req.status + " Unable to load " + req.url);
}

const halfToFloat = (half: number): number => {
const sign = (half & 0x8000) >> 15;
const exp = (half & 0x7c00) >> 10;
const frac = half & 0x03ff;

if (exp === 0) {
return (sign ? -1 : 1) * Math.pow(2, -14) * (frac / 1024);
} else if (exp === 0x1f) {
return frac ? NaN : (sign ? -1 : 1) * Infinity;
}

return (sign ? -1 : 1) * Math.pow(2, exp - 15) * (1 + frac / 1024);
};

const unpackHalf2x16 = (value: number) => {
const h1 = value & 0xffff;
const h2 = (value >> 16) & 0xffff;
return [halfToFloat(h1), halfToFloat(h2)];
};

const processSplatBuffer = (buffer: Uint32Array) => {
const rowLength = 64;
const vertexCount = buffer.byteLength / rowLength;
const positions = new Float32Array(vertexCount * 3);
const rotations = new Float32Array(vertexCount * 4);
const scales = new Float32Array(vertexCount * 3);
const colors = new Uint8Array(vertexCount * 4);

const f_buffer = new Float32Array(buffer.buffer);
const u_buffer = new Uint8Array(buffer.buffer);

for (let i = 0; i < vertexCount; i++) {
positions[3 * i + 0] = f_buffer[16 * i + 0];
positions[3 * i + 1] = f_buffer[16 * i + 1];
positions[3 * i + 2] = f_buffer[16 * i + 2];

const rotXY = buffer[16 * i + 3]; // Half packed 2x16
const rotZW = buffer[16 * i + 4]; // Half packed 2x16
const scaleXY = buffer[16 * i + 5]; // Half packed 2x16
const scaleZ_ = buffer[16 * i + 6]; // Half packed 2x16

const [rotX, rotY] = unpackHalf2x16(rotXY);
const [rotZ, rotW] = unpackHalf2x16(rotZW);
const [scaleX, scaleY] = unpackHalf2x16(scaleXY);
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const [scaleZ, _] = unpackHalf2x16(scaleZ_);

rotations[4 * i + 0] = rotX;
rotations[4 * i + 1] = rotY;
rotations[4 * i + 2] = rotZ;
rotations[4 * i + 3] = rotW;

scales[3 * i + 0] = scaleX;
scales[3 * i + 1] = scaleY;
scales[3 * i + 2] = scaleZ;

colors[4 * i + 0] = u_buffer[64 * i + 28 + 0];
colors[4 * i + 1] = u_buffer[64 * i + 28 + 1];
colors[4 * i + 2] = u_buffer[64 * i + 28 + 2];
colors[4 * i + 3] = u_buffer[64 * i + 28 + 3];
}

const data = new SPLAT.SplatData(vertexCount, positions, rotations, scales, colors);
const splat = new SPLAT.Splat(data);
scene.addObject(splat);
};

const handleChunk = (
chunk: { size: number; type: string },
buffer: Uint8Array,
chunks: { size: number; type: string }[],
remaining: number,
) => {
if (remaining) return;

if (chunk.type === "magic") {
if (!remaining && chunk.type === "magic") {
const intView = new Int32Array(buffer.buffer);
if (intView[0] !== 0x674b) {
throw new Error("Invalid splatv file");
}
chunks.push({ size: intView[1], type: "chunks" });
} else if (chunk.type === "chunks") {
} else if (!remaining && chunk.type === "chunks") {
for (const chunk of JSON.parse(new TextDecoder("utf-8").decode(buffer))) {
chunks.push(chunk);
}
} else if (chunk.type === "splat") {
processSplatBuffer(new Uint32Array(buffer.buffer));
if (remaining) {
progressIndicator.value = 100 - (100 * remaining) / chunk.size;
} else {
const data = SPLAT.SplatvData.Deserialize(buffer);
const splatv = new SPLAT.Splatv(data);
scene.addObject(splatv);
}
}
};

Expand Down Expand Up @@ -140,8 +77,7 @@ async function loadSplatV(url: string) {

async function main() {
const url = "https://huggingface.co/cakewalk/splat-data/resolve/main/flame.splatv";
loadSplatV(url);
// await SPLAT.Loader.LoadAsync(url, scene, (progress) => (progressIndicator.value = progress * 100));
await loadSplatv(url);
progressDialog.close();

const handleResize = () => {
Expand Down
2 changes: 2 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
export { Object3D } from "./core/Object3D";
export { SplatData } from "./splats/SplatData";
export { SplatvData } from "./splats/SplatvData";
export { Splat } from "./splats/Splat";
export { Splatv } from "./splats/Splatv";
export { CameraData } from "./cameras/CameraData";
export { Camera } from "./cameras/Camera";
export { Scene } from "./core/Scene";
Expand Down
8 changes: 1 addition & 7 deletions src/renderers/webgl/programs/RenderProgram.ts
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ class RenderProgram extends ShaderProgram {
private _outlineColor: Color32 = new Color32(255, 165, 0, 255);
private _renderData: RenderData | null = null;
private _depthIndex: Uint32Array = new Uint32Array();
private _chunks: Uint8Array | null = null;
private _splatTexture: WebGLTexture | null = null;

protected _initialize: () => void;
Expand Down Expand Up @@ -221,9 +220,8 @@ class RenderProgram extends ShaderProgram {
worker = new SortWorker();
worker.onmessage = (e) => {
if (e.data.depthIndex) {
const { depthIndex, chunks } = e.data;
const { depthIndex } = e.data;
this._depthIndex = depthIndex;
this._chunks = chunks;
gl.bindBuffer(gl.ARRAY_BUFFER, indexBuffer);
gl.bufferData(gl.ARRAY_BUFFER, depthIndex, gl.STATIC_DRAW);
}
Expand Down Expand Up @@ -532,10 +530,6 @@ class RenderProgram extends ShaderProgram {
return this._depthIndex;
}

get chunks() {
return this._chunks;
}

get splatTexture() {
return this._splatTexture;
}
Expand Down
181 changes: 181 additions & 0 deletions src/renderers/webgl/programs/VideoRenderProgram.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import SortWorker from "web-worker:../utils/SortWorker.ts";

import { SplatvData } from "../../../index";
import { WebGLRenderer } from "../../WebGLRenderer";
import { ShaderPass } from "../passes/ShaderPass";
import { ShaderProgram } from "./ShaderProgram";

const vertexShaderSource = /* glsl */ `#version 300 es
precision highp float;
precision highp int;
uniform highp usampler2D u_texture;
uniform mat4 projection, view;
uniform vec2 focal;
uniform vec2 viewport;
uniform float time;
in vec2 position;
in int index;
out vec4 vColor;
out vec2 vPosition;
void main () {
gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
uvec4 motion1 = texelFetch(u_texture, ivec2(((uint(index) & 0x3ffu) << 2) | 3u, uint(index) >> 10), 0);
vec2 trbf = unpackHalf2x16(motion1.w);
float dt = time - trbf.x;
float topacity = exp(-1.0 * pow(dt / trbf.y, 2.0));
if(topacity < 0.02) return;
uvec4 motion0 = texelFetch(u_texture, ivec2(((uint(index) & 0x3ffu) << 2) | 2u, uint(index) >> 10), 0);
uvec4 static0 = texelFetch(u_texture, ivec2(((uint(index) & 0x3ffu) << 2), uint(index) >> 10), 0);
vec2 m0 = unpackHalf2x16(motion0.x), m1 = unpackHalf2x16(motion0.y), m2 = unpackHalf2x16(motion0.z),
m3 = unpackHalf2x16(motion0.w), m4 = unpackHalf2x16(motion1.x);
vec4 trot = vec4(unpackHalf2x16(motion1.y).xy, unpackHalf2x16(motion1.z).xy) * dt;
vec3 tpos = (vec3(m0.xy, m1.x) * dt + vec3(m1.y, m2.xy) * dt*dt + vec3(m3.xy, m4.x) * dt*dt*dt);
vec4 cam = view * vec4(uintBitsToFloat(static0.xyz) + tpos, 1);
vec4 pos = projection * cam;
float clip = 1.2 * pos.w;
if (pos.z < -clip || pos.x < -clip || pos.x > clip || pos.y < -clip || pos.y > clip) return;
uvec4 static1 = texelFetch(u_texture, ivec2(((uint(index) & 0x3ffu) << 2) | 1u, uint(index) >> 10), 0);
vec4 rot = vec4(unpackHalf2x16(static0.w).xy, unpackHalf2x16(static1.x).xy) + trot;
vec3 scale = vec3(unpackHalf2x16(static1.y).xy, unpackHalf2x16(static1.z).x);
rot /= sqrt(dot(rot, rot));
mat3 S = mat3(scale.x, 0.0, 0.0, 0.0, scale.y, 0.0, 0.0, 0.0, scale.z);
mat3 R = mat3(
1.0 - 2.0 * (rot.z * rot.z + rot.w * rot.w), 2.0 * (rot.y * rot.z - rot.x * rot.w), 2.0 * (rot.y * rot.w + rot.x * rot.z),
2.0 * (rot.y * rot.z + rot.x * rot.w), 1.0 - 2.0 * (rot.y * rot.y + rot.w * rot.w), 2.0 * (rot.z * rot.w - rot.x * rot.y),
2.0 * (rot.y * rot.w - rot.x * rot.z), 2.0 * (rot.z * rot.w + rot.x * rot.y), 1.0 - 2.0 * (rot.y * rot.y + rot.z * rot.z));
mat3 M = S * R;
mat3 Vrk = 4.0 * transpose(M) * M;
mat3 J = mat3(
focal.x / cam.z, 0., -(focal.x * cam.x) / (cam.z * cam.z),
0., -focal.y / cam.z, (focal.y * cam.y) / (cam.z * cam.z),
0., 0., 0.
);
mat3 T = transpose(mat3(view)) * J;
mat3 cov2d = transpose(T) * Vrk * T;
float mid = (cov2d[0][0] + cov2d[1][1]) / 2.0;
float radius = length(vec2((cov2d[0][0] - cov2d[1][1]) / 2.0, cov2d[0][1]));
float lambda1 = mid + radius, lambda2 = mid - radius;
if(lambda2 < 0.0) return;
vec2 diagonalVector = normalize(vec2(cov2d[0][1], lambda1 - cov2d[0][0]));
vec2 majorAxis = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
vec2 minorAxis = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);
uint rgba = static1.w;
vColor =
clamp(pos.z/pos.w+1.0, 0.0, 1.0) *
vec4(1.0, 1.0, 1.0, topacity) *
vec4(
(rgba) & 0xffu,
(rgba >> 8) & 0xffu,
(rgba >> 16) & 0xffu,
(rgba >> 24) & 0xffu) / 255.0;
vec2 vCenter = vec2(pos) / pos.w;
gl_Position = vec4(
vCenter
+ position.x * majorAxis / viewport
+ position.y * minorAxis / viewport, 0.0, 1.0);
vPosition = position;
}
`;

const fragmentShaderSource = /* glsl */ `#version 300 es
precision highp float;
in vec4 vColor;
in vec2 vPosition;
out vec4 fragColor;
void main () {
float A = -dot(vPosition, vPosition);
if (A < -4.0) discard;
float B = exp(A) * vColor.a;
fragColor = vec4(B * vColor.rgb, B);
}
`;

class VideoRenderProgram extends ShaderProgram {
private _renderData: SplatvData | null = null;
private _depthIndex: Uint32Array = new Uint32Array();
private _splatTexture: WebGLTexture | null = null;

protected _initialize: () => void;
protected _resize: () => void;
protected _render: () => void;
protected _dispose: () => void;

constructor(renderer: WebGLRenderer, passes: ShaderPass[]) {
super(renderer, passes);

const canvas = renderer.canvas;
const gl = renderer.gl;

let worker: Worker;

let u_projection: WebGLUniformLocation;
let u_viewport: WebGLUniformLocation;
let u_focal: WebGLUniformLocation;
let u_view: WebGLUniformLocation;
let u_texture: WebGLUniformLocation;
let u_time: WebGLUniformLocation;

let positionAttribute: number;
let indexAttribute: number;

let vertexBuffer: WebGLBuffer;
let indexBuffer: WebGLBuffer;

this._resize = () => {
if (!this._camera) return;

this._camera.data.setSize(canvas.width, canvas.height);
this._camera.update();

u_projection = gl.getUniformLocation(this.program, "projection") as WebGLUniformLocation;
gl.uniformMatrix4fv(u_projection, false, this._camera.data.projectionMatrix.buffer);

u_viewport = gl.getUniformLocation(this.program, "viewport") as WebGLUniformLocation;
gl.uniform2fv(u_viewport, new Float32Array([canvas.width, canvas.height]));
};

const createWorker = () => {
worker = new SortWorker();
worker.onmessage = (e) => {
if (e.data.depthIndex) {
const { depthIndex } = e.data;
this._depthIndex = depthIndex;
gl.bindBuffer(gl.ARRAY_BUFFER, indexBuffer);
gl.bufferData(gl.ARRAY_BUFFER, depthIndex, gl.STATIC_DRAW);
}
};
};
}

protected _getVertexSource(): string {
return vertexShaderSource;
}

protected _getFragmentSource(): string {
return fragmentShaderSource;
}
}

export { VideoRenderProgram };
12 changes: 1 addition & 11 deletions src/renderers/webgl/utils/SortWorker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ let viewProjPtr: number;
let transformsPtr: number;
let transformIndicesPtr: number;
let positionsPtr: number;
let chunksPtr: number;
let depthBufferPtr: number;
let depthIndexPtr: number;
let startsPtr: number;
Expand Down Expand Up @@ -48,7 +47,6 @@ const allocateBuffers = async () => {
wasmModule._free(viewProjPtr);
wasmModule._free(transformIndicesPtr);
wasmModule._free(positionsPtr);
wasmModule._free(chunksPtr);
wasmModule._free(depthBufferPtr);
wasmModule._free(depthIndexPtr);
wasmModule._free(startsPtr);
Expand All @@ -60,7 +58,6 @@ const allocateBuffers = async () => {
viewProjPtr = wasmModule._malloc(16 * 4);
transformIndicesPtr = wasmModule._malloc(allocatedVertexCount * 4);
positionsPtr = wasmModule._malloc(3 * allocatedVertexCount * 4);
chunksPtr = wasmModule._malloc(allocatedVertexCount);
depthBufferPtr = wasmModule._malloc(allocatedVertexCount * 4);
depthIndexPtr = wasmModule._malloc(allocatedVertexCount * 4);
startsPtr = wasmModule._malloc(allocatedVertexCount * 4);
Expand Down Expand Up @@ -99,7 +96,6 @@ const runSort = () => {
transformIndicesPtr,
sortData.vertexCount,
positionsPtr,
chunksPtr,
depthBufferPtr,
depthIndexPtr,
startsPtr,
Expand All @@ -109,13 +105,7 @@ const runSort = () => {
const depthIndex = new Uint32Array(wasmModule.HEAPU32.buffer, depthIndexPtr, sortData.vertexCount);
const detachedDepthIndex = new Uint32Array(depthIndex.slice().buffer);

const chunks = new Uint8Array(wasmModule.HEAPU8.buffer, chunksPtr, sortData.vertexCount);
const detachedChunks = new Uint8Array(chunks.slice().buffer);

self.postMessage({ depthIndex: detachedDepthIndex, chunks: detachedChunks }, [
detachedDepthIndex.buffer,
detachedChunks.buffer,
]);
self.postMessage({ depthIndex: detachedDepthIndex }, [detachedDepthIndex.buffer]);

lock = false;
};
Expand Down
Loading

0 comments on commit 3047d7b

Please sign in to comment.