diff --git a/libensemble/executors/__init__.py b/libensemble/executors/__init__.py index 563fa3352..fdcae5c87 100644 --- a/libensemble/executors/__init__.py +++ b/libensemble/executors/__init__.py @@ -1,4 +1,5 @@ from libensemble.executors.executor import Executor +from libensemble.executors.globus_compute_executor import GlobusComputeExecutor, GlobusComputeTask from libensemble.executors.mpi_executor import MPIExecutor -__all__ = ["Executor", "MPIExecutor"] +__all__ = ["Executor", "GlobusComputeExecutor", "GlobusComputeTask", "MPIExecutor"] diff --git a/libensemble/executors/globus_compute_executor.py b/libensemble/executors/globus_compute_executor.py new file mode 100644 index 000000000..6a2c718dc --- /dev/null +++ b/libensemble/executors/globus_compute_executor.py @@ -0,0 +1,279 @@ +import logging +import os +from concurrent.futures import Future, TimeoutError +from typing import Any + +from libensemble.executors.executor import Application, Executor, ExecutorException, Task, TimeoutExpired +from libensemble.utils.globus_compute import GCSession +from libensemble.utils.timer import TaskTimer + +logger = logging.getLogger(__name__) + + +class GlobusComputeTask(Task): + """A :class:`~libensemble.executors.executor.Task` wrapping a + ``concurrent.futures.Future`` returned by Globus Compute. + + Instead of managing a local subprocess, this task polls a remote + computation via the future's ``done()`` / ``result()`` APIs. + """ + + def __init__(self, future, app=None, app_args=None, workerid=None): + self.id = next(Task.newid) + self.reset() + self.timer = TaskTimer() + self.app = app + self.app_args = app_args + self.workerID = workerid + self._gc_future = future + + worker_name = f"_worker{self.workerID}" if self.workerID else "" + self.name = Task.prefix + f"_{app.name}{worker_name}_{self.id}" + self.stdout = "" + self.stderr = "" + self.workdir = None + self.dry_run = False + self.runline = None + self.run_attempts = 0 + self.env = {} + self.ngpus_req = 0 + + self.state = "RUNNING" + self.timer.start() + self.submit_time = self.timer.tstart + + def _check_poll(self): + if self.finished: + return False + return True + + def poll(self): + if not self._check_poll(): + return + if self._gc_future.done(): + try: + self._gc_future.result() + self.finished = True + self.success = True + self.state = "FINISHED" + except Exception: + self.finished = True + self.success = False + self.state = "FAILED" + self.calc_task_timing() + else: + self.state = "RUNNING" + self.runtime = self.timer.elapsed + + def wait(self, timeout=None): + if not self._check_poll(): + return + try: + self._gc_future.result(timeout=timeout) + self.finished = True + self.success = True + self.state = "FINISHED" + except TimeoutError: + raise TimeoutExpired(self.name, timeout) + except Exception: + self.finished = True + self.success = False + self.state = "FAILED" + self.calc_task_timing() + + def kill(self, wait_time=None): + self._gc_future.cancel() + self.state = "USER_KILLED" + self.finished = True + self.calc_task_timing() + + def result(self, timeout=None): + self.wait(timeout=timeout) + return self.state + + def running(self): + self.poll() + return self.state == "RUNNING" + + def done(self): + self.poll() + return self.finished + + def cancelled(self): + self.poll() + return self.state == "USER_KILLED" + + +class GlobusComputeExecutor(Executor): + """An :class:`~libensemble.executors.executor.Executor` that submits + Python callables to Globus Compute instead of launching local subprocesses. + + Usage in a top-level script:: + + from libensemble.executors.globus_compute_executor import GlobusComputeExecutor + + exctr = GlobusComputeExecutor(endpoint_id="...") + + Inside a simulator function:: + + task = info["executor"].submit(func=my_remote_func, app_args=...) + while not task.finished: + task.poll() + if info["executor"].manager_kill_received(): + task.kill() + break + time.sleep(0.1) + """ + + def __init__(self, endpoint_id: str): + self.manager_signal = None + self.default_apps: dict[str, Application | None] = {"sim": None, "gen": None} + self.apps: dict[str, Application] = {} + self.wait_time = 60 + self.list_of_tasks: list[GlobusComputeTask] = [] + self.workerID = None + self.comm = None + self.last_task = 0 + self.base_dir = os.getcwd() + + self.endpoint_id = endpoint_id + self._gc_executor = None + self._func_cache: dict[int, str] = {} + + def _ensure_gc(self): + if self._gc_executor is None: + self._gc_executor = GCSession.get_or_create_executor(self.endpoint_id) + return self._gc_executor + + def _get_func_id(self, func) -> str: + key = id(func) + if key in self._func_cache: + return self._func_cache[key] + executor = self._ensure_gc() + if executor is None: + raise RuntimeError( + "Globus Compute SDK is not installed. " "Install it with: pip install globus-compute-sdk" + ) + fid = executor.register_function(func) + self._func_cache[key] = fid + return fid + + def register_app( + self, + full_path: str, + app_name: str | None = None, + calc_type: str | None = None, + desc: str | None = None, + precedent: str = "", + pyobj: Any | None = None, + ) -> None: + """Register an application. + + If *pyobj* is provided the application is treated as a remote + Python callable. Otherwise the base-class behaviour applies + (local executable). + """ + if not app_name: + app_name = os.path.split(full_path)[1] + + app = Application(full_path, app_name, calc_type, desc, pyobj, precedent) + self.apps[app_name] = app + + if calc_type is not None: + if calc_type not in self.default_apps: + raise ExecutorException(f"Unrecognized calculation type {calc_type}") + self.default_apps[calc_type] = app + + def submit( + self, + calc_type: str | None = None, + app_name: str | None = None, + app_args: str | None = None, + func: Any = None, + stdout: str | None = None, + stderr: str | None = None, + dry_run: bool = False, + wait_on_start: bool = False, + **kwargs, + ) -> GlobusComputeTask: + """Submit a function or registered application to Globus Compute. + + Parameters + ---------- + calc_type : str, optional + Calculation type (``"sim"`` or ``"gen"``). Used with *app_name*. + app_name : str, optional + Name of a previously registered application. + app_args : str, optional + Arguments passed alongside the function. + func : Callable, optional + A Python callable to execute remotely. Takes precedence over + *app_name* / *calc_type*. + stdout, stderr : str, optional + Ignored (stubs for API compatibility). + dry_run : bool, optional + If True, return a task without actually submitting. + wait_on_start : bool, optional + If True, block until the task is reported as started. + + Returns + ------- + GlobusComputeTask + """ + if dry_run: + raise NotImplementedError("dry_run is not supported for GlobusComputeExecutor") + + if func is not None: + fid = self._get_func_id(func) + app = Application(full_path="", name=func.__name__, calc_type="sim", pyobj=func) + elif app_name is not None: + app = self.get_app(app_name) + if app.pyobj is not None: + fid = self._get_func_id(app.pyobj) + else: + raise ValueError( + f"Application '{app_name}' has no pyobj callable registered. " + "Use the `func=...` argument, or register an app with `pyobj=`." + ) + elif calc_type is not None: + app = self.default_app(calc_type) + if app.pyobj is not None: + fid = self._get_func_id(app.pyobj) + else: + raise ValueError( + f"Default {calc_type} app has no pyobj callable. " + "Use the `func=...` argument, or register an app with `pyobj=`." + ) + else: + raise ValueError("One of `func`, `app_name`, or `calc_type` must be provided") + + args = app_args + future: Future = self._ensure_gc().submit_to_registered_function(fid, args) + task = GlobusComputeTask(future, app=app, app_args=args, workerid=self.workerID) + self.list_of_tasks.append(task) + + if wait_on_start: + task.wait() + + return task + + def set_workerID(self, workerid) -> None: + """Sets the worker ID for this executor.""" + self.workerID = workerid + + def set_worker_info(self, comm=None, workerid=None) -> None: + """Sets worker info for this executor.""" + self.workerID = workerid + self.comm = comm + + def serial_setup(self): + pass + + def set_resources(self, resources): + pass + + def add_platform_info(self, platform_info=None): + pass + + def set_gen_procs_gpus(self, libE_info): + pass diff --git a/libensemble/tests/unit_tests/test_globus_compute.py b/libensemble/tests/unit_tests/test_globus_compute.py new file mode 100644 index 000000000..3dab45f62 --- /dev/null +++ b/libensemble/tests/unit_tests/test_globus_compute.py @@ -0,0 +1,258 @@ +from unittest import mock + +import pytest + +from libensemble.executors.globus_compute_executor import ( + GlobusComputeExecutor, + GlobusComputeTask, +) +from libensemble.utils.globus_compute import GCSession + + +class TestGCSession: + def setup_method(self): + GCSession.clear() + + def test_get_or_create_executor(self): + with mock.patch.object(GCSession, "_create_executor") as mock_create: + mock_exec = mock.MagicMock() + mock_create.return_value = mock_exec + + ex1 = GCSession.get_or_create_executor("ep-1") + assert ex1 is mock_exec + mock_create.assert_called_once_with("ep-1") + + ex2 = GCSession.get_or_create_executor("ep-1") + assert ex2 is mock_exec + mock_create.assert_called_once() + + def test_get_or_create_caches_func_id(self): + with mock.patch.object(GCSession, "_create_executor") as mock_create: + mock_exec = mock.MagicMock() + mock_exec.register_function.return_value = "fid-42" + mock_create.return_value = mock_exec + + def my_func(): + pass + + ex1, fid1 = GCSession.get_or_create("ep-1", my_func) + assert ex1 is mock_exec + assert fid1 == "fid-42" + mock_exec.register_function.assert_called_once_with(my_func) + + ex2, fid2 = GCSession.get_or_create("ep-1", my_func) + assert ex2 is mock_exec + assert fid2 == "fid-42" + mock_exec.register_function.assert_called_once() + + def test_register_function(self): + with mock.patch.object(GCSession, "_create_executor") as mock_create: + mock_exec = mock.MagicMock() + mock_exec.register_function.return_value = "fid-99" + mock_create.return_value = mock_exec + + def f(): + pass + + ex, fid = GCSession.register_function("ep-1", f) + assert ex is mock_exec + assert fid == "fid-99" + mock_exec.register_function.assert_called_once_with(f) + + def test_module_not_found_returns_none(self): + GCSession._create_executor = classmethod(lambda cls, eid: None) + + ex = GCSession.get_or_create_executor("ep-1") + assert ex is None + + ex, fid = GCSession.get_or_create("ep-1", lambda: None) + assert ex is None + assert fid is None + + def test_thread_safety(self): + import threading + + with mock.patch.object(GCSession, "_create_executor") as mock_create: + mock_exec = mock.MagicMock() + mock_exec.register_function.return_value = "fid" + mock_create.return_value = mock_exec + + errors = [] + + def access(): + try: + for _ in range(100): + GCSession.get_or_create_executor("ep-t") + GCSession.get_or_create("ep-t", lambda: None) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=access) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, f"Thread safety errors: {errors}" + + +class TestGlobusComputeTask: + def make_task(self, future=None, app=None): + if app is None: + from libensemble.executors.executor import Application + + app = Application("", name="test_func", calc_type="sim", pyobj=lambda: None) + if future is None: + future = mock.MagicMock() + future.done.return_value = False + return GlobusComputeTask(future, app=app) + + def test_initial_state(self): + task = self.make_task() + assert task.state == "RUNNING" + assert not task.finished + assert task._gc_future is not None + + def test_poll_running(self): + future = mock.MagicMock() + future.done.return_value = False + task = self.make_task(future=future) + task.poll() + assert task.state == "RUNNING" + assert not task.finished + + def test_poll_finished_success(self): + future = mock.MagicMock() + future.done.return_value = True + future.result.return_value = None + task = self.make_task(future=future) + task.poll() + assert task.state == "FINISHED" + assert task.finished + assert task.success + + def test_poll_finished_failure(self): + future = mock.MagicMock() + future.done.return_value = True + future.result.side_effect = RuntimeError("boom") + task = self.make_task(future=future) + task.poll() + assert task.state == "FAILED" + assert task.finished + assert not task.success + + def test_wait_timeout(self): + future = mock.MagicMock() + future.result.side_effect = TimeoutError("timed out") + task = self.make_task(future=future) + with pytest.raises(Exception, match="timed out"): + task.wait(timeout=0.001) + + def test_kill(self): + future = mock.MagicMock() + task = self.make_task(future=future) + task.kill() + assert task.state == "USER_KILLED" + assert task.finished + future.cancel.assert_called_once() + + def test_running(self): + future = mock.MagicMock() + future.done.return_value = False + task = self.make_task(future=future) + assert task.running() + + def test_done(self): + future = mock.MagicMock() + future.done.return_value = True + future.result.return_value = None + task = self.make_task(future=future) + assert task.done() + + def test_not_done(self): + future = mock.MagicMock() + future.done.return_value = False + task = self.make_task(future=future) + assert not task.done() + + +class TestGlobusComputeExecutor: + def setup_method(self): + GCSession.clear() + + def test_init(self): + exctr = GlobusComputeExecutor(endpoint_id="ep-test") + assert exctr.endpoint_id == "ep-test" + assert exctr._gc_executor is None + assert exctr.workerID is None + + def test_submit_with_func(self): + with mock.patch.object(GCSession, "_create_executor") as mock_create: + mock_exec = mock.MagicMock() + mock_exec.register_function.return_value = "fid-xyz" + mock_create.return_value = mock_exec + + exctr = GlobusComputeExecutor(endpoint_id="ep-test") + exctr._ensure_gc() + + future_mock = mock.MagicMock() + mock_exec.submit_to_registered_function.return_value = future_mock + + def my_func(x): + return x * 2 + + task = exctr.submit(func=my_func, app_args="hello") + assert isinstance(task, GlobusComputeTask) + assert task._gc_future is future_mock + assert task.app is not None + assert task.app.name == "my_func" + + def test_submit_with_registered_app_pyobj(self): + with mock.patch.object(GCSession, "_create_executor") as mock_create: + mock_exec = mock.MagicMock() + mock_exec.register_function.return_value = "fid-app" + mock_create.return_value = mock_exec + + exctr = GlobusComputeExecutor(endpoint_id="ep-test") + exctr._ensure_gc() + + def app_func(): + return 42 + + exctr.register_app("/fake/path", app_name="myapp", calc_type="sim", pyobj=app_func) + + future_mock = mock.MagicMock() + mock_exec.submit_to_registered_function.return_value = future_mock + + task = exctr.submit(app_name="myapp") + assert isinstance(task, GlobusComputeTask) + assert task.app.name == "myapp" + + def test_submit_without_func_or_app_raises(self): + exctr = GlobusComputeExecutor(endpoint_id="ep-test") + with pytest.raises(ValueError): + exctr.submit() + + def test_register_function_caching(self): + with mock.patch.object(GCSession, "_create_executor") as mock_create: + mock_exec = mock.MagicMock() + mock_exec.register_function.return_value = "fid-cached" + mock_create.return_value = mock_exec + + exctr = GlobusComputeExecutor(endpoint_id="ep-test") + exctr._ensure_gc() + + def my_func(): + pass + + fid1 = exctr._get_func_id(my_func) + fid2 = exctr._get_func_id(my_func) + assert fid1 == fid2 + assert mock_exec.register_function.call_count == 1 + + def test_register_app_no_pyobj(self): + exctr = GlobusComputeExecutor(endpoint_id="ep-test") + exctr.register_app("/bin/echo", app_name="echo", calc_type="sim") + app = exctr.get_app("echo") + assert app.pyobj is None + assert app.full_path == "/bin/echo" diff --git a/libensemble/utils/globus_compute.py b/libensemble/utils/globus_compute.py new file mode 100644 index 000000000..c3341f304 --- /dev/null +++ b/libensemble/utils/globus_compute.py @@ -0,0 +1,87 @@ +import logging +import threading + +logger = logging.getLogger(__name__) + + +class GCSession: + """Per-process singleton cache for Globus Compute executors. + + Caches executor instances keyed by endpoint_id, ensuring only one + executor per endpoint per process. Thread-safe via ``threading.Lock``. + """ + + _instances: dict[str, tuple] = {} + _lock = threading.Lock() + + @classmethod + def get_or_create_executor(cls, endpoint_id: str): + """Get or create a cached executor for the given endpoint. + + Unlike :meth:`get_or_create`, this does **not** register a function. + """ + with cls._lock: + if endpoint_id in cls._instances: + return cls._instances[endpoint_id][0] + + executor = cls._create_executor(endpoint_id) + if executor is None: + return None + + cls._instances[endpoint_id] = (executor, None) + return executor + + @classmethod + def get_or_create(cls, endpoint_id: str, func): + """Get or create a cached ``(executor, func_id)`` pair. + + The first call for an endpoint creates the executor and registers + the callable. Subsequent calls return the cached pair (the + registered function is re-used). + """ + with cls._lock: + if endpoint_id in cls._instances: + executor, existing_fid = cls._instances[endpoint_id] + if existing_fid is not None: + return executor, existing_fid + func_id = executor.register_function(func) + cls._instances[endpoint_id] = (executor, func_id) + return executor, func_id + + executor = cls._create_executor(endpoint_id) + if executor is None: + return None, None + + func_id = executor.register_function(func) + cls._instances[endpoint_id] = (executor, func_id) + return executor, func_id + + @classmethod + def register_function(cls, endpoint_id: str, func): + """Register an additional function with an existing executor. + + Returns ``(executor, func_id)``. Unlike :meth:`get_or_create`, + this always registers and never caches the func_id (caller should + cache it themselves). + """ + executor = cls.get_or_create_executor(endpoint_id) + if executor is None: + return None, None + func_id = executor.register_function(func) + return executor, func_id + + @classmethod + def _create_executor(cls, endpoint_id: str): + try: + from globus_compute_sdk import Executor + except ModuleNotFoundError: + logger.warning("Globus Compute use detected but Globus Compute not importable. " "Is it installed?") + logger.warning("Running function evaluations normally on local resources.") + return None + return Executor(endpoint_id=endpoint_id) + + @classmethod + def clear(cls): + """Clear the cache (primarily for testing).""" + with cls._lock: + cls._instances.clear()