Skip to content
4 changes: 1 addition & 3 deletions src/agora_agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,12 @@
if typing.TYPE_CHECKING:
from . import agents, agentkit, core, phone_numbers, telephony
from .core.domain import Area, Pool, create_pool
from .pool_client import Agora, AsyncAgora
from .pool_client import Agora, AsyncAgora, AgentClient, AsyncAgentClient
from .version import __version__
from .agentkit import (
Agent,
AgentSession,
AgentSessionOptions,
AgentClient,
AsyncAgentClient,
CNAgent,
GlobalAgent,
GenericAvatar,
Expand Down
99 changes: 51 additions & 48 deletions src/agora_agent/agentkit/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
AvatarConfig = StartAgentsRequestPropertiesAvatar
AvatarVendor = StartAgentsRequestPropertiesAvatarVendor
TurnDetectionConfig = StartAgentsRequestPropertiesTurnDetection
TurnDetectionInput = typing.Union[TurnDetectionConfig, typing.Dict[str, typing.Any]]
SalConfig = StartAgentsRequestPropertiesSal
SalMode = StartAgentsRequestPropertiesSalSalMode
AdvancedFeatures = StartAgentsRequestPropertiesAdvancedFeatures
Expand Down Expand Up @@ -351,52 +352,54 @@ class Agent:
"""

if typing.TYPE_CHECKING:
from .regional_agent import CNAgent, GlobalAgent

_GlobalArea = typing_extensions.Literal[Area.US, Area.EU, Area.AP]

@typing.overload
def __new__(
cls,
client: "Agora[typing_extensions.Literal[Area.CN]]",
*args: typing.Any,
**kwargs: typing.Any,
) -> "CNAgent":
...

@typing.overload
def __new__(
cls,
client: "Agora[_GlobalArea]",
*args: typing.Any,
**kwargs: typing.Any,
) -> "GlobalAgent":
...

@typing.overload
def __new__(
cls,
client: "AsyncAgora[typing_extensions.Literal[Area.CN]]",
*args: typing.Any,
**kwargs: typing.Any,
) -> "CNAgent":
...

@typing.overload
def __new__(
cls,
client: "AsyncAgora[_GlobalArea]",
*args: typing.Any,
**kwargs: typing.Any,
) -> "GlobalAgent":
...

@typing.overload
def __new__(
cls,
client: typing.Any,
*args: typing.Any,
**kwargs: typing.Any,
) -> "Agent":
...
@typing.overload
def __new__(
cls,
client: "Agora[typing_extensions.Literal[Area.CN]]",
*args: typing.Any,
**kwargs: typing.Any,
) -> "CNAgent":
...

@typing.overload
def __new__(
cls,
client: "Agora[_GlobalArea]",
*args: typing.Any,
**kwargs: typing.Any,
) -> "GlobalAgent":
...

@typing.overload
def __new__(
cls,
client: "AsyncAgora[typing_extensions.Literal[Area.CN]]",
*args: typing.Any,
**kwargs: typing.Any,
) -> "CNAgent":
...

@typing.overload
def __new__(
cls,
client: "AsyncAgora[_GlobalArea]",
*args: typing.Any,
**kwargs: typing.Any,
) -> "GlobalAgent":
...

@typing.overload
def __new__(
cls,
client: typing.Any,
*args: typing.Any,
**kwargs: typing.Any,
) -> "Agent":
...

def __new__(
cls,
Expand All @@ -422,7 +425,7 @@ def __init__(
self,
client: typing.Any,
instructions: typing.Optional[str] = None,
turn_detection: typing.Optional[TurnDetectionConfig] = None,
turn_detection: typing.Optional[TurnDetectionInput] = None,
interruption: typing.Optional[InterruptionConfig] = None,
sal: typing.Optional[SalConfig] = None,
advanced_features: typing.Optional[AdvancedFeatures] = None,
Expand Down Expand Up @@ -541,7 +544,7 @@ def with_avatar(self, vendor: BaseAvatar) -> "Agent":
new_agent._avatar_required_sample_rate = required_sample_rate
return new_agent

def with_turn_detection(self, config: TurnDetectionConfig) -> "Agent":
def with_turn_detection(self, config: TurnDetectionInput) -> "Agent":
new_agent = self._clone()
new_agent._turn_detection = config
return new_agent
Expand Down Expand Up @@ -733,7 +736,7 @@ def mllm(self) -> typing.Optional[typing.Dict[str, typing.Any]]:
return self._mllm

@property
def turn_detection(self) -> typing.Optional[TurnDetectionConfig]:
def turn_detection(self) -> typing.Optional[TurnDetectionInput]:
return self._turn_detection

@property
Expand Down Expand Up @@ -1072,7 +1075,7 @@ def _resolve_llm_config(self) -> typing.Dict[str, typing.Any]:
llm_config["max_history"] = self._max_history
return llm_config

def _resolve_asr_config(self, turn_detection_config: TurnDetectionConfig) -> typing.Dict[str, typing.Any]:
def _resolve_asr_config(self, turn_detection_config: TurnDetectionInput) -> typing.Dict[str, typing.Any]:
asr_config = dict(self._stt or {})
if not asr_config:
asr_config["vendor"] = "ares"
Expand Down
32 changes: 16 additions & 16 deletions src/agora_agent/agentkit/regional_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,31 +108,31 @@


class CNAgent(Agent):
def with_stt(self, vendor: CNSTT) -> "CNAgent":
return typing.cast("CNAgent", super().with_stt(typing.cast(BaseSTT, vendor)))
def with_stt(self, vendor: BaseSTT) -> "CNAgent":
return typing.cast("CNAgent", super().with_stt(vendor))

def with_llm(self, vendor: CNLLM) -> "CNAgent":
return typing.cast("CNAgent", super().with_llm(typing.cast(BaseLLM, vendor)))
def with_llm(self, vendor: BaseLLM) -> "CNAgent":
return typing.cast("CNAgent", super().with_llm(vendor))

def with_tts(self, vendor: CNTTS) -> "CNAgent":
return typing.cast("CNAgent", super().with_tts(typing.cast(BaseTTS, vendor)))
def with_tts(self, vendor: BaseTTS) -> "CNAgent":
return typing.cast("CNAgent", super().with_tts(vendor))

def with_avatar(self, vendor: CNAvatar) -> "CNAgent":
return typing.cast("CNAgent", super().with_avatar(typing.cast(BaseAvatar, vendor)))
def with_avatar(self, vendor: BaseAvatar) -> "CNAgent":
return typing.cast("CNAgent", super().with_avatar(vendor))


class GlobalAgent(Agent):
def with_stt(self, vendor: GlobalSTT) -> "GlobalAgent":
return typing.cast("GlobalAgent", super().with_stt(typing.cast(BaseSTT, vendor)))
def with_stt(self, vendor: BaseSTT) -> "GlobalAgent":
return typing.cast("GlobalAgent", super().with_stt(vendor))

def with_llm(self, vendor: GlobalLLM) -> "GlobalAgent":
return typing.cast("GlobalAgent", super().with_llm(typing.cast(BaseLLM, vendor)))
def with_llm(self, vendor: BaseLLM) -> "GlobalAgent":
return typing.cast("GlobalAgent", super().with_llm(vendor))

def with_tts(self, vendor: GlobalTTS) -> "GlobalAgent":
return typing.cast("GlobalAgent", super().with_tts(typing.cast(BaseTTS, vendor)))
def with_tts(self, vendor: BaseTTS) -> "GlobalAgent":
return typing.cast("GlobalAgent", super().with_tts(vendor))

def with_avatar(self, vendor: GlobalAvatar) -> "GlobalAgent":
return typing.cast("GlobalAgent", super().with_avatar(typing.cast(BaseAvatar, vendor)))
def with_avatar(self, vendor: BaseAvatar) -> "GlobalAgent":
return typing.cast("GlobalAgent", super().with_avatar(vendor))


RegionalAgent = typing.Union[CNAgent, GlobalAgent]
22 changes: 11 additions & 11 deletions src/agora_agent/agentkit/vendors/cn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field, model_validator

Expand Down Expand Up @@ -148,7 +148,7 @@ class MicrosoftSTTOptions(BaseModel):
key: str = Field(..., description="Azure subscription key")
region: str = Field(..., description="Azure region (e.g., eastus)")
language: str = Field(..., description="Language code (e.g., zh-CN)")
phrase_list: Optional[list[str]] = Field(default=None, description="Microsoft ASR phrase list")
phrase_list: Optional[List[str]] = Field(default=None, description="Microsoft ASR phrase list")
additional_params: Optional[Dict[str, Any]] = Field(default=None)


Expand Down Expand Up @@ -183,7 +183,7 @@ class TencentTTSOptions(BaseModel):
emotion_category: Optional[str] = Field(default=None, description="Tencent TTS emotion category")
emotion_intensity: Optional[int] = Field(default=None, description="Tencent TTS emotion intensity")
additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Tencent TTS params")
skip_patterns: Optional[list[int]] = Field(default=None)
skip_patterns: Optional[List[int]] = Field(default=None)


class TencentTTS(_BaseTTSCompat):
Expand Down Expand Up @@ -239,7 +239,7 @@ class BytedanceTTSOptions(BaseModel):
pitch_ratio: Optional[float] = Field(default=None, description="Bytedance TTS pitch ratio")
emotion: Optional[str] = Field(default=None, description="Bytedance TTS emotion")
additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Bytedance TTS params")
skip_patterns: Optional[list[int]] = Field(default=None)
skip_patterns: Optional[List[int]] = Field(default=None)


class BytedanceTTS(_BaseTTSCompat):
Expand Down Expand Up @@ -290,7 +290,7 @@ class BytedanceDuplexTTSOptions(BaseModel):
app_id: str = Field(..., description="Bytedance Duplex TTS app id")
speaker: str = Field(..., description="Bytedance Duplex TTS speaker")
additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Bytedance Duplex TTS params")
skip_patterns: Optional[list[int]] = Field(default=None)
skip_patterns: Optional[List[int]] = Field(default=None)


class BytedanceDuplexTTS(_BaseTTSCompat):
Expand Down Expand Up @@ -333,7 +333,7 @@ class CosyVoiceTTSOptions(BaseModel):
sample_rate: Optional[int] = Field(default=None, description="Output sample rate in Hz")
voice: Optional[str] = Field(default=None, description="CosyVoice voice")
additional_params: Optional[Dict[str, Any]] = Field(default=None, description="CosyVoice TTS params from REST doc")
skip_patterns: Optional[list[int]] = Field(default=None)
skip_patterns: Optional[List[int]] = Field(default=None)


class CosyVoiceTTS(_BaseTTSCompat):
Expand Down Expand Up @@ -377,7 +377,7 @@ class StepFunTTSOptions(BaseModel):
model: Optional[str] = Field(default=None, description="StepFun TTS model")
voice_id: Optional[str] = Field(default=None, description="StepFun TTS voice id")
additional_params: Optional[Dict[str, Any]] = Field(default=None, description="StepFun TTS params from REST doc")
skip_patterns: Optional[list[int]] = Field(default=None)
skip_patterns: Optional[List[int]] = Field(default=None)


class StepFunTTS(_BaseTTSCompat):
Expand Down Expand Up @@ -420,7 +420,7 @@ class MicrosoftTTSOptions(BaseModel):
speed: Optional[float] = Field(default=None, description="Speaking rate multiplier")
volume: Optional[float] = Field(default=None, description="Audio volume")
additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional Microsoft TTS params")
skip_patterns: Optional[list[int]] = Field(default=None)
skip_patterns: Optional[List[int]] = Field(default=None)


class MicrosoftTTS(_BaseTTSCompat):
Expand Down Expand Up @@ -463,12 +463,12 @@ class MiniMaxTTSOptions(BaseModel):
emotion: Optional[str] = Field(default=None, description="Emotion style")
latex_read: Optional[bool] = Field(default=None, description="Whether to read LaTeX expressions")
english_normalization: Optional[bool] = Field(default=None, description="Whether to normalize English text")
timber_weights: Optional[list[Dict[str, Any]]] = Field(default=None, description="Alternative timbre mix config")
timber_weights: Optional[List[Dict[str, Any]]] = Field(default=None, description="Alternative timbre mix config")
sample_rate: Optional[int] = Field(default=None, description="Output sample rate in Hz")
pronunciation_dict: Optional[Dict[str, Any]] = Field(default=None, description="Pronunciation replacement dictionary")
language_boost: Optional[str] = Field(default=None, description="Language boost strategy")
additional_params: Optional[Dict[str, Any]] = Field(default=None, description="Additional MiniMax TTS params")
skip_patterns: Optional[list[int]] = Field(default=None)
skip_patterns: Optional[List[int]] = Field(default=None)

@model_validator(mode="after")
def _validate_params(self) -> "MiniMaxTTSOptions":
Expand Down Expand Up @@ -556,7 +556,7 @@ class SenseTimeAvatarOptions(BaseModel):
agora_uid: str = Field(..., description="Avatar RTC publisher uid")
app_id: Optional[str] = Field(default=None, alias="appId", description="SenseTime app id")
app_key: str = Field(..., description="SenseTime app key")
scene_list: list[Dict[str, Any]] = Field(..., alias="sceneList", description="SenseTime scene list")
scene_list: List[Dict[str, Any]] = Field(..., alias="sceneList", description="SenseTime scene list")
enable: Optional[bool] = Field(default=None)
additional_params: Optional[Dict[str, Any]] = Field(default=None)

Expand Down
Loading
Loading