Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion sagemaker-core/src/sagemaker/core/helper/session_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1985,7 +1985,75 @@ def expand_role(self, role):
if "/" in role:
return role
return self.boto_session.resource("iam").Role(role).arn


# ========================================
# Hub Operations
# ========================================

def describe_hub_content(
self, hub_name, hub_content_name, hub_content_version, hub_content_type, **kwargs
):
"""Describe hub content in a SageMaker Hub.

Args:
hub_name (str): The name or ARN of the hub.
hub_content_name (str): The name of the hub content.
hub_content_version (str): The version of the hub content.
hub_content_type (str): The type of hub content (Model, ModelReference, Notebook).

Returns:
dict: Response from the DescribeHubContent API.
"""
return self.sagemaker_client.describe_hub_content(
HubName=hub_name,
HubContentName=hub_content_name,
HubContentVersion=hub_content_version,
HubContentType=hub_content_type,
**kwargs,
)

def list_hub_content_versions(self, hub_name, hub_content_name, hub_content_type, **kwargs):
"""List versions of hub content in a SageMaker Hub.

Args:
hub_name (str): The name or ARN of the hub.
hub_content_name (str): The name of the hub content.
hub_content_type (str): The type of hub content.
**kwargs: Additional arguments (e.g., next_token for pagination).

Returns:
dict: Response from the ListHubContentVersions API.
"""
request = {
"HubName": hub_name,
"HubContentName": hub_content_name,
"HubContentType": hub_content_type,
}
next_token = kwargs.get("next_token")
if next_token:
request["NextToken"] = next_token
return self.sagemaker_client.list_hub_content_versions(**request)

def list_hub_contents(self, hub_name, hub_content_type, **kwargs):
"""List hub contents in a SageMaker Hub.

Args:
hub_name (str): The name or ARN of the hub.
hub_content_type (str): The type of hub content to list.
**kwargs: Additional arguments (e.g., next_token for pagination).

Returns:
dict: Response from the ListHubContents API.
"""
request = {
"HubName": hub_name,
"HubContentType": hub_content_type,
}
next_token = kwargs.get("next_token")
if next_token:
request["NextToken"] = next_token
return self.sagemaker_client.list_hub_contents(**request)


def _expand_container_def(c_def):
"""Placeholder docstring"""
Expand Down
10 changes: 10 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3895,6 +3895,16 @@ def from_jumpstart_config(
mb_instance.resource_requirements = resource_requirements
mb_instance.model_kms_key = model_kms_key
mb_instance.hub_name = jumpstart_config.hub_name
if mb_instance.hub_name and not getattr(mb_instance, "hub_arn", None):
from sagemaker.core.jumpstart.hub.utils import (
generate_hub_arn_for_init_kwargs,
)

mb_instance.hub_arn = generate_hub_arn_for_init_kwargs(
hub_name=mb_instance.hub_name,
region=mb_instance.region,
session=mb_instance.sagemaker_session,
)
mb_instance.config_name = jumpstart_config.inference_config_name
mb_instance.accept_eula = jumpstart_config.accept_eula
mb_instance.tolerate_vulnerable_model = tolerate_vulnerable_model
Expand Down
7 changes: 6 additions & 1 deletion sagemaker-serve/src/sagemaker/serve/model_builder_servers.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,15 +975,20 @@ def _build_for_jumpstart(self) -> Model:
self.secret_key = ""

# Get JumpStart model configuration
init_kwargs = get_init_kwargs(
init_kwargs_params = dict(
model_id=self.model,
model_version=self.model_version or "*",
region=self.region,
instance_type=self.instance_type,
sagemaker_session=self.sagemaker_session,
tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None),
tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None),
config_name=getattr(self, "config_name", None),
)
hub_arn = getattr(self, "hub_arn", None)
if hub_arn:
init_kwargs_params["hub_arn"] = hub_arn
init_kwargs = get_init_kwargs(**init_kwargs_params)

