Skip to content

Commit 75872b6

Browse files
authored
Support new GCP CPU series (#2685)
* Support new GCP CPU series * Handle Persistent disk volumes with new series * Fix tests
1 parent a6f857c commit 75872b6

File tree

8 files changed

+90
-38
lines changed

8 files changed

+90
-38
lines changed

src/dstack/_internal/core/backends/aws/compute.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -611,9 +611,12 @@ def delete_volume(self, volume: Volume):
611611
raise e
612612
logger.debug("Deleted EBS volume %s", volume.configuration.name)
613613

614-
def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
614+
def attach_volume(
615+
self, volume: Volume, provisioning_data: JobProvisioningData
616+
) -> VolumeAttachmentData:
615617
ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
616618

619+
instance_id = provisioning_data.instance_id
617620
device_names = aws_resources.list_available_device_names(
618621
ec2_client=ec2_client, instance_id=instance_id
619622
)
@@ -646,9 +649,12 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
646649
logger.debug("Attached EBS volume %s to instance %s", volume.volume_id, instance_id)
647650
return VolumeAttachmentData(device_name=device_name)
648651

649-
def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
652+
def detach_volume(
653+
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
654+
):
650655
ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
651656

657+
instance_id = provisioning_data.instance_id
652658
logger.debug("Detaching EBS volume %s from instance %s", volume.volume_id, instance_id)
653659
attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
654660
try:
@@ -667,9 +673,10 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
667673
raise e
668674
logger.debug("Detached EBS volume %s from instance %s", volume.volume_id, instance_id)
669675

670-
def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
676+
def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool:
671677
ec2_client = self.session.client("ec2", region_name=volume.configuration.region)
672678

679+
instance_id = provisioning_data.instance_id
673680
logger.debug("Getting EBS volume %s status", volume.volume_id)
674681
response = ec2_client.describe_volumes(VolumeIds=[volume.volume_id])
675682
volumes_infos = response.get("Volumes")

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,9 @@ def delete_volume(self, volume: Volume):
336336
"""
337337
raise NotImplementedError()
338338

339-
def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
339+
def attach_volume(
340+
self, volume: Volume, provisioning_data: JobProvisioningData
341+
) -> VolumeAttachmentData:
340342
"""
341343
Attaches a volume to the instance.
342344
If the volume is not found, it should raise `ComputeError()`.
@@ -345,15 +347,17 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
345347
"""
346348
raise NotImplementedError()
347349

348-
def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
350+
def detach_volume(
351+
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
352+
):
349353
"""
350354
Detaches a volume from the instance.
351355
Implement only if compute may return `VolumeProvisioningData.detachable`.
352356
Otherwise, volumes should be detached on instance termination.
353357
"""
354358
raise NotImplementedError()
355359

356-
def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
360+
def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool:
357361
"""
358362
Checks if a volume was detached from the instance.
359363
If `detach_volume()` may fail to detach volume,

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -649,32 +649,41 @@ def delete_volume(self, volume: Volume):
649649
pass
650650
logger.debug("Deleted persistent disk for volume %s", volume.name)
651651

652-
def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
652+
def attach_volume(
653+
self, volume: Volume, provisioning_data: JobProvisioningData
654+
) -> VolumeAttachmentData:
655+
instance_id = provisioning_data.instance_id
653656
logger.debug(
654657
"Attaching persistent disk for volume %s to instance %s",
655658
volume.volume_id,
656659
instance_id,
657660
)
661+
if not gcp_resources.instance_type_supports_persistent_disk(
662+
provisioning_data.instance_type.name
663+
):
664+
raise ComputeError(
665+
f"Instance type {provisioning_data.instance_type.name} does not support Persistent disk volumes"
666+
)
667+
658668
zone = get_or_error(volume.provisioning_data).availability_zone
669+
is_tpu = _is_tpu_provisioning_data(provisioning_data)
659670
try:
660671
disk = self.disk_client.get(
661672
project=self.config.project_id,
662673
zone=zone,
663674
disk=volume.volume_id,
664675
)
665676
disk_url = disk.self_link
677+
except google.api_core.exceptions.NotFound:
678+
raise ComputeError("Persistent disk found")
666679

667-
# This method has no information if the instance is a TPU or a VM,
668-
# so we first try to see if there is a TPU with such name
669-
try:
680+
try:
681+
if is_tpu:
670682
get_node_request = tpu_v2.GetNodeRequest(
671683
name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
672684
)
673685
tpu_node = self.tpu_client.get_node(get_node_request)
674-
except google.api_core.exceptions.NotFound:
675-
tpu_node = None
676686

677-
if tpu_node is not None:
678687
# Python API to attach a disk to a TPU is not documented,
679688
# so we follow the code from the gcloud CLI:
680689
# https://github.com/twistedpair/google-cloud-sdk/blob/26ab5a281d56b384cc25750f3279a27afe5b499f/google-cloud-sdk/lib/googlecloudsdk/command_lib/compute/tpus/tpu_vm/util.py#L113
@@ -711,7 +720,6 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
711720
attached_disk.auto_delete = False
712721
attached_disk.device_name = f"pd-{volume.volume_id}"
713722
device_name = attached_disk.device_name
714-
715723
operation = self.instances_client.attach_disk(
716724
project=self.config.project_id,
717725
zone=zone,
@@ -720,31 +728,33 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
720728
)
721729
gcp_resources.wait_for_extended_operation(operation, "persistent disk attachment")
722730
except google.api_core.exceptions.NotFound:
723-
raise ComputeError("Persistent disk or instance not found")
731+
raise ComputeError("Disk or instance not found")
724732
logger.debug(
725733
"Attached persistent disk for volume %s to instance %s", volume.volume_id, instance_id
726734
)
727735
return VolumeAttachmentData(device_name=device_name)
728736

729-
def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
737+
def detach_volume(
738+
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
739+
):
740+
instance_id = provisioning_data.instance_id
730741
logger.debug(
731742
"Detaching persistent disk for volume %s from instance %s",
732743
volume.volume_id,
733744
instance_id,
734745
)
735746
zone = get_or_error(volume.provisioning_data).availability_zone
736747
attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
737-
# This method has no information if the instance is a TPU or a VM,
738-
# so we first try to see if there is a TPU with such name
739-
try:
740-
get_node_request = tpu_v2.GetNodeRequest(
741-
name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
742-
)
743-
tpu_node = self.tpu_client.get_node(get_node_request)
744-
except google.api_core.exceptions.NotFound:
745-
tpu_node = None
748+
is_tpu = _is_tpu_provisioning_data(provisioning_data)
749+
if is_tpu:
750+
try:
751+
get_node_request = tpu_v2.GetNodeRequest(
752+
name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
753+
)
754+
tpu_node = self.tpu_client.get_node(get_node_request)
755+
except google.api_core.exceptions.NotFound:
756+
raise ComputeError("Instance not found")
746757

747-
if tpu_node is not None:
748758
source_disk = (
749759
f"projects/{self.config.project_id}/zones/{zone}/disks/{volume.volume_id}"
750760
)
@@ -815,6 +825,11 @@ def _filter(offer: InstanceOffer) -> bool:
815825
if _is_tpu(offer.instance.name) and not _is_single_host_tpu(offer.instance.name):
816826
return False
817827
for family in [
828+
"m4-",
829+
"c4-",
830+
"n4-",
831+
"h3-",
832+
"n2-",
818833
"e2-medium",
819834
"e2-standard-",
820835
"e2-highmem-",
@@ -1001,3 +1016,11 @@ def _get_tpu_data_disk_for_volume(project_id: str, volume: Volume) -> tpu_v2.Att
10011016
mode=tpu_v2.AttachedDisk.DiskMode.READ_WRITE,
10021017
)
10031018
return attached_disk
1019+
1020+
1021+
def _is_tpu_provisioning_data(provisioning_data: JobProvisioningData) -> bool:
1022+
is_tpu = False
1023+
if provisioning_data.backend_data:
1024+
backend_data_dict = json.loads(provisioning_data.backend_data)
1025+
is_tpu = backend_data_dict.get("is_tpu", False)
1026+
return is_tpu

src/dstack/_internal/core/backends/gcp/resources.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ def create_instance_struct(
140140
initialize_params = compute_v1.AttachedDiskInitializeParams()
141141
initialize_params.source_image = image_id
142142
initialize_params.disk_size_gb = disk_size
143-
initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"
143+
if instance_type_supports_persistent_disk(machine_type):
144+
initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"
145+
else:
146+
initialize_params.disk_type = f"zones/{zone}/diskTypes/hyperdisk-balanced"
144147
disk.initialize_params = initialize_params
145148
instance.disks = [disk]
146149

@@ -421,7 +424,7 @@ def wait_for_extended_operation(
421424

422425
if operation.error_code:
423426
# Write only debug logs here.
424-
# The unexpected errors will be propagated and logged appropriatly by the caller.
427+
# The unexpected errors will be propagated and logged appropriately by the caller.
425428
logger.debug(
426429
"Error during %s: [Code: %s]: %s",
427430
verbose_name,
@@ -462,3 +465,16 @@ def get_placement_policy_resource_name(
462465
placement_policy: str,
463466
) -> str:
464467
return f"projects/{project_id}/regions/{region}/resourcePolicies/{placement_policy}"
468+
469+
470+
def instance_type_supports_persistent_disk(instance_type_name: str) -> bool:
471+
return not any(
472+
instance_type_name.startswith(series)
473+
for series in [
474+
"m4-",
475+
"c4-",
476+
"n4-",
477+
"h3-",
478+
"v6e",
479+
]
480+
)

src/dstack/_internal/core/backends/local/compute.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,10 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
110110
def delete_volume(self, volume: Volume):
111111
pass
112112

113-
def attach_volume(self, volume: Volume, instance_id: str):
113+
def attach_volume(self, volume: Volume, provisioning_data: JobProvisioningData):
114114
pass
115115

116-
def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
116+
def detach_volume(
117+
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
118+
):
117119
pass

src/dstack/_internal/server/background/tasks/process_submitted_jobs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ async def _attach_volumes(
659659
backend=backend,
660660
volume_model=volume_model,
661661
instance=instance,
662-
instance_id=job_provisioning_data.instance_id,
662+
jpd=job_provisioning_data,
663663
)
664664
job_runtime_data.volume_names.append(volume.name)
665665
break # attach next mount point
@@ -685,7 +685,7 @@ async def _attach_volume(
685685
backend: Backend,
686686
volume_model: VolumeModel,
687687
instance: InstanceModel,
688-
instance_id: str,
688+
jpd: JobProvisioningData,
689689
):
690690
compute = backend.compute()
691691
assert isinstance(compute, ComputeWithVolumeSupport)
@@ -697,7 +697,7 @@ async def _attach_volume(
697697
attachment_data = await common_utils.run_async(
698698
compute.attach_volume,
699699
volume=volume,
700-
instance_id=instance_id,
700+
provisioning_data=jpd,
701701
)
702702
volume_attachment_model = VolumeAttachmentModel(
703703
volume=volume_model,

src/dstack/_internal/server/services/jobs/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,20 +470,20 @@ async def _detach_volume_from_job_instance(
470470
await run_async(
471471
compute.detach_volume,
472472
volume=volume,
473-
instance_id=jpd.instance_id,
473+
provisioning_data=jpd,
474474
force=False,
475475
)
476476
# For some backends, the volume may be detached immediately
477477
detached = await run_async(
478478
compute.is_volume_detached,
479479
volume=volume,
480-
instance_id=jpd.instance_id,
480+
provisioning_data=jpd,
481481
)
482482
else:
483483
detached = await run_async(
484484
compute.is_volume_detached,
485485
volume=volume,
486-
instance_id=jpd.instance_id,
486+
provisioning_data=jpd,
487487
)
488488
if not detached and _should_force_detach_volume(job_model, job_spec.stop_duration):
489489
logger.info(
@@ -494,7 +494,7 @@ async def _detach_volume_from_job_instance(
494494
await run_async(
495495
compute.detach_volume,
496496
volume=volume,
497-
instance_id=jpd.instance_id,
497+
provisioning_data=jpd,
498498
force=True,
499499
)
500500
# Let the next iteration check if force detach worked

src/tests/_internal/server/background/tasks/test_process_terminating_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ async def test_force_detaches_job_volumes(self, session: AsyncSession):
190190
m.assert_awaited_once()
191191
backend_mock.compute.return_value.detach_volume.assert_called_once_with(
192192
volume=volume_model_to_volume(volume),
193-
instance_id=job_provisioning_data.instance_id,
193+
provisioning_data=job_provisioning_data,
194194
force=True,
195195
)
196196
backend_mock.compute.return_value.is_volume_detached.assert_called_once()

0 commit comments

Comments
 (0)