Skip to content

Added Additional Operators to support for Requantizer; Resolved assertion failed due to unequal shapes #3041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorflow/lite/micro/tools/model_transforms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def clear_resource_variable_buffers(model):
buffer_idx = tensor.buffer
if (tensor.type != schema_fb.TensorType.RESOURCE
and buffer_idx not in multi_subgraph_resource_buffers
and model.buffers[buffer_idx].data != []):
and np.array(model.buffers[buffer_idx].data).size != 0):
# if the entire initialization subgraph has not been cleared, we cannot
# make any additional changes to the flatbuffer
return
Expand Down
32 changes: 29 additions & 3 deletions tensorflow/lite/micro/tools/requantize_flatbuffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,25 +71,51 @@
requantize_flatbuffer_utils.requantize_fully_connected,
schema_py_generated.BuiltinOperator.TRANSPOSE_CONV:
requantize_flatbuffer_utils.requantize_transpose_conv,
schema_py_generated.BuiltinOperator.GREATER:
requantize_flatbuffer_utils.requantize_greater,
schema_py_generated.BuiltinOperator.SUB:
requantize_flatbuffer_utils.requantize_sub,
schema_py_generated.BuiltinOperator.MUL:
requantize_flatbuffer_utils.requantize_mul,
schema_py_generated.BuiltinOperator.BATCH_MATMUL:
requantize_flatbuffer_utils.requantize_batch_matmul,
schema_py_generated.BuiltinOperator.SELECT_V2:
requantize_flatbuffer_utils.requantize_select_v2,
schema_py_generated.BuiltinOperator.CONCATENATION:
requantize_flatbuffer_utils.requantize_concatenation,
schema_py_generated.BuiltinOperator.ADD:
requantize_flatbuffer_utils.requantize_add,
schema_py_generated.BuiltinOperator.PAD:
requantize_flatbuffer_utils.requantize_pad,
}

# List of tested simple operators (no weight and bias, e.g., reshape) see tensorflow/lite/schema/schema.fbs for op code names
_TESTED_SIMPLE_OPS = [
schema_py_generated.BuiltinOperator.ADD,
schema_py_generated.BuiltinOperator.CONCATENATION,
schema_py_generated.BuiltinOperator.DEQUANTIZE,
schema_py_generated.BuiltinOperator.LEAKY_RELU,
schema_py_generated.BuiltinOperator.LOGISTIC,
schema_py_generated.BuiltinOperator.MEAN,
schema_py_generated.BuiltinOperator.MUL,
schema_py_generated.BuiltinOperator.PAD,
schema_py_generated.BuiltinOperator.QUANTIZE,
schema_py_generated.BuiltinOperator.RESHAPE,
schema_py_generated.BuiltinOperator.RSQRT,
schema_py_generated.BuiltinOperator.SQRT,
schema_py_generated.BuiltinOperator.SQUARED_DIFFERENCE,
schema_py_generated.BuiltinOperator.STRIDED_SLICE,
schema_py_generated.BuiltinOperator.SUB,
schema_py_generated.BuiltinOperator.CALL_ONCE,
schema_py_generated.BuiltinOperator.VAR_HANDLE,
schema_py_generated.BuiltinOperator.READ_VARIABLE,
schema_py_generated.BuiltinOperator.ASSIGN_VARIABLE,
schema_py_generated.BuiltinOperator.FLOOR_DIV,
schema_py_generated.BuiltinOperator.CAST,
schema_py_generated.BuiltinOperator.COS,
schema_py_generated.BuiltinOperator.SIN,
schema_py_generated.BuiltinOperator.UNPACK,
schema_py_generated.BuiltinOperator.TRANSPOSE,
schema_py_generated.BuiltinOperator.SPLIT,
schema_py_generated.BuiltinOperator.RESIZE_NEAREST_NEIGHBOR,
schema_py_generated.BuiltinOperator.SELECT_V2,
]