# Configure image URI and environment variables
self.image_uri = self.image_uri or init_kwargs.image_uri
Expand Down
7 changes: 6 additions & 1 deletion sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,14 +629,19 @@ def _detect_jumpstart_image(self) -> None:
ValueError: If image URI cannot be determined or JumpStart lookup fails.
"""
try:
init_kwargs = get_init_kwargs(
detect_kwargs = dict(
model_id=self.model,
model_version=getattr(self, "model_version", None) or "*",
region=self.region,
instance_type=getattr(self, "instance_type", None),
sagemaker_session=getattr(self, "sagemaker_session", None),
tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None),
tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None),
)
hub_arn = getattr(self, "hub_arn", None)
if hub_arn:
detect_kwargs["hub_arn"] = hub_arn
init_kwargs = get_init_kwargs(**detect_kwargs)

self.image_uri = init_kwargs.get("image_uri")
if not self.image_uri:
Expand Down
206 changes: 206 additions & 0 deletions sagemaker-serve/tests/integ/test_private_hub_artifact_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Integration test for private hub artifact resolution.

Verifies that ModelBuilder.from_jumpstart_config with hub_name resolves
model artifacts through the private hub content reference, rather than
falling back to the public JumpStart S3 cache.

This test creates its own private hub, adds a ModelReference, runs the
build flow, and tears everything down afterward. Skips gracefully if
the test environment lacks permissions to create hubs.
"""
from __future__ import absolute_import

import os
import uuid
import time
import logging

import boto3
import pytest
from botocore.exceptions import ClientError

from sagemaker.serve.model_builder import ModelBuilder
from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.train.configs import Compute

logger = logging.getLogger(__name__)

# Use a small/fast JumpStart model for the test
TEST_MODEL_ID = "huggingface-llm-gemma-2b"
TEST_MODEL_VERSION = "*"
TEST_INSTANCE_TYPE = "ml.g5.xlarge"
TEST_REGION = os.environ.get("TEST_REGION", "us-east-1")

HUB_NAME_PREFIX = "sdk-integ-test-private-hub"


@pytest.fixture(scope="module")
def private_hub():
"""Create a private hub with a ModelReference, tear down after tests."""
sm = boto3.client("sagemaker", region_name=TEST_REGION)
hub_name = f"{HUB_NAME_PREFIX}-{uuid.uuid4().hex[:8]}"

# --- Setup ---
logger.info("Creating private hub: %s", hub_name)
try:
sm.create_hub(
HubName=hub_name,
HubDescription="SDK integration test for private hub artifact resolution",
Tags=[
{"Key": "Purpose", "Value": "sdk-integ-test"},
{"Key": "AutoCleanup", "Value": "true"},
],
)
except ClientError as e:
pytest.skip(f"Cannot create hub (likely missing permissions): {e}")

# Wait for hub to be ready
for _ in range(30):
resp = sm.describe_hub(HubName=hub_name)
if resp["HubStatus"] == "InService":
break
time.sleep(2)
else:
pytest.skip(f"Hub {hub_name} did not reach InService state")

# Add a ModelReference to the public JumpStart content
public_arn = (
f"arn:aws:sagemaker:{TEST_REGION}:aws:hub-content/"
f"SageMakerPublicHub/Model/{TEST_MODEL_ID}"
)
logger.info("Creating hub content reference to: %s", public_arn)
try:
sm.create_hub_content_reference(
HubName=hub_name,
SageMakerPublicHubContentArn=public_arn,
)
except ClientError as e:
pytest.skip(f"Cannot create hub content reference: {e}")

# Wait for content reference to be available
for _ in range(60):
try:
contents = sm.list_hub_contents(
HubName=hub_name,
HubContentType="ModelReference",
)
summaries = contents.get("HubContentSummaries", [])
if any(
s["HubContentName"] == TEST_MODEL_ID and s.get("HubContentStatus") == "Available"
for s in summaries
):
break
except Exception:
pass
time.sleep(3)
else:
pytest.skip(
f"ModelReference for {TEST_MODEL_ID} not available in "
f"hub {hub_name} after 3 minutes"
)

