Skip to content

Support new GCP CPU series #2685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,12 @@ def delete_volume(self, volume: Volume):
raise e
logger.debug("Deleted EBS volume %s", volume.configuration.name)

def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
def attach_volume(
self, volume: Volume, provisioning_data: JobProvisioningData
) -> VolumeAttachmentData:
ec2_client = self.session.client("ec2", region_name=volume.configuration.region)

instance_id = provisioning_data.instance_id
device_names = aws_resources.list_available_device_names(
ec2_client=ec2_client, instance_id=instance_id
)
Expand Down Expand Up @@ -646,9 +649,12 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
logger.debug("Attached EBS volume %s to instance %s", volume.volume_id, instance_id)
return VolumeAttachmentData(device_name=device_name)

def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
def detach_volume(
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
):
ec2_client = self.session.client("ec2", region_name=volume.configuration.region)

instance_id = provisioning_data.instance_id
logger.debug("Detaching EBS volume %s from instance %s", volume.volume_id, instance_id)
attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
try:
Expand All @@ -667,9 +673,10 @@ def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
raise e
logger.debug("Detached EBS volume %s from instance %s", volume.volume_id, instance_id)

def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool:
ec2_client = self.session.client("ec2", region_name=volume.configuration.region)

instance_id = provisioning_data.instance_id
logger.debug("Getting EBS volume %s status", volume.volume_id)
response = ec2_client.describe_volumes(VolumeIds=[volume.volume_id])
volumes_infos = response.get("Volumes")
Expand Down
10 changes: 7 additions & 3 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def delete_volume(self, volume: Volume):
"""
raise NotImplementedError()

def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
def attach_volume(
self, volume: Volume, provisioning_data: JobProvisioningData
) -> VolumeAttachmentData:
"""
Attaches a volume to the instance.
If the volume is not found, it should raise `ComputeError()`.
Expand All @@ -345,15 +347,17 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
"""
raise NotImplementedError()

def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
def detach_volume(
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
):
"""
Detaches a volume from the instance.
Implement only if compute may return `VolumeProvisioningData.detachable`.
Otherwise, volumes should be detached on instance termination.
"""
raise NotImplementedError()

def is_volume_detached(self, volume: Volume, instance_id: str) -> bool:
def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool:
"""
Checks if a volume was detached from the instance.
If `detach_volume()` may fail to detach volume,
Expand Down
63 changes: 43 additions & 20 deletions src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,32 +649,41 @@ def delete_volume(self, volume: Volume):
pass
logger.debug("Deleted persistent disk for volume %s", volume.name)

def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentData:
def attach_volume(
self, volume: Volume, provisioning_data: JobProvisioningData
) -> VolumeAttachmentData:
instance_id = provisioning_data.instance_id
logger.debug(
"Attaching persistent disk for volume %s to instance %s",
volume.volume_id,
instance_id,
)
if not gcp_resources.instance_type_supports_persistent_disk(
provisioning_data.instance_type.name
):
raise ComputeError(
f"Instance type {provisioning_data.instance_type.name} does not support Persistent disk volumes"
)

zone = get_or_error(volume.provisioning_data).availability_zone
is_tpu = _is_tpu_provisioning_data(provisioning_data)
try:
disk = self.disk_client.get(
project=self.config.project_id,
zone=zone,
disk=volume.volume_id,
)
disk_url = disk.self_link
except google.api_core.exceptions.NotFound:
raise ComputeError("Persistent disk found")

# This method has no information if the instance is a TPU or a VM,
# so we first try to see if there is a TPU with such name
try:
try:
if is_tpu:
get_node_request = tpu_v2.GetNodeRequest(
name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
)
tpu_node = self.tpu_client.get_node(get_node_request)
except google.api_core.exceptions.NotFound:
tpu_node = None

if tpu_node is not None:
# Python API to attach a disk to a TPU is not documented,
# so we follow the code from the gcloud CLI:
# https://github.com/twistedpair/google-cloud-sdk/blob/26ab5a281d56b384cc25750f3279a27afe5b499f/google-cloud-sdk/lib/googlecloudsdk/command_lib/compute/tpus/tpu_vm/util.py#L113
Expand Down Expand Up @@ -711,7 +720,6 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
attached_disk.auto_delete = False
attached_disk.device_name = f"pd-{volume.volume_id}"
device_name = attached_disk.device_name

operation = self.instances_client.attach_disk(
project=self.config.project_id,
zone=zone,
Expand All @@ -720,31 +728,33 @@ def attach_volume(self, volume: Volume, instance_id: str) -> VolumeAttachmentDat
)
gcp_resources.wait_for_extended_operation(operation, "persistent disk attachment")
except google.api_core.exceptions.NotFound:
raise ComputeError("Persistent disk or instance not found")
raise ComputeError("Disk or instance not found")
logger.debug(
"Attached persistent disk for volume %s to instance %s", volume.volume_id, instance_id
)
return VolumeAttachmentData(device_name=device_name)

def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
def detach_volume(
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
):
instance_id = provisioning_data.instance_id
logger.debug(
"Detaching persistent disk for volume %s from instance %s",
volume.volume_id,
instance_id,
)
zone = get_or_error(volume.provisioning_data).availability_zone
attachment_data = get_or_error(volume.get_attachment_data_for_instance(instance_id))
# This method has no information if the instance is a TPU or a VM,
# so we first try to see if there is a TPU with such name
try:
get_node_request = tpu_v2.GetNodeRequest(
name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
)
tpu_node = self.tpu_client.get_node(get_node_request)
except google.api_core.exceptions.NotFound:
tpu_node = None
is_tpu = _is_tpu_provisioning_data(provisioning_data)
if is_tpu:
try:
get_node_request = tpu_v2.GetNodeRequest(
name=f"projects/{self.config.project_id}/locations/{zone}/nodes/{instance_id}",
)
tpu_node = self.tpu_client.get_node(get_node_request)
except google.api_core.exceptions.NotFound:
raise ComputeError("Instance not found")

