Skip to content

Commit 9a5a141

Browse files
Henry PeteetGitHub Enterprise
Henry Peteet
authored and
GitHub Enterprise
committed
Upgrade pre-commit tools (#12)
1 parent 1bde547 commit 9a5a141

22 files changed

+73
-56
lines changed

.pre-commit-config.yaml

+12-10
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,32 @@
11
repos:
22
- repo: https://github.com/python/black
3-
rev: 19.3b0
3+
rev: 22.1.0
44
hooks:
55
- id: black
66
exclude: >
77
(?x)^(
88
.*_pb2.py|
9+
.*_pb2.pyi|
910
.*_pb2_grpc.py
1011
)$
1112
1213
- repo: https://github.com/pre-commit/mirrors-mypy
13-
rev: v0.761
14+
rev: v0.931
1415
hooks:
1516
- id: mypy
1617
name: mypy-ml-agents
1718
files: "ml-agents/.*"
18-
args: [--ignore-missing-imports, --disallow-incomplete-defs]
19+
args: [--ignore-missing-imports, --disallow-incomplete-defs, --no-strict-optional]
20+
additional_dependencies: [types-PyYAML, types-attrs, types-protobuf, types-setuptools]
1921
- id: mypy
2022
name: mypy-ml-agents-envs
2123
files: "ml-agents-envs/.*"
2224
# Exclude protobuf files and don't follow them when imported
2325
exclude: ".*_pb2.py"
24-
args: [--ignore-missing-imports, --disallow-incomplete-defs]
25-
26+
args: [--ignore-missing-imports, --disallow-incomplete-defs, --no-strict-optional]
27+
additional_dependencies: [types-PyYAML, types-attrs, types-protobuf, types-setuptools]
2628
- repo: https://gitlab.com/pycqa/flake8
27-
rev: 3.8.1
29+
rev: 3.9.2
2830
hooks:
2931
- id: flake8
3032
exclude: >
@@ -36,7 +38,7 @@ repos:
3638
additional_dependencies: [flake8-comprehensions==3.2.2, flake8-tidy-imports==4.1.0, flake8-bugbear==20.1.4]
3739

3840
- repo: https://github.com/asottile/pyupgrade
39-
rev: v2.7.0
41+
rev: v2.31.0
4042
hooks:
4143
- id: pyupgrade
4244
args: [--py3-plus, --py36-plus]
@@ -47,7 +49,7 @@ repos:
4749
)$
4850
4951
- repo: https://github.com/pre-commit/pre-commit-hooks
50-
rev: v2.5.0
52+
rev: v4.1.0
5153
hooks:
5254
- id: mixed-line-ending
5355
exclude: >
@@ -68,12 +70,12 @@ repos:
6870
exclude: \.yamato/.*
6971

7072
- repo: https://github.com/pre-commit/pygrep-hooks
71-
rev: v1.4.2
73+
rev: v1.9.0
7274
hooks:
7375
- id: python-check-mock-methods
7476

7577
- repo: https://github.com/mattlqx/pre-commit-search-and-replace
76-
rev: v1.0.3
78+
rev: v1.0.5
7779
hooks:
7880
- id: search-and-replace
7981
types: [markdown]

ml-agents-envs/mlagents_envs/envs/unity_pettingzoo_base_env.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,15 @@ def _batch_update(self, behavior_name):
253253
self._current_action[behavior_name] = self._create_empty_actions(
254254
behavior_name, len(current_batch[0])
255255
)
256-
agents, obs, dones, rewards, cumulative_rewards, infos, id_map = _unwrap_batch_steps(
257-
current_batch, behavior_name
258-
)
256+
(
257+
agents,
258+
obs,
259+
dones,
260+
rewards,
261+
cumulative_rewards,
262+
infos,
263+
id_map,
264+
) = _unwrap_batch_steps(current_batch, behavior_name)
259265
self._live_agents += agents
260266
self._agents += agents
261267
self._observations.update(obs)