yield hub_name

# --- Teardown ---
logger.info("Cleaning up hub: %s", hub_name)
try:
contents = sm.list_hub_contents(
HubName=hub_name,
HubContentType="ModelReference",
)
for content in contents.get("HubContentSummaries", []):
try:
sm.delete_hub_content_reference(
HubName=hub_name,
HubContentName=content["HubContentName"],
HubContentType="ModelReference",
)
except Exception as e:
logger.warning("Failed to delete content ref: %s", e)

sm.delete_hub(HubName=hub_name)
logger.info("Hub %s deleted", hub_name)
except Exception as e:
logger.warning("Hub cleanup failed: %s", e)


@pytest.fixture(scope="module")
def sagemaker_session():
"""Create a SageMaker session using default credentials."""
from sagemaker.core.helper.session_helper import Session

return Session(boto_session=boto3.Session(region_name=TEST_REGION))


@pytest.fixture(scope="module")
def execution_role():
"""Get a SageMaker execution role from the caller's identity."""
sts = boto3.client("sts", region_name=TEST_REGION)
identity = sts.get_caller_identity()
account_id = identity["Account"]
return f"arn:aws:iam::{account_id}:role/Admin"


@pytest.mark.slow_test
def test_from_jumpstart_config_derives_hub_arn(private_hub, sagemaker_session):
"""Verify from_jumpstart_config correctly derives hub_arn from hub_name."""
js_config = JumpStartConfig(
model_id=TEST_MODEL_ID,
model_version=TEST_MODEL_VERSION,
hub_name=private_hub,
)

mb = ModelBuilder.from_jumpstart_config(
jumpstart_config=js_config,
compute=Compute(instance_type=TEST_INSTANCE_TYPE),
sagemaker_session=sagemaker_session,
)

assert (
mb.hub_arn is not None
), f"hub_arn is None after from_jumpstart_config with hub_name={private_hub}"
assert private_hub in mb.hub_arn
logger.info("hub_arn correctly derived: %s", mb.hub_arn)


@pytest.mark.slow_test
def test_build_resolves_artifacts_via_private_hub(private_hub, execution_role, sagemaker_session):
"""Verify build() resolves model data through the private hub."""
js_config = JumpStartConfig(
model_id=TEST_MODEL_ID,
model_version=TEST_MODEL_VERSION,
hub_name=private_hub,
accept_eula=True,
)

mb = ModelBuilder.from_jumpstart_config(
jumpstart_config=js_config,
role_arn=execution_role,
compute=Compute(instance_type=TEST_INSTANCE_TYPE),
sagemaker_session=sagemaker_session,
)

mb.build()

# After build, model data URI must NOT reference the public cache
model_data = getattr(mb, "s3_model_data_url", None)
assert model_data is not None, "No model_data found after build()"

model_data_str = str(model_data)
assert "jumpstart-cache-prod" not in model_data_str, (
f"Model data still points to public JumpStart cache: "
f"{model_data_str}. Expected private hub artifact resolution."
)
logger.info("Model data resolved to: %s", model_data_str)
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,7 @@ def test_build_passes_config_name_to_get_init_kwargs(
model_version="*",
region=self.builder.region,
instance_type=self.builder.instance_type,
sagemaker_session=self.builder.sagemaker_session,
tolerate_vulnerable_model=None,
tolerate_deprecated_model=None,
config_name="lmi-optimized",
Expand Down Expand Up @@ -979,6 +980,7 @@ def test_build_passes_none_config_name_when_not_set(
model_version="*",
region=self.builder.region,
instance_type=self.builder.instance_type,
sagemaker_session=self.builder.sagemaker_session,
tolerate_vulnerable_model=None,
tolerate_deprecated_model=None,
config_name=None,
Expand Down
Loading
Loading