From ae1f6cfb9dedd34c90154dda37ccd165db407d88 Mon Sep 17 00:00:00 2001 From: Liam Neal Reilly Date: Wed, 1 Jul 2026 22:41:20 +0000 Subject: [PATCH] fix: ModelBuilder resolves private hub artifacts correctly Fix three defects causing ModelBuilder to ignore private hub when resolving model artifacts, forcing the execution role to access the public JumpStart S3 cache bucket. Defect 1: from_jumpstart_config sets hub_name AFTER __init__ has already called _initialize_jumpstart_config(), which takes the else branch and sets hub_arn = None. Fix: derive hub_arn inline after setting hub_name. Defect 2: _build_for_jumpstart does not forward hub_arn or sagemaker_session to get_init_kwargs, causing model data to resolve from the public catalog. Fix: pass hub_arn (when set) and sagemaker_session to all get_init_kwargs call sites in the build path. Defect 3: The v3 Session class (sagemaker.core.helper.session_helper.Session) is missing hub API methods (describe_hub_content, list_hub_content_versions, list_hub_contents) that the JumpStart cache calls during hub content resolution. Fix: add these methods as thin wrappers around the boto3 sagemaker_client calls. Impact: Customers deploying from private hubs via ModelBuilder no longer need to grant their execution role s3:GetObject on the public JumpStart cache bucket. Testing: - 8 new unit tests covering hub_arn derivation, forwarding, and e2e flow - 2 existing unit tests updated to match new call signatures --- .../sagemaker/core/helper/session_helper.py | 70 +++- .../src/sagemaker/serve/model_builder.py | 10 + .../sagemaker/serve/model_builder_servers.py | 7 +- .../sagemaker/serve/model_builder_utils.py | 7 +- .../test_private_hub_artifact_resolution.py | 206 ++++++++++ .../servers/test_model_builder_servers.py | 2 + .../test_private_hub_artifact_resolution.py | 371 ++++++++++++++++++ 7 files changed, 670 insertions(+), 3 deletions(-) create mode 100644 sagemaker-serve/tests/integ/test_private_hub_artifact_resolution.py create mode 100644 sagemaker-serve/tests/unit/test_private_hub_artifact_resolution.py diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 4d33c9c064..ecdd4b95eb 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -1985,7 +1985,75 @@ def expand_role(self, role): if "/" in role: return role return self.boto_session.resource("iam").Role(role).arn - + + # ======================================== + # Hub Operations + # ======================================== + + def describe_hub_content( + self, hub_name, hub_content_name, hub_content_version, hub_content_type, **kwargs + ): + """Describe hub content in a SageMaker Hub. + + Args: + hub_name (str): The name or ARN of the hub. + hub_content_name (str): The name of the hub content. + hub_content_version (str): The version of the hub content. + hub_content_type (str): The type of hub content (Model, ModelReference, Notebook). + + Returns: + dict: Response from the DescribeHubContent API. + """ + return self.sagemaker_client.describe_hub_content( + HubName=hub_name, + HubContentName=hub_content_name, + HubContentVersion=hub_content_version, + HubContentType=hub_content_type, + **kwargs, + ) + + def list_hub_content_versions(self, hub_name, hub_content_name, hub_content_type, **kwargs): + """List versions of hub content in a SageMaker Hub. + + Args: + hub_name (str): The name or ARN of the hub. + hub_content_name (str): The name of the hub content. + hub_content_type (str): The type of hub content. + **kwargs: Additional arguments (e.g., next_token for pagination). + + Returns: + dict: Response from the ListHubContentVersions API. + """ + request = { + "HubName": hub_name, + "HubContentName": hub_content_name, + "HubContentType": hub_content_type, + } + next_token = kwargs.get("next_token") + if next_token: + request["NextToken"] = next_token + return self.sagemaker_client.list_hub_content_versions(**request) + + def list_hub_contents(self, hub_name, hub_content_type, **kwargs): + """List hub contents in a SageMaker Hub. + + Args: + hub_name (str): The name or ARN of the hub. + hub_content_type (str): The type of hub content to list. + **kwargs: Additional arguments (e.g., next_token for pagination). + + Returns: + dict: Response from the ListHubContents API. + """ + request = { + "HubName": hub_name, + "HubContentType": hub_content_type, + } + next_token = kwargs.get("next_token") + if next_token: + request["NextToken"] = next_token + return self.sagemaker_client.list_hub_contents(**request) + def _expand_container_def(c_def): """Placeholder docstring""" diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index b7dc98b768..b7be863d9b 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -3895,6 +3895,16 @@ def from_jumpstart_config( mb_instance.resource_requirements = resource_requirements mb_instance.model_kms_key = model_kms_key mb_instance.hub_name = jumpstart_config.hub_name + if mb_instance.hub_name and not getattr(mb_instance, "hub_arn", None): + from sagemaker.core.jumpstart.hub.utils import ( + generate_hub_arn_for_init_kwargs, + ) + + mb_instance.hub_arn = generate_hub_arn_for_init_kwargs( + hub_name=mb_instance.hub_name, + region=mb_instance.region, + session=mb_instance.sagemaker_session, + ) mb_instance.config_name = jumpstart_config.inference_config_name mb_instance.accept_eula = jumpstart_config.accept_eula mb_instance.tolerate_vulnerable_model = tolerate_vulnerable_model diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index ecbc270540..3ca6b40f6d 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -975,15 +975,20 @@ def _build_for_jumpstart(self) -> Model: self.secret_key = "" # Get JumpStart model configuration - init_kwargs = get_init_kwargs( + init_kwargs_params = dict( model_id=self.model, model_version=self.model_version or "*", region=self.region, instance_type=self.instance_type, + sagemaker_session=self.sagemaker_session, tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), config_name=getattr(self, "config_name", None), ) + hub_arn = getattr(self, "hub_arn", None) + if hub_arn: + init_kwargs_params["hub_arn"] = hub_arn + init_kwargs = get_init_kwargs(**init_kwargs_params) # Configure image URI and environment variables self.image_uri = self.image_uri or init_kwargs.image_uri diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py index 2ec9ef6475..e58ea4d7ad 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_utils.py @@ -629,14 +629,19 @@ def _detect_jumpstart_image(self) -> None: ValueError: If image URI cannot be determined or JumpStart lookup fails. """ try: - init_kwargs = get_init_kwargs( + detect_kwargs = dict( model_id=self.model, model_version=getattr(self, "model_version", None) or "*", region=self.region, instance_type=getattr(self, "instance_type", None), + sagemaker_session=getattr(self, "sagemaker_session", None), tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), ) + hub_arn = getattr(self, "hub_arn", None) + if hub_arn: + detect_kwargs["hub_arn"] = hub_arn + init_kwargs = get_init_kwargs(**detect_kwargs) self.image_uri = init_kwargs.get("image_uri") if not self.image_uri: diff --git a/sagemaker-serve/tests/integ/test_private_hub_artifact_resolution.py b/sagemaker-serve/tests/integ/test_private_hub_artifact_resolution.py new file mode 100644 index 0000000000..4292b6b560 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_private_hub_artifact_resolution.py @@ -0,0 +1,206 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration test for private hub artifact resolution. + +Verifies that ModelBuilder.from_jumpstart_config with hub_name resolves +model artifacts through the private hub content reference, rather than +falling back to the public JumpStart S3 cache. + +This test creates its own private hub, adds a ModelReference, runs the +build flow, and tears everything down afterward. Skips gracefully if +the test environment lacks permissions to create hubs. +""" +from __future__ import absolute_import + +import os +import uuid +import time +import logging + +import boto3 +import pytest +from botocore.exceptions import ClientError + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.core.jumpstart.configs import JumpStartConfig +from sagemaker.train.configs import Compute + +logger = logging.getLogger(__name__) + +# Use a small/fast JumpStart model for the test +TEST_MODEL_ID = "huggingface-llm-gemma-2b" +TEST_MODEL_VERSION = "*" +TEST_INSTANCE_TYPE = "ml.g5.xlarge" +TEST_REGION = os.environ.get("TEST_REGION", "us-east-1") + +HUB_NAME_PREFIX = "sdk-integ-test-private-hub" + + +@pytest.fixture(scope="module") +def private_hub(): + """Create a private hub with a ModelReference, tear down after tests.""" + sm = boto3.client("sagemaker", region_name=TEST_REGION) + hub_name = f"{HUB_NAME_PREFIX}-{uuid.uuid4().hex[:8]}" + + # --- Setup --- + logger.info("Creating private hub: %s", hub_name) + try: + sm.create_hub( + HubName=hub_name, + HubDescription="SDK integration test for private hub artifact resolution", + Tags=[ + {"Key": "Purpose", "Value": "sdk-integ-test"}, + {"Key": "AutoCleanup", "Value": "true"}, + ], + ) + except ClientError as e: + pytest.skip(f"Cannot create hub (likely missing permissions): {e}") + + # Wait for hub to be ready + for _ in range(30): + resp = sm.describe_hub(HubName=hub_name) + if resp["HubStatus"] == "InService": + break + time.sleep(2) + else: + pytest.skip(f"Hub {hub_name} did not reach InService state") + + # Add a ModelReference to the public JumpStart content + public_arn = ( + f"arn:aws:sagemaker:{TEST_REGION}:aws:hub-content/" + f"SageMakerPublicHub/Model/{TEST_MODEL_ID}" + ) + logger.info("Creating hub content reference to: %s", public_arn) + try: + sm.create_hub_content_reference( + HubName=hub_name, + SageMakerPublicHubContentArn=public_arn, + ) + except ClientError as e: + pytest.skip(f"Cannot create hub content reference: {e}") + + # Wait for content reference to be available + for _ in range(60): + try: + contents = sm.list_hub_contents( + HubName=hub_name, + HubContentType="ModelReference", + ) + summaries = contents.get("HubContentSummaries", []) + if any( + s["HubContentName"] == TEST_MODEL_ID and s.get("HubContentStatus") == "Available" + for s in summaries + ): + break + except Exception: + pass + time.sleep(3) + else: + pytest.skip( + f"ModelReference for {TEST_MODEL_ID} not available in " + f"hub {hub_name} after 3 minutes" + ) + + yield hub_name + + # --- Teardown --- + logger.info("Cleaning up hub: %s", hub_name) + try: + contents = sm.list_hub_contents( + HubName=hub_name, + HubContentType="ModelReference", + ) + for content in contents.get("HubContentSummaries", []): + try: + sm.delete_hub_content_reference( + HubName=hub_name, + HubContentName=content["HubContentName"], + HubContentType="ModelReference", + ) + except Exception as e: + logger.warning("Failed to delete content ref: %s", e) + + sm.delete_hub(HubName=hub_name) + logger.info("Hub %s deleted", hub_name) + except Exception as e: + logger.warning("Hub cleanup failed: %s", e) + + +@pytest.fixture(scope="module") +def sagemaker_session(): + """Create a SageMaker session using default credentials.""" + from sagemaker.core.helper.session_helper import Session + + return Session(boto_session=boto3.Session(region_name=TEST_REGION)) + + +@pytest.fixture(scope="module") +def execution_role(): + """Get a SageMaker execution role from the caller's identity.""" + sts = boto3.client("sts", region_name=TEST_REGION) + identity = sts.get_caller_identity() + account_id = identity["Account"] + return f"arn:aws:iam::{account_id}:role/Admin" + + +@pytest.mark.slow_test +def test_from_jumpstart_config_derives_hub_arn(private_hub, sagemaker_session): + """Verify from_jumpstart_config correctly derives hub_arn from hub_name.""" + js_config = JumpStartConfig( + model_id=TEST_MODEL_ID, + model_version=TEST_MODEL_VERSION, + hub_name=private_hub, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + compute=Compute(instance_type=TEST_INSTANCE_TYPE), + sagemaker_session=sagemaker_session, + ) + + assert ( + mb.hub_arn is not None + ), f"hub_arn is None after from_jumpstart_config with hub_name={private_hub}" + assert private_hub in mb.hub_arn + logger.info("hub_arn correctly derived: %s", mb.hub_arn) + + +@pytest.mark.slow_test +def test_build_resolves_artifacts_via_private_hub(private_hub, execution_role, sagemaker_session): + """Verify build() resolves model data through the private hub.""" + js_config = JumpStartConfig( + model_id=TEST_MODEL_ID, + model_version=TEST_MODEL_VERSION, + hub_name=private_hub, + accept_eula=True, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=execution_role, + compute=Compute(instance_type=TEST_INSTANCE_TYPE), + sagemaker_session=sagemaker_session, + ) + + mb.build() + + # After build, model data URI must NOT reference the public cache + model_data = getattr(mb, "s3_model_data_url", None) + assert model_data is not None, "No model_data found after build()" + + model_data_str = str(model_data) + assert "jumpstart-cache-prod" not in model_data_str, ( + f"Model data still points to public JumpStart cache: " + f"{model_data_str}. Expected private hub artifact resolution." + ) + logger.info("Model data resolved to: %s", model_data_str) diff --git a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py index 053b14f416..4fc1a558ba 100644 --- a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py +++ b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py @@ -948,6 +948,7 @@ def test_build_passes_config_name_to_get_init_kwargs( model_version="*", region=self.builder.region, instance_type=self.builder.instance_type, + sagemaker_session=self.builder.sagemaker_session, tolerate_vulnerable_model=None, tolerate_deprecated_model=None, config_name="lmi-optimized", @@ -979,6 +980,7 @@ def test_build_passes_none_config_name_when_not_set( model_version="*", region=self.builder.region, instance_type=self.builder.instance_type, + sagemaker_session=self.builder.sagemaker_session, tolerate_vulnerable_model=None, tolerate_deprecated_model=None, config_name=None, diff --git a/sagemaker-serve/tests/unit/test_private_hub_artifact_resolution.py b/sagemaker-serve/tests/unit/test_private_hub_artifact_resolution.py new file mode 100644 index 0000000000..4fdee7e6ef --- /dev/null +++ b/sagemaker-serve/tests/unit/test_private_hub_artifact_resolution.py @@ -0,0 +1,371 @@ +""" +Unit tests for private hub artifact resolution fix. + +Tests two defects: +1. from_jumpstart_config sets hub_name after __init__ already ran + _initialize_jumpstart_config(), leaving hub_arn as None. +2. _build_for_jumpstart does not forward hub_arn to get_init_kwargs, + so model data resolves from the public catalog instead of the private hub. +""" + +import unittest +from unittest.mock import Mock, patch + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.serve.mode.function_pointers import Mode +from sagemaker.core.training.configs import Compute +from sagemaker.core.jumpstart.configs import JumpStartConfig + + +MOCK_ROLE_ARN = "arn:aws:iam::123456789012:role/SageMakerRole" +MOCK_HUB_NAME = "my-private-hub" +MOCK_HUB_ARN = "arn:aws:sagemaker:us-east-1:123456789012:hub/my-private-hub" +MOCK_MODEL_ID = "huggingface-llm-phi-4-mini-instruct" +MOCK_MODEL_VERSION = "1.1.0" + + +def _mock_session(): + """Create a mock session that won't trigger real AWS calls.""" + session = Mock() + session.boto_region_name = "us-east-1" + session.sagemaker_config = None + session.boto_session = Mock() + session.boto_session.region_name = "us-east-1" + return session + + +# Common patch to prevent __init__ from making real S3/API calls during +# instance type auto-detection and model ID validation. +_PATCH_IS_JS = patch.object(ModelBuilder, "_is_jumpstart_model_id", return_value=False) + + +class TestFromJumpStartConfigHubArnDerivation(unittest.TestCase): + """Test that from_jumpstart_config correctly derives hub_arn from hub_name.""" + + @_PATCH_IS_JS + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs", return_value={}) + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch( + "sagemaker.core.jumpstart.hub.utils.generate_hub_arn_for_init_kwargs", + return_value=MOCK_HUB_ARN, + ) + def test_hub_arn_derived_when_hub_name_set( + self, mock_generate_arn, mock_validate, mock_deploy_kwargs, mock_is_js + ): + """hub_arn must be derived after hub_name is assigned in from_jumpstart_config.""" + js_config = JumpStartConfig( + model_id=MOCK_MODEL_ID, + model_version=MOCK_MODEL_VERSION, + hub_name=MOCK_HUB_NAME, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_session(), + ) + + # The key assertion: hub_arn is populated, proving _initialize_jumpstart_config + # ran after hub_name was set in from_jumpstart_config + self.assertEqual(mb.hub_name, MOCK_HUB_NAME) + self.assertEqual(mb.hub_arn, MOCK_HUB_ARN) + # generate_hub_arn_for_init_kwargs must have been called with the hub_name + mock_generate_arn.assert_called() + + @_PATCH_IS_JS + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs", return_value={}) + @patch( + "sagemaker.core.jumpstart.hub.utils.generate_hub_arn_for_init_kwargs", + return_value=MOCK_HUB_ARN, + ) + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + def test_hub_arn_populated_end_to_end( + self, mock_validate, mock_generate_arn, mock_deploy_kwargs, mock_is_js + ): + """End-to-end: hub_arn is correctly populated when hub_name is specified.""" + js_config = JumpStartConfig( + model_id=MOCK_MODEL_ID, + model_version=MOCK_MODEL_VERSION, + hub_name=MOCK_HUB_NAME, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_session(), + ) + + self.assertEqual(mb.hub_name, MOCK_HUB_NAME) + self.assertEqual(mb.hub_arn, MOCK_HUB_ARN) + mock_generate_arn.assert_called() + + @_PATCH_IS_JS + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs", return_value={}) + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + def test_hub_arn_is_none_when_no_hub_name(self, mock_validate, mock_deploy_kwargs, mock_is_js): + """hub_arn should remain None when hub_name is not provided.""" + js_config = JumpStartConfig( + model_id=MOCK_MODEL_ID, + model_version=MOCK_MODEL_VERSION, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_session(), + ) + + self.assertIsNone(mb.hub_name) + self.assertIsNone(mb.hub_arn) + + +class TestBuildForJumpStartForwardsHubArn(unittest.TestCase): + """Test that _build_for_jumpstart forwards hub_arn to get_init_kwargs.""" + + def setUp(self): + self.mock_session = _mock_session() + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_hub_arn_forwarded_to_get_init_kwargs( + self, mock_prepare, mock_create, mock_get_kwargs, mock_validate, mock_is_js + ): + """get_init_kwargs must receive hub_arn so model data resolves via private hub.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = { + "S3DataSource": { + "S3Uri": "s3://my-private-hub-bucket/artifacts/model.tar.gz", + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + mock_init_kwargs.enable_network_isolation = None + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock() + mock_create.return_value = mock_model + + builder = ModelBuilder( + model=MOCK_MODEL_ID, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + builder._optimizing = False + builder.hub_name = MOCK_HUB_NAME + builder.hub_arn = MOCK_HUB_ARN + builder.model_version = MOCK_MODEL_VERSION + + builder._build_for_jumpstart() + + # Verify hub_arn was passed to get_init_kwargs + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + actual_hub_arn = call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn") + self.assertEqual(actual_hub_arn, MOCK_HUB_ARN) + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_hub_arn_none_when_no_private_hub( + self, mock_prepare, mock_create, mock_get_kwargs, mock_validate, mock_is_js + ): + """When no private hub is configured, hub_arn should be None (public catalog).""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache-prod-us-east-1/models/model.tar.gz" + mock_init_kwargs.enable_network_isolation = None + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock() + mock_create.return_value = mock_model + + builder = ModelBuilder( + model=MOCK_MODEL_ID, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + builder._optimizing = False + builder.model_version = MOCK_MODEL_VERSION + + builder._build_for_jumpstart() + + # Verify hub_arn is NOT passed when no private hub (public catalog) + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + actual_hub_arn = call_kwargs.kwargs.get("hub_arn") + self.assertIsNone(actual_hub_arn) + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_private_hub_resolves_non_public_model_data( + self, mock_prepare, mock_create, mock_get_kwargs, mock_validate, mock_is_js + ): + """With hub_arn set, model_data should resolve to private hub bucket, not public cache.""" + private_s3_uri = "s3://my-private-hub-bucket/hub-content/artifacts/model.tar.gz" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = { + "S3DataSource": { + "S3Uri": private_s3_uri, + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + mock_init_kwargs.enable_network_isolation = None + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock() + mock_create.return_value = mock_model + + builder = ModelBuilder( + model=MOCK_MODEL_ID, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + builder._optimizing = False + builder.hub_name = MOCK_HUB_NAME + builder.hub_arn = MOCK_HUB_ARN + builder.model_version = MOCK_MODEL_VERSION + + builder._build_for_jumpstart() + + # Confirm model data does NOT point to public JumpStart cache + self.assertNotIn("jumpstart-cache-prod", builder.s3_model_data_url) + self.assertEqual(builder.s3_model_data_url, private_s3_uri) + + +class TestDetectJumpStartImageForwardsHubArn(unittest.TestCase): + """Test that _detect_jumpstart_image forwards hub_arn to get_init_kwargs.""" + + def setUp(self): + self.mock_session = _mock_session() + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch("sagemaker.serve.model_builder_utils.get_init_kwargs") + def test_hub_arn_forwarded_in_detect_jumpstart_image( + self, mock_get_kwargs, mock_validate, mock_is_js + ): + """_detect_jumpstart_image must pass hub_arn so private hub images resolve correctly.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.get = lambda k: mock_init_kwargs.image_uri if k == "image_uri" else None + mock_get_kwargs.return_value = mock_init_kwargs + + builder = ModelBuilder( + model=MOCK_MODEL_ID, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + ) + builder.hub_arn = MOCK_HUB_ARN + builder.model_version = MOCK_MODEL_VERSION + + builder._detect_jumpstart_image() + + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + actual_hub_arn = call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn") + self.assertEqual(actual_hub_arn, MOCK_HUB_ARN) + + +class TestEndToEndPrivateHubFlow(unittest.TestCase): + """Integration-style test: from_jumpstart_config with hub_name -> _build_for_jumpstart.""" + + @_PATCH_IS_JS + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + @patch("sagemaker.serve.model_builder._retrieve_model_deploy_kwargs", return_value={}) + @patch("sagemaker.core.jumpstart.utils.validate_model_id_and_get_type", return_value=None) + @patch( + "sagemaker.core.jumpstart.hub.utils.generate_hub_arn_for_init_kwargs", + return_value=MOCK_HUB_ARN, + ) + def test_from_jumpstart_config_then_build_uses_private_hub( + self, + mock_generate_arn, + mock_validate, + mock_deploy_kwargs, + mock_prepare, + mock_create, + mock_get_kwargs, + mock_is_js, + ): + """Full flow: from_jumpstart_config with hub_name -> build -> hub_arn passed through.""" + private_s3_uri = "s3://private-hub-bucket/content/model.tar.gz" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/djl-inference:0.27.0-lmi10.0.0-cu124" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = { + "S3DataSource": { + "S3Uri": private_s3_uri, + "S3DataType": "S3Prefix", + "CompressionType": "None", + } + } + mock_init_kwargs.enable_network_isolation = None + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock() + mock_create.return_value = mock_model + + js_config = JumpStartConfig( + model_id=MOCK_MODEL_ID, + model_version=MOCK_MODEL_VERSION, + hub_name=MOCK_HUB_NAME, + ) + + mb = ModelBuilder.from_jumpstart_config( + jumpstart_config=js_config, + role_arn=MOCK_ROLE_ARN, + sagemaker_session=_mock_session(), + compute=Compute(instance_type="ml.g5.xlarge"), + ) + + # Verify hub_arn was derived + self.assertEqual(mb.hub_arn, MOCK_HUB_ARN) + self.assertEqual(mb.hub_name, MOCK_HUB_NAME) + + # Now trigger build + mb.mode = Mode.SAGEMAKER_ENDPOINT + mb._optimizing = False + mb._build_for_jumpstart() + + # Verify hub_arn was forwarded to get_init_kwargs + mock_get_kwargs.assert_called_once() + call_kwargs = mock_get_kwargs.call_args + actual_hub_arn = call_kwargs.kwargs.get("hub_arn") or call_kwargs[1].get("hub_arn") + self.assertEqual(actual_hub_arn, MOCK_HUB_ARN) + + # Verify model data points to private hub, not public cache + self.assertEqual(mb.s3_model_data_url, private_s3_uri) + self.assertNotIn("jumpstart-cache-prod", mb.s3_model_data_url) + + +if __name__ == "__main__": + unittest.main()