Skip to content

Commit 78e052b

Browse files
author
Ervin Teng
committed
Use attention tests from master
1 parent 492fd17 commit 78e052b

File tree

2 files changed

+137
-95
lines changed

2 files changed

+137
-95
lines changed

ml-agents/mlagents/trainers/tests/torch/test_attention.py

+59-12
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import pytest
12
from mlagents.torch_utils import torch
23
import numpy as np
34

45
from mlagents.trainers.torch.utils import ModelUtils
56
from mlagents.trainers.torch.layers import linear_layer, LinearEncoder
67
from mlagents.trainers.torch.attention import (
78
MultiHeadAttention,
8-
EntityEmbeddings,
9+
EntityEmbedding,
910
ResidualSelfAttention,
11+
get_zero_entities_mask,
1012
)
1113

1214

@@ -71,7 +73,7 @@ def generate_input_helper(pattern):
7173
input_1 = generate_input_helper(masking_pattern_1)
7274
input_2 = generate_input_helper(masking_pattern_2)
7375

74-
masks = EntityEmbeddings.get_masks([input_1, input_2])
76+
masks = get_zero_entities_mask([input_1, input_2])
7577
assert len(masks) == 2
7678
masks_1 = masks[0]
7779
masks_2 = masks[1]
@@ -83,13 +85,60 @@ def generate_input_helper(pattern):
8385
assert masks_2[0, 1] == 0 if i % 2 == 0 else 1
8486

8587

