diff --git a/CHANGELOG.md b/CHANGELOG.md index d427a495..7c49a64b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.18.4](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.18.3) - 2026-05-28 + +### Added +- **Evaluations V2** client support for COCO-style metrics on model runs via stored `evaluation_match_v2` rows. `NucleusClient` exposes `create_evaluation_v2()`, `get_evaluation_v2()`, and `list_evaluations_v2()`. The `EvaluationV2` resource supports `wait_for_completion()`, `charts()` (mAP, confusion matrix, PR curve, TIDE, and related aggregates), `examples()` (paginated TP/FP/FN rows), `delete()`, and `refresh()`. `AllowedLabelMatch` configures allowed ground-truth / prediction label pairs; filter and response types include `EvaluationV2FilterArgs`, `EvaluationV2Charts`, `EvaluationV2ExamplesPage`, and `EvaluationV2MatchExample`. Sphinx docs cover the workflow under Evaluations V2. + ## [0.18.3](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.18.3) - 2026-05-18 ### Added diff --git a/docs/index.rst b/docs/index.rst index 698ef59a..88ec8c3d 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,6 +12,36 @@ Scale Nucleus helps you: Nucleus is a new way—the right way—to develop ML models, helping us move away from the concept of one dataset and towards a paradigm of collections of scenarios. +.. _evaluations-v2: + +Evaluations V2 +-------------- + +Evaluation V2 measures how well a **model run** matches ground-truth annotations. +Create a run with :meth:`NucleusClient.create_evaluation_v2`, wait with +:meth:`nucleus.evaluation_v2.EvaluationV2.wait_for_completion`, then read summary metrics with +:meth:`nucleus.evaluation_v2.EvaluationV2.charts` or individual matches with +:meth:`nucleus.evaluation_v2.EvaluationV2.examples`. + +.. code-block:: python + + import nucleus + + client = nucleus.NucleusClient(api_key="YOUR_API_KEY") + evaluation = client.create_evaluation_v2( + model_run_id="run_xxx", + name="my-eval", + allowed_label_matches=[ + nucleus.AllowedLabelMatch( + ground_truth_label="car", + model_prediction_label="vehicle", + ), + ], + ) + evaluation.wait_for_completion() + charts = evaluation.charts(iou_threshold=0.5) + fps = evaluation.examples(match_type="FP", limit=20) + .. _installation: Installation diff --git a/nucleus/__init__.py b/nucleus/__init__.py index d7ee51db..8e551987 100644 --- a/nucleus/__init__.py +++ b/nucleus/__init__.py @@ -1,7 +1,8 @@ -"""Nucleus Python SDK. """ +"""Nucleus Python SDK.""" __all__ = [ "AsyncJob", + "AllowedLabelMatch", "EmbeddingsExportJob", "BoxAnnotation", "DeduplicationJob", @@ -17,6 +18,12 @@ "DatasetInfo", "DatasetItem", "DatasetItemRetrievalError", + "EvaluationV2", + "EvaluationV2Charts", + "EvaluationV2ExamplesPage", + "EvaluationV2FilterArgs", + "EvaluationV2MatchExample", + "EvaluationV2Status", "Frame", "Keypoint", "KeypointsAnnotation", @@ -129,6 +136,12 @@ ) from .data_transfer_object.dataset_details import DatasetDetails from .data_transfer_object.dataset_info import DatasetInfo +from .data_transfer_object.evaluation_v2 import ( + EvaluationV2Charts, + EvaluationV2ExamplesPage, + EvaluationV2FilterArgs, + EvaluationV2MatchExample, +) from .data_transfer_object.job_status import JobInfoRequestPayload from .dataset import Dataset from .dataset_item import DatasetItem @@ -146,6 +159,7 @@ NotFoundError, NucleusAPIError, ) +from .evaluation_v2 import AllowedLabelMatch, EvaluationV2, EvaluationV2Status from .job import CustomerJobTypes from .model import Model from .model_run import ModelRun @@ -875,6 +889,76 @@ def commit_model_run( payload = {} return self.make_request(payload, f"modelRun/{model_run_id}/commit") + def create_evaluation_v2( + self, + model_run_id: str, + *, + name: Optional[str] = None, + allowed_label_matches: Optional[List[AllowedLabelMatch]] = None, + allowed_label_matches_id: Optional[str] = None, + ) -> EvaluationV2: + """Create an evaluation for a model run. + + The evaluation runs in the background. Call + :meth:`EvaluationV2.wait_for_completion`, then + :meth:`EvaluationV2.charts` or :meth:`EvaluationV2.examples` for results. + + Parameters: + model_run_id: Model run id (``run_*``). + name: Optional display name. + allowed_label_matches: Optional label pairs to treat as matches. + allowed_label_matches_id: Optional id of a saved label-match configuration. + + Returns: + :class:`EvaluationV2`: The created evaluation. + """ + payload: Dict[str, Any] = {} + if name is not None: + payload["name"] = name + if allowed_label_matches is not None: + payload["allowed_label_matches"] = [ + m.to_api_dict() for m in allowed_label_matches + ] + if allowed_label_matches_id is not None: + payload["allowed_label_matches_id"] = allowed_label_matches_id + result = self.make_request( + payload, f"modelRun/{model_run_id}/evaluationsV2" + ) + eval_id = result.get("evaluation_id") + if not eval_id: + raise RuntimeError( + f"Unexpected create evaluation V2 response: {result}" + ) + return self.get_evaluation_v2(str(eval_id)) + + def get_evaluation_v2(self, evaluation_id: str) -> EvaluationV2: + """Get an evaluation by id. + + Parameters: + evaluation_id: Evaluation id (``evalv2_*``). + + Returns: + :class:`EvaluationV2`. + """ + data = self.get(f"evaluationsV2/{evaluation_id}") + return EvaluationV2.from_json(data, self) + + def list_evaluations_v2(self, model_run_id: str) -> List[EvaluationV2]: + """List evaluations for a model run (newest first). + + Parameters: + model_run_id: Model run id (``run_*``). + + Returns: + List of :class:`EvaluationV2`. + """ + rows = self.get(f"modelRun/{model_run_id}/evaluationsV2") + if not isinstance(rows, list): + raise RuntimeError( + f"Unexpected list evaluations V2 response: {rows!r}" + ) + return [EvaluationV2.from_json(r, self) for r in rows] + @deprecated(msg="Prefer calling Dataset.info() directly.") def dataset_info(self, dataset_id: str): dataset = self.get_dataset(dataset_id) diff --git a/nucleus/data_transfer_object/evaluation_v2.py b/nucleus/data_transfer_object/evaluation_v2.py new file mode 100644 index 00000000..a1abb443 --- /dev/null +++ b/nucleus/data_transfer_object/evaluation_v2.py @@ -0,0 +1,162 @@ +"""Response and filter models for Evaluation V2.""" + +from typing import Any, Dict, List, Literal, Optional + +from nucleus.pydantic_base import DictCompatibleModel + + +def _snake_to_camel(name: str) -> str: + parts = name.split("_") + if len(parts) == 1: + return name + return parts[0] + "".join(part.capitalize() for part in parts[1:]) + + +def _camelize_filter_value(value: Any) -> Any: + if isinstance(value, dict): + return { + _snake_to_camel(key): ( + val if key == "value" else _camelize_filter_value(val) + ) + for key, val in value.items() + } + if isinstance(value, list): + return [_camelize_filter_value(item) for item in value] + return value + + +class RangeNum(DictCompatibleModel): + min: Optional[float] = None + max: Optional[float] = None + + +class MetadataPredicate(DictCompatibleModel): + key: str + op: Literal["EQ", "IN", "GT", "LT"] + value: Optional[Any] = None + + +_FILTER_API_KEYS = { + "confidence_range": "confidenceRange", + "iou_range": "iouRange", + "pred_labels": "predLabels", + "gt_labels": "gtLabels", + "item_metadata": "itemMetadata", + "prediction_metadata": "predictionMetadata", + "label_equality": "labelEquality", + "has_ground_truth": "hasGroundTruth", + "tide_background": "tideBackground", +} + + +class EvaluationV2FilterArgs(DictCompatibleModel): + """Optional filters for :meth:`nucleus.evaluation_v2.EvaluationV2.charts` and :meth:`nucleus.evaluation_v2.EvaluationV2.examples`.""" + + confidence_range: Optional[RangeNum] = None + iou_range: Optional[RangeNum] = None + pred_labels: Optional[List[str]] = None + gt_labels: Optional[List[str]] = None + item_metadata: Optional[List[MetadataPredicate]] = None + prediction_metadata: Optional[List[MetadataPredicate]] = None + label_equality: Optional[Literal["EQ", "NEQ"]] = None + has_ground_truth: Optional[bool] = None + tide_background: Optional[bool] = None + + def to_api_filters(self) -> Dict[str, Any]: + """Return filters as a dict ready for API requests.""" + d = self.dict(exclude_none=True) + return { + api_key: _camelize_filter_value(d[snake_key]) + for snake_key, api_key in _FILTER_API_KEYS.items() + if snake_key in d + } + + +class MapSummary(DictCompatibleModel): + mapAt50: Optional[float] = None + mapAt75: Optional[float] = None + mapAt5095: Optional[float] = None + + +class PerClassAp(DictCompatibleModel): + classLabel: str + ap: float + + +class ConfusionEntry(DictCompatibleModel): + gtLabel: str + predLabel: str + count: int + + +class ScoreHistogramBucket(DictCompatibleModel): + bucketMin: float + bucketMax: float + count: int + + +class TotalCounts(DictCompatibleModel): + tp: int + fp: int + fn: int + predsWithConfidence: int + + +class ApBySize(DictCompatibleModel): + small: Optional[float] = None + medium: Optional[float] = None + large: Optional[float] = None + + +class PrCurvePoint(DictCompatibleModel): + classLabel: str + recall: float + precision: float + + +class TideAttribution(DictCompatibleModel): + truePositive: int + localization: int + classification: int + both: int + duplicate: int + background: int + missed: int + + +class EvaluationV2Charts(DictCompatibleModel): + mapSummary: MapSummary + perClassAp: List[PerClassAp] + confusionMatrix: List[ConfusionEntry] + scoreHistogram: List[ScoreHistogramBucket] + computedIouRanges: List[float] + totalCounts: TotalCounts + apBySize: ApBySize + prCurve: List[PrCurvePoint] + tideAttribution: TideAttribution + + +class EvaluationV2MatchExample(DictCompatibleModel): + id: str + evaluation_id: str + dataset_item_id: str + model_prediction_id: Optional[str] = None + ground_truth_annotation_id: Optional[str] = None + pred_canonical_label: Optional[str] = None + gt_canonical_label: Optional[str] = None + pred_raw_label: Optional[str] = None + gt_raw_label: Optional[str] = None + iou: Optional[float] = None + confidence: Optional[float] = None + true_positive: bool + match_type: str + gt_area: Optional[float] = None + item_metadata: Optional[Dict[str, Any]] = None + prediction_metadata: Optional[Dict[str, Any]] = None + prediction_row: Optional[Dict[str, Any]] = None + annotation_row: Optional[Dict[str, Any]] = None + + +class EvaluationV2ExamplesPage(DictCompatibleModel): + rows: List[EvaluationV2MatchExample] + total: int diff --git a/nucleus/evaluation_v2.py b/nucleus/evaluation_v2.py new file mode 100644 index 00000000..43f8a03c --- /dev/null +++ b/nucleus/evaluation_v2.py @@ -0,0 +1,253 @@ +"""Evaluation V2 — metrics and examples for a model run.""" + +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from urllib.parse import urlencode + +import requests + +from nucleus.data_transfer_object.evaluation_v2 import ( + EvaluationV2Charts, + EvaluationV2ExamplesPage, + EvaluationV2FilterArgs, +) + +if TYPE_CHECKING: + from nucleus import NucleusClient + + +class EvaluationV2Status(str, Enum): + """Status of an Evaluation V2 run.""" + + PENDING = "pending" + COMPUTING = "computing" + SUCCEEDED = "succeeded" + FAILED = "failed" + CANCELLED = "cancelled" + + +_TERMINAL_OK: Set[EvaluationV2Status] = { + EvaluationV2Status.SUCCEEDED, + EvaluationV2Status.CANCELLED, +} + + +@dataclass +class AllowedLabelMatch: + """Ground-truth and prediction label pair that counts as a match.""" + + ground_truth_label: str + model_prediction_label: str + + def to_api_dict(self) -> Dict[str, str]: + return { + "ground_truth_label": self.ground_truth_label, + "model_prediction_label": self.model_prediction_label, + } + + +@dataclass +class EvaluationV2: + """An Evaluation V2 run for a model run.""" + + id: str + model_run_id: str + dataset_id: str + status: str + name: Optional[str] = None + temporal_workflow_id: Optional[str] = None + error_message: Optional[str] = None + created_at: Optional[str] = None + allowed_label_matches_id: Optional[str] = None + allowed_label_matches: Optional[List[AllowedLabelMatch]] = None + allowed_label_matches_name: Optional[str] = None + _client: Optional["NucleusClient"] = field(repr=False, default=None) + + @classmethod + def from_json( + cls, + payload: Dict[str, Any], + client: Optional["NucleusClient"] = None, + ) -> "EvaluationV2": + raw_matches = payload.get("allowed_label_matches") + matches: Optional[List[AllowedLabelMatch]] = None + if isinstance(raw_matches, list): + matches = [] + for m in raw_matches: + if not isinstance(m, dict): + continue + gt = m.get("groundTruthLabel") + if gt is None: + gt = m.get("ground_truth_label") + mp = m.get("modelPredictionLabel") + if mp is None: + mp = m.get("model_prediction_label") + if gt is not None and mp is not None: + matches.append( + AllowedLabelMatch( + ground_truth_label=str(gt), + model_prediction_label=str(mp), + ) + ) + + return cls( + id=str(payload["id"]), + model_run_id=str(payload["model_run_id"]), + dataset_id=str(payload["dataset_id"]), + status=str(payload["status"]), + name=payload.get("name"), + temporal_workflow_id=payload.get("temporal_workflow_id"), + error_message=payload.get("error_message"), + created_at=payload.get("created_at"), + allowed_label_matches_id=payload.get("allowed_label_matches_id"), + allowed_label_matches=matches, + allowed_label_matches_name=payload.get( + "allowed_label_matches_name" + ), + _client=client, + ) + + def refresh(self) -> "EvaluationV2": + """Reload this evaluation from Nucleus. + + Returns: + self, with updated fields. + """ + if self._client is None: + raise RuntimeError( + "EvaluationV2 has no client; use NucleusClient.get_evaluation_v2." + ) + data = self._client.get(f"evaluationsV2/{self.id}") + updated = EvaluationV2.from_json(data, self._client) + self.__dict__.update(updated.__dict__) + return self + + def wait_for_completion( + self, + timeout_sec: float = 600, + poll_interval: float = 5, + ) -> "EvaluationV2": + """Wait until the evaluation finishes or is cancelled. + + Parameters: + timeout_sec: Maximum seconds to wait. + poll_interval: Seconds between status checks. + + Returns: + self, after a terminal status is reached. + + Raises: + RuntimeError: If the evaluation fails or times out. + """ + deadline = time.monotonic() + timeout_sec + while time.monotonic() < deadline: + self.refresh() + if self.status == EvaluationV2Status.FAILED: + raise RuntimeError( + f"Evaluation {self.id} failed: {self.error_message or 'unknown'}" + ) + if self.status in _TERMINAL_OK: + return self + time.sleep(poll_interval) + raise RuntimeError( + f"Timed out after {timeout_sec}s waiting for evaluation {self.id} " + f"(last status: {self.status})" + ) + + def delete(self) -> None: + """Delete this evaluation.""" + if self._client is None: + raise RuntimeError("EvaluationV2 has no client.") + self._client.make_request( + {}, + f"evaluationsV2/{self.id}", + requests_command=requests.delete, + return_raw_response=True, + ) + + def charts( + self, + iou_threshold: float = 0.5, + filters: Optional[ + Union[EvaluationV2FilterArgs, Dict[str, Any]] + ] = None, + query: Optional[str] = None, + ) -> EvaluationV2Charts: + """Return aggregate metrics for this evaluation. + + Parameters: + iou_threshold: IoU threshold for matching (default 0.5). + filters: Optional filters (:class:`EvaluationV2FilterArgs` or dict). + query: Optional query string to narrow results. + + Returns: + :class:`EvaluationV2Charts`: Summary metrics (mAP, confusion matrix, PR curve, etc.). + """ + if self._client is None: + raise RuntimeError("EvaluationV2 has no client.") + params: Dict[str, str] = {} + params["iouThreshold"] = str(iou_threshold) + if filters is not None: + if isinstance(filters, EvaluationV2FilterArgs): + filt_dict = filters.to_api_filters() + else: + filt_dict = filters + params["filters"] = json.dumps(filt_dict) + if query: + params["query"] = query + qs = urlencode(params) + route = f"evaluationsV2/{self.id}/charts?{qs}" + data = self._client.get(route) + return EvaluationV2Charts.parse_obj(data) + + def examples( + self, + match_type: str, + limit: int = 50, + offset: int = 0, + sort_by: Optional[str] = None, + sort_order: Optional[str] = None, + filters: Optional[ + Union[EvaluationV2FilterArgs, Dict[str, Any]] + ] = None, + query: Optional[str] = None, + ) -> EvaluationV2ExamplesPage: + """Return paginated true-positive, false-positive, or false-negative examples. + + Parameters: + match_type: ``"TP"``, ``"FP"``, or ``"FN"``. + limit: Page size (default 50). + offset: Row offset for pagination. + sort_by: Optional field to sort by. + sort_order: Optional sort direction (e.g. ``"asc"`` or ``"desc"``). + filters: Optional filters (:class:`EvaluationV2FilterArgs` or dict). + query: Optional query string to narrow results. + + Returns: + :class:`EvaluationV2ExamplesPage`: Matching rows and total count. + """ + if self._client is None: + raise RuntimeError("EvaluationV2 has no client.") + payload: Dict[str, Any] = { + "match_type": match_type, + "limit": limit, + "offset": offset, + } + if sort_by is not None: + payload["sort_by"] = sort_by + if sort_order is not None: + payload["sort_order"] = sort_order + if filters is not None: + if isinstance(filters, EvaluationV2FilterArgs): + payload["filters"] = filters.to_api_filters() + else: + payload["filters"] = filters + if query: + payload["query"] = query + data = self._client.post(payload, f"evaluationsV2/{self.id}/examples") + return EvaluationV2ExamplesPage.parse_obj(data) diff --git a/tests/test_evaluation_v2.py b/tests/test_evaluation_v2.py new file mode 100644 index 00000000..829e34a2 --- /dev/null +++ b/tests/test_evaluation_v2.py @@ -0,0 +1,265 @@ +"""Unit tests for Evaluations V2 client (no live API).""" + +from unittest.mock import MagicMock + +import pytest +import requests + +from nucleus import AllowedLabelMatch, EvaluationV2, NucleusClient +from nucleus.data_transfer_object.evaluation_v2 import ( + EvaluationV2Charts, + EvaluationV2FilterArgs, + MetadataPredicate, + RangeNum, + _camelize_filter_value, +) + + +def test_evaluation_v2_filter_args_to_api_filters(): + filters = EvaluationV2FilterArgs( + confidence_range=RangeNum(min=0.1, max=0.9), + pred_labels=["cat"], + item_metadata=[MetadataPredicate(key="tier", op="EQ", value="gold")], + has_ground_truth=True, + ) + assert filters.to_api_filters() == { + "confidenceRange": {"min": 0.1, "max": 0.9}, + "predLabels": ["cat"], + "itemMetadata": [{"key": "tier", "op": "EQ", "value": "gold"}], + "hasGroundTruth": True, + } + + +def test_camelize_filter_value_nested_keys(): + assert _camelize_filter_value({"bucket_min": 1.0, "bucket_max": 2.0}) == { + "bucketMin": 1.0, + "bucketMax": 2.0, + } + + +def test_camelize_filter_value_preserves_predicate_value(): + assert _camelize_filter_value( + {"key": "k", "op": "EQ", "value": {"keep_snake": 1}} + ) == {"key": "k", "op": "EQ", "value": {"keep_snake": 1}} + + +def test_allowed_label_match_to_api_dict(): + m = AllowedLabelMatch(ground_truth_label="a", model_prediction_label="b") + assert m.to_api_dict() == { + "ground_truth_label": "a", + "model_prediction_label": "b", + } + + +def test_evaluation_v2_from_json_with_matches(): + client = NucleusClient(api_key="k") + payload = { + "id": "evalv2_1", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "pending", + "allowed_label_matches": [ + {"groundTruthLabel": "x", "modelPredictionLabel": "y"}, + ], + } + ev = EvaluationV2.from_json(payload, client) + assert ev.id == "evalv2_1" + assert ev.allowed_label_matches is not None + assert len(ev.allowed_label_matches) == 1 + assert ev.allowed_label_matches[0].ground_truth_label == "x" + + +def test_list_evaluations_v2_empty(): + client = NucleusClient(api_key="test") + client.connection.get = MagicMock(return_value=[]) + result = client.list_evaluations_v2("run_1") + assert result == [] + client.connection.get.assert_called_once_with( + "modelRun/run_1/evaluationsV2" + ) + + +def test_list_evaluations_v2_returns_rows(): + client = NucleusClient(api_key="test") + client.connection.get = MagicMock( + return_value=[ + { + "id": "evalv2_1", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "succeeded", + }, + ] + ) + result = client.list_evaluations_v2("run_1") + assert len(result) == 1 + assert result[0].id == "evalv2_1" + assert result[0]._client is client + + +def test_list_evaluations_v2_invalid_response(): + client = NucleusClient(api_key="test") + client.connection.get = MagicMock(return_value={"evaluations": []}) + with pytest.raises(RuntimeError, match="Unexpected list evaluations V2"): + client.list_evaluations_v2("run_1") + + +def test_create_evaluation_v2_then_get(): + client = NucleusClient(api_key="test") + client.connection.make_request = MagicMock( + return_value={ + "evaluation_id": "evalv2_new", + "status": "pending", + "workflow_id": "w", + } + ) + client.connection.get = MagicMock( + return_value={ + "id": "evalv2_new", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "pending", + } + ) + + ev = client.create_evaluation_v2( + "run_1", + name="n1", + allowed_label_matches=[ + AllowedLabelMatch("gt", "pred"), + ], + ) + assert ev.id == "evalv2_new" + client.connection.make_request.assert_called_once() + client.connection.get.assert_called_once_with("evaluationsV2/evalv2_new") + + +def test_charts_get_query_string(): + client = MagicMock(spec=NucleusClient) + client.get.return_value = { + "mapSummary": {"mapAt50": 0.1, "mapAt75": 0.2, "mapAt5095": 0.15}, + "perClassAp": [], + "confusionMatrix": [], + "scoreHistogram": [], + "computedIouRanges": [], + "totalCounts": {"tp": 0, "fp": 0, "fn": 0, "predsWithConfidence": 0}, + "apBySize": {"small": None, "medium": None, "large": None}, + "prCurve": [], + "tideAttribution": { + "truePositive": 0, + "localization": 0, + "classification": 0, + "both": 0, + "duplicate": 0, + "background": 0, + "missed": 0, + }, + } + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="succeeded", + _client=client, + ) + charts = ev.charts(iou_threshold=0.5) + assert isinstance(charts, EvaluationV2Charts) + call_route = client.get.call_args[0][0] + assert "evaluationsV2/evalv2_1/charts" in call_route + assert "iouThreshold=0.5" in call_route + + +def test_examples_post_body(): + client = MagicMock(spec=NucleusClient) + client.post.return_value = {"rows": [], "total": 0} + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="succeeded", + _client=client, + ) + page = ev.examples("TP", limit=20, offset=5) + assert page.total == 0 + client.post.assert_called_once() + args, kwargs = client.post.call_args + payload, route = args + assert route == "evaluationsV2/evalv2_1/examples" + assert payload["match_type"] == "TP" + assert payload["limit"] == 20 + assert payload["offset"] == 5 + + +def test_examples_with_filter_args(): + client = MagicMock(spec=NucleusClient) + client.post.return_value = {"rows": [], "total": 0} + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="succeeded", + _client=client, + ) + filters = EvaluationV2FilterArgs( + confidence_range=RangeNum(min=0.1, max=0.9), + pred_labels=["cat"], + has_ground_truth=True, + ) + ev.examples("FP", limit=10, filters=filters) + payload = client.post.call_args[0][0] + assert payload["filters"] == { + "confidenceRange": {"min": 0.1, "max": 0.9}, + "predLabels": ["cat"], + "hasGroundTruth": True, + } + + +def test_wait_for_completion(): + client = NucleusClient(api_key="test") + client.connection.get = MagicMock( + side_effect=[ + { + "id": "evalv2_1", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "pending", + }, + { + "id": "evalv2_1", + "model_run_id": "run_1", + "dataset_id": "ds_1", + "status": "succeeded", + }, + ] + ) + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="pending", + _client=client, + ) + ev.wait_for_completion(timeout_sec=5, poll_interval=0.01) + assert ev.status == "succeeded" + + +@pytest.mark.parametrize("status_code", [200, 204]) +def test_delete_success(status_code): + client = NucleusClient(api_key="test") + resp = MagicMock() + resp.status_code = status_code + client.connection.make_request = MagicMock(return_value=resp) + ev = EvaluationV2( + id="evalv2_1", + model_run_id="run_1", + dataset_id="ds_1", + status="succeeded", + _client=client, + ) + ev.delete() + assert client.connection.make_request.call_count == 1 + cargs = client.connection.make_request.call_args + assert cargs[0][0] == {} + assert cargs[0][1] == "evaluationsV2/evalv2_1" + assert cargs[0][2] is requests.delete + assert cargs[0][3] is True