Skip to content

Commit 704f0a1

Browse files
committed
Make prefix logic configurable and polish docs
1 parent b80a8b6 commit 704f0a1

File tree

5 files changed

+81
-13
lines changed

5 files changed

+81
-13
lines changed

README.md

+23-2
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,35 @@ This project explores two kinds of input for commit message completion task: dif
9999
2. Choose one of available model configs or add your own.
100100
3. Note that you have to define missing parameters from [`InputConfig`](conf/data/input_config.py). You can do it via CLI or just rewrite them. Below is the example how to define parameters via CLI.
101101

102-
To launch training of model defined as `XXXModelConfig` and registered via `ConfigStore.store(name="XXX", group="model", node=XXXModelConfig)`, run the following command:
102+
To launch training of model defined as `XXXModelConfig` and registered via `ConfigStore.store(name="XXX", group="model", node=XXXModelConfig)`, run the following command (with actual values instead of X's):
103103
```
104104
python train.py +model=XXX ++input.train_with_history=X ++input.encoder_input_type=X
105105
```
106106

107107
#### Additional steps for RACE model
108108

109-
> :construction: Experiments with RACE model require slightly different procedure. It will be described in this section.
109+
Experiments with RACE model require a slightly different procedure.
110+
111+
1. Fine-tune CodeT5 model. Refer to the instruction above for details.
112+
113+
2. Use encoder from fine-tuned CodeT5 checkpoint to perform retrieval.
114+
115+
Define configuration in [`conf/retrieval_config.py`](conf/retrieval_config.py). You have to either provide a local path to checkpoint in `ckpt_path` or use W&B artifact.
116+
In the latter case, artifact name will be inferred from model configuration.
117+
118+
An example with a local path:
119+
```
120+
python retrieve.py ++ckpt_path=<local_path>
121+
```
122+
123+
An example with a W&B artifact:
124+
```
125+
python retrieve.py +model=codet5 ++input.train_with_history=X ++input.encoder_input_type=X
126+
```
127+
3. Initialize RACE with fine-tuned CodeT5 weights and use retrieved examples to train the model.
128+
Refer to the instruction above for details.
129+
130+
> :construction: Currently, downloading retrieved predictions and fine-tuned CodeT5 checkpoint is only possible with W&B.
110131
111132
### Step 4: Evaluate
112133

compute_metrics.py

+41-10
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@
1818
random.seed(42)
1919

2020

21-
def load_predictions(run: wandb.wandb_sdk.wandb_run.Run, cfg: MetricsConfig) -> str:
22-
input_artifact = run.use_artifact(
21+
def load_predictions(cfg: MetricsConfig) -> str:
22+
"""Load predictions from W&B artifact.
23+
24+
Args:
25+
cfg: Config; all information about artifact should be provided in corresponding fields there.
26+
27+
Returns:
28+
Local path to downloaded predictions.
29+
"""
30+
input_artifact = wandb.use_artifact(
2331
f"{cfg.logger.artifact_config.project}/{cfg.logger.artifact_config.name}:{cfg.logger.artifact_config.version}"
2432
)
25-
if "tags" in input_artifact.metadata:
26-
run.tags = ["new_prefix_logic"] + (
27-
["only_filtered" if cfg.filter.fit_filters else "only_unfiltered"] if cfg.filter.use_filtering else []
28-
)
2933

3034
input_artifact.get_path(cfg.logger.artifact_config.artifact_path).download(
3135
root=hydra.utils.to_absolute_path(
@@ -43,8 +47,24 @@ def load_predictions(run: wandb.wandb_sdk.wandb_run.Run, cfg: MetricsConfig) ->
4347

4448

4549
def add_single_example(
46-
line: Dict[str, str], full_metrics: EvaluationMetrics, prefix_metrics: Dict[int, EvaluationMetrics]
50+
line: Dict[str, str],
51+
full_metrics: EvaluationMetrics,
52+
prefix_metrics: Dict[int, EvaluationMetrics],
53+
include_short: bool,
4754
) -> None:
55+
"""Adds a single example to metrics.
56+
57+
* Compute the usual metrics between full prediction and full target.
58+
* Compute the metrics between all prefixes of prediction and target,
59+
`prefix_metrics` keys are used to determine the numbers of tokens in prefixes.
60+
61+
Args:
62+
line: Current example, expected to include keys `Prediction` and `Target`.
63+
full_metrics: A class for calculating metrics between full prediction and full target.
64+
prefix_metrics: A dictionary where key `i` corresponds to metrics for prefixes of `i` tokens.
65+
include_short: False to only consider messages with >= i tokens when computing metrics for prefixes of i tokens,
66+
True to include all messages.
67+
"""
4868
prediction = line["Prediction"].strip()
4969
target = line["Target"].strip()
5070

@@ -60,6 +80,8 @@ def add_single_example(
6080
target_tokens = target.split()
6181

6282
for i in prefix_metrics:
83+
if not include_short and len(target_tokens) < i:
84+
break
6385
pred_prefix_i = " ".join(pred_tokens[:i])
6486
target_prefix_i = " ".join(target_tokens[:i])
6587
prefix_metrics[i].add_batch(predictions=[pred_prefix_i], references=[target_prefix_i])
@@ -80,8 +102,10 @@ def main(cfg: MetricsConfig):
80102
name=cfg.logger.artifact_config.name,
81103
config=OmegaConf.to_container(cfg, resolve=True), # type: ignore[arg-type]
82104
job_type="metrics" if not cfg.filter.use_filtering else "filter_metrics",
105+
tags=(["new_prefix_logic"] if cfg.include_short else [])
106+
+ (["only_filtered" if cfg.filter.fit_filters else "only_unfiltered"] if cfg.filter.use_filtering else []),
83107
) # type: ignore[assignment]
84-
cfg.preds_path = load_predictions(run=run, cfg=cfg)
108+
cfg.preds_path = load_predictions(cfg)
85109
elif cfg.preds_path:
86110
cfg.preds_path = to_absolute_path(cfg.preds_path)
87111
else:
@@ -102,7 +126,9 @@ def main(cfg: MetricsConfig):
102126
if not cfg.filter.use_filtering:
103127
with jsonlines.open(cfg.preds_path, "r") as reader:
104128
for line in tqdm(reader, desc="Computing metrics"):
105-
add_single_example(line, full_metrics=full_metrics, prefix_metrics=prefix_metrics)
129+
add_single_example(
130+
line, full_metrics=full_metrics, prefix_metrics=prefix_metrics, include_short=cfg.include_short
131+
)
106132

107133
# or define filters configuration to control what subset will be considered
108134
else:
@@ -156,7 +182,12 @@ def include_example(filters_line: Dict[str, str]) -> bool:
156182
and i in subset_ids
157183
and include_example(filters_line)
158184
):
159-
add_single_example(input_line, full_metrics=full_metrics, prefix_metrics=prefix_metrics)
185+
add_single_example(
186+
input_line,
187+
full_metrics=full_metrics,
188+
prefix_metrics=prefix_metrics,
189+
include_short=cfg.include_short,
190+
)
160191

161192
# -----------------------
162193
# - compute results -

conf/metrics_config.py

+3
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,13 @@ class MetricsConfig:
8383
8484
Attributes:
8585
preds_path: Local path to model predictions. Instead of this, you can also define configuration for loading artifact at WandbMetricConfig.
86+
include_short: False to only consider messages with >= i tokens when computing metrics for prefixes of i tokens,
87+
True to include all messages.
8688
max_n_tokens: Maximum number of tokens (for prefix-level metrics).
8789
"""
8890

8991
preds_path: Optional[str] = None
92+
include_short: bool = False
9093
max_n_tokens: int = 15
9194
filter: FilterConfig = field(default_factory=FilterConfig)
9295
logger: WandbMetricConfig = field(default_factory=WandbMetricConfig)

retrieve.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ def download_artifact(cfg: RetrievalConfig, run: wandb.wandb_sdk.wandb_run.Run,
4242

4343

4444
def export_model_checkpoint(cfg: RetrievalConfig) -> str:
45-
"""Helper function to export model weights in Transformers format from Lightning checkpoint."""
45+
"""Helper function to export model weights in a Transformers format from Lightning checkpoint.
46+
47+
Returns:
48+
A local path to directory with checkpoint in a Transformers format.
49+
"""
4650
logging.info(f"Checkpoint path: {cfg.ckpt_path}")
4751

4852
module = CMCModule.load_from_checkpoint(

train.py

+9
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,15 @@
2121

2222

2323
def get_world_size(accelerator: str, devices: Any) -> int:
24+
"""Determines world size for all possible ways of defining number of devices in Lightning.
25+
26+
Args:
27+
accelerator: Argument for `pytorch_lightning.trainer`, corresponds to a device type.
28+
devices: Argument for `pytorch_lightning.trainer`, corresponds to a number of devices/specific devices to use.
29+
30+
Returns:
31+
World size.
32+
"""
2433
if accelerator == "cpu":
2534
return 1
2635
elif accelerator == "gpu":

0 commit comments

Comments
 (0)