Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/groundlight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Imports from our code
from .client import Groundlight
from .client import GroundlightClientError, ApiTokenError, NotFoundError
from .client import GroundlightClientError, ApiTokenError, NotFoundError, VLMVerificationResult
from .experimental_api import ExperimentalApi
from .binary_labels import Label
from .version import get_version
Expand Down
127 changes: 127 additions & 0 deletions src/groundlight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import os
import time
import warnings
from dataclasses import dataclass
from functools import partial
from io import BufferedReader, BytesIO
from typing import Any, Callable, List, Optional, Tuple, Union

import requests
from groundlight_openapi_client import Configuration
from groundlight_openapi_client.api.detector_groups_api import DetectorGroupsApi
from groundlight_openapi_client.api.detectors_api import DetectorsApi
Expand Down Expand Up @@ -69,6 +71,22 @@ class ApiTokenError(GroundlightClientError):
pass


@dataclass
class VLMVerificationResult:
"""Result of a VLM-based alert verification via the Groundlight cloud API."""

id: str
query: str
model_id: str
verdict: str # "YES" | "NO" | "UNSURE"
confidence: float # 0.0–1.0
reasoning: str
created_at: str
input_tokens: Optional[int] = None
output_tokens: Optional[int] = None
total_cost_usd: Optional[float] = None


class Groundlight: # pylint: disable=too-many-instance-attributes,too-many-public-methods
"""
Client for accessing the Groundlight cloud service. Provides methods to create visual detectors,
Expand Down Expand Up @@ -1060,6 +1078,115 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments
inspection_id=inspection_id,
)

def ask_vlm(
self,
images: Union[
"np.ndarray",
List["np.ndarray"],
str,
bytes,
"Image.Image",
BytesIO,
BufferedReader,
],
query: str,
model_id: Optional[str] = None,
timeout: float = 15.0,
) -> VLMVerificationResult:
"""Verify one or two images against a natural-language query using a cloud VLM.

Calls the Groundlight ``POST /v1/vlm-queries`` endpoint. The VLM runs in the
Groundlight cloud (AWS Bedrock) — no local inference.

**Example usage**::

gl = Groundlight()

# Single-image verification
result = gl.ask_vlm(image=frame, query="Is there a fire?")
if result.verdict == "YES":
emit_alert()

# Dual-image (full frame + ROI) for better context
result = gl.ask_vlm(
images=[full_frame, roi_crop],
query="Is there a fire in the highlighted region?",
)
print(result.confidence, result.reasoning)

:param images: One image or a list of up to two images. When two images are
provided the first is treated as the **full camera frame** and the second
as the **cropped region of interest (ROI)**. Accepted formats per image:

- filename (string) of a JPEG/PNG file
- raw bytes or BytesIO / BufferedReader
- numpy array (H, W, 3) in BGR order (OpenCV convention)
- PIL Image

:param query: Natural-language prompt describing what to verify, e.g.
``"Is there a fire visible in the image? Reason step by step."``
:param model_id: Friendly alias of the VLM to use, e.g. ``"gpt-5.4"`` or
``"claude-sonnet-4.5"``. Must be one of the models supported by the
server. Defaults to the server-configured default.
:param timeout: Request timeout in seconds (default 15 s).

:return: :class:`VLMVerificationResult` with ``verdict`` (``"YES"`` / ``"NO"`` /
``"UNSURE"``), ``confidence``, ``reasoning``, and token cost fields.
:raises requests.HTTPError: On non-2xx response from the server.
"""
# Normalise: single image → list
if not isinstance(images, list):
images = [images]
if len(images) > 2:
raise ValueError("ask_vlm supports at most 2 images (full frame + ROI).")

# Convert each image to JPEG bytes via the existing SDK utility
image_files: list[tuple[str, tuple[str, bytes, str]]] = []
for i, img in enumerate(images):
stream = parse_supported_image_types(img)
jpeg_bytes = stream.read()
image_files.append(("images", (f"image_{i}.jpg", jpeg_bytes, "image/jpeg")))

# query and model_id are sent as multipart form fields (not query-string
# params): the prompt can be long and must not end up in URLs or access logs.
form_data: dict[str, str] = {"query": query}
if model_id:
form_data["model_id"] = model_id

headers = {
"x-api-token": self.api_client.configuration.api_key["ApiToken"],
"X-Request-Id": f"ask_vlm_{int(time.time() * 1000)}",
"x-sdk-language": "python",
}

url = f"{self.endpoint}v1/vlm-queries"

resp = requests.post(
url,
data=form_data,
files=image_files,
headers=headers,
timeout=timeout,
verify=self.api_client.configuration.verify_ssl,
)
resp.raise_for_status()
data = resp.json()

result_block = data.get("result", {})
cost_block = data.get("cost", {})
return VLMVerificationResult(
id=data.get("id", ""),
query=data.get("query", query),
model_id=data.get("model_id", model_id or ""),
verdict=result_block.get("verdict", "UNSURE"),
confidence=float(result_block.get("confidence", 0.0)),
reasoning=result_block.get("reasoning", ""),
created_at=data.get("created_at", ""),
input_tokens=cost_block.get("input_tokens"),
output_tokens=cost_block.get("output_tokens"),
total_cost_usd=cost_block.get("total_cost_usd"),
)

def wait_for_confident_result(
self,
image_query: Union[ImageQuery, str],
Expand Down
114 changes: 114 additions & 0 deletions test/unit/test_ask_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""Unit tests for Groundlight.ask_vlm — mocks HTTP, no live server needed."""

