diff --git a/ci/tools/validate-release-wheels b/ci/tools/validate-release-wheels index 5757ca17bc..397153ec77 100755 --- a/ci/tools/validate-release-wheels +++ b/ci/tools/validate-release-wheels @@ -9,11 +9,12 @@ from __future__ import annotations import argparse -import re import sys from collections import defaultdict from pathlib import Path +from check_release_notes import parse_version_from_tag + COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = { "cuda-core": {"cuda_core"}, "cuda-bindings": {"cuda_bindings"}, @@ -22,14 +23,16 @@ COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = { "all": {"cuda_core", "cuda_bindings", "cuda_pathfinder", "cuda_python"}, } -TAG_PATTERNS = ( - re.compile(r"^v(?P\d+\.\d+\.\d+)"), - re.compile(r"^cuda-core-v(?P\d+\.\d+\.\d+)"), - re.compile(r"^cuda-pathfinder-v(?P\d+\.\d+\.\d+)"), -) +COMPONENT_TO_TAG_COMPONENTS: dict[str, tuple[str, ...]] = { + "cuda-core": ("cuda-core",), + "cuda-bindings": ("cuda-bindings",), + "cuda-pathfinder": ("cuda-pathfinder",), + "cuda-python": ("cuda-python",), + "all": ("cuda-core", "cuda-bindings", "cuda-pathfinder", "cuda-python"), +} -def parse_args() -> argparse.Namespace: +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Validate that wheel versions match the release tag. " @@ -39,18 +42,21 @@ def parse_args() -> argparse.Namespace: parser.add_argument("git_tag", help="Release git tag (for example: v13.0.0)") parser.add_argument("component", choices=sorted(COMPONENT_TO_DISTRIBUTIONS.keys())) parser.add_argument("wheel_dir", help="Directory containing wheel files") - return parser.parse_args() + return parser.parse_args(argv) -def version_from_tag(tag: str) -> str: - for pattern in TAG_PATTERNS: - match = pattern.match(tag) - if match: - return match.group("version") +def version_from_tag(tag: str, component: str) -> str: + versions = { + version + for tag_component in COMPONENT_TO_TAG_COMPONENTS[component] + if (version := parse_version_from_tag(tag, tag_component)) is not None + } + if len(versions) == 1: + return versions.pop() raise ValueError( "Unsupported git tag format " - f"{tag!r}; expected tags beginning with vX.Y.Z, cuda-core-vX.Y.Z, " - "or cuda-pathfinder-vX.Y.Z." + f"{tag!r} for component {component!r}; expected vX.Y.Z, cuda-core-vX.Y.Z, " + "or cuda-pathfinder-vX.Y.Z with a valid release version." ) @@ -62,9 +68,14 @@ def parse_wheel_dist_and_version(path: Path) -> tuple[str, str]: return parts[0], parts[1] -def main() -> int: - args = parse_args() - expected_version = version_from_tag(args.git_tag) +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv) + try: + expected_version = version_from_tag(args.git_tag, args.component) + except ValueError as exc: + print(f"Error: {exc}", file=sys.stderr) + return 1 + expected_distributions = COMPONENT_TO_DISTRIBUTIONS[args.component] wheel_dir = Path(args.wheel_dir)