@@ -42,24 +42,14 @@ const lrnProgramMetadata = {
42
42
inputTypes : [ TextureType . unpacked ]
43
43
} ;
44
44
45
- function getOutputExpression ( attributes : LrnAttributes ) : string {
46
- let expression = `float(${ attributes . bias } ) + float(${ attributes . alpha } ) * square_sum` ;
47
- if ( attributes . beta === 0.5 ) {
48
- expression = `inversesqrt(${ expression } )` ;
49
- } else if ( attributes . beta === 1.0 ) {
50
- expression = `1.0/(${ expression } )` ;
51
- } else {
52
- expression = `exp(log(${ expression } )) * float(-${ attributes . beta } )` ;
53
- }
54
- return `x * ${ expression } ` ;
55
- }
56
-
57
45
function createLrnProgramInfo ( inputs : Tensor [ ] , attributes : LrnAttributes ) : ProgramInfo {
58
46
const C = inputs [ 0 ] . dims [ 1 ] ;
59
-
60
47
const rank = inputs [ 0 ] . dims . length ;
61
48
const from = - Math . floor ( ( attributes . size - 1 ) / 2 ) ;
62
49
const to = Math . ceil ( ( attributes . size - 1 ) / 2 ) ;
50
+ const alpha = `float(${ attributes . alpha } ) / float(${ attributes . size } )` ;
51
+ const bias = `float(${ attributes . bias } )` ;
52
+ const beta = `float(${ attributes . beta } )` ;
63
53
64
54
const shaderSource = `
65
55
float process(int indices[${ rank } ]) {
@@ -75,8 +65,7 @@ function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): Prog
75
65
square_sum += j * j;
76
66
}
77
67
}
78
-
79
- return ${ getOutputExpression ( attributes ) } ;
68
+ return x / pow(${ bias } + ${ alpha } * square_sum, ${ beta } );
80
69
}` ;
81
70
return {
82
71
...lrnProgramMetadata ,
0 commit comments