diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py index dd09bc1..2855008 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/plugin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import datetime import functools @@ -7,7 +9,6 @@ from enum import Enum from typing import Any, Callable, MutableMapping -from aws_durable_execution_sdk_python.exceptions import SuspendExecution from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( DurableExecutionInvocationOutput, @@ -54,13 +55,10 @@ class UserFunctionOutcome(Enum): PENDING = "PENDING" @classmethod - def from_error(cls, error: ErrorObject | None) -> "UserFunctionOutcome": + def from_error(cls, error: ErrorObject | None) -> UserFunctionOutcome: if error is None: return cls(cls.SUCCEEDED) - elif error.type == SuspendExecution.__name__: - return cls(cls.PENDING) - else: - return cls(cls.FAILED) + return cls(cls.FAILED) @dataclass(frozen=True) @@ -86,7 +84,7 @@ class UserFunctionEndInfo(OperationInfo): @classmethod def from_start_info( cls, start_info: UserFunctionStartInfo, error: ErrorObject | None - ) -> "UserFunctionEndInfo": + ) -> UserFunctionEndInfo: return UserFunctionEndInfo( operation_id=start_info.operation_id, operation_type=start_info.operation_type, @@ -101,6 +99,24 @@ def from_start_info( error=error, ) + @classmethod + def from_start_info_suspended( + cls, start_info: UserFunctionStartInfo + ) -> UserFunctionEndInfo: + return cls( + operation_id=start_info.operation_id, + operation_type=start_info.operation_type, + sub_type=start_info.sub_type, + name=start_info.name, + parent_id=start_info.parent_id, + start_time=start_info.start_time, + is_replay_children=start_info.is_replay_children, + attempt=start_info.attempt, + outcome=UserFunctionOutcome.PENDING, + end_time=datetime.datetime.now(datetime.UTC), + error=None, + ) + @dataclass(frozen=True) class InvocationInfo: @@ -310,6 +326,16 @@ def on_user_function_end(self, start_info: UserFunctionStartInfo, error) -> None UserFunctionEndInfo.from_start_info(start_info, error), sync=True ) + def on_user_function_suspend(self, start_info: UserFunctionStartInfo) -> None: + """Execute any registered plugins when an operation's user function suspends. + + A suspension is normal durable control flow, not a failure, so the + operation is reported with a PENDING outcome and no error. + """ + self.execute_plugins( + UserFunctionEndInfo.from_start_info_suspended(start_info), sync=True + ) + def on_operation_action(self, update: OperationUpdate): """Execute any registered plugins for a given operation when an update is checkpointed diff --git a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py index c24bd96..c073ca7 100644 --- a/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py +++ b/packages/aws-durable-execution-sdk-python/src/aws_durable_execution_sdk_python/state.py @@ -952,13 +952,8 @@ def wrapper(*args, **kwargs): result = user_function(*args, **kwargs) self._plugin_executor.on_user_function_end(start_info, None) return result - except SuspendExecution as e: - self._plugin_executor.on_user_function_end( - start_info, - ErrorObject( - type=type(e).__name__, message=None, data=None, stack_trace=None - ), - ) + except SuspendExecution: + self._plugin_executor.on_user_function_suspend(start_info) raise except Exception as e: self._plugin_executor.on_user_function_end( diff --git a/packages/aws-durable-execution-sdk-python/tests/plugin_test.py b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py index c402d35..7d38fb5 100644 --- a/packages/aws-durable-execution-sdk-python/tests/plugin_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/plugin_test.py @@ -783,5 +783,54 @@ def on_operation_attempt_end(self, info): # endregion Helper Classes +# region Suspend Outcome Tests +class _CapturingPlugin(DurableInstrumentationPlugin): + """Captures the UserFunctionEndInfo passed to on_user_function_end.""" + + def __init__(self) -> None: + self.user_function_end_infos: list[UserFunctionEndInfo] = [] + + def on_user_function_end(self, info: UserFunctionEndInfo) -> None: + self.user_function_end_infos.append(info) + + +class TestUserFunctionOutcomeFromError(unittest.TestCase): + def test_none_error_is_succeeded(self): + self.assertEqual( + UserFunctionOutcome.from_error(None), UserFunctionOutcome.SUCCEEDED + ) + + def test_error_is_failed(self): + self.assertEqual( + UserFunctionOutcome.from_error(ERROR), UserFunctionOutcome.FAILED + ) + + +class TestUserFunctionEndInfoSuspended(unittest.TestCase): + def test_from_start_info_suspended_is_pending_without_error(self): + info = UserFunctionEndInfo.from_start_info_suspended(USER_FUNCTION_START_INFO) + self.assertEqual(info.outcome, UserFunctionOutcome.PENDING) + self.assertIsNone(info.error) + self.assertEqual(info.operation_id, USER_FUNCTION_START_INFO.operation_id) + self.assertEqual(info.name, USER_FUNCTION_START_INFO.name) + self.assertEqual(info.attempt, USER_FUNCTION_START_INFO.attempt) + + +class TestPluginExecutorOnUserFunctionSuspend(unittest.TestCase): + def test_suspend_dispatches_pending_outcome(self): + plugin = _CapturingPlugin() + executor = PluginExecutor(plugins=[plugin]) + with executor.run(): + executor.on_user_function_suspend(USER_FUNCTION_START_INFO) + self.assertEqual(len(plugin.user_function_end_infos), 1) + info = plugin.user_function_end_infos[0] + self.assertEqual(info.outcome, UserFunctionOutcome.PENDING) + self.assertIsNone(info.error) + self.assertEqual(info.operation_id, USER_FUNCTION_START_INFO.operation_id) + + +# endregion Suspend Outcome Tests + + if __name__ == "__main__": unittest.main() diff --git a/packages/aws-durable-execution-sdk-python/tests/state_test.py b/packages/aws-durable-execution-sdk-python/tests/state_test.py index f026d35..ab7b1d7 100644 --- a/packages/aws-durable-execution-sdk-python/tests/state_test.py +++ b/packages/aws-durable-execution-sdk-python/tests/state_test.py @@ -19,6 +19,7 @@ DurableApiErrorCategory, GetExecutionStateError, OrphanedChildException, + TimedSuspendExecution, ) from aws_durable_execution_sdk_python.identifier import OperationIdentifier from aws_durable_execution_sdk_python.lambda_service import ( @@ -41,6 +42,8 @@ from aws_durable_execution_sdk_python.plugin import ( DurableInstrumentationPlugin, PluginExecutor, + UserFunctionEndInfo, + UserFunctionOutcome, ) from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, @@ -4199,6 +4202,45 @@ def on_operation_end(self, info): executor.shutdown(wait=True) +def test_wrap_user_function_suspend_reports_pending_outcome(): + """A user function that suspends is reported as PENDING with no error. + + Regression: a timed suspend (TimedSuspendExecution) raised inside a wrapped + user function (e.g. a child context that waits) must not be surfaced to + plugins as a FAILED outcome. It is normal durable control flow. + """ + captured: list[UserFunctionEndInfo] = [] + + class _CapturingPlugin(DurableInstrumentationPlugin): + def on_user_function_end(self, info: UserFunctionEndInfo) -> None: + captured.append(info) + + plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=create_autospec(spec=LambdaClient), + plugin_executor=plugin_executor, + ) + + def suspends(_: object) -> None: + raise TimedSuspendExecution.from_delay("waiting", 5) + + op_id = OperationIdentifier( + operation_id="op-1", sub_type=OperationSubType.STEP, name="step" + ) + wrapped = state.wrap_user_function(suspends, op_id, attempt=1) + + with pytest.raises(TimedSuspendExecution): + wrapped(None) + + assert len(captured) == 1 + assert captured[0].outcome is UserFunctionOutcome.PENDING + assert captured[0].error is None + + def test_plugin_executor_not_called_for_pending_operations(): """Test that plugin_executor.on_operation_update fires on_user_function_end for PENDING operations.""" mock_client = create_autospec(spec=LambdaClient)