Skip to content

Commit 8de21ab

Browse files
committed
added start instances
1 parent b888213 commit 8de21ab

File tree

3 files changed

+103
-45
lines changed

3 files changed

+103
-45
lines changed

packages/aws-library/src/aws_library/ec2/client.py

Lines changed: 47 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import botocore.exceptions
99
from aiobotocore.session import ClientCreatorContext
1010
from aiocache import cached # type: ignore[import-untyped]
11-
from pydantic import ByteSize, PositiveInt, parse_obj_as
11+
from pydantic import ByteSize, PositiveInt
1212
from servicelib.logging_utils import log_context
1313
from settings_library.ec2 import EC2Settings
1414
from types_aiobotocore_ec2 import EC2Client
@@ -29,7 +29,7 @@
2929
EC2Tags,
3030
Resources,
3131
)
32-
from .utils import compose_user_data
32+
from .utils import compose_user_data, ec2_instance_data_from_aws_instance
3333

3434
_logger = logging.getLogger(__name__)
3535

@@ -193,27 +193,17 @@ async def launch_instances(
193193
await waiter.wait(InstanceIds=instance_ids)
194194
_logger.info("instances %s exists now.", instance_ids)
195195

196-
# get the private IPs
196+
# NOTE: waiting for pending ensure we get all the IPs back
197197
described_instances = await self.client.describe_instances(
198198
InstanceIds=instance_ids
199199
)
200+
assert "Instances" in described_instances["Reservations"][0] # nosec
200201
instance_datas = [
201-
EC2InstanceData(
202-
launch_time=instance["LaunchTime"],
203-
id=instance["InstanceId"],
204-
aws_private_dns=instance["PrivateDnsName"],
205-
aws_public_ip=instance.get("PublicIpAddress", None),
206-
type=instance["InstanceType"],
207-
state=instance["State"]["Name"],
208-
tags=parse_obj_as(
209-
EC2Tags, {tag["Key"]: tag["Value"] for tag in instance["Tags"]}
210-
),
211-
resources=instance_config.type.resources,
212-
)
213-
for instance in described_instances["Reservations"][0]["Instances"]
202+
await ec2_instance_data_from_aws_instance(self, i)
203+
for i in described_instances["Reservations"][0]["Instances"]
214204
]
215205
_logger.info(
216-
"%s is available, happy computing!!",
206+
"%s is pending now, happy computing!!",
217207
f"{instance_datas=}",
218208
)
219209
return instance_datas
@@ -245,38 +235,51 @@ async def get_instances(
245235
all_instances = []
246236
for reservation in instances["Reservations"]:
247237
assert "Instances" in reservation # nosec
248-
for instance in reservation["Instances"]:
249-
assert "LaunchTime" in instance # nosec
250-
assert "InstanceId" in instance # nosec
251-
assert "PrivateDnsName" in instance # nosec
252-
assert "InstanceType" in instance # nosec
253-
assert "State" in instance # nosec
254-
assert "Name" in instance["State"] # nosec
255-
ec2_instance_types = await self.get_ec2_instance_capabilities(
256-
{instance["InstanceType"]}
257-
)
258-
assert len(ec2_instance_types) == 1 # nosec
259-
assert "Tags" in instance # nosec
260-
all_instances.append(
261-
EC2InstanceData(
262-
launch_time=instance["LaunchTime"],
263-
id=instance["InstanceId"],
264-
aws_private_dns=instance["PrivateDnsName"],
265-
aws_public_ip=instance.get("PublicIpAddress", None),
266-
type=instance["InstanceType"],
267-
state=instance["State"]["Name"],
268-
resources=ec2_instance_types[0].resources,
269-
tags=parse_obj_as(
270-
EC2Tags,
271-
{tag["Key"]: tag["Value"] for tag in instance["Tags"]},
272-
),
273-
)
274-
)
238+
all_instances.extend(
239+
[
240+
await ec2_instance_data_from_aws_instance(self, i)
241+
for i in reservation["Instances"]
242+
]
243+
)
275244
_logger.debug(
276245
"received: %s instances with %s", f"{len(all_instances)}", f"{state_names=}"
277246
)
278247
return all_instances
279248

249+
async def start_instances(
250+
self, instance_datas: Iterable[EC2InstanceData]
251+
) -> list[EC2InstanceData]:
252+
try:
253+
instance_ids = [i.id for i in instance_datas]
254+
with log_context(
255+
_logger,
256+
logging.INFO,
257+
msg=f"starting instances {instance_ids}",
258+
):
259+
await self.client.start_instances(InstanceIds=instance_ids)
260+
# wait for the instance to be in a pending state
261+
# NOTE: reference to EC2 states https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-instance-lifecycle.html
262+
waiter = self.client.get_waiter("instance_exists")
263+
await waiter.wait(InstanceIds=instance_ids)
264+
_logger.info("instances %s exists now.", instance_ids)
265+
# NOTE: waiting for pending ensure we get all the IPs back
266+
aws_instances = await self.client.describe_instances(
267+
InstanceIds=instance_ids
268+
)
269+
assert len(aws_instances["Reservations"]) == 1 # nosec
270+
assert "Instances" in aws_instances["Reservations"][0] # nosec
271+
return [
272+
await ec2_instance_data_from_aws_instance(self, i)
273+
for i in aws_instances["Reservations"][0]["Instances"]
274+
]
275+
except botocore.exceptions.ClientError as exc:
276+
if (
277+
exc.response.get("Error", {}).get("Code", "")
278+
== "InvalidInstanceID.NotFound"
279+
):
280+
raise EC2InstanceNotFoundError from exc
281+
raise # pragma: no cover
282+
280283
async def stop_instances(self, instance_datas: Iterable[EC2InstanceData]) -> None:
281284
try:
282285
with log_context(

packages/aws-library/src/aws_library/ec2/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
11
from textwrap import dedent
2+
from typing import TYPE_CHECKING, cast
3+
4+
from types_aiobotocore_ec2.type_defs import InstanceTypeDef
5+
6+
from .models import EC2InstanceData, EC2Tags
7+
8+
if TYPE_CHECKING:
9+
from .client import SimcoreEC2API
210

311

412
def compose_user_data(docker_join_bash_command: str) -> str:
@@ -8,3 +16,30 @@ def compose_user_data(docker_join_bash_command: str) -> str:
816
{docker_join_bash_command}
917
"""
1018
)
19+
20+
21+
async def ec2_instance_data_from_aws_instance(
22+
ec2_client: "SimcoreEC2API",
23+
instance: InstanceTypeDef,
24+
) -> EC2InstanceData:
25+
assert "LaunchTime" in instance # nosec
26+
assert "InstanceId" in instance # nosec
27+
assert "PrivateDnsName" in instance # nosec
28+
assert "InstanceType" in instance # nosec
29+
assert "State" in instance # nosec
30+
assert "Name" in instance["State"] # nosec
31+
ec2_instance_types = await ec2_client.get_ec2_instance_capabilities(
32+
{instance["InstanceType"]}
33+
)
34+
assert len(ec2_instance_types) == 1 # nosec
35+
assert "Tags" in instance # nosec
36+
return EC2InstanceData(
37+
launch_time=instance["LaunchTime"],
38+
id=instance["InstanceId"],
39+
aws_private_dns=instance["PrivateDnsName"],
40+
aws_public_ip=instance.get("PublicIpAddress", None),
41+
type=instance["InstanceType"],
42+
state=instance["State"]["Name"],
43+
resources=ec2_instance_types[0].resources,
44+
tags=cast(EC2Tags, {tag["Key"]: tag["Value"] for tag in instance["Tags"]}),
45+
)

packages/aws-library/tests/test_ec2_client.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import random
77
from collections.abc import AsyncIterator, Callable
8+
from dataclasses import fields
89
from typing import cast, get_args
910

1011
import botocore.exceptions
@@ -343,7 +344,7 @@ async def test_get_instances(
343344
assert not instance_received
344345

345346

346-
async def test_stop_instances(
347+
async def test_stop_start_instances(
347348
simcore_ec2_api: SimcoreEC2API,
348349
ec2_client: EC2Client,
349350
faker: Faker,
@@ -389,6 +390,25 @@ async def test_stop_instances(
389390
expected_state="stopped",
390391
)
391392

393+
# start the instances now
394+
started_instances = await simcore_ec2_api.start_instances(created_instances)
395+
await _assert_instances_in_ec2(
396+
ec2_client,
397+
expected_num_reservations=1,
398+
expected_num_instances=num_instances,
399+
expected_instance_type=ec2_instance_config.type,
400+
expected_tags=ec2_instance_config.tags,
401+
expected_state="running",
402+
)
403+
# the public IPs change when the instances are stopped and started
404+
for s, c in zip(started_instances, created_instances, strict=True):
405+
# the rest shall be the same
406+
for f in fields(EC2InstanceData):
407+
if f.name == "aws_public_ip":
408+
assert getattr(s, f.name) != getattr(c, f.name)
409+
else:
410+
assert getattr(s, f.name) == getattr(c, f.name)
411+
392412

393413
async def test_terminate_instance(
394414
simcore_ec2_api: SimcoreEC2API,

0 commit comments

Comments
 (0)