Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions ci/tools/validate-release-wheels
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from __future__ import annotations

import argparse
import re
import sys
from collections import defaultdict
from pathlib import Path

from check_release_notes import parse_version_from_tag

COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = {
"cuda-core": {"cuda_core"},
"cuda-bindings": {"cuda_bindings"},
Expand All @@ -22,14 +23,16 @@ COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = {
"all": {"cuda_core", "cuda_bindings", "cuda_pathfinder", "cuda_python"},
}

TAG_PATTERNS = (
re.compile(r"^v(?P<version>\d+\.\d+\.\d+)"),
re.compile(r"^cuda-core-v(?P<version>\d+\.\d+\.\d+)"),
re.compile(r"^cuda-pathfinder-v(?P<version>\d+\.\d+\.\d+)"),
)
COMPONENT_TO_TAG_COMPONENTS: dict[str, tuple[str, ...]] = {
"cuda-core": ("cuda-core",),
"cuda-bindings": ("cuda-bindings",),
"cuda-pathfinder": ("cuda-pathfinder",),
"cuda-python": ("cuda-python",),
"all": ("cuda-core", "cuda-bindings", "cuda-pathfinder", "cuda-python"),
}


def parse_args() -> argparse.Namespace:
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(
description=(
"Validate that wheel versions match the release tag. "
Expand All @@ -39,18 +42,21 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("git_tag", help="Release git tag (for example: v13.0.0)")
parser.add_argument("component", choices=sorted(COMPONENT_TO_DISTRIBUTIONS.keys()))
parser.add_argument("wheel_dir", help="Directory containing wheel files")
return parser.parse_args()
return parser.parse_args(argv)


def version_from_tag(tag: str) -> str:
for pattern in TAG_PATTERNS:
match = pattern.match(tag)
if match:
return match.group("version")
def version_from_tag(tag: str, component: str) -> str:
versions = {
version
for tag_component in COMPONENT_TO_TAG_COMPONENTS[component]
if (version := parse_version_from_tag(tag, tag_component)) is not None
}
if len(versions) == 1:
return versions.pop()
raise ValueError(
"Unsupported git tag format "
f"{tag!r}; expected tags beginning with vX.Y.Z, cuda-core-vX.Y.Z, "
"or cuda-pathfinder-vX.Y.Z."
f"{tag!r} for component {component!r}; expected vX.Y.Z, cuda-core-vX.Y.Z, "
"or cuda-pathfinder-vX.Y.Z with a valid release version."
)


Expand All @@ -62,9 +68,14 @@ def parse_wheel_dist_and_version(path: Path) -> tuple[str, str]:
return parts[0], parts[1]


def main() -> int:
args = parse_args()
expected_version = version_from_tag(args.git_tag)
def main(argv: list[str] | None = None) -> int:
args = parse_args(argv)
try:
expected_version = version_from_tag(args.git_tag, args.component)
except ValueError as exc:
print(f"Error: {exc}", file=sys.stderr)
return 1

expected_distributions = COMPONENT_TO_DISTRIBUTIONS[args.component]
wheel_dir = Path(args.wheel_dir)

Expand Down