Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import contextlib
import datetime
import functools
Expand All @@ -7,7 +9,6 @@
from enum import Enum
from typing import Any, Callable, MutableMapping

from aws_durable_execution_sdk_python.exceptions import SuspendExecution
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import (
DurableExecutionInvocationOutput,
Expand Down Expand Up @@ -54,13 +55,10 @@ class UserFunctionOutcome(Enum):
PENDING = "PENDING"

@classmethod
def from_error(cls, error: ErrorObject | None) -> "UserFunctionOutcome":
def from_error(cls, error: ErrorObject | None) -> UserFunctionOutcome:
if error is None:
return cls(cls.SUCCEEDED)
elif error.type == SuspendExecution.__name__:
return cls(cls.PENDING)
else:
return cls(cls.FAILED)
return cls(cls.FAILED)


@dataclass(frozen=True)
Expand All @@ -86,7 +84,7 @@ class UserFunctionEndInfo(OperationInfo):
@classmethod
def from_start_info(
cls, start_info: UserFunctionStartInfo, error: ErrorObject | None
) -> "UserFunctionEndInfo":
) -> UserFunctionEndInfo:
return UserFunctionEndInfo(
operation_id=start_info.operation_id,
operation_type=start_info.operation_type,
Expand All @@ -101,6 +99,24 @@ def from_start_info(
error=error,
)

@classmethod
def from_start_info_suspended(
cls, start_info: UserFunctionStartInfo
) -> UserFunctionEndInfo:
return cls(
operation_id=start_info.operation_id,
operation_type=start_info.operation_type,
sub_type=start_info.sub_type,
name=start_info.name,
parent_id=start_info.parent_id,
start_time=start_info.start_time,
is_replay_children=start_info.is_replay_children,
attempt=start_info.attempt,
outcome=UserFunctionOutcome.PENDING,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On SuspendExecution, SDK calls CheckpointDurableExecution with STARTED, so we might want to keep this Started, and later it will emit completed from list of operations from backend.

end_time=datetime.datetime.now(datetime.UTC),
error=None,
)


@dataclass(frozen=True)
class InvocationInfo:
Expand Down Expand Up @@ -310,6 +326,16 @@ def on_user_function_end(self, start_info: UserFunctionStartInfo, error) -> None
UserFunctionEndInfo.from_start_info(start_info, error), sync=True
)

def on_user_function_suspend(self, start_info: UserFunctionStartInfo) -> None:
"""Execute any registered plugins when an operation's user function suspends.

A suspension is normal durable control flow, not a failure, so the
operation is reported with a PENDING outcome and no error.
"""
self.execute_plugins(
UserFunctionEndInfo.from_start_info_suspended(start_info), sync=True
)

