Skip to content

Commit

Permalink
temp - testing SPIR-V codegen fix
Browse files Browse the repository at this point in the history
  • Loading branch information
jrprice committed Jun 4, 2024
1 parent 2678ab3 commit 6cf71c1
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 53 deletions.
134 changes: 81 additions & 53 deletions src/nbody.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { mat4, vec3 } from 'gl-matrix'
import shaders from './shaders.wgsl'
import { mat4, vec3 } from "gl-matrix";
import shaders from "./shaders.wgsl";

// Simulation parameters.
let numBodies;
Expand All @@ -26,25 +26,26 @@ let positionsOut: GPUBuffer = null;
let velocities: GPUBuffer = null;
let renderParams: GPUBuffer = null;
let computeBindGroup: GPUBindGroup = null;
let othergroup: GPUBindGroup = null;
let renderBindGroup: GPUBindGroup = null;

const init = async () => {
// Initialize the WebGPU device.
const powerPref = <HTMLSelectElement>document.getElementById('powerpref');
const powerPref = <HTMLSelectElement>document.getElementById("powerpref");
const adapter = await navigator.gpu.requestAdapter({
powerPreference: <GPUPowerPreference>powerPref.selectedOptions[0].value,
});
device = await adapter.requestDevice()
device = await adapter.requestDevice();
queue = device.queue;

// Set up the canvas context.
canvas = <HTMLCanvasElement>document.getElementById('canvas');
canvasContext = canvas.getContext('webgpu');
}
canvas = <HTMLCanvasElement>document.getElementById("canvas");
canvasContext = canvas.getContext("webgpu");
};