_SUPPORTED_OPS = set(
Expand Down
95 changes: 95 additions & 0 deletions tensorflow/lite/micro/tools/requantize_flatbuffer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,101 @@ def requantize_fully_connected(tensors, buffers, op):
set_bias_type_int64(buffers, input_tensor, weight_tensor, bias_tensor)


def requantize_greater(tensors, buffers, op):
# Indices are from tensorflow/lite/micro/kernels/fully_connected_common.cc
input1_tensor = tensors[op.inputs[0]]
# weight stays the same, no change needed
input2_tensor = tensors[op.inputs[1]]
output_tensor = tensors[op.outputs[0]]

change_activation_tensor_8to16(input1_tensor, buffers)
change_activation_tensor_8to16(input2_tensor, buffers)
change_activation_tensor_8to16(output_tensor, buffers)


def requantize_select_v2(tensors, buffers, op):
# Indices are from tensorflow/lite/micro/kernels/fully_connected_common.cc
input1_tensor = tensors[op.inputs[0]]
# weight stays the same, no change needed
input2_tensor = tensors[op.inputs[1]]
input3_tensor = tensors[op.inputs[2]]
output_tensor = tensors[op.outputs[0]]

change_activation_tensor_8to16(input1_tensor, buffers)
change_activation_tensor_8to16(input2_tensor, buffers)
change_activation_tensor_8to16(input3_tensor, buffers)
change_activation_tensor_8to16(output_tensor, buffers)


def requantize_pad(tensors, buffers, op):
# Indices are from tensorflow/lite/micro/kernels/fully_connected_common.cc
input1_tensor = tensors[op.inputs[0]]
# weight stays the same, no change needed
input2_tensor = tensors[op.inputs[1]]
# input3_tensor = tensors[op.inputs[2]]
output_tensor = tensors[op.outputs[0]]

change_activation_tensor_8to16(input1_tensor, buffers)
change_activation_tensor_8to16(input2_tensor, buffers)
# change_activation_tensor_8to16(input3_tensor, buffers)
change_activation_tensor_8to16(output_tensor, buffers)


def requantize_sub(tensors, buffers, op):
# Indices are from tensorflow/lite/micro/kernels/sub.cc
input1_tensor = tensors[op.inputs[0]]
input2_tensor = tensors[op.inputs[1]]
output_tensor = tensors[op.outputs[0]]

change_activation_tensor_8to16(input1_tensor, buffers)
change_activation_tensor_8to16(input2_tensor, buffers)
change_activation_tensor_8to16(output_tensor, buffers)


def requantize_batch_matmul(tensors, buffers, op):
# Indices are from tensorflow/lite/micro/kernels/sub.cc
input1_tensor = tensors[op.inputs[0]]
input2_tensor = tensors[op.inputs[1]]
output_tensor = tensors[op.outputs[0]]

change_activation_tensor_8to16(input1_tensor, buffers)
change_activation_tensor_8to16(input2_tensor, buffers)
change_activation_tensor_8to16(output_tensor, buffers)


def requantize_concatenation(tensors, buffers, op):
# Indices are from tensorflow/lite/micro/kernels/sub.cc
input1_tensor = tensors[op.inputs[0]]
input2_tensor = tensors[op.inputs[1]]
output_tensor = tensors[op.outputs[0]]

change_activation_tensor_8to16(input1_tensor, buffers)
change_activation_tensor_8to16(input2_tensor, buffers)
change_activation_tensor_8to16(output_tensor, buffers)


def requantize_mul(tensors, buffers, op):
# Indices are from tensorflow/lite/micro/kernels/sub.cc
input1_tensor = tensors[op.inputs[0]]
input2_tensor = tensors[op.inputs[1]]
output_tensor = tensors[op.outputs[0]]

change_activation_tensor_8to16(input1_tensor, buffers)
change_activation_tensor_8to16(input2_tensor, buffers)
change_activation_tensor_8to16(output_tensor, buffers)


def requantize_add(tensors, buffers, op):
# Indices are from tensorflow/lite/micro/kernels/sub.cc
input1_tensor = tensors[op.inputs[0]]
input2_tensor = tensors[op.inputs[1]]
output_tensor = tensors[op.outputs[0]]

change_activation_tensor_8to16(input1_tensor, buffers)
change_activation_tensor_8to16(input2_tensor, buffers)
change_activation_tensor_8to16(output_tensor, buffers)


def requantize_unidirectional_sequence_lstm(tensors, buffers, op):
"""Requantize the unidirectonal sequance lstm op from int8 to int16 """
input_tensor = tensors[op.inputs[0]]
Expand Down
Loading