ml-agents-envs/mlagents_envs/registry/binary_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def download_and_extract_zip(url: str, name: str) -> None:
137137
try:
138138
request = urllib.request.urlopen(url, timeout=30)
139139
except urllib.error.HTTPError as e: # type: ignore
140-
e.msg += " " + url
140+
e.reason = f"{e.reason} {url}"
141141
raise
142142
zip_size = int(request.headers["content-length"])
143143
zip_file_path = os.path.join(zip_dir, str(uuid.uuid4()) + ".zip")
@@ -193,7 +193,7 @@ def load_remote_manifest(url: str) -> Dict[str, Any]:
193193
try:
194194
request = urllib.request.urlopen(url, timeout=30)
195195
except urllib.error.HTTPError as e: # type: ignore
196-
e.msg += " " + url
196+
e.reason = f"{e.reason} {url}"
197197
raise
198198
manifest_path = os.path.join(tmp_dir, str(uuid.uuid4()) + ".yaml")
199199
with open(manifest_path, "wb") as manifest:

ml-agents-envs/mlagents_envs/side_channel/side_channel_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def process_side_channel_message(self, data: bytes) -> None:
2121
try:
2222
channel_id = uuid.UUID(bytes_le=bytes(data[offset : offset + 16]))
2323
offset += 16
24-
message_len, = struct.unpack_from("<i", data, offset)
24+
(message_len,) = struct.unpack_from("<i", data, offset)
2525
offset = offset + 4
2626
message_data = data[offset : offset + message_len]
2727
offset = offset + message_len
@@ -63,7 +63,7 @@ def generate_side_channel_messages(self) -> bytearray:
6363

6464
@staticmethod
6565
def _get_side_channels_dict(
66-
side_channels: Optional[List[SideChannel]]
66+
side_channels: Optional[List[SideChannel]],
6767
) -> Dict[uuid.UUID, SideChannel]:
6868
"""
6969
Converts a list of side channels into a dictionary of channel_id to SideChannel

ml-agents-envs/setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,6 @@ def run(self):
5959
"numpy==1.21.2",
6060
],
6161
python_requires=">=3.7.2,<3.9.10",
62-
cmdclass={"verify": VerifyVersionCommand},
62+
# TODO: Remove this once mypy stops having spurious setuptools issues.
63+
cmdclass={"verify": VerifyVersionCommand}, # type: ignore
6364
)

ml-agents/mlagents/trainers/behavior_id_utils.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ def from_name_behavior_id(name_behavior_id: str) -> "BehaviorIdentifiers":
4242

4343
def create_name_behavior_id(name: str, team_id: int) -> str:
4444
"""
45-
Reconstructs fully qualified behavior name from name and team_id
46-
:param name: brain name
47-
:param team_id: team ID
48-
:return: name_behavior_id
49-
"""
45+
Reconstructs fully qualified behavior name from name and team_id
46+
:param name: brain name
47+
:param team_id: team ID
48+
:return: name_behavior_id
49+
"""
5050
return name + "?team=" + str(team_id)
5151

5252

ml-agents/mlagents/trainers/buffer.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,7 @@ def __init__(self):
264264
)
265265

266266
def __str__(self):
267-
return ", ".join(
268-
["'{}' : {}".format(k, str(self[k])) for k in self._fields.keys()]
269-
)
267+
return ", ".join([f"'{k}' : {str(self[k])}" for k in self._fields.keys()])
270268

