diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 40933f17a9..bbe13790a6 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -16,6 +16,7 @@ import sys import sysconfig from typing import Optional, Tuple +import warnings @functools.lru_cache(maxsize=None) @@ -191,6 +192,29 @@ def load_framework_extension(framework: str) -> None: sys.modules[module_name] = solib spec.loader.exec_module(solib) + # Plugin system: if NVTE_ENABLE_PLUGIN=1, let plugin stub take over + # transformer_engine_torch and register original pybind as _nv for CUDA backend. + # Only applies to the PyTorch extension — JAX has no plugin stub. + if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1" and framework == "torch": + _original_module = sys.modules.get(module_name) + try: + from transformer_engine_plugin_fl import load_plugins + + sys.modules[module_name + "_nv"] = solib + load_plugins() + except Exception as e: + # Rollback to pre-plugin state if plugin failed to fully initialize + sys.modules.pop(module_name + "_nv", None) + if _original_module is not None: + sys.modules[module_name] = _original_module + else: + sys.modules.pop(module_name, None) + warnings.warn( + f"NVTE_ENABLE_PLUGIN=1 but plugin loading failed: {e}", + RuntimeWarning, + stacklevel=2, + ) + def sanity_checks_for_pypi_installation() -> None: """Ensure that package is installed correctly if using PyPI.""" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 2dc42be18a..cb2e45d3c7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -148,6 +148,16 @@ _dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1" +# Plugin system: override FlashAttention and get_attention_backend if enabled +if os.environ.get("NVTE_ENABLE_PLUGIN", "0") == "1": + _FlashAttentionNative = FlashAttention + FlashAttention = getattr(tex, "flash_attention", _FlashAttentionNative) + _plugin_get_attention_backend = getattr(tex, "get_attention_backend", None) + if _plugin_get_attention_backend is not None: + dpa_utils._original_get_attention_backend = dpa_utils.get_attention_backend + dpa_utils.get_attention_backend = _plugin_get_attention_backend + + __all__ = ["DotProductAttention"]