Skip to content

Commit

Permalink
chore: Update to TypeGPU 0.2
Browse files Browse the repository at this point in the history
  • Loading branch information
iwoplaza committed Nov 4, 2024
1 parent 68ff192 commit d249abf
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 114 deletions.
2 changes: 1 addition & 1 deletion apps/example/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"react-native-wgpu": "*",
"teapot": "^1.0.0",
"three": "0.168.0",
"typegpu": "^0.1.2",
"typegpu": "^0.2.0",
"wgpu-matrix": "^3.0.2"
},
"devDependencies": {
Expand Down
129 changes: 56 additions & 73 deletions apps/example/src/ComputeBoids/ComputeBoids.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,33 @@ type BoidsOptions = {
cohesionStrength: number;
};

const Parameters = struct({
separationDistance: f32,
separationStrength: f32,
alignmentDistance: f32,
alignmentStrength: f32,
cohesionDistance: f32,
cohesionStrength: f32,
});

const TriangleData = struct({
position: vec2f,
velocity: vec2f,
});

const TriangleDataArray = (n: number) => arrayOf(TriangleData, n);

const renderBindGroupLayout = tgpu.bindGroupLayout({
trianglePos: { storage: TriangleDataArray },
colorPalette: { uniform: vec3f },
});

const computeBindGroupLayout = tgpu.bindGroupLayout({
currentTrianglePos: { storage: TriangleDataArray },
nextTrianglePos: { storage: TriangleDataArray, access: 'mutable' },
params: { uniform: Parameters },
});

const colorPresets = {
plumTree: vec3f(1.0, 2.0, 1.0),
jeans: vec3f(2.0, 1.5, 1.0),
Expand Down Expand Up @@ -69,28 +96,20 @@ export function ComputeBoids() {
);

const ref = useWebGPU(({ context, device, presentationFormat }) => {
const root = tgpu.initFromDevice({ device });

context.configure({
device,
format: presentationFormat,
alphaMode: "premultiplied",
});

const params = struct({
separationDistance: f32,
separationStrength: f32,
alignmentDistance: f32,
alignmentStrength: f32,
cohesionDistance: f32,
cohesionStrength: f32,
});

const paramsBuffer = tgpu
.createBuffer(params, presets.default)
.$device(device)
.$usage(tgpu.Storage);
const paramsBuffer = root
.createBuffer(Parameters, presets.default)
.$usage("uniform");

const triangleSize = 0.03;
const triangleVertexBuffer = tgpu
const triangleVertexBuffer = root
.createBuffer(arrayOf(f32, 6), [
0.0,
triangleSize,
Expand All @@ -99,42 +118,33 @@ export function ComputeBoids() {
triangleSize / 2,
-triangleSize / 2,
])
.$device(device)
.$usage(tgpu.Vertex);
.$usage("vertex");

const triangleAmount = 1000;
const triangleInfoStruct = struct({
position: vec2f,
velocity: vec2f,
});
const trianglePosBuffers = Array.from({ length: 2 }, () =>
tgpu
.createBuffer(arrayOf(triangleInfoStruct, triangleAmount))
.$device(device)
.$usage(tgpu.Storage, tgpu.Uniform),
root.createBuffer(TriangleDataArray(triangleAmount)).$usage("storage")
);

randomizePositions.current = () => {
const positions = Array.from({ length: triangleAmount }, () => ({
position: vec2f(Math.random() * 2 - 1, Math.random() * 2 - 1),
velocity: vec2f(Math.random() * 0.1 - 0.05, Math.random() * 0.1 - 0.05),
}));
tgpu.write(trianglePosBuffers[0], positions);
tgpu.write(trianglePosBuffers[1], positions);
trianglePosBuffers[0].write(positions);
trianglePosBuffers[1].write(positions);
};
randomizePositions.current();

const colorPaletteBuffer = tgpu
const colorPaletteBuffer = root
.createBuffer(vec3f, colorPresets.plumTree)
.$device(device)
.$usage(tgpu.Uniform);
.$usage("uniform");

updateColorPreset.current = (newColorPreset: ColorPresets) => {
tgpu.write(colorPaletteBuffer, colorPresets[newColorPreset]);
colorPaletteBuffer.write(colorPresets[newColorPreset]);
};

updateParams.current = (newOptions: BoidsOptions) => {
tgpu.write(paramsBuffer, newOptions);
paramsBuffer.write(newOptions);
};

const renderModule = device.createShaderModule({
Expand All @@ -146,7 +156,9 @@ export function ComputeBoids() {
});

const pipeline = device.createRenderPipeline({
layout: "auto",
layout: device.createPipelineLayout({
bindGroupLayouts: [root.unwrap(renderBindGroupLayout)],
}),
vertex: {
module: renderModule,
buffers: [
Expand Down Expand Up @@ -176,55 +188,26 @@ export function ComputeBoids() {
});

const computePipeline = device.createComputePipeline({
layout: "auto",
layout: device.createPipelineLayout({
bindGroupLayouts: [root.unwrap(computeBindGroupLayout)],
}),
compute: {
module: computeModule,
},
});

const renderBindGroups = [0, 1].map((idx) =>
device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer: trianglePosBuffers[idx].buffer,
},
},
{
binding: 1,
resource: {
buffer: colorPaletteBuffer.buffer,
},
},
],
renderBindGroupLayout.populate({
trianglePos: trianglePosBuffers[idx],
colorPalette: colorPaletteBuffer,
}),
);

const computeBindGroups = [0, 1].map((idx) =>
device.createBindGroup({
layout: computePipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer: trianglePosBuffers[idx].buffer,
},
},
{
binding: 1,
resource: {
buffer: trianglePosBuffers[1 - idx].buffer,
},
},
{
binding: 2,
resource: {
buffer: paramsBuffer.buffer,
},
},
],
computeBindGroupLayout.populate({
currentTrianglePos: trianglePosBuffers[idx],
nextTrianglePos: trianglePosBuffers[1 - idx],
params: paramsBuffer,
}),
);

Expand All @@ -251,7 +234,7 @@ export function ComputeBoids() {
computePass.setPipeline(computePipeline);
computePass.setBindGroup(
0,
even ? computeBindGroups[0] : computeBindGroups[1],
root.unwrap(even ? computeBindGroups[0] : computeBindGroups[1])
);
computePass.dispatchWorkgroups(triangleAmount);
computePass.end();
Expand All @@ -261,7 +244,7 @@ export function ComputeBoids() {
passEncoder.setVertexBuffer(0, triangleVertexBuffer.buffer);
passEncoder.setBindGroup(
0,
even ? renderBindGroups[1] : renderBindGroups[0],
root.unwrap(even ? renderBindGroups[1] : renderBindGroups[0])
);
passEncoder.draw(3, triangleAmount);
passEncoder.end();
Expand Down
6 changes: 3 additions & 3 deletions apps/example/src/ComputeBoids/Shaders.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export const renderCode = /* wgsl */ `
@location(1) color : vec4f,
};
@binding(0) @group(0) var<uniform> trianglePos : array<TriangleData, ${triangleAmount}>;
@binding(0) @group(0) var<storage> trianglePos : array<TriangleData>;
@binding(1) @group(0) var<uniform> colorPalette : vec3f;
@vertex
Expand Down Expand Up @@ -67,9 +67,9 @@ export const computeCode = /* wgsl */ `
cohesion_strength : f32,
};
@binding(0) @group(0) var<uniform> currentTrianglePos : array<TriangleData, ${triangleAmount}>;
@binding(0) @group(0) var<storage> currentTrianglePos : array<TriangleData>;
@binding(1) @group(0) var<storage, read_write> nextTrianglePos : array<TriangleData>;
@binding(2) @group(0) var<storage> params : Parameters;
@binding(2) @group(0) var<uniform> params : Parameters;
@compute @workgroup_size(1)
fn mainCompute(@builtin(global_invocation_id) gid: vec3u) {
Expand Down
74 changes: 42 additions & 32 deletions apps/example/src/GradientTiles/GradientTiles.tsx
Original file line number Diff line number Diff line change
@@ -1,29 +1,49 @@
import { useEffect, useState } from "react";
import { useEffect, useMemo, useState } from "react";
import { Button, PixelRatio, StyleSheet, Text, View } from "react-native";
import { Canvas, useDevice, useGPUContext } from "react-native-wgpu";
import { struct, u32 } from "typegpu/data";
import tgpu from "typegpu";
import tgpu, { type TgpuBindGroup, type TgpuBuffer } from "typegpu";

import { vertWGSL, fragWGSL } from "./gradientWgsl";
import { vertWGSL, fragWGSL } from './gradientWgsl';

const Span = struct({
x: u32,
y: u32,
});

const bindGroupLayout = tgpu.bindGroupLayout({
span: { uniform: Span },
});

interface RenderingState {
pipeline: GPURenderPipeline;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
spanBuffer: any;
bindGroup: GPUBindGroup;
spanBuffer: TgpuBuffer<typeof Span>;
bindGroup: TgpuBindGroup<(typeof bindGroupLayout)['entries']>;
}

function useRoot() {
const { device } = useDevice();

return useMemo(
() => (device ? tgpu.initFromDevice({ device }) : null),
[device]
);
}

export function GradientTiles() {
const presentationFormat = navigator.gpu.getPreferredCanvasFormat();
const [state, setState] = useState<null | RenderingState>(null);
const [spanX, setSpanX] = useState(4);
const [spanY, setSpanY] = useState(4);
const { device } = useDevice();
const root = useRoot();
const { device = null } = root ?? {};
const { ref, context } = useGPUContext();

useEffect(() => {
if (!device || !context || state !== null) {
if (!device || !root || !context || state !== null) {
return;
}

const canvas = context.canvas as HTMLCanvasElement;
canvas.width = canvas.clientWidth * PixelRatio.get();
canvas.height = canvas.clientHeight * PixelRatio.get();
Expand All @@ -32,18 +52,14 @@ export function GradientTiles() {
format: presentationFormat,
});

const Span = struct({
x: u32,
y: u32,
});

const spanBuffer = tgpu
const spanBuffer = root
.createBuffer(Span, { x: 10, y: 10 })
.$device(device)
.$usage(tgpu.Uniform);
.$usage("uniform");

const pipeline = device.createRenderPipeline({
layout: "auto",
layout: device.createPipelineLayout({
bindGroupLayouts: [root.unwrap(bindGroupLayout)],
}),
vertex: {
module: device.createShaderModule({
code: vertWGSL,
Expand All @@ -64,24 +80,18 @@ export function GradientTiles() {
},
});

const bindGroup = device.createBindGroup({
layout: pipeline.getBindGroupLayout(0),
entries: [
{
binding: 0,
resource: {
buffer: spanBuffer.buffer,
},
},
],
const bindGroup = bindGroupLayout.populate({
span: spanBuffer,
});

setState({ bindGroup, pipeline, spanBuffer });
}, [context, device, presentationFormat, state]);
}, [context, device, root, presentationFormat, state]);

useEffect(() => {
if (!context || !device || !state) {
if (!context || !device || !root || !state) {
return;
}

const { bindGroup, pipeline, spanBuffer } = state;
const textureView = context.getCurrentTexture().createView();
const renderPassDescriptor: GPURenderPassDescriptor = {
Expand All @@ -95,18 +105,18 @@ export function GradientTiles() {
],
};

tgpu.write(spanBuffer, { x: spanX, y: spanY });
spanBuffer.write({ x: spanX, y: spanY });

const commandEncoder = device.createCommandEncoder();
const passEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.setBindGroup(0, root.unwrap(bindGroup));
passEncoder.draw(4);
passEncoder.end();

device.queue.submit([commandEncoder.finish()]);
context.present();
}, [context, device, spanX, spanY, state]);
}, [context, device, root, spanX, spanY, state]);

return (
<View style={style.container}>
Expand Down
Loading

0 comments on commit d249abf

Please sign in to comment.