From 46c462eb4dfaeb6e3b7de4488807fcd4aef16ee6 Mon Sep 17 00:00:00 2001 From: krazer Date: Sat, 23 May 2026 13:12:52 -0400 Subject: [PATCH 1/4] support harmony format and streaming Co-authored-by: Copilot --- docker/Dockerfile | 1 + schemas/model_config.schema.json | 6 + src/arbiterAI/arbiterAI.cpp | 11 + src/arbiterAI/arbiterAI.h | 14 +- src/arbiterAI/hardwareDetector.cpp | 79 ++- src/arbiterAI/modelManager.cpp | 14 + src/arbiterAI/modelManager.h | 1 + src/arbiterAI/modelRuntime.cpp | 17 +- src/arbiterAI/modelRuntime.h | 5 + src/arbiterAI/providers/baseProvider.cpp | 8 + src/arbiterAI/providers/baseProvider.h | 12 + src/arbiterAI/providers/llama.cpp | 748 +++++++++++++++++++- src/arbiterAI/providers/llama.h | 23 + src/server/dashboard.h | 19 +- src/server/dashboardConfig.h | 47 +- src/server/main.cpp | 14 + src/server/routes.cpp | 549 +++++++++++++- vcpkg/custom_ports/llama-cpp/portfile.cmake | 10 +- vcpkg/custom_ports/llama-cpp/vcpkg.json | 2 +- 19 files changed, 1481 insertions(+), 99 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 079cfad..072d0aa 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -33,6 +33,7 @@ RUN apt-get update && apt-get install -y \ vulkan-tools \ libvulkan-dev \ mesa-vulkan-drivers \ + spirv-headers \ glslc \ glslang-tools \ wget \ diff --git a/schemas/model_config.schema.json b/schemas/model_config.schema.json index 537777b..50d0914 100644 --- a/schemas/model_config.schema.json +++ b/schemas/model_config.schema.json @@ -299,6 +299,12 @@ "enum": ["vulkan", "rocm", "cuda"] }, "uniqueItems": true + }, + "api_format": { + "type": "string", + "description": "Output format produced by the model. When set (e.g. 'harmony'), the server converts the model's native output to standard OpenAI API format so clients don't need to understand the model's native format.", + "enum": ["", "harmony"], + "default": "" } } } diff --git a/src/arbiterAI/arbiterAI.cpp b/src/arbiterAI/arbiterAI.cpp index 43d7bb7..63fc42a 100644 --- a/src/arbiterAI/arbiterAI.cpp +++ b/src/arbiterAI/arbiterAI.cpp @@ -254,6 +254,13 @@ ErrorCode ArbiterAI::completion(const CompletionRequest &request, CompletionResp ErrorCode ArbiterAI::streamingCompletion(const CompletionRequest &request, std::function callback) +{ + return streamingCompletion(request, callback, nullptr); +} + +ErrorCode ArbiterAI::streamingCompletion(const CompletionRequest &request, + std::function callback, + std::function waitCallback) { if (!ArbiterAI::instance().initialized) { @@ -273,6 +280,10 @@ ErrorCode ArbiterAI::streamingCompletion(const CompletionRequest &request, return ErrorCode::UnsupportedProvider; } + if(waitCallback) + { + return provider->streamingCompletion(request, callback, waitCallback); + } return provider->streamingCompletion(request, callback); } diff --git a/src/arbiterAI/arbiterAI.h b/src/arbiterAI/arbiterAI.h index e5106fb..4874664 100644 --- a/src/arbiterAI/arbiterAI.h +++ b/src/arbiterAI/arbiterAI.h @@ -76,7 +76,8 @@ enum class ErrorCode ModelLoadError, ModelDownloading, ModelDownloadFailed, - InsufficientStorage + InsufficientStorage, + ServerOverloaded }; /** @@ -616,6 +617,17 @@ class ArbiterAI ErrorCode streamingCompletion(const CompletionRequest &request, std::function callback); + /** + * @brief Perform streaming completion with queue wait notification + * @param request Completion parameters + * @param callback Function to receive streaming chunks + * @param waitCallback Called periodically while waiting for backend availability + * @return ErrorCode indicating success or failure + */ + ErrorCode streamingCompletion(const CompletionRequest &request, + std::function callback, + std::function waitCallback); + /** * @brief Process multiple completion requests in batch * @param requests Vector of completion requests diff --git a/src/arbiterAI/hardwareDetector.cpp b/src/arbiterAI/hardwareDetector.cpp index 7ca4685..100f12b 100644 --- a/src/arbiterAI/hardwareDetector.cpp +++ b/src/arbiterAI/hardwareDetector.cpp @@ -962,7 +962,7 @@ void HardwareDetector::detectUnifiedMemory() continue; } - // Read sysfs VRAM and GTT for diagnostic logging + // Read sysfs VRAM and GTT for unified memory sizing long long sysfsVramTotal=0, sysfsVramUsed=0; long long gttTotalBytes=0, gttUsedBytes=0; @@ -983,57 +983,84 @@ void HardwareDetector::detectUnifiedMemory() if(file.is_open()) file >> gttUsedBytes; } + // Read TTM pages limit — the kernel's hard cap on GPU-accessible + // memory. On unified memory systems the actual usable GPU memory + // is min(gttTotal, ttmPagesLimit * 4096). + long long ttmLimitBytes=0; + { + std::ifstream file("/sys/module/ttm/parameters/pages_limit"); + if(file.is_open()) + { + long long pages=0; + file >> pages; + ttmLimitBytes=pages*4096LL; + } + } + int sysfsVramTotalMb=static_cast(sysfsVramTotal/(1024LL*1024LL)); int sysfsVramUsedMb=static_cast(sysfsVramUsed/(1024LL*1024LL)); int gttTotalMb=static_cast(gttTotalBytes/(1024LL*1024LL)); int gttUsedMb=static_cast(gttUsedBytes/(1024LL*1024LL)); int gttFreeMb=gttTotalMb-gttUsedMb; + int ttmLimitMb=static_cast(ttmLimitBytes/(1024LL*1024LL)); if(gttFreeMb<0) gttFreeMb=0; + // Compute kernel-reported usable GPU memory. + // Vulkan's RADV driver splits unified memory into device-local and + // host heaps using a 2/3 : 1/3 heuristic, which under-reports the + // actual GPU-accessible memory. The kernel sysfs values are + // authoritative: usable = min(gttTotal, ttmLimit). + int kernelUsableMb=0; + if(gttTotalBytes>0) + { + kernelUsableMb=gttTotalMb; + if(ttmLimitBytes>0&&ttmLimitMb0) { - // Vulkan budget is authoritative for allocation decisions. - // Log sysfs as a diagnostic side channel only. - continue; - } + // Override vramTotalMb with the kernel-reported usable memory + // so model-fit calculations use the real capacity. + int usedMb=gpu.vramTotalMb-gpu.vramFreeMb; + if(usedMb<0) usedMb=0; - // No Vulkan budget — use sysfs data for GPU-accessible memory. - // Refine VRAM free from sysfs (more accurate than "assume all free"). - if(sysfsVramTotal>0) - { - gpu.vramTotalMb=sysfsVramTotalMb; - gpu.vramFreeMb=sysfsVramTotalMb-sysfsVramUsedMb; - if(gpu.vramFreeMb<0) gpu.vramFreeMb=0; - } + gpu.vramTotalMb=kernelUsableMb; + gpu.vramFreeMb=std::max(0, kernelUsableMb-usedMb); - if(gttTotalBytes>0) - { - // GPU-accessible memory = VRAM + GTT (system RAM mapped to GPU) - gpu.gpuAccessibleRamMb=gpu.vramTotalMb+gttTotalMb; - gpu.gpuAccessibleRamFreeMb=gpu.vramFreeMb+gttFreeMb; + gpu.gpuAccessibleRamMb=kernelUsableMb; + gpu.gpuAccessibleRamFreeMb=gpu.vramFreeMb; spdlog::log(m_firstRefreshDone ? spdlog::level::debug : spdlog::level::info, - "Unified memory GPU {}: {} — sysfs fallback: " - "total accessible {}MB ({}MB free)", + "Unified memory GPU {}: {} — kernel sysfs override: " + "total {}MB ({}MB free)", gpu.index, gpu.name, gpu.gpuAccessibleRamMb, gpu.gpuAccessibleRamFreeMb); } - else + else if(!hasBudgetData) { - // No GTT info — fall back to system RAM + // No kernel sysfs data and no Vulkan budget — fall back to + // system RAM as a last resort. gpu.gpuAccessibleRamMb=m_systemInfo.totalRamMb; gpu.gpuAccessibleRamFreeMb=m_systemInfo.freeRamMb; spdlog::log(m_firstRefreshDone ? spdlog::level::debug : spdlog::level::info, - "Unified memory GPU {}: {} — no GTT info, " + "Unified memory GPU {}: {} — no sysfs or budget data, " "falling back to system RAM ({}MB total, {}MB free)", gpu.index, gpu.name, gpu.gpuAccessibleRamMb, gpu.gpuAccessibleRamFreeMb); diff --git a/src/arbiterAI/modelManager.cpp b/src/arbiterAI/modelManager.cpp index 11f629a..e0de4bf 100644 --- a/src/arbiterAI/modelManager.cpp +++ b/src/arbiterAI/modelManager.cpp @@ -433,6 +433,12 @@ bool ModelManager::parseModelInfo(const nlohmann::json &modelJson, ModelInfo &in } } + // API format (output format conversion: "" = standard openai, "harmony" = harmony-to-openai) + if(modelJson.contains("api_format")&&modelJson["api_format"].is_string()) + { + info.apiFormat=modelJson["api_format"].get(); + } + return true; } @@ -666,6 +672,8 @@ void ModelManager::mergeModelInfo(ModelInfo &existing, const ModelInfo &source, existing.download=source.download; if(sourceJson.contains("version")) existing.configVersion=source.configVersion; + if(sourceJson.contains("api_format")) + existing.apiFormat=source.apiFormat; } bool ModelManager::addModelFromJson(const nlohmann::json &modelJson, std::string &error) @@ -1041,6 +1049,12 @@ nlohmann::json ModelManager::modelInfoToJson(const ModelInfo &info) j["disabled_backends"]=info.disabledBackends; } + // API format + if(!info.apiFormat.empty()) + { + j["api_format"]=info.apiFormat; + } + return j; } diff --git a/src/arbiterAI/modelManager.h b/src/arbiterAI/modelManager.h index 04309a1..ed12b07 100644 --- a/src/arbiterAI/modelManager.h +++ b/src/arbiterAI/modelManager.h @@ -123,6 +123,7 @@ struct ModelInfo RuntimeOptions runtimeOptions; // Per-model llama.cpp runtime options std::vector backendPriority; // Ordered preference: ["vulkan", "rocm", "cuda"] std::vector disabledBackends; // Backends to exclude (model-level override) + std::string apiFormat; // API output format: "" (default/openai) or "harmony" bool isCompatible(const std::string &clientVersion) const; bool isSchemaCompatible(const std::string &schemaVersion) const; diff --git a/src/arbiterAI/modelRuntime.cpp b/src/arbiterAI/modelRuntime.cpp index ccd3d0c..da76092 100644 --- a/src/arbiterAI/modelRuntime.cpp +++ b/src/arbiterAI/modelRuntime.cpp @@ -1331,7 +1331,10 @@ void ModelRuntime::evictIfNeeded(int requiredVramMb, int gpuIndex) void ModelRuntime::beginInference(const std::string &model) { - m_activeInference.insert(model); + { + std::lock_guard lock(m_activeInferenceMutex); + m_activeInference.insert(model); + } std::lock_guard lock(m_mutex); auto it=m_models.find(model); @@ -1354,9 +1357,14 @@ void ModelRuntime::endInference(const std::string &model) } } - m_activeInference.erase(model); + bool shouldDrain=false; + { + std::lock_guard lock(m_activeInferenceMutex); + m_activeInference.erase(model); + shouldDrain=m_activeInference.empty(); + } - if(m_activeInference.empty()) + if(shouldDrain) { drainPendingSwaps(); } @@ -1364,16 +1372,19 @@ void ModelRuntime::endInference(const std::string &model) bool ModelRuntime::isInferenceActive() const { + std::lock_guard lock(m_activeInferenceMutex); return !m_activeInference.empty(); } bool ModelRuntime::isInferenceActive(const std::string &model) const { + std::lock_guard lock(m_activeInferenceMutex); return m_activeInference.count(model)>0; } int ModelRuntime::getActiveInferenceCount() const { + std::lock_guard lock(m_activeInferenceMutex); return static_cast(m_activeInference.size()); } diff --git a/src/arbiterAI/modelRuntime.h b/src/arbiterAI/modelRuntime.h index b238851..10a1b83 100644 --- a/src/arbiterAI/modelRuntime.h +++ b/src/arbiterAI/modelRuntime.h @@ -189,6 +189,9 @@ class ModelRuntime { /// Mark inference as completed on a model and drain pending swaps. void endInference(const std::string &model); + /// Get the inference mutex for serializing llama_context access. + std::timed_mutex &getInferenceMutex() { return m_inferenceMutex; } + /// Check if any inference is currently active. bool isInferenceActive() const; @@ -286,6 +289,8 @@ class ModelRuntime { std::map m_models; mutable std::mutex m_mutex; + mutable std::timed_mutex m_inferenceMutex; // serializes llama_context access + mutable std::mutex m_activeInferenceMutex; // protects m_activeInference std::string m_modelsDir="/models/"; int m_readyRamBudgetMb=0; std::vector m_defaultBackendPriority; diff --git a/src/arbiterAI/providers/baseProvider.cpp b/src/arbiterAI/providers/baseProvider.cpp index d97e04b..983a2c4 100644 --- a/src/arbiterAI/providers/baseProvider.cpp +++ b/src/arbiterAI/providers/baseProvider.cpp @@ -115,4 +115,12 @@ std::vector BaseProvider::batchCompletion(const std::vector< return responses; } +ErrorCode BaseProvider::streamingCompletion(const CompletionRequest &request, + std::function callback, + std::function waitCallback) +{ + // Default: ignore waitCallback, delegate to standard streaming + return streamingCompletion(request, callback); +} + } // namespace arbiterAI diff --git a/src/arbiterAI/providers/baseProvider.h b/src/arbiterAI/providers/baseProvider.h index 0758541..2b80e4a 100644 --- a/src/arbiterAI/providers/baseProvider.h +++ b/src/arbiterAI/providers/baseProvider.h @@ -54,6 +54,18 @@ class BaseProvider virtual ErrorCode streamingCompletion(const CompletionRequest &request, std::function callback) = 0; + /** + * @brief Perform streaming text completion with queue wait notification + * @param request Completion parameters + * @param callback Function to receive streaming chunks + * @param waitCallback Called periodically while waiting for backend availability + * @return ErrorCode indicating success or failure + */ + virtual ErrorCode streamingCompletion(const CompletionRequest &request, + std::function callback, + std::function waitCallback); + + /** * @brief Process multiple completion requests in batch * @param requests Vector of completion requests diff --git a/src/arbiterAI/providers/llama.cpp b/src/arbiterAI/providers/llama.cpp index 384f409..541cf88 100644 --- a/src/arbiterAI/providers/llama.cpp +++ b/src/arbiterAI/providers/llama.cpp @@ -7,8 +7,11 @@ #include #include +#include +#include #include #include +#include namespace arbiterAI { @@ -44,8 +47,30 @@ ErrorCode Llama::completion(const CompletionRequest &request, return ErrorCode::ModelNotLoaded; } + spdlog::info("[llama] completion request for model '{}', waiting for inference lock", request.model); + auto lockWaitStart=std::chrono::steady_clock::now(); + runtime.beginInference(request.model); + // Use timed lock to avoid blocking HTTP threads indefinitely. + // If the lock can't be acquired within 5 minutes, return overloaded. + bool lockAcquired=runtime.getInferenceMutex().try_lock_for(std::chrono::minutes(5)); + if(!lockAcquired) + { + runtime.endInference(request.model); + spdlog::warn("[llama] completion request for model '{}' timed out waiting for inference lock", request.model); + return ErrorCode::ServerOverloaded; + } + std::lock_guard inferenceLock(runtime.getInferenceMutex(), std::adopt_lock); + + auto lockWaitMs=std::chrono::duration( + std::chrono::steady_clock::now()-lockWaitStart).count(); + if(lockWaitMs>100.0) + { + spdlog::warn("[llama] inference lock acquired after {:.1f}ms wait", lockWaitMs); + } + + spdlog::info("[llama] starting inference for model '{}'", request.model); std::chrono::steady_clock::time_point startTime=std::chrono::steady_clock::now(); std::string resultText; @@ -62,33 +87,40 @@ ErrorCode Llama::completion(const CompletionRequest &request, runtime.endInference(request.model); - if(code==ErrorCode::Success) + if(code!=ErrorCode::Success) { - response.text=resultText; - response.provider="llama"; - response.model=request.model; - response.usage.prompt_tokens=promptTokens; - response.usage.completion_tokens=completionTokens; - response.usage.total_tokens=promptTokens+completionTokens; - response.finishReason="stop"; + spdlog::error("[llama] inference failed for model '{}' after {:.1f}ms (error={})", + request.model, totalTimeMs, static_cast(code)); + return code; + } - // Record telemetry - std::optional state=runtime.getModelState(request.model); + spdlog::info("[llama] inference complete: prompt={} tokens ({:.1f}ms), gen={} tokens ({:.1f}ms), total={:.1f}ms", + promptTokens, promptTimeMs, completionTokens, generationTimeMs, totalTimeMs); - InferenceStats stats; - stats.model=request.model; - stats.variant=state?state->variant:""; - stats.promptTokens=promptTokens; - stats.completionTokens=completionTokens; - stats.totalTimeMs=totalTimeMs; - stats.promptTimeMs=promptTimeMs; - stats.generationTimeMs=generationTimeMs; - stats.tokensPerSecond=totalTimeMs>0.0?(completionTokens/(totalTimeMs/1000.0)):0.0; - stats.promptTokensPerSecond=promptTimeMs>0.0?(promptTokens/(promptTimeMs/1000.0)):0.0; - stats.generationTokensPerSecond=generationTimeMs>0.0?(completionTokens/(generationTimeMs/1000.0)):0.0; - stats.timestamp=std::chrono::system_clock::now(); - TelemetryCollector::instance().recordInference(stats); - } + response.text=resultText; + response.provider="llama"; + response.model=request.model; + response.usage.prompt_tokens=promptTokens; + response.usage.completion_tokens=completionTokens; + response.usage.total_tokens=promptTokens+completionTokens; + response.finishReason="stop"; + + // Record telemetry + std::optional state=runtime.getModelState(request.model); + + InferenceStats stats; + stats.model=request.model; + stats.variant=state?state->variant:""; + stats.promptTokens=promptTokens; + stats.completionTokens=completionTokens; + stats.totalTimeMs=totalTimeMs; + stats.promptTimeMs=promptTimeMs; + stats.generationTimeMs=generationTimeMs; + stats.tokensPerSecond=totalTimeMs>0.0?(completionTokens/(totalTimeMs/1000.0)):0.0; + stats.promptTokensPerSecond=promptTimeMs>0.0?(promptTokens/(promptTimeMs/1000.0)):0.0; + stats.generationTokensPerSecond=generationTimeMs>0.0?(completionTokens/(generationTimeMs/1000.0)):0.0; + stats.timestamp=std::chrono::system_clock::now(); + TelemetryCollector::instance().recordInference(stats); return code; } @@ -119,8 +151,28 @@ ErrorCode Llama::streamingCompletion(const CompletionRequest &request, return ErrorCode::ModelNotFound; } + spdlog::info("[llama] streaming completion request for model '{}', waiting for inference lock", request.model); + auto lockWaitStart=std::chrono::steady_clock::now(); + runtime.beginInference(request.model); + bool lockAcquired=runtime.getInferenceMutex().try_lock_for(std::chrono::minutes(5)); + if(!lockAcquired) + { + runtime.endInference(request.model); + spdlog::warn("[llama] streaming completion request for model '{}' timed out waiting for inference lock", request.model); + return ErrorCode::ServerOverloaded; + } + std::lock_guard inferenceLock(runtime.getInferenceMutex(), std::adopt_lock); + + auto lockWaitMs=std::chrono::duration( + std::chrono::steady_clock::now()-lockWaitStart).count(); + if(lockWaitMs>100.0) + { + spdlog::warn("[llama] streaming inference lock acquired after {:.1f}ms wait", lockWaitMs); + } + + spdlog::info("[llama] starting streaming inference for model '{}'", request.model); std::chrono::steady_clock::time_point startTime=std::chrono::steady_clock::now(); std::string resultText; @@ -137,7 +189,16 @@ ErrorCode Llama::streamingCompletion(const CompletionRequest &request, runtime.endInference(request.model); - if(code==ErrorCode::Success) + if(code!=ErrorCode::Success) + { + spdlog::error("[llama] streaming inference failed for model '{}' after {:.1f}ms (error={})", + request.model, totalTimeMs, static_cast(code)); + return code; + } + + spdlog::info("[llama] streaming inference complete: prompt={} tokens ({:.1f}ms), gen={} tokens ({:.1f}ms), total={:.1f}ms", + promptTokens, promptTimeMs, completionTokens, generationTimeMs, totalTimeMs); + { std::optional state=runtime.getModelState(request.model); @@ -178,6 +239,8 @@ ErrorCode Llama::getEmbeddings(const EmbeddingRequest &request, return ErrorCode::ModelNotLoaded; } + std::lock_guard inferenceLock(runtime.getInferenceMutex()); + // Combine input text std::string inputText; std::visit([&inputText](auto &&arg) @@ -350,6 +413,209 @@ std::string Llama::applyTemplate(llama_model *model, return result; } +std::string Llama::formatHarmonyPrompt(const CompletionRequest &request, + const ModelInfo &modelInfo) const +{ + std::string prompt; + bool hasTools=request.tools.has_value()&&!request.tools->empty(); + + // Build system message + std::string systemContent="You are ChatGPT, a large language model trained by OpenAI.\n" + "Knowledge cutoff: 2024-06\n" + "Current date: 2025-06-28\n" + "\n" + "Reasoning: high\n" + "\n" + "# Valid channels: analysis, commentary, final. Channel must be included for every message."; + + if(hasTools) + { + systemContent+="\nCalls to these tools must go to the commentary channel: 'functions'."; + } + + prompt+="<|start|>system<|message|>"+systemContent+"<|end|>"; + + // Build developer message from the first system-role message (if any) + // and tool definitions + std::string developerContent; + bool hasInstructions=false; + + for(const Message &msg:request.messages) + { + if(msg.role=="system") + { + if(!developerContent.empty()) + developerContent+="\n\n"; + developerContent+="# Instructions\n\n"+msg.content; + hasInstructions=true; + } + } + + if(hasTools) + { + if(!developerContent.empty()) + developerContent+="\n\n"; + else + developerContent+="# Instructions\n\nYou are a helpful assistant.\n\n"; + + developerContent+="# Tools\n\n## functions\n\nnamespace functions {\n"; + + for(const ToolDefinition &tool:*request.tools) + { + if(!tool.description.empty()) + developerContent+="\n// "+tool.description+"\n"; + else + developerContent+="\n"; + + if(tool.parametersSchema.is_object()&&tool.parametersSchema.contains("properties") + &&!tool.parametersSchema["properties"].empty()) + { + developerContent+="type "+tool.name+" = (_: {\n"; + + const nlohmann::json &props=tool.parametersSchema["properties"]; + std::vector required; + if(tool.parametersSchema.contains("required")&&tool.parametersSchema["required"].is_array()) + { + for(const nlohmann::json &r:tool.parametersSchema["required"]) + { + if(r.is_string()) + required.push_back(r.get()); + } + } + + for(auto it=props.begin(); it!=props.end(); ++it) + { + std::string paramName=it.key(); + const nlohmann::json ¶mDef=it.value(); + + // Add description as comment + if(paramDef.contains("description")) + developerContent+="// "+paramDef["description"].get()+"\n"; + + bool isRequired=std::find(required.begin(), required.end(), paramName)!=required.end(); + + // Determine type string + std::string typeStr="any"; + if(paramDef.contains("type")) + { + std::string jsonType=paramDef["type"].get(); + if(jsonType=="string") + { + if(paramDef.contains("enum")) + { + typeStr=""; + for(size_t i=0; i0) typeStr+=" | "; + typeStr+="\""+paramDef["enum"][i].get()+"\""; + } + } + else + { + typeStr="string"; + } + } + else if(jsonType=="integer"||jsonType=="number") + typeStr="number"; + else if(jsonType=="boolean") + typeStr="boolean"; + else if(jsonType=="array") + { + if(paramDef.contains("items")&¶mDef["items"].contains("type")) + { + std::string itemType=paramDef["items"]["type"].get(); + if(itemType=="string") typeStr="string[]"; + else if(itemType=="number"||itemType=="integer") typeStr="number[]"; + else typeStr="any[]"; + } + else + { + typeStr="any[]"; + } + } + } + + developerContent+=paramName+(isRequired?": ":"?: ")+typeStr+","; + + // Add default as inline comment + if(paramDef.contains("default")) + { + developerContent+=" // default: "+paramDef["default"].dump(); + } + developerContent+="\n"; + } + + developerContent+="}) => any;\n"; + } + else + { + developerContent+="type "+tool.name+" = () => any;\n"; + } + } + + developerContent+="\n} // namespace functions"; + } + + if(!developerContent.empty()) + { + prompt+="<|start|>developer<|message|>"+developerContent+"<|end|>"; + } + + // Format conversation messages (skip system messages, already handled above) + for(const Message &msg:request.messages) + { + if(msg.role=="system") + continue; + + if(msg.role=="user") + { + prompt+="<|start|>user<|message|>"+msg.content+"<|end|>"; + } + else if(msg.role=="assistant") + { + if(msg.toolCalls.has_value()&&!msg.toolCalls->empty()) + { + // Assistant message with tool calls — recreate harmony format + // First the content/reasoning if any + if(!msg.content.empty()) + { + prompt+="<|start|>assistant<|channel|>final<|message|>"+msg.content+"<|end|>"; + } + + // Then each tool call + for(const ToolCall &tc:*msg.toolCalls) + { + std::string argsStr; + if(tc.arguments.is_string()) + argsStr=tc.arguments.get(); + else + argsStr=tc.arguments.dump(); + + prompt+="<|start|>assistant<|channel|>commentary to=functions."+tc.name + +" <|constrain|>json<|message|>"+argsStr+"<|call|>"; + } + } + else + { + // Regular assistant message — use final channel + prompt+="<|start|>assistant<|channel|>final<|message|>"+msg.content+"<|end|>"; + } + } + else if(msg.role=="tool") + { + // Tool result message + std::string toolName=msg.name.value_or("unknown"); + prompt+="<|start|>functions."+toolName+" to=assistant<|channel|>commentary<|message|>" + +msg.content+"<|end|>"; + } + } + + // Prompt the assistant to start generating + prompt+="<|start|>assistant"; + + return prompt; +} + ErrorCode Llama::runInference(llama_model *model, llama_context *ctx, const CompletionRequest &request, const ModelInfo &modelInfo, std::string &result, int &promptTokens, int &completionTokens, @@ -357,20 +623,29 @@ ErrorCode Llama::runInference(llama_model *model, llama_context *ctx, std::function streamCallback) { const llama_vocab *vocab=llama_model_get_vocab(model); + bool harmonyMode=(modelInfo.apiFormat=="harmony"); // Apply chat template to format messages properly - std::string prompt=applyTemplate(model, request.messages); + std::string prompt; + if(harmonyMode) + { + prompt=formatHarmonyPrompt(request, modelInfo); + } + else + { + prompt=applyTemplate(model, request.messages); + } - // Tokenize the formatted prompt + // Tokenize the formatted prompt — use special token parsing for harmony std::vector tokensList(prompt.size()+256); int nTokens=llama_tokenize(vocab, prompt.c_str(), prompt.length(), - tokensList.data(), tokensList.size(), true, false); + tokensList.data(), tokensList.size(), true, harmonyMode); if(nTokens<0) { // Buffer too small, resize and retry tokensList.resize(-nTokens); nTokens=llama_tokenize(vocab, prompt.c_str(), prompt.length(), - tokensList.data(), tokensList.size(), true, false); + tokensList.data(), tokensList.size(), true, harmonyMode); if(nTokens<0) { spdlog::error("Failed to tokenize prompt"); @@ -381,6 +656,7 @@ ErrorCode Llama::runInference(llama_model *model, llama_context *ctx, promptTokens=nTokens; // Clear KV cache for fresh inference + spdlog::debug("[llama] clearing KV cache, prompt tokens={}", nTokens); llama_memory_clear(llama_get_memory(ctx), true); int nBatch=static_cast(llama_n_batch(ctx)); @@ -409,9 +685,11 @@ ErrorCode Llama::runInference(llama_model *model, llama_context *ctx, batch.logits[chunkSize-1]=1; } - if(llama_decode(ctx, batch)!=0) + int decodeResult=llama_decode(ctx, batch); + if(decodeResult!=0) { - spdlog::error("llama_decode failed during prompt processing (chunk at offset {})", start); + spdlog::error("[llama] llama_decode failed during prompt processing (chunk at offset {}, chunkSize={}, totalTokens={}, result={})", + start, chunkSize, nTokens, decodeResult); llama_batch_free(batch); return ErrorCode::GenerationError; } @@ -421,6 +699,17 @@ ErrorCode Llama::runInference(llama_model *model, llama_context *ctx, promptTimeMs=std::chrono::duration(promptEnd-promptStart).count(); int maxOutputTokens=request.max_tokens.value_or(modelInfo.maxOutputTokens); + + // Harmony mode models use internal reasoning tokens that are hidden from the client. + // The client's max_tokens should apply to visible output, so we need extra headroom + // for the analysis channel. Apply a multiplier to ensure the model can complete both + // reasoning and the final response. + if(harmonyMode) + { + int minHarmonyTokens=std::max(maxOutputTokens*8, 16384); + maxOutputTokens=std::min(minHarmonyTokens, modelInfo.maxOutputTokens>0?modelInfo.maxOutputTokens:131072); + } + int nCur=nTokens; completionTokens=0; @@ -453,20 +742,61 @@ ErrorCode Llama::runInference(llama_model *model, llama_context *ctx, // Generation loop (timed) std::chrono::steady_clock::time_point genStart=std::chrono::steady_clock::now(); + // For harmony mode, look up special stop token IDs + llama_token harmonyCallToken=-1; + llama_token harmonyReturnToken=-1; + if(harmonyMode) + { + // Try to find <|call|> and <|return|> tokens by tokenizing them + llama_token buf[4]; + int n; + + n=llama_tokenize(vocab, "<|call|>", 8, buf, 4, false, true); + if(n==1) harmonyCallToken=buf[0]; + + n=llama_tokenize(vocab, "<|return|>", 10, buf, 4, false, true); + if(n==1) harmonyReturnToken=buf[0]; + + spdlog::debug("[llama] harmony stop tokens: <|call|>={}, <|return|>={}", + harmonyCallToken, harmonyReturnToken); + } + for(int i=0; i and <|call|> are marked as EOG in the vocab + // but we need to handle them specially in harmony mode. + if(harmonyMode) + { + if(nextToken==harmonyCallToken||nextToken==harmonyReturnToken) + { + // Append the special token text so the output parser can detect it + if(nextToken==harmonyCallToken) + { + result+="<|call|>"; + if(streamCallback) streamCallback("<|call|>"); + } + completionTokens++; + break; + } + } + + // Check for end of sequence (skip harmony-handled tokens) if(llama_vocab_is_eog(vocab, nextToken)) { + if(harmonyMode) + { + spdlog::info("[llama] harmony EOG hit: token={}", nextToken); + } break; } - // Convert token to text - char piece[64]; - int len=llama_token_to_piece(vocab, nextToken, piece, sizeof(piece), 0, false); + // Convert token to text — use special=true for harmony to preserve special token text + char piece[128]; + int len=llama_token_to_piece(vocab, nextToken, piece, sizeof(piece), 0, harmonyMode); if(len>0) { std::string tokenText(piece, len); @@ -509,9 +839,11 @@ ErrorCode Llama::runInference(llama_model *model, llama_context *ctx, batch.logits[0]=1; nCur++; - if(llama_decode(ctx, batch)!=0) + int decodeResult=llama_decode(ctx, batch); + if(decodeResult!=0) { - spdlog::error("llama_decode failed during generation"); + spdlog::error("[llama] llama_decode failed during generation (token #{}, pos={}, result={})", + i, nCur-1, decodeResult); llama_sampler_free(samplerChain); llama_batch_free(batch); return ErrorCode::GenerationError; @@ -527,4 +859,346 @@ ErrorCode Llama::runInference(llama_model *model, llama_context *ctx, return ErrorCode::Success; } +ErrorCode Llama::tokenizePrompt(llama_model *model, + const CompletionRequest &request, const ModelInfo &modelInfo, + std::vector &tokens, std::string &formattedPrompt) +{ + const llama_vocab *vocab=llama_model_get_vocab(model); + bool harmonyMode=(modelInfo.apiFormat=="harmony"); + + if(harmonyMode) + { + formattedPrompt=formatHarmonyPrompt(request, modelInfo); + } + else + { + formattedPrompt=applyTemplate(model, request.messages); + } + + tokens.resize(formattedPrompt.size()+256); + int nTokens=llama_tokenize(vocab, formattedPrompt.c_str(), formattedPrompt.length(), + tokens.data(), tokens.size(), true, harmonyMode); + if(nTokens<0) + { + tokens.resize(-nTokens); + nTokens=llama_tokenize(vocab, formattedPrompt.c_str(), formattedPrompt.length(), + tokens.data(), tokens.size(), true, harmonyMode); + if(nTokens<0) + { + spdlog::error("Failed to tokenize prompt"); + return ErrorCode::GenerationError; + } + } + tokens.resize(nTokens); + return ErrorCode::Success; +} + +ErrorCode Llama::runInferenceWithTokens(llama_model *model, llama_context *ctx, + const CompletionRequest &request, const ModelInfo &modelInfo, + const std::vector &promptTokens, + std::string &result, int &promptTokenCount, int &completionTokens, + double &promptTimeMs, double &generationTimeMs, + std::function streamCallback) +{ + const llama_vocab *vocab=llama_model_get_vocab(model); + bool harmonyMode=(modelInfo.apiFormat=="harmony"); + + int nTokens=static_cast(promptTokens.size()); + promptTokenCount=nTokens; + + spdlog::debug("[llama] clearing KV cache, prompt tokens={}", nTokens); + llama_memory_clear(llama_get_memory(ctx), true); + + int nBatch=static_cast(llama_n_batch(ctx)); + llama_batch batch=llama_batch_init(std::max(nBatch, 512), 0, 1); + + std::chrono::steady_clock::time_point promptStart=std::chrono::steady_clock::now(); + + for(int start=0; start=nTokens); + + batch.n_tokens=chunkSize; + for(int32_t i=0; i(promptEnd-promptStart).count(); + + int maxOutputTokens=request.max_tokens.value_or(modelInfo.maxOutputTokens); + + if(harmonyMode) + { + int minHarmonyTokens=std::max(maxOutputTokens*8, 16384); + maxOutputTokens=std::min(minHarmonyTokens, modelInfo.maxOutputTokens>0?modelInfo.maxOutputTokens:131072); + } + + int nCur=nTokens; + completionTokens=0; + + llama_sampler_chain_params samplerParams=llama_sampler_chain_default_params(); + llama_sampler *samplerChain=llama_sampler_chain_init(samplerParams); + + llama_sampler_chain_add(samplerChain, llama_sampler_init_penalties( + -1, + 1.0f, + request.frequency_penalty.value_or(0.0f), + request.presence_penalty.value_or(0.0f))); + + if(request.top_p.has_value()) + { + llama_sampler_chain_add(samplerChain, llama_sampler_init_top_p(*request.top_p, 1)); + } + if(request.temperature.has_value()&&*request.temperature>0.0) + { + llama_sampler_chain_add(samplerChain, llama_sampler_init_temp(*request.temperature)); + } + llama_sampler_chain_add(samplerChain, llama_sampler_init_greedy()); + + for(const llama_token &token:promptTokens) + { + llama_sampler_accept(samplerChain, token); + } + + std::chrono::steady_clock::time_point genStart=std::chrono::steady_clock::now(); + + llama_token harmonyCallToken=-1; + llama_token harmonyReturnToken=-1; + if(harmonyMode) + { + llama_token buf[4]; + int n; + + n=llama_tokenize(vocab, "<|call|>", 8, buf, 4, false, true); + if(n==1) harmonyCallToken=buf[0]; + + n=llama_tokenize(vocab, "<|return|>", 10, buf, 4, false, true); + if(n==1) harmonyReturnToken=buf[0]; + + spdlog::debug("[llama] harmony stop tokens: <|call|>={}, <|return|>={}", + harmonyCallToken, harmonyReturnToken); + } + + for(int i=0; i"); + } + completionTokens++; + break; + } + } + + if(llama_vocab_is_eog(vocab, nextToken)) + { + if(harmonyMode) + { + spdlog::info("[llama] harmony EOG hit: token={}", nextToken); + } + break; + } + + char piece[128]; + int len=llama_token_to_piece(vocab, nextToken, piece, sizeof(piece), 0, harmonyMode); + if(len>0) + { + std::string tokenText(piece, len); + result+=tokenText; + completionTokens++; + + if(streamCallback) + { + streamCallback(tokenText); + } + } + + if(request.stop.has_value()) + { + bool stopFound=false; + for(const std::string &stopWord:*request.stop) + { + if(result.size()>=stopWord.size()&& + result.substr(result.size()-stopWord.size())==stopWord) + { + result.resize(result.size()-stopWord.size()); + stopFound=true; + break; + } + } + if(stopFound) + { + break; + } + } + + batch.n_tokens=1; + batch.token[0]=nextToken; + batch.pos[0]=nCur; + batch.n_seq_id[0]=1; + batch.seq_id[0][0]=0; + batch.logits[0]=1; + nCur++; + + int decodeResult=llama_decode(ctx, batch); + if(decodeResult!=0) + { + spdlog::error("[llama] llama_decode failed during generation (token #{}, pos={}, result={})", + i, nCur-1, decodeResult); + llama_sampler_free(samplerChain); + llama_batch_free(batch); + return ErrorCode::GenerationError; + } + } + + std::chrono::steady_clock::time_point genEnd=std::chrono::steady_clock::now(); + generationTimeMs=std::chrono::duration(genEnd-genStart).count(); + + llama_sampler_free(samplerChain); + llama_batch_free(batch); + + return ErrorCode::Success; +} + +ErrorCode Llama::streamingCompletion(const CompletionRequest &request, + std::function callback, + std::function waitCallback) +{ + ModelRuntime &runtime=ModelRuntime::instance(); + + ErrorCode loadResult=runtime.loadModel(request.model); + if(loadResult!=ErrorCode::Success) + { + return loadResult; + } + + llama_model *llamaModel=runtime.getLlamaModel(request.model); + llama_context *llamaCtx=runtime.getLlamaContext(request.model); + + if(!llamaModel||!llamaCtx) + { + spdlog::error("Llama model handles not available for: {}", request.model); + return ErrorCode::ModelNotLoaded; + } + + std::optional modelInfo=runtime.getLoadedModelInfo(request.model); + if(!modelInfo) + { + return ErrorCode::ModelNotFound; + } + + // Tokenize prompt BEFORE acquiring the inference lock. + // applyTemplate/formatHarmonyPrompt and llama_tokenize only need + // llama_model/llama_vocab (read-only), not llama_context. + spdlog::info("[llama] streaming: pre-tokenizing prompt for model '{}'", request.model); + std::vector tokens; + std::string formattedPrompt; + ErrorCode tokenizeResult=tokenizePrompt(llamaModel, request, *modelInfo, tokens, formattedPrompt); + if(tokenizeResult!=ErrorCode::Success) + { + return tokenizeResult; + } + spdlog::info("[llama] streaming: tokenized {} tokens, waiting for inference lock", tokens.size()); + + auto lockWaitStart=std::chrono::steady_clock::now(); + + runtime.beginInference(request.model); + + // Use timed try-lock loop so we can call waitCallback periodically + // while another request holds the lock. + bool lockAcquired=false; + while(!lockAcquired) + { + lockAcquired=runtime.getInferenceMutex().try_lock_for(std::chrono::milliseconds(500)); + if(!lockAcquired&&waitCallback) + { + waitCallback(); + } + } + std::lock_guard inferenceLock(runtime.getInferenceMutex(), std::adopt_lock); + + auto lockWaitMs=std::chrono::duration( + std::chrono::steady_clock::now()-lockWaitStart).count(); + if(lockWaitMs>100.0) + { + spdlog::warn("[llama] streaming inference lock acquired after {:.1f}ms wait", lockWaitMs); + } + + spdlog::info("[llama] starting streaming inference for model '{}' ({} prompt tokens pre-tokenized)", request.model, tokens.size()); + std::chrono::steady_clock::time_point startTime=std::chrono::steady_clock::now(); + + std::string resultText; + int promptTokens=0; + int completionTokens=0; + double promptTimeMs=0.0; + double generationTimeMs=0.0; + + ErrorCode code=runInferenceWithTokens(llamaModel, llamaCtx, request, *modelInfo, + tokens, resultText, promptTokens, completionTokens, promptTimeMs, generationTimeMs, callback); + + std::chrono::steady_clock::time_point endTime=std::chrono::steady_clock::now(); + double totalTimeMs=std::chrono::duration(endTime-startTime).count(); + + runtime.endInference(request.model); + + if(code!=ErrorCode::Success) + { + spdlog::error("[llama] streaming inference failed for model '{}' after {:.1f}ms (error={})", + request.model, totalTimeMs, static_cast(code)); + return code; + } + + spdlog::info("[llama] streaming inference complete: prompt={} tokens ({:.1f}ms), gen={} tokens ({:.1f}ms), total={:.1f}ms", + promptTokens, promptTimeMs, completionTokens, generationTimeMs, totalTimeMs); + + { + std::optional state=runtime.getModelState(request.model); + + InferenceStats stats; + stats.model=request.model; + stats.variant=state?state->variant:""; + stats.promptTokens=promptTokens; + stats.completionTokens=completionTokens; + stats.totalTimeMs=totalTimeMs; + stats.promptTimeMs=promptTimeMs; + stats.generationTimeMs=generationTimeMs; + stats.tokensPerSecond=totalTimeMs>0.0?(completionTokens/(totalTimeMs/1000.0)):0.0; + stats.promptTokensPerSecond=promptTimeMs>0.0?(promptTokens/(promptTimeMs/1000.0)):0.0; + stats.generationTokensPerSecond=generationTimeMs>0.0?(completionTokens/(generationTimeMs/1000.0)):0.0; + stats.timestamp=std::chrono::system_clock::now(); + TelemetryCollector::instance().recordInference(stats); + } + + return code; +} + } // namespace arbiterAI diff --git a/src/arbiterAI/providers/llama.h b/src/arbiterAI/providers/llama.h index 6366271..26e3af9 100644 --- a/src/arbiterAI/providers/llama.h +++ b/src/arbiterAI/providers/llama.h @@ -6,6 +6,7 @@ #include #include #include +#include // Forward declarations for llama.cpp types struct llama_model; @@ -26,6 +27,10 @@ class Llama : public BaseProvider { ErrorCode streamingCompletion(const CompletionRequest &request, std::function callback) override; + ErrorCode streamingCompletion(const CompletionRequest &request, + std::function callback, + std::function waitCallback) override; + ErrorCode getEmbeddings(const EmbeddingRequest &request, EmbeddingResponse &response) override; @@ -39,12 +44,30 @@ class Llama : public BaseProvider { std::string applyTemplate(llama_model *model, const std::vector &messages) const; + /// Format messages into harmony special token format for gpt-oss models. + std::string formatHarmonyPrompt(const CompletionRequest &request, + const ModelInfo &modelInfo) const; + + /// Tokenize the prompt outside of the inference mutex. + /// Returns the formatted prompt tokens ready for decode. + ErrorCode tokenizePrompt(llama_model *model, + const CompletionRequest &request, const ModelInfo &modelInfo, + std::vector &tokens, std::string &formattedPrompt); + /// Run the inference loop (shared by completion and streaming). ErrorCode runInference(llama_model *model, llama_context *ctx, const CompletionRequest &request, const ModelInfo &modelInfo, std::string &result, int &promptTokens, int &completionTokens, double &promptTimeMs, double &generationTimeMs, std::function streamCallback); + + /// Run inference with pre-tokenized prompt (avoids re-tokenizing under lock). + ErrorCode runInferenceWithTokens(llama_model *model, llama_context *ctx, + const CompletionRequest &request, const ModelInfo &modelInfo, + const std::vector &promptTokens, + std::string &result, int &promptTokenCount, int &completionTokens, + double &promptTimeMs, double &generationTimeMs, + std::function streamCallback); }; } // namespace arbiterAI diff --git a/src/server/dashboard.h b/src/server/dashboard.h index a6469cd..7a27de3 100644 --- a/src/server/dashboard.h +++ b/src/server/dashboard.h @@ -1553,17 +1553,17 @@ function renderDownloadProgress(downloads) el.innerHTML=html; } -function renderActiveRequests(history) +function renderActiveRequests(history, activeCount) { const el=document.getElementById("activeRequestTable"); - if(!history||history.length===0) + if((!history||history.length===0)&&!activeCount) { el.innerHTML='No recent requests'; return; } - const recent=history.slice(-20).reverse(); + const recent=history?history.slice(-20).reverse():[]; let html=""; for(const s of recent) { @@ -1571,11 +1571,10 @@ function renderActiveRequests(history) const genTps=s.generation_tokens_per_second||0; const totalMs=s.total_time_ms||0; const latencyMs=s.latency_ms||0; - const isActive=(totalMs===0&&latencyMs===0); html+=` ${s.model} - ${isActive?"Running":"Done"} + Done ${s.prompt_tokens.toLocaleString()} ${s.completion_tokens.toLocaleString()} ${promptTps.toFixed(1)} @@ -1634,11 +1633,11 @@ async function refresh() document.getElementById("cpuBar").style.width=stats.hardware.cpu_utilization_percent.toFixed(1)+"%"; } - // TPS chart - if(stats.avg_prompt_tokens_per_second!==undefined||stats.avg_generation_tokens_per_second!==undefined) + // TPS chart — show 0 when no active inference (avoid stale rolling averages) { - promptTpsHistory.push(stats.avg_prompt_tokens_per_second||0); - genTpsHistory.push(stats.avg_generation_tokens_per_second||0); + const idle=(stats.active_requests||0)===0; + promptTpsHistory.push(idle?0:(stats.avg_prompt_tokens_per_second||0)); + genTpsHistory.push(idle?0:(stats.avg_generation_tokens_per_second||0)); if(promptTpsHistory.length>MAX_TPS_POINTS) promptTpsHistory.shift(); if(genTpsHistory.length>MAX_TPS_POINTS) genTpsHistory.shift(); updateTpsChart(); @@ -1655,7 +1654,7 @@ async function refresh() if(history) renderInferences(history); // Active requests summary - if(history) renderActiveRequests(history); + if(history) renderActiveRequests(history, stats.active_requests||0); // Swaps if(swaps) renderSwaps(swaps); diff --git a/src/server/dashboardConfig.h b/src/server/dashboardConfig.h index 6eb5dc8..2674f64 100644 --- a/src/server/dashboardConfig.h +++ b/src/server/dashboardConfig.h @@ -1430,7 +1430,8 @@ function addStartupModel() variant: '', context_size: 0, runtime_options: {}, - devices: [] + devices: [], + api_format: '' }); renderStartupModels(); } @@ -1472,6 +1473,15 @@ function updateStartupModelField(index, field, value) { startupModelsState[index].context_size=opt.effective_context_size; } + // Set default api_format from model options + if(opt&&opt.api_format) + { + startupModelsState[index].api_format=opt.api_format; + } + else + { + startupModelsState[index].api_format=''; + } renderStartupModels(); } } @@ -1650,6 +1660,9 @@ function renderStartupModels() // Runtime options const ro=entry.runtime_options||{}; + // API format (from model catalog or overridden) + const currentApiFormat=entry.api_format||(modelOpt&&modelOpt.api_format?modelOpt.api_format:'')||''; + html+='
' +'
' +'Model '+(i+1)+'' @@ -1677,6 +1690,14 @@ function renderStartupModels() +'
'+devicesHtml+'
' +(entry.devices.length===0?'
No devices selected — auto-assignment will be used.
':'') +'
' + +'
' + +'' + +'' + +'
When set to Harmony, the server parses channel tags from model output and converts to standard OpenAI format.
' + +'
' +'
' +'' +'
' @@ -1714,7 +1735,8 @@ async function saveAllStartupModels() variant: e.variant||'', context_size: e.context_size||0, runtime_options: e.runtime_options||{}, - devices: e.devices||[] + devices: e.devices||[], + api_format: e.api_format||'' })); try @@ -1759,8 +1781,26 @@ async function saveStartupModelEntry(index) // Collect this entry's runtime opts from UI entry.runtime_options=readStartupModelRuntimeOpts(index); + // Read api_format from UI + const apiFormatEl=document.getElementById('smApiFormat_'+index); + if(apiFormatEl) entry.api_format=apiFormatEl.value||''; + showModelStatus(index, 'Saving config...', ''); + // Update model catalog api_format if changed + if(entry.model&&entry.api_format!==undefined) + { + try + { + await fetch('/api/models/config', { + method: 'PUT', + headers: {'Content-Type': 'application/json'}, + body: JSON.stringify({model: entry.model, api_format: entry.api_format}) + }); + } + catch(e) {} + } + // Save all models to config const saved=await saveAllStartupModels(); if(!saved) @@ -1861,7 +1901,8 @@ async function initializePage() variant: e.variant||'', context_size: e.context_size||0, runtime_options: e.runtime_options||{}, - devices: e.devices||[] + devices: e.devices||[], + api_format: e.api_format||'' })); } else diff --git a/src/server/main.cpp b/src/server/main.cpp index 3926c6f..dbfee12 100644 --- a/src/server/main.cpp +++ b/src/server/main.cpp @@ -20,6 +20,8 @@ #include #include #include +#include +#include namespace { @@ -574,6 +576,18 @@ int main(int argc, char *argv[]) logger->set_level(spdlog::get_level()); spdlog::set_default_logger(logger); + // Install crash signal handlers for better diagnostics + auto crashHandler=[](int sig) + { + spdlog::critical("FATAL SIGNAL {} ({}) received — aborting", + sig, (sig==SIGABRT?"SIGABRT":sig==SIGSEGV?"SIGSEGV":"OTHER")); + spdlog::default_logger()->flush(); + std::signal(sig, SIG_DFL); + std::raise(sig); + }; + std::signal(SIGABRT, crashHandler); + std::signal(SIGSEGV, crashHandler); + spdlog::info("Loaded config from: {}", configPath); if(!logDir.empty()) diff --git a/src/server/routes.cpp b/src/server/routes.cpp index f2e5880..023a5d8 100644 --- a/src/server/routes.cpp +++ b/src/server/routes.cpp @@ -351,7 +351,8 @@ nlohmann::json buildStartupOptionJson( {"compatibility", "cloud"}, {"compatibility_label", "Cloud"}, {"compatibility_reason", "Provider-managed model; no local download or VRAM requirement."}, - {"sort_rank", startupCompatibilitySortRank("cloud")} + {"sort_rank", startupCompatibilitySortRank("cloud")}, + {"api_format", model.apiFormat} }; if(model.variants.empty()) @@ -967,6 +968,7 @@ std::string errorCodeToString(ErrorCode code) case ErrorCode::NotImplemented: return "not_implemented"; case ErrorCode::GenerationError: return "generation_error"; case ErrorCode::ApiKeyNotFound: return "api_key_not_found"; + case ErrorCode::ServerOverloaded: return "server_overloaded"; default: return "unknown_error"; } } @@ -998,6 +1000,396 @@ std::pair parseModelVariant(const std::string &modelId return {modelId, ""}; } +/// A single tool call extracted from harmony output. +struct HarmonyToolCall { + std::string name; // Function name (e.g. "get_current_weather") + std::string arguments; // JSON arguments string +}; + +/// Result of parsing harmony format text into separate channels. +struct HarmonyParseResult { + std::string content; // "final" channel → assistant content + std::string reasoningContent; // "analysis" channel → reasoning/thinking + std::vector toolCalls; // tool calls from commentary channel + bool hasToolCall=false; // Whether output ended with <|call|> +}; + +/// Parse the header portion of a harmony message to extract channel and recipient. +/// Header format: "channel_name" or "channel_name to=recipient" +/// Also handles <|constrain|> before <|message|>. +struct HarmonyHeader { + std::string channel; + std::string recipient; // e.g. "functions.get_current_weather" +}; + +HarmonyHeader parseHarmonyHeader(const std::string &headerStr) +{ + HarmonyHeader header; + std::string str=headerStr; + + // Strip trailing <|constrain|>... portion (everything after space+<|constrain|>) + size_t constrainPos=str.find("<|constrain|>"); + if(constrainPos!=std::string::npos) + { + str=str.substr(0, constrainPos); + } + + // Trim trailing whitespace + while(!str.empty()&&(str.back()==' '||str.back()=='\t')) + str.pop_back(); + + // Check for "to=" in the header + size_t toPos=str.find(" to="); + if(toPos!=std::string::npos) + { + header.channel=str.substr(0, toPos); + header.recipient=str.substr(toPos+4); + } + else + { + // Check if recipient is specified without space: "commentary to=functions.x" + toPos=str.find("to="); + if(toPos!=std::string::npos&&toPos>0) + { + header.channel=str.substr(0, toPos); + // Trim trailing space from channel + while(!header.channel.empty()&&header.channel.back()==' ') + header.channel.pop_back(); + header.recipient=str.substr(toPos+3); + } + else + { + header.channel=str; + } + } + + return header; +} + +/// Parse harmony format output into separate content channels and tool calls. +/// Harmony format uses tags like: +/// <|channel|>analysis<|message|>...reasoning...<|end|> +/// <|channel|>final<|message|>...response...<|return|> +/// <|channel|>commentary to=functions.name <|constrain|>json<|message|>{"args"}<|call|> +HarmonyParseResult parseHarmonyFormat(const std::string &text) +{ + HarmonyParseResult result; + std::string remaining=text; + + // Check if output ended with <|call|> + if(remaining.size()>=8&&remaining.substr(remaining.size()-8)=="<|call|>") + { + result.hasToolCall=true; + remaining=remaining.substr(0, remaining.size()-8); + } + + // Parse all channel blocks + while(!remaining.empty()) + { + // Find channel tag + size_t channelPos=remaining.find("<|channel|>"); + if(channelPos==std::string::npos) + { + // No more channels; if we haven't extracted anything yet, treat as plain content + if(result.content.empty()&&result.reasoningContent.empty()&&result.toolCalls.empty()) + { + result.content=remaining; + } + break; + } + + size_t channelNameStart=channelPos+11; // length of "<|channel|>" + size_t messagePos=remaining.find("<|message|>", channelNameStart); + if(messagePos==std::string::npos) + { + break; + } + + // Extract header between <|channel|> and <|message|> + std::string headerStr=remaining.substr(channelNameStart, messagePos-channelNameStart); + HarmonyHeader header=parseHarmonyHeader(headerStr); + + size_t contentStart=messagePos+11; // length of "<|message|>" + + // Find end of this message block — could be <|end|>, <|call|>, or next <|start|> + size_t endPos=remaining.find("<|end|>", contentStart); + size_t callPos=remaining.find("<|call|>", contentStart); + size_t nextStartPos=remaining.find("<|start|>", contentStart); + + std::string messageContent; + size_t nextBlockStart; + bool isCallEnd=false; + + // Find the nearest end marker + size_t nearestEnd=std::string::npos; + if(endPos!=std::string::npos) nearestEnd=endPos; + if(callPos!=std::string::npos&&(nearestEnd==std::string::npos||callPos" + else + nextBlockStart=nearestEnd+7; // length of "<|end|>" + } + else if(nextStartPos!=std::string::npos) + { + messageContent=remaining.substr(contentStart, nextStartPos-contentStart); + nextBlockStart=nextStartPos; + } + else + { + messageContent=remaining.substr(contentStart); + nextBlockStart=remaining.size(); + } + + // Route to appropriate field based on channel name + if(header.channel=="final") + { + if(!result.content.empty()) + result.content+=messageContent; + else + result.content=messageContent; + } + else if(header.channel=="analysis") + { + if(!result.reasoningContent.empty()) + result.reasoningContent+="\n"+messageContent; + else + result.reasoningContent=messageContent; + } + else if(header.channel=="commentary") + { + // Commentary with a recipient = tool call + if(!header.recipient.empty()) + { + // Extract function name from "functions.{name}" + std::string funcName=header.recipient; + size_t dotPos=funcName.find('.'); + if(dotPos!=std::string::npos) + { + funcName=funcName.substr(dotPos+1); + } + + HarmonyToolCall tc; + tc.name=funcName; + tc.arguments=messageContent; + result.toolCalls.push_back(std::move(tc)); + result.hasToolCall=true; + } + else + { + // Commentary without recipient = preamble (show to user as content) + if(!result.content.empty()) + result.content+=messageContent; + else + result.content=messageContent; + } + } + + remaining=remaining.substr(nextBlockStart); + } + + return result; +} + +/// Streaming harmony format parser. Buffers tokens and emits content as channels are identified. +/// Handles tool calls by accumulating them internally. +class HarmonyStreamParser { +public: + /// Feed a new token chunk. Returns content to emit to the client (final channel text only). + /// Reasoning content and tool calls are accumulated internally. + std::string feed(const std::string &chunk) + { + m_buffer+=chunk; + std::string output; + + while(true) + { + if(m_inMessage) + { + // Look for end-of-message markers + size_t endPos=m_buffer.find("<|end|>"); + size_t callPos=m_buffer.find("<|call|>"); + size_t startPos=m_buffer.find("<|start|>"); + + // Find nearest end marker + size_t endOfContent=std::string::npos; + bool isCallEnd=false; + + if(endPos!=std::string::npos) + endOfContent=endPos; + if(callPos!=std::string::npos&&(endOfContent==std::string::npos||callPos — don't consume it + } + else + { + // Check if buffer might contain a partial tag + size_t possibleTag=m_buffer.find("<|"); + if(possibleTag!=std::string::npos&&possibleTag>0) + { + // Emit everything before the potential tag + std::string safe=m_buffer.substr(0, possibleTag); + routeContent(safe, output); + m_buffer=m_buffer.substr(possibleTag); + } + else if(possibleTag==std::string::npos&&!m_buffer.empty()) + { + // No potential tag at all — emit everything + routeContent(m_buffer, output); + m_buffer.clear(); + } + break; + } + } + else + { + // Look for channel start + size_t channelPos=m_buffer.find("<|channel|>"); + if(channelPos==std::string::npos) + { + // Skip past <|start|> tags and role text + size_t startTag=m_buffer.find("<|start|>"); + if(startTag!=std::string::npos) + { + m_buffer=m_buffer.substr(startTag+9); + // Skip role text (e.g. "assistant") + size_t nextTag=m_buffer.find("<|"); + if(nextTag!=std::string::npos) + m_buffer=m_buffer.substr(nextTag); + else + break; + continue; + } + break; + } + + // Skip anything before <|channel|> (e.g. role text after <|start|>) + size_t channelNameStart=channelPos+11; + size_t messagePos=m_buffer.find("<|message|>", channelNameStart); + if(messagePos==std::string::npos) + { + break; // Wait for more data + } + + // Parse header + std::string headerStr=m_buffer.substr(channelNameStart, messagePos-channelNameStart); + HarmonyHeader header=parseHarmonyHeader(headerStr); + + m_currentChannel=header.channel; + m_currentRecipient=header.recipient; + m_currentContent.clear(); + m_inMessage=true; + m_buffer=m_buffer.substr(messagePos+11); + } + } + + return output; + } + + /// Get accumulated reasoning content after streaming completes. + const std::string &getReasoningContent() const { return m_reasoning; } + + /// Get tool calls extracted during streaming. + const std::vector &getToolCalls() const { return m_toolCalls; } + + /// Whether the output contained a tool call. + bool hasToolCall() const { return m_hasToolCall; } + +private: + void routeContent(const std::string &content, std::string &output) + { + if(content.empty()) return; + + if(m_currentChannel=="final") + { + output+=content; + } + else if(m_currentChannel=="analysis") + { + m_reasoning+=content; + } + else if(m_currentChannel=="commentary") + { + if(!m_currentRecipient.empty()) + { + // Accumulate tool call arguments + m_currentContent+=content; + } + else + { + // Preamble commentary — show to user + output+=content; + } + } + } + + std::string m_buffer; + std::string m_currentChannel; + std::string m_currentRecipient; + std::string m_currentContent; // Current tool call content being accumulated + std::string m_reasoning; + std::vector m_toolCalls; + bool m_inMessage=false; + bool m_hasToolCall=false; +}; + +/// Check if a model uses harmony API format. +bool isHarmonyFormat(const std::string &modelName) +{ + ModelInfo info; + if(ArbiterAI::instance().getModelInfo(modelName, info)==ErrorCode::Success) + { + return info.apiFormat=="harmony"; + } + return false; +} + } // anonymous namespace // ========== Override Path ========== @@ -1289,11 +1681,13 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) includeUsage=requestJson.at("stream_options").value("include_usage", false); } + bool harmonyMode=isHarmonyFormat(arbiterRequest.model); + if(stream) { res.set_chunked_content_provider( "text/event-stream", - [arbiterRequest, requestId, created, includeUsage, responseModelId](size_t, httplib::DataSink &sink) + [arbiterRequest, requestId, created, includeUsage, responseModelId, harmonyMode](size_t, httplib::DataSink &sink) { // Send initial chunk with role nlohmann::json roleChunk={ @@ -1311,9 +1705,23 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) std::string roleLine="data: "+roleChunk.dump()+"\n\n"; sink.write(roleLine.c_str(), roleLine.length()); + HarmonyStreamParser harmonyParser; + auto callback=[&](const std::string &chunk) { if(chunk.empty()) return; + + std::string emitContent; + if(harmonyMode) + { + emitContent=harmonyParser.feed(chunk); + if(emitContent.empty()) return; + } + else + { + emitContent=chunk; + } + nlohmann::json sseChunk={ {"id", requestId}, {"object", "chat.completion.chunk"}, @@ -1322,7 +1730,7 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) {"system_fingerprint", nullptr}, {"choices", {{ {"index", 0}, - {"delta", {{"content", chunk}}}, + {"delta", {{"content", emitContent}}}, {"finish_reason", nullptr} }}} }; @@ -1330,7 +1738,16 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) sink.write(line.c_str(), line.length()); }; - ErrorCode err=ArbiterAI::instance().streamingCompletion(arbiterRequest, callback); + // Send SSE comments while waiting for the inference lock. + // This keeps the connection alive and signals to clients that + // the request is queued for processing. + auto waitCallback=[&]() + { + std::string comment=": queued - waiting for model availability\n\n"; + sink.write(comment.c_str(), comment.length()); + }; + + ErrorCode err=ArbiterAI::instance().streamingCompletion(arbiterRequest, callback, waitCallback); std::string finishReason=(err==ErrorCode::Success)?"stop":"error"; @@ -1339,6 +1756,41 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) spdlog::error("Streaming completion failed: {}", errorCodeToString(err)); } + // For harmony mode, check if tool calls were detected and emit them + if(harmonyMode&&harmonyParser.hasToolCall()) + { + finishReason="tool_calls"; + const std::vector &toolCalls=harmonyParser.getToolCalls(); + for(size_t i=0; i(i)}, + {"id", callId}, + {"type", "function"}, + {"function", { + {"name", toolCalls[i].name}, + {"arguments", toolCalls[i].arguments} + }} + }}} + }}, + {"finish_reason", nullptr} + }}} + }; + std::string tcLine="data: "+toolCallChunk.dump()+"\n\n"; + sink.write(tcLine.c_str(), tcLine.length()); + } + } + // Send final chunk with finish_reason nlohmann::json finishChunk={ {"id", requestId}, @@ -1403,6 +1855,11 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) status=400; errType="invalid_request_error"; } + else if(err==ErrorCode::ServerOverloaded) + { + status=503; + errType="server_error"; + } res.status=status; res.set_content(errorJson("Completion failed: "+errCode, errType, "", errCode).dump(), "application/json"); @@ -1411,6 +1868,40 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) std::string finishReason=arbiterResponse.finishReason.empty()?"stop":arbiterResponse.finishReason; + // Convert harmony format to standard OpenAI format if needed + if(harmonyMode&&!arbiterResponse.text.empty()) + { + HarmonyParseResult parsed=parseHarmonyFormat(arbiterResponse.text); + arbiterResponse.text=parsed.content; + if(!parsed.reasoningContent.empty()) + arbiterResponse.reasoningContent=parsed.reasoningContent; + + // Convert harmony tool calls to OpenAI format + if(!parsed.toolCalls.empty()) + { + for(size_t i=0; i0?m.contextSize:info.contextWindow; + // Emit bare model name - data.push_back({ + nlohmann::json modelObj={ {"id", m.modelName}, {"object", "model"}, {"created", created}, {"owned_by", "arbiterai"}, - {"permission", nlohmann::json::array()} - }); + {"permission", nlohmann::json::array()}, + {"context_length", contextLength}, + {"max_completion_tokens", info.maxOutputTokens} + }; + data.push_back(modelObj); // Also emit "model:variant" if a variant is loaded if(!m.variant.empty()) { - data.push_back({ - {"id", m.modelName+":"+m.variant}, - {"object", "model"}, - {"created", created}, - {"owned_by", "arbiterai"}, - {"permission", nlohmann::json::array()} - }); + nlohmann::json variantObj=modelObj; + variantObj["id"]=m.modelName+":"+m.variant; + data.push_back(variantObj); } } @@ -1547,12 +2044,26 @@ void handleGetModelV1(const httplib::Request &req, httplib::Response &res) return; } + // Get context size from loaded model state if available, otherwise from config + int contextLength=info.contextWindow; + std::vector states=ModelRuntime::instance().getModelStates(); + for(const LoadedModel &lm:states) + { + if(lm.modelName==baseName&&lm.state==ModelState::Loaded&&lm.contextSize>0) + { + contextLength=lm.contextSize; + break; + } + } + nlohmann::json response={ {"id", modelId}, {"object", "model"}, {"created", static_cast(std::time(nullptr))}, {"owned_by", "arbiterai"}, - {"permission", nlohmann::json::array()} + {"permission", nlohmann::json::array()}, + {"context_length", contextLength}, + {"max_completion_tokens", info.maxOutputTokens} }; res.set_content(response.dump(), "application/json"); @@ -1862,6 +2373,10 @@ void handleGetModels(const httplib::Request &, httplib::Response &res) { modelJson["backend_priority"]=info.backendPriority; } + if(!info.apiFormat.empty()) + { + modelJson["api_format"]=info.apiFormat; + } } models.push_back(modelJson); @@ -1900,6 +2415,10 @@ void handleGetModels(const httplib::Request &, httplib::Response &res) { modelJson["backend_priority"]=info.backendPriority; } + if(!info.apiFormat.empty()) + { + modelJson["api_format"]=info.apiFormat; + } } models.push_back(modelJson); diff --git a/vcpkg/custom_ports/llama-cpp/portfile.cmake b/vcpkg/custom_ports/llama-cpp/portfile.cmake index d734586..7c545cd 100644 --- a/vcpkg/custom_ports/llama-cpp/portfile.cmake +++ b/vcpkg/custom_ports/llama-cpp/portfile.cmake @@ -23,7 +23,7 @@ else() OUT_SOURCE_PATH SOURCE_PATH REPO ggml-org/llama.cpp REF b${VERSION} - SHA512 6be3482ef58872ee4a386ba831175e53ce0d93c6992e4389ffd97f9af3cc7becdd1356fda575702681f55261e7fe81bc1baa12edd0d5f809aa80684f5c890bac + SHA512 50d06c9b3fc72245ba621dc0dbe6fd8a67fbd36a6d5cc27715dc36675e93cf0aee992f90a85e710013ecf7bd1164b13b0b09dc611d69abd6032e36682cbbd719 HEAD_REF master ) endif() @@ -65,8 +65,12 @@ endif() ) file(MAKE_DIRECTORY "${CURRENT_PACKAGES_DIR}/tools/${PORT}") -file(RENAME "${CURRENT_PACKAGES_DIR}/bin/convert_hf_to_gguf.py" "${CURRENT_PACKAGES_DIR}/tools/${PORT}/convert-hf-to-gguf.py") -file(INSTALL "${SOURCE_PATH}/gguf-py" DESTINATION "${CURRENT_PACKAGES_DIR}/tools/${PORT}") +if(EXISTS "${CURRENT_PACKAGES_DIR}/bin/convert_hf_to_gguf.py") + file(RENAME "${CURRENT_PACKAGES_DIR}/bin/convert_hf_to_gguf.py" "${CURRENT_PACKAGES_DIR}/tools/${PORT}/convert-hf-to-gguf.py") +endif() +if(EXISTS "${SOURCE_PATH}/gguf-py") + file(INSTALL "${SOURCE_PATH}/gguf-py" DESTINATION "${CURRENT_PACKAGES_DIR}/tools/${PORT}") +endif() if (NOT VCPKG_BUILD_TYPE) file(REMOVE "${CURRENT_PACKAGES_DIR}/debug/bin/convert_hf_to_gguf.py") endif() diff --git a/vcpkg/custom_ports/llama-cpp/vcpkg.json b/vcpkg/custom_ports/llama-cpp/vcpkg.json index cc672ae..622abad 100644 --- a/vcpkg/custom_ports/llama-cpp/vcpkg.json +++ b/vcpkg/custom_ports/llama-cpp/vcpkg.json @@ -1,6 +1,6 @@ { "name": "llama-cpp", - "version": "8748", + "version": "9204", "port-version": 0, "description": "LLM inference in C/C++", "homepage": "https://github.com/ggml-org/llama.cpp", From bc3565bafb5a8d293e9b2087072669212317682e Mon Sep 17 00:00:00 2001 From: krazer Date: Wed, 10 Jun 2026 18:05:47 -0400 Subject: [PATCH 2/4] add inference scheduler pipeline for local models - InferenceScheduler: tokenizer thread + per-accelerator worker threads, job tracking with cancellation on client disconnect, TokenChannel for streaming tokens back to HTTP threads - server: scheduler-backed /v1/chat/completions for local models with SSE status keepalives (queue position), heartbeats for non-streaming, proper HTTP status for fast failures, real token usage in include_usage chunk - /api/scheduler/jobs endpoint; dashboard shows active jobs and cancelled status; scheduler shutdown on server exit - llama provider: abort checks during prompt processing and generation - telemetry: job id and cancelled flag on inference stats - config: harmony api_format for gpt-oss-120b (submodule) - tests: TokenChannel and InferenceScheduler coverage; fix flaky configDownloader test (libgit2 init, daemon readiness poll, cleanup) - docs: scheduler endpoint and local model architecture Co-Authored-By: Claude Opus 4.8 (1M context) --- CMakeLists.txt | 3 + arbiterAI_config | 2 +- docs/developer.md | 13 +- docs/server.md | 38 +- src/arbiterAI/arbiterAI.h | 3 +- src/arbiterAI/inferenceScheduler.cpp | 661 +++++++++++++++++++++++++++ src/arbiterAI/inferenceScheduler.h | 228 +++++++++ src/arbiterAI/providers/llama.cpp | 22 +- src/arbiterAI/providers/llama.h | 29 +- src/arbiterAI/telemetryCollector.h | 2 + src/server/dashboard.h | 85 +++- src/server/main.cpp | 21 +- src/server/routes.cpp | 408 +++++++++++++++-- tests/configDownloaderTests.cpp | 70 ++- tests/inferenceSchedulerTests.cpp | 284 ++++++++++++ 15 files changed, 1764 insertions(+), 105 deletions(-) create mode 100644 src/arbiterAI/inferenceScheduler.cpp create mode 100644 src/arbiterAI/inferenceScheduler.h create mode 100644 tests/inferenceSchedulerTests.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f34128a..ffb4a4a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,6 +75,8 @@ set(arbiterai_src ./src/arbiterAI/modelRuntime.cpp ./src/arbiterAI/telemetryCollector.h ./src/arbiterAI/telemetryCollector.cpp + ./src/arbiterAI/inferenceScheduler.h + ./src/arbiterAI/inferenceScheduler.cpp ./src/arbiterAI/storageManager.h ./src/arbiterAI/storageManager.cpp ./src/arbiterAI/providers/baseProvider.h @@ -141,6 +143,7 @@ target_link_libraries(arbiterai tests/hardwareDetectorTests.cpp tests/modelRuntimeTests.cpp tests/telemetryCollectorTests.cpp + tests/inferenceSchedulerTests.cpp tests/llamaProviderTests.cpp tests/storageManagerTests.cpp tests/serverConnectTests.cpp diff --git a/arbiterAI_config b/arbiterAI_config index e6a4342..7315138 160000 --- a/arbiterAI_config +++ b/arbiterAI_config @@ -1 +1 @@ -Subproject commit e6a4342141f6e84f229be0141ae1374b16194110 +Subproject commit 7315138bdeaba9ce001d5c8dd814f0e919d08017 diff --git a/docs/developer.md b/docs/developer.md index ce6c990..f102352 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -53,14 +53,15 @@ ArbiterAI follows a layered architecture: - **[`ModelManager`](../src/arbiterAI/modelManager.h)** — Singleton that loads and manages model configurations from JSON files with schema validation. - **Utility Components** — Cross-cutting functionality including caching ([`CacheManager`](../src/arbiterAI/cacheManager.h)), cost tracking ([`CostManager`](../src/arbiterAI/costManager.h)), model downloading ([`ModelDownloader`](../src/arbiterAI/modelDownloader.h)), and file verification ([`FileVerifier`](../src/arbiterAI/fileVerifier.h)). -### Planned Components +### Local Model Components -See [Local Model Management Task](tasks/local_model_management.md) for upcoming additions: +Components supporting local (llama.cpp) models — see [Local Model Management Task](tasks/local_model_management.md) for background: -- **`HardwareDetector`** — GPU/RAM/CPU detection (NVML + Vulkan) -- **`ModelRuntime`** — Multi-model loading, swap queueing, LRU eviction (refactor of `LlamaInterface`) -- **`TelemetryCollector`** — Inference stats and system snapshots -- **Standalone Server** — Separate `arbiterAI-server` application providing an OpenAI-compatible API, model management endpoints, and a live stats dashboard +- **[`HardwareDetector`](../src/arbiterAI/hardwareDetector.h)** — GPU/RAM/CPU detection (NVML + Vulkan) +- **[`ModelRuntime`](../src/arbiterAI/modelRuntime.h)** — Multi-model loading, swap queueing, LRU eviction, load-failure classification +- **[`InferenceScheduler`](../src/arbiterAI/inferenceScheduler.h)** — Inference pipeline used by the server for local models. HTTP threads submit jobs; a tokenizer thread loads the model and pre-tokenizes the prompt; per-accelerator worker threads run inference. Streaming tokens flow back to the HTTP thread through a `TokenChannel`, and jobs are cancelled on client disconnect. Active jobs are exposed at `/api/scheduler/jobs`. +- **[`TelemetryCollector`](../src/arbiterAI/telemetryCollector.h)** — Inference stats and system snapshots +- **Standalone Server** — Separate `arbiterAI-server` application providing an OpenAI-compatible API, model management endpoints, and a live stats dashboard (see [Server Guide](server.md)) --- diff --git a/docs/server.md b/docs/server.md index 9ea331b..3c924ec 100644 --- a/docs/server.md +++ b/docs/server.md @@ -31,7 +31,7 @@ The server supports: - **Model lifecycle management** — Load, unload, pin, and download models at runtime - **Runtime model config injection** — Add, update, or remove model configurations via REST without restarting - **Storage management** — Track downloaded model files, set hot ready / protected flags, configure automated cleanup, monitor disk usage and download progress with speed and ETA -- **Telemetry** — System snapshots, inference history, swap history, and hardware info +- **Telemetry** — System snapshots, inference history, swap history, active scheduler jobs, and hardware info - **Live dashboard** — Browser-based UI at `/dashboard` with storage bar, download progress, and model management - **CORS** — All responses include permissive CORS headers @@ -738,6 +738,8 @@ Inference history within a time window. { "model": "gpt-4", "variant": "", + "job_id": 17, + "cancelled": false, "tokens_per_second": 45.2, "prompt_tokens": 120, "completion_tokens": 80, @@ -747,6 +749,40 @@ Inference history within a time window. ] ``` +#### `GET /api/scheduler/jobs` + +Active inference scheduler jobs (local models only). Jobs flow through the +pipeline stages `queued` → `tokenizing` → `waiting` → `inferring`; completed +and cancelled jobs are not listed. Returns `[]` when the scheduler is not +running. + +**Response:** + +```json +[ + { + "id": 17, + "model": "my-local-model", + "stage": "inferring", + "streaming": true, + "prompt_tokens": 120, + "completion_tokens": 34, + "queue_position": 0, + "elapsed_ms": 2150.0 + } +] +``` + +| Field | Description | +|-------|-------------| +| `id` | Scheduler job ID (matches `job_id` in inference history) | +| `stage` | `queued`, `tokenizing`, `waiting`, or `inferring` | +| `streaming` | Whether the request is a streaming completion | +| `prompt_tokens` | Prompt token count (available once tokenized) | +| `completion_tokens` | Tokens generated so far (streaming jobs only) | +| `queue_position` | Position in the accelerator queue (`0` = running) | +| `elapsed_ms` | Time since the job was submitted | + #### `GET /api/stats/swaps` Model swap history. diff --git a/src/arbiterAI/arbiterAI.h b/src/arbiterAI/arbiterAI.h index 4874664..b79d6e4 100644 --- a/src/arbiterAI/arbiterAI.h +++ b/src/arbiterAI/arbiterAI.h @@ -77,7 +77,8 @@ enum class ErrorCode ModelDownloading, ModelDownloadFailed, InsufficientStorage, - ServerOverloaded + ServerOverloaded, + Cancelled }; /** diff --git a/src/arbiterAI/inferenceScheduler.cpp b/src/arbiterAI/inferenceScheduler.cpp new file mode 100644 index 0000000..30d87ab --- /dev/null +++ b/src/arbiterAI/inferenceScheduler.cpp @@ -0,0 +1,661 @@ +#include "arbiterAI/inferenceScheduler.h" +#include "arbiterAI/modelRuntime.h" +#include "arbiterAI/modelManager.h" +#include "arbiterAI/hardwareDetector.h" +#include "arbiterAI/telemetryCollector.h" +#include "arbiterAI/providers/llama.h" + +#include +#include + +#include + +namespace arbiterAI +{ + +// ── TokenChannel ────────────────────────────────────────────── + +void TokenChannel::push(const std::string &token) +{ + if(m_cancelled.load()) return; + + { + std::lock_guard lock(m_mutex); + m_tokens.push(token); + } + m_cv.notify_one(); +} + +void TokenChannel::finish(ErrorCode result) +{ + { + std::lock_guard lock(m_mutex); + m_result=result; + } + m_done.store(true); + m_cv.notify_all(); +} + +bool TokenChannel::pop(std::string &token, std::chrono::milliseconds timeout) +{ + std::unique_lock lock(m_mutex); + + if(!m_cv.wait_for(lock, timeout, [this]() + { + return !m_tokens.empty()||m_done.load()||m_cancelled.load(); + })) + { + // Timeout — no token available, but not done either. + // Caller can use this to send keepalive. + return true; // still alive, just no data yet + } + + if(m_cancelled.load()) + { + return false; + } + + if(!m_tokens.empty()) + { + token=std::move(m_tokens.front()); + m_tokens.pop(); + return true; + } + + // Queue empty and done — no more tokens. + return false; +} + +bool TokenChannel::isDone() const +{ + return m_done.load(); +} + +ErrorCode TokenChannel::getResult() const +{ + return m_result; +} + +void TokenChannel::cancel() +{ + m_cancelled.store(true); + m_cv.notify_all(); +} + +bool TokenChannel::isCancelled() const +{ + return m_cancelled.load(); +} + +// ── InferenceScheduler ──────────────────────────────────────── + +InferenceScheduler &InferenceScheduler::instance() +{ + static InferenceScheduler s; + return s; +} + +InferenceScheduler::~InferenceScheduler() +{ + shutdown(); +} + +void InferenceScheduler::initialize(const std::vector &gpuIndices) +{ + if(m_running.load()) + { + spdlog::warn("[scheduler] already initialized"); + return; + } + + spdlog::info("[scheduler] initializing with {} accelerator(s)", gpuIndices.size()); + + // Create per-accelerator queues + for(int idx:gpuIndices) + { + auto q=std::make_unique(); + q->gpuIndex=idx; + q->running.store(true); + + auto gpus=HardwareDetector::instance().getGpus(); + for(const GpuInfo &gpu:gpus) + { + if(gpu.index==idx) + { + q->deviceName=gpu.name; + break; + } + } + + spdlog::info("[scheduler] accelerator {}: {}", idx, q->deviceName); + m_accelerators.push_back(std::move(q)); + } + + // If no GPUs provided, create a single CPU-based accelerator queue + if(m_accelerators.empty()) + { + auto q=std::make_unique(); + q->gpuIndex=-1; + q->deviceName="CPU"; + q->running.store(true); + spdlog::info("[scheduler] no GPUs specified, using single CPU accelerator queue"); + m_accelerators.push_back(std::move(q)); + } + + m_running.store(true); + + // Start tokenizer thread + m_tokenizerThread=std::thread(&InferenceScheduler::tokenizerLoop, this); + + // Start accelerator threads + for(auto &q:m_accelerators) + { + q->workerThread=std::thread(&InferenceScheduler::acceleratorLoop, this, std::ref(*q)); + } + + spdlog::info("[scheduler] started: 1 tokenizer thread, {} accelerator thread(s)", + m_accelerators.size()); +} + +void InferenceScheduler::shutdown() +{ + if(!m_running.load()) return; + + spdlog::info("[scheduler] shutting down"); + m_running.store(false); + + // Wake tokenizer + m_tokenizerCv.notify_all(); + + // Wake all accelerators + for(auto &q:m_accelerators) + { + q->running.store(false); + q->cv.notify_all(); + } + + // Join threads + if(m_tokenizerThread.joinable()) + m_tokenizerThread.join(); + + for(auto &q:m_accelerators) + { + if(q->workerThread.joinable()) + q->workerThread.join(); + } + + m_accelerators.clear(); + spdlog::info("[scheduler] shutdown complete"); +} + +std::shared_ptr InferenceScheduler::submit(const CompletionRequest &request, bool streaming) +{ + auto job=std::make_shared(); + job->id=m_nextJobId.fetch_add(1); + job->request=request; + job->streaming=streaming; + job->submitTime=std::chrono::steady_clock::now(); + job->stage.store(InferenceStage::Queued); + + if(streaming) + { + job->channel=std::make_shared(); + } + + // Track job + { + std::lock_guard lock(m_jobsMutex); + m_activeJobs[job->id]=job; + } + + // Enqueue for tokenization + { + std::lock_guard lock(m_tokenizerMutex); + m_tokenizerQueue.push_back(job); + } + m_tokenizerCv.notify_one(); + + spdlog::info("[scheduler] job {} submitted (model='{}', streaming={})", + job->id, request.model, streaming); + + return job; +} + +void InferenceScheduler::cancel(uint64_t jobId) +{ + std::lock_guard lock(m_jobsMutex); + auto it=m_activeJobs.find(jobId); + if(it!=m_activeJobs.end()) + { + auto job=it->second.lock(); + if(job) + { + job->cancelled.store(true); + job->stage.store(InferenceStage::Cancelled); + if(job->channel) + { + job->channel->cancel(); + } + } + } +} + +void InferenceScheduler::finishJob(const std::shared_ptr &job, ErrorCode result) +{ + job->result=result; + job->stage.store(result==ErrorCode::Cancelled? + InferenceStage::Cancelled:InferenceStage::Complete); + + { + std::lock_guard lock(job->completionMutex); + job->complete.store(true); + } + job->completionCv.notify_all(); + + if(job->channel) + { + job->channel->finish(result); + } + + std::lock_guard lock(m_jobsMutex); + m_activeJobs.erase(job->id); +} + +int InferenceScheduler::getTotalQueueDepth() const +{ + int total=0; + { + std::lock_guard lock(m_tokenizerMutex); + total+=static_cast(m_tokenizerQueue.size()); + } + for(const auto &q:m_accelerators) + { + std::lock_guard lock(q->mutex); + total+=static_cast(q->jobs.size()); + if(q->activeJob) total++; + } + return total; +} + +int InferenceScheduler::getQueueDepth(int gpuIndex) const +{ + for(const auto &q:m_accelerators) + { + if(q->gpuIndex==gpuIndex) + { + std::lock_guard lock(q->mutex); + int depth=static_cast(q->jobs.size()); + if(q->activeJob) depth++; + return depth; + } + } + return 0; +} + +InferenceStage InferenceScheduler::getJobStage(uint64_t jobId) const +{ + std::lock_guard lock(m_jobsMutex); + auto it=m_activeJobs.find(jobId); + if(it!=m_activeJobs.end()) + { + auto job=it->second.lock(); + if(job) + { + return job->stage.load(); + } + } + return InferenceStage::Complete; +} + +std::vector InferenceScheduler::getActiveJobs() const +{ + std::vector result; + auto now=std::chrono::steady_clock::now(); + + std::lock_guard lock(m_jobsMutex); + for(const auto &[id, weak]:m_activeJobs) + { + auto job=weak.lock(); + if(!job) continue; + + InferenceStage stage=job->stage.load(); + if(stage==InferenceStage::Complete||stage==InferenceStage::Cancelled) continue; + if(job->cancelled.load()) continue; + + JobSnapshot snap; + snap.id=job->id; + snap.model=job->request.model; + snap.stage=stage; + snap.streaming=job->streaming; + // tokens is only stable once the tokenizer has handed the job off + if(stage==InferenceStage::WaitingAccelerator||stage==InferenceStage::Inferring) + { + snap.promptTokens=static_cast(job->tokens.size()); + } + snap.completionTokens=job->completionTokens.load(); + snap.queuePosition=job->queuePosition.load(); + snap.elapsedMs=std::chrono::duration(now-job->submitTime).count(); + result.push_back(snap); + } + return result; +} + +// ── Tokenizer Thread ────────────────────────────────────────── + +void InferenceScheduler::tokenizerLoop() +{ + spdlog::info("[scheduler:tokenizer] thread started"); + + while(m_running.load()) + { + std::shared_ptr job; + + { + std::unique_lock lock(m_tokenizerMutex); + m_tokenizerCv.wait(lock, [this]() + { + return !m_tokenizerQueue.empty()||!m_running.load(); + }); + + if(!m_running.load()) break; + if(m_tokenizerQueue.empty()) continue; + + job=m_tokenizerQueue.front(); + m_tokenizerQueue.pop_front(); + } + + if(job->cancelled.load()) + { + finishJob(job, ErrorCode::Cancelled); + continue; + } + + job->stage.store(InferenceStage::Tokenizing); + job->tokenizeStartTime=std::chrono::steady_clock::now(); + + spdlog::debug("[scheduler:tokenizer] tokenizing job {} (model='{}')", + job->id, job->request.model); + + // Ensure model is loaded + ModelRuntime &runtime=ModelRuntime::instance(); + ErrorCode loadResult=runtime.loadModel(job->request.model); + if(loadResult!=ErrorCode::Success) + { + finishJob(job, loadResult); + continue; + } + + llama_model *llamaModel=runtime.getLlamaModel(job->request.model); + if(!llamaModel) + { + finishJob(job, ErrorCode::ModelNotLoaded); + continue; + } + + // Get model info for template/format + std::optional modelInfo=runtime.getLoadedModelInfo(job->request.model); + if(!modelInfo) + { + finishJob(job, ErrorCode::ModelNotFound); + continue; + } + + // Tokenize — this only reads llama_model/vocab (thread-safe, no context needed) + Llama llamaProvider; + ErrorCode tokenizeResult=llamaProvider.tokenizePrompt( + llamaModel, job->request, *modelInfo, job->tokens, job->formattedPrompt); + + if(tokenizeResult!=ErrorCode::Success) + { + finishJob(job, tokenizeResult); + continue; + } + + spdlog::debug("[scheduler:tokenizer] job {} tokenized: {} tokens", + job->id, job->tokens.size()); + + if(job->cancelled.load()) + { + finishJob(job, ErrorCode::Cancelled); + continue; + } + + // Move to accelerator queue + job->stage.store(InferenceStage::WaitingAccelerator); + AcceleratorQueue &accel=selectAccelerator(*job); + + { + std::lock_guard lock(accel.mutex); + accel.jobs.push_back(job); + + // Update queue positions + int pos=1; + if(accel.activeJob) pos++; + for(auto &queued:accel.jobs) + { + queued->queuePosition.store(pos++); + } + } + accel.cv.notify_one(); + } + + spdlog::info("[scheduler:tokenizer] thread exiting"); +} + +// ── Accelerator Thread ──────────────────────────────────────── + +void InferenceScheduler::acceleratorLoop(AcceleratorQueue &queue) +{ + spdlog::info("[scheduler:accel:{}] thread started (device='{}')", + queue.gpuIndex, queue.deviceName); + + while(queue.running.load()) + { + std::shared_ptr job; + + { + std::unique_lock lock(queue.mutex); + queue.cv.wait(lock, [&queue]() + { + return !queue.jobs.empty()||!queue.running.load(); + }); + + if(!queue.running.load()) break; + if(queue.jobs.empty()) continue; + + job=queue.jobs.front(); + queue.jobs.pop_front(); + queue.activeJob=job; + + // Update queue positions for remaining jobs + int pos=1; + for(auto &queued:queue.jobs) + { + queued->queuePosition.store(pos++); + } + } + + if(job->cancelled.load()) + { + finishJob(job, ErrorCode::Cancelled); + std::lock_guard lock(queue.mutex); + queue.activeJob=nullptr; + continue; + } + + job->stage.store(InferenceStage::Inferring); + job->inferenceStartTime=std::chrono::steady_clock::now(); + job->queuePosition.store(0); + + spdlog::info("[scheduler:accel:{}] starting inference for job {} (model='{}', {} prompt tokens)", + queue.gpuIndex, job->id, job->request.model, job->tokens.size()); + + // Acquire inference lock and run + ModelRuntime &runtime=ModelRuntime::instance(); + + llama_model *llamaModel=runtime.getLlamaModel(job->request.model); + llama_context *llamaCtx=runtime.getLlamaContext(job->request.model); + + if(!llamaModel||!llamaCtx) + { + spdlog::error("[scheduler:accel:{}] model handles not available for job {}", + queue.gpuIndex, job->id); + finishJob(job, ErrorCode::ModelNotLoaded); + + std::lock_guard lock(queue.mutex); + queue.activeJob=nullptr; + continue; + } + + std::optional modelInfo=runtime.getLoadedModelInfo(job->request.model); + if(!modelInfo) + { + finishJob(job, ErrorCode::ModelNotFound); + + std::lock_guard lock(queue.mutex); + queue.activeJob=nullptr; + continue; + } + + // Lock the inference mutex (one inference at a time per context) + runtime.beginInference(job->request.model); + std::lock_guard inferenceLock(runtime.getInferenceMutex()); + + // Check for cancellation after acquiring lock + if(job->cancelled.load()) + { + runtime.endInference(job->request.model); + finishJob(job, ErrorCode::Cancelled); + std::lock_guard lock(queue.mutex); + queue.activeJob=nullptr; + continue; + } + + // Build stream callback that pushes to channel + std::function streamCallback=nullptr; + if(job->streaming&&job->channel) + { + streamCallback=[&job](const std::string &token) + { + if(job->channel->isCancelled()) return; + job->channel->push(token); + job->completionTokens.fetch_add(1); + }; + } + + // Abort callback — checks if the job was cancelled by client disconnect + auto abortCheck=[&job]() -> bool + { + return job->cancelled.load(); + }; + + // Run inference with pre-tokenized prompt + int completionTokens=0; + + Llama llamaProvider; + ErrorCode code=llamaProvider.runInferenceWithTokens( + llamaModel, llamaCtx, job->request, *modelInfo, + job->tokens, job->resultText, + job->promptTokens, completionTokens, + job->promptTimeMs, job->generationTimeMs, + streamCallback, abortCheck); + + runtime.endInference(job->request.model); + job->completionTokens.store(completionTokens); + + auto endTime=std::chrono::steady_clock::now(); + double totalTimeMs=std::chrono::duration( + endTime-job->inferenceStartTime).count(); + + finishJob(job, code); + + // Record telemetry + if(code==ErrorCode::Success) + { + spdlog::info("[scheduler:accel:{}] job {} complete: prompt={} ({:.1f}ms), gen={} ({:.1f}ms), total={:.1f}ms", + queue.gpuIndex, job->id, job->promptTokens, job->promptTimeMs, + completionTokens, job->generationTimeMs, totalTimeMs); + + std::optional state=runtime.getModelState(job->request.model); + + InferenceStats stats; + stats.jobId=job->id; + stats.cancelled=false; + stats.model=job->request.model; + stats.variant=state?state->variant:""; + stats.promptTokens=job->promptTokens; + stats.completionTokens=completionTokens; + stats.totalTimeMs=totalTimeMs; + stats.promptTimeMs=job->promptTimeMs; + stats.generationTimeMs=job->generationTimeMs; + stats.latencyMs=std::chrono::duration(job->inferenceStartTime-job->submitTime).count(); + stats.tokensPerSecond=totalTimeMs>0.0?(completionTokens/(totalTimeMs/1000.0)):0.0; + stats.promptTokensPerSecond=job->promptTimeMs>0.0?(job->promptTokens/(job->promptTimeMs/1000.0)):0.0; + stats.generationTokensPerSecond=job->generationTimeMs>0.0?(completionTokens/(job->generationTimeMs/1000.0)):0.0; + stats.timestamp=std::chrono::system_clock::now(); + TelemetryCollector::instance().recordInference(stats); + } + else if(code==ErrorCode::Cancelled) + { + spdlog::info("[scheduler:accel:{}] job {} cancelled after {} gen tokens ({:.1f}ms)", + queue.gpuIndex, job->id, completionTokens, totalTimeMs); + + std::optional state=runtime.getModelState(job->request.model); + + InferenceStats stats; + stats.jobId=job->id; + stats.cancelled=true; + stats.model=job->request.model; + stats.variant=state?state->variant:""; + stats.promptTokens=job->promptTokens; + stats.completionTokens=completionTokens; + stats.totalTimeMs=totalTimeMs; + stats.promptTimeMs=job->promptTimeMs; + stats.generationTimeMs=job->generationTimeMs; + stats.latencyMs=std::chrono::duration(job->inferenceStartTime-job->submitTime).count(); + stats.tokensPerSecond=0.0; + stats.promptTokensPerSecond=0.0; + stats.generationTokensPerSecond=0.0; + stats.timestamp=std::chrono::system_clock::now(); + TelemetryCollector::instance().recordInference(stats); + } + else + { + spdlog::error("[scheduler:accel:{}] job {} failed (error={})", + queue.gpuIndex, job->id, static_cast(code)); + } + + { + std::lock_guard lock(queue.mutex); + queue.activeJob=nullptr; + } + } + + spdlog::info("[scheduler:accel:{}] thread exiting", queue.gpuIndex); +} + +// ── Accelerator Selection ───────────────────────────────────── + +AcceleratorQueue &InferenceScheduler::selectAccelerator(const InferenceJob &job) +{ + // For now, select the accelerator with the shortest queue. + // Future: match by model's GPU affinity / loaded state. + AcceleratorQueue *best=m_accelerators[0].get(); + int bestDepth=std::numeric_limits::max(); + + for(auto &q:m_accelerators) + { + std::lock_guard lock(q->mutex); + int depth=static_cast(q->jobs.size()); + if(q->activeJob) depth++; + if(depth +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace arbiterAI +{ + +/// States a request passes through in the inference pipeline. +enum class InferenceStage +{ + Queued, // Waiting in the submission queue + Tokenizing, // Being tokenized by the tokenizer thread + WaitingAccelerator, // Tokenized, waiting for an accelerator slot + Inferring, // Running on an accelerator + Complete, // Finished (success or error) + Cancelled // Cancelled by client disconnect +}; + +/// Thread-safe channel for streaming tokens from accelerator to HTTP thread. +class TokenChannel { +public: + /// Push a token/chunk. Called by accelerator thread. + void push(const std::string &token); + + /// Signal that generation is complete. + void finish(ErrorCode result); + + /// Block until a token is available or finished. + /// Returns false when the channel is finished (no more tokens). + bool pop(std::string &token, std::chrono::milliseconds timeout=std::chrono::milliseconds(500)); + + /// Check if the channel is done (finished or error). + bool isDone() const; + + /// Get the final error code (only valid after isDone()). + ErrorCode getResult() const; + + /// Signal cancellation (e.g. client disconnect). + void cancel(); + + bool isCancelled() const; + +private: + mutable std::mutex m_mutex; + std::condition_variable m_cv; + std::queue m_tokens; + std::atomic m_done{false}; + std::atomic m_cancelled{false}; + ErrorCode m_result=ErrorCode::Success; +}; + +/// A single inference request flowing through the pipeline. +struct InferenceJob { + /// Unique job ID for tracking. + uint64_t id=0; + + /// The original completion request. + CompletionRequest request; + + /// Whether this is a streaming request. + bool streaming=false; + + /// Current pipeline stage (display/telemetry only — use 'cancelled' for control flow). + std::atomic stage{InferenceStage::Queued}; + + /// Cancellation flag. Separate from 'stage' so a cancel can never be + /// lost to a concurrent forward stage transition. + std::atomic cancelled{false}; + + /// Timestamp when the job was submitted. + std::chrono::steady_clock::time_point submitTime; + + /// Timestamp when each stage started (for telemetry). + std::chrono::steady_clock::time_point tokenizeStartTime; + std::chrono::steady_clock::time_point inferenceStartTime; + + /// Pre-tokenized prompt tokens (filled by tokenizer thread). + std::vector tokens; + std::string formattedPrompt; + + /// Result for non-streaming requests. + std::string resultText; + int promptTokens=0; + double promptTimeMs=0.0; + double generationTimeMs=0.0; + ErrorCode result=ErrorCode::Success; + + /// Completion token count. Atomic because the dashboard snapshots it + /// while the accelerator thread is still generating. + std::atomic completionTokens{0}; + + /// For streaming requests — tokens pushed here by accelerator. + std::shared_ptr channel; + + /// Condition variable signaled when job is complete (non-streaming). + std::mutex completionMutex; + std::condition_variable completionCv; + std::atomic complete{false}; + + /// Queue position (updated by scheduler for status reporting). + std::atomic queuePosition{0}; +}; + +/// Snapshot of a single job for the dashboard/API. +struct JobSnapshot { + uint64_t id=0; + std::string model; + InferenceStage stage=InferenceStage::Queued; + bool streaming=false; + int promptTokens=0; + int completionTokens=0; + int queuePosition=0; + double elapsedMs=0.0; +}; + +/// Per-accelerator worker that processes inference jobs from its queue. +struct AcceleratorQueue { + int gpuIndex=-1; + std::string deviceName; + + mutable std::mutex mutex; + std::condition_variable cv; + std::deque> jobs; + std::shared_ptr activeJob; // currently inferring + + std::thread workerThread; + std::atomic running{false}; +}; + +/// Central inference pipeline scheduler. +/// +/// Architecture: +/// HTTP threads → submit() → [tokenizer queue] +/// ↓ tokenizer thread +/// [accelerator queue(s)] +/// ↓ accelerator thread(s) +/// [completion / streaming channel] +/// ↓ +/// HTTP thread picks up result +class InferenceScheduler { +public: + static InferenceScheduler &instance(); + + /// Initialize the scheduler with the detected accelerators. + /// Call once after hardware detection. + void initialize(const std::vector &gpuIndices); + + /// Shut down all worker threads. + void shutdown(); + + /// Submit a job to the pipeline. Returns immediately. + /// The caller can wait on job->completionCv (non-streaming) + /// or read from job->channel (streaming). + std::shared_ptr submit(const CompletionRequest &request, bool streaming); + + /// Cancel a job by ID (e.g. on client disconnect). + void cancel(uint64_t jobId); + + /// Get the current queue depth across all accelerators. + int getTotalQueueDepth() const; + + /// Get the queue depth for a specific accelerator. + int getQueueDepth(int gpuIndex) const; + + /// Get status info for a job. + InferenceStage getJobStage(uint64_t jobId) const; + + /// Get snapshot of all active (non-complete) jobs. + std::vector getActiveJobs() const; + + /// Check if the scheduler is initialized and running. + bool isRunning() const { return m_running.load(); } + +private: + InferenceScheduler()=default; + ~InferenceScheduler(); + + InferenceScheduler(const InferenceScheduler &)=delete; + InferenceScheduler &operator=(const InferenceScheduler &)=delete; + + /// Mark a job terminal: set result/stage, signal waiters, finish the + /// streaming channel, and remove it from the active job map. + void finishJob(const std::shared_ptr &job, ErrorCode result); + + /// Tokenizer thread function. + void tokenizerLoop(); + + /// Accelerator worker thread function. + void acceleratorLoop(AcceleratorQueue &queue); + + /// Pick the best accelerator queue for a job. + AcceleratorQueue &selectAccelerator(const InferenceJob &job); + + // Tokenizer queue + mutable std::mutex m_tokenizerMutex; + std::condition_variable m_tokenizerCv; + std::deque> m_tokenizerQueue; + std::thread m_tokenizerThread; + + // Accelerator queues (one per GPU) + std::vector> m_accelerators; + + // Job tracking + mutable std::mutex m_jobsMutex; + std::map> m_activeJobs; + std::atomic m_nextJobId{1}; + + std::atomic m_running{false}; +}; + +} // namespace arbiterAI + +#endif//_ARBITERAI_INFERENCESCHEDULER_H_ diff --git a/src/arbiterAI/providers/llama.cpp b/src/arbiterAI/providers/llama.cpp index 541cf88..54ecdba 100644 --- a/src/arbiterAI/providers/llama.cpp +++ b/src/arbiterAI/providers/llama.cpp @@ -898,7 +898,8 @@ ErrorCode Llama::runInferenceWithTokens(llama_model *model, llama_context *ctx, const std::vector &promptTokens, std::string &result, int &promptTokenCount, int &completionTokens, double &promptTimeMs, double &generationTimeMs, - std::function streamCallback) + std::function streamCallback, + std::function shouldAbort) { const llama_vocab *vocab=llama_model_get_vocab(model); bool harmonyMode=(modelInfo.apiFormat=="harmony"); @@ -916,6 +917,14 @@ ErrorCode Llama::runInferenceWithTokens(llama_model *model, llama_context *ctx, for(int start=0; start=nTokens); @@ -1002,6 +1011,17 @@ ErrorCode Llama::runInferenceWithTokens(llama_model *model, llama_context *ctx, for(int i=0; i( + std::chrono::steady_clock::now()-genStart).count(); + return ErrorCode::Cancelled; + } + llama_token nextToken=llama_sampler_sample(samplerChain, ctx, -1); llama_sampler_accept(samplerChain, nextToken); diff --git a/src/arbiterAI/providers/llama.h b/src/arbiterAI/providers/llama.h index 26e3af9..cd76433 100644 --- a/src/arbiterAI/providers/llama.h +++ b/src/arbiterAI/providers/llama.h @@ -39,6 +39,21 @@ class Llama : public BaseProvider { ErrorCode getAvailableModels(std::vector &models) override; + /// Tokenize the prompt outside of the inference mutex. + /// Only reads llama_model/vocab (thread-safe without context lock). + ErrorCode tokenizePrompt(llama_model *model, + const CompletionRequest &request, const ModelInfo &modelInfo, + std::vector &tokens, std::string &formattedPrompt); + + /// Run inference with pre-tokenized prompt (requires inference lock held). + ErrorCode runInferenceWithTokens(llama_model *model, llama_context *ctx, + const CompletionRequest &request, const ModelInfo &modelInfo, + const std::vector &promptTokens, + std::string &result, int &promptTokenCount, int &completionTokens, + double &promptTimeMs, double &generationTimeMs, + std::function streamCallback, + std::function shouldAbort=nullptr); + private: /// Format messages into a prompt string using the model's chat template. std::string applyTemplate(llama_model *model, @@ -48,26 +63,12 @@ class Llama : public BaseProvider { std::string formatHarmonyPrompt(const CompletionRequest &request, const ModelInfo &modelInfo) const; - /// Tokenize the prompt outside of the inference mutex. - /// Returns the formatted prompt tokens ready for decode. - ErrorCode tokenizePrompt(llama_model *model, - const CompletionRequest &request, const ModelInfo &modelInfo, - std::vector &tokens, std::string &formattedPrompt); - /// Run the inference loop (shared by completion and streaming). ErrorCode runInference(llama_model *model, llama_context *ctx, const CompletionRequest &request, const ModelInfo &modelInfo, std::string &result, int &promptTokens, int &completionTokens, double &promptTimeMs, double &generationTimeMs, std::function streamCallback); - - /// Run inference with pre-tokenized prompt (avoids re-tokenizing under lock). - ErrorCode runInferenceWithTokens(llama_model *model, llama_context *ctx, - const CompletionRequest &request, const ModelInfo &modelInfo, - const std::vector &promptTokens, - std::string &result, int &promptTokenCount, int &completionTokens, - double &promptTimeMs, double &generationTimeMs, - std::function streamCallback); }; } // namespace arbiterAI diff --git a/src/arbiterAI/telemetryCollector.h b/src/arbiterAI/telemetryCollector.h index 625edcd..931eadf 100644 --- a/src/arbiterAI/telemetryCollector.h +++ b/src/arbiterAI/telemetryCollector.h @@ -16,6 +16,8 @@ namespace arbiterAI struct InferenceStats { std::string model; std::string variant; + uint64_t jobId=0; + bool cancelled=false; double tokensPerSecond=0.0; double promptTokensPerSecond=0.0; // prompt processing speed (tokens in / sec) double generationTokensPerSecond=0.0; // generation speed (tokens out / sec) diff --git a/src/server/dashboard.h b/src/server/dashboard.h index 7a27de3..967caf1 100644 --- a/src/server/dashboard.h +++ b/src/server/dashboard.h @@ -195,6 +195,16 @@ td background: #2a2a10; color: #f0c040; } +.badge-cooldown +{ + background: #1a2a3a; + color: #64b5f6; +} +.badge-error +{ + background: #3a1a1a; + color: #ef5350; +} .btn { padding: 4px 12px; @@ -676,6 +686,7 @@ td + @@ -687,7 +698,7 @@ td - +
Job Model Status Input Tokens
No recent requests
No active requests
@@ -696,9 +707,9 @@ td

Recent Inferences

- + - +
ModelPrompt t/sGen t/sPromptCompletionLatency
JobStatusModelPrompt t/sGen t/sPromptCompletionLatency
No recent inferences
No recent inferences
@@ -1427,7 +1438,7 @@ function renderInferences(history) if(!history||history.length===0) { - el.innerHTML='No recent inferences'; + el.innerHTML='No recent inferences'; return; } @@ -1437,8 +1448,14 @@ function renderInferences(history) { const promptTps=s.prompt_tokens_per_second||0; const genTps=s.generation_tokens_per_second||0; + const jobId=s.job_id||"—"; + const statusBadge=s.cancelled + ?'Cancelled' + :'Complete'; html+=` + #${jobId} + ${statusBadge} ${s.model} ${promptTps.toFixed(1)} ${genTps.toFixed(1)} @@ -1553,34 +1570,43 @@ function renderDownloadProgress(downloads) el.innerHTML=html; } -function renderActiveRequests(history, activeCount) +function renderActiveRequests(activeJobs) { const el=document.getElementById("activeRequestTable"); - if((!history||history.length===0)&&!activeCount) + if(!activeJobs||activeJobs.length===0) { - el.innerHTML='No recent requests'; + el.innerHTML='No active requests'; return; } - const recent=history?history.slice(-20).reverse():[]; let html=""; - for(const s of recent) + for(const j of activeJobs) { - const promptTps=s.prompt_tokens_per_second||0; - const genTps=s.generation_tokens_per_second||0; - const totalMs=s.total_time_ms||0; - const latencyMs=s.latency_ms||0; + let stageBadge; + switch(j.stage) + { + case "queued": stageBadge='Queued'; break; + case "tokenizing": stageBadge='Tokenizing'; break; + case "waiting": stageBadge='Waiting'; break; + case "inferring": stageBadge='Inferring'; break; + default: stageBadge='' + j.stage + ''; + } + + const elapsed=(j.elapsed_ms/1000).toFixed(1); + const promptToks=j.prompt_tokens||0; + const compToks=j.completion_tokens||0; html+=` - ${s.model} - Done - ${s.prompt_tokens.toLocaleString()} - ${s.completion_tokens.toLocaleString()} - ${promptTps.toFixed(1)} - ${genTps.toFixed(1)} - ${latencyMs.toFixed(0)} ms - ${totalMs.toFixed(0)} ms + #${j.id} + ${j.model} + ${stageBadge} + ${promptToks.toLocaleString()} + ${compToks.toLocaleString()} + — + — + — + ${elapsed}s `; } el.innerHTML=html; @@ -1596,11 +1622,12 @@ async function refreshDownloads() async function refresh() { - const [stats, history, swaps, hw]=await Promise.all([ + const [stats, history, swaps, hw, schedulerJobs]=await Promise.all([ fetchJson("/api/stats"), fetchJson("/api/stats/history?minutes=5"), fetchJson("/api/stats/swaps"), - fetchJson("/api/hardware") + fetchJson("/api/hardware"), + fetchJson("/api/scheduler/jobs") ]); const dot=document.getElementById("statusDot"); @@ -1654,7 +1681,7 @@ async function refresh() if(history) renderInferences(history); // Active requests summary - if(history) renderActiveRequests(history, stats.active_requests||0); + renderActiveRequests(schedulerJobs); // Swaps if(swaps) renderSwaps(swaps); @@ -1826,6 +1853,16 @@ td background: #2a2a10; color: #f0c040; } +.badge-cooldown +{ + background: #1a2a3a; + color: #64b5f6; +} +.badge-error +{ + background: #3a1a1a; + color: #ef5350; +} .btn { padding: 4px 12px; diff --git a/src/server/main.cpp b/src/server/main.cpp index dbfee12..4c64895 100644 --- a/src/server/main.cpp +++ b/src/server/main.cpp @@ -6,6 +6,7 @@ #include "arbiterAI/modelManager.h" #include "arbiterAI/modelRuntime.h" #include "arbiterAI/storageManager.h" +#include "arbiterAI/inferenceScheduler.h" #include #include @@ -776,6 +777,17 @@ int main(int argc, char *argv[]) } } + // ── Inference Scheduler ───────────────────────────────────────── + { + std::vector gpuIndices; + auto gpus=arbiterAI::HardwareDetector::instance().getGpus(); + for(const auto &gpu:gpus) + { + gpuIndices.push_back(gpu.index); + } + arbiterAI::InferenceScheduler::instance().initialize(gpuIndices); + } + // ── HTTP server ────────────────────────────────────────────── httplib::Server server; @@ -807,6 +819,7 @@ int main(int argc, char *argv[]) spdlog::info(" GET /api/stats - System snapshot"); spdlog::info(" GET /api/stats/history - Inference history"); spdlog::info(" GET /api/stats/swaps - Swap history"); + spdlog::info(" GET /api/scheduler/jobs - Active scheduler jobs"); spdlog::info(" GET /api/hardware - Hardware info"); spdlog::info(" POST /api/hardware/vram-override - Set VRAM override"); spdlog::info(" DEL /api/hardware/vram-override/:idx - Clear VRAM override"); @@ -828,7 +841,13 @@ int main(int argc, char *argv[]) spdlog::info("Starting server on {}:{}", host, port); spdlog::info("Dashboard: http://{}:{}/dashboard", host=="0.0.0.0"?"localhost":host, port); - if(!server.listen(host, port)) + bool listenOk=server.listen(host, port); + + // Stop scheduler worker threads before static destruction tears down + // the singletons (ModelRuntime, TelemetryCollector) they depend on. + arbiterAI::InferenceScheduler::instance().shutdown(); + + if(!listenOk) { spdlog::error("Failed to start server on {}:{}", host, port); return 1; diff --git a/src/server/routes.cpp b/src/server/routes.cpp index 023a5d8..b77d018 100644 --- a/src/server/routes.cpp +++ b/src/server/routes.cpp @@ -10,6 +10,7 @@ #include "arbiterAI/hardwareDetector.h" #include "arbiterAI/telemetryCollector.h" #include "arbiterAI/storageManager.h" +#include "arbiterAI/inferenceScheduler.h" #include #include @@ -909,6 +910,8 @@ nlohmann::json inferenceStatsToJson(const InferenceStats &s) return { {"model", s.model}, {"variant", s.variant}, + {"job_id", s.jobId}, + {"cancelled", s.cancelled}, {"tokens_per_second", s.tokensPerSecond}, {"prompt_tokens_per_second", s.promptTokensPerSecond}, {"generation_tokens_per_second", s.generationTokensPerSecond}, @@ -969,6 +972,7 @@ std::string errorCodeToString(ErrorCode code) case ErrorCode::GenerationError: return "generation_error"; case ErrorCode::ApiKeyNotFound: return "api_key_not_found"; case ErrorCode::ServerOverloaded: return "server_overloaded"; + case ErrorCode::Cancelled: return "cancelled"; default: return "unknown_error"; } } @@ -1390,6 +1394,16 @@ bool isHarmonyFormat(const std::string &modelName) return false; } +bool isLocalModel(const std::string &modelName) +{ + ModelInfo info; + if(ArbiterAI::instance().getModelInfo(modelName, info)==ErrorCode::Success) + { + return info.provider=="llama"; + } + return false; +} + } // anonymous namespace // ========== Override Path ========== @@ -1407,6 +1421,8 @@ void setOverridePath(const std::string &path) // ========== Route Registration ========== +void handleGetSchedulerJobs(const httplib::Request &, httplib::Response &res); + void registerRoutes(httplib::Server &server) { // Ensure POST/PUT/DELETE requests without a body include Content-Length: 0. @@ -1480,6 +1496,7 @@ void registerRoutes(httplib::Server &server) server.Get("/api/stats", handleGetStats); server.Get("/api/stats/history", handleGetStatsHistory); server.Get("/api/stats/swaps", handleGetStatsSwaps); + server.Get("/api/scheduler/jobs", handleGetSchedulerJobs); server.Get("/api/hardware", handleGetHardware); server.Post("/api/hardware/vram-override", handleSetVramOverride); server.Delete(R"(/api/hardware/vram-override/(\d+))", handleClearVramOverride); @@ -1682,12 +1699,13 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) } bool harmonyMode=isHarmonyFormat(arbiterRequest.model); + bool useScheduler=isLocalModel(arbiterRequest.model)&&InferenceScheduler::instance().isRunning(); if(stream) { res.set_chunked_content_provider( "text/event-stream", - [arbiterRequest, requestId, created, includeUsage, responseModelId, harmonyMode](size_t, httplib::DataSink &sink) + [arbiterRequest, requestId, created, includeUsage, responseModelId, harmonyMode, useScheduler](size_t, httplib::DataSink &sink) { // Send initial chunk with role nlohmann::json roleChunk={ @@ -1706,48 +1724,137 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) sink.write(roleLine.c_str(), roleLine.length()); HarmonyStreamParser harmonyParser; + ErrorCode err=ErrorCode::Success; + int usagePromptTokens=0; + int usageCompletionTokens=0; - auto callback=[&](const std::string &chunk) + if(useScheduler) { - if(chunk.empty()) return; + // Submit to the inference scheduler pipeline + auto job=InferenceScheduler::instance().submit(arbiterRequest, true); - std::string emitContent; - if(harmonyMode) + // Read tokens from the channel, sending keepalive while queued + while(true) { - emitContent=harmonyParser.feed(chunk); - if(emitContent.empty()) return; + std::string token; + bool alive=job->channel->pop(token, std::chrono::milliseconds(500)); + + if(!alive) + { + // Channel finished + break; + } + + if(token.empty()) + { + // Timeout — send keepalive based on current stage + InferenceStage stage=job->stage.load(); + std::string comment; + if(stage==InferenceStage::Queued||stage==InferenceStage::Tokenizing) + { + comment=": status: tokenizing prompt\n\n"; + } + else if(stage==InferenceStage::WaitingAccelerator) + { + int pos=job->queuePosition.load(); + comment=": status: queued for inference (position "+std::to_string(pos)+")\n\n"; + } + else if(stage==InferenceStage::Inferring) + { + comment=": status: generating\n\n"; + } + else + { + comment=": status: processing\n\n"; + } + if(!sink.write(comment.c_str(), comment.length())) + { + // Client disconnected — cancel the job + InferenceScheduler::instance().cancel(job->id); + break; + } + continue; + } + + // Emit token via SSE + std::string emitContent; + if(harmonyMode) + { + emitContent=harmonyParser.feed(token); + if(emitContent.empty()) continue; + } + else + { + emitContent=token; + } + + nlohmann::json sseChunk={ + {"id", requestId}, + {"object", "chat.completion.chunk"}, + {"created", created}, + {"model", responseModelId}, + {"system_fingerprint", nullptr}, + {"choices", {{ + {"index", 0}, + {"delta", {{"content", emitContent}}}, + {"finish_reason", nullptr} + }}} + }; + std::string line="data: "+sseChunk.dump()+"\n\n"; + if(!sink.write(line.c_str(), line.length())) + { + // Client disconnected — cancel the job + InferenceScheduler::instance().cancel(job->id); + break; + } } - else + + err=job->channel->getResult(); + usagePromptTokens=job->promptTokens; + usageCompletionTokens=job->completionTokens.load(); + } + else + { + // Remote provider — use direct streaming callback + auto callback=[&](const std::string &chunk) { - emitContent=chunk; - } + if(chunk.empty()) return; + + std::string emitContent; + if(harmonyMode) + { + emitContent=harmonyParser.feed(chunk); + if(emitContent.empty()) return; + } + else + { + emitContent=chunk; + } - nlohmann::json sseChunk={ - {"id", requestId}, - {"object", "chat.completion.chunk"}, - {"created", created}, - {"model", responseModelId}, - {"system_fingerprint", nullptr}, - {"choices", {{ - {"index", 0}, - {"delta", {{"content", emitContent}}}, - {"finish_reason", nullptr} - }}} + nlohmann::json sseChunk={ + {"id", requestId}, + {"object", "chat.completion.chunk"}, + {"created", created}, + {"model", responseModelId}, + {"system_fingerprint", nullptr}, + {"choices", {{ + {"index", 0}, + {"delta", {{"content", emitContent}}}, + {"finish_reason", nullptr} + }}} + }; + std::string line="data: "+sseChunk.dump()+"\n\n"; + sink.write(line.c_str(), line.length()); }; - std::string line="data: "+sseChunk.dump()+"\n\n"; - sink.write(line.c_str(), line.length()); - }; - // Send SSE comments while waiting for the inference lock. - // This keeps the connection alive and signals to clients that - // the request is queued for processing. - auto waitCallback=[&]() - { - std::string comment=": queued - waiting for model availability\n\n"; - sink.write(comment.c_str(), comment.length()); - }; + auto waitCallback=[&]() + { + std::string comment=": queued - waiting for model availability\n\n"; + sink.write(comment.c_str(), comment.length()); + }; - ErrorCode err=ArbiterAI::instance().streamingCompletion(arbiterRequest, callback, waitCallback); + err=ArbiterAI::instance().streamingCompletion(arbiterRequest, callback, waitCallback); + } std::string finishReason=(err==ErrorCode::Success)?"stop":"error"; @@ -1818,9 +1925,9 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) {"system_fingerprint", nullptr}, {"choices", nlohmann::json::array()}, {"usage", { - {"prompt_tokens", 0}, - {"completion_tokens", 0}, - {"total_tokens", 0} + {"prompt_tokens", usagePromptTokens}, + {"completion_tokens", usageCompletionTokens}, + {"total_tokens", usagePromptTokens+usageCompletionTokens} }} }; std::string usageLine="data: "+usageChunk.dump()+"\n\n"; @@ -1834,8 +1941,198 @@ void handleChatCompletions(const httplib::Request &req, httplib::Response &res) } ); } + else if(useScheduler) + { + // Submit before committing to a chunked response (whose status is + // locked at 200) so fast failures — missing model file, load errors, + // tokenize errors — can still return a proper HTTP status. + auto job=InferenceScheduler::instance().submit(arbiterRequest, false); + + { + std::unique_lock lock(job->completionMutex); + job->completionCv.wait_for(lock, std::chrono::seconds(2), [&job]() + { + return job->complete.load(); + }); + } + + if(job->complete.load()&&job->result!=ErrorCode::Success) + { + ErrorCode err=job->result; + int status=500; + std::string errType="server_error"; + std::string errCode=errorCodeToString(err); + + if(err==ErrorCode::UnknownModel||err==ErrorCode::ModelNotFound) + { + status=404; + errType="invalid_request_error"; + } + else if(err==ErrorCode::InvalidRequest) + { + status=400; + errType="invalid_request_error"; + } + else if(err==ErrorCode::ServerOverloaded||err==ErrorCode::ModelDownloading) + { + status=503; + } + + res.status=status; + res.set_content(errorJson("Completion failed: "+errCode, errType, "", errCode).dump(), "application/json"); + return; + } + + // Non-streaming scheduler path: use chunked encoding with heartbeat + // newlines every ~30s to keep the client alive while queued/inferring. + res.set_chunked_content_provider( + "application/json", + [job, arbiterRequest, requestId, created, responseModelId, harmonyMode](size_t, httplib::DataSink &sink) + { + // Poll for completion, sending heartbeat newlines every ~30s + constexpr auto heartbeatInterval=std::chrono::seconds(30); + { + std::unique_lock lock(job->completionMutex); + while(!job->complete.load()) + { + if(job->completionCv.wait_for(lock, heartbeatInterval, [&job]() + { + return job->complete.load(); + })) + { + break; + } + + // Send heartbeat newline to keep connection alive + if(!sink.write("\n", 1)) + { + // Client disconnected — cancel the job + InferenceScheduler::instance().cancel(job->id); + sink.done(); + return true; + } + } + } + + CompletionResponse arbiterResponse; + ErrorCode err=job->result; + + if(err==ErrorCode::Success) + { + arbiterResponse.text=job->resultText; + arbiterResponse.provider="llama"; + arbiterResponse.model=arbiterRequest.model; + arbiterResponse.usage.prompt_tokens=job->promptTokens; + arbiterResponse.usage.completion_tokens=job->completionTokens.load(); + arbiterResponse.usage.total_tokens=job->promptTokens+job->completionTokens.load(); + arbiterResponse.finishReason="stop"; + } + + if(err!=ErrorCode::Success) + { + std::string errCode=errorCodeToString(err); + std::string errBody=errorJson("Completion failed: "+errCode, "server_error", "", errCode).dump(); + sink.write(errBody.c_str(), errBody.length()); + sink.done(); + return true; + } + + std::string finishReason=arbiterResponse.finishReason.empty()?"stop":arbiterResponse.finishReason; + + if(harmonyMode&&!arbiterResponse.text.empty()) + { + HarmonyParseResult parsed=parseHarmonyFormat(arbiterResponse.text); + arbiterResponse.text=parsed.content; + if(!parsed.reasoningContent.empty()) + arbiterResponse.reasoningContent=parsed.reasoningContent; + + if(!parsed.toolCalls.empty()) + { + for(size_t i=0; i(); + else + argsStr=tc.arguments.dump(); + + toolCallsJson.push_back({ + {"id", tc.id}, + {"type", "function"}, + {"function", { + {"name", tc.name}, + {"arguments", argsStr} + }} + }); + } + messageJson["tool_calls"]=toolCallsJson; + if(finishReason=="stop") finishReason="tool_calls"; + } + else + { + messageJson["content"]=arbiterResponse.text; + if(!arbiterResponse.reasoningContent.empty()) + messageJson["reasoning_content"]=arbiterResponse.reasoningContent; + messageJson["tool_calls"]=nullptr; + } + + nlohmann::json responseJson={ + {"id", requestId}, + {"object", "chat.completion"}, + {"created", created}, + {"model", responseModelId}, + {"system_fingerprint", nullptr}, + {"choices", {{ + {"index", 0}, + {"message", messageJson}, + {"finish_reason", finishReason} + }}}, + {"usage", { + {"prompt_tokens", arbiterResponse.usage.prompt_tokens}, + {"completion_tokens", arbiterResponse.usage.completion_tokens}, + {"total_tokens", arbiterResponse.usage.total_tokens} + }} + }; + + std::string body=responseJson.dump(); + sink.write(body.c_str(), body.length()); + sink.done(); + return true; + } + ); + } else { + // Non-streaming remote provider path (synchronous, no heartbeat needed) CompletionResponse arbiterResponse; ErrorCode err=ArbiterAI::instance().completion(arbiterRequest, arbiterResponse); @@ -3012,6 +3309,45 @@ void handleGetStatsSwaps(const httplib::Request &, httplib::Response &res) res.set_content(arr.dump(), "application/json"); } +void handleGetSchedulerJobs(const httplib::Request &, httplib::Response &res) +{ + if(!InferenceScheduler::instance().isRunning()) + { + res.set_content("[]", "application/json"); + return; + } + + std::vector jobs=InferenceScheduler::instance().getActiveJobs(); + + nlohmann::json arr=nlohmann::json::array(); + for(const JobSnapshot &j:jobs) + { + std::string stageStr; + switch(j.stage) + { + case InferenceStage::Queued: stageStr="queued"; break; + case InferenceStage::Tokenizing: stageStr="tokenizing"; break; + case InferenceStage::WaitingAccelerator: stageStr="waiting"; break; + case InferenceStage::Inferring: stageStr="inferring"; break; + case InferenceStage::Complete: stageStr="complete"; break; + case InferenceStage::Cancelled: stageStr="cancelled"; break; + } + + arr.push_back({ + {"id", j.id}, + {"model", j.model}, + {"stage", stageStr}, + {"streaming", j.streaming}, + {"prompt_tokens", j.promptTokens}, + {"completion_tokens", j.completionTokens}, + {"queue_position", j.queuePosition}, + {"elapsed_ms", j.elapsedMs} + }); + } + + res.set_content(arr.dump(), "application/json"); +} + void handleGetHardware(const httplib::Request &, httplib::Response &res) { HardwareDetector::instance().refresh(); diff --git a/tests/configDownloaderTests.cpp b/tests/configDownloaderTests.cpp index 2656eef..9280544 100644 --- a/tests/configDownloaderTests.cpp +++ b/tests/configDownloaderTests.cpp @@ -7,6 +7,11 @@ #include #include +#include +#include +#include +#include + namespace arbiterAI { @@ -15,46 +20,71 @@ class ConfigDownloaderTest : public ::testing::Test protected: std::string remote_repo_path; std::string local_repo_path; - pid_t server_pid; + + /// Poll until the git daemon accepts TCP connections (or timeout). + static bool waitForDaemon(int port, std::chrono::seconds timeout) + { + auto deadline=std::chrono::steady_clock::now()+timeout; + + while(std::chrono::steady_clock::now()=0) + { + sockaddr_in addr{}; + addr.sin_family=AF_INET; + addr.sin_port=htons(port); + addr.sin_addr.s_addr=htonl(INADDR_LOOPBACK); + + int result=connect(sock, reinterpret_cast(&addr), sizeof(addr)); + close(sock); + if(result==0) + return true; + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + return false; + } void SetUp() override { local_repo_path=(std::filesystem::temp_directory_path()/"test_repo").string(); remote_repo_path=(std::filesystem::temp_directory_path()/"remote_repo.git").string(); - // Create a bare git repository to act as the remote + // Clean up any leftovers from a previous (crashed) run + std::system("pkill -x git-daemon >/dev/null 2>&1"); + std::filesystem::remove_all(local_repo_path); + std::filesystem::remove_all(remote_repo_path); + + // Create a bare git repository to act as the remote. + // libgit2 must be initialized first — without it + // git_repository_init produces an invalid repository. + git_libgit2_init(); + git_repository *repo=nullptr; - git_repository_init(&repo, remote_repo_path.c_str(), 1); + ASSERT_EQ(git_repository_init(&repo, remote_repo_path.c_str(), 1), 0); git_repository_free(repo); - std::thread server_thread([this]() + std::thread server_thread([]() { std::string base_path=std::filesystem::temp_directory_path().string(); - std::string command="git daemon --verbose --export-all --port=8080 --reuseaddr --base-path="+base_path+" "; + std::string command="git daemon --export-all --port=8080 --reuseaddr --base-path="+base_path+" >/dev/null 2>&1"; - int result=std::system(command.c_str()); - if(result!=0) - { - std::cerr<<"Failed to start git daemon"<0) - { - kill(server_pid, SIGTERM); - waitpid(server_pid, nullptr, 0); - } + std::system("pkill -x git-daemon >/dev/null 2>&1"); std::filesystem::remove_all(local_repo_path); std::filesystem::remove_all(remote_repo_path); + git_libgit2_shutdown(); } }; diff --git a/tests/inferenceSchedulerTests.cpp b/tests/inferenceSchedulerTests.cpp new file mode 100644 index 0000000..1d70ed6 --- /dev/null +++ b/tests/inferenceSchedulerTests.cpp @@ -0,0 +1,284 @@ +#include "arbiterAI/inferenceScheduler.h" +#include "arbiterAI/modelManager.h" +#include "arbiterAI/modelRuntime.h" + +#include +#include +#include +#include + +namespace arbiterAI +{ + +// ── TokenChannel ────────────────────────────────────────────── + +TEST(TokenChannelTest, PushPopOrder) +{ + TokenChannel channel; + + channel.push("alpha"); + channel.push("beta"); + + std::string token; + ASSERT_TRUE(channel.pop(token, std::chrono::milliseconds(100))); + EXPECT_EQ(token, "alpha"); + + ASSERT_TRUE(channel.pop(token, std::chrono::milliseconds(100))); + EXPECT_EQ(token, "beta"); +} + +TEST(TokenChannelTest, PopDrainsAfterFinish) +{ + TokenChannel channel; + + channel.push("alpha"); + channel.push("beta"); + channel.finish(ErrorCode::Success); + + std::string token; + ASSERT_TRUE(channel.pop(token, std::chrono::milliseconds(100))); + EXPECT_EQ(token, "alpha"); + + ASSERT_TRUE(channel.pop(token, std::chrono::milliseconds(100))); + EXPECT_EQ(token, "beta"); + + EXPECT_FALSE(channel.pop(token, std::chrono::milliseconds(100))); + EXPECT_TRUE(channel.isDone()); + EXPECT_EQ(channel.getResult(), ErrorCode::Success); +} + +TEST(TokenChannelTest, PopTimeoutKeepsAlive) +{ + TokenChannel channel; + + // No tokens and not finished — pop times out but reports the + // channel as still alive (caller uses this to send keepalives). + std::string token; + EXPECT_TRUE(channel.pop(token, std::chrono::milliseconds(50))); + EXPECT_TRUE(token.empty()); + EXPECT_FALSE(channel.isDone()); +} + +TEST(TokenChannelTest, FinishStoresResult) +{ + TokenChannel channel; + + channel.finish(ErrorCode::NetworkError); + + std::string token; + EXPECT_FALSE(channel.pop(token, std::chrono::milliseconds(100))); + EXPECT_TRUE(channel.isDone()); + EXPECT_EQ(channel.getResult(), ErrorCode::NetworkError); +} + +TEST(TokenChannelTest, CancelWakesConsumer) +{ + TokenChannel channel; + + std::thread canceller([&channel]() + { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + channel.cancel(); + }); + + std::string token; + bool alive=channel.pop(token, std::chrono::seconds(5)); + canceller.join(); + + EXPECT_FALSE(alive); + EXPECT_TRUE(channel.isCancelled()); +} + +TEST(TokenChannelTest, PushAfterCancelIsDropped) +{ + TokenChannel channel; + + channel.cancel(); + channel.push("alpha"); + + std::string token; + EXPECT_FALSE(channel.pop(token, std::chrono::milliseconds(100))); +} + +TEST(TokenChannelTest, ConcurrentProducerConsumer) +{ + TokenChannel channel; + constexpr int tokenCount=200; + + std::thread producer([&channel]() + { + for(int i=0; i &job, + std::chrono::milliseconds timeout) +{ + std::unique_lock lock(job->completionMutex); + return job->completionCv.wait_for(lock, timeout, [&job]() + { + return job->complete.load(); + }); +} + +} // anonymous namespace + +class InferenceSchedulerTest : public ::testing::Test +{ +protected: + void SetUp() override + { + ModelRuntime::reset(); + ModelManager::reset(); + + // No GPUs — the scheduler falls back to a single CPU queue. + InferenceScheduler::instance().initialize({}); + } + + void TearDown() override + { + InferenceScheduler::instance().shutdown(); + } +}; + +TEST_F(InferenceSchedulerTest, InitializeCreatesCpuQueueWhenNoGpus) +{ + EXPECT_TRUE(InferenceScheduler::instance().isRunning()); + EXPECT_EQ(InferenceScheduler::instance().getTotalQueueDepth(), 0); + EXPECT_TRUE(InferenceScheduler::instance().getActiveJobs().empty()); +} + +TEST_F(InferenceSchedulerTest, SubmitUnknownModelFailsAndSignalsCompletion) +{ + CompletionRequest request; + request.model="no-such-model"; + request.messages={{"user", "hello"}}; + + auto job=InferenceScheduler::instance().submit(request, false); + ASSERT_TRUE(job!=nullptr); + + ASSERT_TRUE(waitForComplete(job, std::chrono::seconds(10))); + EXPECT_NE(job->result, ErrorCode::Success); + + // Terminal jobs must be removed from active tracking + EXPECT_EQ(InferenceScheduler::instance().getJobStage(job->id), InferenceStage::Complete); + EXPECT_TRUE(InferenceScheduler::instance().getActiveJobs().empty()); +} + +TEST_F(InferenceSchedulerTest, StreamingFailureFinishesChannel) +{ + CompletionRequest request; + request.model="no-such-model"; + request.messages={{"user", "hello"}}; + + auto job=InferenceScheduler::instance().submit(request, true); + ASSERT_TRUE(job!=nullptr); + ASSERT_TRUE(job->channel!=nullptr); + + // Pop until the channel reports finished — a failed job must + // always finish its channel so the HTTP thread is not left hanging. + int idlePolls=0; + bool finished=false; + + while(idlePolls<100) + { + std::string token; + if(!job->channel->pop(token, std::chrono::milliseconds(100))) + { + finished=true; + break; + } + if(token.empty()) + { + idlePolls++; + } + } + + ASSERT_TRUE(finished); + EXPECT_TRUE(job->channel->isDone()); + EXPECT_NE(job->channel->getResult(), ErrorCode::Success); +} + +TEST_F(InferenceSchedulerTest, CancelledJobSignalsCompletion) +{ + CompletionRequest request; + request.model="no-such-model"; + request.messages={{"user", "hello"}}; + + auto job=InferenceScheduler::instance().submit(request, false); + InferenceScheduler::instance().cancel(job->id); + + // The job either gets cancelled or fails at load — but it must + // always reach a terminal state and wake any waiter. + ASSERT_TRUE(waitForComplete(job, std::chrono::seconds(10))); + EXPECT_NE(job->result, ErrorCode::Success); + EXPECT_TRUE(InferenceScheduler::instance().getActiveJobs().empty()); +} + +TEST_F(InferenceSchedulerTest, CancelUnknownJobIsNoop) +{ + InferenceScheduler::instance().cancel(987654321); + EXPECT_TRUE(InferenceScheduler::instance().isRunning()); +} + +TEST_F(InferenceSchedulerTest, JobIdsAreUnique) +{ + CompletionRequest request; + request.model="no-such-model"; + request.messages={{"user", "hello"}}; + + auto job1=InferenceScheduler::instance().submit(request, false); + auto job2=InferenceScheduler::instance().submit(request, false); + + EXPECT_NE(job1->id, job2->id); + + EXPECT_TRUE(waitForComplete(job1, std::chrono::seconds(10))); + EXPECT_TRUE(waitForComplete(job2, std::chrono::seconds(10))); +} + +TEST_F(InferenceSchedulerTest, GetJobStageUnknownReturnsComplete) +{ + EXPECT_EQ(InferenceScheduler::instance().getJobStage(123456789), InferenceStage::Complete); +} + +TEST_F(InferenceSchedulerTest, ShutdownIsIdempotent) +{ + InferenceScheduler::instance().shutdown(); + EXPECT_FALSE(InferenceScheduler::instance().isRunning()); + + InferenceScheduler::instance().shutdown(); + EXPECT_FALSE(InferenceScheduler::instance().isRunning()); +} + +} // namespace arbiterAI From b970fa6bf09c4ff2be2760e4465448a219dcbe4a Mon Sep 17 00:00:00 2001 From: krazer Date: Wed, 10 Jun 2026 18:05:55 -0400 Subject: [PATCH 3/4] add CLAUDE.md with build, architecture, and style guidance Co-Authored-By: Claude Opus 4.8 (1M context) --- CLAUDE.md | 103 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 CLAUDE.md diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..8211291 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,103 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## What this is + +ArbiterAI is a C++17 library providing a unified, embeddable interface across multiple LLM +providers (OpenAI, Anthropic, DeepSeek, OpenRouter, llama.cpp local models, and a Mock provider for +testing). It also ships a standalone OpenAI-compatible HTTP server (`arbiterAI-server`) with model +lifecycle management, telemetry, and a live dashboard. + +## Build, test, run — everything goes through Docker + +All building, testing, and running happens **inside the Docker container** (`docker/Dockerfile`). The +host is not guaranteed to have the toolchain (CMake + vcpkg + llama.cpp) or dependencies. Dependencies +are managed by vcpkg (`vcpkg.json`). + +```bash +./runDocker.sh # start/attach the container (bind-mounts repo at /app) +./runDocker.sh ./build.sh # build (runs CMake automatically if cmake files changed) +./runDocker.sh ./build.sh --rebuild # clean rebuild of the app +./runDocker.sh ./build.sh --rebuild-cmake # nuke CMake dir + re-run CMake (only if cmake is broken) +./runDocker.sh --rebuild # rebuild the Docker *image* (only when Dockerfile changes) +./runDocker.sh --stop # stop and remove the container +``` + +Build output: `build/${OS}_${ARCH}_${BUILD_TYPE}`, default `build/linux_x64_debug/`. +Targets: `arbiterai` (library), `arbiterai_tests`, `arbiterAI-cli`, `arbiterAI-proxy`, `arbiterAI-server`. + +### Tests (Google Test) + +```bash +./runDocker.sh ./build/linux_x64_debug/arbiterai_tests +./runDocker.sh ./build/linux_x64_debug/arbiterai_tests --gtest_filter='ModelManager*' # single suite/test +``` + +### Working rules + +- Run binaries/commands through `./runDocker.sh ...`. Do **not** use host `python`/`pip`/`pytest` or host virtualenvs — the container is the environment. +- Do **not** launch `arbiterAI-server` yourself; ask the user to launch it so it doesn't occupy the agent terminal. +- Avoid `2>&1` redirection — the user needs to see live output. + +## Configuration model + +Model/provider configs are JSON, loaded by `ModelManager` (singleton) with schema validation +(`schemas/`). The default configs live in the **`arbiterAI_config` git submodule** (`arbiterAI_config/configs/defaults/{models,backends}/`). +`ArbiterAI::initialize()` takes a list of config directories. The server merges these with runtime-injected +configs (added/updated/removed via REST without restart) and can persist them via an override path. + +## Architecture + +Layered, strategy-pattern core (see `docs/developer.md` for the full API reference): + +``` +ArbiterAI (singleton factory + lifecycle) ── src/arbiterAI/arbiterAI.{h,cpp} + ├─ createChatClient() → ChatClient (stateful per-session: history, tools, cache, stats) + ├─ owns ModelManager (singleton: config load, schema validation, model lookup, ConfigDownloader) + └─ stateless convenience: completion(), streamingCompletion(), batchCompletion(), getEmbeddings() + │ delegates to + BaseProvider (abstract) ── src/arbiterAI/providers/baseProvider.h + OpenAI · Anthropic · DeepSeek · OpenRouter · Llama (local) · Mock +``` + +- **Providers** are instantiated by a `switch` in `arbiterAI.cpp` keyed on the provider string (`createProvider`-style factory). To add a provider: create `providers/.{h,cpp}` subclassing `BaseProvider`, add it to that switch, add the source to `CMakeLists.txt`, and add a model config JSON. +- **Error handling is error-code based** (`ErrorCode` enum), not exceptions — follow this; avoid try/catch where an error code works. + +### Local model subsystem (llama.cpp) + +Distinct from the cloud providers, this is the heavier piece: + +- **`ModelRuntime`** (`modelRuntime.{h,cpp}`) — multi-model loading into VRAM/RAM, swap queueing, LRU eviction, GGUF-aware load-failure classification (`LoadFailureReason`/`LoadErrorDetail`). +- **`InferenceScheduler`** (`inferenceScheduler.{h,cpp}`) — request pipeline with stages (Queued → Tokenizing → WaitingAccelerator → Inferring → Complete), and `TokenChannel` for streaming tokens from the accelerator thread to the HTTP thread. +- **`HardwareDetector`** — GPU/VRAM/RAM/CPU detection; **`ModelFitCalculator`** — whether a model fits available hardware. +- **`ModelDownloader`** / **`StorageManager`** — download GGUF files (libgit2 / HTTP), track storage, hot-ready/protected flags, cleanup. +- **`TelemetryCollector`** — inference stats and system snapshots, surfaced by the server. + +### Server (`src/server/`) + +Separate CMake target linking `arbiterai` + cpp-httplib (httplib is a server-only dependency, kept out of +the core library). `routes.cpp` defines the OpenAI-compatible endpoints (`/v1/chat/completions`, +`/v1/models`, `/v1/embeddings` with SSE streaming), model management, telemetry (`/api/stats`), and config +injection. `dashboard.h`/`dashboardConfig.h` are embedded HTML/JS for the `/dashboard` UI. The server takes a +single required config file: `arbiterAI-server -c `. See `docs/server.md`. + +### Testing without API keys + +The **Mock provider** (`providers/mock.{h,cpp}`) returns deterministic responses driven by `...` +tags in messages — no network or keys. Use `"provider": "mock"` in a model config. See `docs/testing.md`. + +## Code style (from `.roo/rules-code/` and `.github/instructions/`) + +- Files: **camelCase** names, `.h`/`.cpp`/`.inl`. Header guards `_PROJECT_FILENAME_EXT_`, **no `#pragma once`**. +- Braces: open brace on a **new line** for namespaces/functions/control blocks; **same line** for struct/class definitions in headers. +- Naming: Types `PascalCase`; functions/methods `camelCase`; class members `m_camelCase`; locals/struct vars `camelCase`; macros `UPPER_CASE`. +- Spacing: no space around `=`, `::`, unary operators, or between a keyword/function name and `(`; spaces around comparison/logical operators; comma after, not before. +- Pointers/refs bind to the variable: `type *var`, `type &var`. Minimize `auto`. Minimize comments — none for obvious code. +- Includes: `""` for local files, `<>` for libraries. Namespaces: prefer explicit qualification over `using` directives; aliases allowed. + +## Docs map + +`docs/developer.md` (architecture + API), `docs/server.md` (server API), `docs/testing.md` (mock/echo), +`docs/project.md` (goals/providers), `docs/tasks/` (active task plans). The `docs/old/` and +`docs/development/tasks/completed/` dirs are historical. From e906b7ee28c08a0fca199f9d5e739966bab5847c Mon Sep 17 00:00:00 2001 From: krazer Date: Wed, 10 Jun 2026 21:43:31 -0400 Subject: [PATCH 4/4] bump DOCKER_VERSION to 1.2.2 for spirv-headers addition Co-Authored-By: Claude Opus 4.8 (1M context) --- docker/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 072d0aa..e8ecea2 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -ARG DOCKER_VERSION=1.2.1 +ARG DOCKER_VERSION=1.2.2 FROM ubuntu:24.04 # Install basic build tools, Python 3, and GPU libraries.