Skip to content

Commit 7963f9c

Browse files
[float8] add float8 training benchmarking scripts (#1802)
* add float8 training benchmarking scripts * move to benchmarks/float8/training
1 parent 8f93751 commit 7963f9c

File tree

3 files changed

+122
-0
lines changed

3 files changed

+122
-0
lines changed

benchmarks/float8/README.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Float8 training benchmarking
2+
3+
The `float8_training_benchmark.sh` script in this directory can be used to launch a Llama3 8b training run with [torchtitan](https://github.com/pytorch/torchtitan) training run, and parse the logs to calculate the median tokens/sec and peak memory usage for you.
4+
5+
## Usage
6+
7+
Example: `TORCHTITAN_ROOT=${HOME}/torchtitan FLOAT8_RECIPE=rowwise ./float8_training_benchmark.sh`
8+
9+
Training parameters can be configured via environment variables.
10+
11+
- Required:
12+
- `TORCHTITAN_ROOT`
13+
- Optional:
14+
- `RECIPE`: rowwise|tensorwise. defaults to tensorwise.
15+
- `BATCH_SIZE`: defaults to 1.
16+
- `STEPS`: defaults to 100.
17+
18+
**NOTE**: `torch.compile` and FSDP2 are always used. Other forms of parallelism supported in torchtitan are not yet supported in this script.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#!/bin/bash
2+
# This script can be used to launch a torchtitan float8 training run
3+
# with the given parameters,
4+
5+
# script arguments
6+
BATCH_SIZE=${BATCH_SIZE:-1}
7+
STEPS=${STEPS:-100}
8+
9+
# temporary log file which is deleted after performance data is parsed out and metrics are calculated.
10+
LOG_FILE="/tmp/float8_training_log.txt"
11+
12+
# validate user has specified torchtitan root directory
13+
if [ -z "${TORCHTITAN_ROOT}" ]; then
14+
echo "Error: TORCHTITAN environment variable is not set. Please set it before running this script."
15+
echo "Usage: TORCHTITAN_ROOT=<directory> ./float8_training_benchmark.sh"
16+
echo "Optional parameters configurable via environment variables:"
17+
echo " * FLOAT8_RECIPE: "rowwise" or "tensorwise". if set, use float8 training with the specified recipe. otherwise, use bf16 mixed precision training."
18+
echo " * BATCH_SIZE: defaults to 1."
19+
echo " * STEPS: defaults to 100."
20+
exit 1
21+
fi
22+
23+
# validate recipe name
24+
if [ -n "${FLOAT8_RECIPE}" ]; then
25+
FLOAT8_ARGS="--model.converters="float8" --float8.recipe_name=${FLOAT8_RECIPE}"
26+
fi
27+
28+
29+
# remember current directory to return to it later
30+
original_dir=$(pwd)
31+
32+
# navigate to torchtitan root dir
33+
cd ${TORCHTITAN_ROOT}
34+
35+
echo "float8 args: ${FLOAT8_ARGS}"
36+
37+
# run the command with the specified arguments
38+
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.batch_size=${BATCH_SIZE} --training.compile ${FLOAT8_ARGS} 2>&1 | tee ${LOG_FILE}
39+
40+
# return to original working directory
41+
cd $original_dir
42+
43+
# parse logs to calculate top line metrics
44+
python parse_torchtitan_logs.py --log-file ${LOG_FILE}
45+
46+
# clean up logs
47+
rm ${LOG_FILE}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Script which can be used to parse the log file generated by the torchtitan,
4+
and calculate the training performance metrics (mdian tokens/second and peak memory usage).
5+
6+
Usage:
7+
python parse_torchtitan_logs.py --log-file <log_file_path>
8+
"""
9+
10+
import os
11+
import re
12+
import statistics
13+
from argparse import ArgumentParser, Namespace
14+
15+
16+
def main(args: Namespace):
17+
print("\n=====================================================")
18+
print(" Calculating training performance metrics")
19+
print("=====================================================")
20+
21+
log_pattern = re.compile(r"step: (\d+).*?memory: ([\d.]+)GiB.*?tps: ([\d,]+)")
22+
23+
assert os.path.exists(args.log_file), f"{args.log_file} does not exist"
24+
25+
with open(args.log_file, "r") as f:
26+
log_data = f.read()
27+
28+
matches = re.findall(log_pattern, log_data)
29+
30+
tokens_per_second = []
31+
max_memory_usage = 0.0
32+
for match in matches:
33+
step = int(match[0])
34+
memory_usage = float(match[1])
35+
tps = float(match[2].replace(",", ""))
36+
37+
# update peak memory usage
38+
max_memory_usage = max(max_memory_usage, memory_usage)
39+
40+
# collect tokens per second, excluding step 1 which has initialization overhead
41+
if step != 1:
42+
tokens_per_second.append(tps)
43+
44+
# calculate median tokens per second
45+
median_tps = statistics.median(tokens_per_second) if tokens_per_second else 0
46+
47+
print(f"Median Tokens/Second (excluding step 1): {median_tps}")
48+
print(f"Max Memory Usage: {max_memory_usage} GiB")
49+
50+
51+
if __name__ == "__main__":
52+
argparser = ArgumentParser()
53+
argparser.add_argument(
54+
"--log-file", type=str, required=True, help="torchtitan log file"
55+
)
56+
args = argparser.parse_args()
57+
main(args)

0 commit comments

Comments
 (0)