Skip to content

TSL: Introduce ShaderCallNode & tslFn improvements #26824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/jsm/nodes/procedural/CheckerNode.js
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class CheckerNode extends TempNode {

}

generate( builder ) {
construct() {

return checkerShaderNode( { uv: this.uvNode } ).build( builder );
return checkerShaderNode( { uv: this.uvNode } );

}

Expand Down
74 changes: 56 additions & 18 deletions examples/jsm/nodes/shadernode/ShaderNode.js
Original file line number Diff line number Diff line change
Expand Up @@ -196,42 +196,86 @@ const ShaderNodeImmutable = function ( NodeClass, ...params ) {

};

class ShaderNodeInternal extends Node {
class ShaderCallNodeInternal extends Node {

constructor( jsFunc ) {
constructor( shaderNode, inputNodes ) {

super();

this._jsFunc = jsFunc;
this.shaderNode = shaderNode;
this.inputNodes = inputNodes;

}

call( inputs, stack, builder ) {
getNodeType( builder ) {

inputs = nodeObjects( inputs );
const { outputNode } = builder.getNodeProperties( this );

return nodeObject( this._jsFunc( inputs, stack, builder ) );
return outputNode ? outputNode.getNodeType( builder ) : super.getNodeType( builder );

}

getNodeType( builder ) {
call( builder ) {

const { outputNode } = builder.getNodeProperties( this );
const { shaderNode, inputNodes } = this;

return outputNode ? outputNode.getNodeType( builder ) : super.getNodeType( builder );
const jsFunc = shaderNode.jsFunc;
const outputNode = inputNodes !== null ? jsFunc( nodeObjects( inputNodes ), builder.stack, builder ) : jsFunc( builder.stack, builder );

return nodeObject( outputNode );

}

construct( builder ) {

builder.addStack();

builder.stack.outputNode = nodeObject( this._jsFunc( builder.stack, builder ) );
builder.stack.outputNode = this.call( builder );

return builder.removeStack();

}

generate( builder, output ) {

const { outputNode } = builder.getNodeProperties( this );

if ( outputNode === null ) {

// TSL: It's recommended to use `tslFn` in construct() pass.

return this.call( builder ).build( builder, output );

}

return super.generate( builder, output );

}

}

class ShaderNodeInternal extends Node {

constructor( jsFunc ) {

super();

this.jsFunc = jsFunc;

}

call( inputs = null ) {

return nodeObject( new ShaderCallNodeInternal( this, inputs ) );

}

construct() {

return this.call();

}

}

const bools = [ false, true ];
Expand Down Expand Up @@ -349,15 +393,9 @@ export const shader = ( jsFunc ) => { // @deprecated, r154

export const tslFn = ( jsFunc ) => {

let shaderNode = null;
const shaderNode = new ShaderNode( jsFunc );

return ( ...params ) => {

if ( shaderNode === null ) shaderNode = new ShaderNode( jsFunc );

return shaderNode.call( ...params );

};
return ( inputs ) => shaderNode.call( inputs );

};

Expand Down
16 changes: 8 additions & 8 deletions examples/jsm/nodes/utils/LoopNode.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Node, { addNodeClass } from '../core/Node.js';
import { expression } from '../code/ExpressionNode.js';
import { bypass } from '../core/BypassNode.js';
import { context as contextNode } from '../core/ContextNode.js';
import { context } from '../core/ContextNode.js';
import { addNodeElement, nodeObject, nodeArray } from '../shadernode/ShaderNode.js';

class LoopNode extends Node {
Expand Down Expand Up @@ -65,13 +65,11 @@ class LoopNode extends Node {

const properties = this.getProperties( builder );

const context = { tempWrite: false };
const contextData = { tempWrite: false };

const params = this.params;
const stackNode = properties.stackNode;

const returnsSnippet = properties.returnsNode ? properties.returnsNode.build( builder ) : '';

for ( let i = 0, l = params.length - 1; i < l; i ++ ) {

const param = params[ i ];
Expand All @@ -82,7 +80,7 @@ class LoopNode extends Node {
if ( param.isNode ) {

start = '0';
end = param.generate( builder, 'int' );
end = param.build( builder, 'int' );
direction = 'forward';

} else {
Expand All @@ -92,10 +90,10 @@ class LoopNode extends Node {
direction = param.direction;

if ( typeof start === 'number' ) start = start.toString();
else if ( start && start.isNode ) start = start.generate( builder, 'int' );
else if ( start && start.isNode ) start = start.build( builder, 'int' );

if ( typeof end === 'number' ) end = end.toString();
else if ( end && end.isNode ) end = end.generate( builder, 'int' );
else if ( end && end.isNode ) end = end.build( builder, 'int' );

if ( start !== undefined && end === undefined ) {

Expand Down Expand Up @@ -159,7 +157,9 @@ class LoopNode extends Node {

}

const stackSnippet = contextNode( stackNode, context ).build( builder, 'void' );
const stackSnippet = context( stackNode, contextData ).build( builder, 'void' );

const returnsSnippet = properties.returnsNode ? properties.returnsNode.build( builder ) : '';

builder.removeFlowTab().addFlowCode( '\n' + builder.tab + stackSnippet );

Expand Down
7 changes: 3 additions & 4 deletions examples/webgpu_audio_processing.html
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
<script type="module">

import * as THREE from 'three';
import { ShaderNode, uniform, storage, instanceIndex, float, texture, viewportTopLeft, color } from 'three/nodes';
import { tslFn, uniform, storage, instanceIndex, float, texture, viewportTopLeft, color } from 'three/nodes';

import { GUI } from 'three/addons/libs/lil-gui.module.min.js';

Expand Down Expand Up @@ -136,11 +136,10 @@

// compute (shader-node)

const computeShaderNode = new ShaderNode( ( stack ) => {
const computeShaderFn = tslFn( ( stack ) => {

const index = float( instanceIndex );


// pitch

const time = index.mul( pitch );
Expand Down Expand Up @@ -171,7 +170,7 @@

// compute

computeNode = computeShaderNode.compute( waveBuffer.length );
computeNode = computeShaderFn().compute( waveBuffer.length );


// gui
Expand Down
10 changes: 5 additions & 5 deletions examples/webgpu_compute.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
<script type="module">

import * as THREE from 'three';
import { ShaderNode, uniform, storage, attribute, float, vec2, vec3, color, instanceIndex, PointsNodeMaterial } from 'three/nodes';
import { tslFn, uniform, storage, attribute, float, vec2, vec3, color, instanceIndex, PointsNodeMaterial } from 'three/nodes';

import { GUI } from 'three/addons/libs/lil-gui.module.min.js';

Expand Down Expand Up @@ -74,7 +74,7 @@

// create function

const computeShaderNode = new ShaderNode( ( stack ) => {
const computeShaderFn = tslFn( ( stack ) => {

const particle = particleBufferNode.element( instanceIndex );
const velocity = velocityBufferNode.element( instanceIndex );
Expand All @@ -98,10 +98,10 @@

// compute

computeNode = computeShaderNode.compute( particleNum );
computeNode = computeShaderFn().compute( particleNum );
computeNode.onInit = ( { renderer } ) => {

const precomputeShaderNode = new ShaderNode( ( stack ) => {
const precomputeShaderNode = tslFn( ( stack ) => {

const particleIndex = float( instanceIndex );

Expand All @@ -117,7 +117,7 @@

} );

renderer.compute( precomputeShaderNode.compute( particleNum ) );
renderer.compute( precomputeShaderNode().compute( particleNum ) );

};

Expand Down
14 changes: 7 additions & 7 deletions examples/webgpu_compute_particles.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
<script type="module">

import * as THREE from 'three';
import { ShaderNode, uniform, texture, instanceIndex, float, vec3, storage, SpriteNodeMaterial } from 'three/nodes';
import { tslFn, uniform, texture, instanceIndex, float, vec3, storage, SpriteNodeMaterial } from 'three/nodes';

import WebGPU from 'three/addons/capabilities/WebGPU.js';
import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';
Expand Down Expand Up @@ -83,7 +83,7 @@

// compute

const computeInit = new ShaderNode( ( stack ) => {
const computeInit = tslFn( ( stack ) => {

const position = positionBuffer.element( instanceIndex );
const color = colorBuffer.element( instanceIndex );
Expand All @@ -98,11 +98,11 @@

stack.assign( color, vec3( randX, randY, randZ ) );

} ).compute( particleCount );
} )().compute( particleCount );

//

const computeUpdate = new ShaderNode( ( stack ) => {
const computeUpdate = tslFn( ( stack ) => {

const position = positionBuffer.element( instanceIndex );
const velocity = velocityBuffer.element( instanceIndex );
Expand All @@ -128,7 +128,7 @@

} );

computeParticles = computeUpdate.compute( particleCount );
computeParticles = computeUpdate().compute( particleCount );

// create nodes

Expand Down Expand Up @@ -179,7 +179,7 @@

// click event

const computeHit = new ShaderNode( ( stack ) => {
const computeHit = tslFn( ( stack ) => {

const position = positionBuffer.element( instanceIndex );
const velocity = velocityBuffer.element( instanceIndex );
Expand All @@ -193,7 +193,7 @@

stack.assign( velocity, velocity.add( direction.mul( relativePower ) ) );

} ).compute( particleCount );
} )().compute( particleCount );

//

Expand Down
2 changes: 1 addition & 1 deletion examples/webgpu_materials.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import * as THREE from 'three';
import * as Nodes from 'three/nodes';

import { tslFn, wgslFn, attribute, positionLocal, positionWorld, normalLocal, normalWorld, normalView, color, texture, uv, float, vec2, vec3, vec4, oscSine, triplanarTexture, viewportBottomLeft, js, string, global, loop, MeshBasicNodeMaterial, NodeObjectLoader } from 'three/nodes';
import { tslFn, wgslFn, positionLocal, positionWorld, normalLocal, normalWorld, normalView, color, texture, uv, float, vec2, vec3, vec4, oscSine, triplanarTexture, viewportBottomLeft, js, string, global, loop, MeshBasicNodeMaterial, NodeObjectLoader } from 'three/nodes';

import WebGPU from 'three/addons/capabilities/WebGPU.js';
import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';
Expand Down