def on_operation_action(self, update: OperationUpdate):
"""Execute any registered plugins for a given operation when an update is checkpointed

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -952,13 +952,8 @@ def wrapper(*args, **kwargs):
result = user_function(*args, **kwargs)
self._plugin_executor.on_user_function_end(start_info, None)
return result
except SuspendExecution as e:
self._plugin_executor.on_user_function_end(
start_info,
ErrorObject(
type=type(e).__name__, message=None, data=None, stack_trace=None
),
)
except SuspendExecution:
self._plugin_executor.on_user_function_suspend(start_info)
raise
except Exception as e:
self._plugin_executor.on_user_function_end(
Expand Down
49 changes: 49 additions & 0 deletions packages/aws-durable-execution-sdk-python/tests/plugin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,5 +783,54 @@ def on_operation_attempt_end(self, info):
# endregion Helper Classes


# region Suspend Outcome Tests
class _CapturingPlugin(DurableInstrumentationPlugin):
"""Captures the UserFunctionEndInfo passed to on_user_function_end."""

def __init__(self) -> None:
self.user_function_end_infos: list[UserFunctionEndInfo] = []

def on_user_function_end(self, info: UserFunctionEndInfo) -> None:
self.user_function_end_infos.append(info)


class TestUserFunctionOutcomeFromError(unittest.TestCase):
def test_none_error_is_succeeded(self):
self.assertEqual(
UserFunctionOutcome.from_error(None), UserFunctionOutcome.SUCCEEDED
)

def test_error_is_failed(self):
self.assertEqual(
UserFunctionOutcome.from_error(ERROR), UserFunctionOutcome.FAILED
)


class TestUserFunctionEndInfoSuspended(unittest.TestCase):
def test_from_start_info_suspended_is_pending_without_error(self):
info = UserFunctionEndInfo.from_start_info_suspended(USER_FUNCTION_START_INFO)
self.assertEqual(info.outcome, UserFunctionOutcome.PENDING)
self.assertIsNone(info.error)
self.assertEqual(info.operation_id, USER_FUNCTION_START_INFO.operation_id)
self.assertEqual(info.name, USER_FUNCTION_START_INFO.name)
self.assertEqual(info.attempt, USER_FUNCTION_START_INFO.attempt)


class TestPluginExecutorOnUserFunctionSuspend(unittest.TestCase):
def test_suspend_dispatches_pending_outcome(self):
plugin = _CapturingPlugin()
executor = PluginExecutor(plugins=[plugin])
with executor.run():
executor.on_user_function_suspend(USER_FUNCTION_START_INFO)
self.assertEqual(len(plugin.user_function_end_infos), 1)
info = plugin.user_function_end_infos[0]
self.assertEqual(info.outcome, UserFunctionOutcome.PENDING)
self.assertIsNone(info.error)
self.assertEqual(info.operation_id, USER_FUNCTION_START_INFO.operation_id)


# endregion Suspend Outcome Tests


if __name__ == "__main__":
unittest.main()
42 changes: 42 additions & 0 deletions packages/aws-durable-execution-sdk-python/tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DurableApiErrorCategory,
GetExecutionStateError,
OrphanedChildException,
TimedSuspendExecution,
)
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
from aws_durable_execution_sdk_python.lambda_service import (
Expand All @@ -41,6 +42,8 @@
from aws_durable_execution_sdk_python.plugin import (
DurableInstrumentationPlugin,
PluginExecutor,
UserFunctionEndInfo,
UserFunctionOutcome,
)
from aws_durable_execution_sdk_python.state import (
CheckpointBatcherConfig,
Expand Down Expand Up @@ -4199,6 +4202,45 @@ def on_operation_end(self, info):
executor.shutdown(wait=True)


def test_wrap_user_function_suspend_reports_pending_outcome():
"""A user function that suspends is reported as PENDING with no error.

Regression: a timed suspend (TimedSuspendExecution) raised inside a wrapped
user function (e.g. a child context that waits) must not be surfaced to
plugins as a FAILED outcome. It is normal durable control flow.
"""
captured: list[UserFunctionEndInfo] = []

class _CapturingPlugin(DurableInstrumentationPlugin):
def on_user_function_end(self, info: UserFunctionEndInfo) -> None:
captured.append(info)

plugin_executor = PluginExecutor(plugins=[_CapturingPlugin()])
with plugin_executor.run():
state = ExecutionState(
durable_execution_arn="test_arn",
initial_checkpoint_token="token123", # noqa: S106
operations={},
service_client=create_autospec(spec=LambdaClient),
plugin_executor=plugin_executor,
)

def suspends(_: object) -> None:
raise TimedSuspendExecution.from_delay("waiting", 5)

op_id = OperationIdentifier(
operation_id="op-1", sub_type=OperationSubType.STEP, name="step"
)
wrapped = state.wrap_user_function(suspends, op_id, attempt=1)

with pytest.raises(TimedSuspendExecution):
wrapped(None)

assert len(captured) == 1
assert captured[0].outcome is UserFunctionOutcome.PENDING
assert captured[0].error is None


def test_plugin_executor_not_called_for_pending_operations():
"""Test that plugin_executor.on_operation_update fires on_user_function_end for PENDING operations."""
mock_client = create_autospec(spec=LambdaClient)
Expand Down