271269
def reset_agent(self) -> None:
272270
"""

ml-agents/mlagents/trainers/environment_parameter_manager.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,10 @@ def update_lessons(
165165
):
166166
behavior_to_consider = lesson.completion_criteria.behavior
167167
if behavior_to_consider in trainer_steps:
168-
must_increment, new_smoothing = lesson.completion_criteria.need_increment(
168+
(
169+
must_increment,
170+
new_smoothing,
171+
) = lesson.completion_criteria.need_increment(
169172
float(trainer_steps[behavior_to_consider])
170173
/ float(trainer_max_steps[behavior_to_consider]),
171174
trainer_reward_buffer[behavior_to_consider],

ml-agents/mlagents/trainers/stats.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _dict_to_str(param_dict: Dict[str, Any], num_tabs: int) -> str:
3333
[
3434
"\t"
3535
+ " " * num_tabs
36-
+ "{}:\t{}".format(x, _dict_to_str(param_dict[x], num_tabs + 1))
36+
+ f"{x}:\t{_dict_to_str(param_dict[x], num_tabs + 1)}"
3737
for x in param_dict
3838
]
3939
)

ml-agents/mlagents/trainers/tests/test_trainer_controller.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def take_step_sideeffect(env):
7171

7272

7373
def test_start_learning_trains_forever_if_no_train_model(
74-
trainer_controller_with_start_learning_mocks
74+
trainer_controller_with_start_learning_mocks,
7575
):
7676
tc, trainer_mock = trainer_controller_with_start_learning_mocks
7777
tc.train_model = False
@@ -88,7 +88,7 @@ def test_start_learning_trains_forever_if_no_train_model(
8888

8989

9090
def test_start_learning_trains_until_max_steps_then_saves(
91-
trainer_controller_with_start_learning_mocks
91+
trainer_controller_with_start_learning_mocks,
9292
):
9393
tc, trainer_mock = trainer_controller_with_start_learning_mocks
9494

@@ -120,7 +120,7 @@ def trainer_controller_with_take_step_mocks(basic_trainer_controller):
120120

121121

122122
def test_advance_adds_experiences_to_trainer_and_trains(
123-
trainer_controller_with_take_step_mocks
123+
trainer_controller_with_take_step_mocks,
124124
):
125125
tc, trainer_mock = trainer_controller_with_take_step_mocks
126126

ml-agents/mlagents/trainers/tests/torch/test_action_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
def create_action_model(inp_size, act_size, deterministic=False):
15-
mask = torch.ones([1, act_size ** 2])
15+
mask = torch.ones([1, act_size**2])
1616
action_spec = ActionSpec(act_size, tuple(act_size for _ in range(act_size)))
1717
action_model = ActionModel(inp_size, action_spec, deterministic=deterministic)
1818
return action_model, mask

ml-agents/mlagents/trainers/tests/torch/test_attention.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ def test_all_masking(mask_value):
9090
# We make sure that a mask of all zeros or all ones will not trigger an error
9191
np.random.seed(1336)
9292
torch.manual_seed(1336)
93-
size, n_k, = 3, 5
93+
size, n_k, = (
94+
3,
95+
5,
96+
)
9497
embedding_size = 64
9598
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
9699
entity_embeddings.add_self_embedding(size)
@@ -134,7 +137,10 @@ def test_all_masking(mask_value):
134137
def test_predict_closest_training():
135138
np.random.seed(1336)
136139
torch.manual_seed(1336)
137-
size, n_k, = 3, 5
140+
size, n_k, = (
141+
3,
142+
5,
143+
)
138144
embedding_size = 64
139145
entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
140146
entity_embeddings.add_self_embedding(size)

ml-agents/mlagents/trainers/tests/torch/test_policy.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_sample_actions(rnn, visual, discrete):
138138

139139
def test_step_overflow():
140140
policy = create_policy_mock(TrainerSettings())
141-
policy.set_step(2 ** 31 - 1)
142-
assert policy.get_current_step() == 2 ** 31 - 1 # step = 2147483647
141+
policy.set_step(2**31 - 1)
142+
assert policy.get_current_step() == 2**31 - 1 # step = 2147483647
143143
policy.increment_step(3)
144-
assert policy.get_current_step() == 2 ** 31 + 2 # step = 2147483650
144+
assert policy.get_current_step() == 2**31 + 2 # step = 2147483650

ml-agents/mlagents/trainers/torch/attention.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
3939

4040
# Generate the masking tensors for each entities tensor (mask only if all zeros)
4141
key_masks: List[torch.Tensor] = [
42-
(torch.sum(ent ** 2, axis=2) < 0.01).float() for ent in entities
42+
(torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities
4343
]
4444
return key_masks
4545

@@ -101,11 +101,11 @@ def forward(
101101
qk = torch.matmul(query, key) # (b, h, n_q, n_k)
102102

103103
if key_mask is None:
104-
qk = qk / (self.embedding_size ** 0.5)
104+
qk = qk / (self.embedding_size**0.5)
105105
else:
106106
key_mask = key_mask.reshape(b, 1, 1, n_k)
107107
qk = (1 - key_mask) * qk / (
108-
self.embedding_size ** 0.5
108+
self.embedding_size**0.5
109109
) + key_mask * self.NEG_INF
110110

111111
att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k)

ml-agents/mlagents/trainers/torch/components/bc/module.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def __init__(
3333
self._anneal_steps = settings.steps
3434
self.current_lr = policy_learning_rate * settings.strength
3535

36-
learning_rate_schedule: ScheduleType = ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
36+
learning_rate_schedule: ScheduleType = (
37+
ScheduleType.LINEAR if self._anneal_steps > 0 else ScheduleType.CONSTANT
38+
)
3739
self.decay_learning_rate = ModelUtils.DecayedValue(
3840
learning_rate_schedule, self.current_lr, 1e-10, self._anneal_steps
3941
)

ml-agents/mlagents/trainers/torch/components/reward_providers/gail_reward_provider.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,10 @@ def compute_loss(
183183
kl_loss = torch.mean(
184184
-torch.sum(
185185
1
186-
+ (self._z_sigma ** 2).log()
187-
- 0.5 * expert_mu ** 2
188-
- 0.5 * policy_mu ** 2
189-
- (self._z_sigma ** 2),
186+
+ (self._z_sigma**2).log()
187+
- 0.5 * expert_mu**2
188+
- 0.5 * policy_mu**2
189+
- (self._z_sigma**2),
190190
dim=1,
191191
)
192192
)
@@ -255,6 +255,6 @@ def compute_gradient_magnitude(
255255
estimate = self._estimator(hidden).squeeze(1).sum()
256256
gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0]
257257
# Norm's gradient could be NaN at 0. Use our own safe_norm
258-
safe_norm = (torch.sum(gradient ** 2, dim=1) + self.EPSILON).sqrt()
258+
safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
259259
gradient_mag = torch.mean((safe_norm - 1) ** 2)
260260
return gradient_mag

ml-agents/mlagents/trainers/torch/distributions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def deterministic_sample(self):
7070
return self.mean
7171

7272
def log_prob(self, value):
73-
var = self.std ** 2
73+
var = self.std**2
7474
log_scale = torch.log(self.std + EPSILON)
7575
return (
7676
-((value - self.mean) ** 2) / (2 * var + EPSILON)
@@ -84,7 +84,7 @@ def pdf(self, value):
8484

8585
def entropy(self):
8686
return torch.mean(
87-
0.5 * torch.log(2 * math.pi * math.e * self.std ** 2 + EPSILON),
87+
0.5 * torch.log(2 * math.pi * math.e * self.std**2 + EPSILON),
8888
dim=1,
8989
keepdim=True,
9090
) # Use equivalent behavior to TF

ml-agents/mlagents/trainers/trainer/rl_trainer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def create_model_saver(
137137
return model_saver
138138

139139
def _policy_mean_reward(self) -> Optional[float]:
140-
""" Returns the mean episode reward for the current policy. """
140+
"""Returns the mean episode reward for the current policy."""
141141
rewards = self.cumulative_returns_since_policy_update
142142
if len(rewards) == 0:
143143
return None

ml-agents/setup.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,6 @@ def run(self):
8989
"default=mlagents.plugins.stats_writer:get_default_stats_writers"
9090
],
9191
},
92-
cmdclass={"verify": VerifyVersionCommand},
92+
# TODO: Remove this once mypy stops having spurious setuptools issues.
93+
cmdclass={"verify": VerifyVersionCommand}, # type: ignore
9394
)

ml-agents/tests/yamato/scripts/run_llapi.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ def test_run_environment(env_name):
4747
print("Is there a visual observation ?", vis_obs)
4848

4949
# Examine the state space for the first observation for the first agent
50-
print(
51-
"First Agent observation looks like: \n{}".format(decision_steps.obs[0][0])
52-
)
50+
print(f"First Agent observation looks like: \n{decision_steps.obs[0][0]}")
5351

5452
for _episode in range(10):
5553
env.reset()

ml-agents/tests/yamato/yamato_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_base_path():
2424

2525

2626
def get_base_output_path():
27-
""""
27+
""" "
2828
Returns the artifact folder to use for yamato jobs.
2929
"""
3030
return os.path.join(get_base_path(), "artifacts")

utils/validate_release_links.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def check_file(
188188
new_file.write(line)
189189
else:
190190
bad_lines.append(f"{filename}: {line}")
191-
new_line = re.sub(r"release_[0-9]+", fr"{release_tag}", line)
191+
new_line = re.sub(r"release_[0-9]+", rf"{release_tag}", line)
192192
new_line = update_pip_install_line(new_line, package_version)
193193
new_file.write(new_line)
194194
if bad_lines:
@@ -235,7 +235,7 @@ def main():
235235
print(f"Python package version: {package_version}")
236236
release_allow_pattern = re.compile(f"{release_tag}(_docs)?")
237237
pip_allow_pattern = re.compile(
238-
fr"python -m pip install (-q )?mlagents(_envs)?=={package_version}"
238+
rf"python -m pip install (-q )?mlagents(_envs)?=={package_version}"
239239
)
240240
bad_lines = check_all_files(
241241
release_allow_pattern, release_tag, pip_allow_pattern, package_version

0 commit comments

Comments
 (0)