From 7554eb09653008afe3a15900ded7adfb994a59c7 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sat, 4 Jan 2025 16:08:59 +0000 Subject: [PATCH 1/6] Update tests to use fp8 --- .../resnet50/e2e_migraphx_resnet_example.py | 95 ++++++++++++++----- .../migraphx/e2e_migraphx_bert_example.py | 18 +++- 2 files changed, 88 insertions(+), 25 deletions(-) diff --git a/quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py b/quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py index 52dd29802..fc91fcb67 100644 --- a/quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py +++ b/quantization/image_classification/migraphx/resnet50/e2e_migraphx_resnet_example.py @@ -13,6 +13,13 @@ def parse_input_args(): parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + required=False, + default='./resnet50-v2-7.onnx', + help='Target DIR for model. Default is ./resnet50-v2-7.onnx', + ) + parser.add_argument( "--fp16", action="store_true", @@ -29,6 +36,14 @@ def parse_input_args(): help='Perform no quantization', ) + parser.add_argument( + "--fp8", + action="store_true", + required=False, + default=False, + help='Perform fp8 quantizaton instead of int8', + ) + parser.add_argument( "--image_dir", required=False, @@ -48,6 +63,29 @@ def parse_input_args(): help='Size of images for calibration', type=int) + parser.add_argument( + "--exhaustive_tune", + action="store_true", + required=False, + default=False, + help='Enable MIGraphX Exhaustive tune before compile. Default False', + ) + + parser.add_argument( + "--cache", + action="store_true", + required=False, + default=True, + help='cache the compiled model between runs. Saves quantization and compile time. Default true', + ) + + parser.add_argument( + "--cache_name", + required=False, + default="./cached_model.mxr", + help='Name and path of the compiled model cache. Default: ./cached_model.mxr', + ) + return parser.parse_args() class ImageNetDataReader(CalibrationDataReader): @@ -255,6 +293,7 @@ class ImageClassificationEvaluator: def __init__(self, model_path, synset_id, + flags, data_reader: CalibrationDataReader, providers=["MIGraphXExecutionProvider"]): ''' @@ -276,10 +315,21 @@ def get_result(self): def predict(self): sess_options = onnxruntime.SessionOptions() - sess_options.log_severity_level = 0 - sess_options.log_verbosity_level = 0 + sess_options.log_severity_level = 2 + sess_options.log_verbosity_level = 2 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=self.providers) + session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, + providers=[("MIGraphXExecutionProvider", + {"migraphx_fp8_enable": flags.fp8 and not flags.fp32, + "migraphx_int8_enable": not (flags.fp8 or flags.fp32), + "migraphx_fp16_enable": flags.fp16 and not flags.fp32, + "migraphx_int8_calibration_table_name": flags.calibration_table, + "migraphx_use_native_calibration_table": flags.native_calibration_table, + "migraphx_save_compiled_model": flags.cache, + "migraphx_save_model_path": flags.cache_name, + "migraphx_load_compiled_model": flags.cache, + "migraphx_load_model_path": flags.cache_name, + "migraphx_exhaustive_tune": flags.exhaustive_tune})]) inference_outputs_list = [] while True: @@ -362,21 +412,31 @@ def get_dataset_size(dataset_path, calibration_dataset_size): flags = parse_input_args() # Dataset settings - model_path = "./resnet50-v2-7.onnx" + model_path = flags.model ilsvrc2012_dataset_path = flags.image_dir augmented_model_path = "./augmented_model.onnx" batch_size = flags.batch calibration_dataset_size = 0 if flags.fp32 else flags.cal_size # Size of dataset for calibration + precision="" + + if not (flags.fp8 or flags.fp32): + precision = precision + "_int8" + + if flags.fp8 and not flags.fp32: + precision = precision + "_fp8" + + if flags.fp16 and not flags.fp32: + precision = "_fp16" + precision + calibration_table_generation_enable = False if not flags.fp32: - # INT8 calibration setting calibration_table_generation_enable = True # Enable/Disable INT8 calibration - - # MIGraphX EP INT8 settings - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable INT8 precision - os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name - os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name + flags.calibration_table = "calibration_cal"+ str(flags.cal_size) + precision + ".flatbuffers" + flags.native_calibration_table = "False" + if os.path.isfile("./" + flags.calibration_table): + calibration_table_generation = False + print("Found previous calibration: " + flags.calibration_table + "Skipping generating table") execution_provider = ["MIGraphXExecutionProvider"] @@ -406,15 +466,11 @@ def get_dataset_size(dataset_path, calibration_dataset_size): for keys, values in cal_tensors.data.items(): serial_cal_tensors[keys] = [float(x[0]) for x in values.range_value] - print("Writing calibration table") + print("Writing calibration table to:" + flags.calibration_table) write_calibration_table(serial_cal_tensors) + os.rename("./calibration.flatbuffers", flags.calibration_table) print("Write complete") - if flags.fp16: - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "1" - else: - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0" - # Run prediction in MIGraphX EP138G data_reader = ImageNetDataReader(ilsvrc2012_dataset_path, start_index=calibration_dataset_size, @@ -427,14 +483,9 @@ def get_dataset_size(dataset_path, calibration_dataset_size): synset_id = data_reader.get_synset_id(ilsvrc2012_dataset_path, calibration_dataset_size, prediction_dataset_size) # Generate synset id print("Prepping Evalulator") - evaluator = ImageClassificationEvaluator(new_model_path, synset_id, data_reader, providers=execution_provider) + evaluator = ImageClassificationEvaluator(new_model_path, synset_id, flags, data_reader, providers=execution_provider) print("Performing Predictions") evaluator.predict() print("Read out answer") result = evaluator.get_result() evaluator.evaluate(result) - - #Set OS flags to off to ensure we don't interfere with other test runs - - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0" - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0" diff --git a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py index cd891ff94..d8c061c0b 100644 --- a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py +++ b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py @@ -297,6 +297,14 @@ def parse_input_args(): help='Perform int8 quantization on the model before running inference', ) + parser.add_argument( + "--fp8", + action="store_true", + required=False, + default=False, + help='Perform fp8 quantization on the model before running inference', + ) + parser.add_argument( "--ep", action="store", @@ -434,7 +442,7 @@ def output_run_config(flags, samples): print ("filename:" + flags.model) print ("Samples: " + str(samples) + " Batch size: " + str(flags.batch)) print ("Sequence length: " + str(flags.seq_len)) - print ("Model Quantization: fp16:" + str(flags.fp16) + " int8:" + str(flags.int8)) + print ("Model Quantization: fp16:" + str(flags.fp16) + " int8:" + str(flags.int8) + "fp8:" + str(flags.fp8)) if flags.int8: if flags.ort_quant: print ("Quantizer: Onnxruntime") @@ -525,7 +533,7 @@ def output_run_config(flags, samples): model_quants = "" - if flags.int8: + if flags.int8 or flags.fp8: model = onnx.load_model(model_path) # Generate INT8 calibration cache @@ -585,12 +593,16 @@ def output_run_config(flags, samples): qdq_model_path = model_path print("Int8 Quantization Done with " + cal_ep) #Quantize with MIGraphX's INT8 quantizer instead - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable MIGRAPHX INT8 precision + if flags.int8: + os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable MIGRAPHX INT8 precision + else: + os.environ["ORT_MIGRAPHX_FP8_ENABLE"] = "1" # Enable MIGRAPHX INT8 precision os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name else: qdq_model_path = model_path os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0" # Disable MIGRAPHX INT8 precision + os.environ["ORT_MIGRAPHX_FP8_ENABLE"] = "0" # Disable MIGRAPHX INT8 precision # No fp16 cal needed, MIGraphX will handle that through Onnxruntime & MIGraphX Execution Provider during compile if flags.fp16: From e61da258025264751b508e72e33bd9b7ac20a7e3 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Thu, 6 Mar 2025 14:56:24 -0600 Subject: [PATCH 2/6] Update bert model for fp8 --- .../migraphx/e2e_migraphx_bert_example.py | 127 ++++++++++-------- 1 file changed, 74 insertions(+), 53 deletions(-) diff --git a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py index d8c061c0b..fe718d627 100644 --- a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py +++ b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py @@ -323,6 +323,15 @@ def parse_input_args(): help='The desired execution provider [MIGraphX, ROCm, CPU] for int8 quantization; Default is MIGraphX', ) + parser.add_argument( + "--calibration_table", + action="store", + required=False, + default="bert_calibration_table_100_int8.flatbuffers", + type=str, + help='use a previously created calibration table" default is bert_calibration_table_100_int8.flatbuffers', + ) + parser.add_argument( "--model", action="store", @@ -533,42 +542,58 @@ def output_run_config(flags, samples): model_quants = "" - if flags.int8 or flags.fp8: - model = onnx.load_model(model_path) - - # Generate INT8 calibration cache - print("Calibration data compute starts with " + str(cal_ep)) - calibrator = create_calibrator(model_path, op_types_to_quantize, augmented_model_path=augmented_model_path, calibrate_method=CalibrationMethod.Percentile) - calibrator.set_execution_providers([cal_ep]) - - ''' - We can use one data reader to do data pre-processing, however, - some machines don't have sufficient memory to hold all dataset and all intermediate output, - especially using 'Entropy' or 'Percentile' calibrator which collects histogram for tensors. - So let multiple data readers to handle different stride of dataset to avoid OOM. - ''' - stride = 10 - #for i in range(0, calib_num, stride): - data_reader = BertDataReader(model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], start_index=0, end_index=calib_num) - calibrator.collect_data(data_reader) + if flags.int8 and flags.fp8: + print("INT8 and FP8 quantization is mutually exclusive for calibration") + exit() - compute_range = calibrator.compute_data() + precision="" + if flags.int8: + precision = precision + "_int8" - # ORT returns data as return TensorsData(cal, self.collector.compute_collection_result()) - # Need to fix this for serialization but also convert values to float from float32 in order for JSON to correctly - # write out calibration table - json_compute_range = {} - for k, v in compute_range.data.items(): - json_compute_range[k] = (float(v.range_value[0]), float(v.range_value[1])) + if flags.fp8 : + precision = precision + "_fp8" + if flags.int8 or flags.fp8: + model = onnx.load_model(model_path) + native_calibration_table = "False" - write_calibration_table(json_compute_range) - print("Calibration is done. Calibration cache is saved to calibration.json") + calibration_table = "bert_calibration_table_"+ str(flags.cal_num) + precision + ".flatbuffers" - model_quants = model_quants + "_int8" + if os.path.isfile("./" + calibration_table): + print("Found previous calibration: " + flags.calibration_table + "Skipping generating table") + else: + # Generate INT8 calibration cache + print("Calibration data compute starts with " + str(cal_ep)) + calibrator = create_calibrator(model_path, op_types_to_quantize, augmented_model_path=augmented_model_path, calibrate_method=CalibrationMethod.Percentile) + calibrator.set_execution_providers([cal_ep]) + + ''' + We can use one data reader to do data pre-processing, however, + some machines don't have sufficient memory to hold all dataset and all intermediate output, + especially using 'Entropy' or 'Percentile' calibrator which collects histogram for tensors. + So let multiple data readers to handle different stride of dataset to avoid OOM. + ''' + stride = 10 + #for i in range(0, calib_num, stride): + data_reader = BertDataReader(model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], start_index=0, end_index=calib_num) + calibrator.collect_data(data_reader) + + compute_range = calibrator.compute_data() + + # ORT returns data as return TensorsData(cal, self.collector.compute_collection_result()) + # Need to fix this for serialization but also convert values to float from float32 in order for JSON to correctly + # write out calibration table + json_compute_range = {} + for k, v in compute_range.data.items(): + json_compute_range[k] = [float(x[0]) for x in values.range_value] + + write_calibration_table(json_compute_range) + print("Calibration is done. Calibration cache is saved to calibration.json") + + model_quants = model_quants + precision if flags.ort_quant: - print("Int8 Quantization Done with Onnxruntime Quantizer") + print(precision + " Quantization Done with Onnxruntime Quantizer") mode = QuantizationMode.QLinearOps # In TRT, it recommended to add QDQ pair to inputs of Add node followed by ReduceMean node. # Mirroring here what TRT does in MIGraphX Quantization to be able to perform an apples to apples comparison @@ -591,44 +616,40 @@ def output_run_config(flags, samples): print("QDQ model is saved to ", qdq_model_path) else: qdq_model_path = model_path - print("Int8 Quantization Done with " + cal_ep) - #Quantize with MIGraphX's INT8 quantizer instead - if flags.int8: - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable MIGRAPHX INT8 precision - else: - os.environ["ORT_MIGRAPHX_FP8_ENABLE"] = "1" # Enable MIGRAPHX INT8 precision - os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name - os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name + print(precision + " Quantization Done with " + cal_ep) + #Quantize with MIGraphX's INT8/FP8 quantizer instead else: qdq_model_path = model_path - os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0" # Disable MIGRAPHX INT8 precision - os.environ["ORT_MIGRAPHX_FP8_ENABLE"] = "0" # Disable MIGRAPHX INT8 precision # No fp16 cal needed, MIGraphX will handle that through Onnxruntime & MIGraphX Execution Provider during compile if flags.fp16: - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "1" # Enable MIGRAPHX FP16 precision model_quants = model_quants + "_fp16" - else: - os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0" # Disable MIGRAPHX FP16 precision if flags.save_load: model_name = str(qdq_model_path) + "_s" + str(flags.seq_len) + "_b" + str(flags.batch) + str(model_quants) + ".mxr" print("save load model from " + str(model_name)) - os.environ["ORT_MIGRAPHX_SAVE_COMPILED_MODEL"] = "1" - os.environ["ORT_MIGRAPHX_LOAD_COMPILED_MODEL"] = "1" - os.environ["ORT_MIGRAPHX_SAVE_COMPILE_PATH"] = model_name - os.environ["ORT_MIGRAPHX_LOAD_COMPILE_PATH"] = model_name - - # QDQ model inference and get SQUAD prediction - batch_size = flags.batch - data_reader = BertDataReader(qdq_model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], end_index=samples) - sess_options = onnxruntime.SessionOptions() - if flags.ort_verbose: + + # QDQ model inference and get SQUAD prediction + batch_size = flags.batch + data_reader = BertDataReader(qdq_model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], end_index=samples) + sess_options = onnxruntime.SessionOptions() + if flags.ort_verbose: sess_options.log_severity_level = 0 sess_options.log_verbosity_level = 0 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - ort_session = onnxruntime.InferenceSession(qdq_model_path, sess_options=sess_options, providers=[ep]) + ort_session = onnxruntime.InferenceSession(qdq_model_path, sess_options=sess_options, + providers=[("MIGraphXExecutionProvider", + {"migraphx_fp8_enable": flags.fp8 and not flags.fp32, + "migraphx_int8_enable": not (flags.fp8 or flags.fp32), + "migraphx_fp16_enable": flags.fp16 and not flags.fp32, + "migraphx_int8_calibration_table_name": calibration_table, + "migraphx_use_native_calibration_table": native_calibration_table, + "migraphx_save_compiled_model": flags.save_load, + "migraphx_save_model_path": model_name, + "migraphx_load_compiled_model": flags.save_load, + "migraphx_load_model_path": model_name, + "migraphx_exhaustive_tune": flags.exhaustive_tune})]) print("Running Inferences") latency = [] #Used for timing information From cc2f00ee53537a5a5a0e67cab2df134db2a81dbe Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Tue, 25 Mar 2025 16:54:55 -0500 Subject: [PATCH 3/6] Fix 'tuple' object has no attribute 'to_dict' for bert --- .../migraphx/e2e_migraphx_bert_example.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py index fe718d627..2a516dad3 100644 --- a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py +++ b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py @@ -580,17 +580,21 @@ def output_run_config(flags, samples): compute_range = calibrator.compute_data() - # ORT returns data as return TensorsData(cal, self.collector.compute_collection_result()) - # Need to fix this for serialization but also convert values to float from float32 in order for JSON to correctly - # write out calibration table - json_compute_range = {} + print("Writing calibration table") + try: + write_calibration_table(json_compute_range) + except AttributeError as e: + calibration_table = {} for k, v in compute_range.data.items(): - json_compute_range[k] = [float(x[0]) for x in values.range_value] + min_val = float(v.range_value[0]) if hasattr(v.range_value[0], 'item') else float(v.range_value[0]) + max_val = float(v.range_value[1]) if hasattr(v.range_value[1], 'item') else float(v.range_value[1]) + calibration_table[k] = [min_val, max_val] - write_calibration_table(json_compute_range) - print("Calibration is done. Calibration cache is saved to calibration.json") + with open("calibration.flatbuffers", "w") as f: + json.dump(calibration_table, f) + print("Calibration is done. Calibration cache is saved to calibration.json") - model_quants = model_quants + precision + model_quants = model_quants + precision if flags.ort_quant: print(precision + " Quantization Done with Onnxruntime Quantizer") @@ -634,8 +638,8 @@ def output_run_config(flags, samples): data_reader = BertDataReader(qdq_model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], end_index=samples) sess_options = onnxruntime.SessionOptions() if flags.ort_verbose: - sess_options.log_severity_level = 0 - sess_options.log_verbosity_level = 0 + sess_options.log_severity_level = 0 + sess_options.log_verbosity_level = 0 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL ort_session = onnxruntime.InferenceSession(qdq_model_path, sess_options=sess_options, From a908e29c8d8c86cf33cfd5b657bfa8a8ed0194db Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 11 Apr 2025 15:52:45 +0000 Subject: [PATCH 4/6] Update script for batch 1 cal, reusing cal, and handle provider options better Seems to clean up and fix a few odd states we were getting into --- .../migraphx/e2e_migraphx_bert_example.py | 75 ++++++++++--------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py index 2a516dad3..b62468587 100644 --- a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py +++ b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py @@ -327,9 +327,9 @@ def parse_input_args(): "--calibration_table", action="store", required=False, - default="bert_calibration_table_100_int8.flatbuffers", + default="./bert_calibration_table_100_int8.flatbuffers", type=str, - help='use a previously created calibration table" default is bert_calibration_table_100_int8.flatbuffers', + help='use a previously created calibration table" default is ./bert_calibration_table_100_int8.flatbuffers', ) parser.add_argument( @@ -451,7 +451,7 @@ def output_run_config(flags, samples): print ("filename:" + flags.model) print ("Samples: " + str(samples) + " Batch size: " + str(flags.batch)) print ("Sequence length: " + str(flags.seq_len)) - print ("Model Quantization: fp16:" + str(flags.fp16) + " int8:" + str(flags.int8) + "fp8:" + str(flags.fp8)) + print ("Model Quantization: fp16:" + str(flags.fp16) + " int8:" + str(flags.int8) + " fp8:" + str(flags.fp8)) if flags.int8: if flags.ort_quant: print ("Quantizer: Onnxruntime") @@ -541,6 +541,7 @@ def output_run_config(flags, samples): samples = flags.batch model_quants = "" + provider_args = {} if flags.int8 and flags.fp8: print("INT8 and FP8 quantization is mutually exclusive for calibration") @@ -549,19 +550,23 @@ def output_run_config(flags, samples): precision="" if flags.int8: precision = precision + "_int8" + provider_args["migraphx_int8_enable"] = str(True) if flags.fp8 : precision = precision + "_fp8" + provider_args["migraphx_fp8_enable"] = str(True) if flags.int8 or flags.fp8: model = onnx.load_model(model_path) - native_calibration_table = "False" - calibration_table = "bert_calibration_table_"+ str(flags.cal_num) + precision + ".flatbuffers" - - if os.path.isfile("./" + calibration_table): + if os.path.isfile("./" + flags.calibration_table): print("Found previous calibration: " + flags.calibration_table + "Skipping generating table") + provider_args["migraphx_int8_calibration_table_name"] = str(flags.calibration_table) else: + calibration_table_name = "bert_calibration_table_"+ str(flags.cal_num) + precision + ".flatbuffers" + print("Unable to find " + flags.calibration_table + " Generating Table: " + calibration_table_name) + provider_args["migraphx_int8_calibration_table_name"] = calibration_table_name + # Generate INT8 calibration cache print("Calibration data compute starts with " + str(cal_ep)) calibrator = create_calibrator(model_path, op_types_to_quantize, augmented_model_path=augmented_model_path, calibrate_method=CalibrationMethod.Percentile) @@ -575,24 +580,26 @@ def output_run_config(flags, samples): ''' stride = 10 #for i in range(0, calib_num, stride): - data_reader = BertDataReader(model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], start_index=0, end_index=calib_num) + data_reader = BertDataReader(model_path, input_dataset, input_tokens, 1, sequence_lengths[-1], flags.query_len, doc_stride[-1], start_index=0, end_index=calib_num) calibrator.collect_data(data_reader) compute_range = calibrator.compute_data() - print("Writing calibration table") - try: - write_calibration_table(json_compute_range) - except AttributeError as e: calibration_table = {} - for k, v in compute_range.data.items(): - min_val = float(v.range_value[0]) if hasattr(v.range_value[0], 'item') else float(v.range_value[0]) - max_val = float(v.range_value[1]) if hasattr(v.range_value[1], 'item') else float(v.range_value[1]) - calibration_table[k] = [min_val, max_val] - - with open("calibration.flatbuffers", "w") as f: + print("Writing calibration table") + try: + write_calibration_table(calibration_table) + except AttributeError as e: + calibration_table = {} + for k, v in compute_range.data.items(): + min_val = float(v.range_value[0]) if hasattr(v.range_value[0], 'item') else float(v.range_value[0]) + max_val = float(v.range_value[1]) if hasattr(v.range_value[1], 'item') else float(v.range_value[1]) + calibration_table[k] = [min_val, max_val] + + with open(flags.calibration_table, "w") as f: json.dump(calibration_table, f) - print("Calibration is done. Calibration cache is saved to calibration.json") + print("Calibration is done. Calibration cache is saved to " + calibration_table_name) + provider_args["migraphx_int8_calibration_table_name"] = calibration_table_name model_quants = model_quants + precision @@ -628,32 +635,28 @@ def output_run_config(flags, samples): # No fp16 cal needed, MIGraphX will handle that through Onnxruntime & MIGraphX Execution Provider during compile if flags.fp16: model_quants = model_quants + "_fp16" + provider_args["migraphx_fp16_enable"] = str(True) + model_name = "" if flags.save_load: model_name = str(qdq_model_path) + "_s" + str(flags.seq_len) + "_b" + str(flags.batch) + str(model_quants) + ".mxr" print("save load model from " + str(model_name)) + provider_args["migraphx_save_compiled_model"] = flags.save_load + provider_args["migraphx_load_compiled_model"] = flags.save_load + provider_args["migraphx_save_model_path"] = model_name + provider_args["migraphx_load_model_path"] = model_name # QDQ model inference and get SQUAD prediction - batch_size = flags.batch - data_reader = BertDataReader(qdq_model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], end_index=samples) - sess_options = onnxruntime.SessionOptions() - if flags.ort_verbose: - sess_options.log_severity_level = 0 - sess_options.log_verbosity_level = 0 + batch_size = flags.batch + data_reader = BertDataReader(qdq_model_path, input_dataset, input_tokens, batch_size, sequence_lengths[-1], flags.query_len, doc_stride[-1], end_index=samples) + sess_options = onnxruntime.SessionOptions() + if flags.ort_verbose: + sess_options.log_severity_level = 0 + sess_options.log_verbosity_level = 0 sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL ort_session = onnxruntime.InferenceSession(qdq_model_path, sess_options=sess_options, - providers=[("MIGraphXExecutionProvider", - {"migraphx_fp8_enable": flags.fp8 and not flags.fp32, - "migraphx_int8_enable": not (flags.fp8 or flags.fp32), - "migraphx_fp16_enable": flags.fp16 and not flags.fp32, - "migraphx_int8_calibration_table_name": calibration_table, - "migraphx_use_native_calibration_table": native_calibration_table, - "migraphx_save_compiled_model": flags.save_load, - "migraphx_save_model_path": model_name, - "migraphx_load_compiled_model": flags.save_load, - "migraphx_load_model_path": model_name, - "migraphx_exhaustive_tune": flags.exhaustive_tune})]) + providers=[("MIGraphXExecutionProvider", provider_args)]) print("Running Inferences") latency = [] #Used for timing information From 2f53c137702fb3f3553da0f0f76936e23ff9de83 Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Tue, 25 Mar 2025 16:54:55 -0500 Subject: [PATCH 5/6] Fix 'tuple' object has no attribute 'to_dict' for bert Use custom_write_calibration_table for migraphx --- .../migraphx/e2e_migraphx_bert_example.py | 141 +++++++++++++++++- 1 file changed, 137 insertions(+), 4 deletions(-) diff --git a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py index b62468587..2064420f1 100644 --- a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py +++ b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py @@ -277,6 +277,109 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2): return not_selected_op1_nodes +def custom_write_calibration_table(calibration_cache, dir="."): + """ + Helper function to write calibration table to files. + """ + + import json + import logging + import flatbuffers + import numpy as np + + import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue + import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable + from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData + + logging.info(f"calibration cache: {calibration_cache}") + + class MyEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, (TensorData, TensorsData)): + return obj.to_dict() + if isinstance(obj, TensorDataWrapper): + return obj.data_dict + if isinstance(obj, np.ndarray): + return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"} + if isinstance(obj, CalibrationMethod): + return {"CLS": obj.__class__.__name__, "value": str(obj)} + return json.JSONEncoder.default(self, obj) + + json_data = json.dumps(calibration_cache, cls=MyEncoder) + + with open(os.path.join(dir, "calibration.json"), "w") as file: + file.write(json_data) # use `json.loads` to do the reverse + + # Serialize data using FlatBuffers + zero = np.array(0) + builder = flatbuffers.Builder(1024) + key_value_list = [] + + for key in sorted(calibration_cache.keys()): + values = calibration_cache[key] + d_values = values.to_dict() + + highest = d_values.get("highest", zero) + lowest = d_values.get("lowest", zero) + + highest_val = highest.item() if hasattr(highest, "item") else float(highest) + lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest) + + floats = [float(highest_val), float(lowest_val)] + + value = str(max(floats)) + + flat_key = builder.CreateString(key) + flat_value = builder.CreateString(value) + + KeyValue.KeyValueStart(builder) + KeyValue.KeyValueAddKey(builder, flat_key) + KeyValue.KeyValueAddValue(builder, flat_value) + key_value = KeyValue.KeyValueEnd(builder) + + key_value_list.append(key_value) + + + TrtTable.TrtTableStartDictVector(builder, len(key_value_list)) + for key_value in key_value_list: + builder.PrependUOffsetTRelative(key_value) + main_dict = builder.EndVector() + + TrtTable.TrtTableStart(builder) + TrtTable.TrtTableAddDict(builder, main_dict) + cal_table = TrtTable.TrtTableEnd(builder) + + builder.Finish(cal_table) + buf = builder.Output() + + with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file: + file.write(buf) + + # Deserialize data (for validation) + if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"): + cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0) + dict_len = cal_table.DictLength() + for i in range(dict_len): + key_value = cal_table.Dict(i) + logging.info(key_value.Key()) + logging.info(key_value.Value()) + + # write plain text + with open(os.path.join(dir, "calibration.cache"), "w") as file: + for key in sorted(calibration_cache.keys()): + values = calibration_cache[key] + d_values = values.to_dict() + highest = d_values.get("highest", zero) + lowest = d_values.get("lowest", zero) + + highest_val = highest.item() if hasattr(highest, "item") else float(highest) + lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest) + + floats = [float(highest_val), float(lowest_val)] + + value = key + " " + str(max(floats)) + file.write(value) + file.write("\n") def parse_input_args(): parser = argparse.ArgumentParser() @@ -585,22 +688,52 @@ def output_run_config(flags, samples): compute_range = calibrator.compute_data() + calibration_table = {} print("Writing calibration table") try: write_calibration_table(calibration_table) except AttributeError as e: - calibration_table = {} + class TensorDataWrapper: + def __init__(self, data_dict): + self.data_dict = data_dict + + def to_dict(self): + return self.data_dict + + def __repr__(self): + return repr(self.data_dict) + + def __serializable__(self): + return self.data_dict + + calibration_data = {} for k, v in compute_range.data.items(): - min_val = float(v.range_value[0]) if hasattr(v.range_value[0], 'item') else float(v.range_value[0]) - max_val = float(v.range_value[1]) if hasattr(v.range_value[1], 'item') else float(v.range_value[1]) - calibration_table[k] = [min_val, max_val] + if hasattr(v, 'to_dict'): + tensor_dict = v.to_dict() + processed_dict = {} + for dk, dv in tensor_dict.items(): + if isinstance(dv, np.ndarray): + processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist() + elif isinstance(dv, np.number): + processed_dict[dk] = dv.item() + else: + processed_dict[dk] = dv + calibration_data[k] = TensorDataWrapper(processed_dict) + else: + calibration_data[k] = v + + print("Using custom calibration table function") + custom_write_calibration_table(calibration_data) + + print("Calibration is done. Calibration cache is saved to calibration.json") with open(flags.calibration_table, "w") as f: json.dump(calibration_table, f) print("Calibration is done. Calibration cache is saved to " + calibration_table_name) provider_args["migraphx_int8_calibration_table_name"] = calibration_table_name + model_quants = model_quants + precision if flags.ort_quant: From f139fe8b2b2e0bf7aa62560868c6d1a6a540cfc5 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Sun, 13 Apr 2025 00:23:47 +0000 Subject: [PATCH 6/6] Use updated data and flag for model save loading --- .../migraphx/e2e_migraphx_bert_example.py | 86 +++++++++---------- 1 file changed, 39 insertions(+), 47 deletions(-) diff --git a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py index 2064420f1..11df2761e 100644 --- a/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py +++ b/quantization/nlp/bert/migraphx/e2e_migraphx_bert_example.py @@ -277,7 +277,7 @@ def get_op_nodes_not_followed_by_specific_op(model, op1, op2): return not_selected_op1_nodes -def custom_write_calibration_table(calibration_cache, dir="."): +def custom_write_calibration_table(calibration_cache, filename): """ Helper function to write calibration table to files. """ @@ -307,7 +307,7 @@ def default(self, obj): json_data = json.dumps(calibration_cache, cls=MyEncoder) - with open(os.path.join(dir, "calibration.json"), "w") as file: + with open(filename, "w") as file: file.write(json_data) # use `json.loads` to do the reverse # Serialize data using FlatBuffers @@ -352,7 +352,7 @@ def default(self, obj): builder.Finish(cal_table) buf = builder.Output() - with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file: + with open(filename, "wb") as file: file.write(buf) # Deserialize data (for validation) @@ -365,7 +365,7 @@ def default(self, obj): logging.info(key_value.Value()) # write plain text - with open(os.path.join(dir, "calibration.cache"), "w") as file: + with open(filename + ".cache", "w") as file: for key in sorted(calibration_cache.keys()): values = calibration_cache[key] d_values = values.to_dict() @@ -430,9 +430,9 @@ def parse_input_args(): "--calibration_table", action="store", required=False, - default="./bert_calibration_table_100_int8.flatbuffers", + default="bert_calibration_table_100_int8.flatbuffers", type=str, - help='use a previously created calibration table" default is ./bert_calibration_table_100_int8.flatbuffers', + help='use a previously created calibration table" default is bert_calibration_table_100_int8.flatbuffers', ) parser.add_argument( @@ -661,7 +661,7 @@ def output_run_config(flags, samples): if flags.int8 or flags.fp8: model = onnx.load_model(model_path) - + provider_args["migraphx_int8_calibration_table_name"] = str(flags.calibration_table) if os.path.isfile("./" + flags.calibration_table): print("Found previous calibration: " + flags.calibration_table + "Skipping generating table") provider_args["migraphx_int8_calibration_table_name"] = str(flags.calibration_table) @@ -691,45 +691,40 @@ def output_run_config(flags, samples): calibration_table = {} print("Writing calibration table") - try: - write_calibration_table(calibration_table) - except AttributeError as e: - class TensorDataWrapper: - def __init__(self, data_dict): - self.data_dict = data_dict - - def to_dict(self): - return self.data_dict - - def __repr__(self): - return repr(self.data_dict) - - def __serializable__(self): - return self.data_dict - - calibration_data = {} - for k, v in compute_range.data.items(): - if hasattr(v, 'to_dict'): - tensor_dict = v.to_dict() - processed_dict = {} - for dk, dv in tensor_dict.items(): - if isinstance(dv, np.ndarray): - processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist() - elif isinstance(dv, np.number): - processed_dict[dk] = dv.item() - else: - processed_dict[dk] = dv - calibration_data[k] = TensorDataWrapper(processed_dict) - else: - calibration_data[k] = v - - print("Using custom calibration table function") - custom_write_calibration_table(calibration_data) + class TensorDataWrapper: + def __init__(self, data_dict): + self.data_dict = data_dict + + def to_dict(self): + return self.data_dict + + def __repr__(self): + return repr(self.data_dict) + + def __serializable__(self): + return self.data_dict + + calibration_data = {} + for k, v in compute_range.data.items(): + if hasattr(v, 'to_dict'): + tensor_dict = v.to_dict() + processed_dict = {} + for dk, dv in tensor_dict.items(): + if isinstance(dv, np.ndarray): + processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist() + elif isinstance(dv, np.number): + processed_dict[dk] = dv.item() + else: + processed_dict[dk] = dv + calibration_data[k] = TensorDataWrapper(processed_dict) + else: + calibration_data[k] = v + + print("Using custom calibration table function") + custom_write_calibration_table(calibration_data, calibration_table_name) print("Calibration is done. Calibration cache is saved to calibration.json") - with open(flags.calibration_table, "w") as f: - json.dump(calibration_table, f) print("Calibration is done. Calibration cache is saved to " + calibration_table_name) provider_args["migraphx_int8_calibration_table_name"] = calibration_table_name @@ -774,10 +769,7 @@ def __serializable__(self): if flags.save_load: model_name = str(qdq_model_path) + "_s" + str(flags.seq_len) + "_b" + str(flags.batch) + str(model_quants) + ".mxr" print("save load model from " + str(model_name)) - provider_args["migraphx_save_compiled_model"] = flags.save_load - provider_args["migraphx_load_compiled_model"] = flags.save_load - provider_args["migraphx_save_model_path"] = model_name - provider_args["migraphx_load_model_path"] = model_name + provider_args["migraphx_model_cache_dir"] = model_name # QDQ model inference and get SQUAD prediction batch_size = flags.batch