1
+ """Generic actor implementation, using TensorFlow and Sonnet."""
2
+
3
+ from typing import Optional , List
4
+ from acme import adders
5
+ from acme import core
6
+ from acme import types
7
+ from acme .tf import utils as tf2_utils
8
+ from acme .tf import variable_utils as tf2_variable_utils
9
+ import dm_env
10
+ import sonnet as snt
11
+ import tensorflow as tf
12
+ import tensorflow_probability as tfp
13
+ import numpy as np
14
+ from Log .logger import myapp
15
+ tfd = tfp .distributions
16
+
17
+
18
+ class FeedForwardActor (core .Actor ):
19
+ """A feed-forward actor.
20
+
21
+ An actor based on a feed-forward policy which takes non-batched observations
22
+ and outputs non-batched actions. It also allows adding experiences to replay
23
+ and updating the weights from the policy on the learner.
24
+ """
25
+
26
+ def __init__ (
27
+ self ,
28
+ policy_networks : snt .Module ,
29
+
30
+ edge_number : int ,
31
+ edge_action_size : int ,
32
+
33
+ adder : Optional [adders .Adder ] = None ,
34
+ variable_client : Optional [tf2_variable_utils .VariableClient ] = None ,
35
+ ):
36
+ """Initializes the actor.
37
+
38
+ Args:
39
+ policy_network: A module which takes observations and outputs
40
+ actions.
41
+ adder: the adder object to which allows to add experiences to a
42
+ dataset/replay buffer.
43
+ variable_client: object which allows to copy weights from the learner copy
44
+ of the policy to the actor copy (in case they are separate).
45
+ """
46
+
47
+ # Store these for later use.
48
+ self ._adder = adder
49
+ self ._variable_client = variable_client
50
+ self ._policy_networks = policy_networks
51
+
52
+ self ._edge_number = edge_number
53
+ self ._edge_action_size = edge_action_size
54
+
55
+ @tf .function (experimental_relax_shapes = True )
56
+ def _policy (
57
+ self ,
58
+ observations : types .NestedTensor ,
59
+ ) -> types .NestedTensor :
60
+ # # Add a dummy batch dimension and as a side effect convert numpy to TF.
61
+ # Compute the policy, conditioned on the observation.
62
+ # myapp.debug(f"observations: {np.array(observations)}")
63
+ edge_actions = []
64
+ for i in range (self ._edge_number ):
65
+ # myapp.debug(f"i: {i}")
66
+ edge_observation = observations [i , :]
67
+ # myapp.debug(f"edge_observation: {np.array(edge_observation)}")
68
+ edge_batched_observation = tf2_utils .add_batch_dim (edge_observation )
69
+ # myapp.debug(f"edge_batched_observation: {edge_batched_observation}")
70
+ edge_policy = self ._policy_networks (edge_batched_observation )
71
+ edge_action = edge_policy .sample () if isinstance (edge_policy , tfd .Distribution ) else edge_policy
72
+ # myapp.debug(f"edge_action: {edge_action}")
73
+ edge_actions .append (edge_action )
74
+
75
+ edge_actions = tf .convert_to_tensor (edge_actions , dtype = tf .float64 )
76
+ # myapp.debug(f"edge_actions: {edge_actions}")
77
+ action = tf .reshape (edge_actions , [self ._edge_number , self ._edge_action_size ])
78
+ # myapp.debug(f"action: {action}")
79
+ return action
80
+
81
+ def select_action (self , observation : types .NestedArray ) -> types .NestedArray :
82
+ # Pass the observation through the policy network.
83
+ action = self ._policy (
84
+ observations = tf .convert_to_tensor (observation , dtype = tf .float64 ))
85
+ # Return a numpy array with squeezed out batch dimension.
86
+ return action
87
+
88
+ def observe_first (self , timestep : dm_env .TimeStep ):
89
+ if self ._adder :
90
+ self ._adder .add_first (timestep )
91
+
92
+ def observe (self , action : types .NestedArray , next_timestep : dm_env .TimeStep ):
93
+ if self ._adder :
94
+ self ._adder .add (action , next_timestep )
95
+
96
+ def update (self , wait : bool = False ):
97
+ if self ._variable_client :
98
+ self ._variable_client .update (wait )
0 commit comments