-
Notifications
You must be signed in to change notification settings - Fork 890
/
Copy pathtrain_actor.py
101 lines (81 loc) · 3.68 KB
/
train_actor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import ray
import time
import numpy as np
from queue import Queue
from mesh_transformer.util import head_print
@ray.remote(resources={"tpu": 1})
class NetworkRunner(object):
def __init__(self, mesh_shape, network_builder):
self.mesh_shape = mesh_shape
self.network_builder = network_builder
self.input_q = Queue(maxsize=1)
self.output_q = Queue(maxsize=1)
def run(self):
print(f"jax runtime initialization starting")
import jax
from jax.experimental.maps import thread_resources, ResourceEnv, Mesh
import haiku as hk
# jax.experimental.maps.EXPERIMENTAL_SPMD_LOWERING = True
thread_resources.env = ResourceEnv(Mesh(np.empty((), dtype=object), ()), ())
start = time.time()
jax.devices()
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=ResourceWarning)
if jax.host_id() == 0:
warnings.filterwarnings("default")
head_print(f"jax devices: {jax.device_count()}")
head_print(f"jax runtime initialized in {time.time() - start:.06}s")
devices = np.array(jax.devices()).reshape(self.mesh_shape)
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
start = time.time()
network = self.network_builder()
head_print(f"Initialized in {time.time() - start:.06}s")
while True:
operation, input = self.input_q.get()
if operation == "train":
self.output_q.put(network.train(input))
elif operation == "eval":
self.output_q.put(network.eval(input))
elif operation == "generate":
self.output_q.put(network.generate(*input))
elif operation == "write_ckpt":
path, shard = input
network.write_ckpt(path, shard)
self.output_q.put(None)
elif operation == "load_ckpt":
network.load_ckpt(input)
self.output_q.put(network.state["step"][0])
elif operation == "get_params":
self.output_q.put(hk.data_structures.tree_size(network.state['params']))
elif operation == "move_params":
# only needed for inference, otherwise first train step does this
local_shards = max(jax.local_device_count() // self.mesh_shape[1], 1)
# delete the optimizer states otherwise it OOMs for some reason
# TODO: use ShardedDeviceArray or something to get around this for bigger models
del network.state["opt_state"]
network.state = network.move_xmap(network.state, np.zeros(local_shards))
self.output_q.put(None)
else:
raise Exception("Not implemented")
def get_params(self):
self.input_q.put(("get_params", None))
return self.output_q.get()
def train(self, sample):
self.input_q.put(("train", sample))
return self.output_q.get()
def eval(self, sample):
self.input_q.put(("eval", sample))
return self.output_q.get()
def generate(self, input):
self.input_q.put(("generate", input))
return self.output_q.get()
def write_ckpt(self, path, shard):
self.input_q.put(("write_ckpt", (path, shard)))
return self.output_q.get()
def load_ckpt(self, path):
self.input_q.put(("load_ckpt", path))
return self.output_q.get()
def move_params(self):
self.input_q.put(("move_params", None))
return self.output_q.get()