From 8988ac7d493728a8a76dcbc419a209af72e5e92d Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:03:20 -0700 Subject: [PATCH 1/9] improvement(providers): harden OpenAI-compatible providers + add tests --- apps/sim/providers/fireworks/index.test.ts | 226 ++++++++++++ apps/sim/providers/fireworks/index.ts | 10 +- apps/sim/providers/fireworks/utils.ts | 13 +- apps/sim/providers/litellm/index.test.ts | 287 +++++++++++++++ apps/sim/providers/litellm/index.ts | 126 ++++++- apps/sim/providers/ollama/index.test.ts | 347 +++++++++++++++++ apps/sim/providers/ollama/index.ts | 45 ++- apps/sim/providers/openai/core.ts | 55 +-- apps/sim/providers/openrouter/index.test.ts | 345 +++++++++++++++++ apps/sim/providers/openrouter/index.ts | 9 +- apps/sim/providers/openrouter/utils.ts | 14 - apps/sim/providers/utils.ts | 53 +++ apps/sim/providers/vllm/index.test.ts | 296 +++++++++++++++ apps/sim/providers/vllm/index.ts | 388 ++++++++++---------- apps/sim/providers/vllm/utils.ts | 15 +- 15 files changed, 1911 insertions(+), 318 deletions(-) create mode 100644 apps/sim/providers/fireworks/index.test.ts create mode 100644 apps/sim/providers/litellm/index.test.ts create mode 100644 apps/sim/providers/ollama/index.test.ts create mode 100644 apps/sim/providers/openrouter/index.test.ts create mode 100644 apps/sim/providers/vllm/index.test.ts diff --git a/apps/sim/providers/fireworks/index.test.ts b/apps/sim/providers/fireworks/index.test.ts new file mode 100644 index 00000000000..68fba04c736 --- /dev/null +++ b/apps/sim/providers/fireworks/index.test.ts @@ -0,0 +1,226 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCreate, + mockSupportsNativeStructuredOutputs, + mockPrepareToolsWithUsageControl, + mockExecuteTool, +} = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockSupportsNativeStructuredOutputs: vi.fn(), + mockPrepareToolsWithUsageControl: vi.fn(), + mockExecuteTool: vi.fn(), +})) + +vi.mock('openai', () => ({ + default: vi.fn().mockImplementation(() => ({ + chat: { completions: { create: mockCreate } }, + })), +})) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 5 })) + +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn().mockReturnValue([]), + getProviderDefaultModel: vi.fn().mockReturnValue('llama-v3p1-70b-instruct'), +})) + +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: vi.fn((messages) => messages), +})) + +vi.mock('@/providers/fireworks/utils', () => ({ + supportsNativeStructuredOutputs: mockSupportsNativeStructuredOutputs, + createReadableStreamFromOpenAIStream: vi.fn(() => ({}) as ReadableStream), + checkForForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), +})) + +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn().mockReturnValue({ input: 0, output: 0, total: 0 }), + generateSchemaInstructions: vi.fn(() => 'SCHEMA_INSTRUCTIONS'), + prepareToolExecution: vi.fn(() => ({ toolParams: { x: 1 }, executionParams: { x: 1 } })), + prepareToolsWithUsageControl: mockPrepareToolsWithUsageControl, + sumToolCosts: vi.fn().mockReturnValue(0), +})) + +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +import { fireworksProvider } from '@/providers/fireworks/index' +import { ProviderError } from '@/providers/types' + +const textResponse = (content: string) => ({ + choices: [{ message: { content, tool_calls: [] } }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, +}) + +const toolCallResponse = () => ({ + choices: [ + { + message: { + content: null, + tool_calls: [ + { id: 'call_1', type: 'function', function: { name: 'my_tool', arguments: '{"x":1}' } }, + ], + }, + }, + ], + usage: { prompt_tokens: 8, completion_tokens: 4, total_tokens: 12 }, +}) + +const toolDef = { + id: 'my_tool', + name: 'my_tool', + description: '', + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, +} + +const callBody = (index: number) => mockCreate.mock.calls[index][0] +const lastCallBody = () => mockCreate.mock.calls.at(-1)?.[0] + +describe('fireworksProvider', () => { + beforeEach(() => { + vi.clearAllMocks() + mockSupportsNativeStructuredOutputs.mockResolvedValue(true) + mockPrepareToolsWithUsageControl.mockImplementation((tools) => ({ + tools, + toolChoice: 'auto', + forcedTools: [], + })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + const baseRequest = { + model: 'fireworks/llama-v3p1-70b-instruct', + systemPrompt: 'You are helpful.', + messages: [{ role: 'user' as const, content: 'Hello' }], + apiKey: 'fw-test-key', + } + + it('throws when the API key is missing', async () => { + await expect( + fireworksProvider.executeRequest({ ...baseRequest, apiKey: undefined }) + ).rejects.toThrow('API key is required for Fireworks') + }) + + it('returns content and token usage for a simple request', async () => { + mockCreate.mockResolvedValueOnce(textResponse('hi there')) + + const result = await fireworksProvider.executeRequest(baseRequest) + + expect(result).toMatchObject({ + content: 'hi there', + model: 'llama-v3p1-70b-instruct', + tokens: { input: 10, output: 5, total: 15 }, + }) + }) + + it('wraps API errors in a ProviderError', async () => { + mockCreate.mockRejectedValueOnce(new Error('boom')) + + await expect(fireworksProvider.executeRequest(baseRequest)).rejects.toBeInstanceOf( + ProviderError + ) + }) + + it('streams directly when there are no tools', async () => { + mockCreate.mockResolvedValueOnce({}) + + const result = await fireworksProvider.executeRequest({ ...baseRequest, stream: true }) + + expect(lastCallBody()).toMatchObject({ stream: true, stream_options: { include_usage: true } }) + expect(result).toHaveProperty('stream') + expect(result).toHaveProperty('execution') + }) + + it('sends a json_schema response_format with no strict field', async () => { + mockCreate.mockResolvedValueOnce(textResponse('{}')) + + await fireworksProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' }, strict: true }, + }) + + expect(lastCallBody().response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'my_schema', schema: { type: 'object' } }, + }) + expect(lastCallBody().response_format.json_schema).not.toHaveProperty('strict') + }) + + it('falls back to json_object with prompt instructions when native is unsupported', async () => { + mockSupportsNativeStructuredOutputs.mockResolvedValue(false) + mockCreate.mockResolvedValueOnce(textResponse('{}')) + + await fireworksProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' } }, + }) + + expect(lastCallBody().response_format).toEqual({ type: 'json_object' }) + expect(lastCallBody().messages.at(-1)).toEqual({ + role: 'user', + content: 'SCHEMA_INSTRUCTIONS', + }) + }) + + it('defers response_format to a final call when tools are active', async () => { + mockCreate + .mockResolvedValueOnce(textResponse('intermediate')) + .mockResolvedValueOnce(textResponse('{"done":true}')) + + await fireworksProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' } }, + tools: [toolDef], + }) + + expect(mockCreate).toHaveBeenCalledTimes(2) + expect(callBody(0).response_format).toBeUndefined() + expect(callBody(0).tools).toBeDefined() + expect(callBody(1).response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'my_schema', schema: { type: 'object' } }, + }) + expect(callBody(1).tools).toBeUndefined() + }) + + it('runs the tool loop and threads tool results back into the conversation', async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse()) + .mockResolvedValueOnce(textResponse('final answer')) + + const result = await fireworksProvider.executeRequest({ ...baseRequest, tools: [toolDef] }) + + expect(mockExecuteTool).toHaveBeenCalledWith('my_tool', { x: 1 }, expect.anything()) + expect(result).toMatchObject({ content: 'final answer' }) + expect((result as { toolCalls?: unknown[] }).toolCalls).toHaveLength(1) + + const followUpMessages = callBody(1).messages + expect(followUpMessages).toContainEqual( + expect.objectContaining({ role: 'assistant', tool_calls: expect.any(Array) }) + ) + expect(followUpMessages).toContainEqual( + expect.objectContaining({ role: 'tool', tool_call_id: 'call_1' }) + ) + }) + + it("forces tool_choice 'none' on the final streaming call after tools run", async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse()) + .mockResolvedValueOnce(textResponse('done')) + .mockResolvedValueOnce({}) + + await fireworksProvider.executeRequest({ ...baseRequest, stream: true, tools: [toolDef] }) + + expect(mockCreate).toHaveBeenCalledTimes(3) + expect(lastCallBody()).toMatchObject({ tool_choice: 'none', stream: true }) + }) +}) diff --git a/apps/sim/providers/fireworks/index.ts b/apps/sim/providers/fireworks/index.ts index 794ee3f0805..cdf355d3451 100644 --- a/apps/sim/providers/fireworks/index.ts +++ b/apps/sim/providers/fireworks/index.ts @@ -34,7 +34,7 @@ const logger = createLogger('FireworksProvider') /** * Applies structured output configuration to a payload based on model capabilities. - * Uses json_schema with strict mode for supported models, falls back to json_object with prompt instructions. + * Uses native json_schema for supported models, falls back to json_object with prompt instructions. */ async function applyResponseFormat( targetPayload: any, @@ -51,7 +51,6 @@ async function applyResponseFormat( json_schema: { name: responseFormat.name || 'response_schema', schema: responseFormat.schema || responseFormat, - strict: responseFormat.strict !== false, }, } return messages @@ -469,7 +468,7 @@ export const fireworksProvider: ProviderConfig = { const streamingParams: ChatCompletionCreateParamsStreaming = { ...payload, messages: [...currentMessages], - tool_choice: 'auto', + tool_choice: 'none', stream: true, stream_options: { include_usage: true }, } @@ -652,8 +651,3 @@ export const fireworksProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */ diff --git a/apps/sim/providers/fireworks/utils.ts b/apps/sim/providers/fireworks/utils.ts index 70444e07b69..631800cc67f 100644 --- a/apps/sim/providers/fireworks/utils.ts +++ b/apps/sim/providers/fireworks/utils.ts @@ -2,18 +2,11 @@ import type { ChatCompletionChunk } from 'openai/resources/chat/completions' import type { CompletionUsage } from 'openai/resources/completions' import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils' -/** - * Checks if a model supports native structured outputs (json_schema). - * Fireworks AI supports structured outputs across their inference API. - */ +/** Fireworks supports native json_schema structured outputs for all models on its inference API. */ export async function supportsNativeStructuredOutputs(_modelId: string): Promise { return true } -/** - * Creates a ReadableStream from a Fireworks streaming response. - * Uses the shared OpenAI-compatible streaming utility. - */ export function createReadableStreamFromOpenAIStream( openaiStream: AsyncIterable, onComplete?: (content: string, usage: CompletionUsage) => void @@ -21,10 +14,6 @@ export function createReadableStreamFromOpenAIStream( return createOpenAICompatibleStream(openaiStream, 'Fireworks', onComplete) } -/** - * Checks if a forced tool was used in a Fireworks response. - * Uses the shared OpenAI-compatible forced tool usage helper. - */ export function checkForForcedToolUsage( response: any, toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, diff --git a/apps/sim/providers/litellm/index.test.ts b/apps/sim/providers/litellm/index.test.ts new file mode 100644 index 00000000000..653090a1160 --- /dev/null +++ b/apps/sim/providers/litellm/index.test.ts @@ -0,0 +1,287 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockCreate, mockExecuteTool } = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), +})) + +vi.mock('openai', () => ({ + default: vi.fn().mockImplementation(() => ({ + chat: { completions: { create: mockCreate } }, + })), +})) + +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) + +vi.mock('@/lib/core/config/env', () => ({ + env: { LITELLM_BASE_URL: 'http://litellm.test', LITELLM_API_KEY: '' }, +})) + +vi.mock('@/stores/providers', () => ({ + useProvidersStore: { getState: () => ({ setProviderModels: vi.fn() }) }, +})) + +vi.mock('@/providers/models', () => ({ + getProviderModels: () => [], + getProviderDefaultModel: () => '', +})) + +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: (messages: unknown) => messages, +})) + +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) + +vi.mock('@/providers/litellm/utils', () => ({ + createReadableStreamFromLiteLLMStream: vi.fn( + () => new ReadableStream({ start: (c) => c.close() }) + ), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn(() => ({ input: 0, output: 0, total: 0 })), + sumToolCosts: vi.fn(() => 0), + prepareToolExecution: vi.fn((_tool, toolArgs) => ({ + toolParams: toolArgs, + executionParams: toolArgs, + })), + prepareToolsWithUsageControl: vi.fn((tools) => ({ + tools, + toolChoice: 'auto', + forcedTools: [], + hasFilteredTools: false, + })), + trackForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), + enforceStrictSchema: vi.fn((schema) => ({ ...schema, additionalProperties: false })), +})) + +import { litellmProvider } from '@/providers/litellm' +import { ProviderError } from '@/providers/types' + +interface ChatOptions { + content?: string | null + toolCalls?: Array<{ id: string; function: { name: string; arguments: string } }> + usage?: { prompt_tokens: number; completion_tokens: number; total_tokens: number } +} + +function chat({ content = null, toolCalls, usage }: ChatOptions = {}) { + return { + choices: [ + { + message: { content, tool_calls: toolCalls }, + finish_reason: toolCalls ? 'tool_calls' : 'stop', + }, + ], + usage: usage ?? { prompt_tokens: 5, completion_tokens: 3, total_tokens: 8 }, + } +} + +function tool(name: string) { + return { id: name, name, description: 'd', parameters: {} } +} + +function run(request: Record) { + return litellmProvider.executeRequest!({ + model: 'litellm/llama-3', + messages: [{ role: 'user', content: 'Hi' }], + ...request, + } as never) as Promise +} + +const firstPayload = () => mockCreate.mock.calls[0][0] +const lastPayload = () => mockCreate.mock.calls.at(-1)![0] + +describe('litellmProvider.executeRequest', () => { + beforeEach(() => { + vi.clearAllMocks() + mockCreate.mockResolvedValue(chat({ content: 'hello' })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + it('assembles messages, strips the model prefix, and maps params', async () => { + const result = await run({ + systemPrompt: 'You are helpful.', + context: 'Some context', + temperature: 0.5, + maxTokens: 256, + }) + + const payload = firstPayload() + expect(payload.model).toBe('llama-3') + expect(payload.messages).toEqual([ + { role: 'system', content: 'You are helpful.' }, + { role: 'user', content: 'Some context' }, + { role: 'user', content: 'Hi' }, + ]) + expect(payload.temperature).toBe(0.5) + expect(payload.max_completion_tokens).toBe(256) + expect(result.content).toBe('hello') + expect(result.tokens).toEqual({ input: 5, output: 3, total: 8 }) + }) + + it('forwards reasoning_effort only when set to a non-default value', async () => { + await run({ reasoningEffort: 'high' }) + expect(firstPayload().reasoning_effort).toBe('high') + + mockCreate.mockClear() + await run({ reasoningEffort: 'auto' }) + expect(firstPayload().reasoning_effort).toBeUndefined() + + mockCreate.mockClear() + await run({}) + expect(firstPayload().reasoning_effort).toBeUndefined() + }) + + it('sanitizes the schema for strict response_format and passes it through otherwise', async () => { + await run({ responseFormat: { name: 'r', schema: { type: 'object', properties: {} } } }) + let rf = firstPayload().response_format + expect(rf.type).toBe('json_schema') + expect(rf.json_schema.strict).toBe(true) + expect(rf.json_schema.schema.additionalProperties).toBe(false) + + mockCreate.mockClear() + await run({ + responseFormat: { name: 'r', schema: { type: 'object', properties: {} }, strict: false }, + }) + rf = firstPayload().response_format + expect(rf.json_schema.strict).toBe(false) + expect(rf.json_schema.schema.additionalProperties).toBeUndefined() + }) + + it('defers response_format past the tool loop and keeps tools on the final call', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'c1', function: { name: 'known', arguments: '{"q":1}' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'mid' })) + .mockResolvedValueOnce(chat({ content: '{"answer":1}' })) + + const result = await run({ + tools: [tool('known')], + responseFormat: { name: 'r', schema: { type: 'object', properties: {} } }, + }) + + expect(firstPayload().response_format).toBeUndefined() + expect(firstPayload().tools).toBeDefined() + + const final = lastPayload() + expect(final.response_format.type).toBe('json_schema') + expect(final.tools).toBeDefined() + expect(final.tool_choice).toBe('auto') + expect(final.parallel_tool_calls).toBe(false) + expect(result.content).toBe('{"answer":1}') + }) + + it('defers response_format into the final streaming call while keeping tools', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'c1', function: { name: 'known', arguments: '{}' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'mid' })) + + const result = await run({ + stream: true, + tools: [tool('known')], + responseFormat: { name: 'r', schema: { type: 'object', properties: {} } }, + }) + + const final = lastPayload() + expect(final.stream).toBe(true) + expect(final.response_format.type).toBe('json_schema') + expect(final.tools).toBeDefined() + expect(final.parallel_tool_calls).toBe(false) + expect(result.execution.isStreaming).toBe(true) + }) + + it('threads assistant tool_calls and a named tool response, and reports toolCalls', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'c1', function: { name: 'known', arguments: '{}' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'done' })) + mockExecuteTool.mockResolvedValue({ success: true, output: { temp: 72 } }) + + const result = await run({ tools: [tool('known')] }) + + const followupMessages = mockCreate.mock.calls[1][0].messages + expect(followupMessages).toContainEqual({ + role: 'assistant', + content: null, + tool_calls: [{ id: 'c1', type: 'function', function: { name: 'known', arguments: '{}' } }], + }) + expect(followupMessages).toContainEqual({ + role: 'tool', + tool_call_id: 'c1', + name: 'known', + content: JSON.stringify({ temp: 72 }), + }) + expect(result.toolCalls).toHaveLength(1) + expect(result.content).toBe('done') + }) + + it('emits a stub tool response for an unanswered tool_call_id', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'cX', function: { name: 'ghost', arguments: '{}' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'recovered' })) + + await run({ tools: [tool('known')] }) + + expect(mockExecuteTool).not.toHaveBeenCalled() + const followupMessages = mockCreate.mock.calls[1][0].messages + const toolMsg = followupMessages.find((m: any) => m.role === 'tool' && m.tool_call_id === 'cX') + expect(toolMsg).toBeDefined() + expect(toolMsg.content).toContain('not available') + }) + + it('executes a tool with empty arguments without failing', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'c1', function: { name: 'ping', arguments: '' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'pong' })) + + await run({ tools: [tool('ping')] }) + + expect(mockExecuteTool).toHaveBeenCalledTimes(1) + const toolMsg = mockCreate.mock.calls[1][0].messages.find((m: any) => m.role === 'tool') + expect(toolMsg.content).not.toContain('"error":true') + }) + + it('stops the tool loop at MAX_TOOL_ITERATIONS', async () => { + mockCreate.mockResolvedValue( + chat({ toolCalls: [{ id: 'c1', function: { name: 'known', arguments: '{}' } }] }) + ) + + await run({ tools: [tool('known')] }) + + expect(mockCreate).toHaveBeenCalledTimes(1 + 20) + expect(mockExecuteTool).toHaveBeenCalledTimes(20) + }) + + it('returns a streaming execution when streaming without active tools', async () => { + const result = await run({ stream: true }) + + expect(firstPayload().stream).toBe(true) + expect(firstPayload().stream_options).toEqual({ include_usage: true }) + expect(result.stream).toBeInstanceOf(ReadableStream) + expect(result.execution.isStreaming).toBe(true) + }) + + it('wraps API errors in a ProviderError using the error envelope message', async () => { + mockCreate.mockRejectedValue({ + error: { message: 'rate limited', type: 'rate_limit_error', code: '429' }, + }) + + await expect(run({})).rejects.toBeInstanceOf(ProviderError) + await expect(run({})).rejects.toThrow('rate limited') + }) +}) diff --git a/apps/sim/providers/litellm/index.ts b/apps/sim/providers/litellm/index.ts index 33e363f0509..10f3be87d31 100644 --- a/apps/sim/providers/litellm/index.ts +++ b/apps/sim/providers/litellm/index.ts @@ -19,6 +19,7 @@ import type { import { ProviderError } from '@/providers/types' import { calculateCost, + enforceStrictSchema, prepareToolExecution, prepareToolsWithUsageControl, sumToolCosts, @@ -146,19 +147,29 @@ export const litellmProvider: ProviderConfig = { if (request.temperature !== undefined) payload.temperature = request.temperature if (request.maxTokens != null) payload.max_completion_tokens = request.maxTokens - if (request.responseFormat) { - payload.response_format = { - type: 'json_schema', - json_schema: { - name: request.responseFormat.name || 'response_schema', - schema: request.responseFormat.schema || request.responseFormat, - strict: request.responseFormat.strict !== false, - }, - } - - logger.info('Added JSON schema response format to LiteLLM request') + if (request.reasoningEffort !== undefined && request.reasoningEffort !== 'auto') { + payload.reasoning_effort = request.reasoningEffort } + const isStrictResponseFormat = request.responseFormat + ? request.responseFormat.strict !== false + : false + + const responseFormatPayload = request.responseFormat + ? { + type: 'json_schema' as const, + json_schema: { + name: request.responseFormat.name || 'response_schema', + // Strict mode requires additionalProperties:false and all-required keys; + // OpenAI-backed routes 400 without it. + schema: isStrictResponseFormat + ? enforceStrictSchema(request.responseFormat.schema || request.responseFormat) + : request.responseFormat.schema || request.responseFormat, + strict: isStrictResponseFormat, + }, + } + : undefined + let preparedTools: ReturnType | null = null let hasActiveTools = false @@ -184,6 +195,14 @@ export const litellmProvider: ProviderConfig = { } } + // response_format + tools conflict on some backends (Anthropic rejects the pair, + // vLLM guided decoding suppresses tool calls), so defer the format past the tool loop. + const deferResponseFormat = !!responseFormatPayload && hasActiveTools + if (responseFormatPayload && !deferResponseFormat) { + payload.response_format = responseFormatPayload + logger.info('Added JSON schema response format to LiteLLM request') + } + const providerStartTime = Date.now() const providerStartTimeISO = new Date(providerStartTime).toISOString() @@ -271,6 +290,7 @@ export const litellmProvider: ProviderConfig = { endTime: new Date().toISOString(), duration: Date.now() - providerStartTime, }, + isStreaming: true, }, } as StreamingExecution @@ -374,7 +394,9 @@ export const litellmProvider: ProviderConfig = { const toolName = toolCall.function.name try { - const toolArgs = JSON.parse(toolCall.function.arguments) + const toolArgs = toolCall.function.arguments + ? JSON.parse(toolCall.function.arguments) + : {} const tool = request.tools?.find((t) => t.id === toolName) if (!tool) return null @@ -429,6 +451,8 @@ export const litellmProvider: ProviderConfig = { })), }) + const respondedToolCallIds = new Set() + for (const settledResult of executionResults) { if (settledResult.status === 'rejected' || !settledResult.value) continue @@ -469,8 +493,26 @@ export const litellmProvider: ProviderConfig = { currentMessages.push({ role: 'tool', tool_call_id: toolCall.id, + name: toolName, content: JSON.stringify(resultContent), }) + respondedToolCallIds.add(toolCall.id) + } + + // Every tool_call needs a matching `tool` response or the next request 400s; + // stub any the model left unanswered (e.g. an unknown/filtered tool name). + for (const tc of toolCallsInResponse) { + if (respondedToolCallIds.has(tc.id)) continue + currentMessages.push({ + role: 'tool', + tool_call_id: tc.id, + name: tc.function.name, + content: JSON.stringify({ + error: true, + message: `Tool "${tc.function.name}" is not available`, + tool: tc.function.name, + }), + }) } const thisToolsTime = Date.now() - toolsStartTime @@ -555,6 +597,12 @@ export const litellmProvider: ProviderConfig = { stream: true, stream_options: { include_usage: true }, } + if (deferResponseFormat && responseFormatPayload) { + // Keep tools defined (Anthropic requires it once history holds tool results) and + // disable parallel calls (OpenAI's rule for strict outputs alongside tools). + streamingParams.response_format = responseFormatPayload + streamingParams.parallel_tool_calls = false + } const streamResponse = await litellm.chat.completions.create( streamingParams, request.abortSignal ? { signal: request.abortSignal } : undefined @@ -626,12 +674,64 @@ export const litellmProvider: ProviderConfig = { endTime: new Date().toISOString(), duration: Date.now() - providerStartTime, }, + isStreaming: true, }, } as StreamingExecution return streamingResult as StreamingExecution } + if (deferResponseFormat && responseFormatPayload) { + logger.info('Applying deferred JSON schema response format after tool processing') + + const finalFormatStartTime = Date.now() + const finalPayload: any = { + model: payload.model, + messages: currentMessages, + response_format: responseFormatPayload, + // Keep tools defined (Anthropic requires it once history holds tool results) and + // disable parallel calls (OpenAI's rule for strict outputs alongside tools). + tools: payload.tools, + tool_choice: 'auto', + parallel_tool_calls: false, + } + if (request.temperature !== undefined) finalPayload.temperature = request.temperature + if (request.maxTokens != null) finalPayload.max_completion_tokens = request.maxTokens + + currentResponse = await litellm.chat.completions.create( + finalPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const finalFormatEndTime = Date.now() + timeSegments.push({ + type: 'model', + name: request.model, + startTime: finalFormatStartTime, + endTime: finalFormatEndTime, + duration: finalFormatEndTime - finalFormatStartTime, + }) + modelTime += finalFormatEndTime - finalFormatStartTime + + const formattedContent = currentResponse.choices[0]?.message?.content + if (formattedContent) { + content = formattedContent.replace(/```json\n?|\n?```/g, '').trim() + } + + if (currentResponse.usage) { + tokens.input += currentResponse.usage.prompt_tokens || 0 + tokens.output += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } + + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + currentResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: 'litellm' } + ) + } + const providerEndTime = Date.now() const providerEndTimeISO = new Date(providerEndTime).toISOString() const totalDuration = providerEndTime - providerStartTime @@ -660,7 +760,7 @@ export const litellmProvider: ProviderConfig = { let errorMessage = toError(error).message let errorType: string | undefined - let errorCode: number | undefined + let errorCode: string | number | undefined if (error && typeof error === 'object' && 'error' in error) { const litellmError = error.error as any diff --git a/apps/sim/providers/ollama/index.test.ts b/apps/sim/providers/ollama/index.test.ts new file mode 100644 index 00000000000..3b409ec8216 --- /dev/null +++ b/apps/sim/providers/ollama/index.test.ts @@ -0,0 +1,347 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +type StreamUsage = { prompt_tokens: number; completion_tokens: number; total_tokens: number } + +const { mockCreate, mockExecuteTool, streamOnComplete, MockAPIError } = vi.hoisted(() => { + class MockAPIError extends Error { + status?: number + code?: string | null + type?: string + constructor(message: string, opts: { status?: number; code?: string; type?: string } = {}) { + super(message) + this.name = 'APIError' + this.status = opts.status + this.code = opts.code + this.type = opts.type + } + } + return { + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), + streamOnComplete: { + current: undefined as undefined | ((content: string, usage: StreamUsage) => void), + }, + MockAPIError, + } +}) + +vi.mock('openai', () => { + const OpenAI = vi.fn(() => ({ chat: { completions: { create: mockCreate } } })) + ;(OpenAI as unknown as { APIError: typeof MockAPIError }).APIError = MockAPIError + return { default: OpenAI } +}) + +vi.mock('@/lib/core/utils/urls', () => ({ getOllamaUrl: () => 'http://localhost:11434' })) +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: (messages: unknown) => messages, +})) +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) +vi.mock('@/providers/ollama/utils', () => ({ + createReadableStreamFromOllamaStream: ( + _stream: unknown, + onComplete: (content: string, usage: StreamUsage) => void + ) => { + streamOnComplete.current = onComplete + return 'OLLAMA_STREAM' + }, +})) +vi.mock('@/providers/utils', () => ({ + calculateCost: () => ({ input: 0, output: 0, total: 0, pricing: null }), + prepareToolExecution: (_tool: unknown, args: Record) => ({ + toolParams: args, + executionParams: args, + }), + sumToolCosts: () => 0, +})) +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) +vi.mock('@/stores/providers', () => ({ + useProvidersStore: { getState: () => ({ setProviderModels: vi.fn() }) }, +})) + +import { ollamaProvider } from '@/providers/ollama' +import type { ProviderRequest, ProviderResponse, ProviderToolConfig } from '@/providers/types' + +interface StreamingResult { + stream: string + execution: { + output: { + content: string + tokens: { input: number; output: number; total: number } + toolCalls?: { list: unknown[]; count: number } + } + } +} + +type ToolCallChunk = { id: string; type: 'function'; function: { name: string; arguments: string } } + +function completion( + opts: { content?: string | null; toolCalls?: ToolCallChunk[]; usage?: StreamUsage } = {} +) { + return { + choices: [{ message: { content: opts.content ?? null, tool_calls: opts.toolCalls } }], + usage: opts.usage ?? { prompt_tokens: 5, completion_tokens: 3, total_tokens: 8 }, + } +} + +function makeTool(id: string, usageControl?: 'auto' | 'force' | 'none'): ProviderToolConfig { + return { + id, + name: id, + description: `${id} tool`, + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, + ...(usageControl ? { usageControl } : {}), + } +} + +const baseRequest: ProviderRequest = { + model: 'llama3.2', + messages: [{ role: 'user', content: 'hi' }], +} + +describe('ollamaProvider.executeRequest', () => { + beforeEach(() => { + vi.clearAllMocks() + streamOnComplete.current = undefined + mockCreate.mockResolvedValue(completion({ content: 'hello' })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + it('assembles system, context, then history in order and forwards params', async () => { + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + systemPrompt: 'be nice', + context: 'ctx', + temperature: 0.5, + maxTokens: 128, + })) as ProviderResponse + + expect(result).toMatchObject({ content: 'hello', model: 'llama3.2' }) + const payload = mockCreate.mock.calls[0][0] + expect(payload.messages).toEqual([ + { role: 'system', content: 'be nice' }, + { role: 'user', content: 'ctx' }, + { role: 'user', content: 'hi' }, + ]) + expect(payload.model).toBe('llama3.2') + expect(payload.temperature).toBe(0.5) + expect(payload.max_tokens).toBe(128) + }) + + it('returns content verbatim (keeps ```json fences) when no responseFormat', async () => { + const fenced = '```json\n{"a":1}\n```' + mockCreate.mockResolvedValue(completion({ content: fenced })) + const result = (await ollamaProvider.executeRequest(baseRequest)) as ProviderResponse + expect(result.content).toBe(fenced) + }) + + it('strips ```json fences and sends a json_schema response_format when requested', async () => { + mockCreate.mockResolvedValue(completion({ content: '```json\n{"a":1}\n```' })) + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'r', schema: { type: 'object' }, strict: true }, + })) as ProviderResponse + expect(result.content).toBe('{"a":1}') + expect(mockCreate.mock.calls[0][0].response_format).toMatchObject({ + type: 'json_schema', + json_schema: { name: 'r', schema: { type: 'object' }, strict: true }, + }) + }) + + it('runs the tool loop: parses string args, feeds results back, then terminates', async () => { + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{"x":1}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'done' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('mytool')], + })) as ProviderResponse + + expect(mockExecuteTool).toHaveBeenCalledWith('mytool', { x: 1 }, expect.anything()) + expect(mockCreate).toHaveBeenCalledTimes(2) + expect(result.content).toBe('done') + expect(result.toolCalls).toEqual([ + expect.objectContaining({ name: 'mytool', success: true, arguments: { x: 1 } }), + ]) + expect(result.toolResults).toEqual([{ ok: true }]) + + const followUp = mockCreate.mock.calls[1][0].messages + expect(followUp).toContainEqual( + expect.objectContaining({ + role: 'assistant', + content: null, + tool_calls: [ + expect.objectContaining({ + id: 'call_1', + function: { name: 'mytool', arguments: '{"x":1}' }, + }), + ], + }) + ) + expect(followUp).toContainEqual({ + role: 'tool', + tool_call_id: 'call_1', + content: JSON.stringify({ ok: true }), + }) + }) + + it('records a failed tool result without aborting the loop', async () => { + mockExecuteTool.mockResolvedValue({ success: false, error: 'boom' }) + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'recovered' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('mytool')], + })) as ProviderResponse + + expect(result.content).toBe('recovered') + expect(result.toolCalls?.[0]).toMatchObject({ name: 'mytool', success: false }) + const toolMsg = mockCreate.mock.calls[1][0].messages.find( + (m: { role: string }) => m.role === 'tool' + ) + expect(JSON.parse(toolMsg.content)).toMatchObject({ error: true, message: 'boom' }) + }) + + it('executes parallel tool calls from a single response', async () => { + mockExecuteTool + .mockResolvedValueOnce({ success: true, output: { from: 'a' } }) + .mockResolvedValueOnce({ success: true, output: { from: 'b' } }) + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_a', type: 'function', function: { name: 'a', arguments: '{}' } }, + { id: 'call_b', type: 'function', function: { name: 'b', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'summary' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('a'), makeTool('b')], + })) as ProviderResponse + + expect(mockExecuteTool).toHaveBeenCalledTimes(2) + expect(result.toolCalls?.map((c) => c.name)).toEqual(['a', 'b']) + const toolMsgs = mockCreate.mock.calls[1][0].messages.filter( + (m: { role: string }) => m.role === 'tool' + ) + expect(toolMsgs.map((m: { tool_call_id: string }) => m.tool_call_id)).toEqual([ + 'call_a', + 'call_b', + ]) + }) + + it('filters out tools with usageControl "none"', async () => { + await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('keep'), makeTool('drop', 'none')], + }) + const sent = mockCreate.mock.calls[0][0].tools + expect(sent.map((t: { function: { name: string } }) => t.function.name)).toEqual(['keep']) + }) + + it('never forces tools (Ollama ignores tool_choice) and keeps "auto"', async () => { + await ollamaProvider.executeRequest({ ...baseRequest, tools: [makeTool('forced', 'force')] }) + const payload = mockCreate.mock.calls[0][0] + expect(payload.tool_choice).toBe('auto') + expect(payload.tools.map((t: { function: { name: string } }) => t.function.name)).toEqual([ + 'forced', + ]) + }) + + it('surfaces an OpenAI APIError message through ProviderError', async () => { + mockCreate.mockRejectedValue( + new MockAPIError('model not found', { + status: 404, + code: 'not_found', + type: 'invalid_request_error', + }) + ) + await expect(ollamaProvider.executeRequest(baseRequest)).rejects.toThrow('model not found') + }) + + it('streams content and usage when no tools are used', async () => { + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + stream: true, + })) as unknown as StreamingResult + + expect(result.stream).toBe('OLLAMA_STREAM') + expect(mockCreate.mock.calls[0][0].stream_options).toEqual({ include_usage: true }) + + streamOnComplete.current?.('streamed text', { + prompt_tokens: 4, + completion_tokens: 6, + total_tokens: 10, + }) + expect(result.execution.output.content).toBe('streamed text') + expect(result.execution.output.tokens).toMatchObject({ input: 4, output: 6, total: 10 }) + }) + + it('strips ```json fences from streamed content when responseFormat is set', async () => { + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + stream: true, + responseFormat: { name: 'r', schema: { type: 'object' }, strict: true }, + })) as unknown as StreamingResult + + streamOnComplete.current?.('```json\n{"a":1}\n```', { + prompt_tokens: 1, + completion_tokens: 2, + total_tokens: 3, + }) + expect(result.execution.output.content).toBe('{"a":1}') + }) + + it('streams the final response after a tool loop, carrying tool calls', async () => { + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'intermediate' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + stream: true, + tools: [makeTool('mytool')], + })) as unknown as StreamingResult + + expect(result.stream).toBe('OLLAMA_STREAM') + expect(mockExecuteTool).toHaveBeenCalledTimes(1) + + streamOnComplete.current?.('final answer', { + prompt_tokens: 2, + completion_tokens: 4, + total_tokens: 6, + }) + expect(result.execution.output.content).toBe('final answer') + expect(result.execution.output.toolCalls).toMatchObject({ count: 1 }) + }) +}) diff --git a/apps/sim/providers/ollama/index.ts b/apps/sim/providers/ollama/index.ts index 52332aecdb2..3abb1ea958d 100644 --- a/apps/sim/providers/ollama/index.ts +++ b/apps/sim/providers/ollama/index.ts @@ -1,5 +1,5 @@ import { createLogger } from '@sim/logger' -import { getErrorMessage, toError } from '@sim/utils/errors' +import { getErrorMessage } from '@sim/utils/errors' import OpenAI from 'openai' import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions' import { getOllamaUrl } from '@/lib/core/utils/urls' @@ -10,6 +10,7 @@ import type { ModelsObject } from '@/providers/ollama/types' import { createReadableStreamFromOllamaStream } from '@/providers/ollama/utils' import { enrichLastModelSegmentFromChatCompletions } from '@/providers/trace-enrichment' import type { + Message, ProviderConfig, ProviderRequest, ProviderResponse, @@ -73,7 +74,7 @@ export const ollamaProvider: ProviderConfig = { baseURL: `${OLLAMA_HOST}/v1`, }) - const allMessages = [] + const allMessages: Message[] = [] if (request.systemPrompt) { allMessages.push({ @@ -92,7 +93,7 @@ export const ollamaProvider: ProviderConfig = { if (request.messages) { allMessages.push(...request.messages) } - const formattedMessages = formatMessagesForProvider(allMessages, 'ollama') + const formattedMessages = formatMessagesForProvider(allMessages, 'ollama') as Message[] const tools = request.tools?.length ? request.tools.map((tool) => ({ @@ -180,7 +181,7 @@ export const ollamaProvider: ProviderConfig = { stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => { streamingResult.execution.output.content = content - if (content) { + if (content && request.responseFormat) { streamingResult.execution.output.content = content .replace(/```json\n?|\n?```/g, '') .trim() @@ -264,7 +265,7 @@ export const ollamaProvider: ProviderConfig = { let content = currentResponse.choices[0]?.message?.content || '' - if (content) { + if (content && request.responseFormat) { content = content.replace(/```json\n?|\n?```/g, '') content = content.trim() } @@ -295,6 +296,9 @@ export const ollamaProvider: ProviderConfig = { while (iterationCount < MAX_TOOL_ITERATIONS) { if (currentResponse.choices[0]?.message?.content) { content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } } const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls @@ -450,8 +454,9 @@ export const ollamaProvider: ProviderConfig = { if (currentResponse.choices[0]?.message?.content) { content = currentResponse.choices[0].message.content - content = content.replace(/```json\n?|\n?```/g, '') - content = content.trim() + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } } if (currentResponse.usage) { @@ -493,7 +498,7 @@ export const ollamaProvider: ProviderConfig = { stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => { streamingResult.execution.output.content = content - if (content) { + if (content && request.responseFormat) { streamingResult.execution.output.content = content .replace(/```json\n?|\n?```/g, '') .trim() @@ -589,12 +594,27 @@ export const ollamaProvider: ProviderConfig = { const providerEndTimeISO = new Date(providerEndTime).toISOString() const totalDuration = providerEndTime - providerStartTime + let errorMessage = getErrorMessage(error, 'Unknown error') + let errorType: string | undefined + let errorCode: string | undefined + let status: number | undefined + + if (error instanceof OpenAI.APIError) { + errorMessage = error.message + errorType = error.type + errorCode = error.code ?? undefined + status = error.status + } + logger.error('Error in Ollama request:', { - error, + error: errorMessage, + errorType, + errorCode, + status, duration: totalDuration, }) - throw new ProviderError(toError(error).message, { + throw new ProviderError(errorMessage, { startTime: providerStartTimeISO, endTime: providerEndTimeISO, duration: totalDuration, @@ -602,8 +622,3 @@ export const ollamaProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */ diff --git a/apps/sim/providers/openai/core.ts b/apps/sim/providers/openai/core.ts index 6946f0c0fa3..6f19cef1562 100644 --- a/apps/sim/providers/openai/core.ts +++ b/apps/sim/providers/openai/core.ts @@ -8,6 +8,7 @@ import type { Message, ProviderRequest, ProviderResponse, TimeSegment } from '@/ import { ProviderError } from '@/providers/types' import { calculateCost, + enforceStrictSchema, prepareToolExecution, prepareToolsWithUsageControl, sumToolCosts, @@ -31,60 +32,6 @@ import { type PreparedTools = ReturnType type ToolChoice = PreparedTools['toolChoice'] -/** - * Recursively enforces OpenAI strict mode requirements on a JSON schema. - * - Sets additionalProperties: false on all object types. - * - Ensures required includes ALL property keys. - */ -function enforceStrictSchema(schema: Record): Record { - if (!schema || typeof schema !== 'object') return schema - - const result = { ...schema } - - // If this is an object type, enforce strict requirements - if (result.type === 'object') { - result.additionalProperties = false - - // Recursively process properties and ensure required includes all keys - if (result.properties && typeof result.properties === 'object') { - const propKeys = Object.keys(result.properties as Record) - result.required = propKeys // Strict mode requires ALL properties - result.properties = Object.fromEntries( - Object.entries(result.properties as Record).map(([key, value]) => [ - key, - enforceStrictSchema(value as Record), - ]) - ) - } - } - - // Handle array items - if (result.type === 'array' && result.items) { - result.items = enforceStrictSchema(result.items as Record) - } - - // Handle anyOf, oneOf, allOf - for (const keyword of ['anyOf', 'oneOf', 'allOf']) { - if (Array.isArray(result[keyword])) { - result[keyword] = (result[keyword] as Record[]).map(enforceStrictSchema) - } - } - - // Handle $defs / definitions - for (const defKey of ['$defs', 'definitions']) { - if (result[defKey] && typeof result[defKey] === 'object') { - result[defKey] = Object.fromEntries( - Object.entries(result[defKey] as Record).map(([key, value]) => [ - key, - enforceStrictSchema(value as Record), - ]) - ) - } - } - - return result -} - export interface ResponsesProviderConfig { providerId: string providerLabel: string diff --git a/apps/sim/providers/openrouter/index.test.ts b/apps/sim/providers/openrouter/index.test.ts new file mode 100644 index 00000000000..88339fe4a93 --- /dev/null +++ b/apps/sim/providers/openrouter/index.test.ts @@ -0,0 +1,345 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCreate, + mockExecuteTool, + mockSupportsNative, + mockPrepareTools, + mockCheckForced, + mockCreateStream, +} = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), + mockSupportsNative: vi.fn(), + mockPrepareTools: vi.fn((tools: unknown) => ({ + tools, + toolChoice: 'auto', + forcedTools: [], + hasFilteredTools: false, + })), + mockCheckForced: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), + mockCreateStream: vi.fn(), +})) + +vi.mock('openai', () => ({ + default: vi.fn().mockImplementation(() => ({ + chat: { completions: { create: mockCreate } }, + })), +})) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 10 })) + +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn().mockReturnValue([]), + getProviderDefaultModel: vi.fn().mockReturnValue(''), +})) + +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: vi.fn((messages: unknown) => messages), +})) + +vi.mock('@/providers/openrouter/utils', () => ({ + supportsNativeStructuredOutputs: mockSupportsNative, + createReadableStreamFromOpenAIStream: mockCreateStream, + checkForForcedToolUsage: mockCheckForced, +})) + +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn(() => ({ input: 0, output: 0, total: 0 })), + prepareToolsWithUsageControl: mockPrepareTools, + prepareToolExecution: vi.fn((_tool: unknown, toolArgs: Record) => ({ + toolParams: toolArgs, + executionParams: toolArgs, + })), + sumToolCosts: vi.fn(() => 0), + generateSchemaInstructions: vi.fn(() => 'SCHEMA_INSTRUCTIONS'), +})) + +import { openRouterProvider } from '@/providers/openrouter/index' +import type { ProviderRequest, ProviderResponse, ProviderToolConfig } from '@/providers/types' + +interface Usage { + prompt_tokens: number + completion_tokens: number + total_tokens: number +} + +function textResponse( + content: string, + usage: Usage = { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 } +) { + return { + choices: [{ message: { content, tool_calls: undefined }, finish_reason: 'stop' }], + usage, + } +} + +function toolCallResponse(name: string, args: Record, id = 'call_1') { + return { + choices: [ + { + message: { + content: null, + tool_calls: [ + { id, type: 'function', function: { name, arguments: JSON.stringify(args) } }, + ], + }, + finish_reason: 'tool_calls', + }, + ], + usage: { prompt_tokens: 8, completion_tokens: 4, total_tokens: 12 }, + } +} + +function tool(id: string): ProviderToolConfig { + return { + id, + name: id, + description: 'test tool', + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, + } +} + +const baseRequest: ProviderRequest = { + apiKey: 'sk-or-test', + model: 'openrouter/anthropic/claude-3.5-sonnet', + systemPrompt: 'You are helpful.', + messages: [{ role: 'user', content: 'Hello' }], +} + +describe('openRouterProvider.executeRequest', () => { + beforeEach(() => { + vi.clearAllMocks() + mockCreate.mockReset() + mockExecuteTool.mockReset() + mockSupportsNative.mockResolvedValue(false) + }) + + it('requires an API key', async () => { + await expect( + openRouterProvider.executeRequest({ model: 'openrouter/x', messages: [] }) + ).rejects.toThrow('API key is required for OpenRouter') + }) + + it('strips the openrouter/ prefix and returns content + tokens', async () => { + mockCreate.mockResolvedValueOnce(textResponse('Hi there')) + + const res = (await openRouterProvider.executeRequest(baseRequest)) as ProviderResponse + + expect(res.content).toBe('Hi there') + expect(res.model).toBe('anthropic/claude-3.5-sonnet') + expect(res.tokens).toEqual({ input: 10, output: 5, total: 15 }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.model).toBe('anthropic/claude-3.5-sonnet') + expect(payload.messages[0]).toEqual({ role: 'system', content: 'You are helpful.' }) + expect(payload.messages.at(-1)).toEqual({ role: 'user', content: 'Hello' }) + }) + + it('inserts context as a user message between system and history', async () => { + mockCreate.mockResolvedValueOnce(textResponse('ok')) + + await openRouterProvider.executeRequest({ ...baseRequest, context: 'CTX' }) + + const { messages } = mockCreate.mock.calls[0][0] + expect(messages[0]).toEqual({ role: 'system', content: 'You are helpful.' }) + expect(messages[1]).toEqual({ role: 'user', content: 'CTX' }) + expect(messages[2]).toEqual({ role: 'user', content: 'Hello' }) + }) + + it('forwards maxTokens as max_tokens and temperature', async () => { + mockCreate.mockResolvedValueOnce(textResponse('ok')) + + await openRouterProvider.executeRequest({ ...baseRequest, maxTokens: 256, temperature: 0.4 }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.max_tokens).toBe(256) + expect(payload.temperature).toBe(0.4) + }) + + it('runs the tool loop: executes the tool, echoes tool_calls, returns the tool result, sums tokens', async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse('get_weather', { city: 'SF' })) + .mockResolvedValueOnce( + textResponse('It is sunny', { prompt_tokens: 20, completion_tokens: 6, total_tokens: 26 }) + ) + mockExecuteTool.mockResolvedValueOnce({ success: true, output: { temp: 70 } }) + + const res = (await openRouterProvider.executeRequest({ + ...baseRequest, + tools: [tool('get_weather')], + })) as ProviderResponse + + expect(mockExecuteTool).toHaveBeenCalledWith('get_weather', { city: 'SF' }, expect.anything()) + expect(res.content).toBe('It is sunny') + expect(res.toolCalls?.[0]).toMatchObject({ + name: 'get_weather', + result: { temp: 70 }, + success: true, + }) + expect(res.toolResults).toEqual([{ temp: 70 }]) + expect(res.tokens).toEqual({ input: 28, output: 10, total: 38 }) + + const secondMessages = mockCreate.mock.calls[1][0].messages + const assistant = secondMessages.find((m: { role: string }) => m.role === 'assistant') + expect(assistant).toMatchObject({ + content: null, + tool_calls: [{ id: 'call_1', type: 'function', function: { name: 'get_weather' } }], + }) + const toolMsg = secondMessages.find((m: { role: string }) => m.role === 'tool') + expect(toolMsg).toEqual({ + role: 'tool', + tool_call_id: 'call_1', + content: JSON.stringify({ temp: 70 }), + }) + }) + + it('reports a failed tool result as an error payload to the model', async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse('get_weather', { city: 'SF' })) + .mockResolvedValueOnce(textResponse('done')) + mockExecuteTool.mockResolvedValueOnce({ success: false, output: undefined, error: 'boom' }) + + const res = (await openRouterProvider.executeRequest({ + ...baseRequest, + tools: [tool('get_weather')], + })) as ProviderResponse + + expect(res.toolResults).toBeUndefined() + expect(res.toolCalls?.[0]).toMatchObject({ success: false }) + const toolMsg = mockCreate.mock.calls[1][0].messages.find( + (m: { role: string }) => m.role === 'tool' + ) + expect(JSON.parse(toolMsg.content)).toEqual({ + error: true, + message: 'boom', + tool: 'get_weather', + }) + }) + + it('applies native structured outputs (json_schema + require_parameters) when no tools are active', async () => { + mockSupportsNative.mockResolvedValue(true) + mockCreate.mockResolvedValueOnce(textResponse('{"x":1}')) + + await openRouterProvider.executeRequest({ + ...baseRequest, + responseFormat: { + name: 'out', + schema: { type: 'object', properties: { x: { type: 'number' } } }, + strict: true, + }, + }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.response_format).toMatchObject({ + type: 'json_schema', + json_schema: { name: 'out', strict: true }, + }) + expect(payload.provider).toMatchObject({ require_parameters: true }) + }) + + it('falls back to json_object + prompt instructions when native structured outputs are unsupported', async () => { + mockSupportsNative.mockResolvedValue(false) + mockCreate.mockResolvedValueOnce(textResponse('{"x":1}')) + + await openRouterProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'out', schema: { type: 'object' } }, + }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.response_format).toEqual({ type: 'json_object' }) + expect(payload.messages.at(-1)).toEqual({ role: 'user', content: 'SCHEMA_INSTRUCTIONS' }) + }) + + it('defers response_format until after the tool loop when tools are active', async () => { + mockSupportsNative.mockResolvedValue(true) + mockCreate + .mockResolvedValueOnce(textResponse('interim')) + .mockResolvedValueOnce(textResponse('{"x":1}')) + + const res = (await openRouterProvider.executeRequest({ + ...baseRequest, + tools: [tool('get_weather')], + responseFormat: { name: 'out', schema: { type: 'object' }, strict: true }, + })) as ProviderResponse + + const toolCall = mockCreate.mock.calls[0][0] + expect(toolCall.tools).toBeDefined() + expect(toolCall.response_format).toBeUndefined() + + const finalCall = mockCreate.mock.calls[1][0] + expect(finalCall.response_format).toMatchObject({ type: 'json_schema' }) + expect(finalCall.tools).toBeUndefined() + expect(finalCall.tool_choice).toBeUndefined() + expect(res.content).toBe('{"x":1}') + }) + + it('forces the next tool after a forced tool is used', async () => { + mockPrepareTools.mockReturnValueOnce({ + tools: [tool('a')], + toolChoice: { type: 'function', function: { name: 'a' } }, + forcedTools: ['a', 'b'], + hasFilteredTools: false, + }) + mockCheckForced.mockReturnValueOnce({ hasUsedForcedTool: true, usedForcedTools: ['a'] }) + mockCreate + .mockResolvedValueOnce(toolCallResponse('a', {})) + .mockResolvedValueOnce(textResponse('done')) + mockExecuteTool.mockResolvedValueOnce({ success: true, output: {} }) + + await openRouterProvider.executeRequest({ ...baseRequest, tools: [tool('a'), tool('b')] }) + + expect(mockCreate.mock.calls[0][0].tool_choice).toEqual({ + type: 'function', + function: { name: 'a' }, + }) + expect(mockCreate.mock.calls[1][0].tool_choice).toEqual({ + type: 'function', + function: { name: 'b' }, + }) + }) + + it('streams directly when there are no tools and sends usage opt-in', async () => { + mockCreate.mockResolvedValueOnce({}) + + const res = await openRouterProvider.executeRequest({ ...baseRequest, stream: true }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.stream).toBe(true) + expect(payload.stream_options).toEqual({ include_usage: true }) + expect(mockCreateStream).toHaveBeenCalledTimes(1) + expect(res).toHaveProperty('stream') + expect(res).toHaveProperty('execution.output.model', 'anthropic/claude-3.5-sonnet') + }) + + it('stops the tool loop at MAX_TOOL_ITERATIONS', async () => { + mockCreate.mockResolvedValue(toolCallResponse('looping', {})) + mockExecuteTool.mockResolvedValue({ success: true, output: {} }) + + const res = (await openRouterProvider.executeRequest({ + ...baseRequest, + tools: [tool('looping')], + })) as ProviderResponse + + expect(mockCreate).toHaveBeenCalledTimes(11) + expect(mockExecuteTool).toHaveBeenCalledTimes(10) + expect(res.toolCalls?.length).toBe(10) + }) + + it('wraps SDK errors in a ProviderError', async () => { + mockCreate.mockRejectedValueOnce(new Error('rate limited')) + + await expect(openRouterProvider.executeRequest(baseRequest)).rejects.toThrow('rate limited') + }) +}) diff --git a/apps/sim/providers/openrouter/index.ts b/apps/sim/providers/openrouter/index.ts index d3d2535b43d..9bc180bdd11 100644 --- a/apps/sim/providers/openrouter/index.ts +++ b/apps/sim/providers/openrouter/index.ts @@ -376,8 +376,8 @@ export const openRouterProvider: ProviderConfig = { }) let resultContent: any - if (result.success) { - toolResults.push(result.output!) + if (result.success && result.output) { + toolResults.push(result.output) resultContent = result.output } else { resultContent = { @@ -653,8 +653,3 @@ export const openRouterProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */ diff --git a/apps/sim/providers/openrouter/utils.ts b/apps/sim/providers/openrouter/utils.ts index 8d8dade8279..51637f5148d 100644 --- a/apps/sim/providers/openrouter/utils.ts +++ b/apps/sim/providers/openrouter/utils.ts @@ -20,9 +20,6 @@ let modelCapabilitiesCache: Map | null = null let cacheTimestamp = 0 const CACHE_TTL_MS = 5 * 60 * 1000 // 5 minutes -/** - * Fetches and caches OpenRouter model capabilities from their API. - */ async function fetchModelCapabilities(): Promise> { try { const response = await fetch('https://openrouter.ai/api/v1/models', { @@ -82,18 +79,11 @@ export async function getOpenRouterModelCapabilities( return modelCapabilitiesCache.get(normalizedId) ?? null } -/** - * Checks if a model supports native structured outputs (json_schema). - */ export async function supportsNativeStructuredOutputs(modelId: string): Promise { const capabilities = await getOpenRouterModelCapabilities(modelId) return capabilities?.supportsStructuredOutputs ?? false } -/** - * Creates a ReadableStream from an OpenRouter streaming response. - * Uses the shared OpenAI-compatible streaming utility. - */ export function createReadableStreamFromOpenAIStream( openaiStream: AsyncIterable, onComplete?: (content: string, usage: CompletionUsage) => void @@ -101,10 +91,6 @@ export function createReadableStreamFromOpenAIStream( return createOpenAICompatibleStream(openaiStream, 'OpenRouter', onComplete) } -/** - * Checks if a forced tool was used in an OpenRouter response. - * Uses the shared OpenAI-compatible forced tool usage helper. - */ export function checkForForcedToolUsage( response: any, toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, diff --git a/apps/sim/providers/utils.ts b/apps/sim/providers/utils.ts index 205fb307873..22466f8a143 100644 --- a/apps/sim/providers/utils.ts +++ b/apps/sim/providers/utils.ts @@ -669,6 +669,59 @@ export function calculateCost( } } +/** + * Recursively enforces OpenAI strict-mode requirements on a JSON schema: + * - Sets `additionalProperties: false` on every object type. + * - Forces `required` to include ALL property keys. + * + * Required for any OpenAI-compatible backend that validates strict structured + * outputs (OpenAI, Azure OpenAI, and OpenAI routes behind proxies like LiteLLM), + * which reject schemas missing these constraints with an HTTP 400. + */ +export function enforceStrictSchema(schema: Record): Record { + if (!schema || typeof schema !== 'object') return schema + + const result = { ...schema } + + if (result.type === 'object') { + result.additionalProperties = false + + if (result.properties && typeof result.properties === 'object') { + const propKeys = Object.keys(result.properties as Record) + result.required = propKeys + result.properties = Object.fromEntries( + Object.entries(result.properties as Record).map(([key, value]) => [ + key, + enforceStrictSchema(value as Record), + ]) + ) + } + } + + if (result.type === 'array' && result.items) { + result.items = enforceStrictSchema(result.items as Record) + } + + for (const keyword of ['anyOf', 'oneOf', 'allOf']) { + if (Array.isArray(result[keyword])) { + result[keyword] = (result[keyword] as Record[]).map(enforceStrictSchema) + } + } + + for (const defKey of ['$defs', 'definitions']) { + if (result[defKey] && typeof result[defKey] === 'object') { + result[defKey] = Object.fromEntries( + Object.entries(result[defKey] as Record).map(([key, value]) => [ + key, + enforceStrictSchema(value as Record), + ]) + ) + } + } + + return result +} + /** * Sums the `cost.total` from each tool result returned during a provider tool loop. * Tool results may carry a `cost` object injected by `applyHostedKeyCostToResult`. diff --git a/apps/sim/providers/vllm/index.test.ts b/apps/sim/providers/vllm/index.test.ts new file mode 100644 index 00000000000..6727b27b0fe --- /dev/null +++ b/apps/sim/providers/vllm/index.test.ts @@ -0,0 +1,296 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCreate, + mockExecuteTool, + mockPrepareTools, + mockCheckForced, + mockCreateStream, + envState, +} = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), + mockPrepareTools: vi.fn(), + mockCheckForced: vi.fn(), + mockCreateStream: vi.fn(), + envState: { + VLLM_BASE_URL: 'http://localhost:8000', + VLLM_API_KEY: undefined as string | undefined, + }, +})) + +vi.mock('openai', () => ({ + default: vi.fn(() => ({ chat: { completions: { create: mockCreate } } })), +})) +vi.mock('@/lib/core/config/env', () => ({ env: envState })) +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn(() => []), + getProviderDefaultModel: vi.fn(() => 'vllm/generic'), +})) +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: vi.fn((messages) => messages), +})) +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn(() => ({ input: 0, output: 0, total: 0 })), + prepareToolExecution: vi.fn((_tool, args) => ({ toolParams: args, executionParams: args })), + prepareToolsWithUsageControl: mockPrepareTools, + sumToolCosts: vi.fn(() => 0), +})) +vi.mock('@/providers/vllm/utils', () => ({ + checkForForcedToolUsage: mockCheckForced, + createReadableStreamFromVLLMStream: mockCreateStream, +})) +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) +vi.mock('@/stores/providers', () => ({ + useProvidersStore: { getState: () => ({ setProviderModels: vi.fn() }) }, +})) + +import type { ProviderToolConfig } from '@/providers/types' +import { vllmProvider } from '@/providers/vllm/index' + +interface ToolCall { + id: string + type: 'function' + function: { name: string; arguments: string } +} + +function chatResponse(content: string | null, toolCalls?: ToolCall[]) { + return { + choices: [{ message: { content, tool_calls: toolCalls } }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } +} + +function makeTool(id: string): ProviderToolConfig { + return { + id, + name: id, + description: '', + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, + } +} + +const toolCall = (id: string, name: string, args = '{}'): ToolCall => ({ + id, + type: 'function', + function: { name, arguments: args }, +}) + +/** Payload passed to the Nth `chat.completions.create` call. */ +const createPayload = (callIndex: number) => mockCreate.mock.calls[callIndex][0] + +describe('vllmProvider', () => { + beforeEach(() => { + vi.clearAllMocks() + envState.VLLM_BASE_URL = 'http://localhost:8000' + envState.VLLM_API_KEY = undefined + mockPrepareTools.mockReturnValue({ + tools: [{ type: 'function', function: { name: 'myTool' } }], + toolChoice: 'auto', + forcedTools: [], + hasFilteredTools: false, + }) + mockCheckForced.mockReturnValue({ hasUsedForcedTool: false, usedForcedTools: [] }) + mockCreateStream.mockReturnValue(new ReadableStream({ start: (c) => c.close() })) + mockExecuteTool.mockResolvedValue({ success: true, output: { result: 'ok' } }) + }) + + it('builds a chat payload with the vllm/ prefix stripped and messages assembled in order', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('hello')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + systemPrompt: 'be helpful', + context: 'prior context', + messages: [{ role: 'user', content: 'hi' }], + temperature: 0.7, + maxTokens: 256, + }) + + const payload = createPayload(0) + expect(payload.model).toBe('llama-3') + expect(payload.temperature).toBe(0.7) + expect(payload.max_completion_tokens).toBe(256) + expect(payload.messages.map((m: { role: string }) => m.role)).toEqual([ + 'system', + 'user', + 'user', + ]) + expect(result.content).toBe('hello') + expect(result.tokens).toEqual({ input: 10, output: 5, total: 15 }) + }) + + it('sends response_format as json_schema with strict when a responseFormat is provided', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('{}')) + + await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + responseFormat: { name: 'out', schema: { type: 'object' }, strict: true }, + }) + + expect(createPayload(0).response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'out', schema: { type: 'object' }, strict: true }, + }) + }) + + it('strips markdown code fences from structured-output content', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('```json\n{"a":1}\n```')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + responseFormat: { name: 'out', schema: { type: 'object' }, strict: true }, + }) + + expect(result.content).toBe('{"a":1}') + }) + + it('runs the tool loop: executes tools, appends assistant + tool messages, returns results', async () => { + mockCreate + .mockResolvedValueOnce(chatResponse(null, [toolCall('call_1', 'myTool', '{"x":1}')])) + .mockResolvedValueOnce(chatResponse('final answer')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'use a tool' }], + tools: [makeTool('myTool')], + }) + + expect(mockExecuteTool).toHaveBeenCalledWith('myTool', { x: 1 }, expect.anything()) + + const [assistantMessage, toolMessage] = createPayload(1).messages.slice(-2) + expect(assistantMessage).toMatchObject({ + role: 'assistant', + content: null, + tool_calls: [{ id: 'call_1', type: 'function', function: { name: 'myTool' } }], + }) + expect(toolMessage).toMatchObject({ role: 'tool', tool_call_id: 'call_1' }) + expect(toolMessage).not.toHaveProperty('name') + + expect(result.content).toBe('final answer') + expect(result.toolCalls).toHaveLength(1) + expect(result.toolCalls?.[0]).toMatchObject({ name: 'myTool', success: true }) + expect(result.toolResults).toHaveLength(1) + }) + + it('records a failed tool result without throwing', async () => { + mockExecuteTool.mockResolvedValueOnce({ success: false, error: 'tool blew up' }) + mockCreate + .mockResolvedValueOnce(chatResponse(null, [toolCall('call_1', 'myTool')])) + .mockResolvedValueOnce(chatResponse('done')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'go' }], + tools: [makeTool('myTool')], + }) + + expect(result.toolCalls?.[0]).toMatchObject({ name: 'myTool', success: false }) + const toolMessage = createPayload(1).messages.at(-1) + expect(JSON.parse(toolMessage.content)).toMatchObject({ error: true, tool: 'myTool' }) + }) + + it('preserves partial results when a follow-up model call fails mid-loop', async () => { + mockCreate + .mockResolvedValueOnce(chatResponse(null, [toolCall('call_1', 'myTool')])) + .mockRejectedValueOnce(new Error('connection reset')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'go' }], + tools: [makeTool('myTool')], + }) + + expect(mockExecuteTool).toHaveBeenCalledTimes(1) + expect(result.toolCalls).toHaveLength(1) + expect(result.toolResults).toHaveLength(1) + }) + + it('cycles forced tools: forces the next forced tool after the first is used', async () => { + mockPrepareTools.mockReturnValue({ + tools: [{ type: 'function', function: { name: 'toolA' } }], + toolChoice: { type: 'function', function: { name: 'toolA' } }, + forcedTools: ['toolA', 'toolB'], + hasFilteredTools: false, + }) + mockCheckForced + .mockReturnValueOnce({ hasUsedForcedTool: true, usedForcedTools: ['toolA'] }) + .mockReturnValueOnce({ hasUsedForcedTool: true, usedForcedTools: ['toolA', 'toolB'] }) + mockCreate + .mockResolvedValueOnce(chatResponse(null, [toolCall('c1', 'toolA')])) + .mockResolvedValueOnce(chatResponse('done')) + + await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'go' }], + tools: [makeTool('toolA'), makeTool('toolB')], + }) + + expect(createPayload(1).tool_choice).toEqual({ type: 'function', function: { name: 'toolB' } }) + }) + + it('streams directly when there are no tools, requesting usage in the stream', async () => { + mockCreate.mockResolvedValueOnce({}) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + stream: true, + }) + + expect(mockCreate).toHaveBeenCalledTimes(1) + const payload = createPayload(0) + expect(payload.stream).toBe(true) + expect(payload.stream_options).toEqual({ include_usage: true }) + expect('stream' in result && 'execution' in result).toBe(true) + }) + + it('uses tool_choice "none" on the final streaming call after tool processing', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('answer')).mockResolvedValueOnce({}) + + await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + stream: true, + tools: [makeTool('myTool')], + }) + + const streamingPayload = createPayload(1) + expect(streamingPayload.stream).toBe(true) + expect(streamingPayload.tool_choice).toBe('none') + }) + + it('throws a ProviderError carrying the vLLM error message on API failure', async () => { + mockCreate.mockRejectedValueOnce({ + error: { message: 'bad request', type: 'invalid', code: 400 }, + }) + + await expect( + vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + }) + ).rejects.toThrow('bad request') + }) + + it('throws when no base URL is configured', async () => { + envState.VLLM_BASE_URL = '' + + await expect( + vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + }) + ).rejects.toThrow('VLLM_BASE_URL is required') + }) +}) diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts index 2de3c695116..5310a5dd26e 100644 --- a/apps/sim/providers/vllm/index.ts +++ b/apps/sim/providers/vllm/index.ts @@ -21,9 +21,8 @@ import { prepareToolExecution, prepareToolsWithUsageControl, sumToolCosts, - trackForcedToolUsage, } from '@/providers/utils' -import { createReadableStreamFromVLLMStream } from '@/providers/vllm/utils' +import { checkForForcedToolUsage, createReadableStreamFromVLLMStream } from '@/providers/vllm/utils' import { useProvidersStore } from '@/stores/providers' import { executeTool } from '@/tools' @@ -282,25 +281,7 @@ export const vllmProvider: ProviderConfig = { const forcedTools = preparedTools?.forcedTools || [] let usedForcedTools: string[] = [] - - const checkForForcedToolUsage = ( - response: any, - toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any } - ) => { - if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { - const toolCallsResponse = response.choices[0].message.tool_calls - const result = trackForcedToolUsage( - toolCallsResponse, - toolChoice, - logger, - 'vllm', - forcedTools, - usedForcedTools - ) - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - } - } + let hasUsedForcedTool = false let currentResponse = await vllm.chat.completions.create( payload, @@ -327,8 +308,6 @@ export const vllmProvider: ProviderConfig = { let modelTime = firstResponseTime let toolsTime = 0 - let hasUsedForcedTool = false - const timeSegments: TimeSegment[] = [ { type: 'model', @@ -339,207 +318,233 @@ export const vllmProvider: ProviderConfig = { }, ] - checkForForcedToolUsage(currentResponse, originalToolChoice) + if (originalToolChoice) { + const forcedResult = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = forcedResult.hasUsedForcedTool + usedForcedTools = forcedResult.usedForcedTools + } - while (iterationCount < MAX_TOOL_ITERATIONS) { - if (currentResponse.choices[0]?.message?.content) { - content = currentResponse.choices[0].message.content - if (request.responseFormat) { - content = content.replace(/```json\n?|\n?```/g, '').trim() + try { + while (iterationCount < MAX_TOOL_ITERATIONS) { + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } } - } - const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls + const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls - enrichLastModelSegmentFromChatCompletions( - timeSegments, - currentResponse, - toolCallsInResponse, - { model: request.model, provider: 'vllm' } - ) + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + toolCallsInResponse, + { model: request.model, provider: 'vllm' } + ) - if (!toolCallsInResponse || toolCallsInResponse.length === 0) { - break - } + if (!toolCallsInResponse || toolCallsInResponse.length === 0) { + break + } - logger.info( - `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` - ) + logger.info( + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` + ) - const toolsStartTime = Date.now() + const toolsStartTime = Date.now() - const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => { - const toolCallStartTime = Date.now() - const toolName = toolCall.function.name + const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => { + const toolCallStartTime = Date.now() + const toolName = toolCall.function.name - try { - const toolArgs = JSON.parse(toolCall.function.arguments) - const tool = request.tools?.find((t) => t.id === toolName) + try { + const toolArgs = JSON.parse(toolCall.function.arguments) + const tool = request.tools?.find((t) => t.id === toolName) - if (!tool) return null + if (!tool) return null - const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams, { - signal: request.abortSignal, - }) - const toolCallEndTime = Date.now() - - return { - toolCall, - toolName, - toolParams, - result, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration: toolCallEndTime - toolCallStartTime, - } - } catch (error) { - const toolCallEndTime = Date.now() - logger.error('Error processing tool call:', { error, toolName }) - - return { - toolCall, - toolName, - toolParams: {}, - result: { - success: false, - output: undefined, - error: getErrorMessage(error, 'Tool execution failed'), - }, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration: toolCallEndTime - toolCallStartTime, + const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) + const toolCallEndTime = Date.now() + + return { + toolCall, + toolName, + toolParams, + result, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } + } catch (error) { + const toolCallEndTime = Date.now() + logger.error('Error processing tool call:', { error, toolName }) + + return { + toolCall, + toolName, + toolParams: {}, + result: { + success: false, + output: undefined, + error: getErrorMessage(error, 'Tool execution failed'), + }, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } } - } - }) + }) - const executionResults = await Promise.allSettled(toolExecutionPromises) - - currentMessages.push({ - role: 'assistant', - content: null, - tool_calls: toolCallsInResponse.map((tc) => ({ - id: tc.id, - type: 'function', - function: { - name: tc.function.name, - arguments: tc.function.arguments, - }, - })), - }) + const executionResults = await Promise.allSettled(toolExecutionPromises) - for (const settledResult of executionResults) { - if (settledResult.status === 'rejected' || !settledResult.value) continue + currentMessages.push({ + role: 'assistant', + content: null, + tool_calls: toolCallsInResponse.map((tc) => ({ + id: tc.id, + type: 'function', + function: { + name: tc.function.name, + arguments: tc.function.arguments, + }, + })), + }) - const { toolCall, toolName, toolParams, result, startTime, endTime, duration } = - settledResult.value + for (const settledResult of executionResults) { + if (settledResult.status === 'rejected' || !settledResult.value) continue - timeSegments.push({ - type: 'tool', - name: toolName, - startTime: startTime, - endTime: endTime, - duration: duration, - toolCallId: toolCall.id, - }) + const { toolCall, toolName, toolParams, result, startTime, endTime, duration } = + settledResult.value - let resultContent: any - if (result.success && result.output) { - toolResults.push(result.output) - resultContent = result.output - } else { - resultContent = { - error: true, - message: result.error || 'Tool execution failed', - tool: toolName, - } - } + timeSegments.push({ + type: 'tool', + name: toolName, + startTime: startTime, + endTime: endTime, + duration: duration, + toolCallId: toolCall.id, + }) - toolCalls.push({ - name: toolName, - arguments: toolParams, - startTime: new Date(startTime).toISOString(), - endTime: new Date(endTime).toISOString(), - duration: duration, - result: resultContent, - success: result.success, - }) + let resultContent: any + if (result.success && result.output) { + toolResults.push(result.output) + resultContent = result.output + } else { + resultContent = { + error: true, + message: result.error || 'Tool execution failed', + tool: toolName, + } + } - currentMessages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: JSON.stringify(resultContent), - }) - } + toolCalls.push({ + name: toolName, + arguments: toolParams, + startTime: new Date(startTime).toISOString(), + endTime: new Date(endTime).toISOString(), + duration: duration, + result: resultContent, + success: result.success, + }) - const thisToolsTime = Date.now() - toolsStartTime - toolsTime += thisToolsTime + currentMessages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(resultContent), + }) + } - const nextPayload = { - ...payload, - messages: currentMessages, - } + const thisToolsTime = Date.now() - toolsStartTime + toolsTime += thisToolsTime - if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) { - const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool)) + const nextPayload = { + ...payload, + messages: currentMessages, + } - if (remainingTools.length > 0) { - nextPayload.tool_choice = { - type: 'function', - function: { name: remainingTools[0] }, + if ( + typeof originalToolChoice === 'object' && + hasUsedForcedTool && + forcedTools.length > 0 + ) { + const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool)) + + if (remainingTools.length > 0) { + nextPayload.tool_choice = { + type: 'function', + function: { name: remainingTools[0] }, + } + logger.info(`Forcing next tool: ${remainingTools[0]}`) + } else { + nextPayload.tool_choice = 'auto' + logger.info('All forced tools have been used, switching to auto tool_choice') } - logger.info(`Forcing next tool: ${remainingTools[0]}`) - } else { - nextPayload.tool_choice = 'auto' - logger.info('All forced tools have been used, switching to auto tool_choice') } - } - const nextModelStartTime = Date.now() + const nextModelStartTime = Date.now() - currentResponse = await vllm.chat.completions.create( - nextPayload, - request.abortSignal ? { signal: request.abortSignal } : undefined - ) + currentResponse = await vllm.chat.completions.create( + nextPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) - checkForForcedToolUsage(currentResponse, nextPayload.tool_choice) + if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') { + const forcedResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = forcedResult.hasUsedForcedTool + usedForcedTools = forcedResult.usedForcedTools + } - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime - timeSegments.push({ - type: 'model', - name: request.model, - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) + timeSegments.push({ + type: 'model', + name: request.model, + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) - modelTime += thisModelTime + modelTime += thisModelTime - if (currentResponse.choices[0]?.message?.content) { - content = currentResponse.choices[0].message.content - if (request.responseFormat) { - content = content.replace(/```json\n?|\n?```/g, '').trim() + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } } - } - if (currentResponse.usage) { - tokens.input += currentResponse.usage.prompt_tokens || 0 - tokens.output += currentResponse.usage.completion_tokens || 0 - tokens.total += currentResponse.usage.total_tokens || 0 - } + if (currentResponse.usage) { + tokens.input += currentResponse.usage.prompt_tokens || 0 + tokens.output += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } - iterationCount++ - } + iterationCount++ + } - if (iterationCount === MAX_TOOL_ITERATIONS) { - enrichLastModelSegmentFromChatCompletions( - timeSegments, - currentResponse, - currentResponse.choices[0]?.message?.tool_calls, - { model: request.model, provider: 'vllm' } - ) + if (iterationCount === MAX_TOOL_ITERATIONS) { + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + currentResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: 'vllm' } + ) + } + } catch (error) { + logger.error('Error in vLLM tool processing:', { error }) } if (request.stream) { @@ -550,7 +555,7 @@ export const vllmProvider: ProviderConfig = { const streamingParams: ChatCompletionCreateParamsStreaming = { ...payload, messages: currentMessages, - tool_choice: 'auto', + tool_choice: 'none', stream: true, stream_options: { include_usage: true }, } @@ -685,8 +690,3 @@ export const vllmProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */ diff --git a/apps/sim/providers/vllm/utils.ts b/apps/sim/providers/vllm/utils.ts index 6f433291488..2b1db5bf553 100644 --- a/apps/sim/providers/vllm/utils.ts +++ b/apps/sim/providers/vllm/utils.ts @@ -1,6 +1,6 @@ import type { ChatCompletionChunk } from 'openai/resources/chat/completions' import type { CompletionUsage } from 'openai/resources/completions' -import { createOpenAICompatibleStream } from '@/providers/utils' +import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils' /** * Creates a ReadableStream from a vLLM streaming response. @@ -12,3 +12,16 @@ export function createReadableStreamFromVLLMStream( ): ReadableStream { return createOpenAICompatibleStream(vllmStream, 'vLLM', onComplete) } + +/** + * Checks if a forced tool was used in a vLLM response. + * Uses the shared OpenAI-compatible forced tool usage helper. + */ +export function checkForForcedToolUsage( + response: any, + toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, + forcedTools: string[], + usedForcedTools: string[] +): { hasUsedForcedTool: boolean; usedForcedTools: string[] } { + return checkForForcedToolUsageOpenAI(response, toolChoice, 'vLLM', forcedTools, usedForcedTools) +} From 4e286a213f10cbdd8d81053e334a00ac84a9dc11 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:14:12 -0700 Subject: [PATCH 2/9] fix(vllm): let tool-loop errors propagate instead of returning silent partial success --- apps/sim/providers/vllm/index.test.ts | 16 +- apps/sim/providers/vllm/index.ts | 354 +++++++++++++------------- 2 files changed, 181 insertions(+), 189 deletions(-) diff --git a/apps/sim/providers/vllm/index.test.ts b/apps/sim/providers/vllm/index.test.ts index 6727b27b0fe..4477beeeda7 100644 --- a/apps/sim/providers/vllm/index.test.ts +++ b/apps/sim/providers/vllm/index.test.ts @@ -200,20 +200,20 @@ describe('vllmProvider', () => { expect(JSON.parse(toolMessage.content)).toMatchObject({ error: true, tool: 'myTool' }) }) - it('preserves partial results when a follow-up model call fails mid-loop', async () => { + it('surfaces a ProviderError when a follow-up model call fails mid-loop', async () => { mockCreate .mockResolvedValueOnce(chatResponse(null, [toolCall('call_1', 'myTool')])) .mockRejectedValueOnce(new Error('connection reset')) - const result = await vllmProvider.executeRequest({ - model: 'vllm/llama-3', - messages: [{ role: 'user', content: 'go' }], - tools: [makeTool('myTool')], - }) + await expect( + vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'go' }], + tools: [makeTool('myTool')], + }) + ).rejects.toThrow('connection reset') expect(mockExecuteTool).toHaveBeenCalledTimes(1) - expect(result.toolCalls).toHaveLength(1) - expect(result.toolResults).toHaveLength(1) }) it('cycles forced tools: forces the next forced tool after the first is used', async () => { diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts index 5310a5dd26e..87610d9c43e 100644 --- a/apps/sim/providers/vllm/index.ts +++ b/apps/sim/providers/vllm/index.ts @@ -329,222 +329,214 @@ export const vllmProvider: ProviderConfig = { usedForcedTools = forcedResult.usedForcedTools } - try { - while (iterationCount < MAX_TOOL_ITERATIONS) { - if (currentResponse.choices[0]?.message?.content) { - content = currentResponse.choices[0].message.content - if (request.responseFormat) { - content = content.replace(/```json\n?|\n?```/g, '').trim() - } + while (iterationCount < MAX_TOOL_ITERATIONS) { + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() } + } - const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls - - enrichLastModelSegmentFromChatCompletions( - timeSegments, - currentResponse, - toolCallsInResponse, - { model: request.model, provider: 'vllm' } - ) + const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls - if (!toolCallsInResponse || toolCallsInResponse.length === 0) { - break - } + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + toolCallsInResponse, + { model: request.model, provider: 'vllm' } + ) - logger.info( - `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` - ) + if (!toolCallsInResponse || toolCallsInResponse.length === 0) { + break + } - const toolsStartTime = Date.now() + logger.info( + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` + ) - const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => { - const toolCallStartTime = Date.now() - const toolName = toolCall.function.name + const toolsStartTime = Date.now() - try { - const toolArgs = JSON.parse(toolCall.function.arguments) - const tool = request.tools?.find((t) => t.id === toolName) + const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => { + const toolCallStartTime = Date.now() + const toolName = toolCall.function.name - if (!tool) return null + try { + const toolArgs = JSON.parse(toolCall.function.arguments) + const tool = request.tools?.find((t) => t.id === toolName) - const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams, { - signal: request.abortSignal, - }) - const toolCallEndTime = Date.now() + if (!tool) return null - return { - toolCall, - toolName, - toolParams, - result, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration: toolCallEndTime - toolCallStartTime, - } - } catch (error) { - const toolCallEndTime = Date.now() - logger.error('Error processing tool call:', { error, toolName }) - - return { - toolCall, - toolName, - toolParams: {}, - result: { - success: false, - output: undefined, - error: getErrorMessage(error, 'Tool execution failed'), - }, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration: toolCallEndTime - toolCallStartTime, - } + const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) + const toolCallEndTime = Date.now() + + return { + toolCall, + toolName, + toolParams, + result, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, } - }) - - const executionResults = await Promise.allSettled(toolExecutionPromises) - - currentMessages.push({ - role: 'assistant', - content: null, - tool_calls: toolCallsInResponse.map((tc) => ({ - id: tc.id, - type: 'function', - function: { - name: tc.function.name, - arguments: tc.function.arguments, + } catch (error) { + const toolCallEndTime = Date.now() + logger.error('Error processing tool call:', { error, toolName }) + + return { + toolCall, + toolName, + toolParams: {}, + result: { + success: false, + output: undefined, + error: getErrorMessage(error, 'Tool execution failed'), }, - })), - }) + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } + } + }) - for (const settledResult of executionResults) { - if (settledResult.status === 'rejected' || !settledResult.value) continue + const executionResults = await Promise.allSettled(toolExecutionPromises) + + currentMessages.push({ + role: 'assistant', + content: null, + tool_calls: toolCallsInResponse.map((tc) => ({ + id: tc.id, + type: 'function', + function: { + name: tc.function.name, + arguments: tc.function.arguments, + }, + })), + }) - const { toolCall, toolName, toolParams, result, startTime, endTime, duration } = - settledResult.value + for (const settledResult of executionResults) { + if (settledResult.status === 'rejected' || !settledResult.value) continue - timeSegments.push({ - type: 'tool', - name: toolName, - startTime: startTime, - endTime: endTime, - duration: duration, - toolCallId: toolCall.id, - }) + const { toolCall, toolName, toolParams, result, startTime, endTime, duration } = + settledResult.value - let resultContent: any - if (result.success && result.output) { - toolResults.push(result.output) - resultContent = result.output - } else { - resultContent = { - error: true, - message: result.error || 'Tool execution failed', - tool: toolName, - } + timeSegments.push({ + type: 'tool', + name: toolName, + startTime: startTime, + endTime: endTime, + duration: duration, + toolCallId: toolCall.id, + }) + + let resultContent: any + if (result.success && result.output) { + toolResults.push(result.output) + resultContent = result.output + } else { + resultContent = { + error: true, + message: result.error || 'Tool execution failed', + tool: toolName, } + } - toolCalls.push({ - name: toolName, - arguments: toolParams, - startTime: new Date(startTime).toISOString(), - endTime: new Date(endTime).toISOString(), - duration: duration, - result: resultContent, - success: result.success, - }) + toolCalls.push({ + name: toolName, + arguments: toolParams, + startTime: new Date(startTime).toISOString(), + endTime: new Date(endTime).toISOString(), + duration: duration, + result: resultContent, + success: result.success, + }) - currentMessages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: JSON.stringify(resultContent), - }) - } + currentMessages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(resultContent), + }) + } - const thisToolsTime = Date.now() - toolsStartTime - toolsTime += thisToolsTime + const thisToolsTime = Date.now() - toolsStartTime + toolsTime += thisToolsTime - const nextPayload = { - ...payload, - messages: currentMessages, - } + const nextPayload = { + ...payload, + messages: currentMessages, + } - if ( - typeof originalToolChoice === 'object' && - hasUsedForcedTool && - forcedTools.length > 0 - ) { - const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool)) - - if (remainingTools.length > 0) { - nextPayload.tool_choice = { - type: 'function', - function: { name: remainingTools[0] }, - } - logger.info(`Forcing next tool: ${remainingTools[0]}`) - } else { - nextPayload.tool_choice = 'auto' - logger.info('All forced tools have been used, switching to auto tool_choice') + if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) { + const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool)) + + if (remainingTools.length > 0) { + nextPayload.tool_choice = { + type: 'function', + function: { name: remainingTools[0] }, } + logger.info(`Forcing next tool: ${remainingTools[0]}`) + } else { + nextPayload.tool_choice = 'auto' + logger.info('All forced tools have been used, switching to auto tool_choice') } + } - const nextModelStartTime = Date.now() + const nextModelStartTime = Date.now() - currentResponse = await vllm.chat.completions.create( - nextPayload, - request.abortSignal ? { signal: request.abortSignal } : undefined - ) - - if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') { - const forcedResult = checkForForcedToolUsage( - currentResponse, - nextPayload.tool_choice, - forcedTools, - usedForcedTools - ) - hasUsedForcedTool = forcedResult.hasUsedForcedTool - usedForcedTools = forcedResult.usedForcedTools - } + currentResponse = await vllm.chat.completions.create( + nextPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime + if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') { + const forcedResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = forcedResult.hasUsedForcedTool + usedForcedTools = forcedResult.usedForcedTools + } - timeSegments.push({ - type: 'model', - name: request.model, - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime - modelTime += thisModelTime + timeSegments.push({ + type: 'model', + name: request.model, + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) - if (currentResponse.choices[0]?.message?.content) { - content = currentResponse.choices[0].message.content - if (request.responseFormat) { - content = content.replace(/```json\n?|\n?```/g, '').trim() - } - } + modelTime += thisModelTime - if (currentResponse.usage) { - tokens.input += currentResponse.usage.prompt_tokens || 0 - tokens.output += currentResponse.usage.completion_tokens || 0 - tokens.total += currentResponse.usage.total_tokens || 0 + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() } - - iterationCount++ } - if (iterationCount === MAX_TOOL_ITERATIONS) { - enrichLastModelSegmentFromChatCompletions( - timeSegments, - currentResponse, - currentResponse.choices[0]?.message?.tool_calls, - { model: request.model, provider: 'vllm' } - ) + if (currentResponse.usage) { + tokens.input += currentResponse.usage.prompt_tokens || 0 + tokens.output += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 } - } catch (error) { - logger.error('Error in vLLM tool processing:', { error }) + + iterationCount++ + } + + if (iterationCount === MAX_TOOL_ITERATIONS) { + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + currentResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: 'vllm' } + ) } if (request.stream) { From 8f9428b0d8665a0f845931019978e4b454f3a1d3 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:15:59 -0700 Subject: [PATCH 3/9] fix(litellm): force tool_choice 'none' on final structured-output call The deferred final call used tool_choice 'auto', so the model could emit another tool_calls round instead of the structured answer, leaving content stale. Use 'none' (matching vLLM/Fireworks) on both the streaming and non-streaming final calls so the model must return the structured response. Co-Authored-By: Claude Opus 4.8 --- apps/sim/providers/litellm/index.test.ts | 3 ++- apps/sim/providers/litellm/index.ts | 16 ++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/apps/sim/providers/litellm/index.test.ts b/apps/sim/providers/litellm/index.test.ts index 653090a1160..84ac2f2fa0b 100644 --- a/apps/sim/providers/litellm/index.test.ts +++ b/apps/sim/providers/litellm/index.test.ts @@ -174,7 +174,7 @@ describe('litellmProvider.executeRequest', () => { const final = lastPayload() expect(final.response_format.type).toBe('json_schema') expect(final.tools).toBeDefined() - expect(final.tool_choice).toBe('auto') + expect(final.tool_choice).toBe('none') expect(final.parallel_tool_calls).toBe(false) expect(result.content).toBe('{"answer":1}') }) @@ -196,6 +196,7 @@ describe('litellmProvider.executeRequest', () => { expect(final.stream).toBe(true) expect(final.response_format.type).toBe('json_schema') expect(final.tools).toBeDefined() + expect(final.tool_choice).toBe('none') expect(final.parallel_tool_calls).toBe(false) expect(result.execution.isStreaming).toBe(true) }) diff --git a/apps/sim/providers/litellm/index.ts b/apps/sim/providers/litellm/index.ts index 10f3be87d31..eea8ea77a57 100644 --- a/apps/sim/providers/litellm/index.ts +++ b/apps/sim/providers/litellm/index.ts @@ -593,13 +593,15 @@ export const litellmProvider: ProviderConfig = { const streamingParams: ChatCompletionCreateParamsStreaming = { ...payload, messages: currentMessages, - tool_choice: 'auto', + // Tools are resolved; force a final answer so the model can't emit another + // tool_calls round the stream reader would drop. Keep tools defined for + // backends (e.g. Anthropic) that reject a tool-result history without them. + tool_choice: 'none', stream: true, stream_options: { include_usage: true }, } if (deferResponseFormat && responseFormatPayload) { - // Keep tools defined (Anthropic requires it once history holds tool results) and - // disable parallel calls (OpenAI's rule for strict outputs alongside tools). + // Disable parallel calls — OpenAI's rule for strict outputs alongside tools. streamingParams.response_format = responseFormatPayload streamingParams.parallel_tool_calls = false } @@ -689,10 +691,12 @@ export const litellmProvider: ProviderConfig = { model: payload.model, messages: currentMessages, response_format: responseFormatPayload, - // Keep tools defined (Anthropic requires it once history holds tool results) and - // disable parallel calls (OpenAI's rule for strict outputs alongside tools). + // Force the structured answer: 'none' stops the model from returning another + // tool_calls round (which would leave content stale). Keep tools defined for + // backends (e.g. Anthropic) that reject a tool-result history without them, and + // disable parallel calls per OpenAI's strict-outputs-with-tools rule. tools: payload.tools, - tool_choice: 'auto', + tool_choice: 'none', parallel_tool_calls: false, } if (request.temperature !== undefined) finalPayload.temperature = request.temperature From 7f6c49a36a331e2fb0ef2bfa45858c1f38f797c5 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:28:52 -0700 Subject: [PATCH 4/9] fix(providers/ollama): drop tools from post-tool streaming call Ollama ignores tool_choice (not in its supported fields), so vLLM/Fireworks' tool_choice:'none' guard is a no-op here. Omit tools from the final streaming payload instead so the summarization turn can't emit dropped tool calls. Co-Authored-By: Claude Opus 4.8 --- apps/sim/providers/ollama/index.test.ts | 4 ++++ apps/sim/providers/ollama/index.ts | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/apps/sim/providers/ollama/index.test.ts b/apps/sim/providers/ollama/index.test.ts index 3b409ec8216..a6b91e9e8f5 100644 --- a/apps/sim/providers/ollama/index.test.ts +++ b/apps/sim/providers/ollama/index.test.ts @@ -336,6 +336,10 @@ describe('ollamaProvider.executeRequest', () => { expect(result.stream).toBe('OLLAMA_STREAM') expect(mockExecuteTool).toHaveBeenCalledTimes(1) + const finalCall = mockCreate.mock.calls[2][0] + expect(finalCall.tools).toBeUndefined() + expect(finalCall.tool_choice).toBeUndefined() + streamOnComplete.current?.('final answer', { prompt_tokens: 2, completion_tokens: 4, diff --git a/apps/sim/providers/ollama/index.ts b/apps/sim/providers/ollama/index.ts index 3abb1ea958d..bfe7cff6134 100644 --- a/apps/sim/providers/ollama/index.ts +++ b/apps/sim/providers/ollama/index.ts @@ -482,10 +482,11 @@ export const ollamaProvider: ProviderConfig = { const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output) + const { tools: _tools, tool_choice: _toolChoice, ...streamPayload } = payload + const streamingParams: ChatCompletionCreateParamsStreaming = { - ...payload, + ...streamPayload, messages: currentMessages, - tool_choice: 'auto', stream: true, stream_options: { include_usage: true }, } From bfd6fac6d3c59e7c7f65ac483554ef3b5e4a4524 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:31:20 -0700 Subject: [PATCH 5/9] fix(litellm): spread payload into deferred final call so reasoning_effort carries over The non-streaming deferred finalPayload hand-picked fields and dropped reasoning_effort (and any future payload field), diverging from the streaming path which spreads ...payload. Spread payload here too for consistency. Co-Authored-By: Claude Opus 4.8 --- apps/sim/providers/litellm/index.test.ts | 2 ++ apps/sim/providers/litellm/index.ts | 14 ++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/apps/sim/providers/litellm/index.test.ts b/apps/sim/providers/litellm/index.test.ts index 84ac2f2fa0b..5261f4e23d6 100644 --- a/apps/sim/providers/litellm/index.test.ts +++ b/apps/sim/providers/litellm/index.test.ts @@ -165,6 +165,7 @@ describe('litellmProvider.executeRequest', () => { const result = await run({ tools: [tool('known')], + reasoningEffort: 'high', responseFormat: { name: 'r', schema: { type: 'object', properties: {} } }, }) @@ -176,6 +177,7 @@ describe('litellmProvider.executeRequest', () => { expect(final.tools).toBeDefined() expect(final.tool_choice).toBe('none') expect(final.parallel_tool_calls).toBe(false) + expect(final.reasoning_effort).toBe('high') expect(result.content).toBe('{"answer":1}') }) diff --git a/apps/sim/providers/litellm/index.ts b/apps/sim/providers/litellm/index.ts index eea8ea77a57..8264f63256b 100644 --- a/apps/sim/providers/litellm/index.ts +++ b/apps/sim/providers/litellm/index.ts @@ -687,20 +687,18 @@ export const litellmProvider: ProviderConfig = { logger.info('Applying deferred JSON schema response format after tool processing') const finalFormatStartTime = Date.now() + // Spread payload so all request fields carry over (model, temperature, + // max_completion_tokens, reasoning_effort, tools) — matching the streaming path. + // 'none' forces the structured answer instead of another tool_calls round that + // would leave content stale; tools stay defined for backends like Anthropic that + // reject a tool-result history without them; parallel calls off per OpenAI's rule. const finalPayload: any = { - model: payload.model, + ...payload, messages: currentMessages, response_format: responseFormatPayload, - // Force the structured answer: 'none' stops the model from returning another - // tool_calls round (which would leave content stale). Keep tools defined for - // backends (e.g. Anthropic) that reject a tool-result history without them, and - // disable parallel calls per OpenAI's strict-outputs-with-tools rule. - tools: payload.tools, tool_choice: 'none', parallel_tool_calls: false, } - if (request.temperature !== undefined) finalPayload.temperature = request.temperature - if (request.maxTokens != null) finalPayload.max_completion_tokens = request.maxTokens currentResponse = await litellm.chat.completions.create( finalPayload, From 683c105f120fea8252d8f96c6ccbde4024d41cbd Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:32:46 -0700 Subject: [PATCH 6/9] chore(providers/ollama): restore enrichment TSDoc block Keeps parity with sibling Chat Completions providers (cerebras/mistral/xai). Co-Authored-By: Claude Opus 4.8 --- apps/sim/providers/ollama/index.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/apps/sim/providers/ollama/index.ts b/apps/sim/providers/ollama/index.ts index bfe7cff6134..a6f7034fd2c 100644 --- a/apps/sim/providers/ollama/index.ts +++ b/apps/sim/providers/ollama/index.ts @@ -623,3 +623,8 @@ export const ollamaProvider: ProviderConfig = { } }, } + +/** + * Enriches the last model segment with per-iteration content from a Chat + * Completions response: assistant text, tool calls, finish reason, token usage. + */ From 9866ab546fb92dd0221ff180947dceedd1eacddc Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:34:00 -0700 Subject: [PATCH 7/9] docs(fireworks): restore TSDoc on utils helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restore the TSDoc blocks on supportsNativeStructuredOutputs, createReadableStreamFromOpenAIStream, and checkForForcedToolUsage — TSDoc is the codebase documentation standard and should not have been stripped. Co-Authored-By: Claude Opus 4.8 --- apps/sim/providers/fireworks/utils.ts | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/apps/sim/providers/fireworks/utils.ts b/apps/sim/providers/fireworks/utils.ts index 631800cc67f..70444e07b69 100644 --- a/apps/sim/providers/fireworks/utils.ts +++ b/apps/sim/providers/fireworks/utils.ts @@ -2,11 +2,18 @@ import type { ChatCompletionChunk } from 'openai/resources/chat/completions' import type { CompletionUsage } from 'openai/resources/completions' import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils' -/** Fireworks supports native json_schema structured outputs for all models on its inference API. */ +/** + * Checks if a model supports native structured outputs (json_schema). + * Fireworks AI supports structured outputs across their inference API. + */ export async function supportsNativeStructuredOutputs(_modelId: string): Promise { return true } +/** + * Creates a ReadableStream from a Fireworks streaming response. + * Uses the shared OpenAI-compatible streaming utility. + */ export function createReadableStreamFromOpenAIStream( openaiStream: AsyncIterable, onComplete?: (content: string, usage: CompletionUsage) => void @@ -14,6 +21,10 @@ export function createReadableStreamFromOpenAIStream( return createOpenAICompatibleStream(openaiStream, 'Fireworks', onComplete) } +/** + * Checks if a forced tool was used in a Fireworks response. + * Uses the shared OpenAI-compatible forced tool usage helper. + */ export function checkForForcedToolUsage( response: any, toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, From 7d56656be67fd348355919e75a9e3622173a8624 Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:34:47 -0700 Subject: [PATCH 8/9] chore(litellm): remove inline rationale comments (codebase uses TSDoc) Co-Authored-By: Claude Opus 4.8 --- apps/sim/providers/litellm/index.ts | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/apps/sim/providers/litellm/index.ts b/apps/sim/providers/litellm/index.ts index 8264f63256b..53a5360d2c9 100644 --- a/apps/sim/providers/litellm/index.ts +++ b/apps/sim/providers/litellm/index.ts @@ -160,8 +160,6 @@ export const litellmProvider: ProviderConfig = { type: 'json_schema' as const, json_schema: { name: request.responseFormat.name || 'response_schema', - // Strict mode requires additionalProperties:false and all-required keys; - // OpenAI-backed routes 400 without it. schema: isStrictResponseFormat ? enforceStrictSchema(request.responseFormat.schema || request.responseFormat) : request.responseFormat.schema || request.responseFormat, @@ -195,8 +193,6 @@ export const litellmProvider: ProviderConfig = { } } - // response_format + tools conflict on some backends (Anthropic rejects the pair, - // vLLM guided decoding suppresses tool calls), so defer the format past the tool loop. const deferResponseFormat = !!responseFormatPayload && hasActiveTools if (responseFormatPayload && !deferResponseFormat) { payload.response_format = responseFormatPayload @@ -499,8 +495,6 @@ export const litellmProvider: ProviderConfig = { respondedToolCallIds.add(toolCall.id) } - // Every tool_call needs a matching `tool` response or the next request 400s; - // stub any the model left unanswered (e.g. an unknown/filtered tool name). for (const tc of toolCallsInResponse) { if (respondedToolCallIds.has(tc.id)) continue currentMessages.push({ @@ -593,15 +587,11 @@ export const litellmProvider: ProviderConfig = { const streamingParams: ChatCompletionCreateParamsStreaming = { ...payload, messages: currentMessages, - // Tools are resolved; force a final answer so the model can't emit another - // tool_calls round the stream reader would drop. Keep tools defined for - // backends (e.g. Anthropic) that reject a tool-result history without them. tool_choice: 'none', stream: true, stream_options: { include_usage: true }, } if (deferResponseFormat && responseFormatPayload) { - // Disable parallel calls — OpenAI's rule for strict outputs alongside tools. streamingParams.response_format = responseFormatPayload streamingParams.parallel_tool_calls = false } @@ -687,11 +677,6 @@ export const litellmProvider: ProviderConfig = { logger.info('Applying deferred JSON schema response format after tool processing') const finalFormatStartTime = Date.now() - // Spread payload so all request fields carry over (model, temperature, - // max_completion_tokens, reasoning_effort, tools) — matching the streaming path. - // 'none' forces the structured answer instead of another tool_calls round that - // would leave content stale; tools stay defined for backends like Anthropic that - // reject a tool-result history without them; parallel calls off per OpenAI's rule. const finalPayload: any = { ...payload, messages: currentMessages, From 0be3ca213a30eb0416aaad83a456ec4fa715c24d Mon Sep 17 00:00:00 2001 From: Waleed Latif Date: Fri, 29 May 2026 13:34:56 -0700 Subject: [PATCH 9/9] chore(providers/ollama): drop orphaned enrichment TSDoc The block documented a function that now lives in trace-enrichment.ts, so it documents nothing in this file. Co-Authored-By: Claude Opus 4.8 --- apps/sim/providers/ollama/index.ts | 5 ----- 1 file changed, 5 deletions(-) diff --git a/apps/sim/providers/ollama/index.ts b/apps/sim/providers/ollama/index.ts index a6f7034fd2c..bfe7cff6134 100644 --- a/apps/sim/providers/ollama/index.ts +++ b/apps/sim/providers/ollama/index.ts @@ -623,8 +623,3 @@ export const ollamaProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */