@@ -497,6 +497,7 @@ def compute_metrics(p: EvalPrediction):
497
497
label_list = label_list ,
498
498
model = model ,
499
499
num_labels = num_labels ,
500
+ config = config ,
500
501
)
501
502
id_to_label = {id_ : label for label , id_ in label_to_id .items ()}
502
503
@@ -754,7 +755,7 @@ def _get_tokenized_and_preprocessed_raw_datasets(
754
755
# Some models have set the order of the labels to use, so let's make sure
755
756
# we do use it
756
757
label_to_id = _get_label_to_id (
757
- data_args , is_regression , label_list , model , num_labels
758
+ data_args , is_regression , label_list , model , num_labels , config = config
758
759
)
759
760
760
761
if label_to_id is not None :
@@ -842,15 +843,16 @@ def preprocess_function(examples):
842
843
return tokenized_datasets , raw_datasets
843
844
844
845
845
- def _get_label_to_id (data_args , is_regression , label_list , model , num_labels ):
846
+ def _get_label_to_id (data_args , is_regression , label_list , model , num_labels , config ):
846
847
label_to_id = None
848
+ config = model .config if model else config
847
849
if (
848
- model . config .label2id != PretrainedConfig (num_labels = num_labels ).label2id
850
+ config .label2id != PretrainedConfig (num_labels = num_labels ).label2id
849
851
and data_args .task_name is not None
850
852
and not is_regression
851
853
):
852
854
# Some have all caps in their config, some don't.
853
- label_name_to_id = {k .lower (): v for k , v in model . config .label2id .items ()}
855
+ label_name_to_id = {k .lower (): v for k , v in config .label2id .items ()}
854
856
if list (sorted (label_name_to_id .keys ())) == list (sorted (label_list )):
855
857
label_to_id = {
856
858
i : int (label_name_to_id [label_list [i ]]) for i in range (num_labels )
0 commit comments