88+
@pytest.mark.parametrize("mask_value", [0, 1])
89+
def test_all_masking(mask_value):
90+
# We make sure that a mask of all zeros or all ones will not trigger an error
91+
np.random.seed(1336)
92+
torch.manual_seed(1336)
93+
size, n_k, = 3, 5
94+
embedding_size = 64
95+
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
96+
entity_embeddings.add_self_embedding(size)
97+
transformer = ResidualSelfAttention(embedding_size, n_k)
98+
l_layer = linear_layer(embedding_size, size)
99+
optimizer = torch.optim.Adam(
100+
list(entity_embeddings.parameters())
101+
+ list(transformer.parameters())
102+
+ list(l_layer.parameters()),
103+
lr=0.001,
104+
weight_decay=1e-6,
105+
)
106+
batch_size = 20
107+
for _ in range(5):
108+
center = torch.rand((batch_size, size))
109+
key = torch.rand((batch_size, n_k, size))
110+
with torch.no_grad():
111+
# create the target : The key closest to the query in euclidean distance
112+
distance = torch.sum(
113+
(center.reshape((batch_size, 1, size)) - key) ** 2, dim=2
114+
)
115+
argmin = torch.argmin(distance, dim=1)
116+
target = []
117+
for i in range(batch_size):
118+
target += [key[i, argmin[i], :]]
119+
target = torch.stack(target, dim=0)
120+
target = target.detach()
121+
122+
embeddings = entity_embeddings(center, key)
123+
masks = [torch.ones_like(key[:, :, 0]) * mask_value]
124+
prediction = transformer.forward(embeddings, masks)
125+
prediction = l_layer(prediction)
126+
prediction = prediction.reshape((batch_size, size))
127+
error = torch.mean((prediction - target) ** 2, dim=1)
128+
error = torch.mean(error) / 2
129+
optimizer.zero_grad()
130+
error.backward()
131+
optimizer.step()
132+
133+
86134
def test_predict_closest_training():
87135
np.random.seed(1336)
88136
torch.manual_seed(1336)
89137
size, n_k, = 3, 5
90138
embedding_size = 64
91-
entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k])
92-
transformer = ResidualSelfAttention(embedding_size, [n_k])
139+
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
140+
entity_embeddings.add_self_embedding(size)
141+
transformer = ResidualSelfAttention(embedding_size, n_k)
93142
l_layer = linear_layer(embedding_size, size)
94143
optimizer = torch.optim.Adam(
95144
list(entity_embeddings.parameters())
@@ -114,8 +163,8 @@ def test_predict_closest_training():
114163
target = torch.stack(target, dim=0)
115164
target = target.detach()
116165

117-
embeddings = entity_embeddings(center, [key])
118-
masks = EntityEmbeddings.get_masks([key])
166+
embeddings = entity_embeddings(center, key)
167+
masks = get_zero_entities_mask([key])
119168
prediction = transformer.forward(embeddings, masks)
120169
prediction = l_layer(prediction)
121170
prediction = prediction.reshape((batch_size, size))
@@ -135,14 +184,12 @@ def test_predict_minimum_training():
135184
n_k = 5
136185
size = n_k + 1
137186
embedding_size = 64
138-
entity_embeddings = EntityEmbeddings(
139-
size, [size], embedding_size, [n_k], concat_self=False
140-
)
187+
entity_embedding = EntityEmbedding(size, n_k, embedding_size) # no self
141188
transformer = ResidualSelfAttention(embedding_size)
142189
l_layer = LinearEncoder(embedding_size, 2, n_k)
143190
loss = torch.nn.CrossEntropyLoss()
144191
optimizer = torch.optim.Adam(
145-
list(entity_embeddings.parameters())
192+
list(entity_embedding.parameters())
146193
+ list(transformer.parameters())
147194
+ list(l_layer.parameters()),
148195
lr=0.001,
@@ -166,8 +213,8 @@ def test_predict_minimum_training():
166213
sliced_oh = onehots[:, : num + 1]
167214
inp = torch.cat([inp, sliced_oh], dim=2)
168215

169-
embeddings = entity_embeddings(inp, [inp])
170-
masks = EntityEmbeddings.get_masks([inp])
216+
embeddings = entity_embedding(inp, inp)
217+
masks = get_zero_entities_mask([inp])
171218
prediction = transformer(embeddings, masks)
172219
prediction = l_layer(prediction)
173220
ce = loss(prediction, argmin)

ml-agents/mlagents/trainers/torch/attention.py

+78-83
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,41 @@
1010
from mlagents.trainers.exception import UnityTrainerException
1111

1212

13-
class MultiHeadAttention(torch.nn.Module):
13+
def get_zero_entities_mask(observations: List[torch.Tensor]) -> List[torch.Tensor]:
1414
"""
15-
Multi Head Attention module. We do not use the regular Torch implementation since
16-
Barracuda does not support some operators it uses.
17-
Takes as input to the forward method 3 tensors:
18-
- query: of dimensions (batch_size, number_of_queries, embedding_size)
19-
- key: of dimensions (batch_size, number_of_keys, embedding_size)
20-
- value: of dimensions (batch_size, number_of_keys, embedding_size)
21-
The forward method will return 2 tensors:
22-
- The output: (batch_size, number_of_queries, embedding_size)
23-
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
15+
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
16+
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
17+
layer to mask the padding observations.
2418
"""
19+
with torch.no_grad():
20+
# Generate the masking tensors for each entities tensor (mask only if all zeros)
21+
key_masks: List[torch.Tensor] = [
22+
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
23+
]
24+
return key_masks
25+
26+
27+
class MultiHeadAttention(torch.nn.Module):
2528

2629
NEG_INF = -1e6
2730

2831
def __init__(self, embedding_size: int, num_heads: int):
32+
"""
33+
Multi Head Attention module. We do not use the regular Torch implementation since
34+
Barracuda does not support some operators it uses.
35+
Takes as input to the forward method 3 tensors:
36+
- query: of dimensions (batch_size, number_of_queries, embedding_size)
37+
- key: of dimensions (batch_size, number_of_keys, embedding_size)
38+
- value: of dimensions (batch_size, number_of_keys, embedding_size)
39+
The forward method will return 2 tensors:
40+
- The output: (batch_size, number_of_queries, embedding_size)
41+
- The attention matrix: (batch_size, num_heads, number_of_queries, number_of_keys)
42+
:param embedding_size: The size of the embeddings that will be generated (should be
43+
dividable by the num_heads)
44+
:param total_max_elements: The maximum total number of entities that can be passed to
45+
the module
46+
:param num_heads: The number of heads of the attention module
47+
"""
2948
super().__init__()
3049
self.n_heads = num_heads
3150
self.head_size: int = embedding_size // self.n_heads
@@ -82,7 +101,7 @@ def forward(
82101
return value_attention, att
83102

84103

85-
class EntityEmbeddings(torch.nn.Module):
104+
class EntityEmbedding(torch.nn.Module):
86105
"""
87106
A module used to embed entities before passing them to a self-attention block.
88107
Used in conjunction with ResidualSelfAttention to encode information about a self
@@ -92,95 +111,69 @@ class EntityEmbeddings(torch.nn.Module):
92111

93112
def __init__(
94113
self,
95-
x_self_size: int,
96-
entity_sizes: List[int],
114+
entity_size: int,
115+
entity_num_max_elements: Optional[int],
97116
embedding_size: int,
98-
entity_num_max_elements: Optional[List[int]] = None,
99-
concat_self: bool = True,
100117
):
101118
"""
102-
Constructs an EntityEmbeddings module.
119+
Constructs an EntityEmbedding module.
103120
:param x_self_size: Size of "self" entity.
104-
:param entity_sizes: List of sizes for other entities. Should be of length
105-
equivalent to the number of entities.
106-
:param embedding_size: Embedding size for entity encoders.
107-
:param entity_num_max_elements: Maximum elements in an entity, None for unrestricted.
121+
:param entity_size: Size of other entities.
122+
:param entity_num_max_elements: Maximum elements for a given entity, None for unrestricted.
108123
Needs to be assigned in order for model to be exportable to ONNX and Barracuda.
109-
:param concat_self: Whether to concatenate x_self to entites. Set True for ego-centric
124+
:param embedding_size: Embedding size for the entity encoder.
125+
:param concat_self: Whether to concatenate x_self to entities. Set True for ego-centric
110126
self-attention.
111127
"""
112128
super().__init__()
113-
self.self_size: int = x_self_size
114-
self.entity_sizes: List[int] = entity_sizes
115-
self.entity_num_max_elements: List[int] = [-1] * len(entity_sizes)
129+
self.self_size: int = 0
130+
self.entity_size: int = entity_size
131+
self.entity_num_max_elements: int = -1
116132
if entity_num_max_elements is not None:
117133
self.entity_num_max_elements = entity_num_max_elements
118-
119-
self.concat_self: bool = concat_self
120-
# If not concatenating self, input to encoder is just entity size
121-
if not concat_self:
122-
self.self_size = 0
134+
self.embedding_size = embedding_size
123135
# Initialization scheme from http://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
124-
self.ent_encoders = torch.nn.ModuleList(
125-
[
126-
LinearEncoder(
127-
self.self_size + ent_size,
128-
1,
129-
embedding_size,
130-
kernel_init=Initialization.Normal,
131-
kernel_gain=(0.125 / embedding_size) ** 0.5,
132-
)
133-
for ent_size in self.entity_sizes
134-
]
136+
self.self_ent_encoder = LinearEncoder(
137+
self.entity_size,
138+
1,
139+
self.embedding_size,
140+
kernel_init=Initialization.Normal,
141+
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
135142
)
136-
self.embedding_norm = LayerNorm()
137143

138-
def forward(
139-
self, x_self: torch.Tensor, entities: List[torch.Tensor]
140-
) -> Tuple[torch.Tensor, int]:
141-
if self.concat_self:
142-
# Concatenate all observations with self
143-
self_and_ent: List[torch.Tensor] = []
144-
for num_entities, ent in zip(self.entity_num_max_elements, entities):
145-
if num_entities < 0:
146-
if exporting_to_onnx.is_exporting():
147-
raise UnityTrainerException(
148-
"Trying to export an attention mechanism that doesn't have a set max \
149-
number of elements."
150-
)
151-
num_entities = ent.shape[1]
152-
expanded_self = x_self.reshape(-1, 1, self.self_size)
153-
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
154-
self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
155-
else:
156-
self_and_ent = entities
157-
# Encode and concatenate entites
158-
encoded_entities = torch.cat(
159-
[ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent)],
160-
dim=1,
144+
def add_self_embedding(self, size: int) -> None:
145+
self.self_size = size
146+
self.self_ent_encoder = LinearEncoder(
147+
self.self_size + self.entity_size,
148+
1,
149+
self.embedding_size,
150+
kernel_init=Initialization.Normal,
151+
kernel_gain=(0.125 / self.embedding_size) ** 0.5,
161152
)
162-
encoded_entities = self.embedding_norm(encoded_entities)
163-
return encoded_entities
164153

165-
@staticmethod
166-
def get_masks(observations: List[torch.Tensor]) -> List[torch.Tensor]:
167-
"""
168-
Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
169-
all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
170-
layer to mask the padding observations.
171-
"""
172-
with torch.no_grad():
173-
# Generate the masking tensors for each entities tensor (mask only if all zeros)
174-
key_masks: List[torch.Tensor] = [
175-
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in observations
176-
]
177-
return key_masks
154+
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor:
155+
if self.self_size > 0:
156+
num_entities = self.entity_num_max_elements
157+
if num_entities < 0:
158+
if exporting_to_onnx.is_exporting():
159+
raise UnityTrainerException(
160+
"Trying to export an attention mechanism that doesn't have a set max \
161+
number of elements."
162+
)
163+
num_entities = entities.shape[1]
164+
expanded_self = x_self.reshape(-1, 1, self.self_size)
165+
expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
166+
# Concatenate all observations with self
167+
entities = torch.cat([expanded_self, entities], dim=2)
168+
# Encode entities
169+
encoded_entities = self.self_ent_encoder(entities)
170+
return encoded_entities
178171

179172

180173
class ResidualSelfAttention(torch.nn.Module):
181174
"""
182175
Residual self attentioninspired from https://arxiv.org/pdf/1909.07528.pdf. Can be used
183-
with an EntityEmbeddings module, to apply multi head self attention to encode information
176+
with an EntityEmbedding module, to apply multi head self attention to encode information
184177
about a "Self" and a list of relevant "Entities".
185178
"""
186179

@@ -189,7 +182,7 @@ class ResidualSelfAttention(torch.nn.Module):
189182
def __init__(
190183
self,
191184
embedding_size: int,
192-
entity_num_max_elements: Optional[List[int]] = None,
185+
entity_num_max_elements: Optional[int] = None,
193186
num_heads: int = 4,
194187
):
195188
"""
@@ -205,8 +198,7 @@ def __init__(
205198
super().__init__()
206199
self.max_num_ent: Optional[int] = None
207200
if entity_num_max_elements is not None:
208-
_entity_num_max_elements = entity_num_max_elements
209-
self.max_num_ent = sum(_entity_num_max_elements)
201+
self.max_num_ent = entity_num_max_elements
210202

211203
self.attention = MultiHeadAttention(
212204
num_heads=num_heads, embedding_size=embedding_size
@@ -237,11 +229,14 @@ def __init__(
237229
kernel_init=Initialization.Normal,
238230
kernel_gain=(0.125 / embedding_size) ** 0.5,
239231
)
232+
self.embedding_norm = LayerNorm()
240233
self.residual_norm = LayerNorm()
241234

242235
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor:
243236
# Gather the maximum number of entities information
244237
mask = torch.cat(key_masks, dim=1)
238+
239+
inp = self.embedding_norm(inp)
245240
# Feed to self attention
246241
query = self.fc_q(inp) # (b, n_q, emb)
247242
key = self.fc_k(inp) # (b, n_k, emb)

0 commit comments

Comments
 (0)