if tpu_node is not None:
source_disk = (
f"projects/{self.config.project_id}/zones/{zone}/disks/{volume.volume_id}"
)
Expand Down Expand Up @@ -815,6 +825,11 @@ def _filter(offer: InstanceOffer) -> bool:
if _is_tpu(offer.instance.name) and not _is_single_host_tpu(offer.instance.name):
return False
for family in [
"m4-",
"c4-",
"n4-",
"h3-",
"n2-",
"e2-medium",
"e2-standard-",
"e2-highmem-",
Expand Down Expand Up @@ -1001,3 +1016,11 @@ def _get_tpu_data_disk_for_volume(project_id: str, volume: Volume) -> tpu_v2.Att
mode=tpu_v2.AttachedDisk.DiskMode.READ_WRITE,
)
return attached_disk


def _is_tpu_provisioning_data(provisioning_data: JobProvisioningData) -> bool:
is_tpu = False
if provisioning_data.backend_data:
backend_data_dict = json.loads(provisioning_data.backend_data)
is_tpu = backend_data_dict.get("is_tpu", False)
return is_tpu
20 changes: 18 additions & 2 deletions src/dstack/_internal/core/backends/gcp/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def create_instance_struct(
initialize_params = compute_v1.AttachedDiskInitializeParams()
initialize_params.source_image = image_id
initialize_params.disk_size_gb = disk_size
initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"
if instance_type_supports_persistent_disk(machine_type):
initialize_params.disk_type = f"zones/{zone}/diskTypes/pd-balanced"
else:
initialize_params.disk_type = f"zones/{zone}/diskTypes/hyperdisk-balanced"
disk.initialize_params = initialize_params
instance.disks = [disk]

Expand Down Expand Up @@ -421,7 +424,7 @@ def wait_for_extended_operation(

if operation.error_code:
# Write only debug logs here.
# The unexpected errors will be propagated and logged appropriatly by the caller.
# The unexpected errors will be propagated and logged appropriately by the caller.
logger.debug(
"Error during %s: [Code: %s]: %s",
verbose_name,
Expand Down Expand Up @@ -462,3 +465,16 @@ def get_placement_policy_resource_name(
placement_policy: str,
) -> str:
return f"projects/{project_id}/regions/{region}/resourcePolicies/{placement_policy}"


def instance_type_supports_persistent_disk(instance_type_name: str) -> bool:
return not any(
instance_type_name.startswith(series)
for series in [
"m4-",
"c4-",
"n4-",
"h3-",
"v6e",
]
)
6 changes: 4 additions & 2 deletions src/dstack/_internal/core/backends/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,10 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData:
def delete_volume(self, volume: Volume):
pass

def attach_volume(self, volume: Volume, instance_id: str):
def attach_volume(self, volume: Volume, provisioning_data: JobProvisioningData):
pass

def detach_volume(self, volume: Volume, instance_id: str, force: bool = False):
def detach_volume(
self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False
):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ async def _attach_volumes(
backend=backend,
volume_model=volume_model,
instance=instance,
instance_id=job_provisioning_data.instance_id,
jpd=job_provisioning_data,
)
job_runtime_data.volume_names.append(volume.name)
break # attach next mount point
Expand All @@ -685,7 +685,7 @@ async def _attach_volume(
backend: Backend,
volume_model: VolumeModel,
instance: InstanceModel,
instance_id: str,
jpd: JobProvisioningData,
):
compute = backend.compute()
assert isinstance(compute, ComputeWithVolumeSupport)
Expand All @@ -697,7 +697,7 @@ async def _attach_volume(
attachment_data = await common_utils.run_async(
compute.attach_volume,
volume=volume,
instance_id=instance_id,
provisioning_data=jpd,
)
volume_attachment_model = VolumeAttachmentModel(
volume=volume_model,
Expand Down
8 changes: 4 additions & 4 deletions src/dstack/_internal/server/services/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,20 +470,20 @@ async def _detach_volume_from_job_instance(
await run_async(
compute.detach_volume,
volume=volume,
instance_id=jpd.instance_id,
provisioning_data=jpd,
force=False,
)
# For some backends, the volume may be detached immediately
detached = await run_async(
compute.is_volume_detached,
volume=volume,
instance_id=jpd.instance_id,
provisioning_data=jpd,
)
else:
detached = await run_async(
compute.is_volume_detached,
volume=volume,
instance_id=jpd.instance_id,
provisioning_data=jpd,
)
if not detached and _should_force_detach_volume(job_model, job_spec.stop_duration):
logger.info(
Expand All @@ -494,7 +494,7 @@ async def _detach_volume_from_job_instance(
await run_async(
compute.detach_volume,
volume=volume,
instance_id=jpd.instance_id,
provisioning_data=jpd,
force=True,
)
# Let the next iteration check if force detach worked
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ async def test_force_detaches_job_volumes(self, session: AsyncSession):
m.assert_awaited_once()
backend_mock.compute.return_value.detach_volume.assert_called_once_with(
volume=volume_model_to_volume(volume),
instance_id=job_provisioning_data.instance_id,
provisioning_data=job_provisioning_data,
force=True,
)
backend_mock.compute.return_value.is_volume_detached.assert_called_once()
Expand Down
Loading