Skip to content

Commit 153684e

Browse files
committed
update before add batch
1 parent 28e956f commit 153684e

File tree

48 files changed

+4578
-446
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+4578
-446
lines changed

Agents/MAD5PG/actors_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
sys.path.append(r"/home/neardws/Documents/Game-Theoretic-Deep-Reinforcement-Learning/")
33
from environment_loop import EnvironmentLoop
44
from absl.testing import absltest
5-
from Agents.MAD4PG import actors
5+
from Agents.MAD5PG import actors
66
from Environment.environment import make_environment_spec
7-
from Agents.MAD4PG.networks import make_policy_network
7+
from Agents.MAD5PG.networks import make_policy_network
88
from Experiment.make_environment import get_default_environment
99

1010
class ActorTest(absltest.TestCase):
@@ -13,11 +13,11 @@ class ActorTest(absltest.TestCase):
1313
def test_feedforward(self):
1414

1515
time_slots, task_list, vehicle_list, edge_list, distance_matrix, channel_condition_matrix, \
16-
vehicle_index_within_edges, environment_config, environment = get_default_environment()
16+
vehicle_index_within_edges, environment_config, environment = get_default_environment(for_mad5pg=True)
1717

1818
env_spec = make_environment_spec(environment)
1919

20-
policy_networks = [make_policy_network(env_spec.edge_actions) for _ in range(environment_config.edge_number)]
20+
policy_networks = make_policy_network(env_spec.edge_actions)
2121

