Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 226 additions & 0 deletions apps/sim/providers/fireworks/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/**
* @vitest-environment node
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'

const {
mockCreate,
mockSupportsNativeStructuredOutputs,
mockPrepareToolsWithUsageControl,
mockExecuteTool,
} = vi.hoisted(() => ({
mockCreate: vi.fn(),
mockSupportsNativeStructuredOutputs: vi.fn(),
mockPrepareToolsWithUsageControl: vi.fn(),
mockExecuteTool: vi.fn(),
}))

vi.mock('openai', () => ({
default: vi.fn().mockImplementation(() => ({
chat: { completions: { create: mockCreate } },
})),
}))

vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 5 }))

vi.mock('@/providers/models', () => ({
getProviderModels: vi.fn().mockReturnValue([]),
getProviderDefaultModel: vi.fn().mockReturnValue('llama-v3p1-70b-instruct'),
}))

vi.mock('@/providers/attachments', () => ({
formatMessagesForProvider: vi.fn((messages) => messages),
}))

vi.mock('@/providers/fireworks/utils', () => ({
supportsNativeStructuredOutputs: mockSupportsNativeStructuredOutputs,
createReadableStreamFromOpenAIStream: vi.fn(() => ({}) as ReadableStream),
checkForForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })),
}))

vi.mock('@/providers/trace-enrichment', () => ({
enrichLastModelSegmentFromChatCompletions: vi.fn(),
}))

vi.mock('@/providers/utils', () => ({
calculateCost: vi.fn().mockReturnValue({ input: 0, output: 0, total: 0 }),
generateSchemaInstructions: vi.fn(() => 'SCHEMA_INSTRUCTIONS'),
prepareToolExecution: vi.fn(() => ({ toolParams: { x: 1 }, executionParams: { x: 1 } })),
prepareToolsWithUsageControl: mockPrepareToolsWithUsageControl,
sumToolCosts: vi.fn().mockReturnValue(0),
}))

vi.mock('@/tools', () => ({ executeTool: mockExecuteTool }))

import { fireworksProvider } from '@/providers/fireworks/index'
import { ProviderError } from '@/providers/types'

const textResponse = (content: string) => ({
choices: [{ message: { content, tool_calls: [] } }],
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
})

const toolCallResponse = () => ({
choices: [
{
message: {
content: null,
tool_calls: [
{ id: 'call_1', type: 'function', function: { name: 'my_tool', arguments: '{"x":1}' } },
],
},
},
],
usage: { prompt_tokens: 8, completion_tokens: 4, total_tokens: 12 },
})

const toolDef = {
id: 'my_tool',
name: 'my_tool',
description: '',
params: {},
parameters: { type: 'object', properties: {}, required: [] },
}

const callBody = (index: number) => mockCreate.mock.calls[index][0]
const lastCallBody = () => mockCreate.mock.calls.at(-1)?.[0]

describe('fireworksProvider', () => {
beforeEach(() => {
vi.clearAllMocks()
mockSupportsNativeStructuredOutputs.mockResolvedValue(true)
mockPrepareToolsWithUsageControl.mockImplementation((tools) => ({
tools,
toolChoice: 'auto',
forcedTools: [],
}))
mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } })
})

const baseRequest = {
model: 'fireworks/llama-v3p1-70b-instruct',
systemPrompt: 'You are helpful.',
messages: [{ role: 'user' as const, content: 'Hello' }],
apiKey: 'fw-test-key',
}

it('throws when the API key is missing', async () => {
await expect(
fireworksProvider.executeRequest({ ...baseRequest, apiKey: undefined })
).rejects.toThrow('API key is required for Fireworks')
})

it('returns content and token usage for a simple request', async () => {
mockCreate.mockResolvedValueOnce(textResponse('hi there'))

const result = await fireworksProvider.executeRequest(baseRequest)

expect(result).toMatchObject({
content: 'hi there',
model: 'llama-v3p1-70b-instruct',
tokens: { input: 10, output: 5, total: 15 },
})
})

it('wraps API errors in a ProviderError', async () => {
mockCreate.mockRejectedValueOnce(new Error('boom'))

await expect(fireworksProvider.executeRequest(baseRequest)).rejects.toBeInstanceOf(
ProviderError
)
})

it('streams directly when there are no tools', async () => {
mockCreate.mockResolvedValueOnce({})

const result = await fireworksProvider.executeRequest({ ...baseRequest, stream: true })

expect(lastCallBody()).toMatchObject({ stream: true, stream_options: { include_usage: true } })
expect(result).toHaveProperty('stream')
expect(result).toHaveProperty('execution')
})

it('sends a json_schema response_format with no strict field', async () => {
mockCreate.mockResolvedValueOnce(textResponse('{}'))

await fireworksProvider.executeRequest({
...baseRequest,
responseFormat: { name: 'my_schema', schema: { type: 'object' }, strict: true },
})

expect(lastCallBody().response_format).toEqual({
type: 'json_schema',
json_schema: { name: 'my_schema', schema: { type: 'object' } },
})
expect(lastCallBody().response_format.json_schema).not.toHaveProperty('strict')
})

it('falls back to json_object with prompt instructions when native is unsupported', async () => {
mockSupportsNativeStructuredOutputs.mockResolvedValue(false)
mockCreate.mockResolvedValueOnce(textResponse('{}'))

await fireworksProvider.executeRequest({
...baseRequest,
responseFormat: { name: 'my_schema', schema: { type: 'object' } },
})

expect(lastCallBody().response_format).toEqual({ type: 'json_object' })
expect(lastCallBody().messages.at(-1)).toEqual({
role: 'user',
content: 'SCHEMA_INSTRUCTIONS',
})
})

it('defers response_format to a final call when tools are active', async () => {
mockCreate
.mockResolvedValueOnce(textResponse('intermediate'))
.mockResolvedValueOnce(textResponse('{"done":true}'))

await fireworksProvider.executeRequest({
...baseRequest,
responseFormat: { name: 'my_schema', schema: { type: 'object' } },
tools: [toolDef],
})

expect(mockCreate).toHaveBeenCalledTimes(2)
expect(callBody(0).response_format).toBeUndefined()
expect(callBody(0).tools).toBeDefined()
expect(callBody(1).response_format).toEqual({
type: 'json_schema',
json_schema: { name: 'my_schema', schema: { type: 'object' } },
})
expect(callBody(1).tools).toBeUndefined()
})

it('runs the tool loop and threads tool results back into the conversation', async () => {
mockCreate
.mockResolvedValueOnce(toolCallResponse())
.mockResolvedValueOnce(textResponse('final answer'))

const result = await fireworksProvider.executeRequest({ ...baseRequest, tools: [toolDef] })

expect(mockExecuteTool).toHaveBeenCalledWith('my_tool', { x: 1 }, expect.anything())
expect(result).toMatchObject({ content: 'final answer' })
expect((result as { toolCalls?: unknown[] }).toolCalls).toHaveLength(1)

const followUpMessages = callBody(1).messages
expect(followUpMessages).toContainEqual(
expect.objectContaining({ role: 'assistant', tool_calls: expect.any(Array) })
)
expect(followUpMessages).toContainEqual(
expect.objectContaining({ role: 'tool', tool_call_id: 'call_1' })
)
})

it("forces tool_choice 'none' on the final streaming call after tools run", async () => {
mockCreate
.mockResolvedValueOnce(toolCallResponse())
.mockResolvedValueOnce(textResponse('done'))
.mockResolvedValueOnce({})

await fireworksProvider.executeRequest({ ...baseRequest, stream: true, tools: [toolDef] })

expect(mockCreate).toHaveBeenCalledTimes(3)
expect(lastCallBody()).toMatchObject({ tool_choice: 'none', stream: true })
})
})
10 changes: 2 additions & 8 deletions apps/sim/providers/fireworks/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const logger = createLogger('FireworksProvider')

/**
* Applies structured output configuration to a payload based on model capabilities.
* Uses json_schema with strict mode for supported models, falls back to json_object with prompt instructions.
* Uses native json_schema for supported models, falls back to json_object with prompt instructions.
*/
async function applyResponseFormat(
targetPayload: any,
Expand All @@ -51,7 +51,6 @@ async function applyResponseFormat(
json_schema: {
name: responseFormat.name || 'response_schema',
schema: responseFormat.schema || responseFormat,
strict: responseFormat.strict !== false,
},
}
return messages
Expand Down Expand Up @@ -469,7 +468,7 @@ export const fireworksProvider: ProviderConfig = {
const streamingParams: ChatCompletionCreateParamsStreaming = {
...payload,
messages: [...currentMessages],
tool_choice: 'auto',
tool_choice: 'none',
stream: true,
stream_options: { include_usage: true },
}
Expand Down Expand Up @@ -652,8 +651,3 @@ export const fireworksProvider: ProviderConfig = {
}
},
}

/**
* Enriches the last model segment with per-iteration content from a Chat
* Completions response: assistant text, tool calls, finish reason, token usage.
*/
Loading
Loading