Skip to content

Commit c1e71e4

Browse files
committed
.
1 parent dcb0ab5 commit c1e71e4

File tree

7 files changed

+128
-26
lines changed

7 files changed

+128
-26
lines changed

src/lighteval/config/lighteval_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,11 @@ class LightEvalConfig:
101101
class FullNanotronConfig:
102102
lighteval_config: LightEvalConfig
103103
nanotron_config: "Config"
104+
105+
@property
106+
def generation_parameters(self):
107+
# Return the generation parameters from the lighteval config
108+
# or create default generation parameters if none are set
109+
if self.lighteval_config.generation:
110+
return self.lighteval_config.generation
111+
return GenerationArgs()

src/lighteval/main_nanotron.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,17 @@ def nanotron(
4242
checkpoint_config_path: Annotated[
4343
str, Option(help="Path to the nanotron checkpoint YAML or python config file, potentially on s3.")
4444
],
45-
lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")],
45+
lighteval_config_path: Annotated[str, Option(help="Path to a YAML config to be used for the evaluation.")] = None,
4646
cache_dir: Annotated[str, Option(help="Cache directory for datasets and models.")] = CACHE_DIR,
4747
):
4848
"""
4949
Evaluate models using nanotron as backend.
5050
"""
5151
from nanotron.config import Config, get_config_from_file
52+
from nanotron.config.parallelism_config import ParallelismArgs
5253

