-
Notifications
You must be signed in to change notification settings - Fork 316
/
Copy pathpretrain_xxl.gin
32 lines (23 loc) · 992 Bytes
/
pretrain_xxl.gin
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from __gin__ import dynamic_registration
from t5x import partitioning
# Model (has to be imported first so that optimizer and vocab can be overridden)
include "t5x/examples/scalable_t5/mt5/xxl.gin"
# Architecture-specific configs
include "t5x/examples/scalable_t5/umt5/architectures/encoder_decoder.gin"
# Run mode
include "t5x/examples/scalable_t5/umt5/runs/pretraining_common.gin"
# Optimizer
include "t5x/examples/scalable_t5/umt5/optimizer/adafactor_momentum_nofactor.gin"
# Vocabulary
include "t5x/examples/scalable_t5/umt5/vocab.gin"
# Partitioning
partitioning.PjitPartitioner:
model_parallel_submesh = (1, 1, 8, 1)
# Task configurations
MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TRAIN_EVAL_MIXTURE_OR_TASK_NAME = %gin.REQUIRED
TASK_FEATURE_LENGTHS = {"inputs": 1024, "targets": 229}
USE_CACHED_TASKS = True
TRAIN_STEPS = 1_000_000
partitioning.standard_logical_axis_rules.activation_partitioning_dims = 1
partitioning.standard_logical_axis_rules.parameter_partitioning_dims = 2