5
5
# license information.
6
6
# --------------------------------------------------------------------------
7
7
import abc
8
+ import copy
8
9
import itertools
9
10
import os
10
11
import uuid
21
22
from .quant_utils import apply_plot , load_model_with_shape_infer , smooth_distribution
22
23
23
24
25
+ def rel_entr (pk : np .ndarray , qk : np .ndarray ) -> np .ndarray :
26
+ """
27
+ See https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html#scipy.special.rel_entr.
28
+ Python implementation.
29
+ """
30
+ res = np .empty (pk .shape , dtype = pk .dtype )
31
+ res [:] = pk [:] * np .log (pk [:] / qk [:])
32
+ c2 = (pk == 0 ) & (qk >= 0 )
33
+ res [c2 ] = 0
34
+ c1 = (pk > 0 ) & (qk > 0 )
35
+ res [~ c1 ] = np .inf
36
+ return res
37
+
38
+
39
+ def entropy (
40
+ pk : np .ndarray ,
41
+ qk : np .ndarray ,
42
+ base : Optional [float ] = None ,
43
+ axis : int = 0 ,
44
+ ) -> np .ndarray :
45
+ """
46
+ Simplifeied version of entropy.
47
+ Source: https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.entropy.html.
48
+ This avoids taking a dependency on scipy just for this function.
49
+ """
50
+ assert base is None or base > 0 , "base={base} must be a positive number or `None`."
51
+ assert qk is not None , "qk is None"
52
+
53
+ pk = np .asarray (pk ).astype (np .float32 )
54
+ pk = 1.0 * pk / np .sum (pk , axis = axis , keepdims = True )
55
+
56
+ qk = np .asarray (qk ).astype (np .float32 )
57
+ pk , qk = np .broadcast_arrays (pk , qk )
58
+ qk = 1.0 * qk / np .sum (qk , axis = axis , keepdims = True )
59
+ vec = rel_entr (pk , qk )
60
+
61
+ s = np .sum (vec , axis = axis )
62
+ if base is not None :
63
+ s /= np .log (base )
64
+ return s .astype (pk .dtype )
65
+
66
+
24
67
class TensorData :
25
68
_allowed = frozenset (["avg" , "std" , "lowest" , "highest" , "hist" , "hist_edges" , "bins" ])
26
69
_floats = frozenset (["avg" , "std" , "lowest" , "highest" , "hist_edges" ])
@@ -708,8 +751,8 @@ def collect_absolute_value(self, name_to_arr):
708
751
min_value = np .min (data_arr_np )
709
752
max_value = np .max (data_arr_np )
710
753
else :
711
- min_value = 0
712
- max_value = 0
754
+ min_value = np . array ( 0 , dtype = data_arr_np . dtype )
755
+ max_value = np . array ( 0 , dtype = data_arr_np . dtype )
713
756
714
757
data_arr_np = np .absolute (data_arr_np ) # only consider absolute value
715
758
@@ -725,6 +768,8 @@ def collect_absolute_value(self, name_to_arr):
725
768
old_histogram = self .histogram_dict [tensor ]
726
769
old_min = old_histogram [2 ]
727
770
old_max = old_histogram [3 ]
771
+ assert hasattr (old_min , "dtype" ), f"old_min should be a numpy array but is { type (old_min )} "
772
+ assert hasattr (old_max , "dtype" ), f"old_min should be a numpy array but is { type (old_max )} "
728
773
old_hist = old_histogram [0 ]
729
774
old_hist_edges = old_histogram [1 ]
730
775
temp_amax = np .max (data_arr_np )
@@ -757,7 +802,7 @@ def collect_value(self, name_to_arr):
757
802
min_value = np .array (0 , dtype = data_arr .dtype )
758
803
max_value = np .array (0 , dtype = data_arr .dtype )
759
804
760
- threshold = max (abs (min_value ), abs (max_value ))
805
+ threshold = np . array ( max (abs (min_value ), abs (max_value )), dtype = data_arr . dtype )
761
806
762
807
if tensor in self .histogram_dict :
763
808
old_histogram = self .histogram_dict [tensor ]
@@ -809,7 +854,7 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho
809
854
def compute_collection_result (self ):
810
855
if not self .histogram_dict or len (self .histogram_dict ) == 0 :
811
856
raise ValueError ("Histogram has not been collected. Please run collect() first." )
812
- print (f"Finding optimal threshold for each tensor using { self .method } algorithm ..." )
857
+ print (f"Finding optimal threshold for each tensor using { self .method !r } algorithm ..." )
813
858
814
859
if self .method == "entropy" :
815
860
return self .compute_entropy ()
@@ -938,7 +983,14 @@ def compute_distribution(self):
938
983
assert avg_coef .dtype != np .float64
939
984
assert std_coef .dtype != np .float64
940
985
assert hist_edges .dtype != np .float64
941
- thresholds_dict [tensor ] = TensorData (avg = avg_coef , std = std_coef , hist = hist , hist_edges = hist_edges )
986
+ thresholds_dict [tensor ] = TensorData (
987
+ avg = avg_coef ,
988
+ std = std_coef ,
989
+ hist = hist ,
990
+ hist_edges = hist_edges ,
991
+ lowest = hist_edges .min (),
992
+ highest = hist_edges .max (),
993
+ )
942
994
943
995
# Plot histogram for debug only
944
996
if os .environ .get ("QUANTIZATION_DEBUG" , 0 ) in (1 , "1" ):
@@ -952,18 +1004,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
952
1004
`q` is a truncated version of the original distribution.
953
1005
Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
954
1006
"""
955
- import copy
956
-
957
- from scipy .stats import entropy
958
-
959
1007
hist = histogram [0 ]
960
1008
hist_edges = histogram [1 ]
961
1009
num_bins = hist .size
962
1010
zero_bin_index = num_bins // 2
963
1011
num_half_quantized_bin = num_quantized_bins // 2
964
1012
1013
+ dtype = histogram [1 ].dtype
965
1014
kl_divergence = np .zeros (zero_bin_index - num_half_quantized_bin + 1 )
966
- thresholds = [(0 , 0 ) for i in range (kl_divergence .size )]
1015
+ thresholds = [(np . array ( 0 , dtype = dtype ), np . array ( 0 , dtype = dtype ) ) for i in range (kl_divergence .size )]
967
1016
968
1017
# <------------ num bins ---------------->
969
1018
# <--- quantized bins ---->
@@ -983,10 +1032,7 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
983
1032
start_index = zero_bin_index - i
984
1033
end_index = zero_bin_index + i + 1 if (zero_bin_index + i + 1 ) <= num_bins else num_bins
985
1034
986
- thresholds [i - num_half_quantized_bin ] = (
987
- float (hist_edges [start_index ]),
988
- float (hist_edges [end_index ]),
989
- )
1035
+ thresholds [i - num_half_quantized_bin ] = (hist_edges [start_index ], hist_edges [end_index ])
990
1036
991
1037
sliced_distribution = copy .deepcopy (hist [start_index :end_index ])
992
1038
@@ -1020,15 +1066,15 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
1020
1066
1021
1067
norm = sum (nonzeros [start :end ])
1022
1068
if norm != 0 :
1023
- q [start :end ] = float ( quantized_bins [index ]) / float ( norm )
1069
+ q [start :end ] = quantized_bins [index ] / norm
1024
1070
1025
1071
p = smooth_distribution (p )
1026
1072
q = smooth_distribution (q )
1027
-
1028
- if isinstance (q , np .ndarray ):
1029
- kl_divergence [i - num_half_quantized_bin ] = entropy (p , q )
1073
+ if p is None or q is None :
1074
+ div = np .array (np .inf , dtype = dtype )
1030
1075
else :
1031
- kl_divergence [i - num_half_quantized_bin ] = float ("inf" )
1076
+ div = np .array (entropy (p , q ), dtype = dtype )
1077
+ kl_divergence [i - num_half_quantized_bin ] = div
1032
1078
1033
1079
min_kl_divergence_idx = np .argmin (kl_divergence )
1034
1080
optimal_threshold = thresholds [min_kl_divergence_idx ]
@@ -1038,6 +1084,8 @@ def get_entropy_threshold(self, histogram, num_quantized_bins):
1038
1084
optimal_threshold = (min_value , optimal_threshold [1 ])
1039
1085
if optimal_threshold [1 ] > max_value :
1040
1086
optimal_threshold = (optimal_threshold [0 ], max_value )
1087
+ assert hasattr (optimal_threshold [0 ], "dtype" )
1088
+ assert hasattr (optimal_threshold [1 ], "dtype" )
1041
1089
return optimal_threshold
1042
1090
1043
1091
0 commit comments