Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 641c830

Browse files
authored
[Bug-fix] Text classifiation config (#1545)
1 parent d823867 commit 641c830

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/sparseml/transformers/text_classification.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def compute_metrics(p: EvalPrediction):
497497
label_list=label_list,
498498
model=model,
499499
num_labels=num_labels,
500+
config=config,
500501
)
501502
id_to_label = {id_: label for label, id_ in label_to_id.items()}
502503

@@ -754,7 +755,7 @@ def _get_tokenized_and_preprocessed_raw_datasets(
754755
# Some models have set the order of the labels to use, so let's make sure
755756
# we do use it
756757
label_to_id = _get_label_to_id(
757-
data_args, is_regression, label_list, model, num_labels
758+
data_args, is_regression, label_list, model, num_labels, config=config
758759
)
759760

760761
if label_to_id is not None:
@@ -842,15 +843,16 @@ def preprocess_function(examples):
842843
return tokenized_datasets, raw_datasets
843844

844845

845-
def _get_label_to_id(data_args, is_regression, label_list, model, num_labels):
846+
def _get_label_to_id(data_args, is_regression, label_list, model, num_labels, config):
846847
label_to_id = None
848+
config = model.config if model else config
847849
if (
848-
model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
850+
config.label2id != PretrainedConfig(num_labels=num_labels).label2id
849851
and data_args.task_name is not None
850852
and not is_regression
851853
):
852854
# Some have all caps in their config, some don't.
853-
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
855+
label_name_to_id = {k.lower(): v for k, v in config.label2id.items()}
854856
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
855857
label_to_id = {
856858
i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)

0 commit comments

Comments
 (0)