Skip to content

Commit 7a3ad2f

Browse files
committed
move the registration of the perplexity eval function where it belongs
1 parent d4cdd98 commit 7a3ad2f

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

src/deepsparse/evaluation/evaluator.py

-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
from typing import List, Optional, Union
1717

1818
from deepsparse import Pipeline
19-
from deepsparse.evaluation.integrations.perplexity import ( # noqa
20-
integration_eval as integration_eval_perplexity,
21-
)
2219
from deepsparse.evaluation.registry import EvaluationRegistry
2320
from deepsparse.evaluation.results import Result
2421
from deepsparse.evaluation.utils import create_pipeline

src/deepsparse/evaluation/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,14 @@ def potentially_check_dependency_import(integration_name: str) -> bool:
4242
:return: True if the dependency is installed, False otherwise
4343
"""
4444

45-
if integration_name.replace("_", "-") == LM_EVALUATION_HARNESS:
45+
if integration_name == LM_EVALUATION_HARNESS:
4646
from deepsparse.evaluation.integrations import try_import_lm_evaluation_harness
4747

4848
try_import_lm_evaluation_harness()
49+
if integration_name == PERPLEXITY:
50+
from deepsparse.evaluation.integrations.perplexity import ( # noqa F401
51+
integration_eval,
52+
)
4953

5054
return True
5155

tests/deepsparse/evaluation/test_evaluator.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,25 @@ def test_evaluate_pipeline_without_kv_cache(
115115
not try_import_lm_evaluation_harness(raise_error=False),
116116
reason="lm_evaluation_harness not installed",
117117
)
118-
def test_evaluation_llm_evaluation_harness_integration_name(
118+
def test_evaluation_llm_evaluation_harness(
119119
model_path,
120-
datasets,
121120
):
122121
assert evaluate(
123122
model=model_path,
124123
# testing only on hellaswag dataset
125124
# to avoid long running time
126-
datasets=datasets[0],
125+
datasets="hellaswag",
127126
limit=1,
128127
integration="lm_evaluation_harness",
129128
)
130129

131130

131+
def test_evaluation_perplexity(model_path):
132+
assert evaluate(
133+
model=model_path, datasets="openai_humaneval", limit=1, integration="perplexity"
134+
)
135+
136+
132137
@pytest.mark.parametrize("type_serialization", ["json", "yaml"])
133138
@pytest.mark.skipif(
134139
tuple(map(int, sys.version.split(".")[:2])) < (3, 10),

0 commit comments

Comments
 (0)