Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import sysconfig
from typing import Optional, Tuple
import warnings


@functools.lru_cache(maxsize=None)
Expand Down Expand Up @@ -191,6 +192,29 @@ def load_framework_extension(framework: str) -> None:
sys.modules[module_name] = solib
spec.loader.exec_module(solib)

# Plugin system: if NVTE_ENABLE_PLUGIN=1, let plugin stub take over
# transformer_engine_torch and register original pybind as _nv for CUDA backend.
# Only applies to the PyTorch extension — JAX has no plugin stub.
if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1" and framework == "torch":
_original_module = sys.modules.get(module_name)
try:
from transformer_engine_plugin_fl import load_plugins

sys.modules[module_name + "_nv"] = solib
load_plugins()
except Exception as e:
# Rollback to pre-plugin state if plugin failed to fully initialize
sys.modules.pop(module_name + "_nv", None)
if _original_module is not None:
sys.modules[module_name] = _original_module
else:
sys.modules.pop(module_name, None)
warnings.warn(
f"NVTE_ENABLE_PLUGIN=1 but plugin loading failed: {e}",
RuntimeWarning,
stacklevel=2,
)
Comment on lines +200 to +216

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Incomplete sys.modules rollback on load_plugins() failure

sys.modules[module_name] (i.e., transformer_engine_torch) is set to solib at line 192, before this try block. If load_plugins() partially succeeds — for example, replaces sys.modules["transformer_engine_torch"] with the plugin stub before raising a RuntimeError during backend registration — the except block pops _nv but leaves sys.modules["transformer_engine_torch"] pointing to the partially-initialized stub. TE then continues with a broken tex module even though the warning says plugin loading failed.

Capture the pre-attempt value of sys.modules.get(module_name) before calling load_plugins() and restore it in the except block alongside the _nv pop.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed



def sanity_checks_for_pypi_installation() -> None:
"""Ensure that package is installed correctly if using PyPI."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@
_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1"


# Plugin system: override FlashAttention and get_attention_backend if enabled
if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1":
_FlashAttentionNative = FlashAttention
FlashAttention = getattr(tex, "flash_attention", _FlashAttentionNative)
_plugin_get_attention_backend = getattr(tex, "get_attention_backend", None)
if _plugin_get_attention_backend is not None:
dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend
dpa_utils.get_attention_backend = _plugin_get_attention_backend


__all__ = ["DotProductAttention"]


Expand Down