Skip to content

Commit 06269a3

Browse files
authored
[js/webgpu] allow uint8 tensors for webgpu (#19545)
### Description allow uint8 tensors for webgpu
1 parent 4874a41 commit 06269a3

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

js/common/lib/tensor-impl.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ export class Tensor implements TensorInterface {
103103
}
104104
case 'gpu-buffer': {
105105
if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' &&
106-
type !== 'bool')) {
106+
type !== 'uint8' && type !== 'bool')) {
107107
throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
108108
}
109109
this.gpuBufferData = arg0.gpuBuffer;

js/common/lib/tensor.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ export declare namespace Tensor {
135135
/**
136136
* supported data types for constructing a tensor from a WebGPU buffer
137137
*/
138-
export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool';
138+
export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool';
139139

140140
/**
141141
* represent where the tensor data is stored

js/web/lib/wasm/wasm-common.ts

+2-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro
169169
* Check whether the given tensor type is supported by GPU buffer
170170
*/
171171
export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' ||
172-
type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32';
172+
type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' ||
173+
type === 'bool';
173174

174175
/**
175176
* Map string data location to integer value

0 commit comments

Comments
 (0)