From 4c6083e7b8ff711ba9186f67907353212ba41c93 Mon Sep 17 00:00:00 2001 From: Luke Schaefer Date: Tue, 12 May 2026 13:18:25 -0500 Subject: [PATCH 1/8] add eval capes to sdk --- docs/index.rst | 34 +++ nucleus/__init__.py | 67 ++++++ nucleus/data_transfer_object/evaluation_v2.py | 145 ++++++++++++ nucleus/evaluation_v2.py | 210 ++++++++++++++++++ tests/test_evaluation_v2.py | 170 ++++++++++++++ 5 files changed, 626 insertions(+) create mode 100644 nucleus/data_transfer_object/evaluation_v2.py create mode 100644 nucleus/evaluation_v2.py create mode 100644 tests/test_evaluation_v2.py diff --git a/docs/index.rst b/docs/index.rst index 698ef59a..a33f704f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,6 +12,40 @@ Scale Nucleus helps you: Nucleus is a new way—the right way—to develop ML models, helping us move away from the concept of one dataset and towards a paradigm of collections of scenarios. +.. _evaluations-v2: + +Evaluations V2 +-------------- + +Evaluation V2 runs COCO-style metrics against stored matches (``evaluation_match_v2``) for a **model run**. +Create an evaluation with :meth:`NucleusClient.create_evaluation_v2`; poll with +:meth:`nucleus.evaluation_v2.EvaluationV2.wait_for_completion`; then fetch aggregates via +:meth:`nucleus.evaluation_v2.EvaluationV2.charts` or per-row examples via +:meth:`nucleus.evaluation_v2.EvaluationV2.examples`. + +.. code-block:: python + + import nucleus + + client = nucleus.NucleusClient(api_key="YOUR_API_KEY") + evaluation = client.create_evaluation_v2( + model_run_id="run_xxx", + name="my-eval", + allowed_label_matches=[ + nucleus.AllowedLabelMatch( + ground_truth_label="car", + model_prediction_label="vehicle", + ), + ], + ) + evaluation.wait_for_completion() + charts = evaluation.charts(iou_threshold=0.5) + fps = evaluation.examples(match_type="FP", limit=20) + +The API uses REST endpoints ``/nucleus/modelRun/:id/evaluationsV2``, +``/nucleus/evaluationsV2/:id``, ``/nucleus/evaluationsV2/:id/charts``, and +``POST /nucleus/evaluationsV2/:id/examples``. + .. _installation: Installation diff --git a/nucleus/__init__.py b/nucleus/__init__.py index d7ee51db..6d675433 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -2,6 +2,7 @@ __all__ = [ "AsyncJob", + "AllowedLabelMatch", "EmbeddingsExportJob", "BoxAnnotation", "DeduplicationJob", @@ -17,6 +18,12 @@ "DatasetInfo", "DatasetItem", "DatasetItemRetrievalError", + "EvaluationV2", + "EvaluationV2Charts", + "EvaluationV2ExamplesPage", + "EvaluationV2FilterArgs", + "EvaluationV2MatchExample", + "EvaluationV2Status", "Frame", "Keypoint", "KeypointsAnnotation", @@ -129,6 +136,12 @@ ) from .data_transfer_object.dataset_details import DatasetDetails from .data_transfer_object.dataset_info import DatasetInfo +from .data_transfer_object.evaluation_v2 import ( + EvaluationV2Charts, + EvaluationV2ExamplesPage, + EvaluationV2FilterArgs, + EvaluationV2MatchExample, +) from .data_transfer_object.job_status import JobInfoRequestPayload from .dataset import Dataset from .dataset_item import DatasetItem @@ -138,6 +151,7 @@ DeduplicationStats, ) from .deprecation_warning import deprecated +from .evaluation_v2 import AllowedLabelMatch, EvaluationV2, EvaluationV2Status from .errors import ( DatasetItemRetrievalError, ModelCreationError, @@ -875,6 +889,59 @@ def commit_model_run( payload = {} return self.make_request(payload, f"modelRun/{model_run_id}/commit") + def create_evaluation_v2( + self, + model_run_id: str, + *, + name: Optional[str] = None, + allowed_label_matches: Optional[List[AllowedLabelMatch]] = None, + allowed_label_matches_id: Optional[str] = None, + ) -> EvaluationV2: + """Create an Evaluation V2 job for a model run. + + Starts a Temporal workflow that fills ``evaluation_match_v2``. Use + :meth:`EvaluationV2.wait_for_completion` then :meth:`EvaluationV2.charts` + or :meth:`EvaluationV2.examples` for results. + + Parameters: + model_run_id: Nucleus model run id (``run_*``). + name: Optional human-readable name. + allowed_label_matches: Optional explicit allowed label pairs; omit to use + the model run's default configuration. + allowed_label_matches_id: Optional existing allowed-label-matches config id. + + Returns: + :class:`EvaluationV2` loaded via ``GET /nucleus/evaluationsV2/:id``. + """ + payload: Dict[str, Any] = {} + if name is not None: + payload["name"] = name + if allowed_label_matches: + payload[ + "allowed_label_matches" + ] = [m.to_api_dict() for m in allowed_label_matches] + if allowed_label_matches_id is not None: + payload["allowed_label_matches_id"] = allowed_label_matches_id + result = self.make_request( + payload, f"modelRun/{model_run_id}/evaluationsV2" + ) + eval_id = result.get("evaluation_id") + if not eval_id: + raise RuntimeError(f"Unexpected create evaluation V2 response: {result}") + return self.get_evaluation_v2(str(eval_id)) + + def get_evaluation_v2(self, evaluation_id: str) -> EvaluationV2: + """Fetch a single Evaluation V2 row.""" + data = self.get(f"evaluationsV2/{evaluation_id}") + return EvaluationV2.from_json(data, self) + + def list_evaluations_v2(self, model_run_id: str) -> List[EvaluationV2]: + """List Evaluation V2 rows for a model run (newest first).""" + rows = self.get(f"modelRun/{model_run_id}/evaluationsV2") + if not isinstance(rows, list): + return [] + return [EvaluationV2.from_json(r, self) for r in rows] + @deprecated(msg="Prefer calling Dataset.info() directly.") def dataset_info(self, dataset_id: str): dataset = self.get_dataset(dataset_id) diff --git a/nucleus/data_transfer_object/evaluation_v2.py b/nucleus/data_transfer_object/evaluation_v2.py new file mode 100644 index 00000000..7524e32e --- /dev/null +++ b/nucleus/data_transfer_object/evaluation_v2.py @@ -0,0 +1,145 @@ +"""Pydantic models for Nucleus Evaluations V2 REST payloads.""" + +from typing import Any, Dict, List, Literal, Optional + +from nucleus.pydantic_base import DictCompatibleModel + + +class RangeNum(DictCompatibleModel): + min: Optional[float] = None + max: Optional[float] = None + + +class MetadataPredicate(DictCompatibleModel): + key: str + op: Literal["EQ", "IN", "GT", "LT"] + value: Optional[Any] = None + + +class EvaluationV2FilterArgs(DictCompatibleModel): + """Filter object for charts/examples calls (mirrors server evaluation_v2 SQL filters).""" + + confidence_range: Optional[RangeNum] = None + iou_range: Optional[RangeNum] = None + pred_labels: Optional[List[str]] = None + gt_labels: Optional[List[str]] = None + item_metadata: Optional[List[MetadataPredicate]] = None + prediction_metadata: Optional[List[MetadataPredicate]] = None + label_equality: Optional[Literal["EQ", "NEQ"]] = None + has_ground_truth: Optional[bool] = None + tide_background: Optional[bool] = None + + def to_api_filters(self) -> Dict[str, Any]: + """Serialize to camelCase keys expected by the GraphQL / REST layer.""" + d = self.dict(exclude_none=True) + # pydantic v1 uses snake_case fields; server expects camelCase in JSON filters + out: Dict[str, Any] = {} + if "confidence_range" in d: + out["confidenceRange"] = d["confidence_range"] + if "iou_range" in d: + out["iouRange"] = d["iou_range"] + if "pred_labels" in d: + out["predLabels"] = d["pred_labels"] + if "gt_labels" in d: + out["gtLabels"] = d["gt_labels"] + if "item_metadata" in d: + out["itemMetadata"] = d["item_metadata"] + if "prediction_metadata" in d: + out["predictionMetadata"] = d["prediction_metadata"] + if "label_equality" in d: + out["labelEquality"] = d["label_equality"] + if "has_ground_truth" in d: + out["hasGroundTruth"] = d["has_ground_truth"] + if "tide_background" in d: + out["tideBackground"] = d["tide_background"] + return out + + +class MapSummary(DictCompatibleModel): + mapAt50: Optional[float] = None + mapAt75: Optional[float] = None + mapAt5095: Optional[float] = None + + +class PerClassAp(DictCompatibleModel): + classLabel: str + ap: float + + +class ConfusionEntry(DictCompatibleModel): + gtLabel: str + predLabel: str + count: int + + +class ScoreHistogramBucket(DictCompatibleModel): + bucketMin: float + bucketMax: float + count: int + + +class TotalCounts(DictCompatibleModel): + tp: int + fp: int + fn: int + predsWithConfidence: int + + +class ApBySize(DictCompatibleModel): + small: Optional[float] = None + medium: Optional[float] = None + large: Optional[float] = None + + +class PrCurvePoint(DictCompatibleModel): + classLabel: str + recall: float + precision: float + + +class TideAttribution(DictCompatibleModel): + truePositive: int + localization: int + classification: int + both: int + duplicate: int + background: int + missed: int + + +class EvaluationV2Charts(DictCompatibleModel): + mapSummary: MapSummary + perClassAp: List[PerClassAp] + confusionMatrix: List[ConfusionEntry] + scoreHistogram: List[ScoreHistogramBucket] + computedIouRanges: List[float] + totalCounts: TotalCounts + apBySize: ApBySize + prCurve: List[PrCurvePoint] + tideAttribution: TideAttribution + + +class EvaluationV2MatchExample(DictCompatibleModel): + id: str + evaluation_id: str + dataset_item_id: str + model_prediction_id: Optional[str] = None + ground_truth_annotation_id: Optional[str] = None + pred_canonical_label: Optional[str] = None + gt_canonical_label: Optional[str] = None + pred_raw_label: Optional[str] = None + gt_raw_label: Optional[str] = None + iou: float + confidence: Optional[float] = None + true_positive: bool + match_type: str + gt_area: Optional[float] = None + item_metadata: Dict[str, Any] + prediction_metadata: Dict[str, Any] + prediction_row: Optional[Dict[str, Any]] = None + annotation_row: Optional[Dict[str, Any]] = None + + +class EvaluationV2ExamplesPage(DictCompatibleModel): + rows: List[EvaluationV2MatchExample] + total: int diff --git a/nucleus/evaluation_v2.py b/nucleus/evaluation_v2.py new file mode 100644 index 00000000..eaee103a --- /dev/null +++ b/nucleus/evaluation_v2.py @@ -0,0 +1,210 @@ +"""Nucleus Evaluation V2 — COCO-style metrics computed off ``evaluation_match_v2``.""" + +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from urllib.parse import urlencode + +import requests + +from nucleus.data_transfer_object.evaluation_v2 import ( + EvaluationV2Charts, + EvaluationV2ExamplesPage, + EvaluationV2FilterArgs, +) +from nucleus.errors import NucleusAPIError + +if TYPE_CHECKING: + from nucleus import NucleusClient + + +class EvaluationV2Status(str, Enum): + """Lifecycle states for ``nucleus.evaluation_v2.status``.""" + + PENDING = "pending" + COMPUTING = "computing" + SUCCEEDED = "succeeded" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class AllowedLabelMatch: + """Pair of labels that may match for IoU evaluation (snake_case JSON for the API).""" + + ground_truth_label: str + model_prediction_label: str + + def to_api_dict(self) -> Dict[str, str]: + return { + "ground_truth_label": self.ground_truth_label, + "model_prediction_label": self.model_prediction_label, + } + + +@dataclass +class EvaluationV2: + """A single Evaluation V2 run for a model run (``evalv2_*``).""" + + id: str + model_run_id: str + dataset_id: str + status: str + name: Optional[str] = None + temporal_workflow_id: Optional[str] = None + error_message: Optional[str] = None + created_at: Optional[str] = None + allowed_label_matches_id: Optional[str] = None + allowed_label_matches: Optional[List[AllowedLabelMatch]] = None + allowed_label_matches_name: Optional[str] = None + _client: Any = field(repr=False, default=None) + + @classmethod + def from_json( + cls, + payload: Dict[str, Any], + client: Optional["NucleusClient"] = None, + ) -> "EvaluationV2": + raw_matches = payload.get("allowed_label_matches") + matches: Optional[List[AllowedLabelMatch]] = None + if isinstance(raw_matches, list): + matches = [] + for m in raw_matches: + if not isinstance(m, dict): + continue + gt = m.get("groundTruthLabel") or m.get("ground_truth_label") + mp = m.get("modelPredictionLabel") or m.get("model_prediction_label") + if gt is not None and mp is not None: + matches.append( + AllowedLabelMatch( + ground_truth_label=str(gt), + model_prediction_label=str(mp), + ) + ) + + return cls( + id=str(payload["id"]), + model_run_id=str(payload["model_run_id"]), + dataset_id=str(payload["dataset_id"]), + status=str(payload["status"]), + name=payload.get("name"), + temporal_workflow_id=payload.get("temporal_workflow_id"), + error_message=payload.get("error_message"), + created_at=payload.get("created_at"), + allowed_label_matches_id=payload.get("allowed_label_matches_id"), + allowed_label_matches=matches, + allowed_label_matches_name=payload.get("allowed_label_matches_name"), + _client=client, + ) + + def refresh(self) -> "EvaluationV2": + """Reload this evaluation from ``GET /nucleus/evaluationsV2/:id``.""" + if self._client is None: + raise RuntimeError("EvaluationV2 has no client; use NucleusClient.get_evaluation_v2.") + data = self._client.get(f"evaluationsV2/{self.id}") + updated = EvaluationV2.from_json(data, self._client) + self.__dict__.update(updated.__dict__) + return self + + def wait_for_completion( + self, + timeout_sec: float = 600, + poll_interval: float = 5, + ) -> "EvaluationV2": + """Poll until status is terminal or ``timeout_sec`` elapses. + + Raises: + RuntimeError: on ``failed`` status or timeout. + """ + deadline = time.monotonic() + timeout_sec + terminal_ok = {"succeeded", "cancelled"} + while time.monotonic() < deadline: + self.refresh() + if self.status == "failed": + raise RuntimeError( + f"Evaluation {self.id} failed: {self.error_message or 'unknown'}" + ) + if self.status in terminal_ok: + return self + time.sleep(poll_interval) + raise RuntimeError( + f"Timed out after {timeout_sec}s waiting for evaluation {self.id} " + f"(last status: {self.status})" + ) + + def delete(self) -> None: + """Cancel workflow (best effort) and soft-delete (``204 No Content``).""" + if self._client is None: + raise RuntimeError("EvaluationV2 has no client.") + resp = self._client.make_request( + {}, + f"evaluationsV2/{self.id}", + requests_command=requests.delete, + return_raw_response=True, + ) + if resp.status_code != 204: + raise NucleusAPIError( + f"{self._client.endpoint}/evaluationsV2/{self.id}", + requests.delete, + resp, + ) + + def charts( + self, + iou_threshold: float = 0.5, + filters: Optional[Union[EvaluationV2FilterArgs, Dict[str, Any]]] = None, + query: Optional[str] = None, + ) -> EvaluationV2Charts: + """Aggregate metrics (mAP, confusion matrix, PR curve, TIDE, …).""" + if self._client is None: + raise RuntimeError("EvaluationV2 has no client.") + params: Dict[str, str] = {} + params["iouThreshold"] = str(iou_threshold) + if filters is not None: + if isinstance(filters, EvaluationV2FilterArgs): + filt_dict = filters.to_api_filters() + else: + filt_dict = filters + params["filters"] = json.dumps(filt_dict) + if query: + params["query"] = query + qs = urlencode(params) + route = f"evaluationsV2/{self.id}/charts?{qs}" + data = self._client.get(route) + return EvaluationV2Charts.parse_obj(data) + + def examples( + self, + match_type: str, + limit: int = 50, + offset: int = 0, + sort_by: Optional[str] = None, + sort_order: Optional[str] = None, + filters: Optional[Union[EvaluationV2FilterArgs, Dict[str, Any]]] = None, + query: Optional[str] = None, + ) -> EvaluationV2ExamplesPage: + """Paginated TP / FP / FN match rows with prediction and annotation blobs.""" + if self._client is None: + raise RuntimeError("EvaluationV2 has no client.") + payload: Dict[str, Any] = { + "match_type": match_type, + "limit": limit, + "offset": offset, + } + if sort_by is not None: + payload["sort_by"] = sort_by + if sort_order is not None: + payload["sort_order"] = sort_order + if filters is not None: + if isinstance(filters, EvaluationV2FilterArgs): + payload["filters"] = filters.to_api_filters() + else: + payload["filters"] = filters + if query: + payload["query"] = query + data = self._client.post(payload, f"evaluationsV2/{self.id}/examples") + return EvaluationV2ExamplesPage.parse_obj(data) diff --git a/tests/test_evaluation_v2.py b/tests/test_evaluation_v2.py new file mode 100644 index 00000000..c8429cf7 --- /dev/null +++ b/tests/test_evaluation_v2.py @@ -0,0 +1,170 @@ +"""Unit tests for Evaluations V2 client (no live API).""" + +from unittest.mock import MagicMock + +import requests + +from nucleus import AllowedLabelMatch, EvaluationV2, NucleusClient +from nucleus.data_transfer_object.evaluation_v2 import EvaluationV2Charts + + +def test_allowed_label_match_to_api_dict(): + m = AllowedLabelMatch(ground_truth_label="a", model_prediction_label="b") + assert m.to_api_dict() == { + "ground_truth_label": "a", + "model_prediction_label": "b", + } + + +def test_evaluation_v2_from_json_with_matches(): + client = NucleusClient(api_key="k") + payload = { + "id": "evalv2_1", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "pending", + "allowed_label_matches": [ + {"groundTruthLabel": "x", "modelPredictionLabel": "y"}, + ], + } + ev = EvaluationV2.from_json(payload, client) + assert ev.id == "evalv2_1" + assert ev.allowed_label_matches is not None + assert len(ev.allowed_label_matches) == 1 + assert ev.allowed_label_matches[0].ground_truth_label == "x" + + +def test_create_evaluation_v2_then_get(): + client = NucleusClient(api_key="test") + client.connection.make_request = MagicMock( + return_value={ + "evaluation_id": "evalv2_new", + "status": "pending", + "workflow_id": "w", + } + ) + client.connection.get = MagicMock( + return_value={ + "id": "evalv2_new", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "pending", + } + ) + + ev = client.create_evaluation_v2( + "run_1", + name="n1", + allowed_label_matches=[ + AllowedLabelMatch("gt", "pred"), + ], + ) + assert ev.id == "evalv2_new" + client.connection.make_request.assert_called_once() + client.connection.get.assert_called_once_with("evaluationsV2/evalv2_new") + + +def test_charts_get_query_string(): + client = MagicMock(spec=NucleusClient) + client.get.return_value = { + "mapSummary": {"mapAt50": 0.1, "mapAt75": 0.2, "mapAt5095": 0.15}, + "perClassAp": [], + "confusionMatrix": [], + "scoreHistogram": [], + "computedIouRanges": [], + "totalCounts": {"tp": 0, "fp": 0, "fn": 0, "predsWithConfidence": 0}, + "apBySize": {"small": None, "medium": None, "large": None}, + "prCurve": [], + "tideAttribution": { + "truePositive": 0, + "localization": 0, + "classification": 0, + "both": 0, + "duplicate": 0, + "background": 0, + "missed": 0, + }, + } + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="succeeded", + _client=client, + ) + charts = ev.charts(iou_threshold=0.5) + assert isinstance(charts, EvaluationV2Charts) + call_route = client.get.call_args[0][0] + assert "evaluationsV2/evalv2_1/charts" in call_route + assert "iouThreshold=0.5" in call_route + + +def test_examples_post_body(): + client = MagicMock(spec=NucleusClient) + client.post.return_value = {"rows": [], "total": 0} + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="succeeded", + _client=client, + ) + page = ev.examples("TP", limit=20, offset=5) + assert page.total == 0 + client.post.assert_called_once() + args, kwargs = client.post.call_args + payload, route = args + assert route == "evaluationsV2/evalv2_1/examples" + assert payload["match_type"] == "TP" + assert payload["limit"] == 20 + assert payload["offset"] == 5 + + +def test_wait_for_completion(): + client = NucleusClient(api_key="test") + client.connection.get = MagicMock( + side_effect=[ + { + "id": "evalv2_1", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "pending", + }, + { + "id": "evalv2_1", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "succeeded", + }, + ] + ) + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="pending", + _client=client, + ) + ev.wait_for_completion(timeout_sec=5, poll_interval=0.01) + assert ev.status == "succeeded" + + +def test_delete_204(): + client = NucleusClient(api_key="test") + resp = MagicMock() + resp.status_code = 204 + client.connection.make_request = MagicMock(return_value=resp) + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="succeeded", + _client=client, + ) + ev.delete() + assert client.connection.make_request.call_count == 1 + cargs = client.connection.make_request.call_args + assert cargs[0][0] == {} + assert cargs[0][1] == "evaluationsV2/evalv2_1" + assert cargs[0][2] is requests.delete + assert cargs[0][3] is True From 36f6b4aef3e244e5194aa52898232d6027619edd Mon Sep 17 00:00:00 2001 From: Luke Schaefer Date: Tue, 12 May 2026 13:49:22 -0500 Subject: [PATCH 2/8] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- nucleus/data_transfer_object/evaluation_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nucleus/data_transfer_object/evaluation_v2.py b/nucleus/data_transfer_object/evaluation_v2.py index 7524e32e..6150aee3 100644 --- a/nucleus/data_transfer_object/evaluation_v2.py +++ b/nucleus/data_transfer_object/evaluation_v2.py @@ -129,7 +129,7 @@ class EvaluationV2MatchExample(DictCompatibleModel): gt_canonical_label: Optional[str] = None pred_raw_label: Optional[str] = None gt_raw_label: Optional[str] = None - iou: float + iou: Optional[float] = None confidence: Optional[float] = None true_positive: bool match_type: str From 3caaf8d336ac1a7bec27e85fe3095d2419e01018 Mon Sep 17 00:00:00 2001 From: Luke Schaefer Date: Tue, 12 May 2026 13:49:34 -0500 Subject: [PATCH 3/8] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- nucleus/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nucleus/__init__.py b/nucleus/__init__.py index 6d675433..b3ad1297 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -916,7 +916,7 @@ def create_evaluation_v2( payload: Dict[str, Any] = {} if name is not None: payload["name"] = name - if allowed_label_matches: + if allowed_label_matches is not None: payload[ "allowed_label_matches" ] = [m.to_api_dict() for m in allowed_label_matches] From 13a91b2c057761b40ce9ce40a8829efff6837adf Mon Sep 17 00:00:00 2001 From: Luke Schaefer Date: Tue, 12 May 2026 14:03:04 -0500 Subject: [PATCH 4/8] run hooks --- nucleus/__init__.py | 12 +++++++----- nucleus/evaluation_v2.py | 20 +++++++++++++++----- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/nucleus/__init__.py b/nucleus/__init__.py index 6d675433..d3995ba8 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -151,7 +151,6 @@ DeduplicationStats, ) from .deprecation_warning import deprecated -from .evaluation_v2 import AllowedLabelMatch, EvaluationV2, EvaluationV2Status from .errors import ( DatasetItemRetrievalError, ModelCreationError, @@ -160,6 +159,7 @@ NotFoundError, NucleusAPIError, ) +from .evaluation_v2 import AllowedLabelMatch, EvaluationV2, EvaluationV2Status from .job import CustomerJobTypes from .model import Model from .model_run import ModelRun @@ -917,9 +917,9 @@ def create_evaluation_v2( if name is not None: payload["name"] = name if allowed_label_matches: - payload[ - "allowed_label_matches" - ] = [m.to_api_dict() for m in allowed_label_matches] + payload["allowed_label_matches"] = [ + m.to_api_dict() for m in allowed_label_matches + ] if allowed_label_matches_id is not None: payload["allowed_label_matches_id"] = allowed_label_matches_id result = self.make_request( @@ -927,7 +927,9 @@ def create_evaluation_v2( ) eval_id = result.get("evaluation_id") if not eval_id: - raise RuntimeError(f"Unexpected create evaluation V2 response: {result}") + raise RuntimeError( + f"Unexpected create evaluation V2 response: {result}" + ) return self.get_evaluation_v2(str(eval_id)) def get_evaluation_v2(self, evaluation_id: str) -> EvaluationV2: diff --git a/nucleus/evaluation_v2.py b/nucleus/evaluation_v2.py index eaee103a..94191f31 100644 --- a/nucleus/evaluation_v2.py +++ b/nucleus/evaluation_v2.py @@ -77,7 +77,9 @@ def from_json( if not isinstance(m, dict): continue gt = m.get("groundTruthLabel") or m.get("ground_truth_label") - mp = m.get("modelPredictionLabel") or m.get("model_prediction_label") + mp = m.get("modelPredictionLabel") or m.get( + "model_prediction_label" + ) if gt is not None and mp is not None: matches.append( AllowedLabelMatch( @@ -97,14 +99,18 @@ def from_json( created_at=payload.get("created_at"), allowed_label_matches_id=payload.get("allowed_label_matches_id"), allowed_label_matches=matches, - allowed_label_matches_name=payload.get("allowed_label_matches_name"), + allowed_label_matches_name=payload.get( + "allowed_label_matches_name" + ), _client=client, ) def refresh(self) -> "EvaluationV2": """Reload this evaluation from ``GET /nucleus/evaluationsV2/:id``.""" if self._client is None: - raise RuntimeError("EvaluationV2 has no client; use NucleusClient.get_evaluation_v2.") + raise RuntimeError( + "EvaluationV2 has no client; use NucleusClient.get_evaluation_v2." + ) data = self._client.get(f"evaluationsV2/{self.id}") updated = EvaluationV2.from_json(data, self._client) self.__dict__.update(updated.__dict__) @@ -156,7 +162,9 @@ def delete(self) -> None: def charts( self, iou_threshold: float = 0.5, - filters: Optional[Union[EvaluationV2FilterArgs, Dict[str, Any]]] = None, + filters: Optional[ + Union[EvaluationV2FilterArgs, Dict[str, Any]] + ] = None, query: Optional[str] = None, ) -> EvaluationV2Charts: """Aggregate metrics (mAP, confusion matrix, PR curve, TIDE, …).""" @@ -184,7 +192,9 @@ def examples( offset: int = 0, sort_by: Optional[str] = None, sort_order: Optional[str] = None, - filters: Optional[Union[EvaluationV2FilterArgs, Dict[str, Any]]] = None, + filters: Optional[ + Union[EvaluationV2FilterArgs, Dict[str, Any]] + ] = None, query: Optional[str] = None, ) -> EvaluationV2ExamplesPage: """Paginated TP / FP / FN match rows with prediction and annotation blobs.""" From aced4aab70192915fa0c67a9d3d2e35ba3d985a2 Mon Sep 17 00:00:00 2001 From: Luke Schaefer Date: Tue, 12 May 2026 14:05:32 -0500 Subject: [PATCH 5/8] Update nucleus/data_transfer_object/evaluation_v2.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- nucleus/data_transfer_object/evaluation_v2.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nucleus/data_transfer_object/evaluation_v2.py b/nucleus/data_transfer_object/evaluation_v2.py index 6150aee3..18607ddc 100644 --- a/nucleus/data_transfer_object/evaluation_v2.py +++ b/nucleus/data_transfer_object/evaluation_v2.py @@ -134,8 +134,8 @@ class EvaluationV2MatchExample(DictCompatibleModel): true_positive: bool match_type: str gt_area: Optional[float] = None - item_metadata: Dict[str, Any] - prediction_metadata: Dict[str, Any] + item_metadata: Optional[Dict[str, Any]] = None + prediction_metadata: Optional[Dict[str, Any]] = None prediction_row: Optional[Dict[str, Any]] = None annotation_row: Optional[Dict[str, Any]] = None From 866ac71918b65a229310592fa8b1a34df702e8ae Mon Sep 17 00:00:00 2001 From: Luke Schaefer Date: Tue, 12 May 2026 14:54:55 -0500 Subject: [PATCH 6/8] fix p1 --- nucleus/evaluation_v2.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/nucleus/evaluation_v2.py b/nucleus/evaluation_v2.py index 94191f31..3fdf5dca 100644 --- a/nucleus/evaluation_v2.py +++ b/nucleus/evaluation_v2.py @@ -76,10 +76,12 @@ def from_json( for m in raw_matches: if not isinstance(m, dict): continue - gt = m.get("groundTruthLabel") or m.get("ground_truth_label") - mp = m.get("modelPredictionLabel") or m.get( - "model_prediction_label" - ) + gt = m.get("groundTruthLabel") + if gt is None: + gt = m.get("ground_truth_label") + mp = m.get("modelPredictionLabel") + if mp is None: + mp = m.get("model_prediction_label") if gt is not None and mp is not None: matches.append( AllowedLabelMatch( From 658216319ac2413994a3d7ce03fc3559db647268 Mon Sep 17 00:00:00 2001 From: Luke Schaefer Date: Thu, 28 May 2026 17:31:01 -0500 Subject: [PATCH 7/8] address comments --- CHANGELOG.md | 5 + docs/index.rst | 12 +-- nucleus/__init__.py | 41 ++++--- nucleus/data_transfer_object/evaluation_v2.py | 65 ++++++----- nucleus/evaluation_v2.py | 78 +++++++++----- pyproject.toml | 2 +- tests/test_evaluation_v2.py | 101 +++++++++++++++++- 7 files changed, 231 insertions(+), 73 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 019af44e..486ff13a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.18.3](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.18.3) - 2026-05-28 + +### Added +- **Evaluations V2** client support for COCO-style metrics on model runs via stored `evaluation_match_v2` rows. `NucleusClient` exposes `create_evaluation_v2()`, `get_evaluation_v2()`, and `list_evaluations_v2()`. The `EvaluationV2` resource supports `wait_for_completion()`, `charts()` (mAP, confusion matrix, PR curve, TIDE, and related aggregates), `examples()` (paginated TP/FP/FN rows), `delete()`, and `refresh()`. `AllowedLabelMatch` configures allowed ground-truth / prediction label pairs; filter and response types include `EvaluationV2FilterArgs`, `EvaluationV2Charts`, `EvaluationV2ExamplesPage`, and `EvaluationV2MatchExample`. Sphinx docs cover the workflow under Evaluations V2. + ## [0.18.2](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.18.2) - 2026-05-08 ### Added diff --git a/docs/index.rst b/docs/index.rst index a33f704f..88ec8c3d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,10 +17,10 @@ Nucleus is a new way—the right way—to develop ML models, helping us move awa Evaluations V2 -------------- -Evaluation V2 runs COCO-style metrics against stored matches (``evaluation_match_v2``) for a **model run**. -Create an evaluation with :meth:`NucleusClient.create_evaluation_v2`; poll with -:meth:`nucleus.evaluation_v2.EvaluationV2.wait_for_completion`; then fetch aggregates via -:meth:`nucleus.evaluation_v2.EvaluationV2.charts` or per-row examples via +Evaluation V2 measures how well a **model run** matches ground-truth annotations. +Create a run with :meth:`NucleusClient.create_evaluation_v2`, wait with +:meth:`nucleus.evaluation_v2.EvaluationV2.wait_for_completion`, then read summary metrics with +:meth:`nucleus.evaluation_v2.EvaluationV2.charts` or individual matches with :meth:`nucleus.evaluation_v2.EvaluationV2.examples`. .. code-block:: python @@ -42,10 +42,6 @@ Create an evaluation with :meth:`NucleusClient.create_evaluation_v2`; poll with charts = evaluation.charts(iou_threshold=0.5) fps = evaluation.examples(match_type="FP", limit=20) -The API uses REST endpoints ``/nucleus/modelRun/:id/evaluationsV2``, -``/nucleus/evaluationsV2/:id``, ``/nucleus/evaluationsV2/:id/charts``, and -``POST /nucleus/evaluationsV2/:id/examples``. - .. _installation: Installation diff --git a/nucleus/__init__.py b/nucleus/__init__.py index 4e1ea270..8e551987 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -897,21 +897,20 @@ def create_evaluation_v2( allowed_label_matches: Optional[List[AllowedLabelMatch]] = None, allowed_label_matches_id: Optional[str] = None, ) -> EvaluationV2: - """Create an Evaluation V2 job for a model run. + """Create an evaluation for a model run. - Starts a Temporal workflow that fills ``evaluation_match_v2``. Use - :meth:`EvaluationV2.wait_for_completion` then :meth:`EvaluationV2.charts` - or :meth:`EvaluationV2.examples` for results. + The evaluation runs in the background. Call + :meth:`EvaluationV2.wait_for_completion`, then + :meth:`EvaluationV2.charts` or :meth:`EvaluationV2.examples` for results. Parameters: - model_run_id: Nucleus model run id (``run_*``). - name: Optional human-readable name. - allowed_label_matches: Optional explicit allowed label pairs; omit to use - the model run's default configuration. - allowed_label_matches_id: Optional existing allowed-label-matches config id. + model_run_id: Model run id (``run_*``). + name: Optional display name. + allowed_label_matches: Optional label pairs to treat as matches. + allowed_label_matches_id: Optional id of a saved label-match configuration. Returns: - :class:`EvaluationV2` loaded via ``GET /nucleus/evaluationsV2/:id``. + :class:`EvaluationV2`: The created evaluation. """ payload: Dict[str, Any] = {} if name is not None: @@ -933,15 +932,31 @@ def create_evaluation_v2( return self.get_evaluation_v2(str(eval_id)) def get_evaluation_v2(self, evaluation_id: str) -> EvaluationV2: - """Fetch a single Evaluation V2 row.""" + """Get an evaluation by id. + + Parameters: + evaluation_id: Evaluation id (``evalv2_*``). + + Returns: + :class:`EvaluationV2`. + """ data = self.get(f"evaluationsV2/{evaluation_id}") return EvaluationV2.from_json(data, self) def list_evaluations_v2(self, model_run_id: str) -> List[EvaluationV2]: - """List Evaluation V2 rows for a model run (newest first).""" + """List evaluations for a model run (newest first). + + Parameters: + model_run_id: Model run id (``run_*``). + + Returns: + List of :class:`EvaluationV2`. + """ rows = self.get(f"modelRun/{model_run_id}/evaluationsV2") if not isinstance(rows, list): - return [] + raise RuntimeError( + f"Unexpected list evaluations V2 response: {rows!r}" + ) return [EvaluationV2.from_json(r, self) for r in rows] @deprecated(msg="Prefer calling Dataset.info() directly.") diff --git a/nucleus/data_transfer_object/evaluation_v2.py b/nucleus/data_transfer_object/evaluation_v2.py index 18607ddc..a1abb443 100644 --- a/nucleus/data_transfer_object/evaluation_v2.py +++ b/nucleus/data_transfer_object/evaluation_v2.py @@ -1,10 +1,30 @@ -"""Pydantic models for Nucleus Evaluations V2 REST payloads.""" +"""Response and filter models for Evaluation V2.""" from typing import Any, Dict, List, Literal, Optional from nucleus.pydantic_base import DictCompatibleModel +def _snake_to_camel(name: str) -> str: + parts = name.split("_") + if len(parts) == 1: + return name + return parts[0] + "".join(part.capitalize() for part in parts[1:]) + + +def _camelize_filter_value(value: Any) -> Any: + if isinstance(value, dict): + return { + _snake_to_camel(key): ( + val if key == "value" else _camelize_filter_value(val) + ) + for key, val in value.items() + } + if isinstance(value, list): + return [_camelize_filter_value(item) for item in value] + return value + + class RangeNum(DictCompatibleModel): min: Optional[float] = None max: Optional[float] = None @@ -16,8 +36,21 @@ class MetadataPredicate(DictCompatibleModel): value: Optional[Any] = None +_FILTER_API_KEYS = { + "confidence_range": "confidenceRange", + "iou_range": "iouRange", + "pred_labels": "predLabels", + "gt_labels": "gtLabels", + "item_metadata": "itemMetadata", + "prediction_metadata": "predictionMetadata", + "label_equality": "labelEquality", + "has_ground_truth": "hasGroundTruth", + "tide_background": "tideBackground", +} + + class EvaluationV2FilterArgs(DictCompatibleModel): - """Filter object for charts/examples calls (mirrors server evaluation_v2 SQL filters).""" + """Optional filters for :meth:`nucleus.evaluation_v2.EvaluationV2.charts` and :meth:`nucleus.evaluation_v2.EvaluationV2.examples`.""" confidence_range: Optional[RangeNum] = None iou_range: Optional[RangeNum] = None @@ -30,29 +63,13 @@ class EvaluationV2FilterArgs(DictCompatibleModel): tide_background: Optional[bool] = None def to_api_filters(self) -> Dict[str, Any]: - """Serialize to camelCase keys expected by the GraphQL / REST layer.""" + """Return filters as a dict ready for API requests.""" d = self.dict(exclude_none=True) - # pydantic v1 uses snake_case fields; server expects camelCase in JSON filters - out: Dict[str, Any] = {} - if "confidence_range" in d: - out["confidenceRange"] = d["confidence_range"] - if "iou_range" in d: - out["iouRange"] = d["iou_range"] - if "pred_labels" in d: - out["predLabels"] = d["pred_labels"] - if "gt_labels" in d: - out["gtLabels"] = d["gt_labels"] - if "item_metadata" in d: - out["itemMetadata"] = d["item_metadata"] - if "prediction_metadata" in d: - out["predictionMetadata"] = d["prediction_metadata"] - if "label_equality" in d: - out["labelEquality"] = d["label_equality"] - if "has_ground_truth" in d: - out["hasGroundTruth"] = d["has_ground_truth"] - if "tide_background" in d: - out["tideBackground"] = d["tide_background"] - return out + return { + api_key: _camelize_filter_value(d[snake_key]) + for snake_key, api_key in _FILTER_API_KEYS.items() + if snake_key in d + } class MapSummary(DictCompatibleModel): diff --git a/nucleus/evaluation_v2.py b/nucleus/evaluation_v2.py index 3fdf5dca..4dd35385 100644 --- a/nucleus/evaluation_v2.py +++ b/nucleus/evaluation_v2.py @@ -1,4 +1,4 @@ -"""Nucleus Evaluation V2 — COCO-style metrics computed off ``evaluation_match_v2``.""" +"""Evaluation V2 — metrics and examples for a model run.""" from __future__ import annotations @@ -6,7 +6,7 @@ import time from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union from urllib.parse import urlencode import requests @@ -16,14 +16,12 @@ EvaluationV2ExamplesPage, EvaluationV2FilterArgs, ) -from nucleus.errors import NucleusAPIError - if TYPE_CHECKING: from nucleus import NucleusClient class EvaluationV2Status(str, Enum): - """Lifecycle states for ``nucleus.evaluation_v2.status``.""" + """Status of an Evaluation V2 run.""" PENDING = "pending" COMPUTING = "computing" @@ -32,9 +30,15 @@ class EvaluationV2Status(str, Enum): CANCELLED = "cancelled" +_TERMINAL_OK: Set[EvaluationV2Status] = { + EvaluationV2Status.SUCCEEDED, + EvaluationV2Status.CANCELLED, +} + + @dataclass class AllowedLabelMatch: - """Pair of labels that may match for IoU evaluation (snake_case JSON for the API).""" + """Ground-truth and prediction label pair that counts as a match.""" ground_truth_label: str model_prediction_label: str @@ -48,7 +52,7 @@ def to_api_dict(self) -> Dict[str, str]: @dataclass class EvaluationV2: - """A single Evaluation V2 run for a model run (``evalv2_*``).""" + """An Evaluation V2 run for a model run.""" id: str model_run_id: str @@ -61,7 +65,7 @@ class EvaluationV2: allowed_label_matches_id: Optional[str] = None allowed_label_matches: Optional[List[AllowedLabelMatch]] = None allowed_label_matches_name: Optional[str] = None - _client: Any = field(repr=False, default=None) + _client: Optional["NucleusClient"] = field(repr=False, default=None) @classmethod def from_json( @@ -108,7 +112,11 @@ def from_json( ) def refresh(self) -> "EvaluationV2": - """Reload this evaluation from ``GET /nucleus/evaluationsV2/:id``.""" + """Reload this evaluation from Nucleus. + + Returns: + self, with updated fields. + """ if self._client is None: raise RuntimeError( "EvaluationV2 has no client; use NucleusClient.get_evaluation_v2." @@ -123,20 +131,26 @@ def wait_for_completion( timeout_sec: float = 600, poll_interval: float = 5, ) -> "EvaluationV2": - """Poll until status is terminal or ``timeout_sec`` elapses. + """Wait until the evaluation finishes or is cancelled. + + Parameters: + timeout_sec: Maximum seconds to wait. + poll_interval: Seconds between status checks. + + Returns: + self, after a terminal status is reached. Raises: - RuntimeError: on ``failed`` status or timeout. + RuntimeError: If the evaluation fails or times out. """ deadline = time.monotonic() + timeout_sec - terminal_ok = {"succeeded", "cancelled"} while time.monotonic() < deadline: self.refresh() - if self.status == "failed": + if self.status == EvaluationV2Status.FAILED: raise RuntimeError( f"Evaluation {self.id} failed: {self.error_message or 'unknown'}" ) - if self.status in terminal_ok: + if self.status in _TERMINAL_OK: return self time.sleep(poll_interval) raise RuntimeError( @@ -145,21 +159,15 @@ def wait_for_completion( ) def delete(self) -> None: - """Cancel workflow (best effort) and soft-delete (``204 No Content``).""" + """Delete this evaluation.""" if self._client is None: raise RuntimeError("EvaluationV2 has no client.") - resp = self._client.make_request( + self._client.make_request( {}, f"evaluationsV2/{self.id}", requests_command=requests.delete, return_raw_response=True, ) - if resp.status_code != 204: - raise NucleusAPIError( - f"{self._client.endpoint}/evaluationsV2/{self.id}", - requests.delete, - resp, - ) def charts( self, @@ -169,7 +177,16 @@ def charts( ] = None, query: Optional[str] = None, ) -> EvaluationV2Charts: - """Aggregate metrics (mAP, confusion matrix, PR curve, TIDE, …).""" + """Return aggregate metrics for this evaluation. + + Parameters: + iou_threshold: IoU threshold for matching (default 0.5). + filters: Optional filters (:class:`EvaluationV2FilterArgs` or dict). + query: Optional query string to narrow results. + + Returns: + :class:`EvaluationV2Charts`: Summary metrics (mAP, confusion matrix, PR curve, etc.). + """ if self._client is None: raise RuntimeError("EvaluationV2 has no client.") params: Dict[str, str] = {} @@ -199,7 +216,20 @@ def examples( ] = None, query: Optional[str] = None, ) -> EvaluationV2ExamplesPage: - """Paginated TP / FP / FN match rows with prediction and annotation blobs.""" + """Return paginated true-positive, false-positive, or false-negative examples. + + Parameters: + match_type: ``"TP"``, ``"FP"``, or ``"FN"``. + limit: Page size (default 50). + offset: Row offset for pagination. + sort_by: Optional field to sort by. + sort_order: Optional sort direction (e.g. ``"asc"`` or ``"desc"``). + filters: Optional filters (:class:`EvaluationV2FilterArgs` or dict). + query: Optional query string to narrow results. + + Returns: + :class:`EvaluationV2ExamplesPage`: Matching rows and total count. + """ if self._client is None: raise RuntimeError("EvaluationV2 has no client.") payload: Dict[str, Any] = { diff --git a/pyproject.toml b/pyproject.toml index 772decb2..dd07937e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ ignore = ["E501", "E741", "E731", "F401"] # Easy ignore for getting it running [tool.poetry] name = "scale-nucleus" -version = "0.18.2" +version = "0.18.3" description = "The official Python client library for Nucleus, the Data Platform for AI" license = "MIT" authors = ["Scale AI Nucleus Team "] diff --git a/tests/test_evaluation_v2.py b/tests/test_evaluation_v2.py index c8429cf7..829e34a2 100644 --- a/tests/test_evaluation_v2.py +++ b/tests/test_evaluation_v2.py @@ -2,10 +2,45 @@ from unittest.mock import MagicMock +import pytest import requests from nucleus import AllowedLabelMatch, EvaluationV2, NucleusClient -from nucleus.data_transfer_object.evaluation_v2 import EvaluationV2Charts +from nucleus.data_transfer_object.evaluation_v2 import ( + EvaluationV2Charts, + EvaluationV2FilterArgs, + MetadataPredicate, + RangeNum, + _camelize_filter_value, +) + + +def test_evaluation_v2_filter_args_to_api_filters(): + filters = EvaluationV2FilterArgs( + confidence_range=RangeNum(min=0.1, max=0.9), + pred_labels=["cat"], + item_metadata=[MetadataPredicate(key="tier", op="EQ", value="gold")], + has_ground_truth=True, + ) + assert filters.to_api_filters() == { + "confidenceRange": {"min": 0.1, "max": 0.9}, + "predLabels": ["cat"], + "itemMetadata": [{"key": "tier", "op": "EQ", "value": "gold"}], + "hasGroundTruth": True, + } + + +def test_camelize_filter_value_nested_keys(): + assert _camelize_filter_value({"bucket_min": 1.0, "bucket_max": 2.0}) == { + "bucketMin": 1.0, + "bucketMax": 2.0, + } + + +def test_camelize_filter_value_preserves_predicate_value(): + assert _camelize_filter_value( + {"key": "k", "op": "EQ", "value": {"keep_snake": 1}} + ) == {"key": "k", "op": "EQ", "value": {"keep_snake": 1}} def test_allowed_label_match_to_api_dict(): @@ -34,6 +69,41 @@ def test_evaluation_v2_from_json_with_matches(): assert ev.allowed_label_matches[0].ground_truth_label == "x" +def test_list_evaluations_v2_empty(): + client = NucleusClient(api_key="test") + client.connection.get = MagicMock(return_value=[]) + result = client.list_evaluations_v2("run_1") + assert result == [] + client.connection.get.assert_called_once_with( + "modelRun/run_1/evaluationsV2" + ) + + +def test_list_evaluations_v2_returns_rows(): + client = NucleusClient(api_key="test") + client.connection.get = MagicMock( + return_value=[ + { + "id": "evalv2_1", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "succeeded", + }, + ] + ) + result = client.list_evaluations_v2("run_1") + assert len(result) == 1 + assert result[0].id == "evalv2_1" + assert result[0]._client is client + + +def test_list_evaluations_v2_invalid_response(): + client = NucleusClient(api_key="test") + client.connection.get = MagicMock(return_value={"evaluations": []}) + with pytest.raises(RuntimeError, match="Unexpected list evaluations V2"): + client.list_evaluations_v2("run_1") + + def test_create_evaluation_v2_then_get(): client = NucleusClient(api_key="test") client.connection.make_request = MagicMock( @@ -120,6 +190,30 @@ def test_examples_post_body(): assert payload["offset"] == 5 +def test_examples_with_filter_args(): + client = MagicMock(spec=NucleusClient) + client.post.return_value = {"rows": [], "total": 0} + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="succeeded", + _client=client, + ) + filters = EvaluationV2FilterArgs( + confidence_range=RangeNum(min=0.1, max=0.9), + pred_labels=["cat"], + has_ground_truth=True, + ) + ev.examples("FP", limit=10, filters=filters) + payload = client.post.call_args[0][0] + assert payload["filters"] == { + "confidenceRange": {"min": 0.1, "max": 0.9}, + "predLabels": ["cat"], + "hasGroundTruth": True, + } + + def test_wait_for_completion(): client = NucleusClient(api_key="test") client.connection.get = MagicMock( @@ -149,10 +243,11 @@ def test_wait_for_completion(): assert ev.status == "succeeded" -def test_delete_204(): +@pytest.mark.parametrize("status_code", [200, 204]) +def test_delete_success(status_code): client = NucleusClient(api_key="test") resp = MagicMock() - resp.status_code = 204 + resp.status_code = status_code client.connection.make_request = MagicMock(return_value=resp) ev = EvaluationV2( id="evalv2_1", From f88b665808d40e52eacd0c75a083ecd74b82336c Mon Sep 17 00:00:00 2001 From: Luke Schaefer Date: Thu, 28 May 2026 18:21:23 -0500 Subject: [PATCH 8/8] fix lint --- nucleus/evaluation_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nucleus/evaluation_v2.py b/nucleus/evaluation_v2.py index 4dd35385..43f8a03c 100644 --- a/nucleus/evaluation_v2.py +++ b/nucleus/evaluation_v2.py @@ -16,6 +16,7 @@ EvaluationV2ExamplesPage, EvaluationV2FilterArgs, ) + if TYPE_CHECKING: from nucleus import NucleusClient