2222
actor = actors.FeedForwardActor(
2323
policy_networks=policy_networks,

Agents/MAD5PG/agent.py

+20-16
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import copy
66
import dataclasses
7-
from typing import Callable, Iterator, List, Optional, Union, Sequence
7+
from typing import Iterator, List, Optional, Union, Sequence
88
import acme
99
from acme import adders
1010
from acme import core
@@ -23,7 +23,7 @@
2323
import sonnet as snt
2424
import launchpad as lp
2525
import functools
26-
import dm_env
26+
from Utilities.FileOperator import load_obj
2727
from Agents.MAD5PG.networks import make_default_MAD3PGNetworks, MAD3PGNetwork
2828
from environment_loop import EnvironmentLoop
2929

@@ -57,17 +57,17 @@ class MAD3PGConfig:
5757
accelerator: 'TPU', 'GPU', or 'CPU'. If omitted, the first available accelerator type from ['TPU', 'GPU', 'CPU'] will be selected.
5858
"""
5959
discount: float = 0.996
60-
batch_size: int = 512
60+
batch_size: int = 256
6161
prefetch_size: int = 4
6262
target_update_period: int = 100
63-
variable_update_period: int = 1000
63+
variable_update_period: int = 500
6464
policy_optimizers: Optional[snt.Optimizer] = None
6565
critic_optimizers: Optional[snt.Optimizer] = None
6666
min_replay_size: int = 1000
6767
max_replay_size: int = 1000000
68-
samples_per_insert: Optional[float] = 1.0
69-
n_step: int = 1
70-
sigma: float = 0.3
68+
samples_per_insert: Optional[float] = 32.0
69+
n_step: int = 5
70+
sigma: float = 0.5
7171
clipping: bool = True
7272
replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE
7373
counter: Optional[counting.Counter] = None
@@ -104,6 +104,7 @@ def __init__(
104104
if networks is None:
105105
online_networks = make_default_MAD3PGNetworks(
106106
action_spec=environment_spec.edge_actions,
107+
sigma=self._config.sigma,
107108
)
108109
else:
109110
online_networks = networks
@@ -118,7 +119,7 @@ def __init__(
118119
target_networks.init(self._environment_spec)
119120

120121
# Create the behavior policy.
121-
policy_networks = online_networks.make_policy(self._environment_spec, self._config.sigma)
122+
policy_networks = online_networks.make_policy()
122123

123124
# Create the replay server and grab its address.
124125
replay_tables = self.make_replay_tables(self._environment_spec)
@@ -289,7 +290,7 @@ class MultiAgentDistributedDDPG:
289290
def __init__(
290291
self,
291292
config: MAD3PGConfig,
292-
environment_factory: Callable[[bool], dm_env.Environment],
293+
environment_file_name: str,
293294
environment_spec,
294295
networks: Optional[MAD3PGNetwork] = None,
295296
num_actors: int = 1,
@@ -311,11 +312,14 @@ def __init__(
311312
self._log_every = log_every
312313
self._networks = networks
313314
self._environment_spec = environment_spec
314-
self._environment_factory = environment_factory
315+
self._environment_file_name = environment_file_name
315316
# Create the agent.
317+
318+
environment = load_obj(environment_file_name)
319+
316320
self._agent = MAD3PGAgent(
317321
config=self._config,
318-
environment=self._environment_factory(False),
322+
environment=environment,
319323
environment_spec=self._environment_spec,
320324
networks=self._networks,
321325
)
@@ -379,10 +383,10 @@ def actor(
379383

380384
networks.init(self._environment_spec)
381385

382-
policy_networks = networks.make_policy(environment_spec=self._environment_spec, sigma=self._config.sigma)
386+
policy_networks = networks.make_policy()
383387

384388
# Create the environment
385-
environment = self._environment_factory(False)
389+
environment = load_obj(self._environment_file_name)
386390

387391
# Create the agent.
388392
actor = self._agent.make_actor(
@@ -395,7 +399,7 @@ def actor(
395399
counter = counting.Counter(counter, 'actor')
396400
logger = loggers.make_default_logger(
397401
'actor',
398-
save_data=False,
402+
save_data=True,
399403
time_delta=self._log_every,
400404
steps_key='actor_steps')
401405

@@ -420,10 +424,10 @@ def evaluator(
420424
networks = self._networks
421425
networks.init(self._environment_spec)
422426

423-
policy_networks = networks.make_policy(self._environment_spec)
427+
policy_networks = networks.make_policy()
424428

425429
# Make the environment
426-
environment = self._environment_factory(True)
430+
environment = load_obj(self._environment_file_name)
427431

428432
# Create the agent.
429433
actor = self._agent.make_actor(

Agents/MAD5PG/agent_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import launchpad as lp
77
from absl.testing import absltest
88
from Environment.environment import vehicularNetworkEnv, make_environment_spec
9-
from Agents.MAD4PG.networks import make_default_MAD3PGNetworks
10-
from Agents.MAD4PG.agent import MultiAgentDistributedDDPG, MAD3PGConfig
9+
from Agents.MAD5PG.networks import make_default_MAD3PGNetworks
10+
from Agents.MAD5PG.agent import MultiAgentDistributedDDPG, MAD3PGConfig
1111
from Experiment.make_environment import get_default_environment
1212

1313

@@ -18,7 +18,7 @@ def test_control_suite(self):
1818
"""Tests that the agent can run on the control suite without crashing."""
1919

2020
time_slots, task_list, vehicle_list, edge_list, distance_matrix, channel_condition_matrix, \
21-
vehicle_index_within_edges, environment_config, environment = get_default_environment()
21+
vehicle_index_within_edges, environment_config, environment = get_default_environment(for_mad5pg=True)
2222

2323
spec = make_environment_spec(environment)
2424

Agents/MAD5PG/learning.py

+38-20
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,11 @@ def __init__(
154154
'critic_optimizer': self._critic_optimizers,
155155
'num_steps': self._num_steps,
156156
})
157-
object_to_save = dict()
158-
object_to_save['policy'] = self._policy_networks
159-
object_to_save['critic_mean'] = snt.Sequential([self._critic_networks, acme_nets.StochasticMeanHead()])
160-
self._snapshotter = tf2_savers.Snapshotter(
161-
objects_to_save=object_to_save)
157+
# object_to_save = dict()
158+
# object_to_save['policy'] = self._policy_networks
159+
# object_to_save['critic_mean'] = snt.Sequential([self._critic_networks, acme_nets.StochasticMeanHead()])
160+
# self._snapshotter = tf2_savers.Snapshotter(
161+
# objects_to_save=object_to_save)
162162

163163
# Do not record timestamps until after the first learning step is done.
164164
# This is to avoid including the time it takes for actors to come online and
@@ -195,7 +195,6 @@ def _step(self, sample) -> Dict[str, tf.Tensor]:
195195
# a_t_list.append(a_t)
196196

197197
# edge_a_t = tf.concat([a_t_list[i] for i in range(len(self._target_observation_networks))], axis=1)
198-
199198
a_t_list = []
200199
for i in range(self._edge_number):
201200
observation = transitions.next_observation[:, i, :]
@@ -205,7 +204,7 @@ def _step(self, sample) -> Dict[str, tf.Tensor]:
205204
a_t_list.append(a_t)
206205

207206
edge_next_a_t = tf.concat([a_t_list[i] for i in range(self._edge_number)], axis=1)
208-
207+
edge_next_a_t = tf.reshape(edge_next_a_t, [batch_size, self._edge_number, self._edge_action_size])
209208

210209
for edge_index in range(self._edge_number):
211210

@@ -218,8 +217,20 @@ def _step(self, sample) -> Dict[str, tf.Tensor]:
218217
o_t = tree.map_structure(tf.stop_gradient, o_t)
219218

220219
# Critic learning.
221-
q_tm1 = self._critic_networks(o_tm1, tf.reshape(transitions.action, shape=[batch_size, -1]))
222-
q_t = self._target_critic_networks(o_t, tf.reshape(edge_next_a_t, shape=[batch_size, -1]))
220+
critic_actions = tf2_utils.batch_concat([
221+
transitions.action[:, : edge_index, :],
222+
transitions.action[:, edge_index + 1 :, :],
223+
transitions.action[:, edge_index, :],
224+
])
225+
q_tm1 = self._critic_networks(o_tm1, tf.reshape(critic_actions, shape=[batch_size, -1]))
226+
227+
228+
critic_actions = tf2_utils.batch_concat([
229+
edge_next_a_t[:, : edge_index, :],
230+
edge_next_a_t[:, edge_index + 1 :, :],
231+
edge_next_a_t[:, edge_index, :],
232+
])
233+
q_t = self._target_critic_networks(o_t, tf.reshape(critic_actions, shape=[batch_size, -1]))
223234

224235
# Critic loss.
225236
critic_loss = losses.categorical(q_tm1, transitions.reward[:, edge_index],
@@ -231,26 +242,33 @@ def _step(self, sample) -> Dict[str, tf.Tensor]:
231242
critic_losses.append(critic_loss)
232243

233244
# Actor learning
234-
if edge_index == 0:
235-
dpg_a_t = self._policy_networks(o_t)
236-
else:
237-
dpg_a_t = tf.reshape(edge_next_a_t, shape=[batch_size, self._edge_number, self._edge_action_size])[:, 0, :]
238-
for i in range(self._edge_number):
239-
if i != 0 and i != edge_index:
240-
dpg_a_t = tf.concat([dpg_a_t, tf.reshape(edge_next_a_t, shape=[batch_size, self._edge_number, self._edge_action_size])[:, i, :]], axis=1)
241-
elif i != 0 and i == edge_index:
242-
dpg_a_t = tf.concat([dpg_a_t, self._policy_networks(o_t)], axis=1)
245+
policy_a_t = self._policy_networks(o_t)
246+
247+
dpg_a_t = tf2_utils.batch_concat([
248+
edge_next_a_t[:, : edge_index, :],
249+
edge_next_a_t[:, edge_index + 1 :, :],
250+
policy_a_t,
251+
])
252+
253+
# if edge_index == 0:
254+
# dpg_a_t = policy_a_t
255+
# else:
256+
# dpg_a_t = edge_next_a_t[:, 0, :]
257+
# for i in range(self._edge_number):
258+
# if i != 0 and i != edge_index:
259+
# dpg_a_t = tf.concat([dpg_a_t, edge_next_a_t[:, i, :]], axis=1)
260+
# elif i != 0 and i == edge_index:
261+
# dpg_a_t = tf.concat([dpg_a_t, policy_a_t], axis=1)
243262

244263
dpg_z_t = self._critic_networks(o_t, dpg_a_t)
245264
dpg_q_t = dpg_z_t.mean()
246-
247265
# Actor loss. If clipping is true use dqda clipping and clip the norm.
248266
dqda_clipping = 1.0 if self._clipping else None
249267
# myapp.debug(f"dpg_q_t: {np.array(dpg_q_t)}")
250268
# myapp.debug(f"dpg_a_t: {np.array(dpg_a_t)}")
251269
policy_loss = losses.dpg(
252270
dpg_q_t,
253-
dpg_a_t,
271+
policy_a_t,
254272
tape=tape,
255273
dqda_clipping=dqda_clipping,
256274
clip_norm=self._clipping)

Agents/MAD5PG/multiplexers.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# python3
2+
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Multiplexers are networks that take multiple inputs."""
17+
18+
from typing import Callable, Optional, Union
19+
20+
from acme import types
21+
from acme.tf import utils as tf2_utils
22+
23+
import sonnet as snt
24+
import tensorflow as tf
25+
import tensorflow_probability as tfp
26+
27+
tfd = tfp.distributions
28+
TensorTransformation = Union[snt.Module, Callable[[types.NestedTensor],
29+
tf.Tensor]]
30+
31+
32+
class CriticMultiplexer(snt.Module):
33+
"""Module connecting a critic torso to (transformed) observations/actions.
34+
35+
This takes as input a `critic_network`, an `observation_network`, and an
36+
`action_network` and returns another network whose outputs are given by
37+
`critic_network(observation_network(o), action_network(a))`.
38+
39+
The observations and actions passed to this module are assumed to have a batch
40+
dimension that match.
41+
42+
Notes:
43+
- Either the `observation_` or `action_network` can be `None`, in which case
44+
the observation or action, resp., are passed to the critic network as is.
45+
- If all `critic_`, `observation_` and `action_network` are `None`, this
46+
module reduces to a simple `tf2_utils.batch_concat()`.
47+
"""
48+
49+
def __init__(self,
50+
critic_network: Optional[TensorTransformation] = None,
51+
observation_network: Optional[TensorTransformation] = None,
52+
action_network: Optional[TensorTransformation] = None):
53+
self._critic_network = critic_network
54+
self._observation_network = observation_network
55+
self._action_network = action_network
56+
super().__init__(name='critic_multiplexer')
57+
58+
def __call__(self,
59+
observation: types.NestedTensor,
60+
action: types.NestedTensor) -> tf.Tensor:
61+
62+
# Maybe transform observations and actions before feeding them on.
63+
if self._observation_network:
64+
observation = self._observation_network(observation)
65+
if self._action_network:
66+
action = self._action_network(action)
67+
68+
if hasattr(observation, 'dtype') and hasattr(action, 'dtype'):
69+
if observation.dtype != action.dtype:
70+
# Observation and action must be the same type for concat to work
71+
action = tf.cast(action, observation.dtype)
72+
73+
# Concat observations and actions, with one batch dimension.
74+
outputs = tf2_utils.batch_concat([observation, action])
75+
76+
# Maybe transform output before returning.
77+
if self._critic_network:
78+
outputs = self._critic_network(outputs)
79+
80+
return outputs

0 commit comments

Comments
 (0)