diff --git a/apps/sim/providers/fireworks/index.test.ts b/apps/sim/providers/fireworks/index.test.ts new file mode 100644 index 0000000000..68fba04c73 --- /dev/null +++ b/apps/sim/providers/fireworks/index.test.ts @@ -0,0 +1,226 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCreate, + mockSupportsNativeStructuredOutputs, + mockPrepareToolsWithUsageControl, + mockExecuteTool, +} = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockSupportsNativeStructuredOutputs: vi.fn(), + mockPrepareToolsWithUsageControl: vi.fn(), + mockExecuteTool: vi.fn(), +})) + +vi.mock('openai', () => ({ + default: vi.fn().mockImplementation(() => ({ + chat: { completions: { create: mockCreate } }, + })), +})) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 5 })) + +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn().mockReturnValue([]), + getProviderDefaultModel: vi.fn().mockReturnValue('llama-v3p1-70b-instruct'), +})) + +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: vi.fn((messages) => messages), +})) + +vi.mock('@/providers/fireworks/utils', () => ({ + supportsNativeStructuredOutputs: mockSupportsNativeStructuredOutputs, + createReadableStreamFromOpenAIStream: vi.fn(() => ({}) as ReadableStream), + checkForForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), +})) + +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn().mockReturnValue({ input: 0, output: 0, total: 0 }), + generateSchemaInstructions: vi.fn(() => 'SCHEMA_INSTRUCTIONS'), + prepareToolExecution: vi.fn(() => ({ toolParams: { x: 1 }, executionParams: { x: 1 } })), + prepareToolsWithUsageControl: mockPrepareToolsWithUsageControl, + sumToolCosts: vi.fn().mockReturnValue(0), +})) + +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +import { fireworksProvider } from '@/providers/fireworks/index' +import { ProviderError } from '@/providers/types' + +const textResponse = (content: string) => ({ + choices: [{ message: { content, tool_calls: [] } }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, +}) + +const toolCallResponse = () => ({ + choices: [ + { + message: { + content: null, + tool_calls: [ + { id: 'call_1', type: 'function', function: { name: 'my_tool', arguments: '{"x":1}' } }, + ], + }, + }, + ], + usage: { prompt_tokens: 8, completion_tokens: 4, total_tokens: 12 }, +}) + +const toolDef = { + id: 'my_tool', + name: 'my_tool', + description: '', + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, +} + +const callBody = (index: number) => mockCreate.mock.calls[index][0] +const lastCallBody = () => mockCreate.mock.calls.at(-1)?.[0] + +describe('fireworksProvider', () => { + beforeEach(() => { + vi.clearAllMocks() + mockSupportsNativeStructuredOutputs.mockResolvedValue(true) + mockPrepareToolsWithUsageControl.mockImplementation((tools) => ({ + tools, + toolChoice: 'auto', + forcedTools: [], + })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + const baseRequest = { + model: 'fireworks/llama-v3p1-70b-instruct', + systemPrompt: 'You are helpful.', + messages: [{ role: 'user' as const, content: 'Hello' }], + apiKey: 'fw-test-key', + } + + it('throws when the API key is missing', async () => { + await expect( + fireworksProvider.executeRequest({ ...baseRequest, apiKey: undefined }) + ).rejects.toThrow('API key is required for Fireworks') + }) + + it('returns content and token usage for a simple request', async () => { + mockCreate.mockResolvedValueOnce(textResponse('hi there')) + + const result = await fireworksProvider.executeRequest(baseRequest) + + expect(result).toMatchObject({ + content: 'hi there', + model: 'llama-v3p1-70b-instruct', + tokens: { input: 10, output: 5, total: 15 }, + }) + }) + + it('wraps API errors in a ProviderError', async () => { + mockCreate.mockRejectedValueOnce(new Error('boom')) + + await expect(fireworksProvider.executeRequest(baseRequest)).rejects.toBeInstanceOf( + ProviderError + ) + }) + + it('streams directly when there are no tools', async () => { + mockCreate.mockResolvedValueOnce({}) + + const result = await fireworksProvider.executeRequest({ ...baseRequest, stream: true }) + + expect(lastCallBody()).toMatchObject({ stream: true, stream_options: { include_usage: true } }) + expect(result).toHaveProperty('stream') + expect(result).toHaveProperty('execution') + }) + + it('sends a json_schema response_format with no strict field', async () => { + mockCreate.mockResolvedValueOnce(textResponse('{}')) + + await fireworksProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' }, strict: true }, + }) + + expect(lastCallBody().response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'my_schema', schema: { type: 'object' } }, + }) + expect(lastCallBody().response_format.json_schema).not.toHaveProperty('strict') + }) + + it('falls back to json_object with prompt instructions when native is unsupported', async () => { + mockSupportsNativeStructuredOutputs.mockResolvedValue(false) + mockCreate.mockResolvedValueOnce(textResponse('{}')) + + await fireworksProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' } }, + }) + + expect(lastCallBody().response_format).toEqual({ type: 'json_object' }) + expect(lastCallBody().messages.at(-1)).toEqual({ + role: 'user', + content: 'SCHEMA_INSTRUCTIONS', + }) + }) + + it('defers response_format to a final call when tools are active', async () => { + mockCreate + .mockResolvedValueOnce(textResponse('intermediate')) + .mockResolvedValueOnce(textResponse('{"done":true}')) + + await fireworksProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' } }, + tools: [toolDef], + }) + + expect(mockCreate).toHaveBeenCalledTimes(2) + expect(callBody(0).response_format).toBeUndefined() + expect(callBody(0).tools).toBeDefined() + expect(callBody(1).response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'my_schema', schema: { type: 'object' } }, + }) + expect(callBody(1).tools).toBeUndefined() + }) + + it('runs the tool loop and threads tool results back into the conversation', async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse()) + .mockResolvedValueOnce(textResponse('final answer')) + + const result = await fireworksProvider.executeRequest({ ...baseRequest, tools: [toolDef] }) + + expect(mockExecuteTool).toHaveBeenCalledWith('my_tool', { x: 1 }, expect.anything()) + expect(result).toMatchObject({ content: 'final answer' }) + expect((result as { toolCalls?: unknown[] }).toolCalls).toHaveLength(1) + + const followUpMessages = callBody(1).messages + expect(followUpMessages).toContainEqual( + expect.objectContaining({ role: 'assistant', tool_calls: expect.any(Array) }) + ) + expect(followUpMessages).toContainEqual( + expect.objectContaining({ role: 'tool', tool_call_id: 'call_1' }) + ) + }) + + it("forces tool_choice 'none' on the final streaming call after tools run", async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse()) + .mockResolvedValueOnce(textResponse('done')) + .mockResolvedValueOnce({}) + + await fireworksProvider.executeRequest({ ...baseRequest, stream: true, tools: [toolDef] }) + + expect(mockCreate).toHaveBeenCalledTimes(3) + expect(lastCallBody()).toMatchObject({ tool_choice: 'none', stream: true }) + }) +}) diff --git a/apps/sim/providers/fireworks/index.ts b/apps/sim/providers/fireworks/index.ts index 794ee3f080..cdf355d345 100644 --- a/apps/sim/providers/fireworks/index.ts +++ b/apps/sim/providers/fireworks/index.ts @@ -34,7 +34,7 @@ const logger = createLogger('FireworksProvider') /** * Applies structured output configuration to a payload based on model capabilities. - * Uses json_schema with strict mode for supported models, falls back to json_object with prompt instructions. + * Uses native json_schema for supported models, falls back to json_object with prompt instructions. */ async function applyResponseFormat( targetPayload: any, @@ -51,7 +51,6 @@ async function applyResponseFormat( json_schema: { name: responseFormat.name || 'response_schema', schema: responseFormat.schema || responseFormat, - strict: responseFormat.strict !== false, }, } return messages @@ -469,7 +468,7 @@ export const fireworksProvider: ProviderConfig = { const streamingParams: ChatCompletionCreateParamsStreaming = { ...payload, messages: [...currentMessages], - tool_choice: 'auto', + tool_choice: 'none', stream: true, stream_options: { include_usage: true }, } @@ -652,8 +651,3 @@ export const fireworksProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */ diff --git a/apps/sim/providers/litellm/index.test.ts b/apps/sim/providers/litellm/index.test.ts new file mode 100644 index 0000000000..5261f4e23d --- /dev/null +++ b/apps/sim/providers/litellm/index.test.ts @@ -0,0 +1,290 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { mockCreate, mockExecuteTool } = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), +})) + +vi.mock('openai', () => ({ + default: vi.fn().mockImplementation(() => ({ + chat: { completions: { create: mockCreate } }, + })), +})) + +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) + +vi.mock('@/lib/core/config/env', () => ({ + env: { LITELLM_BASE_URL: 'http://litellm.test', LITELLM_API_KEY: '' }, +})) + +vi.mock('@/stores/providers', () => ({ + useProvidersStore: { getState: () => ({ setProviderModels: vi.fn() }) }, +})) + +vi.mock('@/providers/models', () => ({ + getProviderModels: () => [], + getProviderDefaultModel: () => '', +})) + +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: (messages: unknown) => messages, +})) + +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) + +vi.mock('@/providers/litellm/utils', () => ({ + createReadableStreamFromLiteLLMStream: vi.fn( + () => new ReadableStream({ start: (c) => c.close() }) + ), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn(() => ({ input: 0, output: 0, total: 0 })), + sumToolCosts: vi.fn(() => 0), + prepareToolExecution: vi.fn((_tool, toolArgs) => ({ + toolParams: toolArgs, + executionParams: toolArgs, + })), + prepareToolsWithUsageControl: vi.fn((tools) => ({ + tools, + toolChoice: 'auto', + forcedTools: [], + hasFilteredTools: false, + })), + trackForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), + enforceStrictSchema: vi.fn((schema) => ({ ...schema, additionalProperties: false })), +})) + +import { litellmProvider } from '@/providers/litellm' +import { ProviderError } from '@/providers/types' + +interface ChatOptions { + content?: string | null + toolCalls?: Array<{ id: string; function: { name: string; arguments: string } }> + usage?: { prompt_tokens: number; completion_tokens: number; total_tokens: number } +} + +function chat({ content = null, toolCalls, usage }: ChatOptions = {}) { + return { + choices: [ + { + message: { content, tool_calls: toolCalls }, + finish_reason: toolCalls ? 'tool_calls' : 'stop', + }, + ], + usage: usage ?? { prompt_tokens: 5, completion_tokens: 3, total_tokens: 8 }, + } +} + +function tool(name: string) { + return { id: name, name, description: 'd', parameters: {} } +} + +function run(request: Record) { + return litellmProvider.executeRequest!({ + model: 'litellm/llama-3', + messages: [{ role: 'user', content: 'Hi' }], + ...request, + } as never) as Promise +} + +const firstPayload = () => mockCreate.mock.calls[0][0] +const lastPayload = () => mockCreate.mock.calls.at(-1)![0] + +describe('litellmProvider.executeRequest', () => { + beforeEach(() => { + vi.clearAllMocks() + mockCreate.mockResolvedValue(chat({ content: 'hello' })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + it('assembles messages, strips the model prefix, and maps params', async () => { + const result = await run({ + systemPrompt: 'You are helpful.', + context: 'Some context', + temperature: 0.5, + maxTokens: 256, + }) + + const payload = firstPayload() + expect(payload.model).toBe('llama-3') + expect(payload.messages).toEqual([ + { role: 'system', content: 'You are helpful.' }, + { role: 'user', content: 'Some context' }, + { role: 'user', content: 'Hi' }, + ]) + expect(payload.temperature).toBe(0.5) + expect(payload.max_completion_tokens).toBe(256) + expect(result.content).toBe('hello') + expect(result.tokens).toEqual({ input: 5, output: 3, total: 8 }) + }) + + it('forwards reasoning_effort only when set to a non-default value', async () => { + await run({ reasoningEffort: 'high' }) + expect(firstPayload().reasoning_effort).toBe('high') + + mockCreate.mockClear() + await run({ reasoningEffort: 'auto' }) + expect(firstPayload().reasoning_effort).toBeUndefined() + + mockCreate.mockClear() + await run({}) + expect(firstPayload().reasoning_effort).toBeUndefined() + }) + + it('sanitizes the schema for strict response_format and passes it through otherwise', async () => { + await run({ responseFormat: { name: 'r', schema: { type: 'object', properties: {} } } }) + let rf = firstPayload().response_format + expect(rf.type).toBe('json_schema') + expect(rf.json_schema.strict).toBe(true) + expect(rf.json_schema.schema.additionalProperties).toBe(false) + + mockCreate.mockClear() + await run({ + responseFormat: { name: 'r', schema: { type: 'object', properties: {} }, strict: false }, + }) + rf = firstPayload().response_format + expect(rf.json_schema.strict).toBe(false) + expect(rf.json_schema.schema.additionalProperties).toBeUndefined() + }) + + it('defers response_format past the tool loop and keeps tools on the final call', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'c1', function: { name: 'known', arguments: '{"q":1}' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'mid' })) + .mockResolvedValueOnce(chat({ content: '{"answer":1}' })) + + const result = await run({ + tools: [tool('known')], + reasoningEffort: 'high', + responseFormat: { name: 'r', schema: { type: 'object', properties: {} } }, + }) + + expect(firstPayload().response_format).toBeUndefined() + expect(firstPayload().tools).toBeDefined() + + const final = lastPayload() + expect(final.response_format.type).toBe('json_schema') + expect(final.tools).toBeDefined() + expect(final.tool_choice).toBe('none') + expect(final.parallel_tool_calls).toBe(false) + expect(final.reasoning_effort).toBe('high') + expect(result.content).toBe('{"answer":1}') + }) + + it('defers response_format into the final streaming call while keeping tools', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'c1', function: { name: 'known', arguments: '{}' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'mid' })) + + const result = await run({ + stream: true, + tools: [tool('known')], + responseFormat: { name: 'r', schema: { type: 'object', properties: {} } }, + }) + + const final = lastPayload() + expect(final.stream).toBe(true) + expect(final.response_format.type).toBe('json_schema') + expect(final.tools).toBeDefined() + expect(final.tool_choice).toBe('none') + expect(final.parallel_tool_calls).toBe(false) + expect(result.execution.isStreaming).toBe(true) + }) + + it('threads assistant tool_calls and a named tool response, and reports toolCalls', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'c1', function: { name: 'known', arguments: '{}' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'done' })) + mockExecuteTool.mockResolvedValue({ success: true, output: { temp: 72 } }) + + const result = await run({ tools: [tool('known')] }) + + const followupMessages = mockCreate.mock.calls[1][0].messages + expect(followupMessages).toContainEqual({ + role: 'assistant', + content: null, + tool_calls: [{ id: 'c1', type: 'function', function: { name: 'known', arguments: '{}' } }], + }) + expect(followupMessages).toContainEqual({ + role: 'tool', + tool_call_id: 'c1', + name: 'known', + content: JSON.stringify({ temp: 72 }), + }) + expect(result.toolCalls).toHaveLength(1) + expect(result.content).toBe('done') + }) + + it('emits a stub tool response for an unanswered tool_call_id', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'cX', function: { name: 'ghost', arguments: '{}' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'recovered' })) + + await run({ tools: [tool('known')] }) + + expect(mockExecuteTool).not.toHaveBeenCalled() + const followupMessages = mockCreate.mock.calls[1][0].messages + const toolMsg = followupMessages.find((m: any) => m.role === 'tool' && m.tool_call_id === 'cX') + expect(toolMsg).toBeDefined() + expect(toolMsg.content).toContain('not available') + }) + + it('executes a tool with empty arguments without failing', async () => { + mockCreate + .mockResolvedValueOnce( + chat({ toolCalls: [{ id: 'c1', function: { name: 'ping', arguments: '' } }] }) + ) + .mockResolvedValueOnce(chat({ content: 'pong' })) + + await run({ tools: [tool('ping')] }) + + expect(mockExecuteTool).toHaveBeenCalledTimes(1) + const toolMsg = mockCreate.mock.calls[1][0].messages.find((m: any) => m.role === 'tool') + expect(toolMsg.content).not.toContain('"error":true') + }) + + it('stops the tool loop at MAX_TOOL_ITERATIONS', async () => { + mockCreate.mockResolvedValue( + chat({ toolCalls: [{ id: 'c1', function: { name: 'known', arguments: '{}' } }] }) + ) + + await run({ tools: [tool('known')] }) + + expect(mockCreate).toHaveBeenCalledTimes(1 + 20) + expect(mockExecuteTool).toHaveBeenCalledTimes(20) + }) + + it('returns a streaming execution when streaming without active tools', async () => { + const result = await run({ stream: true }) + + expect(firstPayload().stream).toBe(true) + expect(firstPayload().stream_options).toEqual({ include_usage: true }) + expect(result.stream).toBeInstanceOf(ReadableStream) + expect(result.execution.isStreaming).toBe(true) + }) + + it('wraps API errors in a ProviderError using the error envelope message', async () => { + mockCreate.mockRejectedValue({ + error: { message: 'rate limited', type: 'rate_limit_error', code: '429' }, + }) + + await expect(run({})).rejects.toBeInstanceOf(ProviderError) + await expect(run({})).rejects.toThrow('rate limited') + }) +}) diff --git a/apps/sim/providers/litellm/index.ts b/apps/sim/providers/litellm/index.ts index 33e363f050..53a5360d2c 100644 --- a/apps/sim/providers/litellm/index.ts +++ b/apps/sim/providers/litellm/index.ts @@ -19,6 +19,7 @@ import type { import { ProviderError } from '@/providers/types' import { calculateCost, + enforceStrictSchema, prepareToolExecution, prepareToolsWithUsageControl, sumToolCosts, @@ -146,19 +147,27 @@ export const litellmProvider: ProviderConfig = { if (request.temperature !== undefined) payload.temperature = request.temperature if (request.maxTokens != null) payload.max_completion_tokens = request.maxTokens - if (request.responseFormat) { - payload.response_format = { - type: 'json_schema', - json_schema: { - name: request.responseFormat.name || 'response_schema', - schema: request.responseFormat.schema || request.responseFormat, - strict: request.responseFormat.strict !== false, - }, - } - - logger.info('Added JSON schema response format to LiteLLM request') + if (request.reasoningEffort !== undefined && request.reasoningEffort !== 'auto') { + payload.reasoning_effort = request.reasoningEffort } + const isStrictResponseFormat = request.responseFormat + ? request.responseFormat.strict !== false + : false + + const responseFormatPayload = request.responseFormat + ? { + type: 'json_schema' as const, + json_schema: { + name: request.responseFormat.name || 'response_schema', + schema: isStrictResponseFormat + ? enforceStrictSchema(request.responseFormat.schema || request.responseFormat) + : request.responseFormat.schema || request.responseFormat, + strict: isStrictResponseFormat, + }, + } + : undefined + let preparedTools: ReturnType | null = null let hasActiveTools = false @@ -184,6 +193,12 @@ export const litellmProvider: ProviderConfig = { } } + const deferResponseFormat = !!responseFormatPayload && hasActiveTools + if (responseFormatPayload && !deferResponseFormat) { + payload.response_format = responseFormatPayload + logger.info('Added JSON schema response format to LiteLLM request') + } + const providerStartTime = Date.now() const providerStartTimeISO = new Date(providerStartTime).toISOString() @@ -271,6 +286,7 @@ export const litellmProvider: ProviderConfig = { endTime: new Date().toISOString(), duration: Date.now() - providerStartTime, }, + isStreaming: true, }, } as StreamingExecution @@ -374,7 +390,9 @@ export const litellmProvider: ProviderConfig = { const toolName = toolCall.function.name try { - const toolArgs = JSON.parse(toolCall.function.arguments) + const toolArgs = toolCall.function.arguments + ? JSON.parse(toolCall.function.arguments) + : {} const tool = request.tools?.find((t) => t.id === toolName) if (!tool) return null @@ -429,6 +447,8 @@ export const litellmProvider: ProviderConfig = { })), }) + const respondedToolCallIds = new Set() + for (const settledResult of executionResults) { if (settledResult.status === 'rejected' || !settledResult.value) continue @@ -469,8 +489,24 @@ export const litellmProvider: ProviderConfig = { currentMessages.push({ role: 'tool', tool_call_id: toolCall.id, + name: toolName, content: JSON.stringify(resultContent), }) + respondedToolCallIds.add(toolCall.id) + } + + for (const tc of toolCallsInResponse) { + if (respondedToolCallIds.has(tc.id)) continue + currentMessages.push({ + role: 'tool', + tool_call_id: tc.id, + name: tc.function.name, + content: JSON.stringify({ + error: true, + message: `Tool "${tc.function.name}" is not available`, + tool: tc.function.name, + }), + }) } const thisToolsTime = Date.now() - toolsStartTime @@ -551,10 +587,14 @@ export const litellmProvider: ProviderConfig = { const streamingParams: ChatCompletionCreateParamsStreaming = { ...payload, messages: currentMessages, - tool_choice: 'auto', + tool_choice: 'none', stream: true, stream_options: { include_usage: true }, } + if (deferResponseFormat && responseFormatPayload) { + streamingParams.response_format = responseFormatPayload + streamingParams.parallel_tool_calls = false + } const streamResponse = await litellm.chat.completions.create( streamingParams, request.abortSignal ? { signal: request.abortSignal } : undefined @@ -626,12 +666,59 @@ export const litellmProvider: ProviderConfig = { endTime: new Date().toISOString(), duration: Date.now() - providerStartTime, }, + isStreaming: true, }, } as StreamingExecution return streamingResult as StreamingExecution } + if (deferResponseFormat && responseFormatPayload) { + logger.info('Applying deferred JSON schema response format after tool processing') + + const finalFormatStartTime = Date.now() + const finalPayload: any = { + ...payload, + messages: currentMessages, + response_format: responseFormatPayload, + tool_choice: 'none', + parallel_tool_calls: false, + } + + currentResponse = await litellm.chat.completions.create( + finalPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const finalFormatEndTime = Date.now() + timeSegments.push({ + type: 'model', + name: request.model, + startTime: finalFormatStartTime, + endTime: finalFormatEndTime, + duration: finalFormatEndTime - finalFormatStartTime, + }) + modelTime += finalFormatEndTime - finalFormatStartTime + + const formattedContent = currentResponse.choices[0]?.message?.content + if (formattedContent) { + content = formattedContent.replace(/```json\n?|\n?```/g, '').trim() + } + + if (currentResponse.usage) { + tokens.input += currentResponse.usage.prompt_tokens || 0 + tokens.output += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } + + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + currentResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: 'litellm' } + ) + } + const providerEndTime = Date.now() const providerEndTimeISO = new Date(providerEndTime).toISOString() const totalDuration = providerEndTime - providerStartTime @@ -660,7 +747,7 @@ export const litellmProvider: ProviderConfig = { let errorMessage = toError(error).message let errorType: string | undefined - let errorCode: number | undefined + let errorCode: string | number | undefined if (error && typeof error === 'object' && 'error' in error) { const litellmError = error.error as any diff --git a/apps/sim/providers/ollama/index.test.ts b/apps/sim/providers/ollama/index.test.ts new file mode 100644 index 0000000000..a6b91e9e8f --- /dev/null +++ b/apps/sim/providers/ollama/index.test.ts @@ -0,0 +1,351 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +type StreamUsage = { prompt_tokens: number; completion_tokens: number; total_tokens: number } + +const { mockCreate, mockExecuteTool, streamOnComplete, MockAPIError } = vi.hoisted(() => { + class MockAPIError extends Error { + status?: number + code?: string | null + type?: string + constructor(message: string, opts: { status?: number; code?: string; type?: string } = {}) { + super(message) + this.name = 'APIError' + this.status = opts.status + this.code = opts.code + this.type = opts.type + } + } + return { + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), + streamOnComplete: { + current: undefined as undefined | ((content: string, usage: StreamUsage) => void), + }, + MockAPIError, + } +}) + +vi.mock('openai', () => { + const OpenAI = vi.fn(() => ({ chat: { completions: { create: mockCreate } } })) + ;(OpenAI as unknown as { APIError: typeof MockAPIError }).APIError = MockAPIError + return { default: OpenAI } +}) + +vi.mock('@/lib/core/utils/urls', () => ({ getOllamaUrl: () => 'http://localhost:11434' })) +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: (messages: unknown) => messages, +})) +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) +vi.mock('@/providers/ollama/utils', () => ({ + createReadableStreamFromOllamaStream: ( + _stream: unknown, + onComplete: (content: string, usage: StreamUsage) => void + ) => { + streamOnComplete.current = onComplete + return 'OLLAMA_STREAM' + }, +})) +vi.mock('@/providers/utils', () => ({ + calculateCost: () => ({ input: 0, output: 0, total: 0, pricing: null }), + prepareToolExecution: (_tool: unknown, args: Record) => ({ + toolParams: args, + executionParams: args, + }), + sumToolCosts: () => 0, +})) +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) +vi.mock('@/stores/providers', () => ({ + useProvidersStore: { getState: () => ({ setProviderModels: vi.fn() }) }, +})) + +import { ollamaProvider } from '@/providers/ollama' +import type { ProviderRequest, ProviderResponse, ProviderToolConfig } from '@/providers/types' + +interface StreamingResult { + stream: string + execution: { + output: { + content: string + tokens: { input: number; output: number; total: number } + toolCalls?: { list: unknown[]; count: number } + } + } +} + +type ToolCallChunk = { id: string; type: 'function'; function: { name: string; arguments: string } } + +function completion( + opts: { content?: string | null; toolCalls?: ToolCallChunk[]; usage?: StreamUsage } = {} +) { + return { + choices: [{ message: { content: opts.content ?? null, tool_calls: opts.toolCalls } }], + usage: opts.usage ?? { prompt_tokens: 5, completion_tokens: 3, total_tokens: 8 }, + } +} + +function makeTool(id: string, usageControl?: 'auto' | 'force' | 'none'): ProviderToolConfig { + return { + id, + name: id, + description: `${id} tool`, + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, + ...(usageControl ? { usageControl } : {}), + } +} + +const baseRequest: ProviderRequest = { + model: 'llama3.2', + messages: [{ role: 'user', content: 'hi' }], +} + +describe('ollamaProvider.executeRequest', () => { + beforeEach(() => { + vi.clearAllMocks() + streamOnComplete.current = undefined + mockCreate.mockResolvedValue(completion({ content: 'hello' })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + it('assembles system, context, then history in order and forwards params', async () => { + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + systemPrompt: 'be nice', + context: 'ctx', + temperature: 0.5, + maxTokens: 128, + })) as ProviderResponse + + expect(result).toMatchObject({ content: 'hello', model: 'llama3.2' }) + const payload = mockCreate.mock.calls[0][0] + expect(payload.messages).toEqual([ + { role: 'system', content: 'be nice' }, + { role: 'user', content: 'ctx' }, + { role: 'user', content: 'hi' }, + ]) + expect(payload.model).toBe('llama3.2') + expect(payload.temperature).toBe(0.5) + expect(payload.max_tokens).toBe(128) + }) + + it('returns content verbatim (keeps ```json fences) when no responseFormat', async () => { + const fenced = '```json\n{"a":1}\n```' + mockCreate.mockResolvedValue(completion({ content: fenced })) + const result = (await ollamaProvider.executeRequest(baseRequest)) as ProviderResponse + expect(result.content).toBe(fenced) + }) + + it('strips ```json fences and sends a json_schema response_format when requested', async () => { + mockCreate.mockResolvedValue(completion({ content: '```json\n{"a":1}\n```' })) + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'r', schema: { type: 'object' }, strict: true }, + })) as ProviderResponse + expect(result.content).toBe('{"a":1}') + expect(mockCreate.mock.calls[0][0].response_format).toMatchObject({ + type: 'json_schema', + json_schema: { name: 'r', schema: { type: 'object' }, strict: true }, + }) + }) + + it('runs the tool loop: parses string args, feeds results back, then terminates', async () => { + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{"x":1}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'done' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('mytool')], + })) as ProviderResponse + + expect(mockExecuteTool).toHaveBeenCalledWith('mytool', { x: 1 }, expect.anything()) + expect(mockCreate).toHaveBeenCalledTimes(2) + expect(result.content).toBe('done') + expect(result.toolCalls).toEqual([ + expect.objectContaining({ name: 'mytool', success: true, arguments: { x: 1 } }), + ]) + expect(result.toolResults).toEqual([{ ok: true }]) + + const followUp = mockCreate.mock.calls[1][0].messages + expect(followUp).toContainEqual( + expect.objectContaining({ + role: 'assistant', + content: null, + tool_calls: [ + expect.objectContaining({ + id: 'call_1', + function: { name: 'mytool', arguments: '{"x":1}' }, + }), + ], + }) + ) + expect(followUp).toContainEqual({ + role: 'tool', + tool_call_id: 'call_1', + content: JSON.stringify({ ok: true }), + }) + }) + + it('records a failed tool result without aborting the loop', async () => { + mockExecuteTool.mockResolvedValue({ success: false, error: 'boom' }) + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'recovered' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('mytool')], + })) as ProviderResponse + + expect(result.content).toBe('recovered') + expect(result.toolCalls?.[0]).toMatchObject({ name: 'mytool', success: false }) + const toolMsg = mockCreate.mock.calls[1][0].messages.find( + (m: { role: string }) => m.role === 'tool' + ) + expect(JSON.parse(toolMsg.content)).toMatchObject({ error: true, message: 'boom' }) + }) + + it('executes parallel tool calls from a single response', async () => { + mockExecuteTool + .mockResolvedValueOnce({ success: true, output: { from: 'a' } }) + .mockResolvedValueOnce({ success: true, output: { from: 'b' } }) + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_a', type: 'function', function: { name: 'a', arguments: '{}' } }, + { id: 'call_b', type: 'function', function: { name: 'b', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'summary' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('a'), makeTool('b')], + })) as ProviderResponse + + expect(mockExecuteTool).toHaveBeenCalledTimes(2) + expect(result.toolCalls?.map((c) => c.name)).toEqual(['a', 'b']) + const toolMsgs = mockCreate.mock.calls[1][0].messages.filter( + (m: { role: string }) => m.role === 'tool' + ) + expect(toolMsgs.map((m: { tool_call_id: string }) => m.tool_call_id)).toEqual([ + 'call_a', + 'call_b', + ]) + }) + + it('filters out tools with usageControl "none"', async () => { + await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('keep'), makeTool('drop', 'none')], + }) + const sent = mockCreate.mock.calls[0][0].tools + expect(sent.map((t: { function: { name: string } }) => t.function.name)).toEqual(['keep']) + }) + + it('never forces tools (Ollama ignores tool_choice) and keeps "auto"', async () => { + await ollamaProvider.executeRequest({ ...baseRequest, tools: [makeTool('forced', 'force')] }) + const payload = mockCreate.mock.calls[0][0] + expect(payload.tool_choice).toBe('auto') + expect(payload.tools.map((t: { function: { name: string } }) => t.function.name)).toEqual([ + 'forced', + ]) + }) + + it('surfaces an OpenAI APIError message through ProviderError', async () => { + mockCreate.mockRejectedValue( + new MockAPIError('model not found', { + status: 404, + code: 'not_found', + type: 'invalid_request_error', + }) + ) + await expect(ollamaProvider.executeRequest(baseRequest)).rejects.toThrow('model not found') + }) + + it('streams content and usage when no tools are used', async () => { + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + stream: true, + })) as unknown as StreamingResult + + expect(result.stream).toBe('OLLAMA_STREAM') + expect(mockCreate.mock.calls[0][0].stream_options).toEqual({ include_usage: true }) + + streamOnComplete.current?.('streamed text', { + prompt_tokens: 4, + completion_tokens: 6, + total_tokens: 10, + }) + expect(result.execution.output.content).toBe('streamed text') + expect(result.execution.output.tokens).toMatchObject({ input: 4, output: 6, total: 10 }) + }) + + it('strips ```json fences from streamed content when responseFormat is set', async () => { + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + stream: true, + responseFormat: { name: 'r', schema: { type: 'object' }, strict: true }, + })) as unknown as StreamingResult + + streamOnComplete.current?.('```json\n{"a":1}\n```', { + prompt_tokens: 1, + completion_tokens: 2, + total_tokens: 3, + }) + expect(result.execution.output.content).toBe('{"a":1}') + }) + + it('streams the final response after a tool loop, carrying tool calls', async () => { + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'intermediate' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + stream: true, + tools: [makeTool('mytool')], + })) as unknown as StreamingResult + + expect(result.stream).toBe('OLLAMA_STREAM') + expect(mockExecuteTool).toHaveBeenCalledTimes(1) + + const finalCall = mockCreate.mock.calls[2][0] + expect(finalCall.tools).toBeUndefined() + expect(finalCall.tool_choice).toBeUndefined() + + streamOnComplete.current?.('final answer', { + prompt_tokens: 2, + completion_tokens: 4, + total_tokens: 6, + }) + expect(result.execution.output.content).toBe('final answer') + expect(result.execution.output.toolCalls).toMatchObject({ count: 1 }) + }) +}) diff --git a/apps/sim/providers/ollama/index.ts b/apps/sim/providers/ollama/index.ts index 52332aecdb..bfe7cff613 100644 --- a/apps/sim/providers/ollama/index.ts +++ b/apps/sim/providers/ollama/index.ts @@ -1,5 +1,5 @@ import { createLogger } from '@sim/logger' -import { getErrorMessage, toError } from '@sim/utils/errors' +import { getErrorMessage } from '@sim/utils/errors' import OpenAI from 'openai' import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions' import { getOllamaUrl } from '@/lib/core/utils/urls' @@ -10,6 +10,7 @@ import type { ModelsObject } from '@/providers/ollama/types' import { createReadableStreamFromOllamaStream } from '@/providers/ollama/utils' import { enrichLastModelSegmentFromChatCompletions } from '@/providers/trace-enrichment' import type { + Message, ProviderConfig, ProviderRequest, ProviderResponse, @@ -73,7 +74,7 @@ export const ollamaProvider: ProviderConfig = { baseURL: `${OLLAMA_HOST}/v1`, }) - const allMessages = [] + const allMessages: Message[] = [] if (request.systemPrompt) { allMessages.push({ @@ -92,7 +93,7 @@ export const ollamaProvider: ProviderConfig = { if (request.messages) { allMessages.push(...request.messages) } - const formattedMessages = formatMessagesForProvider(allMessages, 'ollama') + const formattedMessages = formatMessagesForProvider(allMessages, 'ollama') as Message[] const tools = request.tools?.length ? request.tools.map((tool) => ({ @@ -180,7 +181,7 @@ export const ollamaProvider: ProviderConfig = { stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => { streamingResult.execution.output.content = content - if (content) { + if (content && request.responseFormat) { streamingResult.execution.output.content = content .replace(/```json\n?|\n?```/g, '') .trim() @@ -264,7 +265,7 @@ export const ollamaProvider: ProviderConfig = { let content = currentResponse.choices[0]?.message?.content || '' - if (content) { + if (content && request.responseFormat) { content = content.replace(/```json\n?|\n?```/g, '') content = content.trim() } @@ -295,6 +296,9 @@ export const ollamaProvider: ProviderConfig = { while (iterationCount < MAX_TOOL_ITERATIONS) { if (currentResponse.choices[0]?.message?.content) { content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } } const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls @@ -450,8 +454,9 @@ export const ollamaProvider: ProviderConfig = { if (currentResponse.choices[0]?.message?.content) { content = currentResponse.choices[0].message.content - content = content.replace(/```json\n?|\n?```/g, '') - content = content.trim() + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } } if (currentResponse.usage) { @@ -477,10 +482,11 @@ export const ollamaProvider: ProviderConfig = { const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output) + const { tools: _tools, tool_choice: _toolChoice, ...streamPayload } = payload + const streamingParams: ChatCompletionCreateParamsStreaming = { - ...payload, + ...streamPayload, messages: currentMessages, - tool_choice: 'auto', stream: true, stream_options: { include_usage: true }, } @@ -493,7 +499,7 @@ export const ollamaProvider: ProviderConfig = { stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => { streamingResult.execution.output.content = content - if (content) { + if (content && request.responseFormat) { streamingResult.execution.output.content = content .replace(/```json\n?|\n?```/g, '') .trim() @@ -589,12 +595,27 @@ export const ollamaProvider: ProviderConfig = { const providerEndTimeISO = new Date(providerEndTime).toISOString() const totalDuration = providerEndTime - providerStartTime + let errorMessage = getErrorMessage(error, 'Unknown error') + let errorType: string | undefined + let errorCode: string | undefined + let status: number | undefined + + if (error instanceof OpenAI.APIError) { + errorMessage = error.message + errorType = error.type + errorCode = error.code ?? undefined + status = error.status + } + logger.error('Error in Ollama request:', { - error, + error: errorMessage, + errorType, + errorCode, + status, duration: totalDuration, }) - throw new ProviderError(toError(error).message, { + throw new ProviderError(errorMessage, { startTime: providerStartTimeISO, endTime: providerEndTimeISO, duration: totalDuration, @@ -602,8 +623,3 @@ export const ollamaProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */ diff --git a/apps/sim/providers/openai/core.ts b/apps/sim/providers/openai/core.ts index 6946f0c0fa..6f19cef156 100644 --- a/apps/sim/providers/openai/core.ts +++ b/apps/sim/providers/openai/core.ts @@ -8,6 +8,7 @@ import type { Message, ProviderRequest, ProviderResponse, TimeSegment } from '@/ import { ProviderError } from '@/providers/types' import { calculateCost, + enforceStrictSchema, prepareToolExecution, prepareToolsWithUsageControl, sumToolCosts, @@ -31,60 +32,6 @@ import { type PreparedTools = ReturnType type ToolChoice = PreparedTools['toolChoice'] -/** - * Recursively enforces OpenAI strict mode requirements on a JSON schema. - * - Sets additionalProperties: false on all object types. - * - Ensures required includes ALL property keys. - */ -function enforceStrictSchema(schema: Record): Record { - if (!schema || typeof schema !== 'object') return schema - - const result = { ...schema } - - // If this is an object type, enforce strict requirements - if (result.type === 'object') { - result.additionalProperties = false - - // Recursively process properties and ensure required includes all keys - if (result.properties && typeof result.properties === 'object') { - const propKeys = Object.keys(result.properties as Record) - result.required = propKeys // Strict mode requires ALL properties - result.properties = Object.fromEntries( - Object.entries(result.properties as Record).map(([key, value]) => [ - key, - enforceStrictSchema(value as Record), - ]) - ) - } - } - - // Handle array items - if (result.type === 'array' && result.items) { - result.items = enforceStrictSchema(result.items as Record) - } - - // Handle anyOf, oneOf, allOf - for (const keyword of ['anyOf', 'oneOf', 'allOf']) { - if (Array.isArray(result[keyword])) { - result[keyword] = (result[keyword] as Record[]).map(enforceStrictSchema) - } - } - - // Handle $defs / definitions - for (const defKey of ['$defs', 'definitions']) { - if (result[defKey] && typeof result[defKey] === 'object') { - result[defKey] = Object.fromEntries( - Object.entries(result[defKey] as Record).map(([key, value]) => [ - key, - enforceStrictSchema(value as Record), - ]) - ) - } - } - - return result -} - export interface ResponsesProviderConfig { providerId: string providerLabel: string diff --git a/apps/sim/providers/openrouter/index.test.ts b/apps/sim/providers/openrouter/index.test.ts new file mode 100644 index 0000000000..88339fe4a9 --- /dev/null +++ b/apps/sim/providers/openrouter/index.test.ts @@ -0,0 +1,345 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCreate, + mockExecuteTool, + mockSupportsNative, + mockPrepareTools, + mockCheckForced, + mockCreateStream, +} = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), + mockSupportsNative: vi.fn(), + mockPrepareTools: vi.fn((tools: unknown) => ({ + tools, + toolChoice: 'auto', + forcedTools: [], + hasFilteredTools: false, + })), + mockCheckForced: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), + mockCreateStream: vi.fn(), +})) + +vi.mock('openai', () => ({ + default: vi.fn().mockImplementation(() => ({ + chat: { completions: { create: mockCreate } }, + })), +})) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 10 })) + +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn().mockReturnValue([]), + getProviderDefaultModel: vi.fn().mockReturnValue(''), +})) + +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: vi.fn((messages: unknown) => messages), +})) + +vi.mock('@/providers/openrouter/utils', () => ({ + supportsNativeStructuredOutputs: mockSupportsNative, + createReadableStreamFromOpenAIStream: mockCreateStream, + checkForForcedToolUsage: mockCheckForced, +})) + +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn(() => ({ input: 0, output: 0, total: 0 })), + prepareToolsWithUsageControl: mockPrepareTools, + prepareToolExecution: vi.fn((_tool: unknown, toolArgs: Record) => ({ + toolParams: toolArgs, + executionParams: toolArgs, + })), + sumToolCosts: vi.fn(() => 0), + generateSchemaInstructions: vi.fn(() => 'SCHEMA_INSTRUCTIONS'), +})) + +import { openRouterProvider } from '@/providers/openrouter/index' +import type { ProviderRequest, ProviderResponse, ProviderToolConfig } from '@/providers/types' + +interface Usage { + prompt_tokens: number + completion_tokens: number + total_tokens: number +} + +function textResponse( + content: string, + usage: Usage = { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 } +) { + return { + choices: [{ message: { content, tool_calls: undefined }, finish_reason: 'stop' }], + usage, + } +} + +function toolCallResponse(name: string, args: Record, id = 'call_1') { + return { + choices: [ + { + message: { + content: null, + tool_calls: [ + { id, type: 'function', function: { name, arguments: JSON.stringify(args) } }, + ], + }, + finish_reason: 'tool_calls', + }, + ], + usage: { prompt_tokens: 8, completion_tokens: 4, total_tokens: 12 }, + } +} + +function tool(id: string): ProviderToolConfig { + return { + id, + name: id, + description: 'test tool', + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, + } +} + +const baseRequest: ProviderRequest = { + apiKey: 'sk-or-test', + model: 'openrouter/anthropic/claude-3.5-sonnet', + systemPrompt: 'You are helpful.', + messages: [{ role: 'user', content: 'Hello' }], +} + +describe('openRouterProvider.executeRequest', () => { + beforeEach(() => { + vi.clearAllMocks() + mockCreate.mockReset() + mockExecuteTool.mockReset() + mockSupportsNative.mockResolvedValue(false) + }) + + it('requires an API key', async () => { + await expect( + openRouterProvider.executeRequest({ model: 'openrouter/x', messages: [] }) + ).rejects.toThrow('API key is required for OpenRouter') + }) + + it('strips the openrouter/ prefix and returns content + tokens', async () => { + mockCreate.mockResolvedValueOnce(textResponse('Hi there')) + + const res = (await openRouterProvider.executeRequest(baseRequest)) as ProviderResponse + + expect(res.content).toBe('Hi there') + expect(res.model).toBe('anthropic/claude-3.5-sonnet') + expect(res.tokens).toEqual({ input: 10, output: 5, total: 15 }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.model).toBe('anthropic/claude-3.5-sonnet') + expect(payload.messages[0]).toEqual({ role: 'system', content: 'You are helpful.' }) + expect(payload.messages.at(-1)).toEqual({ role: 'user', content: 'Hello' }) + }) + + it('inserts context as a user message between system and history', async () => { + mockCreate.mockResolvedValueOnce(textResponse('ok')) + + await openRouterProvider.executeRequest({ ...baseRequest, context: 'CTX' }) + + const { messages } = mockCreate.mock.calls[0][0] + expect(messages[0]).toEqual({ role: 'system', content: 'You are helpful.' }) + expect(messages[1]).toEqual({ role: 'user', content: 'CTX' }) + expect(messages[2]).toEqual({ role: 'user', content: 'Hello' }) + }) + + it('forwards maxTokens as max_tokens and temperature', async () => { + mockCreate.mockResolvedValueOnce(textResponse('ok')) + + await openRouterProvider.executeRequest({ ...baseRequest, maxTokens: 256, temperature: 0.4 }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.max_tokens).toBe(256) + expect(payload.temperature).toBe(0.4) + }) + + it('runs the tool loop: executes the tool, echoes tool_calls, returns the tool result, sums tokens', async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse('get_weather', { city: 'SF' })) + .mockResolvedValueOnce( + textResponse('It is sunny', { prompt_tokens: 20, completion_tokens: 6, total_tokens: 26 }) + ) + mockExecuteTool.mockResolvedValueOnce({ success: true, output: { temp: 70 } }) + + const res = (await openRouterProvider.executeRequest({ + ...baseRequest, + tools: [tool('get_weather')], + })) as ProviderResponse + + expect(mockExecuteTool).toHaveBeenCalledWith('get_weather', { city: 'SF' }, expect.anything()) + expect(res.content).toBe('It is sunny') + expect(res.toolCalls?.[0]).toMatchObject({ + name: 'get_weather', + result: { temp: 70 }, + success: true, + }) + expect(res.toolResults).toEqual([{ temp: 70 }]) + expect(res.tokens).toEqual({ input: 28, output: 10, total: 38 }) + + const secondMessages = mockCreate.mock.calls[1][0].messages + const assistant = secondMessages.find((m: { role: string }) => m.role === 'assistant') + expect(assistant).toMatchObject({ + content: null, + tool_calls: [{ id: 'call_1', type: 'function', function: { name: 'get_weather' } }], + }) + const toolMsg = secondMessages.find((m: { role: string }) => m.role === 'tool') + expect(toolMsg).toEqual({ + role: 'tool', + tool_call_id: 'call_1', + content: JSON.stringify({ temp: 70 }), + }) + }) + + it('reports a failed tool result as an error payload to the model', async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse('get_weather', { city: 'SF' })) + .mockResolvedValueOnce(textResponse('done')) + mockExecuteTool.mockResolvedValueOnce({ success: false, output: undefined, error: 'boom' }) + + const res = (await openRouterProvider.executeRequest({ + ...baseRequest, + tools: [tool('get_weather')], + })) as ProviderResponse + + expect(res.toolResults).toBeUndefined() + expect(res.toolCalls?.[0]).toMatchObject({ success: false }) + const toolMsg = mockCreate.mock.calls[1][0].messages.find( + (m: { role: string }) => m.role === 'tool' + ) + expect(JSON.parse(toolMsg.content)).toEqual({ + error: true, + message: 'boom', + tool: 'get_weather', + }) + }) + + it('applies native structured outputs (json_schema + require_parameters) when no tools are active', async () => { + mockSupportsNative.mockResolvedValue(true) + mockCreate.mockResolvedValueOnce(textResponse('{"x":1}')) + + await openRouterProvider.executeRequest({ + ...baseRequest, + responseFormat: { + name: 'out', + schema: { type: 'object', properties: { x: { type: 'number' } } }, + strict: true, + }, + }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.response_format).toMatchObject({ + type: 'json_schema', + json_schema: { name: 'out', strict: true }, + }) + expect(payload.provider).toMatchObject({ require_parameters: true }) + }) + + it('falls back to json_object + prompt instructions when native structured outputs are unsupported', async () => { + mockSupportsNative.mockResolvedValue(false) + mockCreate.mockResolvedValueOnce(textResponse('{"x":1}')) + + await openRouterProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'out', schema: { type: 'object' } }, + }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.response_format).toEqual({ type: 'json_object' }) + expect(payload.messages.at(-1)).toEqual({ role: 'user', content: 'SCHEMA_INSTRUCTIONS' }) + }) + + it('defers response_format until after the tool loop when tools are active', async () => { + mockSupportsNative.mockResolvedValue(true) + mockCreate + .mockResolvedValueOnce(textResponse('interim')) + .mockResolvedValueOnce(textResponse('{"x":1}')) + + const res = (await openRouterProvider.executeRequest({ + ...baseRequest, + tools: [tool('get_weather')], + responseFormat: { name: 'out', schema: { type: 'object' }, strict: true }, + })) as ProviderResponse + + const toolCall = mockCreate.mock.calls[0][0] + expect(toolCall.tools).toBeDefined() + expect(toolCall.response_format).toBeUndefined() + + const finalCall = mockCreate.mock.calls[1][0] + expect(finalCall.response_format).toMatchObject({ type: 'json_schema' }) + expect(finalCall.tools).toBeUndefined() + expect(finalCall.tool_choice).toBeUndefined() + expect(res.content).toBe('{"x":1}') + }) + + it('forces the next tool after a forced tool is used', async () => { + mockPrepareTools.mockReturnValueOnce({ + tools: [tool('a')], + toolChoice: { type: 'function', function: { name: 'a' } }, + forcedTools: ['a', 'b'], + hasFilteredTools: false, + }) + mockCheckForced.mockReturnValueOnce({ hasUsedForcedTool: true, usedForcedTools: ['a'] }) + mockCreate + .mockResolvedValueOnce(toolCallResponse('a', {})) + .mockResolvedValueOnce(textResponse('done')) + mockExecuteTool.mockResolvedValueOnce({ success: true, output: {} }) + + await openRouterProvider.executeRequest({ ...baseRequest, tools: [tool('a'), tool('b')] }) + + expect(mockCreate.mock.calls[0][0].tool_choice).toEqual({ + type: 'function', + function: { name: 'a' }, + }) + expect(mockCreate.mock.calls[1][0].tool_choice).toEqual({ + type: 'function', + function: { name: 'b' }, + }) + }) + + it('streams directly when there are no tools and sends usage opt-in', async () => { + mockCreate.mockResolvedValueOnce({}) + + const res = await openRouterProvider.executeRequest({ ...baseRequest, stream: true }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.stream).toBe(true) + expect(payload.stream_options).toEqual({ include_usage: true }) + expect(mockCreateStream).toHaveBeenCalledTimes(1) + expect(res).toHaveProperty('stream') + expect(res).toHaveProperty('execution.output.model', 'anthropic/claude-3.5-sonnet') + }) + + it('stops the tool loop at MAX_TOOL_ITERATIONS', async () => { + mockCreate.mockResolvedValue(toolCallResponse('looping', {})) + mockExecuteTool.mockResolvedValue({ success: true, output: {} }) + + const res = (await openRouterProvider.executeRequest({ + ...baseRequest, + tools: [tool('looping')], + })) as ProviderResponse + + expect(mockCreate).toHaveBeenCalledTimes(11) + expect(mockExecuteTool).toHaveBeenCalledTimes(10) + expect(res.toolCalls?.length).toBe(10) + }) + + it('wraps SDK errors in a ProviderError', async () => { + mockCreate.mockRejectedValueOnce(new Error('rate limited')) + + await expect(openRouterProvider.executeRequest(baseRequest)).rejects.toThrow('rate limited') + }) +}) diff --git a/apps/sim/providers/openrouter/index.ts b/apps/sim/providers/openrouter/index.ts index d3d2535b43..9bc180bdd1 100644 --- a/apps/sim/providers/openrouter/index.ts +++ b/apps/sim/providers/openrouter/index.ts @@ -376,8 +376,8 @@ export const openRouterProvider: ProviderConfig = { }) let resultContent: any - if (result.success) { - toolResults.push(result.output!) + if (result.success && result.output) { + toolResults.push(result.output) resultContent = result.output } else { resultContent = { @@ -653,8 +653,3 @@ export const openRouterProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */ diff --git a/apps/sim/providers/openrouter/utils.ts b/apps/sim/providers/openrouter/utils.ts index 8d8dade827..51637f5148 100644 --- a/apps/sim/providers/openrouter/utils.ts +++ b/apps/sim/providers/openrouter/utils.ts @@ -20,9 +20,6 @@ let modelCapabilitiesCache: Map | null = null let cacheTimestamp = 0 const CACHE_TTL_MS = 5 * 60 * 1000 // 5 minutes -/** - * Fetches and caches OpenRouter model capabilities from their API. - */ async function fetchModelCapabilities(): Promise> { try { const response = await fetch('https://openrouter.ai/api/v1/models', { @@ -82,18 +79,11 @@ export async function getOpenRouterModelCapabilities( return modelCapabilitiesCache.get(normalizedId) ?? null } -/** - * Checks if a model supports native structured outputs (json_schema). - */ export async function supportsNativeStructuredOutputs(modelId: string): Promise { const capabilities = await getOpenRouterModelCapabilities(modelId) return capabilities?.supportsStructuredOutputs ?? false } -/** - * Creates a ReadableStream from an OpenRouter streaming response. - * Uses the shared OpenAI-compatible streaming utility. - */ export function createReadableStreamFromOpenAIStream( openaiStream: AsyncIterable, onComplete?: (content: string, usage: CompletionUsage) => void @@ -101,10 +91,6 @@ export function createReadableStreamFromOpenAIStream( return createOpenAICompatibleStream(openaiStream, 'OpenRouter', onComplete) } -/** - * Checks if a forced tool was used in an OpenRouter response. - * Uses the shared OpenAI-compatible forced tool usage helper. - */ export function checkForForcedToolUsage( response: any, toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, diff --git a/apps/sim/providers/utils.ts b/apps/sim/providers/utils.ts index 205fb30787..22466f8a14 100644 --- a/apps/sim/providers/utils.ts +++ b/apps/sim/providers/utils.ts @@ -669,6 +669,59 @@ export function calculateCost( } } +/** + * Recursively enforces OpenAI strict-mode requirements on a JSON schema: + * - Sets `additionalProperties: false` on every object type. + * - Forces `required` to include ALL property keys. + * + * Required for any OpenAI-compatible backend that validates strict structured + * outputs (OpenAI, Azure OpenAI, and OpenAI routes behind proxies like LiteLLM), + * which reject schemas missing these constraints with an HTTP 400. + */ +export function enforceStrictSchema(schema: Record): Record { + if (!schema || typeof schema !== 'object') return schema + + const result = { ...schema } + + if (result.type === 'object') { + result.additionalProperties = false + + if (result.properties && typeof result.properties === 'object') { + const propKeys = Object.keys(result.properties as Record) + result.required = propKeys + result.properties = Object.fromEntries( + Object.entries(result.properties as Record).map(([key, value]) => [ + key, + enforceStrictSchema(value as Record), + ]) + ) + } + } + + if (result.type === 'array' && result.items) { + result.items = enforceStrictSchema(result.items as Record) + } + + for (const keyword of ['anyOf', 'oneOf', 'allOf']) { + if (Array.isArray(result[keyword])) { + result[keyword] = (result[keyword] as Record[]).map(enforceStrictSchema) + } + } + + for (const defKey of ['$defs', 'definitions']) { + if (result[defKey] && typeof result[defKey] === 'object') { + result[defKey] = Object.fromEntries( + Object.entries(result[defKey] as Record).map(([key, value]) => [ + key, + enforceStrictSchema(value as Record), + ]) + ) + } + } + + return result +} + /** * Sums the `cost.total` from each tool result returned during a provider tool loop. * Tool results may carry a `cost` object injected by `applyHostedKeyCostToResult`. diff --git a/apps/sim/providers/vllm/index.test.ts b/apps/sim/providers/vllm/index.test.ts new file mode 100644 index 0000000000..4477beeeda --- /dev/null +++ b/apps/sim/providers/vllm/index.test.ts @@ -0,0 +1,296 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCreate, + mockExecuteTool, + mockPrepareTools, + mockCheckForced, + mockCreateStream, + envState, +} = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), + mockPrepareTools: vi.fn(), + mockCheckForced: vi.fn(), + mockCreateStream: vi.fn(), + envState: { + VLLM_BASE_URL: 'http://localhost:8000', + VLLM_API_KEY: undefined as string | undefined, + }, +})) + +vi.mock('openai', () => ({ + default: vi.fn(() => ({ chat: { completions: { create: mockCreate } } })), +})) +vi.mock('@/lib/core/config/env', () => ({ env: envState })) +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn(() => []), + getProviderDefaultModel: vi.fn(() => 'vllm/generic'), +})) +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: vi.fn((messages) => messages), +})) +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn(() => ({ input: 0, output: 0, total: 0 })), + prepareToolExecution: vi.fn((_tool, args) => ({ toolParams: args, executionParams: args })), + prepareToolsWithUsageControl: mockPrepareTools, + sumToolCosts: vi.fn(() => 0), +})) +vi.mock('@/providers/vllm/utils', () => ({ + checkForForcedToolUsage: mockCheckForced, + createReadableStreamFromVLLMStream: mockCreateStream, +})) +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) +vi.mock('@/stores/providers', () => ({ + useProvidersStore: { getState: () => ({ setProviderModels: vi.fn() }) }, +})) + +import type { ProviderToolConfig } from '@/providers/types' +import { vllmProvider } from '@/providers/vllm/index' + +interface ToolCall { + id: string + type: 'function' + function: { name: string; arguments: string } +} + +function chatResponse(content: string | null, toolCalls?: ToolCall[]) { + return { + choices: [{ message: { content, tool_calls: toolCalls } }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } +} + +function makeTool(id: string): ProviderToolConfig { + return { + id, + name: id, + description: '', + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, + } +} + +const toolCall = (id: string, name: string, args = '{}'): ToolCall => ({ + id, + type: 'function', + function: { name, arguments: args }, +}) + +/** Payload passed to the Nth `chat.completions.create` call. */ +const createPayload = (callIndex: number) => mockCreate.mock.calls[callIndex][0] + +describe('vllmProvider', () => { + beforeEach(() => { + vi.clearAllMocks() + envState.VLLM_BASE_URL = 'http://localhost:8000' + envState.VLLM_API_KEY = undefined + mockPrepareTools.mockReturnValue({ + tools: [{ type: 'function', function: { name: 'myTool' } }], + toolChoice: 'auto', + forcedTools: [], + hasFilteredTools: false, + }) + mockCheckForced.mockReturnValue({ hasUsedForcedTool: false, usedForcedTools: [] }) + mockCreateStream.mockReturnValue(new ReadableStream({ start: (c) => c.close() })) + mockExecuteTool.mockResolvedValue({ success: true, output: { result: 'ok' } }) + }) + + it('builds a chat payload with the vllm/ prefix stripped and messages assembled in order', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('hello')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + systemPrompt: 'be helpful', + context: 'prior context', + messages: [{ role: 'user', content: 'hi' }], + temperature: 0.7, + maxTokens: 256, + }) + + const payload = createPayload(0) + expect(payload.model).toBe('llama-3') + expect(payload.temperature).toBe(0.7) + expect(payload.max_completion_tokens).toBe(256) + expect(payload.messages.map((m: { role: string }) => m.role)).toEqual([ + 'system', + 'user', + 'user', + ]) + expect(result.content).toBe('hello') + expect(result.tokens).toEqual({ input: 10, output: 5, total: 15 }) + }) + + it('sends response_format as json_schema with strict when a responseFormat is provided', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('{}')) + + await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + responseFormat: { name: 'out', schema: { type: 'object' }, strict: true }, + }) + + expect(createPayload(0).response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'out', schema: { type: 'object' }, strict: true }, + }) + }) + + it('strips markdown code fences from structured-output content', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('```json\n{"a":1}\n```')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + responseFormat: { name: 'out', schema: { type: 'object' }, strict: true }, + }) + + expect(result.content).toBe('{"a":1}') + }) + + it('runs the tool loop: executes tools, appends assistant + tool messages, returns results', async () => { + mockCreate + .mockResolvedValueOnce(chatResponse(null, [toolCall('call_1', 'myTool', '{"x":1}')])) + .mockResolvedValueOnce(chatResponse('final answer')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'use a tool' }], + tools: [makeTool('myTool')], + }) + + expect(mockExecuteTool).toHaveBeenCalledWith('myTool', { x: 1 }, expect.anything()) + + const [assistantMessage, toolMessage] = createPayload(1).messages.slice(-2) + expect(assistantMessage).toMatchObject({ + role: 'assistant', + content: null, + tool_calls: [{ id: 'call_1', type: 'function', function: { name: 'myTool' } }], + }) + expect(toolMessage).toMatchObject({ role: 'tool', tool_call_id: 'call_1' }) + expect(toolMessage).not.toHaveProperty('name') + + expect(result.content).toBe('final answer') + expect(result.toolCalls).toHaveLength(1) + expect(result.toolCalls?.[0]).toMatchObject({ name: 'myTool', success: true }) + expect(result.toolResults).toHaveLength(1) + }) + + it('records a failed tool result without throwing', async () => { + mockExecuteTool.mockResolvedValueOnce({ success: false, error: 'tool blew up' }) + mockCreate + .mockResolvedValueOnce(chatResponse(null, [toolCall('call_1', 'myTool')])) + .mockResolvedValueOnce(chatResponse('done')) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'go' }], + tools: [makeTool('myTool')], + }) + + expect(result.toolCalls?.[0]).toMatchObject({ name: 'myTool', success: false }) + const toolMessage = createPayload(1).messages.at(-1) + expect(JSON.parse(toolMessage.content)).toMatchObject({ error: true, tool: 'myTool' }) + }) + + it('surfaces a ProviderError when a follow-up model call fails mid-loop', async () => { + mockCreate + .mockResolvedValueOnce(chatResponse(null, [toolCall('call_1', 'myTool')])) + .mockRejectedValueOnce(new Error('connection reset')) + + await expect( + vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'go' }], + tools: [makeTool('myTool')], + }) + ).rejects.toThrow('connection reset') + + expect(mockExecuteTool).toHaveBeenCalledTimes(1) + }) + + it('cycles forced tools: forces the next forced tool after the first is used', async () => { + mockPrepareTools.mockReturnValue({ + tools: [{ type: 'function', function: { name: 'toolA' } }], + toolChoice: { type: 'function', function: { name: 'toolA' } }, + forcedTools: ['toolA', 'toolB'], + hasFilteredTools: false, + }) + mockCheckForced + .mockReturnValueOnce({ hasUsedForcedTool: true, usedForcedTools: ['toolA'] }) + .mockReturnValueOnce({ hasUsedForcedTool: true, usedForcedTools: ['toolA', 'toolB'] }) + mockCreate + .mockResolvedValueOnce(chatResponse(null, [toolCall('c1', 'toolA')])) + .mockResolvedValueOnce(chatResponse('done')) + + await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'go' }], + tools: [makeTool('toolA'), makeTool('toolB')], + }) + + expect(createPayload(1).tool_choice).toEqual({ type: 'function', function: { name: 'toolB' } }) + }) + + it('streams directly when there are no tools, requesting usage in the stream', async () => { + mockCreate.mockResolvedValueOnce({}) + + const result = await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + stream: true, + }) + + expect(mockCreate).toHaveBeenCalledTimes(1) + const payload = createPayload(0) + expect(payload.stream).toBe(true) + expect(payload.stream_options).toEqual({ include_usage: true }) + expect('stream' in result && 'execution' in result).toBe(true) + }) + + it('uses tool_choice "none" on the final streaming call after tool processing', async () => { + mockCreate.mockResolvedValueOnce(chatResponse('answer')).mockResolvedValueOnce({}) + + await vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + stream: true, + tools: [makeTool('myTool')], + }) + + const streamingPayload = createPayload(1) + expect(streamingPayload.stream).toBe(true) + expect(streamingPayload.tool_choice).toBe('none') + }) + + it('throws a ProviderError carrying the vLLM error message on API failure', async () => { + mockCreate.mockRejectedValueOnce({ + error: { message: 'bad request', type: 'invalid', code: 400 }, + }) + + await expect( + vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + }) + ).rejects.toThrow('bad request') + }) + + it('throws when no base URL is configured', async () => { + envState.VLLM_BASE_URL = '' + + await expect( + vllmProvider.executeRequest({ + model: 'vllm/llama-3', + messages: [{ role: 'user', content: 'hi' }], + }) + ).rejects.toThrow('VLLM_BASE_URL is required') + }) +}) diff --git a/apps/sim/providers/vllm/index.ts b/apps/sim/providers/vllm/index.ts index 2de3c69511..87610d9c43 100644 --- a/apps/sim/providers/vllm/index.ts +++ b/apps/sim/providers/vllm/index.ts @@ -21,9 +21,8 @@ import { prepareToolExecution, prepareToolsWithUsageControl, sumToolCosts, - trackForcedToolUsage, } from '@/providers/utils' -import { createReadableStreamFromVLLMStream } from '@/providers/vllm/utils' +import { checkForForcedToolUsage, createReadableStreamFromVLLMStream } from '@/providers/vllm/utils' import { useProvidersStore } from '@/stores/providers' import { executeTool } from '@/tools' @@ -282,25 +281,7 @@ export const vllmProvider: ProviderConfig = { const forcedTools = preparedTools?.forcedTools || [] let usedForcedTools: string[] = [] - - const checkForForcedToolUsage = ( - response: any, - toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any } - ) => { - if (typeof toolChoice === 'object' && response.choices[0]?.message?.tool_calls) { - const toolCallsResponse = response.choices[0].message.tool_calls - const result = trackForcedToolUsage( - toolCallsResponse, - toolChoice, - logger, - 'vllm', - forcedTools, - usedForcedTools - ) - hasUsedForcedTool = result.hasUsedForcedTool - usedForcedTools = result.usedForcedTools - } - } + let hasUsedForcedTool = false let currentResponse = await vllm.chat.completions.create( payload, @@ -327,8 +308,6 @@ export const vllmProvider: ProviderConfig = { let modelTime = firstResponseTime let toolsTime = 0 - let hasUsedForcedTool = false - const timeSegments: TimeSegment[] = [ { type: 'model', @@ -339,7 +318,16 @@ export const vllmProvider: ProviderConfig = { }, ] - checkForForcedToolUsage(currentResponse, originalToolChoice) + if (originalToolChoice) { + const forcedResult = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = forcedResult.hasUsedForcedTool + usedForcedTools = forcedResult.usedForcedTools + } while (iterationCount < MAX_TOOL_ITERATIONS) { if (currentResponse.choices[0]?.message?.content) { @@ -502,7 +490,16 @@ export const vllmProvider: ProviderConfig = { request.abortSignal ? { signal: request.abortSignal } : undefined ) - checkForForcedToolUsage(currentResponse, nextPayload.tool_choice) + if (nextPayload.tool_choice && typeof nextPayload.tool_choice === 'object') { + const forcedResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = forcedResult.hasUsedForcedTool + usedForcedTools = forcedResult.usedForcedTools + } const nextModelEndTime = Date.now() const thisModelTime = nextModelEndTime - nextModelStartTime @@ -550,7 +547,7 @@ export const vllmProvider: ProviderConfig = { const streamingParams: ChatCompletionCreateParamsStreaming = { ...payload, messages: currentMessages, - tool_choice: 'auto', + tool_choice: 'none', stream: true, stream_options: { include_usage: true }, } @@ -685,8 +682,3 @@ export const vllmProvider: ProviderConfig = { } }, } - -/** - * Enriches the last model segment with per-iteration content from a Chat - * Completions response: assistant text, tool calls, finish reason, token usage. - */ diff --git a/apps/sim/providers/vllm/utils.ts b/apps/sim/providers/vllm/utils.ts index 6f43329148..2b1db5bf55 100644 --- a/apps/sim/providers/vllm/utils.ts +++ b/apps/sim/providers/vllm/utils.ts @@ -1,6 +1,6 @@ import type { ChatCompletionChunk } from 'openai/resources/chat/completions' import type { CompletionUsage } from 'openai/resources/completions' -import { createOpenAICompatibleStream } from '@/providers/utils' +import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils' /** * Creates a ReadableStream from a vLLM streaming response. @@ -12,3 +12,16 @@ export function createReadableStreamFromVLLMStream( ): ReadableStream { return createOpenAICompatibleStream(vllmStream, 'vLLM', onComplete) } + +/** + * Checks if a forced tool was used in a vLLM response. + * Uses the shared OpenAI-compatible forced tool usage helper. + */ +export function checkForForcedToolUsage( + response: any, + toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, + forcedTools: string[], + usedForcedTools: string[] +): { hasUsedForcedTool: boolean; usedForcedTools: string[] } { + return checkForForcedToolUsageOpenAI(response, toolChoice, 'vLLM', forcedTools, usedForcedTools) +}