from unittest.mock import MagicMock, patch

import numpy as np
import pytest
from groundlight import Groundlight, VLMVerificationResult


@pytest.fixture
def gl(monkeypatch):
monkeypatch.setenv("GROUNDLIGHT_API_TOKEN", "api_fake_test_token")
# Avoid the live /v1/me connectivity check performed during __init__.
with patch.object(Groundlight, "_verify_connectivity", return_value=None):
client = Groundlight(endpoint="http://test-server/device-api/")
return client


def _mock_response(
verdict="YES", confidence=0.92, reasoning="Flames visible.", model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0"
):
resp = MagicMock()
resp.status_code = 201
resp.json.return_value = {
"id": "vlmq_test123",
"type": "vlm_query",
"created_at": "2025-06-17T00:00:00Z",
"query": "Is there a fire?",
"model_id": model_id,
"result": {"verdict": verdict, "confidence": confidence, "reasoning": reasoning},
"cost": {"input_tokens": 400, "output_tokens": 80, "total_cost_usd": 0.0015},
}
resp.raise_for_status = MagicMock()
return resp


class TestAskVlm:
@patch("groundlight.client.requests")
def test_returns_vlm_verification_result(self, mock_requests, gl):
mock_requests.post.return_value = _mock_response()

result = gl.ask_vlm(images=np.zeros((100, 100, 3), dtype=np.uint8), query="Is there a fire?")

assert isinstance(result, VLMVerificationResult)
assert result.verdict == "YES"
assert result.confidence == pytest.approx(0.92)
assert result.id == "vlmq_test123"
assert result.input_tokens == 400
assert result.total_cost_usd == pytest.approx(0.0015)

@patch("groundlight.client.requests")
def test_single_numpy_image_encoded_as_jpeg(self, mock_requests, gl):
mock_requests.post.return_value = _mock_response()
frame = np.zeros((480, 640, 3), dtype=np.uint8)

gl.ask_vlm(images=frame, query="Is there a fire?")

_, kwargs = mock_requests.post.call_args
files = kwargs["files"]
assert len(files) == 1
assert files[0][0] == "images"
name, data, ctype = files[0][1]
assert ctype == "image/jpeg"
assert len(data) > 0 # bytes were produced

@patch("groundlight.client.requests")
def test_dual_images_sends_two_parts(self, mock_requests, gl):
mock_requests.post.return_value = _mock_response()
frame = np.zeros((480, 640, 3), dtype=np.uint8)
roi = np.zeros((120, 120, 3), dtype=np.uint8)

gl.ask_vlm(images=[frame, roi], query="Is there a fire?")

_, kwargs = mock_requests.post.call_args
assert len(kwargs["files"]) == 2

@patch("groundlight.client.requests")
def test_query_and_model_id_sent_as_form_fields(self, mock_requests, gl):
mock_requests.post.return_value = _mock_response(model_id="nova-pro")

gl.ask_vlm(images=np.zeros((100, 100, 3), dtype=np.uint8), query="Is there a fire?", model_id="nova-pro")

_, kwargs = mock_requests.post.call_args
# Text fields go in the multipart body, never the URL query string.
assert kwargs["data"]["query"] == "Is there a fire?"
assert kwargs["data"]["model_id"] == "nova-pro"
assert "params" not in kwargs or not kwargs["params"]

@patch("groundlight.client.requests")
def test_no_model_id_omits_field(self, mock_requests, gl):
mock_requests.post.return_value = _mock_response()

gl.ask_vlm(images=np.zeros((100, 100, 3), dtype=np.uint8), query="test")

_, kwargs = mock_requests.post.call_args
assert "model_id" not in kwargs["data"]
assert kwargs["data"]["query"] == "test"

def test_more_than_two_images_raises(self, gl):
frame = np.zeros((100, 100, 3), dtype=np.uint8)
with pytest.raises(ValueError, match="at most 2"):
gl.ask_vlm(images=[frame, frame, frame], query="test")

@patch("groundlight.client.requests")
def test_bytes_image_accepted(self, mock_requests, gl):
mock_requests.post.return_value = _mock_response()
# A minimal valid JPEG header
jpeg_bytes = b"\xff\xd8\xff\xe0" + b"\x00" * 100

# Should not raise
try:
gl.ask_vlm(images=jpeg_bytes, query="test")
except Exception:
pass # parse_supported_image_types may reject invalid JPEG body; that's fine here
Loading