Skip to content

Commit 2863235

Browse files
authored
Add bitwiseAnd() ops (#7645)
* Test commit * Support ensureShape in tfjs * Applied changes * Updates doc * Updates doc * Updates doc * Add an example * add dif length and null shape tests * fix the lint error * Update executor tests * Update the header * Implement bitwiseAnd() ops and add kernel in Webcpu backEnd * Fix format * Change to only support int32 type
1 parent f3b00b9 commit 2863235

File tree

16 files changed

+195
-7
lines changed

16 files changed

+195
-7
lines changed

.vscode/settings.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
"editor.tabSize": 2,
1919
"editor.insertSpaces": true,
2020
"[typescript]": {
21-
"editor.formatOnSave": true
21+
"editor.formatOnSave": true,
22+
"editor.defaultFormatter": "xaver.clang-format"
2223
},
2324
"[javascript]": {
2425
"editor.formatOnSave": true
+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {BitwiseAnd, KernelConfig} from '@tensorflow/tfjs-core';
19+
20+
import {createSimpleBinaryKernelImpl} from '../utils/binary_impl';
21+
import {binaryKernelFunc} from '../utils/binary_utils';
22+
23+
export const bitwiseAndImpl =
24+
createSimpleBinaryKernelImpl(((a: number, b: number) => a & b));
25+
26+
export const bitwiseAnd = binaryKernelFunc(BitwiseAnd, bitwiseAndImpl);
27+
28+
export const bitwiseAndConfig: KernelConfig = {
29+
kernelName: BitwiseAnd,
30+
backendName: 'cpu',
31+
kernelFunc: bitwiseAnd
32+
};

tfjs-backend-cpu/src/register_all_kernels.ts

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import {batchMatMulConfig} from './kernels/BatchMatMul';
4242
import {batchNormConfig} from './kernels/BatchNorm';
4343
import {batchToSpaceNDConfig} from './kernels/BatchToSpaceND';
4444
import {bincountConfig} from './kernels/Bincount';
45+
import {bitwiseAndConfig} from './kernels/BitwiseAnd';
4546
import {broadcastArgsConfig} from './kernels/BroadcastArgs';
4647
import {castConfig} from './kernels/Cast';
4748
import {ceilConfig} from './kernels/Ceil';
@@ -215,6 +216,7 @@ const kernelConfigs: KernelConfig[] = [
215216
batchNormConfig,
216217
batchToSpaceNDConfig,
217218
bincountConfig,
219+
bitwiseAndConfig,
218220
broadcastArgsConfig,
219221
castConfig,
220222
ceilConfig,

tfjs-backend-cpu/src/shared.ts

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
export {simpleAbsImpl} from './kernels/Abs';
2020
export {addImpl} from './kernels/Add';
2121
export {bincountImpl, bincountReduceImpl} from './kernels/Bincount_impl';
22+
export {bitwiseAndImpl} from './kernels/BitwiseAnd';
2223
export {castImpl} from './kernels/Cast';
2324
export {ceilImpl} from './kernels/Ceil';
2425
export {concatImpl} from './kernels/Concat_impl';

tfjs-backend-wasm/src/setup_test.ts

+6
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,12 @@ const TEST_FILTERS: TestFilter[] = [
290290
{startsWith: 'logicalNot '},
291291
{startsWith: 'logicalOr '},
292292
{startsWith: 'logicalXor '},
293+
{
294+
startsWith: 'bitwiseAnd',
295+
excludes: [
296+
'bitwiseAnd',
297+
]
298+
},
293299
{
294300
startsWith: 'tile ',
295301
excludes: [

tfjs-backend-webgl/src/setup_test.ts

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ import './index';
2222
import '@tensorflow/tfjs-core/dist/public/chained_ops/register_all_chained_ops';
2323
// tslint:disable-next-line: no-imports-from-dist
2424
import '@tensorflow/tfjs-core/dist/register_all_gradients';
25-
import {registerTestEnvs} from './backend_webgl_test_registry';
25+
2626
// tslint:disable-next-line: no-imports-from-dist
2727
import {parseTestEnvFromKarmaFlags, setTestEnvs, setupTestFilters, TEST_ENVS, TestFilter} from '@tensorflow/tfjs-core/dist/jasmine_util';
2828

29+
import {registerTestEnvs} from './backend_webgl_test_registry';
30+
2931
registerTestEnvs();
3032

3133
const TEST_FILTERS: TestFilter[] = [];
@@ -34,7 +36,7 @@ const customInclude = (testName: string) => {
3436
'isBrowser: false', 'dilation gradient',
3537
'throws when index is out of bound',
3638
// otsu tests for threshold op is failing on windows
37-
'method otsu'
39+
'method otsu', 'bitwiseAnd'
3840
];
3941
for (const subStr of toExclude) {
4042
if (testName.includes(subStr)) {

tfjs-backend-webgpu/src/setup_test.ts

+6
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ const TEST_FILTERS: TestFilter[] = [
146146
'calling .data() twice works',
147147
],
148148
},
149+
{
150+
startsWith: 'bitwiseAnd',
151+
excludes: [
152+
'bitwiseAnd',
153+
],
154+
},
149155

150156
// exclude unsupported kernels and to be fixed cases
151157
{

tfjs-converter/docs/supported_ops.md

+1
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@
201201

202202
|Tensorflow Op Name|Tensorflow.js Op Name|
203203
|---|---|
204+
|BitwiseAnd|bitwiseAnd|
204205
|Equal|equal|
205206
|Greater|greater|
206207
|GreaterEqual|greaterEqual|

tfjs-converter/python/tensorflowjs/op_list/logical.json

+17-1
Original file line numberDiff line numberDiff line change
@@ -267,5 +267,21 @@
267267
"notSupported": true
268268
}
269269
]
270+
},
271+
{
272+
"tfOpName": "BitwiseAnd",
273+
"category": "logical",
274+
"inputs": [
275+
{
276+
"start": 0,
277+
"name": "x",
278+
"type": "tensor"
279+
},
280+
{
281+
"start": 1,
282+
"name": "y",
283+
"type": "tensor"
284+
}
285+
]
270286
}
271-
]
287+
]

tfjs-converter/src/operations/executors/logical_executor.ts

+7-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ import {InternalOpExecutor, Node} from '../types';
2626
import {getParamValue} from './utils';
2727

2828
export const executeOp: InternalOpExecutor =
29-
(node: Node, tensorMap: NamedTensorsMap,
30-
context: ExecutionContext, ops = tfOps): Tensor[] => {
29+
(node: Node, tensorMap: NamedTensorsMap, context: ExecutionContext,
30+
ops = tfOps): Tensor[] => {
3131
switch (node.op) {
3232
case 'Equal': {
3333
return [ops.equal(
@@ -80,6 +80,11 @@ export const executeOp: InternalOpExecutor =
8080
getParamValue('a', node, tensorMap, context) as Tensor,
8181
getParamValue('b', node, tensorMap, context) as Tensor)];
8282
}
83+
case 'BitwiseAnd': {
84+
return [ops.bitwiseAnd(
85+
getParamValue('a', node, tensorMap, context) as Tensor,
86+
getParamValue('b', node, tensorMap, context) as Tensor)];
87+
}
8388
default:
8489
throw TypeError(`Node type ${node.op} is not implemented`);
8590
}

tfjs-converter/src/operations/executors/logical_executor_test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ describe('logical', () => {
5454

5555
([
5656
'Equal', 'NotEqual', 'Greater', 'GreaterEqual', 'Less', 'LessEqual',
57-
'LogicalAnd', 'LogicalOr'
57+
'LogicalAnd', 'LogicalOr', 'BitwiseAnd'
5858
] as const )
5959
.forEach(op => {
6060
it('should call tfOps.' + op, () => {

tfjs-core/src/kernel_names.ts

+3
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ export interface BincountAttrs {
138138
size: number;
139139
}
140140

141+
export const BitwiseAnd = 'BitwiseAnd';
142+
export type BitwiseAndInputs = BinaryInputs;
143+
141144
export const BroadcastTo = 'BroadcastTo';
142145
export type BroadcastToInputs = Pick<NamedTensorInfoMap, 'x'>;
143146
export interface BroadCastToAttrs {

tfjs-core/src/ops/bitwise_and.ts

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import {ENGINE} from '../engine';
19+
import {BitwiseAnd, BitwiseAndInputs} from '../kernel_names';
20+
import {Tensor} from '../tensor';
21+
import {NamedTensorMap} from '../tensor_types';
22+
import {convertToTensor} from '../tensor_util_env';
23+
import {Rank} from '../types';
24+
import {arraysEqual} from '../util_base';
25+
26+
import {op} from './operation';
27+
28+
/**
29+
* Bitwise `AND` operation for input tensors.
30+
*
31+
* Given two input tensors, returns a new tensor
32+
* with the `AND` calculated values.
33+
*
34+
* The method supports int32 values
35+
*
36+
*
37+
* ```js
38+
* const x = tf.tensor1d([0, 5, 3, 14], 'int32');
39+
* const y = tf.tensor1d([5, 0, 7, 11], 'int32');
40+
* tf.bitwiseAnd(x, y).print();
41+
* ```
42+
*
43+
* @param x The input tensor to be calculated.
44+
* @param y The input tensor to be calculated.
45+
*
46+
* @doc {heading: 'Operations', subheading: 'Logical'}
47+
*/
48+
function bitwiseAnd_<R extends Rank>(x: Tensor, y: Tensor): Tensor<R> {
49+
const $x = convertToTensor(x, 'x', 'bitwiseAnd');
50+
const $y = convertToTensor(y, 'y', 'bitwiseAnd');
51+
52+
if (!arraysEqual($x.shape, $y.shape)) {
53+
throw new Error(`BitwiseAnd: Tensors must have the same shape. x: ${
54+
$x.shape}, y: ${$y.shape}`);
55+
}
56+
if ($x.dtype !== 'int32' || $y.dtype !== 'int32') {
57+
throw new Error(
58+
`BitwiseAnd: Only supports 'int32' values in tensor, found type of x: ${
59+
$x.dtype} and type of y: ${$y.dtype}`);
60+
}
61+
62+
const inputs: BitwiseAndInputs = {a: $x, b: $y};
63+
return ENGINE.runKernel(BitwiseAnd, inputs as unknown as NamedTensorMap);
64+
}
65+
export const bitwiseAnd = /* @__PURE__ */ op({bitwiseAnd_});

tfjs-core/src/ops/bitwise_and_test.ts

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import * as tf from '../index';
18+
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';
19+
import {expectArraysClose} from '../test_util';
20+
import {bitwiseAnd} from './bitwise_and';
21+
22+
describeWithFlags('bitwiseAnd', ALL_ENVS, () => {
23+
it('a bitwiseAnd b', async () => {
24+
const a = tf.tensor1d([0, 5, 3, 14], 'int32');
25+
const b = tf.tensor1d([5, 0, 7, 11], 'int32');
26+
27+
const res = bitwiseAnd(a, b);
28+
expectArraysClose(await res.data(), [0, 0, 3, 10]);
29+
});
30+
31+
it('different shape', () => {
32+
const a = tf.tensor1d([0, 5, 3, 14]);
33+
const b = tf.tensor1d([5, 0, 7]);
34+
35+
expect(() => bitwiseAnd(a, b))
36+
.toThrowError(/BitwiseAnd: Tensors must have the same shape/);
37+
});
38+
39+
it('wrong type', () => {
40+
const a = tf.tensor1d([0, 1, 3, 14], 'float32');
41+
const b = tf.tensor1d([5, 0, 7, 12], 'float32');
42+
43+
expect(() => bitwiseAnd(a, b))
44+
.toThrowError(/BitwiseAnd: Only supports 'int32' values in tensor/);
45+
});
46+
});

tfjs-core/src/ops/ops.ts

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ export {batchNorm2d} from './batchnorm2d';
3939
export {batchNorm3d} from './batchnorm3d';
4040
export {batchNorm4d} from './batchnorm4d';
4141
export {bincount} from './bincount';
42+
export {bitwiseAnd} from './bitwise_and';
4243
export {broadcastArgs} from './broadcast_args';
4344
export {broadcastTo} from './broadcast_to';
4445
export {buffer} from './buffer';

tfjs-node/src/run_tests.ts

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ jasmine_util.setTestEnvs([{
4848

4949
const IGNORE_LIST: string[] = [
5050
// Always ignore version tests:
51+
'bitwiseAnd',
5152
'version version',
5253
'unreliable is true due to both auto gc and string tensors',
5354
'unreliable is true due to auto gc',

0 commit comments

Comments
 (0)