From 704b6407b4e51800376e73fe934a762e94b30d9d Mon Sep 17 00:00:00 2001 From: Navan Chauhan Date: Sat, 14 Oct 2023 04:39:41 -0600 Subject: rebased --- main.py | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 72 insertions(+), 6 deletions(-) (limited to 'main.py') diff --git a/main.py b/main.py index c49cdd6..54b6c9c 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,11 @@ +import asyncio + +import os +from typing import AsyncGenerator, AsyncIterable, Awaitable, Optional, Tuple + +from vocode.streaming.models.agent import AgentConfig, AgentType +from vocode.streaming.agent.base_agent import BaseAgent, RespondAgent + import logging import os from fastapi import FastAPI @@ -14,6 +22,17 @@ from vocode.streaming.telephony.server.base import ( TelephonyServer, ) +from vocode.streaming.telephony.server.base import TwilioCallConfig + +import uvicorn +from dotenv import load_dotenv +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from langchain.callbacks import AsyncIteratorCallbackHandler +from langchain.chat_models import ChatOpenAI +from langchain.schema import HumanMessage +from pydantic import BaseModel + from speller_agent import SpellerAgentFactory import sys @@ -23,7 +42,7 @@ from dotenv import load_dotenv load_dotenv() -app = FastAPI(docs_url=None) +app = FastAPI() logging.basicConfig() logger = logging.getLogger(__name__) @@ -46,17 +65,23 @@ if not BASE_URL: if not BASE_URL: raise ValueError("BASE_URL must be set in environment if not using pyngrok") +from speller_agent import SpellerAgentConfig + +print(AgentType) + telephony_server = TelephonyServer( base_url=BASE_URL, config_manager=config_manager, inbound_call_configs=[ TwilioInboundCallConfig( url="/inbound_call", - agent_config=ChatGPTAgentConfig( - initial_message=BaseMessage(text="What up."), - prompt_preamble="Act as a customer talking to 'Cosmos', a pizza establisment ordering a large pepperoni pizza for pickup. If asked for a name, your name is 'Hunter McRobie', and your credit card number is 4743 2401 5792 0539 CVV: 123 and expiratoin is 10/25. If asked for numbers, say them one by one",#"Have a polite conversation about life while talking like a pirate.", - generate_responses=True, - ), + # agent_config=ChatGPTAgentConfig( + # initial_message=BaseMessage(text="What up."), + # prompt_preamble="Act as a customer talking to 'Cosmos', a pizza establisment ordering a large pepperoni pizza for pickup. If asked for a name, your name is 'Hunter McRobie', and your credit card number is 4743 2401 5792 0539 CVV: 123 and expiratoin is 10/25. If asked for numbers, say them one by one",#"Have a polite conversation about life while talking like a pirate.", + # generate_responses=True, + # model_name="gpt-3.5-turbo" + # ), + agent_config=SpellerAgentConfig(generate_responses=False, initial_message=BaseMessage(text="What up.")), twilio_config=TwilioConfig( account_sid=os.environ["TWILIO_ACCOUNT_SID"], auth_token=os.environ["TWILIO_AUTH_TOKEN"], @@ -71,4 +96,45 @@ telephony_server = TelephonyServer( logger=logger, ) +async def send_message(message: str) -> AsyncIterable[str]: + callback = AsyncIteratorCallbackHandler() + model = ChatOpenAI( + streaming=True, + verbose=True, + callbacks=[callback], + ) + + async def wrap_done(fn: Awaitable, event: asyncio.Event): + """Wrap an awaitable with a event to signal when it's done or an exception is raised.""" + try: + await fn + except Exception as e: + # TODO: handle exception + print(f"Caught exception: {e}") + finally: + # Signal the aiter to stop. + event.set() + + # Begin a task that runs in the background. + task = asyncio.create_task(wrap_done( + model.agenerate(messages=[[HumanMessage(content=message)]]), + callback.done), + ) + + async for token in callback.aiter(): + # Use server-sent-events to stream the response + yield f"data: {token}\n\n" + + await task + + +class StreamRequest(BaseModel): + """Request body for streaming.""" + message: str + + +@app.post("/stream") +def stream(body: StreamRequest): + return StreamingResponse(send_message(body.message), media_type="text/event-stream") + app.include_router(telephony_server.get_router()) \ No newline at end of file -- cgit v1.2.3