diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..97dd996 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,36 @@ +name: Lint and Test + +on: + pull_request: + branches: [main, "feat/**", "fix/**"] + push: + branches: [main] + +jobs: + lint-and-test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + version: "0.6.14" + enable-cache: true + + - name: Install dependencies + run: uv sync --dev --all-extras + + - name: Create test env + run: cp tests/sample.env tests/.env + + - name: Tests (pytest) + run: uv run pytest tests/ -v diff --git a/pyproject.toml b/pyproject.toml index c45a8a9..a208e3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,15 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", ] +[project.optional-dependencies] +clone = [ + "click>=8.1", + "rich>=13.7", +] + +[project.scripts] +unstract-clone = "unstract.clone.cli:main" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/unstract/clone/README.md b/src/unstract/clone/README.md new file mode 100644 index 0000000..e31557c --- /dev/null +++ b/src/unstract/clone/README.md @@ -0,0 +1,61 @@ +# Cloning Organizations + +> [!NOTE] +> **Users are not cloned.** Two reasons: +> - The same user may not need access in every environment. +> - The same user may hold different roles across environments. +> +> **Groups _will_ be cloned** (upcoming — not yet implemented). Once available, an admin can add the right users to each group per environment. + +Clone an Unstract organization's configured resources into another organization (same deployment or different). Useful for environment promotion (DEV → QA → PROD) and for spinning up a fresh org from a known-good baseline. + +Cloned resources: adapters, connectors, custom tools, prompts, profiles, workflows, tool instances, workflow endpoints, tags, API deployments, pipelines, and Prompt Studio document files. The source org is left untouched. + +> **Full documentation, behavior notes, CLI reference, and sample report:** +> https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/cloning-orgs/ + +## Install + +From a clone of this repository: + +```bash +uv sync --all-extras +``` + +This pulls in the `clone` extra (`click`, `rich`) needed by the CLI. + +## Quickstart + +```bash +UNSTRACT_SRC_PLATFORM_KEY=src_pk_... \ +UNSTRACT_TGT_PLATFORM_KEY=tgt_pk_... \ +uv run python -m unstract.clone clone \ + --source-url https://source.example.com \ + --source-org my-source-org \ + --target-url https://target.example.com \ + --target-org my-target-org +``` + +Both keys must be **org admin Platform API keys**. + +> [!WARNING] +> Both keys grant broad access. Run from a trusted machine and rotate both keys after the clone completes. + +> [!NOTE] +> **Unstract Cloud free-trial adapters are not cloned.** Trial adapters are platform-owned and filtered out of the source listing. Prompt Studio projects whose default profile references them are skipped, and that cascades to dependent workflows, API deployments, and pipelines. Provision your own adapters on the target org and re-run the clone to bring the rest across. + +> [!NOTE] +> **OAuth-backed connectors need re-authorisation on target.** Connectors that use OAuth (e.g. Google Drive) are cloned without their refresh tokens — the Platform API never exposes them. Re-connect each one on the target after the clone. + +## Re-runs are safe + +If a phase fails partway, fix the cause and re-run the same command. Resources already on the target are detected by name and reused. There is no `--resume-from` flag — the target is the state. + +## Files + +The Prompt Studio document corpus is the only resource type with bytes on disk. Default cap per file is 25 MB; oversize files are reported for manual re-upload. Use `--skip-files` to skip bytes entirely (document records are still created). + +> [!WARNING] +> Run clones during low-activity windows. Concurrent uploads to the source org during a clone can create duplicate file records on the target. + +See the [public docs](https://docs.unstract.com/unstract/unstract_platform/api_documentation/versions/cloning-orgs/) for the full flag list, behavioral notes, and the format of the end-of-run report. diff --git a/src/unstract/clone/__init__.py b/src/unstract/clone/__init__.py new file mode 100644 index 0000000..c36300b --- /dev/null +++ b/src/unstract/clone/__init__.py @@ -0,0 +1,25 @@ +"""Cloning organizations over the Platform API. + +Migrates configured resources (adapters, connectors, custom tools, workflows, +etc.) from one Unstract org to another using two admin-issued Platform API +keys. The target deployment is the persistent state — re-runs reconcile +against existing target rows by natural key. +""" + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + OrgEndpoint, + RemapTable, +) +from unstract.clone.orchestrator import clone +from unstract.clone.report import CloneReport + +__all__ = [ + "CloneContext", + "CloneOptions", + "CloneReport", + "OrgEndpoint", + "RemapTable", + "clone", +] diff --git a/src/unstract/clone/__main__.py b/src/unstract/clone/__main__.py new file mode 100644 index 0000000..2eef45c --- /dev/null +++ b/src/unstract/clone/__main__.py @@ -0,0 +1,6 @@ +"""Entry point: ``python -m unstract.clone``.""" + +from unstract.clone.cli import main + +if __name__ == "__main__": + main() diff --git a/src/unstract/clone/cli.py b/src/unstract/clone/cli.py new file mode 100644 index 0000000..d2ed358 --- /dev/null +++ b/src/unstract/clone/cli.py @@ -0,0 +1,212 @@ +"""Click-based CLI for ``unstract.clone``. + +Single ``clone`` command. Platform keys can be passed via flags +(``--source-key`` / ``--target-key``) or env vars +(``UNSTRACT_SRC_PLATFORM_KEY`` / ``UNSTRACT_TGT_PLATFORM_KEY``) — env vars +are preferred so the key never lands in shell history. +""" + +from __future__ import annotations + +import logging +import re +import sys +from typing import Any + +import click + +from unstract.clone.context import ( + DEFAULT_CONCURRENCY, + DEFAULT_MAX_FILE_SIZE, + CloneOptions, + OrgEndpoint, +) +from unstract.clone.exceptions import CloneError +from unstract.clone.orchestrator import clone as run_clone + +_SIZE_UNITS: dict[str, int] = { + "B": 1, + "K": 1024, + "KB": 1024, + "M": 1024 * 1024, + "MB": 1024 * 1024, + "G": 1024 * 1024 * 1024, + "GB": 1024 * 1024 * 1024, +} +_SIZE_RE = re.compile(r"^\s*(\d+(?:\.\d+)?)\s*([A-Za-z]*)\s*$") + + +def _parse_size(value: str) -> int: + """Accept ``25``, ``25MB``, ``1.5GB`` etc. Returns bytes.""" + m = _SIZE_RE.match(value) + if not m: + raise click.BadParameter(f"can't parse size '{value}'") + num, unit = m.group(1), m.group(2).upper() or "B" + if unit not in _SIZE_UNITS: + raise click.BadParameter( + f"unknown size unit '{unit}'; use one of {sorted(_SIZE_UNITS)}" + ) + return int(float(num) * _SIZE_UNITS[unit]) + + +def _configure_logging(verbose: bool) -> None: + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s %(levelname)-7s %(name)s: %(message)s", + datefmt="%H:%M:%S", + ) + + +def _split_csv(value: str | None) -> tuple[str, ...] | None: + if not value: + return None + return tuple(p.strip() for p in value.split(",") if p.strip()) + + +@click.group() +def cli() -> None: + """Cloning organizations over the Platform API.""" + + +@cli.command("clone") +@click.option("--source-url", required=True, help="Base URL of the source deployment") +@click.option( + "--source-org", required=True, help="Source organization_id (slug in the URL path)" +) +@click.option( + "--source-key", + envvar="UNSTRACT_SRC_PLATFORM_KEY", + required=True, + help="Source admin's Platform API key (or env UNSTRACT_SRC_PLATFORM_KEY)", +) +@click.option("--target-url", required=True, help="Base URL of the target deployment") +@click.option( + "--target-org", required=True, help="Target organization_id (slug in the URL path)" +) +@click.option( + "--target-key", + envvar="UNSTRACT_TGT_PLATFORM_KEY", + required=True, + help="Target admin's Platform API key (or env UNSTRACT_TGT_PLATFORM_KEY)", +) +@click.option( + "--dry-run", is_flag=True, help="Plan only — do not POST anything to target" +) +@click.option( + "--include", + default=None, + help="Comma-separated phase names to include (default: all)", +) +@click.option( + "--exclude", + default=None, + help="Comma-separated phase names to exclude", +) +@click.option( + "--on-name-conflict", + type=click.Choice(["adopt", "abort"]), + default="adopt", + show_default=True, + help="What to do when a like-named entity exists in target", +) +@click.option( + "--api-prefix", + default="api/v1", + show_default=True, + help="Backend URL prefix (matches deployment's PATH_PREFIX env)", +) +@click.option( + "--file-strategy", + type=click.Choice(["platform_api", "skip"]), + default="platform_api", + show_default=True, + help="How to move Prompt Studio document files. 'skip' = metadata only.", +) +@click.option( + "--max-file-size", + default="25MB", + show_default=True, + help="Per-file cap for the files phase. Oversize → reported, not aborted.", +) +@click.option( + "--skip-files", + is_flag=True, + help="Alias for --file-strategy=skip.", +) +@click.option( + "--concurrency", + type=click.IntRange(min=1, max=32), + default=DEFAULT_CONCURRENCY, + show_default=True, + help="Per-phase worker count. 1 = strictly sequential.", +) +@click.option("-v", "--verbose", is_flag=True, help="Debug logging") +def clone_cmd( + source_url: str, + source_org: str, + source_key: str, + target_url: str, + target_org: str, + target_key: str, + dry_run: bool, + include: str | None, + exclude: str | None, + on_name_conflict: str, + api_prefix: str, + file_strategy: str, + max_file_size: str, + skip_files: bool, + concurrency: int, + verbose: bool, +) -> None: + """Clone configured resources from one org to another.""" + _configure_logging(verbose) + + effective_strategy = "skip" if skip_files else file_strategy + try: + cap_bytes = _parse_size(max_file_size) + except click.BadParameter as e: + raise click.UsageError(str(e)) from e + + options = CloneOptions( + dry_run=dry_run, + include=_split_csv(include), + exclude=_split_csv(exclude) or (), + on_name_conflict=on_name_conflict, + verbose=verbose, + file_strategy=effective_strategy, + max_file_size=cap_bytes if cap_bytes is not None else DEFAULT_MAX_FILE_SIZE, + concurrency=concurrency, + ) + + source = OrgEndpoint( + base_url=source_url, + organization_id=source_org, + platform_key=source_key, + api_path_prefix=api_prefix, + ) + target = OrgEndpoint( + base_url=target_url, + organization_id=target_org, + platform_key=target_key, + api_path_prefix=api_prefix, + ) + + try: + report = run_clone(source, target, options) + except CloneError as e: + click.echo(f"Clone failed: {e}", err=True) + sys.exit(2) + + click.echo(report.render()) + if report.aborted or any(p.failed for p in report.phases): + sys.exit(1) + + +def main(argv: list[str] | None = None) -> Any: + return cli(args=argv, standalone_mode=True) + + +if __name__ == "__main__": + main() diff --git a/src/unstract/clone/client.py b/src/unstract/clone/client.py new file mode 100644 index 0000000..ad873da --- /dev/null +++ b/src/unstract/clone/client.py @@ -0,0 +1,502 @@ +"""Thin Platform API client for the clone subpackage. + +One ``PlatformClient`` instance per ``OrgEndpoint``. Methods are entity- +scoped (``list_adapters``, ``create_adapter``, ...) so call sites in phases +read like business logic, not HTTP plumbing. + +URL shape: ``{base_url}/{api_path_prefix}/unstract/{organization_id}//`` +Auth: ``Authorization: Bearer ``. +""" + +from __future__ import annotations + +import json as json_lib +import logging +from typing import Any + +import requests + +from unstract.clone.context import OrgEndpoint +from unstract.clone.exceptions import PlatformAPIError + +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 60 + + +class PlatformClient: + """HTTP client scoped to a single org via its Platform API key.""" + + def __init__( + self, endpoint: OrgEndpoint, timeout: int = DEFAULT_TIMEOUT, verify: bool = True + ): + self.endpoint = endpoint + self.timeout = timeout + self.verify = verify + self._session = requests.Session() + self._session.headers.update( + { + "Authorization": f"Bearer {endpoint.platform_key}", + "Accept": "application/json", + } + ) + # Cache the OPTIONS-derived writable-field set per entity path. + # Backend serializer is the single source of truth; we read it once. + self._post_schema_cache: dict[str, frozenset[str]] = {} + + def close(self) -> None: + """Release the underlying HTTP connection pool.""" + self._session.close() + + def __enter__(self) -> "PlatformClient": + return self + + def __exit__(self, *exc: Any) -> None: + self.close() + + def _url(self, path: str) -> str: + base = self.endpoint.base_url.rstrip("/") + api_prefix = self.endpoint.api_path_prefix.strip("/") + prefix = f"/{api_prefix}/unstract/{self.endpoint.organization_id}/" + return base + prefix + path.lstrip("/") + + def _request( + self, + method: str, + path: str, + *, + params: dict[str, Any] | None = None, + json: Any = None, + files: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, + ) -> Any: + url = self._url(path) + # Redact secrets from logs: only entity path + method, never body. + logger.debug("%s %s", method, url) + resp = self._session.request( + method, + url, + params=params, + json=json, + files=files, + data=data, + timeout=self.timeout, + verify=self.verify, + ) + if not 200 <= resp.status_code < 300: + raise PlatformAPIError( + f"{method} {path} returned {resp.status_code}", + status_code=resp.status_code, + body=resp.text[:2000], + ) + if resp.status_code == 204 or not resp.content: + return None + return resp.json() + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + """Return the set of fields the backend's POST serializer accepts. + + Reads it from a DRF ``OPTIONS`` response (``actions.POST``) once + per path and caches the result. DRF ``SimpleMetadata`` already + excludes ``read_only`` fields from ``actions.POST``, so the + returned set is exactly the writable subset. + """ + cached = self._post_schema_cache.get(entity_path) + if cached is not None: + return cached + body = self._request("OPTIONS", entity_path) + actions = (body or {}).get("actions") or {} + post_block = actions.get("POST") or {} + writable = frozenset( + name for name, meta in post_block.items() if not meta.get("read_only") + ) + self._post_schema_cache[entity_path] = writable + return writable + + # ----- adapters ----- + + def list_adapters( + self, + *, + name: str | None = None, + adapter_type: str | None = None, + ) -> list[dict[str, Any]]: + """List adapters in this org, optionally filtered by name and/or type.""" + params: dict[str, Any] = {} + if name is not None: + params["adapter_name"] = name + if adapter_type is not None: + params["adapter_type"] = adapter_type + result = self._request("GET", "adapter/", params=params) + # DRF ModelViewSet.list returns a bare list (no pagination on this endpoint). + return result if isinstance(result, list) else result.get("results", []) + + def get_adapter(self, adapter_pk: str) -> dict[str, Any]: + return self._request("GET", f"adapter/{adapter_pk}/") + + def create_adapter(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "adapter/", json=payload) + + # ----- connectors ----- + + def list_connectors( + self, + *, + name: str | None = None, + connector_type: str | None = None, + ) -> list[dict[str, Any]]: + """List connectors in this org, optionally filtered by name and/or type.""" + params: dict[str, Any] = {} + if name is not None: + params["connector_name"] = name + if connector_type is not None: + params["connector_type"] = connector_type + result = self._request("GET", "connector/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_connector(self, connector_pk: str) -> dict[str, Any]: + return self._request("GET", f"connector/{connector_pk}/") + + def create_connector(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "connector/", json=payload) + + # ----- tags ----- + + def list_tags(self, *, name: str | None = None) -> list[dict[str, Any]]: + """List tags in this org, optionally filtered by exact name.""" + params: dict[str, Any] = {} + if name is not None: + params["name"] = name + result = self._request("GET", "tags/", params=params) + # Tags endpoint uses pagination — accept either bare list or paginated envelope. + return result if isinstance(result, list) else result.get("results", []) + + def create_tag(self, payload: dict[str, Any]) -> dict[str, Any]: + return self._request("POST", "tags/", json=payload) + + # ----- custom tools (prompt studio) ----- + + def list_custom_tools(self) -> list[dict[str, Any]]: + """List all prompt-studio projects in this org. No name filter.""" + result = self._request("GET", "prompt-studio/") + return result if isinstance(result, list) else result.get("results", []) + + def get_custom_tool(self, tool_id: str) -> dict[str, Any]: + """Fetch a single prompt-studio project (full serializer). + + Returns ``fields = "__all__"`` per ``CustomToolSerializer`` — + notably includes ``output`` (the default DocumentManager id the + FE binds to ``selectedDoc`` on load). + """ + return self._request("GET", f"prompt-studio/{tool_id}/") + + def update_custom_tool(self, tool_id: str, body: dict[str, Any]) -> dict[str, Any]: + """PATCH a prompt-studio project. Used to set ``output`` (the + default doc id) after the files phase populates DM rows.""" + return self._request("PATCH", f"prompt-studio/{tool_id}/", json=body) + + def list_profiles(self, tool_id: str) -> list[dict[str, Any]]: + """List ProfileManager rows for a tool. + + The clone reads this on the source only — to discover the + default profile's adapter UUIDs so they can be remapped to + target adapter ids for ``import_project``. + """ + result = self._request("GET", f"prompt-studio/prompt-studio-profile/{tool_id}/") + return result if isinstance(result, list) else result.get("results", []) + + def export_project(self, tool_id: str) -> dict[str, Any]: + """Export a prompt-studio project as a portable JSON blob. + + Bundles ``tool_metadata``, ``tool_settings``, + ``default_profile_settings``, ``prompts``, ``export_metadata`` in + one shot — feed straight into ``import_project`` or + ``sync_prompts`` on the target. + """ + return self._request("GET", f"prompt-studio/project-transfer/{tool_id}") + + def import_project( + self, + export_data: dict[str, Any], + adapter_ids: dict[str, str | None] | None = None, + ) -> dict[str, Any]: + """Import a prompt-studio project from an export blob. + + Backend creates the tool, builds the default ProfileManager from + the supplied target-org adapter ids, and imports all prompts in + one call. On name collision the backend silently uniquifies the + new tool's name — callers should pre-check via + ``list_custom_tools`` to avoid that. + + ``adapter_ids`` keys are the backend's form fields: + ``llm_adapter_id``, ``vector_db_adapter_id``, + ``embedding_adapter_id``, ``x2text_adapter_id``. All four + required to wire the profile; otherwise backend falls back to + a profile without adapters and flags ``needs_adapter_config``. + """ + tool_name = export_data.get("tool_metadata", {}).get("tool_name") or "export" + content = json_lib.dumps(export_data).encode() + files = {"file": (f"{tool_name}.json", content, "application/json")} + data: dict[str, Any] = {} + if adapter_ids: + for key in ( + "llm_adapter_id", + "vector_db_adapter_id", + "embedding_adapter_id", + "x2text_adapter_id", + ): + val = adapter_ids.get(key) + if val: + data[key] = val + return self._request( + "POST", + "prompt-studio/project-transfer/", + files=files, + data=data, + ) + + def sync_prompts( + self, + tool_id: str, + export_data: dict[str, Any], + *, + create_copy: bool = False, + ) -> dict[str, Any]: + """Rip-and-replace prompts on an existing target tool. + + Adopt path: target tool already exists with its own + adapter-bound profiles. This overwrites its prompt set (and + ``tool_settings``) from source; profiles and uploaded documents + are left untouched. + """ + payload = {"data": export_data, "create_copy": create_copy} + return self._request( + "POST", f"prompt-studio/{tool_id}/sync-prompts/", json=payload + ) + + def list_prompt_documents(self, tool_id: str) -> list[dict[str, Any]]: + """List DocumentManager rows for a tool. + + Used by FilesPhase for target-side idempotency and source-side + enumeration. Response items carry ``document_id``, + ``document_name``, and ``tool`` (per the serializer's + ``to_representation`` filter). + """ + result = self._request( + "GET", "prompt-studio/prompt-document/", params={"tool_id": tool_id} + ) + return result if isinstance(result, list) else result.get("results", []) + + def download_prompt_file(self, tool_id: str, document_id: str) -> dict[str, Any]: + """GET a Prompt Studio document by tool + DM row id. + + ``fetch_contents_ide`` resolves the filename internally from the + DocumentManager row, so the SDK passes the ``document_id`` it + already has from ``list_prompt_documents`` rather than reposting + the filename. Returns ``{"data": ..., "mime_type": ...}`` — + PDFs base64, text/csv utf-8, Excel placeholder. + """ + return self._request( + "GET", + f"prompt-studio/file/{tool_id}", + params={"document_id": document_id}, + ) + + def upload_prompt_file( + self, + tool_id: str, + file_name: str, + data: bytes, + mime_type: str, + ) -> dict[str, Any]: + """Upload a file into a target Prompt Studio tool. + + Backend writes bytes to storage and creates a ``DocumentManager`` + row. The DM model has ``UniqueConstraint(document_name, tool)``, + so callers must pre-check via ``list_prompt_documents`` to avoid + an IntegrityError → 500 on re-runs. + """ + files = {"file": (file_name, data, mime_type)} + return self._request("POST", f"prompt-studio/file/{tool_id}", files=files) + + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> Any: + """Republish ``PromptStudioRegistry`` from the tool's current state. + + Called after import/sync so the registry row reflects the + freshly landed prompts. Required for ToolInstancePhase to find + a target registry id to remap. + """ + return self._request( + "POST", + f"prompt-studio/export/{tool_id}", + json={ + "is_shared_with_org": False, + "user_id": [], + "force_export": force, + }, + ) + + # ----- workflows ----- + + def list_workflows(self, *, name: str | None = None) -> list[dict[str, Any]]: + """List workflows in this org, optionally filtered by exact name.""" + params: dict[str, Any] = {} + if name is not None: + params["workflow_name"] = name + result = self._request("GET", "workflow/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_workflow(self, workflow_id: str) -> dict[str, Any]: + return self._request("GET", f"workflow/{workflow_id}/") + + def create_workflow(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a workflow. Backend auto-creates empty WorkflowEndpoints for it.""" + return self._request("POST", "workflow/", json=payload) + + # ----- prompt studio registry ----- + + def list_registries( + self, *, custom_tool: str | None = None + ) -> list[dict[str, Any]]: + """List PromptStudioRegistry rows. The list endpoint returns nothing + unless a filter is supplied; pass ``custom_tool`` to look up the + registry id for a given tool. + """ + params: dict[str, Any] = {} + if custom_tool is not None: + params["custom_tool"] = custom_tool + result = self._request("GET", "prompt-studio/registry/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + # ----- tool instances ----- + + def list_tool_instances( + self, *, workflow_id: str | None = None + ) -> list[dict[str, Any]]: + """List ToolInstance rows, optionally scoped to a workflow.""" + params: dict[str, Any] = {} + if workflow_id is not None: + params["workflow"] = workflow_id + result = self._request("GET", "tool_instance/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def create_tool_instance(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a tool instance (max 1 per workflow). The backend overwrites + the ``metadata`` field with tool defaults — caller must PATCH after + create to transfer source metadata. + """ + return self._request("POST", "tool_instance/", json=payload) + + def update_tool_instance_metadata( + self, instance_id: str, metadata: dict[str, Any] + ) -> dict[str, Any]: + """PATCH a tool instance's metadata. Backend resolves adapter names + in the payload to local UUIDs via ``update_instance_metadata``. + """ + return self._request( + "PATCH", f"tool_instance/{instance_id}/", json={"metadata": metadata} + ) + + # ----- workflow endpoints ----- + + def list_workflow_endpoints( + self, *, workflow_id: str | None = None + ) -> list[dict[str, Any]]: + """List workflow endpoints, optionally filtered by workflow id. + + The backend auto-creates one SOURCE and one DESTINATION endpoint + per workflow, so a workflow filter typically returns exactly two + rows. + """ + params: dict[str, Any] = {} + if workflow_id is not None: + params["workflow"] = workflow_id + result = self._request("GET", "workflow/endpoint/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def update_workflow_endpoint( + self, endpoint_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request("PATCH", f"workflow/endpoint/{endpoint_id}/", json=payload) + + # ----- pipelines (ETL / TASK) ----- + + def list_pipelines( + self, + *, + name: str | None = None, + pipeline_type: str | None = None, + ) -> list[dict[str, Any]]: + """List pipelines in this org, optionally filtered by exact name + and/or pipeline_type (``ETL`` / ``TASK`` / ``APP``). + """ + params: dict[str, Any] = {} + if name is not None: + params["pipeline_name"] = name + if pipeline_type is not None: + params["type"] = pipeline_type + result = self._request("GET", "pipeline/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_pipeline(self, pipeline_id: str) -> dict[str, Any]: + return self._request("GET", f"pipeline/{pipeline_id}/") + + def create_pipeline(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create a pipeline. Backend force-sets ``active=True`` and auto-creates + a single active API key on the new pipeline. + """ + return self._request("POST", "pipeline/", json=payload) + + def update_pipeline( + self, pipeline_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request("PATCH", f"pipeline/{pipeline_id}/", json=payload) + + # ----- API deployments ----- + + def list_api_deployments( + self, + *, + api_name: str | None = None, + ) -> list[dict[str, Any]]: + """List API deployments in this org, optionally filtered by exact api_name.""" + params: dict[str, Any] = {} + if api_name is not None: + params["api_name"] = api_name + result = self._request("GET", "api/deployment/", params=params) + return result if isinstance(result, list) else result.get("results", []) + + def get_api_deployment(self, deployment_id: str) -> dict[str, Any]: + return self._request("GET", f"api/deployment/{deployment_id}/") + + def create_api_deployment(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an API deployment. Backend auto-creates a single active key + and returns it in the response under ``api_key``. + """ + return self._request("POST", "api/deployment/", json=payload) + + def update_api_deployment( + self, deployment_id: str, payload: dict[str, Any] + ) -> dict[str, Any]: + return self._request("PATCH", f"api/deployment/{deployment_id}/", json=payload) + + # ----- API keys (per pipeline / deployment) ----- + + def list_pipeline_keys(self, pipeline_id: str) -> list[dict[str, Any]]: + """List API keys belonging to a pipeline.""" + result = self._request("GET", f"api/keys/pipeline/{pipeline_id}/") + return result if isinstance(result, list) else result.get("results", []) + + def list_api_deployment_keys(self, deployment_id: str) -> list[dict[str, Any]]: + """List API keys belonging to an API deployment.""" + result = self._request("GET", f"api/keys/api/{deployment_id}/") + return result if isinstance(result, list) else result.get("results", []) + + def create_api_key(self, payload: dict[str, Any]) -> dict[str, Any]: + """Create an extra API key tied to a pipeline or deployment. + + Used to mirror non-default keys (e.g. an additional rotated key) + on the target. The ``api_key`` UUID itself is server-generated + and cannot be carried over from source. + """ + return self._request("POST", "api/keys/api/", json=payload) diff --git a/src/unstract/clone/context.py b/src/unstract/clone/context.py new file mode 100644 index 0000000..833668f --- /dev/null +++ b/src/unstract/clone/context.py @@ -0,0 +1,109 @@ +"""Shared state passed between clone phases. + +Three top-level types: + +- ``OrgEndpoint`` — base URL + organization_id + Platform API key for one org. +- ``CloneOptions`` — run flags (dry-run, include/exclude, name-conflict). +- ``CloneContext`` — bundles source/target clients, options, and the + per-run ``RemapTable``. + +``RemapTable`` lives here too because every phase touches it. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from unstract.clone.client import PlatformClient + + +@dataclass(frozen=True) +class OrgEndpoint: + """One end of a clone: where to talk to and who to talk as. + + ``organization_id`` is the slug embedded in the URL path; the bearer + Platform API key must belong to this org. ``api_path_prefix`` matches + the deployment's URL prefix (defaults to ``api/v1``). + """ + + base_url: str + organization_id: str + platform_key: str + api_path_prefix: str = "api/v1" + + +DEFAULT_MAX_FILE_SIZE = 25 * 1024 * 1024 # 25 MB; oversize → manual-upload list +DEFAULT_CONCURRENCY = 4 + + +@dataclass +class CloneOptions: + """Per-run flags for ``clone()``.""" + + dry_run: bool = False + include: tuple[str, ...] | None = None + exclude: tuple[str, ...] = () + on_name_conflict: str = "adopt" # "adopt" | "abort" + verbose: bool = False + # "platform_api": download/upload via existing endpoints (default). + # "skip": metadata only; operator re-uploads via UI on target. + file_strategy: str = "platform_api" + max_file_size: int = DEFAULT_MAX_FILE_SIZE + # Per-phase worker fan-out. 1 = sequential (no executor). + concurrency: int = DEFAULT_CONCURRENCY + + def includes(self, phase_name: str) -> bool: + if self.include is not None and phase_name not in self.include: + return False + return phase_name not in self.exclude + + +class RemapTable: + """Maps source UUID -> target UUID, scoped per entity type. + + Built up in dependency order; consumed by the JSON walker before POST. + ``resolve_any`` lets the walker look up a UUID without knowing its + entity type — necessary because embedded references in JSON payloads + don't always carry an entity hint. + """ + + def __init__(self) -> None: + self._table: dict[str, dict[str, str]] = {} + + def record(self, entity: str, src_uuid: str, tgt_uuid: str) -> None: + self._table.setdefault(entity, {})[src_uuid] = tgt_uuid + + def resolve(self, entity: str, src_uuid: str) -> str | None: + return self._table.get(entity, {}).get(src_uuid) + + def resolve_any(self, src_uuid: str) -> str | None: + # Snapshot to avoid `RuntimeError: dictionary changed size during + # iteration` when a concurrent record() inserts a new entity bucket. + for mapping in list(self._table.values()): + hit = mapping.get(src_uuid) + if hit is not None: + return hit + return None + + def snapshot(self) -> dict[str, dict[str, str]]: + """Read-only snapshot for the post-run report.""" + return {entity: dict(m) for entity, m in self._table.items()} + + +@dataclass +class CloneContext: + """Shared state for one ``clone()`` invocation. + + Phases hold a reference to this and call ``ctx.source`` / ``ctx.target`` + to drive HTTP, ``ctx.remap`` to record UUID mappings. + """ + + source: PlatformClient + target: PlatformClient + options: CloneOptions + remap: RemapTable = field(default_factory=RemapTable) + # Source prompt_registry_ids whose CustomTool was skipped; used to + # cascade-skip dependent workflows downstream. + skipped_custom_tool_registry_ids: set[str] = field(default_factory=set) diff --git a/src/unstract/clone/exceptions.py b/src/unstract/clone/exceptions.py new file mode 100644 index 0000000..3933c1c --- /dev/null +++ b/src/unstract/clone/exceptions.py @@ -0,0 +1,25 @@ +"""Exceptions raised by the clone subpackage.""" + + +class CloneError(Exception): + """Base class for all clone errors.""" + + +class PlatformAPIError(CloneError): + """Raised when the Platform API returns a non-2xx response we can't recover from.""" + + def __init__( + self, message: str, status_code: int | None = None, body: str | None = None + ): + full_message = f"{message}\n body: {body}" if body else message + super().__init__(full_message) + self.status_code = status_code + self.body = body + + +class NameConflictError(CloneError): + """Raised when ``on_name_conflict='abort'`` and the target has a like-named entity.""" + + +class DependencyMissingError(CloneError): + """Raised when a phase references a source UUID that no prior phase has mapped.""" diff --git a/src/unstract/clone/orchestrator.py b/src/unstract/clone/orchestrator.py new file mode 100644 index 0000000..a7c81a0 --- /dev/null +++ b/src/unstract/clone/orchestrator.py @@ -0,0 +1,115 @@ +"""Top-level ``clone()`` entry point. + +Wires source/target ``PlatformClient`` instances, builds a +``CloneContext``, runs each phase in strict topological order, and +returns a ``CloneReport``. + +Phase order is owned here — phases must not call each other. Adding a new +entity type means: write a new ``Phase`` subclass and append it to +``PHASES`` at the right dependency position. +""" + +from __future__ import annotations + +import logging +import time + +from unstract.clone.client import PlatformClient +from unstract.clone.context import CloneContext, CloneOptions, OrgEndpoint +from unstract.clone.exceptions import CloneError +from unstract.clone.phases import ( + AdapterPhase, + APIDeploymentPhase, + ConnectorPhase, + CustomToolPhase, + FilesPhase, + PipelinePhase, + TagPhase, + ToolInstancePhase, + WorkflowEndpointPhase, + WorkflowPhase, +) +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, Endpoint + +logger = logging.getLogger(__name__) + +# Strict dependency order. Each entry: (phase_name, phase_class). +# Adapter, connector, tag are independent leaf phases. Downstream phases +# (custom_tool, workflow, tool_instance, workflow_endpoint) land later +# and consume the remap entries these produce. Pipeline + api_deployment +# come last: both FK the workflow and api_deployment additionally +# requires endpoints to be configured before the serializer accepts it. +PHASES: list[tuple[str, type[Phase]]] = [ + ("adapter", AdapterPhase), + ("connector", ConnectorPhase), + ("tag", TagPhase), + ("custom_tool", CustomToolPhase), + ("files", FilesPhase), + ("workflow", WorkflowPhase), + ("tool_instance", ToolInstancePhase), + ("workflow_endpoint", WorkflowEndpointPhase), + ("pipeline", PipelinePhase), + ("api_deployment", APIDeploymentPhase), +] + + +def clone( + source: OrgEndpoint, + target: OrgEndpoint, + options: CloneOptions | None = None, +) -> CloneReport: + """Migrate configured resources from one org to another. + + Returns a ``CloneReport`` even on partial failure; raises only on + setup errors or ``on_name_conflict='abort'`` collisions. + """ + opts = options or CloneOptions() + src_client = PlatformClient(source) + tgt_client = PlatformClient(target) + try: + ctx = CloneContext( + source=src_client, + target=tgt_client, + options=opts, + ) + report = CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + run_started = time.perf_counter() + for name, phase_cls in PHASES: + if not opts.includes(name): + report.skipped_phases.append(name) + logger.info("Phase '%s' skipped (excluded)", name) + continue + logger.info("=== Phase: %s ===", name) + phase_started = time.perf_counter() + try: + phase_cls(ctx).run(report) + except CloneError as e: + report.aborted = True + report.abort_reason = str(e) + logger.error("Phase '%s' aborted: %s", name, e) + # Stamp duration even on abort so the report reflects time spent. + report.get_phase(name).duration_s = time.perf_counter() - phase_started + break + else: + report.get_phase(name).duration_s = time.perf_counter() - phase_started + logger.info( + "=== Phase '%s' done in %.2fs ===", + name, + report.get_phase(name).duration_s, + ) + + report.total_duration_s = time.perf_counter() - run_started + report.remap_snapshot = ctx.remap.snapshot() + return report + finally: + src_client.close() + tgt_client.close() diff --git a/src/unstract/clone/phases/__init__.py b/src/unstract/clone/phases/__init__.py new file mode 100644 index 0000000..03f0952 --- /dev/null +++ b/src/unstract/clone/phases/__init__.py @@ -0,0 +1,34 @@ +"""Per-entity clone phases. + +Each phase implements ``run(report)``, uses ``ctx.source`` / ``ctx.target`` +to drive HTTP, records ``ctx.remap`` entries for downstream phases. + +Dependency order is owned by ``orchestrator.clone`` — phases must NOT +call each other directly. +""" + +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.phases.api_deployment import APIDeploymentPhase +from unstract.clone.phases.base import Phase +from unstract.clone.phases.connector import ConnectorPhase +from unstract.clone.phases.custom_tool import CustomToolPhase +from unstract.clone.phases.files import FilesPhase +from unstract.clone.phases.pipeline import PipelinePhase +from unstract.clone.phases.tag import TagPhase +from unstract.clone.phases.tool_instance import ToolInstancePhase +from unstract.clone.phases.workflow import WorkflowPhase +from unstract.clone.phases.workflow_endpoint import WorkflowEndpointPhase + +__all__ = [ + "APIDeploymentPhase", + "AdapterPhase", + "ConnectorPhase", + "CustomToolPhase", + "FilesPhase", + "Phase", + "PipelinePhase", + "TagPhase", + "ToolInstancePhase", + "WorkflowEndpointPhase", + "WorkflowPhase", +] diff --git a/src/unstract/clone/phases/adapter.py b/src/unstract/clone/phases/adapter.py new file mode 100644 index 0000000..522629f --- /dev/null +++ b/src/unstract/clone/phases/adapter.py @@ -0,0 +1,125 @@ +"""Migrate adapters from source org to target org. + +Reference implementation for the get-or-create pattern: list-by-name GET +against target, POST create if missing, record source->target UUID in the +remap table for downstream phases. + +Frictionless onboarding adapters are excluded — the backend's +service-account queryset already filters them out, so clone never +sees them. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +ADAPTER_PATH = "adapter/" + + +class AdapterPhase(Phase): + name = "adapter" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(ADAPTER_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for adapter: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS adapter: {e}") + return result + try: + src_summaries = self.ctx.source.list_adapters() + except Exception as e: + logger.exception("Failed to list source adapters: %s", e) + result.failed += 1 + result.errors.append(f"list source adapters: {e}") + return result + + logger.info("Found %d adapter(s) in source org", len(src_summaries)) + self.parallel_map( + src_summaries, + lambda summary, lock: self._clone_one(summary, result, lock), + ) + return result + + def _clone_one( + self, summary: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = summary["adapter_name"] + atype = summary["adapter_type"] + src_id = summary["id"] + try: + src = self.ctx.source.get_adapter(src_id) + except Exception as e: + logger.exception( + "Failed to GET source adapter %s [%s] detail: %s", name, atype, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET source detail {name} [{atype}]: {e}") + return + + try: + existing = self.ctx.target.list_adapters(name=name, adapter_type=atype) + except Exception as e: + logger.exception( + "Failed to GET adapter %s [%s] on target: %s", name, atype, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET {name} [{atype}]: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"adapter '{name}' [{atype}] already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted adapter '%s' [%s] src=%s -> tgt=%s", + name, + atype, + src_id, + tgt["id"], + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would create adapter '%s' [%s] src=%s", name, atype, src_id + ) + return + else: + payload = build_post_payload(src, self._writable) + try: + tgt = self.ctx.target.create_adapter(payload) + except Exception as e: + logger.exception("Failed to create adapter %s [%s]: %s", name, atype, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name} [{atype}]: {e}") + return + with lock: + result.created += 1 + logger.info( + "created adapter '%s' [%s] src=%s -> tgt=%s", + name, + atype, + src_id, + tgt["id"], + ) + + with lock: + self.ctx.remap.record("adapter", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/api_deployment.py b/src/unstract/clone/phases/api_deployment.py new file mode 100644 index 0000000..df55983 --- /dev/null +++ b/src/unstract/clone/phases/api_deployment.py @@ -0,0 +1,175 @@ +"""Migrate API deployments from source org to target org. + +APIDeployment FKs ``workflow`` — remap via the WorkflowPhase table. +Backend enforces one active deployment per workflow and one +``api_name`` per organization, so adopt-by-name is the only safe +re-run strategy. + +On create the backend auto-provisions a single active API key and +returns it on the response. Extra rotated keys on the source are NOT +mirrored (server-generated UUIDs can't be preserved; rotate +post-clone). +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids + +logger = logging.getLogger(__name__) + +API_DEPLOYMENT_PATH = "api/deployment/" + + +class APIDeploymentPhase(Phase): + name = "api_deployment" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(API_DEPLOYMENT_PATH) + except Exception as e: + logger.exception( + "Failed to fetch target POST schema for api_deployment: %s", e + ) + result.failed += 1 + result.errors.append(f"OPTIONS api_deployment: {e}") + return result + + try: + src_deployments = self.ctx.source.list_api_deployments() + except Exception as e: + logger.exception("Failed to list source api_deployments: %s", e) + result.failed += 1 + result.errors.append(f"list source api_deployments: {e}") + return result + + logger.info("Found %d source API deployment(s)", len(src_deployments)) + self.parallel_map( + src_deployments, + lambda src, lock: self._clone_one(src, result, lock), + ) + return result + + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + api_name = src["api_name"] + src_id = src["id"] + src_wf_id = src.get("workflow") or src.get("workflow_id") + + if not src_wf_id: + logger.warning( + "source api_deployment '%s' has no workflow FK — skipping", api_name + ) + with lock: + result.skipped += 1 + return + + with lock: + tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) + if not tgt_wf_id: + logger.warning( + "no workflow remap for api_deployment '%s' (src workflow %s) — skipping", + api_name, + src_wf_id, + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_api_deployments(api_name=api_name) + except Exception as e: + logger.exception( + "Failed to GET api_deployment %s on target: %s", api_name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET {api_name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"api_deployment '{api_name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted api_deployment '%s' src=%s -> tgt=%s", + api_name, + src_id, + tgt["id"], + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would create api_deployment '%s' src=%s", api_name, src_id + ) + return + else: + try: + full_src = self.ctx.source.get_api_deployment(src_id) + except Exception as e: + logger.exception( + "Failed to GET source api_deployment %s: %s", api_name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET src api_deployment {api_name}: {e}") + return + remapped = remap_uuids(full_src, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + payload["workflow"] = tgt_wf_id + try: + tgt = self.ctx.target.create_api_deployment(payload) + except Exception as e: + logger.exception("Failed to create api_deployment %s: %s", api_name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {api_name}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created api_deployment '%s' src=%s -> tgt=%s", + api_name, + src_id, + tgt["id"], + ) + self._warn_if_extra_source_keys(src_id, api_name) + + with lock: + self.ctx.remap.record("api_deployment", src_id, tgt["id"]) + + def _warn_if_extra_source_keys(self, src_deployment_id: str, name: str) -> None: + try: + keys = self.ctx.source.list_api_deployment_keys(src_deployment_id) + except Exception as e: + # WARNING (not DEBUG) — the operator needs to know we couldn't + # check whether they have additional keys to recreate manually. + logger.warning( + "Could not list source keys for api_deployment %s " + "(extra-key check skipped; re-verify in source UI): %s", + name, + e, + ) + return + active = [k for k in keys if k.get("is_active")] + if len(active) > 1: + logger.warning( + "source api_deployment '%s' had %d active API keys; " + "target has only the auto-provisioned default — " + "re-create the rest manually if your clients depend on them", + name, + len(active), + ) diff --git a/src/unstract/clone/phases/base.py b/src/unstract/clone/phases/base.py new file mode 100644 index 0000000..e83e00e --- /dev/null +++ b/src/unstract/clone/phases/base.py @@ -0,0 +1,115 @@ +"""Base class for clone phases.""" + +from __future__ import annotations + +import logging +import threading +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable +from concurrent.futures import FIRST_EXCEPTION, Future, ThreadPoolExecutor, wait +from typing import Any, TypeVar + +from unstract.clone.context import CloneContext +from unstract.clone.exceptions import CloneError +from unstract.clone.report import CloneReport, PhaseResult + +T = TypeVar("T") + +logger = logging.getLogger(__name__) + +# DRF OPTIONS reports any ModelSerializer FK/M2M as writable, but the +# backend's perform_create overrides these server-side. Posting them is +# either noise (silently overwritten) or a 400 (when a source-org value +# doesn't validate against the target org). Strip them universally — +# the phase OPTIONS schema covers the entity-specific writable subset. +SERVER_MANAGED: frozenset[str] = frozenset( + { + "id", + "organization", + "created_by", + "created_by_email", + "modified_by", + "modified_by_email", + "created_at", + "modified_at", + "shared_users", + "workflow_owner", + } +) + + +def build_post_payload(src: dict[str, Any], writable: frozenset[str]) -> dict[str, Any]: + """Project ``src`` onto the writable schema, dropping server-managed + fields, ``None`` values, and empty strings (which DRF treats as blank + and rejects on required fields). + """ + keys = writable - SERVER_MANAGED + # Equality with `(None, "")` matched False and 0 too (Python: False == 0, + # 0 in (None, "") is False, but `0 not in (...)` falsely returns True). + # Explicit identity / equality checks preserve falsy-but-meaningful + # values like ``BooleanField`` False and numeric defaults. + return {k: src[k] for k in keys if k in src and src[k] is not None and src[k] != ""} + + +class Phase(ABC): + """Abstract phase. One subclass per entity type.""" + + name: str = "" + + def __init__(self, ctx: CloneContext): + self.ctx = ctx + + @abstractmethod + def run(self, report: CloneReport) -> PhaseResult: + """Migrate all entities of this phase's type. Idempotent across runs.""" + raise NotImplementedError + + def parallel_map( + self, + items: Iterable[T], + work_fn: Callable[[T, threading.Lock], None], + ) -> None: + """Fan ``work_fn(item, lock)`` across ``ctx.options.concurrency`` + threads. ``work_fn`` must hold ``lock`` while mutating shared + state. ``CloneError`` from any worker cancels the rest and + re-raises. ``concurrency <= 1`` skips the executor entirely. + """ + materialised = list(items) + if not materialised: + return + + concurrency = max(1, self.ctx.options.concurrency) + lock = threading.Lock() + + if concurrency == 1: + for item in materialised: + work_fn(item, lock) + return + + with ThreadPoolExecutor( + max_workers=concurrency, + thread_name_prefix=f"clone-{self.name}", + ) as pool: + futures: list[Future[None]] = [ + pool.submit(work_fn, item, lock) for item in materialised + ] + done, _ = wait(futures, return_when=FIRST_EXCEPTION) + clone_err: CloneError | None = None + other_err: BaseException | None = None + for fut in done: + if fut.cancelled(): + continue + exc = fut.exception() + if exc is None: + continue + if isinstance(exc, CloneError) and clone_err is None: + clone_err = exc + elif other_err is None: + other_err = exc + if clone_err is not None or other_err is not None: + for fut in futures: + fut.cancel() + if clone_err is not None: + raise clone_err + if other_err is not None: + raise other_err diff --git a/src/unstract/clone/phases/connector.py b/src/unstract/clone/phases/connector.py new file mode 100644 index 0000000..88215da --- /dev/null +++ b/src/unstract/clone/phases/connector.py @@ -0,0 +1,159 @@ +"""Migrate connectors from source org to target org. + +Same list -> per-id GET -> POST/adopt pattern as AdapterPhase. Two +connector-specific wrinkles: + +1. **Connectors with redacted metadata are skipped.** The backend + serializer strips ``connector_metadata`` for auto-provisioned rows + (e.g. Unstract Cloud Storage), so the SDK cannot reconstruct them + on the target. We detect this by inspecting the source GET response: + a falsy ``connector_metadata`` means the operator must rely on the + target's own provisioning (or re-create the row manually) — the + remap table records no entry for these. + +2. **OAuth ``connector_auth`` is stripped from responses.** Tokens are + stored in a sibling ``ConnectorAuth`` row that the public API never + exposes, so OAuth-backed connectors land on the target without + refresh tokens. Operator must re-authorise on target. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +CONNECTOR_PATH = "connector/" + +# Backend POST serializer trips on these keys (connector_v2/serializers.py) +# by trying to refresh against the source user's social auth — guaranteed +# OAuthTimeOut on target. Detect here and skip ahead of POST. +_OAUTH_TOKEN_KEYS: frozenset[str] = frozenset({"access_token", "refresh_token"}) + + +def _has_oauth_tokens(metadata: dict[str, Any]) -> bool: + return any(metadata.get(k) for k in _OAUTH_TOKEN_KEYS) + + +class ConnectorPhase(Phase): + name = "connector" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(CONNECTOR_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for connector: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS connector: {e}") + return result + try: + src_summaries = self.ctx.source.list_connectors() + except Exception as e: + logger.exception("Failed to list source connectors: %s", e) + result.failed += 1 + result.errors.append(f"list source connectors: {e}") + return result + + logger.info("Found %d connector(s) in source org", len(src_summaries)) + self.parallel_map( + src_summaries, + lambda summary, lock: self._clone_one(summary, result, lock), + ) + return result + + def _clone_one( + self, summary: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = summary["connector_name"] + src_id = summary["id"] + + try: + src = self.ctx.source.get_connector(src_id) + except Exception as e: + logger.exception("Failed to GET source connector %s detail: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET source detail {name}: {e}") + return + + metadata = src.get("connector_metadata") or {} + if not metadata: + logger.info( + "skipping connector '%s' (src=%s, catalog=%s) — source returned no metadata", + name, + src_id, + src.get("connector_id"), + ) + with lock: + result.skipped += 1 + return + + if _has_oauth_tokens(metadata): + logger.warning( + "skipping connector '%s' (src=%s, catalog=%s) — OAuth-backed; " + "re-authorise on target after the clone, then re-run to wire " + "dependent workflow endpoints.", + name, + src_id, + src.get("connector_id"), + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_connectors(name=name) + except Exception as e: + logger.exception("Failed to GET connector %s on target: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"connector '{name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted connector '%s' src=%s -> tgt=%s", + name, + src_id, + tgt["id"], + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create connector '%s' src=%s", name, src_id) + return + else: + payload = build_post_payload(src, self._writable) + try: + tgt = self.ctx.target.create_connector(payload) + except Exception as e: + logger.exception("Failed to create connector %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created connector '%s' src=%s -> tgt=%s", + name, + src_id, + tgt["id"], + ) + + with lock: + self.ctx.remap.record("connector", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/custom_tool.py b/src/unstract/clone/phases/custom_tool.py new file mode 100644 index 0000000..c43d64f --- /dev/null +++ b/src/unstract/clone/phases/custom_tool.py @@ -0,0 +1,409 @@ +"""Migrate prompt-studio projects via the project-transfer endpoints. + +For each source tool the phase: + +1. ``GET prompt-studio/project-transfer/{src_tool_id}`` — pulls a + portable JSON blob (tool_metadata, tool_settings, + default_profile_settings, prompts, export_metadata). +2. Decides fresh vs adopt by looking up the target tool by name. +3. **Fresh path**: reads source's default ProfileManager to learn the + adapter UUIDs the profile is bound to, remaps each via the running + ``adapter`` remap table, and POSTs the import as a multipart upload + with target-org adapter ids on the form. Backend creates the tool, + the default profile, and all prompts server-side in one call. +4. **Adopt path**: POSTs ``sync-prompts`` on the existing target tool. + Backend rip-and-replaces prompts + ``tool_settings`` and leaves the + target's locally-configured profiles + adapters untouched (which is + what the operator wants — they may have rewired adapters on target). +5. Republishes ``PromptStudioRegistry`` via the export action and + records the ``custom_tool`` + ``prompt_studio_registry`` remaps so + downstream ToolInstancePhase can rewrite ``ToolInstance.tool_id``. + +Adapter id discovery for the fresh path needs all four of LLM, +vector_db, embedding, x2text. If any source adapter can't be resolved +via the adapter remap, the tool is failed cleanly — we never want to +land a half-wired profile. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +_PROFILE_ADAPTER_FIELDS: tuple[tuple[str, str], ...] = ( + ("llm", "llm_adapter_id"), + ("vector_store", "vector_db_adapter_id"), + ("embedding_model", "embedding_adapter_id"), + ("x2text", "x2text_adapter_id"), +) + + +def _extract_adapter_name(value: Any) -> str | None: + """Adapter FKs serialise as the adapter NAME on the wire; tolerate a + nested-dict shape too. Never fall back to the UUID — list_adapters + matches by name and would silently miss. + """ + if isinstance(value, str): + return value or None + if isinstance(value, dict): + return value.get("adapter_name") or value.get("name") + return None + + +class CustomToolPhase(Phase): + name = "custom_tool" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + src_tools = self.ctx.source.list_custom_tools() + except Exception as e: + logger.exception("Failed to list source custom tools: %s", e) + result.failed += 1 + result.errors.append(f"list source custom tools: {e}") + return result + + logger.info("Found %d custom tool(s) in source org", len(src_tools)) + try: + target_tools = self.ctx.target.list_custom_tools() + except Exception as e: + logger.exception("Failed to list target tools: %s", e) + result.failed += 1 + result.errors.append(f"list target tools: {e}") + return result + + # Source's service-account view hides frictionless adapters; a + # profile-referenced name missing here flags a tool we can't migrate. + try: + self._src_adapter_names = { + a["adapter_name"] for a in self.ctx.source.list_adapters() + } + except Exception as e: + logger.exception("Failed to list source adapters: %s", e) + result.failed += 1 + result.errors.append(f"list source adapters for visibility check: {e}") + return result + + # Updated under lock when a fresh create lands so duplicate + # same-name source rows adopt instead of recreating. + target_by_name: dict[str, dict[str, Any]] = { + t["tool_name"]: t for t in target_tools + } + + self.parallel_map( + src_tools, + lambda summary, lock: self._clone_one( + summary, target_by_name, result, lock + ), + ) + return result + + def _clone_one( + self, + summary: dict[str, Any], + target_by_name: dict[str, dict[str, Any]], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + tool_name = summary["tool_name"] + src_tool_id = summary["tool_id"] + + try: + export_data = self.ctx.source.export_project(src_tool_id) + except Exception as e: + logger.exception("Failed to export source tool '%s': %s", tool_name, e) + with lock: + result.failed += 1 + result.errors.append(f"export src tool {tool_name}: {e}") + return + + with lock: + match = target_by_name.get(tool_name) + + if match is not None: + tgt_tool_id = self._adopt( + match, export_data, result, tool_name, src_tool_id, lock + ) + else: + tgt_tool_id = self._create_fresh( + export_data, src_tool_id, tool_name, result, lock + ) + if tgt_tool_id is not None: + with lock: + target_by_name[tool_name] = { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + } + + if tgt_tool_id is None: + return + + with lock: + self.ctx.remap.record("custom_tool", src_tool_id, tgt_tool_id) + + if self.ctx.options.dry_run: + return + + try: + self.ctx.target.export_custom_tool(tgt_tool_id) + logger.info( + "republished registry for tool '%s' tgt=%s", tool_name, tgt_tool_id + ) + except Exception as e: + logger.exception("Registry republish failed for tool %s: %s", tool_name, e) + with lock: + result.failed += 1 + result.errors.append(f"export {tool_name}: {e}") + return + + try: + src_regs = self.ctx.source.list_registries(custom_tool=src_tool_id) + tgt_regs = self.ctx.target.list_registries(custom_tool=tgt_tool_id) + except Exception as e: + logger.warning( + "registry remap lookup failed for tool '%s' " + "(downstream ToolInstance clone may skip): %s", + tool_name, + e, + ) + with lock: + result.failed += 1 + result.errors.append(f"registry remap lookup {tool_name}: {e}") + return + + if src_regs and tgt_regs: + with lock: + self.ctx.remap.record( + "prompt_studio_registry", + src_regs[0]["prompt_registry_id"], + tgt_regs[0]["prompt_registry_id"], + ) + + def _adopt( + self, + match: dict[str, Any], + export_data: dict[str, Any], + result: PhaseResult, + tool_name: str, + src_tool_id: str, + lock: threading.Lock, + ) -> str | None: + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"tool '{tool_name}' already exists in target as {match['tool_id']}" + ) + + tgt_tool_id = match["tool_id"] + if self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would sync prompts into adopted tool '%s' src=%s -> tgt=%s", + tool_name, + src_tool_id, + tgt_tool_id, + ) + return tgt_tool_id + + try: + self.ctx.target.sync_prompts(tgt_tool_id, export_data) + except Exception as e: + logger.exception("sync_prompts failed for tool %s: %s", tool_name, e) + with lock: + result.failed += 1 + result.errors.append(f"sync {tool_name}: {e}") + return None + + with lock: + result.adopted += 1 + logger.info( + "adopted tool '%s' src=%s -> tgt=%s (prompts re-synced)", + tool_name, + src_tool_id, + tgt_tool_id, + ) + return tgt_tool_id + + def _create_fresh( + self, + export_data: dict[str, Any], + src_tool_id: str, + tool_name: str, + result: PhaseResult, + lock: threading.Lock, + ) -> str | None: + if self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would import tool '%s' src=%s", tool_name, src_tool_id + ) + return None + + default_profile = self._source_default_profile(src_tool_id, tool_name) + if default_profile is None: + with lock: + result.failed += 1 + result.errors.append( + f"import {tool_name}: no default profile on source" + ) + return None + + invisible = self._invisible_source_adapter_names(default_profile) + if invisible: + self._register_frictionless_skip( + src_tool_id, tool_name, invisible, result, lock + ) + return None + + adapter_ids = self._resolve_target_adapter_ids( + default_profile, tool_name + ) + if adapter_ids is None: + with lock: + result.failed += 1 + result.errors.append( + f"import {tool_name}: missing target adapter remap for default profile" + ) + return None + + try: + tgt = self.ctx.target.import_project(export_data, adapter_ids=adapter_ids) + except Exception as e: + logger.exception("import_project failed for tool %s: %s", tool_name, e) + with lock: + result.failed += 1 + result.errors.append(f"import {tool_name}: {e}") + return None + + tgt_tool_id = tgt["tool_id"] + with lock: + result.created += 1 + logger.info( + "created tool '%s' src=%s -> tgt=%s (needs_adapter_config=%s)", + tool_name, + src_tool_id, + tgt_tool_id, + tgt.get("needs_adapter_config"), + ) + return tgt_tool_id + + def _source_default_profile( + self, src_tool_id: str, tool_name: str + ) -> dict[str, Any] | None: + try: + src_profiles = self.ctx.source.list_profiles(src_tool_id) + except Exception as e: + logger.exception( + "Failed to list source profiles for tool %s: %s", tool_name, e + ) + return None + + default = next( + (p for p in src_profiles if p.get("is_default")), + src_profiles[0] if src_profiles else None, + ) + if default is None: + logger.warning( + "source tool '%s' has no profiles to derive adapter ids from", + tool_name, + ) + return default + + def _invisible_source_adapter_names( + self, default_profile: dict[str, Any] + ) -> list[str]: + """Profile adapter names not in the source's visible adapter set + (typically frictionless) — these can't be migrated. + """ + missing: list[str] = [] + for src_field, _ in _PROFILE_ADAPTER_FIELDS: + adapter_name = _extract_adapter_name(default_profile.get(src_field)) + if adapter_name and adapter_name not in self._src_adapter_names: + missing.append(adapter_name) + return missing + + def _register_frictionless_skip( + self, + src_tool_id: str, + tool_name: str, + missing_adapters: list[str], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + """Record the skip + source registry id so dependent workflows + cascade-skip downstream. + """ + logger.warning( + "skipping tool '%s' src=%s — default profile references adapters " + "not visible to the source service account (frictionless?): %s. " + "Wire equivalents on target and re-run.", + tool_name, + src_tool_id, + missing_adapters, + ) + try: + src_regs = self.ctx.source.list_registries(custom_tool=src_tool_id) + except Exception as e: + logger.warning( + "registry lookup failed for skipped tool '%s' — " + "downstream cascade-skip may not fire: %s", + tool_name, + e, + ) + src_regs = [] + with lock: + result.skipped += 1 + for reg in src_regs: + reg_id = reg.get("prompt_registry_id") + if reg_id: + self.ctx.skipped_custom_tool_registry_ids.add(reg_id) + + def _resolve_target_adapter_ids( + self, default_profile: dict[str, Any], tool_name: str + ) -> dict[str, str] | None: + """Source profile carries adapter NAMES (per serializer); resolve + each name to a target adapter UUID via ``list_adapters(name=...)``. + + Returns ``None`` if any of the four required adapters can't be + found on target — caller fails the tool. AdapterPhase preserves + names across orgs so this lookup should always hit when the + adapter clone ran cleanly. + """ + resolved: dict[str, str] = {} + for src_field, form_field in _PROFILE_ADAPTER_FIELDS: + adapter_name = _extract_adapter_name(default_profile.get(src_field)) + if not adapter_name: + logger.warning( + "source default profile for tool '%s' missing adapter '%s'", + tool_name, + src_field, + ) + return None + try: + matches = self.ctx.target.list_adapters(name=adapter_name) + except Exception as e: + logger.exception( + "list_adapters lookup failed for %s on tool '%s': %s", + adapter_name, + tool_name, + e, + ) + return None + if not matches: + logger.warning( + "no target adapter named '%s' for field %s on tool '%s'", + adapter_name, + src_field, + tool_name, + ) + return None + resolved[form_field] = matches[0]["id"] + return resolved diff --git a/src/unstract/clone/phases/files.py b/src/unstract/clone/phases/files.py new file mode 100644 index 0000000..22de854 --- /dev/null +++ b/src/unstract/clone/phases/files.py @@ -0,0 +1,507 @@ +"""Migrate Prompt Studio document files (the user-uploaded test corpus). + +Runs after ``CustomToolPhase`` — consumes the ``custom_tool`` remap to +know which source-tool to target-tool mapping to iterate. + +Default mode (``file_strategy='platform_api'``): + +1. For each ``(src_tool_id, tgt_tool_id)``, list source DM rows + target + DM rows once each. +2. For each source filename missing on target: download from source, decode + per mime, enforce the size cap, upload as multipart to target. +3. Oversize files → ``CloneReport.oversize_files``; mime types the + backend can't round-trip (Excel placeholder, etc) → + ``unsupported_files``; transport errors → ``failed_files``. + +Skip mode (``file_strategy='skip'``): + +- No download/upload. Source DM list is emitted into ``skipped_files`` so + the operator knows what to re-upload manually via UI. + +Per-file work fans out across ``ctx.options.concurrency`` workers. +""" + +from __future__ import annotations + +import base64 +import logging +import threading +import time +from dataclasses import dataclass +from typing import Any + +import requests + +from unstract.clone.exceptions import PlatformAPIError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +_BASE64_MIMES: frozenset[str] = frozenset({"application/pdf"}) +_TEXT_MIMES: frozenset[str] = frozenset({"text/plain", "text/csv"}) + +_RETRYABLE_STATUS: frozenset[int] = frozenset({502, 503, 504}) +_MAX_RETRIES = 3 +_RETRY_BACKOFF_BASE_SECONDS = 1.0 + + +@dataclass +class _FileTask: + src_tool_id: str + tgt_tool_id: str + tool_name: str + file_name: str + src_document_id: str + + +class FilesPhase(Phase): + name = "files" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + tool_remap = self.ctx.remap.snapshot().get("custom_tool", {}) + if not tool_remap: + logger.info("files phase: no custom_tool remap entries; nothing to do") + return result + + strategy = self.ctx.options.file_strategy + logger.info( + "files phase: strategy=%s tools=%d cap=%d bytes concurrency=%d", + strategy, + len(tool_remap), + self.ctx.options.max_file_size, + self.ctx.options.concurrency, + ) + + # Pass 1: build per-file task list sequentially (cheap). + file_tasks: list[_FileTask] = [] + cloned_tools: list[tuple[str, str, str, list[dict[str, Any]]]] = [] + for src_tool_id, tgt_tool_id in tool_remap.items(): + tool_name = self._lookup_tool_name(tgt_tool_id) or src_tool_id + try: + src_docs = self.ctx.source.list_prompt_documents(src_tool_id) + except Exception as e: + logger.exception( + "files: failed to list source DM rows for tool %s: %s", + tool_name, + e, + ) + result.failed += 1 + result.errors.append(f"list source docs {tool_name}: {e}") + continue + + if strategy == "skip": + self._emit_skip( + src_docs, src_tool_id, tgt_tool_id, tool_name, report, result + ) + continue + + tasks = self._build_tool_tasks( + src_tool_id, tgt_tool_id, tool_name, src_docs, report, result + ) + file_tasks.extend(tasks) + cloned_tools.append((src_tool_id, tgt_tool_id, tool_name, src_docs)) + + # Pass 2: download + upload each file in parallel. + if file_tasks: + self.parallel_map( + file_tasks, + lambda task, lock: self._clone_one_file(task, report, result, lock), + ) + + # Pass 3: set default doc per tool after all uploads land. + if not self.ctx.options.dry_run and strategy != "skip": + for src_tool_id, tgt_tool_id, tool_name, src_docs in cloned_tools: + self._ensure_default_doc(src_tool_id, tgt_tool_id, tool_name, src_docs) + + return result + + def _build_tool_tasks( + self, + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + src_docs: list[dict[str, Any]], + report: CloneReport, + result: PhaseResult, + ) -> list[_FileTask]: + try: + tgt_docs = self.ctx.target.list_prompt_documents(tgt_tool_id) + except Exception as e: + logger.exception( + "files: failed to list target DM rows for tool %s: %s", + tool_name, + e, + ) + result.failed += 1 + result.errors.append(f"list target docs {tool_name}: {e}") + return [] + target_names = {d["document_name"] for d in tgt_docs} + + tasks: list[_FileTask] = [] + for doc in src_docs: + file_name = doc.get("document_name") + src_document_id = doc.get("document_id") + if not file_name or not src_document_id: + result.skipped += 1 + result.errors.append( + f"malformed source DM row on tool={tool_name}: {doc!r}" + ) + logger.warning( + "files: skipping malformed source DM row on tool=%s: %r", + tool_name, + doc, + ) + continue + if file_name in target_names: + result.skipped += 1 + logger.info( + "files: skipping tool=%s file=%s — already present on target", + tool_name, + file_name, + ) + continue + if self.ctx.options.dry_run: + result.skipped += 1 + logger.info( + "[dry-run] files: would clone tool=%s file=%s", + tool_name, + file_name, + ) + continue + tasks.append( + _FileTask( + src_tool_id=src_tool_id, + tgt_tool_id=tgt_tool_id, + tool_name=tool_name, + file_name=file_name, + src_document_id=src_document_id, + ) + ) + return tasks + + def _clone_one_file( + self, + task: _FileTask, + report: CloneReport, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + payload = self._with_retry( + lambda: self.ctx.source.download_prompt_file( + task.src_tool_id, task.src_document_id + ), + op=f"download {task.tool_name}/{task.file_name}", + ) + except Exception as e: + logger.exception( + "files: download failed tool=%s file=%s: %s", + task.tool_name, + task.file_name, + e, + ) + with lock: + result.failed += 1 + report.failed_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "error": f"download: {e}", + } + ) + return + + mime = (payload or {}).get("mime_type") or "" + raw = self._decode_payload(payload, mime) + if raw is None: + logger.warning( + "files: unsupported mime tool=%s file=%s mime=%s", + task.tool_name, + task.file_name, + mime, + ) + with lock: + result.skipped += 1 + report.unsupported_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "mime_type": mime, + } + ) + return + + if len(raw) > self.ctx.options.max_file_size: + with lock: + result.skipped += 1 + report.oversize_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "size_bytes": len(raw), + "cap_bytes": self.ctx.options.max_file_size, + } + ) + logger.info( + "files: oversize tool=%s file=%s size=%d cap=%d", + task.tool_name, + task.file_name, + len(raw), + self.ctx.options.max_file_size, + ) + return + + try: + self._with_retry( + lambda: self.ctx.target.upload_prompt_file( + task.tgt_tool_id, task.file_name, raw, mime + ), + op=f"upload {task.tool_name}/{task.file_name}", + ) + except Exception as e: + logger.exception( + "files: upload failed tool=%s file=%s: %s", + task.tool_name, + task.file_name, + e, + ) + with lock: + result.failed += 1 + report.failed_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "error": f"upload: {e}", + } + ) + return + + with lock: + result.created += 1 + report.uploaded_files.append( + { + "tool_id": task.tgt_tool_id, + "tool_name": task.tool_name, + "file_name": task.file_name, + "size_bytes": len(raw), + "mime_type": mime, + } + ) + logger.info( + "files: uploaded tool=%s file=%s size=%d", + task.tool_name, + task.file_name, + len(raw), + ) + + def _emit_skip( + self, + src_docs: list[dict[str, Any]], + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + report: CloneReport, + result: PhaseResult, + ) -> None: + for doc in src_docs: + file_name = doc.get("document_name") + if not file_name: + continue + report.skipped_files.append( + { + "tool_id": tgt_tool_id, + "tool_name": tool_name, + "file_name": file_name, + "source_org_slug": self.ctx.source.endpoint.organization_id, + "source_tool_id": src_tool_id, + } + ) + result.skipped += 1 + logger.info( + "files: skip mode emitted %d filenames for tool=%s", + len(src_docs), + tool_name, + ) + + def _decode_payload( + self, payload: dict[str, Any] | None, mime: str + ) -> bytes | None: + if not payload: + return None + data_field = payload.get("data") + if data_field is None: + return None + if mime in _BASE64_MIMES: + if isinstance(data_field, bytes): + return base64.b64decode(data_field) + return base64.b64decode(data_field.encode()) + if mime in _TEXT_MIMES: + if isinstance(data_field, bytes): + return data_field + return data_field.encode("utf-8") + # Excel + unhandled types: BE returned a placeholder string, + # not real bytes. Round-trip would corrupt the file. + return None + + def _ensure_default_doc( + self, + src_tool_id: str, + tgt_tool_id: str, + tool_name: str, + src_docs: list[dict[str, Any]], + ) -> None: + """Set target ``CustomTool.output`` so the FE auto-selects a doc. + + Mirror source's chosen doc by filename when possible; fall back + to the first available target doc. Skip if target already has + ``output`` set — never override an operator's later choice on + re-runs. + """ + try: + tgt_tool = self.ctx.target.get_custom_tool(tgt_tool_id) + except Exception as e: + logger.warning( + "files: skipping default-doc set for tool=%s — fetch tgt failed: %s", + tool_name, + e, + ) + return + + if tgt_tool.get("output"): + logger.debug( + "files: target tool=%s already has default doc; leaving as-is", + tool_name, + ) + return + + try: + tgt_docs = self.ctx.target.list_prompt_documents(tgt_tool_id) + except Exception as e: + logger.warning( + "files: skipping default-doc set for tool=%s — list tgt docs failed: %s", + tool_name, + e, + ) + return + if not tgt_docs: + return + + chosen_id = self._pick_default_doc_id( + src_tool_id, src_docs, tgt_docs, tool_name + ) + if not chosen_id: + return + + try: + self.ctx.target.update_custom_tool(tgt_tool_id, {"output": chosen_id}) + logger.info( + "files: set default doc tool=%s doc_id=%s", tool_name, chosen_id + ) + except Exception as e: + logger.warning("files: PATCH default doc failed tool=%s: %s", tool_name, e) + + def _pick_default_doc_id( + self, + src_tool_id: str, + src_docs: list[dict[str, Any]], + tgt_docs: list[dict[str, Any]], + tool_name: str, + ) -> str | None: + try: + src_tool = self.ctx.source.get_custom_tool(src_tool_id) + src_output = src_tool.get("output") + except Exception as e: + logger.debug( + "files: source CustomTool fetch failed for tool=%s (%s); " + "falling back to first target doc", + tool_name, + e, + ) + src_output = None + + if src_output: + src_name = next( + ( + d.get("document_name") + for d in src_docs + if d.get("document_id") == src_output + ), + None, + ) + if src_name: + matched = next( + ( + d.get("document_id") + for d in tgt_docs + if d.get("document_name") == src_name + ), + None, + ) + if matched: + return matched + + return tgt_docs[0].get("document_id") + + def _lookup_tool_name(self, tgt_tool_id: str) -> str | None: + try: + tools = self.ctx.target.list_custom_tools() + except PlatformAPIError as e: + logger.warning( + "files: list_custom_tools failed during name lookup (%s); " + "log lines will fall back to tool ids", + e, + ) + return None + except (requests.ConnectionError, requests.Timeout) as e: + logger.warning( + "files: transport error during tool-name lookup (%s); " + "log lines will fall back to tool ids", + e, + ) + return None + for t in tools: + if t.get("tool_id") == tgt_tool_id: + return t.get("tool_name") + return None + + def _with_retry(self, fn: Any, *, op: str) -> Any: + last_exc: Exception | None = None + for attempt in range(1, _MAX_RETRIES + 1): + try: + return fn() + except PlatformAPIError as e: + last_exc = e + if e.status_code not in _RETRYABLE_STATUS or attempt == _MAX_RETRIES: + raise + sleep = _RETRY_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) + logger.warning( + "files: retry %d/%d for %s after %d: sleeping %.1fs", + attempt, + _MAX_RETRIES, + op, + e.status_code, + sleep, + ) + time.sleep(sleep) + except (requests.ConnectionError, requests.Timeout) as e: + last_exc = e + if attempt == _MAX_RETRIES: + raise + sleep = _RETRY_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) + logger.warning( + "files: retry %d/%d for %s after %s: sleeping %.1fs", + attempt, + _MAX_RETRIES, + op, + type(e).__name__, + sleep, + ) + time.sleep(sleep) + assert last_exc is not None + raise last_exc diff --git a/src/unstract/clone/phases/pipeline.py b/src/unstract/clone/phases/pipeline.py new file mode 100644 index 0000000..9892b1c --- /dev/null +++ b/src/unstract/clone/phases/pipeline.py @@ -0,0 +1,172 @@ +"""Migrate ETL/TASK pipelines from source org to target org. + +Pipelines FK ``workflow`` — the only entity remap needed. On create the +backend force-sets ``active=True`` and auto-provisions one active API +key per pipeline; if the source had additional rotated keys, those are +NOT mirrored (their UUIDs are server-generated and can't be preserved, +and operators rotate post-clone anyway). + +``DEFAULT`` (legacy) and ``APP`` pipeline types are skipped — DEFAULT is +dead code from the v1 era; APP is a Streamlit-style deployment whose +lifecycle isn't shaped like an ETL/TASK pipeline. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids + +logger = logging.getLogger(__name__) + +PIPELINE_PATH = "pipeline/" +_MIGRATABLE_TYPES: frozenset[str] = frozenset({"ETL", "TASK"}) + + +class PipelinePhase(Phase): + name = "pipeline" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(PIPELINE_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for pipeline: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS pipeline: {e}") + return result + + try: + src_pipelines = self.ctx.source.list_pipelines() + except Exception as e: + logger.exception("Failed to list source pipelines: %s", e) + result.failed += 1 + result.errors.append(f"list source pipelines: {e}") + return result + + migratable = [ + p for p in src_pipelines if p.get("pipeline_type") in _MIGRATABLE_TYPES + ] + skipped_types = len(src_pipelines) - len(migratable) + if skipped_types: + logger.info( + "Found %d source pipeline(s); skipping %d of unsupported type (DEFAULT/APP)", + len(src_pipelines), + skipped_types, + ) + else: + logger.info("Found %d source pipeline(s)", len(src_pipelines)) + + self.parallel_map( + migratable, + lambda src, lock: self._clone_one(src, result, lock), + ) + return result + + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = src["pipeline_name"] + src_id = src["id"] + src_wf_id = src.get("workflow") or src.get("workflow_id") + + if not src_wf_id: + logger.warning("source pipeline '%s' has no workflow FK — skipping", name) + with lock: + result.skipped += 1 + return + + with lock: + tgt_wf_id = self.ctx.remap.resolve("workflow", src_wf_id) + if not tgt_wf_id: + logger.warning( + "no workflow remap for pipeline '%s' (src workflow %s) — skipping", + name, + src_wf_id, + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_pipelines(name=name) + except Exception as e: + logger.exception("Failed to GET pipeline %s on target: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"pipeline '{name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create pipeline '%s' src=%s", name, src_id) + return + else: + try: + full_src = self.ctx.source.get_pipeline(src_id) + except Exception as e: + logger.exception("Failed to GET source pipeline %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET src pipeline {name}: {e}") + return + remapped = remap_uuids(full_src, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + payload["workflow"] = tgt_wf_id + try: + tgt = self.ctx.target.create_pipeline(payload) + except Exception as e: + logger.exception("Failed to create pipeline %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created pipeline '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + self._warn_if_extra_source_keys(src_id, name) + + with lock: + self.ctx.remap.record("pipeline", src_id, tgt["id"]) + + def _warn_if_extra_source_keys(self, src_pipeline_id: str, name: str) -> None: + try: + keys = self.ctx.source.list_pipeline_keys(src_pipeline_id) + except Exception as e: + # WARNING (not DEBUG) — the operator needs to know we couldn't + # check whether they have additional keys to recreate manually. + logger.warning( + "Could not list source keys for pipeline %s " + "(extra-key check skipped; re-verify in source UI): %s", + name, + e, + ) + return + active = [k for k in keys if k.get("is_active")] + if len(active) > 1: + logger.warning( + "source pipeline '%s' had %d active API keys; " + "target has only the auto-provisioned default — " + "re-create the rest manually if your clients depend on them", + name, + len(active), + ) diff --git a/src/unstract/clone/phases/tag.py b/src/unstract/clone/phases/tag.py new file mode 100644 index 0000000..9cbca05 --- /dev/null +++ b/src/unstract/clone/phases/tag.py @@ -0,0 +1,97 @@ +"""Migrate tags from source org to target org. + +Tags are flat (``name`` + ``description``) with a per-org uniqueness +constraint on ``name``. No metadata, no encryption, no list-vs-detail +divergence — the simplest entity in the clone set. + +List endpoint paginates; ``PlatformClient.list_tags`` already unwraps +the envelope. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +TAG_PATH = "tags/" + + +class TagPhase(Phase): + name = "tag" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(TAG_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for tag: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS tag: {e}") + return result + try: + src_tags = self.ctx.source.list_tags() + except Exception as e: + logger.exception("Failed to list source tags: %s", e) + result.failed += 1 + result.errors.append(f"list source tags: {e}") + return result + + logger.info("Found %d tag(s) in source org", len(src_tags)) + self.parallel_map( + src_tags, + lambda src, lock: self._clone_one(src, result, lock), + ) + return result + + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = src["name"] + src_id = src["id"] + + try: + existing = self.ctx.target.list_tags(name=name) + except Exception as e: + logger.exception("Failed to GET tag %s on target: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"tag '{name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info("adopted tag '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create tag '%s' src=%s", name, src_id) + return + else: + payload = build_post_payload(src, self._writable) + try: + tgt = self.ctx.target.create_tag(payload) + except Exception as e: + logger.exception("Failed to create tag %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + with lock: + result.created += 1 + logger.info("created tag '%s' src=%s -> tgt=%s", name, src_id, tgt["id"]) + + with lock: + self.ctx.remap.record("tag", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/tool_instance.py b/src/unstract/clone/phases/tool_instance.py new file mode 100644 index 0000000..293d206 --- /dev/null +++ b/src/unstract/clone/phases/tool_instance.py @@ -0,0 +1,231 @@ +"""Migrate ToolInstance rows from source org to target org. + +Each workflow holds at most one ToolInstance, enforced server-side +(``tool_instance_v2/serializers.py`` raises if a workflow already has one). +The row carries: + +- ``workflow`` FK — remapped from the WorkflowPhase remap table. +- ``tool_id`` (CharField, not FK) — a ``prompt_registry_id`` UUID. The + target's registry was rebuilt in CustomToolPhase, so we remap via the + ``prompt_studio_registry`` table populated there. +- ``metadata`` JSON — backend's ``create()`` discards the POST metadata + and rebuilds it from tool defaults. So we POST a bare instance, then + PATCH the metadata afterwards. Source metadata stores adapter values + as NAMES (via to_representation in source GET); on PATCH the backend's + ``update_metadata_with_adapter_instances`` resolves those names to + the target's adapter UUIDs. Names match across orgs because + AdapterPhase preserved them. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + +logger = logging.getLogger(__name__) + +# Source backend's ToolInstanceSerializer.to_representation emits these +# sentinel strings when an adapter UUID/name in the stored metadata can +# no longer be resolved (deleted or renamed on source). Round-tripping +# them to target produces an AdapterNotFound on PATCH, so we detect and +# skip the metadata PATCH instead — the ToolInstance row exists with the +# backend's safe defaults and the operator can re-bind in the UI. +_BROKEN_ADAPTER_SENTINELS: tuple[str, ...] = ( + "NOT FOUND", + "[DELETED ADAPTER", + "[NEEDS UPDATE]", +) + +# Fields tied to the source row's own ids — never valid on the target. +# Always rewrite these with target values before PATCHing. +_SOURCE_IDENTITY_FIELDS: tuple[str, ...] = ( + "prompt_registry_id", + "tool_instance_id", + "tenant_id", +) + + +def _broken_adapter_keys(metadata: dict[str, Any]) -> list[str]: + broken: list[str] = [] + for key, value in metadata.items(): + if isinstance(value, str) and any( + s in value for s in _BROKEN_ADAPTER_SENTINELS + ): + broken.append(f"{key}={value!r}") + return broken + + +def _strip_source_identity(metadata: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in metadata.items() if k not in _SOURCE_IDENTITY_FIELDS} + + +class ToolInstancePhase(Phase): + name = "tool_instance" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) + if not workflow_remap: + logger.info("No workflows in remap; nothing to do for tool_instance phase") + return result + + self.parallel_map( + list(workflow_remap.items()), + lambda pair, lock: self._clone_workflow_tools( + pair[0], pair[1], result, lock + ), + ) + return result + + def _clone_workflow_tools( + self, + src_wf_id: str, + tgt_wf_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + src_instances = self.ctx.source.list_tool_instances(workflow_id=src_wf_id) + except Exception as e: + logger.exception( + "Failed to list source tool_instances for wf %s: %s", src_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"list src tool_instances {src_wf_id}: {e}") + return + + if not src_instances: + return + if len(src_instances) > 1: + logger.warning( + "source workflow %s has %d tool_instances (expected ≤1) — migrating first only", + src_wf_id, + len(src_instances), + ) + + src_ti = src_instances[0] + src_ti_id = src_ti["id"] + src_tool_id = src_ti["tool_id"] + + with lock: + tgt_tool_id = self.ctx.remap.resolve("prompt_studio_registry", src_tool_id) + if not tgt_tool_id: + logger.warning( + "skipping tool_instance %s — no registry remap for tool_id %s " + "(custom tool likely unpublished on source)", + src_ti_id, + src_tool_id, + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_tool_instances(workflow_id=tgt_wf_id) + except Exception as e: + logger.exception( + "Failed to list target tool_instances for wf %s: %s", tgt_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"list tgt tool_instances {tgt_wf_id}: {e}") + return + + if existing: + tgt_ti = existing[0] + if self.ctx.options.dry_run: + with lock: + result.skipped += 1 + self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) + logger.info( + "[dry-run] would re-PATCH metadata on adopted tool_instance " + "src=%s -> tgt=%s (workflow %s)", + src_ti_id, + tgt_ti["id"], + tgt_wf_id, + ) + return + with lock: + result.adopted += 1 + logger.info( + "adopted tool_instance src=%s -> tgt=%s (workflow %s)", + src_ti_id, + tgt_ti["id"], + tgt_wf_id, + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would create tool_instance for tgt workflow %s " + "(src tool_instance %s)", + tgt_wf_id, + src_ti_id, + ) + return + else: + try: + tgt_ti = self.ctx.target.create_tool_instance( + {"workflow_id": tgt_wf_id, "tool_id": tgt_tool_id} + ) + except Exception as e: + logger.exception( + "Failed to create tool_instance for wf %s: %s", tgt_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"create tool_instance {tgt_wf_id}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created tool_instance src=%s -> tgt=%s (workflow %s)", + src_ti_id, + tgt_ti["id"], + tgt_wf_id, + ) + + src_metadata = src_ti.get("metadata") or {} + broken = _broken_adapter_keys(src_metadata) + if broken: + logger.warning( + "skipping metadata PATCH for tool_instance src=%s tgt=%s — " + "source metadata carries broken adapter refs %s; " + "row exists with backend defaults, re-bind in UI", + src_ti_id, + tgt_ti["id"], + broken, + ) + with lock: + result.skipped += 1 + result.errors.append( + f"stale adapter refs on src tool_instance {src_ti_id}: {broken}" + ) + else: + # PATCH overwrites the whole metadata dict — re-stamp target + # identity fields or the runtime sees them as empty. + patch_metadata = { + **_strip_source_identity(src_metadata), + "prompt_registry_id": tgt_tool_id, + "tool_instance_id": tgt_ti["id"], + } + try: + self.ctx.target.update_tool_instance_metadata( + tgt_ti["id"], patch_metadata + ) + except Exception as e: + logger.exception( + "Failed to PATCH tool_instance %s metadata: %s", tgt_ti["id"], e + ) + with lock: + result.failed += 1 + result.errors.append(f"patch metadata {tgt_ti['id']}: {e}") + return + + with lock: + self.ctx.remap.record("tool_instance", src_ti_id, tgt_ti["id"]) diff --git a/src/unstract/clone/phases/workflow.py b/src/unstract/clone/phases/workflow.py new file mode 100644 index 0000000..2612eaf --- /dev/null +++ b/src/unstract/clone/phases/workflow.py @@ -0,0 +1,160 @@ +"""Migrate workflows from source org to target org. + +Workflow rows themselves are simple — no required FKs to clone +entities, unique per ``(workflow_name, organization)``. The two +non-trivial bits: + +1. ``source_settings`` and ``destination_settings`` are JSON blobs that + embed connector UUIDs. The walker remaps them using the running + ``RemapTable`` (connectors already landed in the previous phase). + +2. Creating a workflow auto-creates empty ``WorkflowEndpoint`` rows + server-side. We don't touch those here — the dedicated + WorkflowEndpoint phase reconciles them after ToolInstance lands. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.base import Phase, build_post_payload +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids + +logger = logging.getLogger(__name__) + +WORKFLOW_PATH = "workflow/" + + +class WorkflowPhase(Phase): + name = "workflow" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + try: + self._writable = self.ctx.target.get_post_schema(WORKFLOW_PATH) + except Exception as e: + logger.exception("Failed to fetch target POST schema for workflow: %s", e) + result.failed += 1 + result.errors.append(f"OPTIONS workflow: {e}") + return result + + try: + src_workflows = self.ctx.source.list_workflows() + except Exception as e: + logger.exception("Failed to list source workflows: %s", e) + result.failed += 1 + result.errors.append(f"list source workflows: {e}") + return result + + # Built once so per-workflow cascade-skip checks stay O(1). + self._wf_to_src_tool_id = self._collect_wf_tool_map(result) + + logger.info("Found %d workflow(s) in source org", len(src_workflows)) + self.parallel_map( + src_workflows, + lambda src, lock: self._clone_one(src, result, lock), + ) + return result + + def _collect_wf_tool_map(self, result: PhaseResult) -> dict[str, str]: + """Map source workflow_id to its ToolInstance.tool_id; listed once + to avoid N+1 fetches. + """ + if not self.ctx.skipped_custom_tool_registry_ids: + return {} + try: + tis = self.ctx.source.list_tool_instances() + except Exception as e: + logger.warning( + "workflow phase: failed to list source tool_instances for " + "cascade-skip lookup (%s); proceeding without cascade", + e, + ) + return {} + mapping: dict[str, str] = {} + for ti in tis: + wf_id = ti.get("workflow") + tool_id = ti.get("tool_id") + if wf_id and tool_id: + mapping[wf_id] = tool_id + return mapping + + def _clone_one( + self, src: dict[str, Any], result: PhaseResult, lock: threading.Lock + ) -> None: + name = src["workflow_name"] + src_id = src["id"] + + src_tool_id = self._wf_to_src_tool_id.get(src_id) + if src_tool_id and src_tool_id in self.ctx.skipped_custom_tool_registry_ids: + logger.warning( + "skipping workflow '%s' src=%s — its tool was skipped in " + "custom_tool phase (frictionless adapter dependence)", + name, + src_id, + ) + with lock: + result.skipped += 1 + return + + try: + existing = self.ctx.target.list_workflows(name=name) + except Exception as e: + logger.exception("Failed to GET workflow %s on target: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"GET {name}: {e}") + return + + if existing: + tgt = existing[0] + if self.ctx.options.on_name_conflict == "abort": + raise NameConflictError( + f"workflow '{name}' already exists in target as {tgt['id']}" + ) + with lock: + result.adopted += 1 + logger.info( + "adopted workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + elif self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info("[dry-run] would create workflow '%s' src=%s", name, src_id) + return + else: + # List endpoints serve stripped payloads (e.g. AdapterListSerializer + # omits adapter_metadata_b); workflow detail carries the JSON blobs + # source_settings / destination_settings that embed connector UUIDs. + try: + src_detail = self.ctx.source.get_workflow(src_id) + except Exception as e: + logger.exception( + "Failed to GET source workflow %s detail: %s", name, e + ) + with lock: + result.failed += 1 + result.errors.append(f"GET source detail {name}: {e}") + return + remapped = remap_uuids(src_detail, self.ctx.remap) + payload = build_post_payload(remapped, self._writable) + try: + tgt = self.ctx.target.create_workflow(payload) + except Exception as e: + logger.exception("Failed to create workflow %s: %s", name, e) + with lock: + result.failed += 1 + result.errors.append(f"create {name}: {e}") + return + with lock: + result.created += 1 + logger.info( + "created workflow '%s' src=%s -> tgt=%s", name, src_id, tgt["id"] + ) + + with lock: + self.ctx.remap.record("workflow", src_id, tgt["id"]) diff --git a/src/unstract/clone/phases/workflow_endpoint.py b/src/unstract/clone/phases/workflow_endpoint.py new file mode 100644 index 0000000..a9ffa7a --- /dev/null +++ b/src/unstract/clone/phases/workflow_endpoint.py @@ -0,0 +1,190 @@ +"""Migrate WorkflowEndpoint rows from source org to target org. + +The backend auto-creates one SOURCE and one DESTINATION endpoint per +workflow on workflow create (``perform_create`` in WorkflowViewSet), so +there's nothing to POST — we only PATCH the target's existing endpoints +with the source's connection_type, connector_instance, and configuration. + +Notes: +- ``workflow`` and ``endpoint_type`` are ``editable=False`` server-side + and aren't writable on PATCH. +- ``connector_instance`` FK is nullable; we remap via the connector + remap table populated in ConnectorPhase. +- ``configuration`` is a JSON blob that may embed connector UUIDs; + walker pass remaps them before PATCH. +- Source ``connector_instance`` arrives as a nested dict (per + ``WorkflowEndpointSerializer.connector_instance``); we extract its + ``id`` and remap. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult +from unstract.clone.walker import remap_uuids + +logger = logging.getLogger(__name__) + + +def _extract_connector_id(endpoint: dict[str, Any]) -> str | None: + """``connector_instance`` is a nested dict on GET; pull out the FK uuid.""" + ci = endpoint.get("connector_instance") + if isinstance(ci, dict): + return ci.get("id") + if isinstance(ci, str): + return ci + return None + + +class WorkflowEndpointPhase(Phase): + name = "workflow_endpoint" + + def run(self, report: CloneReport) -> PhaseResult: + result = report.get_phase(self.name) + workflow_remap = self.ctx.remap.snapshot().get("workflow", {}) + if not workflow_remap: + logger.info( + "No workflows in remap; nothing to do for workflow_endpoint phase" + ) + return result + + self.parallel_map( + list(workflow_remap.items()), + lambda pair, lock: self._clone_workflow_endpoints( + pair[0], pair[1], result, lock + ), + ) + return result + + def _clone_workflow_endpoints( + self, + src_wf_id: str, + tgt_wf_id: str, + result: PhaseResult, + lock: threading.Lock, + ) -> None: + try: + src_endpoints = self.ctx.source.list_workflow_endpoints( + workflow_id=src_wf_id + ) + except Exception as e: + logger.exception( + "Failed to list source endpoints for wf %s: %s", src_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"list src endpoints {src_wf_id}: {e}") + return + + try: + tgt_endpoints = self.ctx.target.list_workflow_endpoints( + workflow_id=tgt_wf_id + ) + except Exception as e: + logger.exception( + "Failed to list target endpoints for wf %s: %s", tgt_wf_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"list tgt endpoints {tgt_wf_id}: {e}") + return + + tgt_by_type = {ep["endpoint_type"]: ep for ep in tgt_endpoints} + + for src_ep in src_endpoints: + etype = src_ep["endpoint_type"] + tgt_ep = tgt_by_type.get(etype) + if tgt_ep is None: + logger.warning( + "target workflow %s missing %s endpoint — skipping", + tgt_wf_id, + etype, + ) + with lock: + result.failed += 1 + result.errors.append( + f"missing tgt {etype} endpoint for wf {tgt_wf_id}" + ) + continue + + self._patch_endpoint(src_ep, tgt_ep, result, lock) + + def _patch_endpoint( + self, + src_ep: dict[str, Any], + tgt_ep: dict[str, Any], + result: PhaseResult, + lock: threading.Lock, + ) -> None: + src_ep_id = src_ep["id"] + tgt_ep_id = tgt_ep["id"] + etype = src_ep["endpoint_type"] + + if self.ctx.options.dry_run: + with lock: + result.skipped += 1 + logger.info( + "[dry-run] would PATCH %s endpoint src=%s -> tgt=%s", + etype, + src_ep_id, + tgt_ep_id, + ) + return + + src_conn_id = _extract_connector_id(src_ep) + tgt_conn_id: str | None = None + if src_conn_id: + with lock: + tgt_conn_id = self.ctx.remap.resolve("connector", src_conn_id) + if not tgt_conn_id: + logger.warning( + "skipping %s endpoint src=%s tgt=%s — source connector %s " + "has no target remap; would silently unset connector", + etype, + src_ep_id, + tgt_ep_id, + src_conn_id, + ) + with lock: + result.skipped += 1 + result.errors.append( + f"unmapped connector on {etype} endpoint {src_ep_id}: " + f"src_connector={src_conn_id}" + ) + return + + payload: dict[str, Any] = { + "configuration": remap_uuids( + src_ep.get("configuration") or {}, self.ctx.remap + ), + "connector_instance_id": tgt_conn_id, + } + src_connection_type = src_ep.get("connection_type") + if src_connection_type is not None: + payload["connection_type"] = src_connection_type + + try: + self.ctx.target.update_workflow_endpoint(tgt_ep_id, payload) + except Exception as e: + logger.exception( + "Failed to PATCH %s endpoint tgt=%s: %s", etype, tgt_ep_id, e + ) + with lock: + result.failed += 1 + result.errors.append(f"patch {etype} {tgt_ep_id}: {e}") + return + + with lock: + result.created += 1 + self.ctx.remap.record("workflow_endpoint", src_ep_id, tgt_ep_id) + logger.info( + "patched %s endpoint src=%s -> tgt=%s (connector %s)", + etype, + src_ep_id, + tgt_ep_id, + tgt_conn_id, + ) diff --git a/src/unstract/clone/report.py b/src/unstract/clone/report.py new file mode 100644 index 0000000..4bec96a --- /dev/null +++ b/src/unstract/clone/report.py @@ -0,0 +1,341 @@ +"""Structured report produced by ``clone()``. + +Tracks per-phase counts (created / adopted / skipped / failed) and a final +remap snapshot. Renders to a rich-formatted table when ``rich`` is +available; falls back to plain text otherwise. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass +class PhaseResult: + name: str + created: int = 0 + adopted: int = 0 + skipped: int = 0 + failed: int = 0 + errors: list[str] = field(default_factory=list) + duration_s: float = 0.0 + + +@dataclass +class Endpoint: + """Just enough about an endpoint for the report header — never carries the API key.""" + + base_url: str + organization_id: str + + +@dataclass +class CloneReport: + source: Endpoint | None = None + target: Endpoint | None = None + phases: list[PhaseResult] = field(default_factory=list) + skipped_phases: list[str] = field(default_factory=list) + remap_snapshot: dict[str, dict[str, str]] = field(default_factory=dict) + aborted: bool = False + abort_reason: str | None = None + total_duration_s: float = 0.0 + # Files-phase artifacts. Each entry carries enough context for an + # operator to act on it without cross-referencing the run log. + uploaded_files: list[dict[str, Any]] = field(default_factory=list) + skipped_files: list[dict[str, Any]] = field(default_factory=list) + oversize_files: list[dict[str, Any]] = field(default_factory=list) + unsupported_files: list[dict[str, Any]] = field(default_factory=list) + failed_files: list[dict[str, Any]] = field(default_factory=list) + + def get_phase(self, name: str) -> PhaseResult: + for p in self.phases: + if p.name == name: + return p + result = PhaseResult(name=name) + self.phases.append(result) + return result + + def render(self) -> str: + """Render as a rich table when available, otherwise plain text.""" + try: + from io import StringIO + + from rich.console import Console + from rich.table import Table + except ImportError: + return self._render_plain() + + buf = StringIO() + # force_terminal so ANSI codes survive the StringIO capture; the + # caller decides whether to strip them when printing to a non-tty. + console = Console( + file=buf, force_terminal=True, color_system="truecolor", width=100 + ) + # Actionable summary first so it doesn't scroll past the table. + self._render_failures_summary(console_print=console.print, rich=True) + self._render_endpoints(console.print) + table = Table(title="Clone Report", header_style="bold cyan") + table.add_column("Phase", style="bold", justify="left") + for col in ("Created", "Adopted", "Skipped", "Failed", "Time"): + table.add_column(col, justify="right") + + totals = {"created": 0, "adopted": 0, "skipped": 0, "failed": 0} + for p in self.phases: + phase_style = "red" if p.failed else ("yellow" if p.skipped else "green") + table.add_row( + f"[{phase_style}]{p.name}[/{phase_style}]", + self._fmt_count(p.created, "green"), + self._fmt_count(p.adopted, "green"), + self._fmt_count(p.skipped, "yellow"), + self._fmt_count(p.failed, "red"), + self._fmt_duration(p.duration_s), + ) + for k in totals: + totals[k] += getattr(p, k) + + table.add_section() + table.add_row( + "[bold]TOTAL[/bold]", + self._fmt_count(totals["created"], "green", bold=True), + self._fmt_count(totals["adopted"], "green", bold=True), + self._fmt_count(totals["skipped"], "yellow", bold=True), + self._fmt_count(totals["failed"], "red", bold=True), + self._fmt_duration(self.total_duration_s, bold=True), + ) + console.print(table) + if self.skipped_phases: + console.print( + f"[dim]Skipped phases:[/dim] {', '.join(self.skipped_phases)}" + ) + self._render_files_sections(console) + self._render_remap_summary(console_print=console.print) + if self.aborted: + console.print(f"[bold red]ABORTED:[/bold red] {self.abort_reason}") + elif totals["failed"]: + console.print( + f"[bold red]Completed with {totals['failed']} failure(s)[/bold red] — " + "see WARNING/ERROR log lines above for details" + ) + else: + console.print("[bold green]Completed successfully[/bold green]") + return buf.getvalue() + + @staticmethod + def _fmt_count(value: int, color: str, bold: bool = False) -> str: + """Dim a zero to keep the eye on non-zero cells; colour anything > 0.""" + if value == 0: + return "[dim]0[/dim]" + style = f"bold {color}" if bold else color + return f"[{style}]{value}[/{style}]" + + @staticmethod + def _fmt_duration(seconds: float, bold: bool = False) -> str: + if seconds <= 0: + return "[dim]—[/dim]" + if seconds < 60: + text = f"{seconds:.1f}s" + else: + mins, secs = divmod(seconds, 60) + text = f"{int(mins)}m{secs:.0f}s" + return f"[bold]{text}[/bold]" if bold else text + + @staticmethod + def _fmt_duration_plain(seconds: float) -> str: + if seconds <= 0: + return "—" + if seconds < 60: + return f"{seconds:.1f}s" + mins, secs = divmod(seconds, 60) + return f"{int(mins)}m{secs:.0f}s" + + def _render_plain(self) -> str: + lines: list[str] = [] + self._render_failures_summary(console_print=lines.append, rich=False) + lines.extend(["Clone Report", "=" * 60]) + self._render_endpoints(lines.append) + header = ( + f"{'Phase':<24}{'Created':>10}{'Adopted':>10}" + f"{'Skipped':>10}{'Failed':>10}{'Time':>10}" + ) + lines.append(header) + for p in self.phases: + lines.append( + f"{p.name:<24}{p.created:>10}{p.adopted:>10}" + f"{p.skipped:>10}{p.failed:>10}{self._fmt_duration_plain(p.duration_s):>10}" + ) + lines.append( + f"{'TOTAL':<64}{self._fmt_duration_plain(self.total_duration_s):>10}" + ) + if self.skipped_phases: + lines.append(f"Skipped phases: {', '.join(self.skipped_phases)}") + lines.extend(self._files_sections_plain()) + self._render_remap_summary(console_print=lines.append) + if self.aborted: + lines.append(f"ABORTED: {self.abort_reason}") + return "\n".join(lines) + + def as_dict(self) -> dict[str, Any]: + return { + "source": ( + { + "base_url": self.source.base_url, + "organization_id": self.source.organization_id, + } + if self.source + else None + ), + "target": ( + { + "base_url": self.target.base_url, + "organization_id": self.target.organization_id, + } + if self.target + else None + ), + "phases": [ + { + "name": p.name, + "created": p.created, + "adopted": p.adopted, + "skipped": p.skipped, + "failed": p.failed, + "errors": list(p.errors), + "duration_s": p.duration_s, + } + for p in self.phases + ], + "skipped_phases": list(self.skipped_phases), + "remap_snapshot": self.remap_snapshot, + "aborted": self.aborted, + "abort_reason": self.abort_reason, + "total_duration_s": self.total_duration_s, + "uploaded_files": list(self.uploaded_files), + "skipped_files": list(self.skipped_files), + "oversize_files": list(self.oversize_files), + "unsupported_files": list(self.unsupported_files), + "failed_files": list(self.failed_files), + } + + def _render_endpoints(self, console_print: Any) -> None: + if not self.source and not self.target: + return + src = self._fmt_endpoint(self.source) + tgt = self._fmt_endpoint(self.target) + console_print(f"Source: {src}") + console_print(f"Target: {tgt}") + + @staticmethod + def _fmt_endpoint(ep: Endpoint | None) -> str: + if ep is None: + return "?" + return f"{ep.organization_id} @ {ep.base_url}" + + def _render_remap_summary(self, console_print: Any) -> None: + """Summarise the remap snapshot. Full map is large and noisy, so + we only print per-entity counts here; the full mapping is emitted + at DEBUG and remains in ``as_dict()`` for programmatic consumers. + """ + if not self.remap_snapshot: + return + counts = ", ".join( + f"{entity}={len(mapping)}" + for entity, mapping in self.remap_snapshot.items() + if mapping + ) + if counts: + console_print(f"Remap entries: {counts}") + if logger.isEnabledFor(logging.DEBUG): + for entity, mapping in self.remap_snapshot.items(): + for src, tgt in mapping.items(): + logger.debug("remap %s %s -> %s", entity, src, tgt) + + def _render_files_sections(self, console: Any) -> None: + if self.uploaded_files: + console.print(f"[green]Files uploaded:[/green] {len(self.uploaded_files)}") + for header, rows in ( + ("Oversize files (manual upload required)", self.oversize_files), + ("Unsupported mime files (manual upload required)", self.unsupported_files), + ("Skipped files (operator action required)", self.skipped_files), + ("Failed files", self.failed_files), + ): + if not rows: + continue + console.print(f"[yellow]{header}:[/yellow]") + for row in rows: + console.print(f" - {self._describe_file_row(row)}") + + def _files_sections_plain(self) -> list[str]: + lines: list[str] = [] + if self.uploaded_files: + lines.append(f"Files uploaded: {len(self.uploaded_files)}") + for header, rows in ( + ("Oversize files (manual upload required)", self.oversize_files), + ("Unsupported mime files (manual upload required)", self.unsupported_files), + ("Skipped files (operator action required)", self.skipped_files), + ("Failed files", self.failed_files), + ): + if not rows: + continue + lines.append(f"{header}:") + for row in rows: + lines.append(f" - {self._describe_file_row(row)}") + return lines + + # Caps so a long traceback or many failures don't dominate the report. + _FAILURE_LINE_MAX_CHARS = 200 + _FAILURE_MAX_ROWS = 30 + + def _render_failures_summary(self, console_print: Any, rich: bool) -> None: + rows: list[tuple[str, str]] = [] + for p in self.phases: + for err in p.errors: + rows.append((p.name, err)) + if not rows: + return + header = "Failures (see WARNING/ERROR log lines above for full detail)" + if rich: + console_print(f"[red]{header}:[/red]") + else: + console_print(f"{header}:") + shown = rows[: self._FAILURE_MAX_ROWS] + for phase_name, err in shown: + truncated = self._truncate(err, self._FAILURE_LINE_MAX_CHARS) + if rich: + console_print( + f" - [bold cyan]{phase_name}[/bold cyan]: {truncated}", + highlight=False, + ) + else: + console_print(f" - {phase_name}: {truncated}") + remaining = len(rows) - len(shown) + if remaining > 0: + tail = f" ... +{remaining} more — see logs" + if rich: + console_print(f"[dim]{tail}[/dim]") + else: + console_print(tail) + + @staticmethod + def _truncate(text: str, limit: int) -> str: + text = text.replace("\n", " ") + if len(text) <= limit: + return text + return text[: limit - 1] + "…" + + @staticmethod + def _describe_file_row(row: dict[str, Any]) -> str: + tool = row.get("tool_name") or row.get("tool_id") or "?" + name = row.get("file_name", "?") + extras: list[str] = [] + if "size_bytes" in row: + extras.append(f"{row['size_bytes']} bytes") + if "mime_type" in row: + extras.append(row["mime_type"]) + if "error" in row: + extras.append(f"error={row['error']}") + suffix = f" ({', '.join(extras)})" if extras else "" + return f"tool={tool} file={name}{suffix}" diff --git a/src/unstract/clone/walker.py b/src/unstract/clone/walker.py new file mode 100644 index 0000000..eb9c401 --- /dev/null +++ b/src/unstract/clone/walker.py @@ -0,0 +1,32 @@ +"""JSON walker that rewrites embedded source UUIDs to target UUIDs. + +Used by phases whose payloads carry foreign-key UUIDs inside JSON fields +(e.g. ``tool_instance.metadata``). Unknown UUIDs pass through untouched so +we don't accidentally rewrite an unrelated identifier that just happens +to look like a UUID. +""" + +from __future__ import annotations + +import re +from typing import Any + +from unstract.clone.context import RemapTable + +UUID_RE = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$", + re.IGNORECASE, +) + + +def remap_uuids(obj: Any, remap: RemapTable) -> Any: + """Walk a JSON-shaped value; replace any string that looks like a UUID + AND has a known mapping. Unknown UUIDs pass through untouched. + """ + if isinstance(obj, dict): + return {k: remap_uuids(v, remap) for k, v in obj.items()} + if isinstance(obj, list): + return [remap_uuids(v, remap) for v in obj] + if isinstance(obj, str) and UUID_RE.match(obj): + return remap.resolve_any(obj) or obj + return obj diff --git a/tests/clone/__init__.py b/tests/clone/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/clone/test_adapter_phase.py b/tests/clone/test_adapter_phase.py new file mode 100644 index 0000000..d2b0311 --- /dev/null +++ b/tests/clone/test_adapter_phase.py @@ -0,0 +1,163 @@ +"""Tests for ``AdapterPhase``. + +Uses an in-process fake ``PlatformClient`` to avoid real HTTP. Verifies: +- happy path: source has N adapters, target gets N POSTs, all remapped +- idempotency: re-run with target already populated → zero POSTs, all adopted +- dry-run: zero POSTs, all skipped +- on_name_conflict='abort' raises on existing +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + """Minimal in-memory stand-in for ``PlatformClient``.""" + + # Mirrors DRF OPTIONS actions.POST writable fields for adapter. + POST_SCHEMA = frozenset( + {"adapter_id", "adapter_name", "adapter_type", "adapter_metadata", "description"} + ) + + def __init__(self, adapters: list[dict] | None = None): + # Stored as a list of dicts; mutated by create_adapter. + self.adapters: list[dict] = list(adapters or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_adapters(self, *, name=None, adapter_type=None): + result = self.adapters + if name is not None: + result = [a for a in result if a["adapter_name"] == name] + if adapter_type is not None: + result = [a for a in result if a["adapter_type"] == adapter_type] + # Mimic AdapterListSerializer — strip adapter_metadata from list output. + return [{k: v for k, v in a.items() if k != "adapter_metadata"} for a in result] + + def get_adapter(self, adapter_pk): + for a in self.adapters: + if a["id"] == adapter_pk: + return a + raise KeyError(adapter_pk) + + def create_adapter(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.adapters.append(new) + self.posts.append(new) + return new + + +def _src_adapter(id_, name, atype="LLM"): + return { + "id": id_, + "adapter_id": "openai-llm-v2", + "adapter_name": name, + "adapter_type": atype, + "adapter_metadata": {"api_key": "sk-secret", "model": "gpt-4"}, + "description": f"{name} desc", + } + + +def _ctx(source: FakeClient, target: FakeClient, **opt_overrides): + ctx = CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + return ctx + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient( + [ + _src_adapter("src-a", "OpenAI Prod"), + _src_adapter("src-b", "Mistral Stg", atype="EMBEDDING"), + ] + ) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert result.failed == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("adapter", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("adapter", "src-b") == tgt.posts[1]["id"] + + +def test_idempotency_zero_creates_on_rerun(): + src_adapters = [_src_adapter("src-a", "OpenAI Prod")] + src = FakeClient(src_adapters) + # Target pre-populated with the same name+type — simulates a prior run. + tgt = FakeClient( + [ + { + "id": "preexisting", + "adapter_id": "openai-llm-v2", + "adapter_name": "OpenAI Prod", + "adapter_type": "LLM", + "adapter_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] # no new POSTs + assert ctx.remap.resolve("adapter", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src_adapter("src-a", "OpenAI Prod")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_adapter("src-a", "OpenAI Prod")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "adapter_id": "openai-llm-v2", + "adapter_name": "OpenAI Prod", + "adapter_type": "LLM", + "adapter_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + AdapterPhase(ctx).run(report) diff --git a/tests/clone/test_api_deployment_phase.py b/tests/clone/test_api_deployment_phase.py new file mode 100644 index 0000000..dc25d7a --- /dev/null +++ b/tests/clone/test_api_deployment_phase.py @@ -0,0 +1,185 @@ +"""Tests for ``APIDeploymentPhase``. + +Coverage: +- happy path: source api_deployments created with workflow FK remapped. +- adopt by api_name on existing target deployment. +- skipped when workflow remap missing. +- dry-run is a no-op. +- abort raises ``NameConflictError``. +- extra source keys produce a warning, never a failure. +""" + +from __future__ import annotations + +import logging + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.api_deployment import APIDeploymentPhase +from unstract.clone.report import CloneReport + +API_DEPLOYMENT_POST_SCHEMA = frozenset( + { + "display_name", + "description", + "workflow", + "is_active", + "api_name", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, deployments: list[dict] | None = None): + self.deployments: list[dict] = list(deployments or []) + self.posts: list[dict] = [] + self.keys_by_deployment: dict[str, list[dict]] = {} + self._next = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return API_DEPLOYMENT_POST_SCHEMA + + def list_api_deployments(self, *, api_name: str | None = None): + result = self.deployments + if api_name is not None: + result = [d for d in result if d["api_name"] == api_name] + return list(result) + + def get_api_deployment(self, deployment_id: str) -> dict: + for d in self.deployments: + if d["id"] == deployment_id: + return dict(d) + raise KeyError(deployment_id) + + def create_api_deployment(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-dep-{self._next:04d}" + new["api_key"] = f"key-{self._next:04d}" + self._next += 1 + self.deployments.append(new) + self.posts.append(new) + return new + + def list_api_deployment_keys(self, deployment_id: str) -> list[dict]: + return list(self.keys_by_deployment.get(deployment_id, [])) + + +def _src_deployment( + id_: str, api_name: str, workflow_id: str, *, display_name: str | None = None +) -> dict: + return { + "id": id_, + "api_name": api_name, + "display_name": display_name or api_name, + "description": f"{api_name} desc", + "workflow": workflow_id, + "workflow_id": workflow_id, + "is_active": True, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_deployment_with_remapped_workflow(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + posted = tgt.posts[0] + assert posted["api_name"] == "invoices_api" + assert posted["workflow"] == "wf-tgt-1" + assert ctx.remap.resolve("api_deployment", "src-dep-1") == posted["id"] + + +def test_adopts_existing_deployment_by_api_name(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "api_name": "invoices_api"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("api_deployment", "src-dep-1") == "tgt-existing" + + +def test_skipped_when_workflow_remap_missing(): + src = FakeClient([_src_deployment("src-dep-1", "orphan", "wf-src-1")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) # No workflow remap. + + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_dry_run_makes_no_writes(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "api_name": "invoices_api"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + APIDeploymentPhase(ctx).run(CloneReport()) + + +def test_extra_source_keys_log_warning_not_failure(caplog): + src = FakeClient([_src_deployment("src-dep-1", "invoices_api", "wf-src-1")]) + src.keys_by_deployment["src-dep-1"] = [ + {"id": "k1", "is_active": True}, + {"id": "k2", "is_active": True}, + ] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + with caplog.at_level( + logging.WARNING, logger="unstract.clone.phases.api_deployment" + ): + result = APIDeploymentPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + assert any("2 active API keys" in r.message for r in caplog.records) diff --git a/tests/clone/test_base_helpers.py b/tests/clone/test_base_helpers.py new file mode 100644 index 0000000..727affa --- /dev/null +++ b/tests/clone/test_base_helpers.py @@ -0,0 +1,58 @@ +"""Tests for ``unstract.clone.phases.base`` helpers.""" + +from __future__ import annotations + +from unstract.clone.phases.base import SERVER_MANAGED, build_post_payload + + +def test_preserves_false_and_zero_values(): + """Booleans set to False and numeric 0 are legitimate field values. + + Earlier ``value not in (None, "")`` worked for None/"" but dropped + False and 0 too because of Python's ``False == 0 == in (None, "")`` + edge case. Regression guard. + """ + src = { + "is_active": False, + "retry_count": 0, + "rate_limit": 0.0, + "name": "demo", + } + writable = frozenset({"is_active", "retry_count", "rate_limit", "name"}) + + payload = build_post_payload(src, writable) + + assert payload == { + "is_active": False, + "retry_count": 0, + "rate_limit": 0.0, + "name": "demo", + } + + +def test_strips_none_and_empty_string_but_keeps_zero(): + src = {"a": None, "b": "", "c": 0, "d": False, "e": "kept"} + writable = frozenset({"a", "b", "c", "d", "e"}) + + payload = build_post_payload(src, writable) + + assert payload == {"c": 0, "d": False, "e": "kept"} + + +def test_drops_server_managed_keys_even_if_writable(): + src = {"id": "X", "name": "demo", "organization": "org", "created_by": "u"} + # All four are nominally writable but SERVER_MANAGED should win. + writable = frozenset(src.keys()) + + payload = build_post_payload(src, writable) + + assert payload == {"name": "demo"} + for key in SERVER_MANAGED & set(src.keys()): + assert key not in payload + + +def test_ignores_writable_keys_missing_from_src(): + src = {"present": 1} + writable = frozenset({"present", "absent"}) + + assert build_post_payload(src, writable) == {"present": 1} diff --git a/tests/clone/test_cli.py b/tests/clone/test_cli.py new file mode 100644 index 0000000..e4ac623 --- /dev/null +++ b/tests/clone/test_cli.py @@ -0,0 +1,129 @@ +"""Tests for the click CLI wiring in ``unstract.clone.cli``. + +Coverage: +- ``_parse_size`` accepts bare integers, K/M/G suffixes, decimals. +- ``--max-file-size 0`` propagates as 0 (force every file to manual list), + not the default cap — distinguished from the unparseable case. +""" + +from __future__ import annotations + +import pytest +from click.testing import CliRunner + +from unstract.clone.cli import _parse_size, cli +from unstract.clone.context import DEFAULT_MAX_FILE_SIZE, CloneOptions +from unstract.clone.report import CloneReport, Endpoint + + +def test_parse_size_bare_int_is_bytes(): + assert _parse_size("25") == 25 + + +def test_parse_size_accepts_kb_mb_gb_units(): + assert _parse_size("25MB") == 25 * 1024 * 1024 + assert _parse_size("1.5GB") == int(1.5 * 1024 * 1024 * 1024) + assert _parse_size("512K") == 512 * 1024 + + +def test_parse_size_zero_returns_zero(): + # Regression for `cap_bytes or DEFAULT` — must not coerce 0 to the + # default. CLI flag --max-file-size 0 means "every file goes to the + # oversize/manual-upload list". + assert _parse_size("0") == 0 + + +def test_parse_size_unknown_unit_raises(): + import click + + with pytest.raises(click.BadParameter): + _parse_size("10XB") + + +def test_parse_size_unparseable_raises(): + import click + + with pytest.raises(click.BadParameter): + _parse_size("not-a-size") + + +def test_cli_max_file_size_zero_propagates_to_options(monkeypatch): + captured: dict = {} + + def fake_clone(source, target, options=None): + captured["options"] = options + return CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + monkeypatch.setattr("unstract.clone.cli.run_clone", fake_clone) + + result = CliRunner().invoke( + cli, + [ + "clone", + "--source-url", + "http://src", + "--source-org", + "src", + "--source-key", + "sk", + "--target-url", + "http://tgt", + "--target-org", + "tgt", + "--target-key", + "tk", + "--max-file-size", + "0", + ], + ) + + assert result.exit_code == 0, result.output + opts: CloneOptions = captured["options"] + assert opts.max_file_size == 0 + + +def test_cli_max_file_size_default_when_flag_omitted(monkeypatch): + captured: dict = {} + + def fake_clone(source, target, options=None): + captured["options"] = options + return CloneReport( + source=Endpoint( + base_url=source.base_url, organization_id=source.organization_id + ), + target=Endpoint( + base_url=target.base_url, organization_id=target.organization_id + ), + ) + + monkeypatch.setattr("unstract.clone.cli.run_clone", fake_clone) + + result = CliRunner().invoke( + cli, + [ + "clone", + "--source-url", + "http://src", + "--source-org", + "src", + "--source-key", + "sk", + "--target-url", + "http://tgt", + "--target-org", + "tgt", + "--target-key", + "tk", + ], + ) + + assert result.exit_code == 0, result.output + opts: CloneOptions = captured["options"] + assert opts.max_file_size == DEFAULT_MAX_FILE_SIZE diff --git a/tests/clone/test_client.py b/tests/clone/test_client.py new file mode 100644 index 0000000..9fa3dca --- /dev/null +++ b/tests/clone/test_client.py @@ -0,0 +1,145 @@ +"""Tests for ``PlatformClient`` HTTP layer. + +Coverage: +- URL composition honours base_url, api_path_prefix, organization_id. +- Bearer auth header present on every request. +- Non-2xx response raises ``PlatformAPIError`` with status_code + body. +- 204 / empty body returns ``None`` instead of raising on .json(). +- ``get_post_schema`` parses DRF ``actions.POST`` and caches per path. +- ``close()`` shuts the underlying session; context manager works. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from unstract.clone.client import PlatformClient +from unstract.clone.context import OrgEndpoint +from unstract.clone.exceptions import PlatformAPIError + + +def _endpoint() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://api.example.com", + organization_id="org_abc", + platform_key="plat-key-xyz", + ) + + +def _fake_response(status: int, payload=None, text: str = "") -> MagicMock: + resp = MagicMock() + resp.status_code = status + resp.text = text + resp.content = b"" if payload is None and not text else b"x" + resp.json.return_value = payload + return resp + + +def _client_with_mock( + payload=None, status: int = 200, text: str = "" +) -> tuple[PlatformClient, MagicMock]: + client = PlatformClient(_endpoint()) + mock_request = MagicMock(return_value=_fake_response(status, payload, text)) + client._session.request = mock_request + return client, mock_request + + +def test_url_composition_includes_org_and_api_prefix(): + client, mock_request = _client_with_mock(payload=[]) + client.list_adapters() + call = mock_request.call_args + assert call.args[0] == "GET" + assert call.args[1] == "https://api.example.com/api/v1/unstract/org_abc/adapter/" + + +def test_bearer_token_sent_on_session(): + client, _ = _client_with_mock(payload=[]) + assert client._session.headers["Authorization"] == "Bearer plat-key-xyz" + assert client._session.headers["Accept"] == "application/json" + + +def test_non_2xx_raises_platform_api_error_with_status_and_body(): + client, _ = _client_with_mock(status=404, text="not found") + with pytest.raises(PlatformAPIError) as exc_info: + client.list_adapters() + err = exc_info.value + assert err.status_code == 404 + assert "not found" in err.body + + +def test_500_with_long_body_truncated_to_2000_chars(): + big = "x" * 5000 + client, _ = _client_with_mock(status=500, text=big) + with pytest.raises(PlatformAPIError) as exc_info: + client.list_adapters() + assert len(exc_info.value.body) == 2000 + + +def test_204_no_content_returns_none(): + client = PlatformClient(_endpoint()) + resp = MagicMock() + resp.status_code = 204 + resp.content = b"" + client._session.request = MagicMock(return_value=resp) + assert client._request("DELETE", "tag/abc/") is None + + +def test_get_post_schema_parses_options_and_caches(): + options_body = { + "actions": { + "POST": { + "name": {"read_only": False}, + "id": {"read_only": True}, + "shared_to_org": {"read_only": False}, + # No read_only key → treated as writable. + "description": {}, + } + } + } + client, mock_request = _client_with_mock(payload=options_body) + writable = client.get_post_schema("adapter/") + assert writable == frozenset({"name", "shared_to_org", "description"}) + # second call hits cache — no extra HTTP. + writable2 = client.get_post_schema("adapter/") + assert writable2 is writable + assert mock_request.call_count == 1 + + +def test_get_post_schema_handles_missing_actions_block(): + client, _ = _client_with_mock(payload={}) + assert client.get_post_schema("connector/") == frozenset() + + +def test_close_shuts_session(): + client = PlatformClient(_endpoint()) + sess = client._session + sess.close = MagicMock() + client.close() + sess.close.assert_called_once() + + +def test_context_manager_closes_on_exit(): + with PlatformClient(_endpoint()) as client: + client._session.close = MagicMock() + sess_close = client._session.close + sess_close.assert_called_once() + + +def test_list_endpoint_unwraps_paginated_envelope(): + client, _ = _client_with_mock(payload={"results": [{"id": "a"}, {"id": "b"}]}) + items = client.list_tags() + assert [i["id"] for i in items] == ["a", "b"] + + +def test_list_endpoint_accepts_bare_list(): + client, _ = _client_with_mock(payload=[{"id": "a"}]) + items = client.list_tags() + assert items == [{"id": "a"}] + + +def test_options_response_with_null_body_still_yields_empty_schema(): + # Some deployments return 200 with no body on OPTIONS. + client, _ = _client_with_mock(payload=None, text="") + assert client.get_post_schema("pipeline/") == frozenset() diff --git a/tests/clone/test_connector_phase.py b/tests/clone/test_connector_phase.py new file mode 100644 index 0000000..28f17f2 --- /dev/null +++ b/tests/clone/test_connector_phase.py @@ -0,0 +1,200 @@ +"""Tests for ``ConnectorPhase``. + +Mirrors the adapter phase suite — happy path, idempotency, dry-run, +abort — plus connector-specific behavior: UCS auto-provisioned rows are +skipped without consulting the target. +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.connector import ConnectorPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + POST_SCHEMA = frozenset( + { + "connector_id", + "connector_name", + "connector_metadata", + "connector_version", + "connector_mode", + "connector_type", + "shared_to_org", + } + ) + + def __init__(self, connectors: list[dict] | None = None): + self.connectors: list[dict] = list(connectors or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_connectors(self, *, name=None, connector_type=None): + result = self.connectors + if name is not None: + result = [c for c in result if c["connector_name"] == name] + if connector_type is not None: + result = [c for c in result if c.get("connector_type") == connector_type] + return list(result) + + def get_connector(self, connector_pk): + for c in self.connectors: + if c["id"] == connector_pk: + return c + raise KeyError(connector_pk) + + def create_connector(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.connectors.append(new) + self.posts.append(new) + return new + + +def _src(id_, name, catalog_id="postgres|abc", ctype="INPUT"): + return { + "id": id_, + "connector_id": catalog_id, + "connector_name": name, + "connector_type": ctype, + "connector_version": "1.0", + "connector_metadata": {"host": "db.example.com", "password": "secret"}, + "shared_to_org": False, + } + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient([_src("src-a", "Prod PG"), _src("src-b", "Stg S3", "s3|xyz")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = CloneReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert result.skipped == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("connector", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("connector", "src-b") == tgt.posts[1]["id"] + + +def test_redacted_metadata_connector_skipped(): + """Source returning empty metadata (redacted by backend) is unmigratable — + skipped with no POST and no remap entry.""" + redacted = _src("src-ucs", "User Storage") + redacted["connector_metadata"] = {} # backend redaction signal + src = FakeClient([redacted]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = CloneReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("connector", "src-ucs") is None + + +def test_oauth_connector_skipped_before_post(): + """OAuth-backed connectors (metadata carries access_token/refresh_token) + would fail target POST with OAuthTimeOut — skip ahead of POST so the + operator re-authorises post-clone. + """ + oauth = _src("src-gdrive", "Unstract's google drive") + oauth["connector_metadata"] = { + "provider": "google-oauth2", + "uid": "src-user", + "access_token": "ya29.src-access", + "refresh_token": "1//src-refresh", + } + src = FakeClient([oauth]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + + result = ConnectorPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.created == 0 + assert result.failed == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("connector", "src-gdrive") is None + + +def test_idempotency_zero_creates_on_rerun(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "connector_id": "postgres|abc", + "connector_name": "Prod PG", + "connector_type": "INPUT", + "connector_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = CloneReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] + assert ctx.remap.resolve("connector", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = CloneReport() + + result = ConnectorPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("src-a", "Prod PG")]) + tgt = FakeClient( + [ + { + "id": "preexisting", + "connector_id": "postgres|abc", + "connector_name": "Prod PG", + "connector_type": "INPUT", + "connector_metadata": {}, + } + ] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + ConnectorPhase(ctx).run(report) diff --git a/tests/clone/test_custom_tool_phase.py b/tests/clone/test_custom_tool_phase.py new file mode 100644 index 0000000..30edfb1 --- /dev/null +++ b/tests/clone/test_custom_tool_phase.py @@ -0,0 +1,383 @@ +"""Tests for ``CustomToolPhase`` — project-transfer + sync-prompts based. + +Coverage: +- fresh path: ``export_project`` on source → ``import_project`` on + target with adapter ids resolved by looking up each source-profile + adapter NAME against the target via ``list_adapters(name=...)``. +- adopt path: existing target tool with matching name → + ``sync_prompts`` overwrites prompts; no profile/adapter writes. +- registry remap recorded after ``export_custom_tool``. +- dry-run: no writes on either side. +- abort on name conflict when option is set. +- missing target adapter fails the tool cleanly. +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.custom_tool import CustomToolPhase +from unstract.clone.report import CloneReport + +ADAPTER_NAMES = { + "llm": "gpt4", + "embedding_model": "ada-embed", + "vector_store": "pgvector", + "x2text": "llmw", +} +TGT_ADAPTER_IDS = { + "gpt4": "a1111111-1111-1111-1111-111111111111", + "ada-embed": "a2222222-2222-2222-2222-222222222222", + "pgvector": "a3333333-3333-3333-3333-333333333333", + "llmw": "a4444444-4444-4444-4444-444444444444", +} +SRC_REG = "55555555-5555-5555-5555-555555555555" + + +class FakeClient: + """In-memory stand-in for ``PlatformClient`` covering project-transfer.""" + + def __init__(self) -> None: + self.tools: dict[str, dict] = {} + self.profiles_by_tool: dict[str, list[dict]] = {} + self.export_blobs: dict[str, dict] = {} + self.registries_by_tool: dict[str, dict] = {} + self.adapters_by_name: dict[str, dict] = {} + # Call recorders. + self.import_calls: list[tuple[dict, dict | None]] = [] + self.sync_calls: list[tuple[str, dict, bool]] = [] + self.export_tool_calls: list[str] = [] + self._next = 1 + + def _mint(self, prefix: str) -> str: + s = f"tgt-{prefix}-{self._next:04d}" + self._next += 1 + return s + + # --- reads --- + def list_custom_tools(self) -> list[dict]: + return [ + {"tool_id": tid, "tool_name": t["tool_name"]} + for tid, t in self.tools.items() + ] + + def list_profiles(self, tool_id: str) -> list[dict]: + return list(self.profiles_by_tool.get(tool_id, [])) + + def export_project(self, tool_id: str) -> dict: + return self.export_blobs[tool_id] + + def list_adapters( + self, + *, + name: str | None = None, + adapter_type: str | None = None, + ) -> list[dict]: + if name is None: + return list(self.adapters_by_name.values()) + ad = self.adapters_by_name.get(name) + return [ad] if ad else [] + + def list_registries(self, *, custom_tool: str | None = None) -> list[dict]: + if custom_tool is None: + return list(self.registries_by_tool.values()) + reg = self.registries_by_tool.get(custom_tool) + return [reg] if reg else [] + + # --- writes --- + def import_project( + self, export_data: dict, adapter_ids: dict | None = None + ) -> dict: + self.import_calls.append((export_data, adapter_ids)) + tool_id = self._mint("tool") + tool_name = export_data["tool_metadata"]["tool_name"] + self.tools[tool_id] = {"tool_name": tool_name} + return { + "tool_id": tool_id, + "message": f"Project imported successfully as '{tool_name}'", + "needs_adapter_config": adapter_ids is None, + } + + def sync_prompts( + self, tool_id: str, export_data: dict, *, create_copy: bool = False + ) -> dict: + self.sync_calls.append((tool_id, export_data, create_copy)) + return { + "prompts_created": len(export_data.get("prompts", [])), + "prompts_deleted": 0, + "tool_settings_updated": True, + } + + def export_custom_tool(self, tool_id: str, *, force: bool = True) -> None: + self.export_tool_calls.append(tool_id) + self.registries_by_tool.setdefault( + tool_id, + {"prompt_registry_id": self._mint("registry"), "custom_tool": tool_id}, + ) + + +def _ctx(source, target, *, remap=None, **opt_overrides) -> CloneContext: + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _seed_target_adapters(target: FakeClient) -> None: + """ProfileManagerSerializer surfaces adapter NAMES — target must + expose name → id lookups for the phase to resolve them. + """ + for name, adapter_id in TGT_ADAPTER_IDS.items(): + target.adapters_by_name[name] = {"id": adapter_id, "adapter_name": name} + + +def _seed_source_adapters(source: FakeClient) -> None: + """Source-visible adapter set; phase uses it for frictionless detection.""" + for name in ADAPTER_NAMES.values(): + source.adapters_by_name[name] = {"id": f"src-{name}", "adapter_name": name} + + +def _src_default_profile(*, nested: bool = False) -> dict: + """Mirror the live ProfileManager serializer: adapter FKs render as + flat NAME strings. ``nested=True`` covers the alternate dict shape + in case backend behavior changes. + """ + if nested: + return { + "profile_id": "src-profile-1", + "profile_name": "Default", + "is_default": True, + "llm": {"adapter_name": ADAPTER_NAMES["llm"]}, + "embedding_model": {"adapter_name": ADAPTER_NAMES["embedding_model"]}, + "vector_store": {"adapter_name": ADAPTER_NAMES["vector_store"]}, + "x2text": {"adapter_name": ADAPTER_NAMES["x2text"]}, + } + return { + "profile_id": "src-profile-1", + "profile_name": "Default", + "is_default": True, + "llm": ADAPTER_NAMES["llm"], + "embedding_model": ADAPTER_NAMES["embedding_model"], + "vector_store": ADAPTER_NAMES["vector_store"], + "x2text": ADAPTER_NAMES["x2text"], + } + + +def _src_export_blob(tool_name: str) -> dict: + return { + "tool_metadata": { + "tool_name": tool_name, + "description": "x", + "author": "a", + "icon": None, + }, + "tool_settings": {"preamble": "p", "postamble": "q"}, + "default_profile_settings": { + "chunk_size": 1024, + "chunk_overlap": 128, + "retrieval_strategy": "simple", + "similarity_top_k": 3, + "section": "default", + "profile_name": "Default", + }, + "prompts": [ + { + "prompt_key": "field_a", + "prompt": "What is field_a?", + "sequence_number": 1, + } + ], + "export_metadata": {"exported_at": "2026-05-24T00:00:00Z"}, + } + + +def _preload_source_tool( + client: FakeClient, tool_id: str, tool_name: str, *, nested_profile: bool = False +) -> None: + client.tools[tool_id] = {"tool_name": tool_name} + client.profiles_by_tool[tool_id] = [_src_default_profile(nested=nested_profile)] + client.export_blobs[tool_id] = _src_export_blob(tool_name) + client.registries_by_tool[tool_id] = { + "prompt_registry_id": SRC_REG, + "custom_tool": tool_id, + } + + +def test_fresh_imports_with_name_resolved_adapter_ids_and_records_registry(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") + _seed_source_adapters(src) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + # Exactly one import_project call with the right export blob + name-resolved adapter ids. + assert len(tgt.import_calls) == 1 + blob, adapter_ids = tgt.import_calls[0] + assert blob["tool_metadata"]["tool_name"] == "Invoice Extractor" + assert adapter_ids == { + "llm_adapter_id": TGT_ADAPTER_IDS["gpt4"], + "vector_db_adapter_id": TGT_ADAPTER_IDS["pgvector"], + "embedding_adapter_id": TGT_ADAPTER_IDS["ada-embed"], + "x2text_adapter_id": TGT_ADAPTER_IDS["llmw"], + } + # No sync_prompts on fresh path. + assert tgt.sync_calls == [] + # Registry republish fired exactly once. + assert len(tgt.export_tool_calls) == 1 + tgt_tool_id = tgt.export_tool_calls[0] + + # Remap records populated for downstream phases. + assert ctx.remap.resolve("custom_tool", "src-tool-x") == tgt_tool_id + tgt_reg_id = tgt.registries_by_tool[tgt_tool_id]["prompt_registry_id"] + assert ctx.remap.resolve("prompt_studio_registry", SRC_REG) == tgt_reg_id + + +def test_nested_adapter_dict_also_resolves(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "T", nested_profile=True) + _seed_source_adapters(src) + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + CustomToolPhase(ctx).run(CloneReport()) + + _, adapter_ids = tgt.import_calls[0] + assert adapter_ids["llm_adapter_id"] == TGT_ADAPTER_IDS["gpt4"] + + +def test_adopt_path_calls_sync_prompts_and_skips_import(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") + # Target already has the tool with the same name. + tgt.tools["tgt-existing"] = {"tool_name": "Invoice Extractor"} + + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + # sync_prompts ran against the pre-existing target tool, not a new one. + assert len(tgt.sync_calls) == 1 + tool_id, blob, create_copy = tgt.sync_calls[0] + assert tool_id == "tgt-existing" + assert blob["tool_metadata"]["tool_name"] == "Invoice Extractor" + assert create_copy is False + # Import path never fired on adopt. + assert tgt.import_calls == [] + # Registry still republished against the adopted tool. + assert tgt.export_tool_calls == ["tgt-existing"] + assert ctx.remap.resolve("custom_tool", "src-tool-x") == "tgt-existing" + + +def test_abort_on_name_conflict_raises(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Conflict") + tgt.tools["tgt-existing"] = {"tool_name": "Conflict"} + + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + CustomToolPhase(ctx).run(CloneReport()) + + assert tgt.sync_calls == [] + assert tgt.import_calls == [] + + +def test_dry_run_makes_no_writes(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "T") + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, dry_run=True) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.import_calls == [] + assert tgt.sync_calls == [] + assert tgt.export_tool_calls == [] + + +def test_dry_run_on_adopt_path_does_not_republish_registry(): + # Adopt path used to return tgt_tool_id even on dry-run, falling + # through to export_custom_tool (a real POST to the target). + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Invoice Extractor") + tgt.tools["tgt-existing"] = {"tool_name": "Invoice Extractor"} + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt, dry_run=True) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.sync_calls == [] + assert tgt.import_calls == [] + # Critical regression: registry republish must NOT fire on dry-run. + assert tgt.export_tool_calls == [] + # Remap still recorded so downstream dry-run output stays coherent. + assert ctx.remap.resolve("custom_tool", "src-tool-x") == "tgt-existing" + + +def test_frictionless_adapter_dependence_skips_tool_and_records_for_cascade(): + """Source profile references an adapter NAME the source's + service-account view can't list (frictionless). Tool is skipped + cleanly and source registry id is recorded for WorkflowPhase to + cascade-skip dependent workflows. + """ + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "Frictionless-Bound Tool") + # Source-visible adapters cover 3 of 4 — llm "gpt4" hidden. + for name in ("ada-embed", "pgvector", "llmw"): + src.adapters_by_name[name] = {"id": f"src-{name}", "adapter_name": name} + _seed_target_adapters(tgt) + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.failed == 0 + assert tgt.import_calls == [] + assert tgt.export_tool_calls == [] + assert ctx.remap.resolve("custom_tool", "src-tool-x") is None + assert SRC_REG in ctx.skipped_custom_tool_registry_ids + + +def test_missing_target_adapter_fails_tool_cleanly(): + src = FakeClient() + tgt = FakeClient() + _preload_source_tool(src, "src-tool-x", "T") + _seed_source_adapters(src) + # Only seed 3 of 4 adapters → x2text lookup misses on target. + for name in ("gpt4", "ada-embed", "pgvector"): + tgt.adapters_by_name[name] = {"id": TGT_ADAPTER_IDS[name], "adapter_name": name} + ctx = _ctx(src, tgt) + + result = CustomToolPhase(ctx).run(CloneReport()) + + assert result.failed == 1 + assert tgt.import_calls == [] + # Registry republish should NOT fire when the tool fails. + assert tgt.export_tool_calls == [] + # No custom_tool remap recorded. + assert ctx.remap.resolve("custom_tool", "src-tool-x") is None diff --git a/tests/clone/test_files_phase.py b/tests/clone/test_files_phase.py new file mode 100644 index 0000000..50f739e --- /dev/null +++ b/tests/clone/test_files_phase.py @@ -0,0 +1,525 @@ +"""Tests for ``FilesPhase``. + +Coverage: +- happy path: PDF + text/csv files uploaded with base64 + utf-8 decoding. +- target-side idempotency: filename already present → skip, no upload. +- oversize file → ``oversize_files`` entry, sibling files continue. +- unsupported mime (Excel placeholder) → ``unsupported_files`` entry. +- skip strategy → no uploads, source filenames listed in ``skipped_files``. +- dry-run → no uploads even for missing files. +- transient 503 → retried, eventual success. +- no custom_tool remap → no-op. +- listing failure on source aborts only that tool, others continue. +""" + +from __future__ import annotations + +import base64 +from typing import Any + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + OrgEndpoint, + RemapTable, +) +from unstract.clone.exceptions import PlatformAPIError +from unstract.clone.phases.files import FilesPhase +from unstract.clone.report import CloneReport + +SRC_ENDPOINT = OrgEndpoint( + base_url="http://src", organization_id="src-org", platform_key="src-key" +) +TGT_ENDPOINT = OrgEndpoint( + base_url="http://tgt", organization_id="tgt-org", platform_key="tgt-key" +) + + +class FakeClient: + def __init__( + self, + *, + endpoint: OrgEndpoint, + documents: dict[str, list[dict]] | None = None, + file_payloads: dict[tuple[str, str], dict] | None = None, + tools: list[dict] | None = None, + ): + self.endpoint = endpoint + # tool_id -> list of {document_name, document_id, tool} + self._documents: dict[str, list[dict]] = { + k: list(v) for k, v in (documents or {}).items() + } + # (tool_id, file_name) -> {"data": ..., "mime_type": ...} + self._file_payloads: dict[tuple[str, str], dict] = dict(file_payloads or {}) + self._tools = list(tools or []) + self.uploaded: list[dict[str, Any]] = [] + self.list_calls: list[str] = [] + self.download_calls: list[tuple[str, str]] = [] + # Configurable fault injection. + self.download_errors: dict[tuple[str, str], list[Exception]] = {} + self.upload_errors: dict[tuple[str, str], list[Exception]] = {} + self.list_errors: dict[str, Exception] = {} + self._next_id = 1 + + def list_prompt_documents(self, tool_id: str) -> list[dict]: + self.list_calls.append(tool_id) + if tool_id in self.list_errors: + raise self.list_errors[tool_id] + return [dict(d) for d in self._documents.get(tool_id, [])] + + def download_prompt_file(self, tool_id: str, document_id: str) -> dict: + # Tests key payloads + error queues by (tool_id, file_name) for + # readability; resolve the filename from the documents list. + file_name = next( + ( + d["document_name"] + for d in self._documents.get(tool_id, []) + if d.get("document_id") == document_id + ), + document_id, + ) + self.download_calls.append((tool_id, file_name)) + queue = self.download_errors.get((tool_id, file_name)) + if queue: + raise queue.pop(0) + return dict(self._file_payloads[(tool_id, file_name)]) + + def upload_prompt_file( + self, tool_id: str, file_name: str, data: bytes, mime_type: str + ) -> dict: + queue = self.upload_errors.get((tool_id, file_name)) + if queue: + raise queue.pop(0) + doc_id = f"doc-{self._next_id:04d}" + self._next_id += 1 + self.uploaded.append( + { + "tool_id": tool_id, + "file_name": file_name, + "data": data, + "mime_type": mime_type, + } + ) + self._documents.setdefault(tool_id, []).append( + {"document_id": doc_id, "document_name": file_name, "tool": tool_id} + ) + return {"document_id": doc_id} + + def list_custom_tools(self) -> list[dict]: + return list(self._tools) + + def get_custom_tool(self, tool_id: str) -> dict: + return dict(next((t for t in self._tools if t.get("tool_id") == tool_id), {})) + + def update_custom_tool(self, tool_id: str, body: dict) -> dict: + for t in self._tools: + if t.get("tool_id") == tool_id: + t.update(body) + return dict(t) + return {} + + +def _ctx( + src: FakeClient, tgt: FakeClient, *, remap: RemapTable | None = None, **opts +) -> CloneContext: + remap = remap or RemapTable() + return CloneContext( + source=src, + target=tgt, + options=CloneOptions(**opts), + remap=remap, + ) + + +def _doc(name: str) -> dict: + return {"document_id": f"src-{name}", "document_name": name, "tool": "ignored"} + + +def _pdf_payload(raw: bytes) -> dict: + return {"data": base64.b64encode(raw).decode(), "mime_type": "application/pdf"} + + +def _text_payload(text: str, mime: str = "text/plain") -> dict: + return {"data": text, "mime_type": mime} + + +def test_happy_path_uploads_pdf_and_text(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("invoice.pdf"), _doc("notes.txt")]}, + file_payloads={ + ("src-1", "invoice.pdf"): _pdf_payload(b"%PDF-FAKE"), + ("src-1", "notes.txt"): _text_payload("hello world"), + }, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 2 + assert result.failed == 0 + assert {u["file_name"] for u in tgt.uploaded} == {"invoice.pdf", "notes.txt"} + pdf_upload = next(u for u in tgt.uploaded if u["file_name"] == "invoice.pdf") + assert pdf_upload["data"] == b"%PDF-FAKE" + assert pdf_upload["mime_type"] == "application/pdf" + txt_upload = next(u for u in tgt.uploaded if u["file_name"] == "notes.txt") + assert txt_upload["data"] == b"hello world" + assert len(report.uploaded_files) == 2 + assert all(u["tool_name"] == "demo" for u in report.uploaded_files) + + +def test_target_filename_present_is_skipped_no_download(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("invoice.pdf")]}, + file_payloads={("src-1", "invoice.pdf"): _pdf_payload(b"BYTES")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + documents={"tgt-1": [_doc("invoice.pdf")]}, + tools=[{"tool_id": "tgt-1", "tool_name": "demo"}], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.uploaded == [] + assert src.download_calls == [] # pre-check guards the download + + +def test_oversize_file_is_recorded_and_siblings_continue(): + big = b"X" * 50 + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("big.pdf"), _doc("small.txt")]}, + file_payloads={ + ("src-1", "big.pdf"): _pdf_payload(big), + ("src-1", "small.txt"): _text_payload("ok"), + }, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, max_file_size=10) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + # Oversize must bump skipped so the operator sees it surfaced in the + # phase counters, not only in the report's list. + assert result.skipped == 1 + assert result.failed == 0 + assert {u["file_name"] for u in tgt.uploaded} == {"small.txt"} + assert len(report.oversize_files) == 1 + over = report.oversize_files[0] + assert over["file_name"] == "big.pdf" + assert over["size_bytes"] == 50 + assert over["cap_bytes"] == 10 + + +def test_unsupported_mime_is_recorded_not_uploaded(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("sheet.xlsx")]}, + file_payloads={ + ("src-1", "sheet.xlsx"): { + "data": "Preview not available for Excel files. ...", + "mime_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + } + }, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 0 + # Unsupported mimes must bump skipped so the run doesn't report green + # while leaving files unmoved. + assert result.skipped == 1 + assert result.failed == 0 + assert tgt.uploaded == [] + assert len(report.unsupported_files) == 1 + entry = report.unsupported_files[0] + assert entry["file_name"] == "sheet.xlsx" + assert entry["mime_type"].startswith("application/vnd.openxmlformats") + + +def test_malformed_source_dm_row_bumps_skipped_with_error(): + # Renamed-field or partial-serializer response: row lacks + # document_name/document_id. Must surface, not silently disappear. + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [{"tool": "src-1"}, _doc("ok.pdf")]}, + file_payloads={("src-1", "ok.pdf"): _pdf_payload(b"BYTES")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 # the well-formed sibling still uploads. + assert result.skipped == 1 # the malformed row. + assert any("malformed source DM row" in e for e in result.errors) + + +def test_skip_strategy_emits_skipped_files_no_traffic(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf"), _doc("b.pdf")]}, + ) + tgt = FakeClient(endpoint=TGT_ENDPOINT) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, file_strategy="skip") + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 2 + assert tgt.uploaded == [] + assert src.download_calls == [] + names = {row["file_name"] for row in report.skipped_files} + assert names == {"a.pdf", "b.pdf"} + assert all(row["source_org_slug"] == "src-org" for row in report.skipped_files) + assert all(row["source_tool_id"] == "src-1" for row in report.skipped_files) + + +def test_dry_run_makes_no_writes_even_for_missing_files(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"X")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.uploaded == [] + assert src.download_calls == [] + + +def test_transient_503_is_retried_then_succeeds(monkeypatch): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"OK")}, + ) + src.download_errors[("src-1", "a.pdf")] = [ + PlatformAPIError("flaky", status_code=503, body="") + ] + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + # Strip the backoff sleep so the test stays fast. + monkeypatch.setattr("unstract.clone.phases.files.time.sleep", lambda *_: None) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + assert tgt.uploaded[0]["data"] == b"OK" + + +def test_no_custom_tool_remap_is_noop(): + src = FakeClient(endpoint=SRC_ENDPOINT) + tgt = FakeClient(endpoint=TGT_ENDPOINT) + ctx = _ctx(src, tgt) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 0 + assert result.skipped == 0 + assert src.list_calls == [] + + +def test_source_list_failure_isolates_to_that_tool(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-2": [_doc("ok.pdf")]}, + file_payloads={("src-2", "ok.pdf"): _pdf_payload(b"OK")}, + ) + src.list_errors["src-1"] = RuntimeError("source down for this tool") + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + tools=[ + {"tool_id": "tgt-1", "tool_name": "broken"}, + {"tool_id": "tgt-2", "tool_name": "healthy"}, + ], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + remap.record("custom_tool", "src-2", "tgt-2") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.failed == 1 + assert result.created == 1 + assert {u["file_name"] for u in tgt.uploaded} == {"ok.pdf"} + + +def test_upload_failure_records_failed_files_entry(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"X")}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + tgt.upload_errors[("tgt-1", "a.pdf")] = [ + PlatformAPIError("bad", status_code=400, body="bad") + ] + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.failed == 1 + assert result.created == 0 + assert len(report.failed_files) == 1 + entry = report.failed_files[0] + assert entry["file_name"] == "a.pdf" + assert "upload" in entry["error"] + + +@pytest.mark.parametrize( + "mime,raw", + [ + ("text/csv", "name,age\nalice,30"), + ("text/plain", "plain old text"), + ], +) +def test_text_mimes_round_trip_as_utf8(mime, raw): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("data")]}, + file_payloads={("src-1", "data"): _text_payload(raw, mime=mime)}, + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + report = CloneReport() + + result = FilesPhase(ctx).run(report) + + assert result.created == 1 + upload = tgt.uploaded[0] + assert upload["data"] == raw.encode("utf-8") + assert upload["mime_type"] == mime + + +def test_default_doc_mirrors_source_selection_by_filename(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf"), _doc("b.pdf")]}, + file_payloads={ + ("src-1", "a.pdf"): _pdf_payload(b"A"), + ("src-1", "b.pdf"): _pdf_payload(b"B"), + }, + # Source's selected doc is b.pdf (document_id="src-b.pdf"). + tools=[{"tool_id": "src-1", "tool_name": "demo", "output": "src-b.pdf"}], + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(CloneReport()) + + # Target's CustomTool.output now points at b.pdf's new target doc id. + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + output_id = tgt_tool["output"] + b_upload = next(d for d in tgt._documents["tgt-1"] if d["document_name"] == "b.pdf") + assert output_id == b_upload["document_id"] + + +def test_default_doc_falls_back_to_first_when_source_has_none(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"A")}, + # Source has no output set. + tools=[{"tool_id": "src-1", "tool_name": "demo"}], + ) + tgt = FakeClient( + endpoint=TGT_ENDPOINT, tools=[{"tool_id": "tgt-1", "tool_name": "demo"}] + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(CloneReport()) + + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + a_upload = next(d for d in tgt._documents["tgt-1"] if d["document_name"] == "a.pdf") + assert tgt_tool["output"] == a_upload["document_id"] + + +def test_default_doc_preserves_existing_target_choice(): + src = FakeClient( + endpoint=SRC_ENDPOINT, + documents={"src-1": [_doc("a.pdf")]}, + file_payloads={("src-1", "a.pdf"): _pdf_payload(b"A")}, + tools=[{"tool_id": "src-1", "tool_name": "demo", "output": "src-a.pdf"}], + ) + # Operator already picked a doc on target — re-run must not clobber. + tgt = FakeClient( + endpoint=TGT_ENDPOINT, + tools=[{"tool_id": "tgt-1", "tool_name": "demo", "output": "operator-pick"}], + ) + remap = RemapTable() + remap.record("custom_tool", "src-1", "tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + FilesPhase(ctx).run(CloneReport()) + + tgt_tool = next(t for t in tgt._tools if t["tool_id"] == "tgt-1") + assert tgt_tool["output"] == "operator-pick" diff --git a/tests/clone/test_orchestrator.py b/tests/clone/test_orchestrator.py new file mode 100644 index 0000000..b7a2626 --- /dev/null +++ b/tests/clone/test_orchestrator.py @@ -0,0 +1,159 @@ +"""End-to-end tests for the ``clone()`` orchestrator. + +Coverage: +- Phase ordering matches ``PHASES`` declaration. +- ``include`` / ``exclude`` route phases through ``skipped_phases``. +- ``CloneError`` raised by a phase aborts the run; subsequent phases skipped. +- Both ``PlatformClient`` instances are closed even when a phase aborts. +- ``RemapTable`` snapshot lands on the report. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from unstract.clone import orchestrator +from unstract.clone.context import CloneOptions, OrgEndpoint +from unstract.clone.exceptions import CloneError +from unstract.clone.phases.base import Phase +from unstract.clone.report import CloneReport, PhaseResult + + +class _RecordingPhase(Phase): + """Per-test phase factory; records invocation order on a shared list.""" + + invocations: list[str] = [] + name = "" + + def run(self, report: CloneReport) -> PhaseResult: + _RecordingPhase.invocations.append(self.name) + result = report.get_phase(self.name) + result.created += 1 + # Drop a remap entry so we can prove the snapshot lands on the report. + self.ctx.remap.record(self.name, f"src-{self.name}", f"tgt-{self.name}") + return result + + +def _make_phase(phase_name: str) -> type[Phase]: + return type( + f"FakePhase_{phase_name}", + (_RecordingPhase,), + {"name": phase_name}, + ) + + +@pytest.fixture(autouse=True) +def _reset_invocations(): + _RecordingPhase.invocations = [] + yield + _RecordingPhase.invocations = [] + + +@pytest.fixture +def fake_phases(): + """Replace PHASES with a small deterministic set for the test run.""" + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", _make_phase("connector")), + ("workflow", _make_phase("workflow")), + ] + with patch.object(orchestrator, "PHASES", fake): + yield fake + + +def _src() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://src.example.com", + organization_id="src_org", + platform_key="src-key", + ) + + +def _tgt() -> OrgEndpoint: + return OrgEndpoint( + base_url="https://tgt.example.com", + organization_id="tgt_org", + platform_key="tgt-key", + ) + + +def test_phases_run_in_declared_order(fake_phases): + with patch.object(orchestrator.PlatformClient, "close") as mock_close: + report = orchestrator.clone(_src(), _tgt()) + assert _RecordingPhase.invocations == ["adapter", "connector", "workflow"] + assert [p.name for p in report.phases] == ["adapter", "connector", "workflow"] + # Both clients must close (source + target) regardless of outcome. + assert mock_close.call_count == 2 + + +def test_include_filter_only_runs_listed_phases(fake_phases): + opts = CloneOptions(include=("connector",)) + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt(), opts) + assert _RecordingPhase.invocations == ["connector"] + assert set(report.skipped_phases) == {"adapter", "workflow"} + + +def test_exclude_filter_skips_listed_phases(fake_phases): + opts = CloneOptions(exclude=("workflow",)) + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt(), opts) + assert _RecordingPhase.invocations == ["adapter", "connector"] + assert report.skipped_phases == ["workflow"] + + +def test_clone_error_aborts_and_skips_subsequent_phases(): + class AbortingPhase(Phase): + name = "connector" + + def run(self, report: CloneReport) -> PhaseResult: + raise CloneError("name collision in 'connector'") + + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", AbortingPhase), + ("workflow", _make_phase("workflow")), + ] + with ( + patch.object(orchestrator, "PHASES", fake), + patch.object(orchestrator.PlatformClient, "close") as mock_close, + ): + report = orchestrator.clone(_src(), _tgt()) + + assert _RecordingPhase.invocations == ["adapter"] + assert report.aborted is True + assert "name collision" in report.abort_reason + # Clients still close on abort. + assert mock_close.call_count == 2 + + +def test_unrelated_exception_propagates_but_still_closes_clients(): + class CrashingPhase(Phase): + name = "connector" + + def run(self, report: CloneReport) -> PhaseResult: + raise RuntimeError("boom") + + fake = [ + ("adapter", _make_phase("adapter")), + ("connector", CrashingPhase), + ] + with ( + patch.object(orchestrator, "PHASES", fake), + patch.object(orchestrator.PlatformClient, "close") as mock_close, + ): + with pytest.raises(RuntimeError, match="boom"): + orchestrator.clone(_src(), _tgt()) + assert mock_close.call_count == 2 + + +def test_remap_snapshot_populated_on_report(fake_phases): + with patch.object(orchestrator.PlatformClient, "close"): + report = orchestrator.clone(_src(), _tgt()) + assert report.remap_snapshot == { + "adapter": {"src-adapter": "tgt-adapter"}, + "connector": {"src-connector": "tgt-connector"}, + "workflow": {"src-workflow": "tgt-workflow"}, + } diff --git a/tests/clone/test_phase_concurrency.py b/tests/clone/test_phase_concurrency.py new file mode 100644 index 0000000..0179f1b --- /dev/null +++ b/tests/clone/test_phase_concurrency.py @@ -0,0 +1,291 @@ +"""Thread-safety checks for ``Phase.parallel_map``. + +Coverage: +- Many-item fan-out produces exact counts + remap entries with no loss. +- Sequential path (``concurrency=1``) skips the thread pool entirely + while preserving identical behaviour. +- ``CloneError`` raised inside a worker propagates out of ``parallel_map`` + so the orchestrator's abort handling engages. +- A non-``CloneError`` exception inside a worker still propagates. + +We use a fake client that holds a lock around its own mutable state and +injects a small sleep per HTTP call to force real interleaving between +workers, then assert the phase's lock-guarded code keeps counters and +the remap table consistent. +""" + +from __future__ import annotations + +import threading +import time + +import pytest + +from unstract.clone.context import CloneContext, CloneOptions, RemapTable +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.adapter import AdapterPhase +from unstract.clone.phases.tag import TagPhase +from unstract.clone.report import CloneReport + + +class _ThreadSafeAdapterClient: + """Adapter FakeClient with a lock around mutable state + per-call sleep + so workers actually interleave under ThreadPoolExecutor. + """ + + POST_SCHEMA = frozenset( + { + "adapter_id", + "adapter_name", + "adapter_type", + "adapter_metadata", + "description", + } + ) + + def __init__(self, adapters=None, sleep_seconds: float = 0.005): + self._adapters: list[dict] = list(adapters or []) + self.posts: list[dict] = [] + self._next_id = 1 + self._lock = threading.Lock() + self._sleep = sleep_seconds + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_adapters(self, *, name=None, adapter_type=None): + time.sleep(self._sleep) + with self._lock: + snap = list(self._adapters) + result = snap + if name is not None: + result = [a for a in result if a["adapter_name"] == name] + if adapter_type is not None: + result = [a for a in result if a["adapter_type"] == adapter_type] + return [{k: v for k, v in a.items() if k != "adapter_metadata"} for a in result] + + def get_adapter(self, adapter_pk): + time.sleep(self._sleep) + with self._lock: + for a in self._adapters: + if a["id"] == adapter_pk: + return dict(a) + raise KeyError(adapter_pk) + + def create_adapter(self, payload): + time.sleep(self._sleep) + with self._lock: + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self._adapters.append(new) + self.posts.append(new) + return new + + +def _src_adapter(id_, name, atype="LLM"): + return { + "id": id_, + "adapter_id": "openai-llm-v2", + "adapter_name": name, + "adapter_type": atype, + "adapter_metadata": {"api_key": "sk-secret", "model": "gpt-4"}, + "description": f"{name} desc", + } + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_parallel_map_preserves_counts_with_many_items(): + items = 50 + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i:03d}", f"adapter-{i:03d}") for i in range(items)] + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == items + assert result.adopted == 0 + assert result.skipped == 0 + assert result.failed == 0 + assert len(tgt.posts) == items + remap = ctx.remap.snapshot().get("adapter", {}) + assert len(remap) == items + # Every source id should be mapped to a fresh target id. + assert set(remap.keys()) == {f"src-{i:03d}" for i in range(items)} + assert len(set(remap.values())) == items + + +def test_concurrency_one_runs_sequentially_with_no_executor(monkeypatch): + """With concurrency=1 we should never hit ThreadPoolExecutor.""" + sentinel = {"executor_used": False} + + import unstract.clone.phases.base as base_mod + + original = base_mod.ThreadPoolExecutor + + class _Forbidden: + def __init__(self, *a, **kw): + sentinel["executor_used"] = True + raise AssertionError("ThreadPoolExecutor must not be used at concurrency=1") + + monkeypatch.setattr(base_mod, "ThreadPoolExecutor", _Forbidden) + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i}", f"a-{i}") for i in range(5)] + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=1) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.created == 5 + assert sentinel["executor_used"] is False + # restore for any other tests in same module (monkeypatch undoes on teardown). + base_mod.ThreadPoolExecutor = original # noqa: F841 + + +class _AbortingAdapterClient(_ThreadSafeAdapterClient): + """As parent, but ``list_adapters`` claims the named adapter already + exists on target — used to trigger NameConflictError when the phase + is run with ``on_name_conflict='abort'``.""" + + def list_adapters(self, *, name=None, adapter_type=None): + time.sleep(self._sleep) + return [ + { + "id": "tgt-existing-0001", + "adapter_name": name or "x", + "adapter_type": adapter_type or "LLM", + } + ] + + +def test_clone_error_in_worker_propagates_under_concurrency(): + src = _ThreadSafeAdapterClient( + [_src_adapter(f"src-{i}", f"clash-{i}") for i in range(10)] + ) + tgt = _AbortingAdapterClient() + ctx = _ctx(src, tgt, concurrency=4, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + AdapterPhase(ctx).run(report) + + +class _UnexpectedAdapterClient(_ThreadSafeAdapterClient): + """One of the GETs blows up with a non-Clone RuntimeError.""" + + def __init__(self, *a, fail_on_name: str, **kw): + super().__init__(*a, **kw) + self._fail_on_name = fail_on_name + + def get_adapter(self, adapter_pk): + snap = super().get_adapter(adapter_pk) + if snap["adapter_name"] == self._fail_on_name: + raise RuntimeError("transport boom") + return snap + + +def test_non_clone_exception_recorded_as_failed_not_raised(): + """Workers convert non-Clone errors into ``result.failed`` counts; + they don't escape the phase. (CloneError is the abort signal — + arbitrary exceptions are per-item failures.)""" + src = _UnexpectedAdapterClient( + adapters=[_src_adapter(f"src-{i}", f"adapter-{i}") for i in range(10)], + fail_on_name="adapter-3", + ) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=4) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + + assert result.failed == 1 + # The other 9 still created successfully. + assert result.created == 9 + assert len(tgt.posts) == 9 + + +class _TagClient: + """Minimal tag fake with thread-safe state + per-call sleep.""" + + POST_SCHEMA = frozenset({"name", "description"}) + + def __init__(self, tags=None, sleep_seconds: float = 0.005): + self._tags: list[dict] = list(tags or []) + self.posts: list[dict] = [] + self._next_id = 1 + self._lock = threading.Lock() + self._sleep = sleep_seconds + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_tags(self, *, name=None): + time.sleep(self._sleep) + with self._lock: + snap = list(self._tags) + if name is not None: + snap = [t for t in snap if t["name"] == name] + return snap + + def create_tag(self, payload): + time.sleep(self._sleep) + with self._lock: + new = dict(payload) + new["id"] = f"tag-tgt-{self._next_id:04d}" + self._next_id += 1 + self._tags.append(new) + self.posts.append(new) + return new + + +def test_tag_phase_parallel_remap_table_consistent(): + """Distinct phase exercising the same parallel_map path — ensures the + helper isn't accidentally adapter-specific. + """ + src = _TagClient( + [{"id": f"tag-src-{i}", "name": f"tag-{i:03d}"} for i in range(30)] + ) + tgt = _TagClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 30 + assert result.failed == 0 + remap = ctx.remap.snapshot().get("tag", {}) + assert len(remap) == 30 + # remap value uniqueness — no two source tags mapped to the same target id. + assert len(set(remap.values())) == 30 + + +def test_parallel_map_empty_input_no_executor(monkeypatch): + """No items → no thread pool, no work.""" + import unstract.clone.phases.base as base_mod + + class _Forbidden: + def __init__(self, *a, **kw): + raise AssertionError("Should not create pool for empty input") + + monkeypatch.setattr(base_mod, "ThreadPoolExecutor", _Forbidden) + src = _ThreadSafeAdapterClient([]) + tgt = _ThreadSafeAdapterClient() + ctx = _ctx(src, tgt, concurrency=8) + report = CloneReport() + + result = AdapterPhase(ctx).run(report) + assert result.created == 0 + assert result.adopted == 0 diff --git a/tests/clone/test_pipeline_phase.py b/tests/clone/test_pipeline_phase.py new file mode 100644 index 0000000..e69c3cb --- /dev/null +++ b/tests/clone/test_pipeline_phase.py @@ -0,0 +1,265 @@ +"""Tests for ``PipelinePhase``. + +Coverage: +- happy path: source ETL/TASK pipelines created with workflow FK remapped. +- DEFAULT and APP types are skipped (out of clone scope). +- adopt path on name conflict. +- skipped when workflow remap missing. +- dry-run is a no-op. +- abort raises ``NameConflictError``. +- extra source keys produce a warning, never a failure. +""" + +from __future__ import annotations + +import logging + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.pipeline import PipelinePhase +from unstract.clone.report import CloneReport + +PIPELINE_POST_SCHEMA = frozenset( + { + "pipeline_name", + "workflow", + "pipeline_type", + "cron_string", + "app_id", + "app_icon", + "app_url", + "access_control_bundle_id", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, pipelines: list[dict] | None = None): + self.pipelines: list[dict] = list(pipelines or []) + self.posts: list[dict] = [] + self.keys_by_pipeline: dict[str, list[dict]] = {} + self._next = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return PIPELINE_POST_SCHEMA + + def list_pipelines( + self, *, name: str | None = None, pipeline_type: str | None = None + ): + result = self.pipelines + if name is not None: + result = [p for p in result if p["pipeline_name"] == name] + if pipeline_type is not None: + result = [p for p in result if p.get("pipeline_type") == pipeline_type] + return list(result) + + def get_pipeline(self, pipeline_id: str) -> dict: + for p in self.pipelines: + if p["id"] == pipeline_id: + return dict(p) + raise KeyError(pipeline_id) + + def create_pipeline(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-pipeline-{self._next:04d}" + self._next += 1 + self.pipelines.append(new) + self.posts.append(new) + return new + + def list_pipeline_keys(self, pipeline_id: str) -> list[dict]: + return list(self.keys_by_pipeline.get(pipeline_id, [])) + + +def _src_pipeline( + id_: str, + name: str, + workflow_id: str, + *, + pipeline_type: str = "ETL", + cron_string: str | None = None, +) -> dict: + return { + "id": id_, + "pipeline_name": name, + "workflow": workflow_id, + "workflow_id": workflow_id, + "workflow_name": "wf", + "pipeline_type": pipeline_type, + "active": True, + "scheduled": cron_string is not None, + "cron_string": cron_string, + "app_id": None, + "app_icon": None, + "app_url": None, + "access_control_bundle_id": None, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_pipeline_with_remapped_workflow(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + posted = tgt.posts[0] + assert posted["pipeline_name"] == "Daily Invoices" + assert posted["workflow"] == "wf-tgt-1" + assert ctx.remap.resolve("pipeline", "src-pl-1") == posted["id"] + + +def test_create_uses_per_id_get_not_stripped_list_payload(): + # list_pipelines can omit fields the create serializer expects. Phase + # must re-fetch the full record via get_pipeline before POSTing. + full = _src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1") + full["cron_string"] = "0 5 * * *" # only present on detail serializer. + stripped = {k: v for k, v in full.items() if k not in ("cron_string",)} + + class StripListFakeClient(FakeClient): + def list_pipelines(self, *, name=None, pipeline_type=None): + base = ( + [stripped] + if ( + (name is None or stripped["pipeline_name"] == name) + and ( + pipeline_type is None + or stripped["pipeline_type"] == pipeline_type + ) + ) + else [] + ) + return list(base) + + def get_pipeline(self, pipeline_id): + assert pipeline_id == full["id"] + return dict(full) + + src = StripListFakeClient([full]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + PipelinePhase(ctx).run(CloneReport()) + + posted = tgt.posts[0] + # cron_string only existed on the detail GET — proves we did NOT + # POST the stripped list-item payload. + assert posted["cron_string"] == "0 5 * * *" + + +def test_default_and_app_pipeline_types_are_skipped(): + src = FakeClient( + [ + _src_pipeline( + "src-1", "default-legacy", "wf-src-1", pipeline_type="DEFAULT" + ), + _src_pipeline("src-2", "streamlit-app", "wf-src-1", pipeline_type="APP"), + _src_pipeline("src-3", "real-etl", "wf-src-1", pipeline_type="ETL"), + ] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert len(tgt.posts) == 1 + assert tgt.posts[0]["pipeline_name"] == "real-etl" + + +def test_adopts_existing_pipeline_by_name(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "pipeline_name": "Daily Invoices"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("pipeline", "src-pl-1") == "tgt-existing" + + +def test_skipped_when_workflow_remap_missing(): + src = FakeClient([_src_pipeline("src-pl-1", "Orphan", "wf-src-1")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) # No workflow remap. + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.failed == 0 + assert tgt.posts == [] + + +def test_dry_run_makes_no_writes(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, dry_run=True) + + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + tgt = FakeClient([{"id": "tgt-existing", "pipeline_name": "Daily Invoices"}]) + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + PipelinePhase(ctx).run(CloneReport()) + + +def test_extra_source_keys_log_warning_not_failure(caplog): + src = FakeClient([_src_pipeline("src-pl-1", "Daily Invoices", "wf-src-1")]) + src.keys_by_pipeline["src-pl-1"] = [ + {"id": "k1", "is_active": True}, + {"id": "k2", "is_active": True}, + {"id": "k3", "is_active": False}, + ] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", "wf-src-1", "wf-tgt-1") + ctx = _ctx(src, tgt, remap=remap) + + with caplog.at_level(logging.WARNING, logger="unstract.clone.phases.pipeline"): + result = PipelinePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + assert any("2 active API keys" in r.message for r in caplog.records) diff --git a/tests/clone/test_remap_table.py b/tests/clone/test_remap_table.py new file mode 100644 index 0000000..6a13754 --- /dev/null +++ b/tests/clone/test_remap_table.py @@ -0,0 +1,36 @@ +"""Tests for ``RemapTable``.""" + +from unstract.clone.context import RemapTable + + +def test_record_and_resolve_per_entity(): + t = RemapTable() + t.record("adapter", "src-1", "tgt-1") + t.record("adapter", "src-2", "tgt-2") + t.record("connector", "src-1", "tgt-99") + + assert t.resolve("adapter", "src-1") == "tgt-1" + assert t.resolve("adapter", "src-2") == "tgt-2" + assert t.resolve("connector", "src-1") == "tgt-99" + + +def test_resolve_missing_returns_none(): + t = RemapTable() + assert t.resolve("adapter", "nope") is None + assert t.resolve_any("nope") is None + + +def test_resolve_any_searches_across_entities(): + t = RemapTable() + t.record("adapter", "src-a", "tgt-a") + t.record("workflow", "src-w", "tgt-w") + assert t.resolve_any("src-a") == "tgt-a" + assert t.resolve_any("src-w") == "tgt-w" + + +def test_snapshot_is_independent_copy(): + t = RemapTable() + t.record("adapter", "src-1", "tgt-1") + snap = t.snapshot() + t.record("adapter", "src-2", "tgt-2") + assert "src-2" not in snap["adapter"] diff --git a/tests/clone/test_tag_phase.py b/tests/clone/test_tag_phase.py new file mode 100644 index 0000000..f6086a9 --- /dev/null +++ b/tests/clone/test_tag_phase.py @@ -0,0 +1,109 @@ +"""Tests for ``TagPhase``. + +Tag is the simplest entity — no encryption, no list-vs-detail divergence. +Suite covers happy / idempotency / dry-run / abort. +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.tag import TagPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + POST_SCHEMA = frozenset({"name", "description"}) + + def __init__(self, tags: list[dict] | None = None): + self.tags: list[dict] = list(tags or []) + self.posts: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path): + return self.POST_SCHEMA + + def list_tags(self, *, name=None): + result = self.tags + if name is not None: + result = [t for t in result if t["name"] == name] + return list(result) + + def create_tag(self, payload): + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.tags.append(new) + self.posts.append(new) + return new + + +def _src(id_, name): + return {"id": id_, "name": name, "description": f"{name} desc"} + + +def _ctx(source, target, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=RemapTable(), + ) + + +def test_happy_path_creates_all_and_records_remap(): + src = FakeClient([_src("src-a", "billing"), _src("src-b", "finance")]) + tgt = FakeClient() + ctx = _ctx(src, tgt) + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 2 + assert result.adopted == 0 + assert len(tgt.posts) == 2 + assert ctx.remap.resolve("tag", "src-a") == tgt.posts[0]["id"] + assert ctx.remap.resolve("tag", "src-b") == tgt.posts[1]["id"] + + +def test_idempotency_zero_creates_on_rerun(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient([{"id": "preexisting", "name": "billing", "description": "x"}]) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.created == 0 + assert result.adopted == 1 + assert tgt.posts == [] + assert ctx.remap.resolve("tag", "src-a") == "preexisting" + + +def test_dry_run_makes_no_posts(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + report = CloneReport() + + result = TagPhase(ctx).run(report) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("src-a", "billing")]) + tgt = FakeClient([{"id": "preexisting", "name": "billing", "description": "x"}]) + ctx = _ctx(src, tgt, on_name_conflict="abort") + report = CloneReport() + + with pytest.raises(NameConflictError): + TagPhase(ctx).run(report) diff --git a/tests/clone/test_tool_instance_phase.py b/tests/clone/test_tool_instance_phase.py new file mode 100644 index 0000000..180c285 --- /dev/null +++ b/tests/clone/test_tool_instance_phase.py @@ -0,0 +1,248 @@ +"""Tests for ``ToolInstancePhase``. + +ToolInstance is unique among phases: +- The source list of "things to clone" comes from the workflow remap + table, not a top-level entity list. +- Create is a two-step dance (POST bare, PATCH metadata) because the + backend rebuilds metadata from defaults on POST. +""" + +from __future__ import annotations + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.phases.tool_instance import ToolInstancePhase +from unstract.clone.report import CloneReport + + +class FakeClient: + def __init__(self) -> None: + # Keyed by workflow_id -> list of tool_instances. + self.instances: dict[str, list[dict]] = {} + self.create_calls: list[dict] = [] + self.patch_calls: list[tuple[str, dict]] = [] + self._next = 1 + + def _mint(self) -> str: + s = f"tgt-ti-{self._next:04d}" + self._next += 1 + return s + + def list_tool_instances(self, *, workflow_id: str | None = None) -> list[dict]: + if workflow_id is None: + return [ti for instances in self.instances.values() for ti in instances] + return list(self.instances.get(workflow_id, [])) + + def create_tool_instance(self, payload: dict) -> dict: + wf = payload["workflow_id"] + new = {**payload, "id": self._mint(), "metadata": {"defaults": True}} + self.instances.setdefault(wf, []).append(new) + self.create_calls.append(new) + return new + + def update_tool_instance_metadata(self, instance_id: str, metadata: dict) -> dict: + self.patch_calls.append((instance_id, metadata)) + for wf_instances in self.instances.values(): + for ti in wf_instances: + if ti["id"] == instance_id: + ti["metadata"] = metadata + return ti + raise KeyError(instance_id) + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def _src_ti(ti_id: str, wf_id: str, tool_id: str, metadata: dict) -> dict: + return { + "id": ti_id, + "workflow": wf_id, + "tool_id": tool_id, + "metadata": metadata, + "step": 1, + } + + +SRC_WF = "10000000-0000-0000-0000-000000000001" +TGT_WF = "20000000-0000-0000-0000-000000000001" +SRC_REG = "30000000-0000-0000-0000-000000000001" +TGT_REG = "40000000-0000-0000-0000-000000000001" + + +def _seed_remap() -> RemapTable: + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + remap.record("prompt_studio_registry", SRC_REG, TGT_REG) + return remap + + +def test_happy_path_creates_instance_then_patches_metadata(): + src = FakeClient() + src.instances[SRC_WF] = [ + _src_ti( + "src-ti-1", + SRC_WF, + SRC_REG, + { + "llm": "My OpenAI", + "embedding": "MyEmb", + # Identity fields that the backend populated server-side + # at source create time — must NOT cross the org boundary. + "tenant_id": "src-org", + "prompt_registry_id": "src-registry-uuid", + "tool_instance_id": "src-ti-1-pk", + }, + ) + ] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + assert len(tgt.create_calls) == 1 + posted = tgt.create_calls[0] + assert posted["workflow_id"] == TGT_WF + assert posted["tool_id"] == TGT_REG + # PATCH carries source settings but stamps identity fields with + # target values — backend PATCH overwrites the whole metadata dict. + assert len(tgt.patch_calls) == 1 + patched_id, patched_metadata = tgt.patch_calls[0] + assert patched_id == posted["id"] + assert patched_metadata == { + "llm": "My OpenAI", + "embedding": "MyEmb", + "prompt_registry_id": TGT_REG, + "tool_instance_id": posted["id"], + } + assert ctx.remap.resolve("tool_instance", "src-ti-1") == posted["id"] + + +def test_skip_when_registry_remap_missing(): + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, "unknown-reg", {})] + tgt = FakeClient() + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + # No prompt_studio_registry remap entry → SDK must skip. + ctx = _ctx(src, tgt, remap=remap) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.created == 0 + assert tgt.create_calls == [] + + +def test_adopt_existing_target_instance_and_repatch_metadata(): + src = FakeClient() + src_meta = {"llm": "My OpenAI"} + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, src_meta)] + tgt = FakeClient() + tgt.instances[TGT_WF] = [ + {"id": "tgt-pre-ti", "workflow": TGT_WF, "tool_id": TGT_REG, "metadata": {}} + ] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.create_calls == [] + # PATCH still fires on adopt and stamps identity fields with target + # values so the runtime can resolve the registry. + assert tgt.patch_calls == [ + ( + "tgt-pre-ti", + { + "llm": "My OpenAI", + "prompt_registry_id": TGT_REG, + "tool_instance_id": "tgt-pre-ti", + }, + ) + ] + assert ctx.remap.resolve("tool_instance", "src-ti-1") == "tgt-pre-ti" + + +def test_no_op_when_no_workflows_in_remap(): + src = FakeClient() + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=RemapTable()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.created == 0 + assert result.skipped == 0 + assert tgt.create_calls == [] + + +def test_broken_adapter_refs_bumps_skipped_and_records_error(): + src = FakeClient() + src.instances[SRC_WF] = [ + _src_ti( + "src-ti-1", + SRC_WF, + SRC_REG, + {"llm": "[DELETED ADAPTER] My OpenAI", "embedding": "MyEmb"}, + ) + ] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.skipped == 1 + assert tgt.patch_calls == [] + assert any("stale adapter refs" in e for e in result.errors) + + +def test_dry_run_does_not_create_or_patch(): + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, {"x": 1})] + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.create_calls == [] + assert tgt.patch_calls == [] + + +def test_dry_run_on_adopt_path_does_not_repatch_target(): + # Target already has a tool_instance for the target workflow. On a + # dry-run, we must NOT PATCH its metadata — the adopt branch used to + # fall through to the PATCH call. + src = FakeClient() + src.instances[SRC_WF] = [_src_ti("src-ti-1", SRC_WF, SRC_REG, {"llm": "My OpenAI"})] + tgt = FakeClient() + tgt.instances[TGT_WF] = [ + { + "id": "tgt-pre-ti", + "workflow": TGT_WF, + "tool_id": TGT_REG, + "metadata": {"existing": "untouched"}, + "step": 1, + } + ] + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = ToolInstancePhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert result.adopted == 0 + assert tgt.create_calls == [] + assert tgt.patch_calls == [] + # Remap still gets recorded so downstream dry-run output is coherent. + assert ctx.remap.resolve("tool_instance", "src-ti-1") == "tgt-pre-ti" diff --git a/tests/clone/test_walker.py b/tests/clone/test_walker.py new file mode 100644 index 0000000..5ae0301 --- /dev/null +++ b/tests/clone/test_walker.py @@ -0,0 +1,54 @@ +"""Tests for ``remap_uuids``.""" + +from unstract.clone.context import RemapTable +from unstract.clone.walker import remap_uuids + +SRC_A = "11111111-1111-1111-1111-111111111111" +TGT_A = "22222222-2222-2222-2222-222222222222" +SRC_B = "33333333-3333-3333-3333-333333333333" +TGT_B = "44444444-4444-4444-4444-444444444444" +UNRELATED = "55555555-5555-5555-5555-555555555555" + + +def _populated_remap(): + t = RemapTable() + t.record("adapter", SRC_A, TGT_A) + t.record("workflow", SRC_B, TGT_B) + return t + + +def test_remaps_mapped_uuid_string(): + assert remap_uuids(SRC_A, _populated_remap()) == TGT_A + + +def test_leaves_unmapped_uuid_untouched(): + assert remap_uuids(UNRELATED, _populated_remap()) == UNRELATED + + +def test_leaves_non_uuid_string_alone(): + assert remap_uuids("hello-world", _populated_remap()) == "hello-world" + + +def test_remaps_inside_nested_dict_and_list(): + payload = { + "id": SRC_A, + "config": { + "refs": [SRC_B, "not-a-uuid", UNRELATED], + "nested": {"adapter_id": SRC_A}, + }, + "count": 42, + } + result = remap_uuids(payload, _populated_remap()) + assert result == { + "id": TGT_A, + "config": { + "refs": [TGT_B, "not-a-uuid", UNRELATED], + "nested": {"adapter_id": TGT_A}, + }, + "count": 42, + } + + +def test_handles_non_string_scalars(): + payload = {"a": 1, "b": True, "c": None, "d": 3.14} + assert remap_uuids(payload, _populated_remap()) == payload diff --git a/tests/clone/test_workflow_endpoint_phase.py b/tests/clone/test_workflow_endpoint_phase.py new file mode 100644 index 0000000..811488f --- /dev/null +++ b/tests/clone/test_workflow_endpoint_phase.py @@ -0,0 +1,242 @@ +"""Tests for ``WorkflowEndpointPhase``. + +WorkflowEndpoints are PATCH-only — backend auto-creates them on workflow +POST. Tests verify that the SDK: +- pairs source/target endpoints by ``endpoint_type``; +- remaps the embedded ``connector_instance`` UUID; +- walker-rewrites UUIDs nested in ``configuration``; +- silently leaves connector_instance_id null when no remap exists. +""" + +from __future__ import annotations + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.phases.workflow_endpoint import WorkflowEndpointPhase +from unstract.clone.report import CloneReport + + +class FakeClient: + def __init__(self) -> None: + self.endpoints: dict[str, list[dict]] = {} + self.patch_calls: list[tuple[str, dict]] = [] + + def list_workflow_endpoints(self, *, workflow_id: str | None = None) -> list[dict]: + if workflow_id is None: + return [ep for eps in self.endpoints.values() for ep in eps] + return list(self.endpoints.get(workflow_id, [])) + + def update_workflow_endpoint(self, endpoint_id: str, payload: dict) -> dict: + self.patch_calls.append((endpoint_id, payload)) + for eps in self.endpoints.values(): + for ep in eps: + if ep["id"] == endpoint_id: + ep.update(payload) + return ep + raise KeyError(endpoint_id) + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +SRC_WF = "10000000-0000-0000-0000-000000000001" +TGT_WF = "20000000-0000-0000-0000-000000000001" +SRC_CONN = "30000000-0000-0000-0000-000000000001" +TGT_CONN = "40000000-0000-0000-0000-000000000001" + + +def _src_endpoint(ep_id, etype, connector_id, configuration): + return { + "id": ep_id, + "workflow": SRC_WF, + "endpoint_type": etype, + "connection_type": "FILESYSTEM", + "configuration": configuration, + "connector_instance": {"id": connector_id, "connector_name": "src-conn"}, + } + + +def _tgt_endpoint(ep_id, etype): + return { + "id": ep_id, + "workflow": TGT_WF, + "endpoint_type": etype, + "connection_type": "", + "configuration": {}, + "connector_instance": None, + } + + +def _seed_remap() -> RemapTable: + remap = RemapTable() + remap.record("workflow", SRC_WF, TGT_WF) + remap.record("connector", SRC_CONN, TGT_CONN) + return remap + + +def test_pairs_endpoints_by_type_and_remaps_connector(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + _src_endpoint( + "src-ep-source", + "SOURCE", + SRC_CONN, + {"connector_id": SRC_CONN, "path": "/in"}, + ), + _src_endpoint( + "src-ep-dest", + "DESTINATION", + SRC_CONN, + {"connector_id": SRC_CONN, "path": "/out"}, + ), + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [ + _tgt_endpoint("tgt-ep-source", "SOURCE"), + _tgt_endpoint("tgt-ep-dest", "DESTINATION"), + ] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.created == 2 + assert result.failed == 0 + assert len(tgt.patch_calls) == 2 + + patches_by_id = dict(tgt.patch_calls) + src_patch = patches_by_id["tgt-ep-source"] + assert src_patch["connection_type"] == "FILESYSTEM" + assert src_patch["connector_instance_id"] == TGT_CONN + assert src_patch["configuration"]["connector_id"] == TGT_CONN + assert src_patch["configuration"]["path"] == "/in" + + dst_patch = patches_by_id["tgt-ep-dest"] + assert dst_patch["configuration"]["path"] == "/out" + assert dst_patch["connector_instance_id"] == TGT_CONN + + assert ctx.remap.resolve("workflow_endpoint", "src-ep-source") == "tgt-ep-source" + assert ctx.remap.resolve("workflow_endpoint", "src-ep-dest") == "tgt-ep-dest" + + +def test_endpoint_with_null_connection_type_omits_key_in_payload(): + # Source had connection_type=None (rare but legal on the model). + # Must NOT coerce to "" — backend treats blank as a validation + # failure on the enum. Omit the key entirely so backend keeps the + # existing target value. + src = FakeClient() + src.endpoints[SRC_WF] = [ + { + "id": "src-ep-source", + "workflow": SRC_WF, + "endpoint_type": "SOURCE", + "connection_type": None, + "configuration": {}, + "connector_instance": None, + } + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.failed == 0 + assert len(tgt.patch_calls) == 1 + _, payload = tgt.patch_calls[0] + assert "connection_type" not in payload + + +def test_endpoint_without_source_connector_patches_with_null(): + src = FakeClient() + src.endpoints[SRC_WF] = [ + { + "id": "src-ep-source", + "endpoint_type": "SOURCE", + "connection_type": "API", + "configuration": {"foo": "bar"}, + "connector_instance": None, + } + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert len(tgt.patch_calls) == 1 + _, payload = tgt.patch_calls[0] + assert payload["connector_instance_id"] is None + assert payload["configuration"] == {"foo": "bar"} + + +def test_unknown_connector_uuid_skips_endpoint_and_flags_error(): + """Source had a connector but its remap is missing — patching with + connector=None would silently detach the endpoint on target. Skip + the PATCH and record an operator-visible error entry instead. + """ + src = FakeClient() + src.endpoints[SRC_WF] = [ + _src_endpoint( + "src-ep-source", + "SOURCE", + "unmapped-but-uuid-99999999-9999-9999-9999-999999999999"[:36], + {}, + ) + ] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.created == 0 + assert result.skipped == 1 + assert tgt.patch_calls == [] + assert any("unmapped connector" in e for e in result.errors) + + +def test_missing_target_endpoint_fails_loudly(): + src = FakeClient() + src.endpoints[SRC_WF] = [_src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {})] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [] # No endpoints — anomaly. + ctx = _ctx(src, tgt, remap=_seed_remap()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.failed == 1 + assert tgt.patch_calls == [] + + +def test_dry_run_makes_no_patches(): + src = FakeClient() + src.endpoints[SRC_WF] = [_src_endpoint("src-ep-source", "SOURCE", SRC_CONN, {})] + tgt = FakeClient() + tgt.endpoints[TGT_WF] = [_tgt_endpoint("tgt-ep-source", "SOURCE")] + ctx = _ctx(src, tgt, remap=_seed_remap(), dry_run=True) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.patch_calls == [] + + +def test_no_workflows_in_remap_is_noop(): + src = FakeClient() + tgt = FakeClient() + ctx = _ctx(src, tgt, remap=RemapTable()) + + result = WorkflowEndpointPhase(ctx).run(CloneReport()) + + assert result.created == 0 + assert tgt.patch_calls == [] diff --git a/tests/clone/test_workflow_phase.py b/tests/clone/test_workflow_phase.py new file mode 100644 index 0000000..b29aca8 --- /dev/null +++ b/tests/clone/test_workflow_phase.py @@ -0,0 +1,191 @@ +"""Tests for ``WorkflowPhase``. + +Coverage: +- happy path: source workflow created on target, connector UUIDs in + ``source_settings`` / ``destination_settings`` rewritten via walker. +- idempotency: re-run on existing target adopts and doesn't duplicate. +- dry-run: no POST. +- abort on name conflict. +""" + +from __future__ import annotations + +import pytest + +from unstract.clone.context import ( + CloneContext, + CloneOptions, + RemapTable, +) +from unstract.clone.exceptions import NameConflictError +from unstract.clone.phases.workflow import WorkflowPhase +from unstract.clone.report import CloneReport + + +WORKFLOW_POST_SCHEMA = frozenset( + { + "workflow_name", + "description", + "is_active", + "deployment_type", + "source_settings", + "destination_settings", + "max_file_execution_count", + "shared_users", + "shared_to_org", + } +) + + +class FakeClient: + def __init__(self, workflows: list[dict] | None = None): + self.workflows: list[dict] = list(workflows or []) + self.posts: list[dict] = [] + self.tool_instances: list[dict] = [] + self._next_id = 1 + + def get_post_schema(self, entity_path: str) -> frozenset[str]: + return WORKFLOW_POST_SCHEMA + + def list_workflows(self, *, name: str | None = None): + result = self.workflows + if name is not None: + result = [w for w in result if w["workflow_name"] == name] + return list(result) + + def get_workflow(self, workflow_id: str) -> dict: + for w in self.workflows: + if w["id"] == workflow_id: + return dict(w) + raise KeyError(workflow_id) + + def list_tool_instances(self, *, workflow_id: str | None = None) -> list[dict]: + if workflow_id is None: + return list(self.tool_instances) + return [ti for ti in self.tool_instances if ti.get("workflow") == workflow_id] + + def create_workflow(self, payload: dict) -> dict: + new = dict(payload) + new["id"] = f"tgt-{self._next_id:08d}-0000-0000-0000-000000000000" + self._next_id += 1 + self.workflows.append(new) + self.posts.append(new) + return new + + +def _src(id_, name, *, source_settings=None, destination_settings=None): + return { + "id": id_, + "workflow_name": name, + "description": f"{name} desc", + "is_active": True, + "deployment_type": "DEFAULT", + "source_settings": source_settings or {}, + "destination_settings": destination_settings or {}, + "max_file_execution_count": None, + "shared_users": [], + "shared_to_org": False, + } + + +def _ctx(source, target, *, remap=None, **opt_overrides): + return CloneContext( + source=source, + target=target, + options=CloneOptions(**opt_overrides), + remap=remap or RemapTable(), + ) + + +def test_happy_path_creates_workflow_and_remaps_connector_uuids(): + src_conn = "11111111-1111-1111-1111-111111111111" + tgt_conn = "a1111111-1111-1111-1111-111111111111" + src = FakeClient( + [ + _src( + "wf-src-1", + "Invoice ETL", + source_settings={"connector_id": src_conn, "extras": {"a": 1}}, + destination_settings={"connector_id": src_conn}, + ) + ] + ) + tgt = FakeClient() + remap = RemapTable() + remap.record("connector", src_conn, tgt_conn) + ctx = _ctx(src, tgt, remap=remap) + + result = WorkflowPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.failed == 0 + assert len(tgt.posts) == 1 + posted = tgt.posts[0] + # Walker rewrote both occurrences of the source connector UUID. + assert posted["source_settings"]["connector_id"] == tgt_conn + assert posted["destination_settings"]["connector_id"] == tgt_conn + # Unrelated nested data passes through untouched. + assert posted["source_settings"]["extras"] == {"a": 1} + + assert ctx.remap.resolve("workflow", "wf-src-1") == posted["id"] + + +def test_idempotent_rerun_adopts_existing_workflow(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient( + [{"id": "wf-tgt-pre", "workflow_name": "Invoice ETL"}] + ) + ctx = _ctx(src, tgt, on_name_conflict="adopt") + + result = WorkflowPhase(ctx).run(CloneReport()) + + assert result.adopted == 1 + assert result.created == 0 + assert tgt.posts == [] + assert ctx.remap.resolve("workflow", "wf-src-1") == "wf-tgt-pre" + + +def test_dry_run_creates_nothing(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient() + ctx = _ctx(src, tgt, dry_run=True) + + result = WorkflowPhase(ctx).run(CloneReport()) + + assert result.skipped == 1 + assert tgt.posts == [] + + +def test_abort_on_name_conflict_raises(): + src = FakeClient([_src("wf-src-1", "Invoice ETL")]) + tgt = FakeClient( + [{"id": "wf-tgt-pre", "workflow_name": "Invoice ETL"}] + ) + ctx = _ctx(src, tgt, on_name_conflict="abort") + + with pytest.raises(NameConflictError): + WorkflowPhase(ctx).run(CloneReport()) + + +def test_cascade_skip_when_workflow_tool_was_skipped(): + """Workflow whose ToolInstance references a registry id in the + cascade-skip set must not land on target. Re-runs after the operator + wires the missing adapter pick it up naturally. + """ + skipped_reg = "skipped-registry-id" + src = FakeClient([_src("wf-skipped", "Frictionless WF"), _src("wf-ok", "OK WF")]) + src.tool_instances = [ + {"workflow": "wf-skipped", "tool_id": skipped_reg}, + {"workflow": "wf-ok", "tool_id": "other-registry-id"}, + ] + tgt = FakeClient() + ctx = _ctx(src, tgt) + ctx.skipped_custom_tool_registry_ids.add(skipped_reg) + + result = WorkflowPhase(ctx).run(CloneReport()) + + assert result.created == 1 + assert result.skipped == 1 + assert [p["workflow_name"] for p in tgt.posts] == ["OK WF"] + assert ctx.remap.resolve("workflow", "wf-skipped") is None + assert ctx.remap.resolve("workflow", "wf-ok") is not None diff --git a/uv.lock b/uv.lock index 25478df..8710285 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" [[package]] @@ -102,6 +102,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, ] +[[package]] +name = "click" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/98/518d8e5081007684232226f475082b30087d0f585e8457db087298259f49/click-8.4.1.tar.gz", hash = "sha256:918b5633eddf6b41c32d4f454bf0de810065c74e3f7dbf8ee5452f8be88d3e96", size = 353007, upload-time = "2026-05-22T04:08:37.769Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/0d/67e5b4109ea4a837e80daa87c2c696711955e40449a97e8926672534def2/click-8.4.1-py3-none-any.whl", hash = "sha256:482be17c6991b8c19c5429a1e995d9b0efdbb63172824c41f99965dc0ade8ec2", size = 116639, upload-time = "2026-05-22T04:08:35.26Z" }, +] + [[package]] name = "colorama" version = "0.4.6" @@ -282,6 +294,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/ff/7841249c247aa650a76b9ee4bbaeae59370dc8bfd2f6c01f3630c35eb134/markdown_it_py-4.2.0.tar.gz", hash = "sha256:04a21681d6fbb623de53f6f364d352309d4094dd4194040a10fd51833e418d49", size = 82454, upload-time = "2026-05-07T12:08:28.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl", hash = "sha256:9f7ebbcd14fe59494226453aed97c1070d83f8d24b6fc3a3bcf9a38092641c4a", size = 91687, upload-time = "2026-05-07T12:08:27.182Z" }, +] + [[package]] name = "mbstrdecoder" version = "1.1.4" @@ -294,6 +318,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/ac/5ce64a1d4cce00390beab88622a290420401f1cabf05caf2fc0995157c21/mbstrdecoder-1.1.4-py3-none-any.whl", hash = "sha256:03dae4ec50ec0d2ff4743e63fdbd5e0022815857494d35224b60775d3d934a8c", size = 7933, upload-time = "2025-01-18T10:07:29.562Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + [[package]] name = "mypy" version = "1.10.1" @@ -593,6 +626,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, ] +[[package]] +name = "rich" +version = "15.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/8f/0722ca900cc807c13a6a0c696dacf35430f72e0ec571c4275d2371fca3e9/rich-15.0.0.tar.gz", hash = "sha256:edd07a4824c6b40189fb7ac9bc4c52536e9780fbbfbddf6f1e2502c31b068c36", size = 230680, upload-time = "2026-04-12T08:24:00.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" }, +] + [[package]] name = "ruff" version = "0.15.1" @@ -757,6 +803,12 @@ dependencies = [ { name = "tenacity" }, ] +[package.optional-dependencies] +clone = [ + { name = "click" }, + { name = "rich" }, +] + [package.dev-dependencies] dev = [ { name = "docutils" }, @@ -789,9 +841,12 @@ test = [ [package.metadata] requires-dist = [ + { name = "click", marker = "extra == 'clone'", specifier = ">=8.1" }, { name = "requests", specifier = ">=2.32.3" }, + { name = "rich", marker = "extra == 'clone'", specifier = ">=13.7" }, { name = "tenacity", specifier = ">=8.2.0" }, ] +provides-extras = ["clone"] [package.metadata.requires-dev] dev = [