10
10
from mlagents .trainers .exception import UnityTrainerException
11
11
12
12
13
- class MultiHeadAttention ( torch .nn . Module ) :
13
+ def get_zero_entities_mask ( observations : List [ torch .Tensor ]) -> List [ torch . Tensor ] :
14
14
"""
15
- Multi Head Attention module. We do not use the regular Torch implementation since
16
- Barracuda does not support some operators it uses.
17
- Takes as input to the forward method 3 tensors:
18
- - query: of dimensions (batch_size, number_of_queries, embedding_size)
19
- - key: of dimensions (batch_size, number_of_keys, embedding_size)
20
- - value: of dimensions (batch_size, number_of_keys, embedding_size)
21
- The forward method will return 2 tensors:
22
- - The output: (batch_size, number_of_queries, embedding_size)
23
- - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
15
+ Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
16
+ all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
17
+ layer to mask the padding observations.
24
18
"""
19
+ with torch .no_grad ():
20
+ # Generate the masking tensors for each entities tensor (mask only if all zeros)
21
+ key_masks : List [torch .Tensor ] = [
22
+ (torch .sum (ent ** 2 , axis = 2 ) < 0.01 ).float () for ent in observations
23
+ ]
24
+ return key_masks
25
+
26
+
27
+ class MultiHeadAttention (torch .nn .Module ):
25
28
26
29
NEG_INF = - 1e6
27
30
28
31
def __init__ (self , embedding_size : int , num_heads : int ):
32
+ """
33
+ Multi Head Attention module. We do not use the regular Torch implementation since
34
+ Barracuda does not support some operators it uses.
35
+ Takes as input to the forward method 3 tensors:
36
+ - query: of dimensions (batch_size, number_of_queries, embedding_size)
37
+ - key: of dimensions (batch_size, number_of_keys, embedding_size)
38
+ - value: of dimensions (batch_size, number_of_keys, embedding_size)
39
+ The forward method will return 2 tensors:
40
+ - The output: (batch_size, number_of_queries, embedding_size)
41
+ - The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
42
+ :param embedding_size: The size of the embeddings that will be generated (should be
43
+ dividable by the num_heads)
44
+ :param total_max_elements: The maximum total number of entities that can be passed to
45
+ the module
46
+ :param num_heads: The number of heads of the attention module
47
+ """
29
48
super ().__init__ ()
30
49
self .n_heads = num_heads
31
50
self .head_size : int = embedding_size // self .n_heads
@@ -82,7 +101,7 @@ def forward(
82
101
return value_attention , att
83
102
84
103
85
- class EntityEmbeddings (torch .nn .Module ):
104
+ class EntityEmbedding (torch .nn .Module ):
86
105
"""
87
106
A module used to embed entities before passing them to a self-attention block.
88
107
Used in conjunction with ResidualSelfAttention to encode information about a self
@@ -92,95 +111,69 @@ class EntityEmbeddings(torch.nn.Module):
92
111
93
112
def __init__ (
94
113
self ,
95
- x_self_size : int ,
96
- entity_sizes : List [int ],
114
+ entity_size : int ,
115
+ entity_num_max_elements : Optional [int ],
97
116
embedding_size : int ,
98
- entity_num_max_elements : Optional [List [int ]] = None ,
99
- concat_self : bool = True ,
100
117
):
101
118
"""
102
- Constructs an EntityEmbeddings module.
119
+ Constructs an EntityEmbedding module.
103
120
:param x_self_size: Size of "self" entity.
104
- :param entity_sizes: List of sizes for other entities. Should be of length
105
- equivalent to the number of entities.
106
- :param embedding_size: Embedding size for entity encoders.
107
- :param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
121
+ :param entity_size: Size of other entities.
122
+ :param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
108
123
Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
109
- :param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
124
+ :param embedding_size: Embedding size for the entity encoder.
125
+ :param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
110
126
self-attention.
111
127
"""
112
128
super ().__init__ ()
113
- self .self_size : int = x_self_size
114
- self .entity_sizes : List [ int ] = entity_sizes
115
- self .entity_num_max_elements : List [ int ] = [ - 1 ] * len ( entity_sizes )
129
+ self .self_size : int = 0
130
+ self .entity_size : int = entity_size
131
+ self .entity_num_max_elements : int = - 1
116
132
if entity_num_max_elements is not None :
117
133
self .entity_num_max_elements = entity_num_max_elements
118
-
119
- self .concat_self : bool = concat_self
120
- # If not concatenating self, input to encoder is just entity size
121
- if not concat_self :
122
- self .self_size = 0
134
+ self .embedding_size = embedding_size
123
135
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
124
- self .ent_encoders = torch .nn .ModuleList (
125
- [
126
- LinearEncoder (
127
- self .self_size + ent_size ,
128
- 1 ,
129
- embedding_size ,
130
- kernel_init = Initialization .Normal ,
131
- kernel_gain = (0.125 / embedding_size ) ** 0.5 ,
132
- )
133
- for ent_size in self .entity_sizes
134
- ]
136
+ self .self_ent_encoder = LinearEncoder (
137
+ self .entity_size ,
138
+ 1 ,
139
+ self .embedding_size ,
140
+ kernel_init = Initialization .Normal ,
141
+ kernel_gain = (0.125 / self .embedding_size ) ** 0.5 ,
135
142
)
136
- self .embedding_norm = LayerNorm ()
137
143
138
- def forward (
139
- self , x_self : torch .Tensor , entities : List [torch .Tensor ]
140
- ) -> Tuple [torch .Tensor , int ]:
141
- if self .concat_self :
142
- # Concatenate all observations with self
143
- self_and_ent : List [torch .Tensor ] = []
144
- for num_entities , ent in zip (self .entity_num_max_elements , entities ):
145
- if num_entities < 0 :
146
- if exporting_to_onnx .is_exporting ():
147
- raise UnityTrainerException (
148
- "Trying to export an attention mechanism that doesn't have a set max \
149
- number of elements."
150
- )
151
- num_entities = ent .shape [1 ]
152
- expanded_self = x_self .reshape (- 1 , 1 , self .self_size )
153
- expanded_self = torch .cat ([expanded_self ] * num_entities , dim = 1 )
154
- self_and_ent .append (torch .cat ([expanded_self , ent ], dim = 2 ))
155
- else :
156
- self_and_ent = entities
157
- # Encode and concatenate entites
158
- encoded_entities = torch .cat (
159
- [ent_encoder (x ) for ent_encoder , x in zip (self .ent_encoders , self_and_ent )],
160
- dim = 1 ,
144
+ def add_self_embedding (self , size : int ) -> None :
145
+ self .self_size = size
146
+ self .self_ent_encoder = LinearEncoder (
147
+ self .self_size + self .entity_size ,
148
+ 1 ,
149
+ self .embedding_size ,
150
+ kernel_init = Initialization .Normal ,
151
+ kernel_gain = (0.125 / self .embedding_size ) ** 0.5 ,
161
152
)
162
- encoded_entities = self .embedding_norm (encoded_entities )
163
- return encoded_entities
164
153
165
- @staticmethod
166
- def get_masks (observations : List [torch .Tensor ]) -> List [torch .Tensor ]:
167
- """
168
- Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
169
- all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
170
- layer to mask the padding observations.
171
- """
172
- with torch .no_grad ():
173
- # Generate the masking tensors for each entities tensor (mask only if all zeros)
174
- key_masks : List [torch .Tensor ] = [
175
- (torch .sum (ent ** 2 , axis = 2 ) < 0.01 ).float () for ent in observations
176
- ]
177
- return key_masks
154
+ def forward (self , x_self : torch .Tensor , entities : torch .Tensor ) -> torch .Tensor :
155
+ if self .self_size > 0 :
156
+ num_entities = self .entity_num_max_elements
157
+ if num_entities < 0 :
158
+ if exporting_to_onnx .is_exporting ():
159
+ raise UnityTrainerException (
160
+ "Trying to export an attention mechanism that doesn't have a set max \
161
+ number of elements."
162
+ )
163
+ num_entities = entities .shape [1 ]
164
+ expanded_self = x_self .reshape (- 1 , 1 , self .self_size )
165
+ expanded_self = torch .cat ([expanded_self ] * num_entities , dim = 1 )
166
+ # Concatenate all observations with self
167
+ entities = torch .cat ([expanded_self , entities ], dim = 2 )
168
+ # Encode entities
169
+ encoded_entities = self .self_ent_encoder (entities )
170
+ return encoded_entities
178
171
179
172
180
173
class ResidualSelfAttention (torch .nn .Module ):
181
174
"""
182
175
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used
183
- with an EntityEmbeddings module, to apply multi head self attention to encode information
176
+ with an EntityEmbedding module, to apply multi head self attention to encode information
184
177
about a "Self" and a list of relevant "Entities".
185
178
"""
186
179
@@ -189,7 +182,7 @@ class ResidualSelfAttention(torch.nn.Module):
189
182
def __init__ (
190
183
self ,
191
184
embedding_size : int ,
192
- entity_num_max_elements : Optional [List [ int ] ] = None ,
185
+ entity_num_max_elements : Optional [int ] = None ,
193
186
num_heads : int = 4 ,
194
187
):
195
188
"""
@@ -205,8 +198,7 @@ def __init__(
205
198
super ().__init__ ()
206
199
self .max_num_ent : Optional [int ] = None
207
200
if entity_num_max_elements is not None :
208
- _entity_num_max_elements = entity_num_max_elements
209
- self .max_num_ent = sum (_entity_num_max_elements )
201
+ self .max_num_ent = entity_num_max_elements
210
202
211
203
self .attention = MultiHeadAttention (
212
204
num_heads = num_heads , embedding_size = embedding_size
@@ -237,11 +229,14 @@ def __init__(
237
229
kernel_init = Initialization .Normal ,
238
230
kernel_gain = (0.125 / embedding_size ) ** 0.5 ,
239
231
)
232
+ self .embedding_norm = LayerNorm ()
240
233
self .residual_norm = LayerNorm ()
241
234
242
235
def forward (self , inp : torch .Tensor , key_masks : List [torch .Tensor ]) -> torch .Tensor :
243
236
# Gather the maximum number of entities information
244
237
mask = torch .cat (key_masks , dim = 1 )
238
+
239
+ inp = self .embedding_norm (inp )
245
240
# Feed to self attention
246
241
query = self .fc_q (inp ) # (b, n_q, emb)
247
242
key = self .fc_k (inp ) # (b, n_k, emb)
0 commit comments