53-
from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig
54+
from lighteval.config.lighteval_config import FullNanotronConfig, LightEvalConfig, LightEvalLoggingArgs, LightEvalTasksArgs
5455
from lighteval.logging.evaluation_tracker import EvaluationTracker
55-
from lighteval.logging.hierarchical_logger import htrack_block
5656
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
5757
from lighteval.utils.imports import NO_NANOTRON_ERROR_MSG, is_nanotron_available
5858
from lighteval.utils.utils import EnvConfig
@@ -61,23 +61,38 @@ def nanotron(
6161

6262
if not is_nanotron_available():
6363
raise ImportError(NO_NANOTRON_ERROR_MSG)
64+
65+
# Create nanotron config
66+
if not checkpoint_config_path.endswith(".yaml"):
67+
raise ValueError("The checkpoint path should point to a YAML file")
68+
69+
model_config = get_config_from_file(
70+
checkpoint_config_path,
71+
config_class=Config,
72+
model_config_class=None,
73+
skip_unused_config_keys=True,
74+
skip_null_keys=True,
75+
)
6476

65-
with htrack_block("Load nanotron config"):
66-
# Create nanotron config
67-
if not checkpoint_config_path.endswith(".yaml"):
68-
raise ValueError("The checkpoint path should point to a YAML file")
69-
70-
model_config = get_config_from_file(
71-
checkpoint_config_path,
72-
config_class=Config,
73-
model_config_class=None,
74-
skip_unused_config_keys=True,
75-
skip_null_keys=True,
76-
)
77-
78-
# We are getting an type error, because the get_config_from_file is not correctly typed,
77+
# Create or use default lighteval config
78+
if lighteval_config_path is not None:
7979
lighteval_config: LightEvalConfig = get_config_from_file(lighteval_config_path, config_class=LightEvalConfig) # type: ignore
80-
nanotron_config = FullNanotronConfig(lighteval_config, model_config)
80+
else:
81+
# Create default config with minimal required parameters
82+
default_logging = LightEvalLoggingArgs(
83+
output_dir="./eval_results"
84+
)
85+
default_tasks = LightEvalTasksArgs(
86+
tasks="lighteval|agieval:aqua-rat|5|0"
87+
)
88+
default_parallelism = ParallelismArgs(dp=1, pp=1, tp=1)
89+
lighteval_config = LightEvalConfig(
90+
logging=default_logging,
91+
tasks=default_tasks,
92+
parallelism=default_parallelism
93+
)
94+
95+
nanotron_config = FullNanotronConfig(lighteval_config, model_config)
8196

8297
evaluation_tracker = EvaluationTracker(
8398
output_dir=lighteval_config.logging.output_dir,

src/lighteval/models/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.

src/lighteval/models/nanotron/nanotron_model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,14 @@ def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
343343
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
344344

345345
def _model_call(self, inputs: torch.Tensor) -> torch.Tensor:
346-
return self.model(inputs)
346+
position_ids = (
347+
torch.arange(
348+
inputs.shape[1], device=inputs.device, dtype=torch.int32
349+
)
350+
.unsqueeze(0)
351+
.repeat(inputs.shape[0], 1)
352+
)
353+
return self.model(inputs, position_ids)
347354

348355
def homogeneize_ending_conditions(self, ending_condition: tuple | dict | list | str) -> tuple[list, int]:
349356
"""Ending conditions are submitted in several possible formats.
@@ -711,14 +718,14 @@ def _loglikelihood_single_token(
711718
inputs, padding_length=max_context, max_context=max_context, full_attention_masks=True
712719
)
713720
# batched_inputs, batch_attention, input_lengths, truncated, padded
714-
715-
out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask)
721+
position_ids = torch.arange(batch_model.input_ids.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0).repeat(batch_model.input_ids.shape[0], 1)
722+
out = self.model(input_ids=batch_model.input_ids, position_ids=position_ids)
716723

717724
if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
718725
# This process got outputs
719726

720727
# Gather all the output accross TP
721-
out = out.transpose(0, 1).contiguous() # [batch, seq_length, vocab]
728+
out = out.view(*batch_model.input_ids.shape, -1).contiguous() # [batch, seq_length, vocab]
722729

723730
gathered_out = [torch.zeros_like(out) for _ in range(self.parallel_context.tp_pg.size())]
724731
dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False)
@@ -944,7 +951,8 @@ def _loglikelihood_tokens(
944951
)
945952
# batched_inputs, batch_attention, input_lengths, truncated, padded
946953
with torch.no_grad():
947-
out = self.model(input_ids=batch_model.input_ids, input_mask=batch_model.input_mask)
954+
position_ids = torch.arange(batch_model.input_ids.shape[1], device=self.device, dtype=torch.int32).unsqueeze(0).repeat(batch_model.input_ids.shape[0], 1)
955+
out = self.model(input_ids=batch_model.input_ids, position_ids=position_ids)
948956

949957
if dist.get_rank(self.parallel_context.pp_pg) == self.output_pp_rank:
950958
# This process got outputs
@@ -954,7 +962,7 @@ def _loglikelihood_tokens(
954962
dist.all_gather(gathered_out, out, group=self.parallel_context.tp_pg, async_op=False)
955963
out = torch.cat(gathered_out, dim=-1)
956964

957-
out = out.transpose(0, 1) # [batch, seq_length, vocab]
965+
out = out.view(*batch_model.input_ids.shape, -1) # [batch, seq_length, vocab]
958966
multi_logits = F.log_softmax(out, dim=-1) # [batch, padding_length, vocab]
959967

960968
logits_sum = []
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
# Import and re-export the NanotronLightevalModel class from the nanotron module
24+
from lighteval.models.nanotron.nanotron_model import NanotronLightevalModel
25+
26+
__all__ = ["NanotronLightevalModel"]

src/lighteval/pipeline.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
from nanotron.parallel.context import ParallelContext
7373
from nanotron.utils import local_ranks_zero_first
7474

75-
from lighteval.models.nanotron_model import NanotronLightevalModel
75+
# from lighteval.models.nanotron import NanotronLightevalModel
7676

7777

7878
import logging
@@ -188,16 +188,19 @@ def _init_model(self, model_config, model):
188188
logger.info("--- LOADING MODEL ---")
189189
if model_config is not None:
190190
if self.parallel_context:
191+
from lighteval.models.nanotron_model import NanotronLightevalModel
192+
191193
return NanotronLightevalModel(
192194
checkpoint_path=os.path.dirname(self.pipeline_parameters.nanotron_checkpoint_path)
193195
if self.pipeline_parameters.nanotron_checkpoint_path
194196
else "",
195-
nanotron_config=self.model_config,
197+
nanotron_config=model_config,
196198
parallel_context=self.parallel_context,
197199
debug_one_layer_model=False,
198200
model_class=None,
199201
env_config=self.pipeline_parameters.env_config,
200202
)
203+
# return None
201204
else:
202205
return load_model(config=model_config, env_config=self.pipeline_parameters.env_config)
203206
if isinstance(model, TransformersModel):

0 commit comments

Comments
 (0)