diff --git a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py index cd891ff94..0c423036d 100644 --- a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py +++ b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py @@ -277,6 +277,109 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2): return not_selected_op1_nodes +def custom_write_calibration_table(calibration_cache, dir="."): + """ + Helper function to write calibration table to files. + """ + + import json + import logging + import flatbuffers + import numpy as np + + import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue + import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable + from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData + + logging.info(f"calibration cache: {calibration_cache}") + + class MyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (TensorData, TensorsData)): + return obj.to_dict() + if isinstance(obj, TensorDataWrapper): + return obj.data_dict + if isinstance(obj, np.ndarray): + return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"} + if isinstance(obj, CalibrationMethod): + return {"CLS": obj.__class__.__name__, "value": str(obj)} + return json.JSONEncoder.default(self, obj) + + json_data = json.dumps(calibration_cache, cls=MyEncoder) + + with open(os.path.join(dir, "calibration.json"), "w") as file: + file.write(json_data) # use `json.loads` to do the reverse + + # Serialize data using FlatBuffers + zero = np.array(0) + builder = flatbuffers.Builder(1024) + key_value_list = [] + + for key in sorted(calibration_cache.keys()): + values = calibration_cache[key] + d_values = values.to_dict() + + highest = d_values.get("highest", zero) + lowest = d_values.get("lowest", zero) + + highest_val = highest.item() if hasattr(highest, "item") else float(highest) + lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest) + + floats = [float(highest_val), float(lowest_val)] + + value = str(max(floats)) + + flat_key = builder.CreateString(key) + flat_value = builder.CreateString(value) + + KeyValue.KeyValueStart(builder) + KeyValue.KeyValueAddKey(builder, flat_key) + KeyValue.KeyValueAddValue(builder, flat_value) + key_value = KeyValue.KeyValueEnd(builder) + + key_value_list.append(key_value) + + + TrtTable.TrtTableStartDictVector(builder, len(key_value_list)) + for key_value in key_value_list: + builder.PrependUOffsetTRelative(key_value) + main_dict = builder.EndVector() + + TrtTable.TrtTableStart(builder) + TrtTable.TrtTableAddDict(builder, main_dict) + cal_table = TrtTable.TrtTableEnd(builder) + + builder.Finish(cal_table) + buf = builder.Output() + + with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file: + file.write(buf) + + # Deserialize data (for validation) + if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): + cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0) + dict_len = cal_table.DictLength() + for i in range(dict_len): + key_value = cal_table.Dict(i) + logging.info(key_value.Key()) + logging.info(key_value.Value()) + + # write plain text + with open(os.path.join(dir, "calibration.cache"), "w") as file: + for key in sorted(calibration_cache.keys()): + values = calibration_cache[key] + d_values = values.to_dict() + highest = d_values.get("highest", zero) + lowest = d_values.get("lowest", zero) + + highest_val = highest.item() if hasattr(highest, "item") else float(highest) + lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest) + + floats = [float(highest_val), float(lowest_val)] + + value = key + " " + str(max(floats)) + file.write(value) + file.write("\n") def parse_input_args(): parser = argparse.ArgumentParser() @@ -553,8 +656,42 @@ def output_run_config(flags, samples): for k, v in compute_range.data.items(): json_compute_range[k] = (float(v.range_value[0]), float(v.range_value[1])) + print("Writing calibration table") + try: + write_calibration_table(json_compute_range) + except AttributeError as e: + class TensorDataWrapper: + def __init__(self, data_dict): + self.data_dict = data_dict + + def to_dict(self): + return self.data_dict + + def __repr__(self): + return repr(self.data_dict) + + def __serializable__(self): + return self.data_dict + + calibration_data = {} + for k, v in compute_range.data.items(): + if hasattr(v, 'to_dict'): + tensor_dict = v.to_dict() + processed_dict = {} + for dk, dv in tensor_dict.items(): + if isinstance(dv, np.ndarray): + processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist() + elif isinstance(dv, np.number): + processed_dict[dk] = dv.item() + else: + processed_dict[dk] = dv + calibration_data[k] = TensorDataWrapper(processed_dict) + else: + calibration_data[k] = v + + print("Using custom calibration table function") + custom_write_calibration_table(calibration_data) - write_calibration_table(json_compute_range) print("Calibration is done. Calibration cache is saved to calibration.json") model_quants = model_quants + "_int8"