// Generate WGSL shader source.
function getShaders() {
let preamble = ''
let preamble = "";
preamble += `const kWorkgroupSize = ${workgroupSize};\n`;
preamble += `const kNumBodies = ${numBodies};\n`;
return preamble + shaders;
Expand All @@ -69,8 +70,7 @@ const updateRenderParams = async () => {
// Generate the view projection matrix.
let projectionMatrix = mat4.create();
let viewProjectionMatrix = mat4.create();
mat4.perspectiveZO(projectionMatrix,
1.0, canvas.width / canvas.height, 0.1, 50.0);
mat4.perspectiveZO(projectionMatrix, 1.0, canvas.width / canvas.height, 0.1, 50.0);
mat4.translate(viewProjectionMatrix, viewProjectionMatrix, eyePosition);
mat4.multiply(viewProjectionMatrix, projectionMatrix, viewProjectionMatrix);

Expand All @@ -79,7 +79,7 @@ const updateRenderParams = async () => {
let viewProjectionMatrixHost = new Float32Array(renderParamsHost);
viewProjectionMatrixHost.set(viewProjectionMatrix);
queue.writeBuffer(renderParams, 0, renderParamsHost);
}
};

function initPipelines() {
// Reset pipelines.
Expand All @@ -104,42 +104,44 @@ function initPipelines() {
const positionsAttribute: GPUVertexAttribute = {
shaderLocation: 0,
offset: 0,
format: 'float32x4',
format: "float32x4",
};
const positionsLayout: GPUVertexBufferLayout = {
attributes: [positionsAttribute],
arrayStride: 4 * 4,
stepMode: 'instance',
stepMode: "instance",
};
renderPipeline = device.createRenderPipeline({
vertex: {
module: module,
entryPoint: 'vs_main',
entryPoint: "vs_main",
buffers: [positionsLayout],
},
fragment: {
module: module,
entryPoint: 'fs_main',
targets: [{
format: navigator.gpu.getPreferredCanvasFormat(),
blend: {
color: {
operation: "add",
srcFactor: "one",
dstFactor: "one",
},
alpha: {
operation: "add",
srcFactor: "one",
dstFactor: "one",
entryPoint: "fs_main",
targets: [
{
format: navigator.gpu.getPreferredCanvasFormat(),
blend: {
color: {
operation: "add",
srcFactor: "one",
dstFactor: "one",
},
alpha: {
operation: "add",
srcFactor: "one",
dstFactor: "one",
},
},
}
}],
},
],
},
primitive: {
frontFace: 'cw',
cullMode: 'none',
topology: 'triangle-list',
frontFace: "cw",
cullMode: "none",
topology: "triangle-list",
},
layout: "auto",
});
Expand All @@ -148,9 +150,9 @@ function initPipelines() {
computePipeline = device.createComputePipeline({
compute: {
module: module,
entryPoint: 'cs_main',
entryPoint: "cs_main",
},
layout: "auto"
layout: "auto",
});
}

Expand All @@ -159,25 +161,25 @@ function initBodies() {
positionsIn = device.createBuffer({
size: numBodies * 4 * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.VERTEX,
mappedAtCreation: true
mappedAtCreation: true,
});
positionsOut = device.createBuffer({
size: numBodies * 4 * 4,
usage: GPUBufferUsage.STORAGE | GPUBufferUsage.VERTEX,
mappedAtCreation: false
mappedAtCreation: false,
});
velocities = device.createBuffer({
size: numBodies * 4 * 4,
usage: GPUBufferUsage.STORAGE,
mappedAtCreation: false
mappedAtCreation: false,
});

// Generate initial positions on the surface of a sphere.
const kRadius = 0.6;
let positions = new Float32Array(positionsIn.getMappedRange());
for (let i = 0; i < numBodies; i++) {
let longitude = 2.0 * Math.PI * Math.random();
let latitude = Math.acos((2.0 * Math.random() - 1.0));
let latitude = Math.acos(2.0 * Math.random() - 1.0);
positions[i * 4 + 0] = kRadius * Math.sin(latitude) * Math.cos(longitude);
positions[i * 4 + 1] = kRadius * Math.sin(latitude) * Math.sin(longitude);
positions[i * 4 + 2] = kRadius * Math.cos(latitude);
Expand All @@ -203,7 +205,7 @@ function draw() {
const timeSinceLastLog = now - lastFpsUpdateTime;
if (timeSinceLastLog >= kFpsUpdateInterval) {
const fps = numFramesSinceFpsUpdate / (timeSinceLastLog / 1000.0);
document.getElementById("fps").innerHTML = fps.toFixed(1) + ' FPS';
document.getElementById("fps").innerHTML = fps.toFixed(1) + " FPS";
lastFpsUpdateTime = performance.now();
numFramesSinceFpsUpdate = 0;
}
Expand All @@ -215,9 +217,9 @@ function draw() {
// Update render parameters based on key presses.
if (currentKey) {
let zInc = 0.025;
if (currentKey.key == 'ArrowUp') {
if (currentKey.key == "ArrowUp") {
eyePosition[2] += zInc;
} else if (currentKey.key == 'ArrowDown') {
} else if (currentKey.key == "ArrowDown") {
eyePosition[2] -= zInc;
}
updateRenderParams();
Expand Down Expand Up @@ -250,6 +252,31 @@ function draw() {
],
});

let image_a = device.createTexture({
format: "r32uint",
size: [16, 16],
usage: GPUTextureUsage.STORAGE_BINDING,
});
let image_b = device.createTexture({
format: "r32uint",
size: [16, 16],
usage: GPUTextureUsage.STORAGE_BINDING,
});

othergroup = device.createBindGroup({
layout: computePipeline.getBindGroupLayout(1),
entries: [
{
binding: 0,
resource: image_a.createView(),
},
{
binding: 1,
resource: image_b.createView(),
},
],
});

// Create the bind group for the compute shader.
renderBindGroup = device.createBindGroup({
layout: renderPipeline.getBindGroupLayout(0),
Expand All @@ -268,6 +295,7 @@ function draw() {
const computePassEncoder = commandEncoder.beginComputePass();
computePassEncoder.setPipeline(computePipeline);
computePassEncoder.setBindGroup(0, computeBindGroup);
computePassEncoder.setBindGroup(1, othergroup);
computePassEncoder.dispatchWorkgroups(numBodies / workgroupSize);
computePassEncoder.end();

Expand All @@ -282,7 +310,7 @@ function draw() {
view: colorTextureView,
loadOp: "clear",
clearValue: { r: 0, g: 0, b: 0.1, a: 1 },
storeOp: 'store'
storeOp: "store",
};
const renderPassEncoder = commandEncoder.beginRenderPass({
colorAttachments: [colorAttachment],
Expand Down Expand Up @@ -317,43 +345,43 @@ const reset = async () => {
initPipelines();

paused = false;
}
};

function pause() {
paused = !paused;
document.getElementById("pause").innerText = paused ? 'Unpause' : 'Pause';
document.getElementById("pause").innerText = paused ? "Unpause" : "Pause";
}

reset();
draw();

// Set up button onclick handlers.
document.querySelector('#reset').addEventListener('click', reset);
document.querySelector('#pause').addEventListener('click', pause);
document.querySelector("#reset").addEventListener("click", reset);
document.querySelector("#pause").addEventListener("click", pause);

// Automatically reset when the number of bodies is changed.
document.querySelector('#numbodies').addEventListener('change', reset);
document.querySelector("#numbodies").addEventListener("change", reset);

// Automatically reset when the power preference is changed.
document.querySelector('#powerpref').addEventListener('change', () => {
document.querySelector("#powerpref").addEventListener("change", () => {
device = null;
computePipeline = null;
reset();
});

// Automatically rebuild the pipelines when the workgroup size is changed.
document.querySelector('#wgsize').addEventListener('change', initPipelines);
document.querySelector("#wgsize").addEventListener("change", initPipelines);

// Add an event handler to update render parameters when the window is resized.
window.addEventListener('resize', updateRenderParams);
window.addEventListener("resize", updateRenderParams);

// Handle key presses for user controls.
document.addEventListener('keydown', (e: KeyboardEvent) => {
if (e.key == ' ') {
document.addEventListener("keydown", (e: KeyboardEvent) => {
if (e.key == " ") {
pause();
}
currentKey = e;
});
document.addEventListener('keyup', (e: KeyboardEvent) => {
document.addEventListener("keyup", (e: KeyboardEvent) => {
currentKey = null;
});
6 changes: 6 additions & 0 deletions src/shaders.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ fn computeForce(ipos : vec4<f32>,
return coeff * d;
}

@group(1) @binding(0) var image_a: texture_storage_2d<r32uint,read>;
@group(1) @binding(1) var image_b: texture_storage_2d<r32uint,write>;

@compute @workgroup_size(kWorkgroupSize)
fn cs_main(
@builtin(global_invocation_id) gid : vec3<u32>,
Expand All @@ -39,6 +42,9 @@ fn cs_main(
velocity = velocity + force * kDelta;
velocities[idx] = velocity;

_ = image_a;
_ = image_b;

// Update position.
positionsOut[idx] = pos + velocity * kDelta;
}
Expand Down

0 comments on commit 6cf71c1

Please sign in to comment.