Skip to content

Commit 9d2d7eb

Browse files
authored
TSL: Introduce ShaderCallNode & tslFn improvements (#26824)
* Add ShaderCallNode * cleanup * Use tslFn as default
1 parent 45505ed commit 9d2d7eb

File tree

7 files changed

+82
-45
lines changed

7 files changed

+82
-45
lines changed

examples/jsm/nodes/procedural/CheckerNode.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ class CheckerNode extends TempNode {
2525

2626
}
2727

28-
generate( builder ) {
28+
construct() {
2929

30-
return checkerShaderNode( { uv: this.uvNode } ).build( builder );
30+
return checkerShaderNode( { uv: this.uvNode } );
3131

3232
}
3333

examples/jsm/nodes/shadernode/ShaderNode.js

+56-18
Original file line numberDiff line numberDiff line change
@@ -196,42 +196,86 @@ const ShaderNodeImmutable = function ( NodeClass, ...params ) {
196196

197197
};
198198

199-
class ShaderNodeInternal extends Node {
199+
class ShaderCallNodeInternal extends Node {
200200

201-
constructor( jsFunc ) {
201+
constructor( shaderNode, inputNodes ) {
202202

203203
super();
204204

205-
this._jsFunc = jsFunc;
205+
this.shaderNode = shaderNode;
206+
this.inputNodes = inputNodes;
206207

207208
}
208209

209-
call( inputs, stack, builder ) {
210+
getNodeType( builder ) {
210211

211-
inputs = nodeObjects( inputs );
212+
const { outputNode } = builder.getNodeProperties( this );
212213

213-
return nodeObject( this._jsFunc( inputs, stack, builder ) );
214+
return outputNode ? outputNode.getNodeType( builder ) : super.getNodeType( builder );
214215

215216
}
216217

217-
getNodeType( builder ) {
218+
call( builder ) {
218219

219-
const { outputNode } = builder.getNodeProperties( this );
220+
const { shaderNode, inputNodes } = this;
220221

221-
return outputNode ? outputNode.getNodeType( builder ) : super.getNodeType( builder );
222+
const jsFunc = shaderNode.jsFunc;
223+
const outputNode = inputNodes !== null ? jsFunc( nodeObjects( inputNodes ), builder.stack, builder ) : jsFunc( builder.stack, builder );
224+
225+
return nodeObject( outputNode );
222226

223227
}
224228

225229
construct( builder ) {
226230

227231
builder.addStack();
228232

229-
builder.stack.outputNode = nodeObject( this._jsFunc( builder.stack, builder ) );
233+
builder.stack.outputNode = this.call( builder );
230234

231235
return builder.removeStack();
232236

233237
}
234238

239+
generate( builder, output ) {
240+
241+
const { outputNode } = builder.getNodeProperties( this );
242+
243+
if ( outputNode === null ) {
244+
245+
// TSL: It's recommended to use `tslFn` in construct() pass.
246+
247+
return this.call( builder ).build( builder, output );
248+
249+
}
250+
251+
return super.generate( builder, output );
252+
253+
}
254+
255+
}
256+
257+
class ShaderNodeInternal extends Node {
258+
259+
constructor( jsFunc ) {
260+
261+
super();
262+
263+
this.jsFunc = jsFunc;
264+
265+
}
266+
267+
call( inputs = null ) {
268+
269+
return nodeObject( new ShaderCallNodeInternal( this, inputs ) );
270+
271+
}
272+
273+
construct() {
274+
275+
return this.call();
276+
277+
}
278+
235279
}
236280

237281
const bools = [ false, true ];
@@ -349,15 +393,9 @@ export const shader = ( jsFunc ) => { // @deprecated, r154
349393

350394
export const tslFn = ( jsFunc ) => {
351395

352-
let shaderNode = null;
396+
const shaderNode = new ShaderNode( jsFunc );
353397

354-
return ( ...params ) => {
355-
356-
if ( shaderNode === null ) shaderNode = new ShaderNode( jsFunc );
357-
358-
return shaderNode.call( ...params );
359-
360-
};
398+
return ( inputs ) => shaderNode.call( inputs );
361399

362400
};
363401

examples/jsm/nodes/utils/LoopNode.js

+8-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Node, { addNodeClass } from '../core/Node.js';
22
import { expression } from '../code/ExpressionNode.js';
33
import { bypass } from '../core/BypassNode.js';
4-
import { context as contextNode } from '../core/ContextNode.js';
4+
import { context } from '../core/ContextNode.js';
55
import { addNodeElement, nodeObject, nodeArray } from '../shadernode/ShaderNode.js';
66

77
class LoopNode extends Node {
@@ -65,13 +65,11 @@ class LoopNode extends Node {
6565

6666
const properties = this.getProperties( builder );
6767

68-
const context = { tempWrite: false };
68+
const contextData = { tempWrite: false };
6969

7070
const params = this.params;
7171
const stackNode = properties.stackNode;
7272

73-
const returnsSnippet = properties.returnsNode ? properties.returnsNode.build( builder ) : '';
74-
7573
for ( let i = 0, l = params.length - 1; i < l; i ++ ) {
7674

7775
const param = params[ i ];
@@ -82,7 +80,7 @@ class LoopNode extends Node {
8280
if ( param.isNode ) {
8381

8482
start = '0';
85-
end = param.generate( builder, 'int' );
83+
end = param.build( builder, 'int' );
8684
direction = 'forward';
8785

8886
} else {
@@ -92,10 +90,10 @@ class LoopNode extends Node {
9290
direction = param.direction;
9391

9492
if ( typeof start === 'number' ) start = start.toString();
95-
else if ( start && start.isNode ) start = start.generate( builder, 'int' );
93+
else if ( start && start.isNode ) start = start.build( builder, 'int' );
9694

9795
if ( typeof end === 'number' ) end = end.toString();
98-
else if ( end && end.isNode ) end = end.generate( builder, 'int' );
96+
else if ( end && end.isNode ) end = end.build( builder, 'int' );
9997

10098
if ( start !== undefined && end === undefined ) {
10199

@@ -159,7 +157,9 @@ class LoopNode extends Node {
159157

160158
}
161159

162-
const stackSnippet = contextNode( stackNode, context ).build( builder, 'void' );
160+
const stackSnippet = context( stackNode, contextData ).build( builder, 'void' );
161+
162+
const returnsSnippet = properties.returnsNode ? properties.returnsNode.build( builder ) : '';
163163

164164
builder.removeFlowTab().addFlowCode( '\n' + builder.tab + stackSnippet );
165165

examples/webgpu_audio_processing.html

+3-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
<script type="module">
3232

3333
import * as THREE from 'three';
34-
import { ShaderNode, uniform, storage, instanceIndex, float, texture, viewportTopLeft, color } from 'three/nodes';
34+
import { tslFn, uniform, storage, instanceIndex, float, texture, viewportTopLeft, color } from 'three/nodes';
3535

3636
import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
3737

@@ -136,11 +136,10 @@
136136

137137
// compute (shader-node)
138138

139-
const computeShaderNode = new ShaderNode( ( stack ) => {
139+
const computeShaderFn = tslFn( ( stack ) => {
140140

141141
const index = float( instanceIndex );
142142

143-
144143
// pitch
145144

146145
const time = index.mul( pitch );
@@ -171,7 +170,7 @@
171170

172171
// compute
173172

174-
computeNode = computeShaderNode.compute( waveBuffer.length );
173+
computeNode = computeShaderFn().compute( waveBuffer.length );
175174

176175

177176
// gui

examples/webgpu_compute.html

+5-5
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
<script type="module">
2727

2828
import * as THREE from 'three';
29-
import { ShaderNode, uniform, storage, attribute, float, vec2, vec3, color, instanceIndex, PointsNodeMaterial } from 'three/nodes';
29+
import { tslFn, uniform, storage, attribute, float, vec2, vec3, color, instanceIndex, PointsNodeMaterial } from 'three/nodes';
3030

3131
import { GUI } from 'three/addons/libs/lil-gui.module.min.js';
3232

@@ -74,7 +74,7 @@
7474

7575
// create function
7676

77-
const computeShaderNode = new ShaderNode( ( stack ) => {
77+
const computeShaderFn = tslFn( ( stack ) => {
7878

7979
const particle = particleBufferNode.element( instanceIndex );
8080
const velocity = velocityBufferNode.element( instanceIndex );
@@ -98,10 +98,10 @@
9898

9999
// compute
100100

101-
computeNode = computeShaderNode.compute( particleNum );
101+
computeNode = computeShaderFn().compute( particleNum );
102102
computeNode.onInit = ( { renderer } ) => {
103103

104-
const precomputeShaderNode = new ShaderNode( ( stack ) => {
104+
const precomputeShaderNode = tslFn( ( stack ) => {
105105

106106
const particleIndex = float( instanceIndex );
107107

@@ -117,7 +117,7 @@
117117

118118
} );
119119

120-
renderer.compute( precomputeShaderNode.compute( particleNum ) );
120+
renderer.compute( precomputeShaderNode().compute( particleNum ) );
121121

122122
};
123123

examples/webgpu_compute_particles.html

+7-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
<script type="module">
2727

2828
import * as THREE from 'three';
29-
import { ShaderNode, uniform, texture, instanceIndex, float, vec3, storage, SpriteNodeMaterial } from 'three/nodes';
29+
import { tslFn, uniform, texture, instanceIndex, float, vec3, storage, SpriteNodeMaterial } from 'three/nodes';
3030

3131
import WebGPU from 'three/addons/capabilities/WebGPU.js';
3232
import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';
@@ -83,7 +83,7 @@
8383

8484
// compute
8585

86-
const computeInit = new ShaderNode( ( stack ) => {
86+
const computeInit = tslFn( ( stack ) => {
8787

8888
const position = positionBuffer.element( instanceIndex );
8989
const color = colorBuffer.element( instanceIndex );
@@ -98,11 +98,11 @@
9898

9999
stack.assign( color, vec3( randX, randY, randZ ) );
100100

101-
} ).compute( particleCount );
101+
} )().compute( particleCount );
102102

103103
//
104104

105-
const computeUpdate = new ShaderNode( ( stack ) => {
105+
const computeUpdate = tslFn( ( stack ) => {
106106

107107
const position = positionBuffer.element( instanceIndex );
108108
const velocity = velocityBuffer.element( instanceIndex );
@@ -128,7 +128,7 @@
128128

129129
} );
130130

131-
computeParticles = computeUpdate.compute( particleCount );
131+
computeParticles = computeUpdate().compute( particleCount );
132132

133133
// create nodes
134134

@@ -179,7 +179,7 @@
179179

180180
// click event
181181

182-
const computeHit = new ShaderNode( ( stack ) => {
182+
const computeHit = tslFn( ( stack ) => {
183183

184184
const position = positionBuffer.element( instanceIndex );
185185
const velocity = velocityBuffer.element( instanceIndex );
@@ -193,7 +193,7 @@
193193

194194
stack.assign( velocity, velocity.add( direction.mul( relativePower ) ) );
195195

196-
} ).compute( particleCount );
196+
} )().compute( particleCount );
197197

198198
//
199199

examples/webgpu_materials.html

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import * as THREE from 'three';
3030
import * as Nodes from 'three/nodes';
3131

32-
import { tslFn, wgslFn, attribute, positionLocal, positionWorld, normalLocal, normalWorld, normalView, color, texture, uv, float, vec2, vec3, vec4, oscSine, triplanarTexture, viewportBottomLeft, js, string, global, loop, MeshBasicNodeMaterial, NodeObjectLoader } from 'three/nodes';
32+
import { tslFn, wgslFn, positionLocal, positionWorld, normalLocal, normalWorld, normalView, color, texture, uv, float, vec2, vec3, vec4, oscSine, triplanarTexture, viewportBottomLeft, js, string, global, loop, MeshBasicNodeMaterial, NodeObjectLoader } from 'three/nodes';
3333

3434
import WebGPU from 'three/addons/capabilities/WebGPU.js';
3535
import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';

0 commit comments

Comments
 (0)