import json
import time
from datetime import datetime, timedelta
from typing import Annotated, List, Union
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from fastapi.responses import Response, StreamingResponse
from fastapi import Depends, Header, Request, Response, HTTPException
import chainfury.types as T
import chainfury_server.database as DB
from chainfury_server.engines import engine_registry
[docs]def create_chain(
req: Request,
resp: Response,
token: Annotated[str, Header()],
chatbot_data: T.ApiCreateChainRequest,
db: Session = Depends(DB.fastapi_db_session),
) -> Union[T.ApiChain, T.ApiResponse]:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
# validate chatbot
if not chatbot_data.name:
resp.status_code = 400
return T.ApiResponse(message="Name not specified")
if not chatbot_data.engine:
resp.status_code = 400
return T.ApiResponse(message="Engine not specified")
if chatbot_data.engine not in DB.ChatBotTypes.all():
resp.status_code = 400
return T.ApiResponse(message=f"Invalid engine should be one of {DB.ChatBotTypes.all()}")
# DB call
dag = chatbot_data.dag.dict() if chatbot_data.dag else {}
chatbot = DB.ChatBot(
name=chatbot_data.name,
created_by=user.id,
dag=dag,
engine=chatbot_data.engine,
created_at=datetime.now(),
description=chatbot_data.description,
) # type: ignore
db.add(chatbot)
db.commit()
db.refresh(chatbot)
# return
response = T.ApiChain(**chatbot.to_dict())
return response
[docs]def get_chain(
req: Request,
resp: Response,
token: Annotated[str, Header()],
id: str,
tag_id: str = "",
db: Session = Depends(DB.fastapi_db_session),
) -> Union[T.ApiChain, T.ApiResponse]:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
# DB call
filters = [
DB.ChatBot.id == id,
DB.ChatBot.created_by == user.id,
DB.ChatBot.deleted_at == None,
]
if tag_id:
filters.append(DB.ChatBot.tag_id == tag_id)
chatbot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore
if not chatbot:
resp.status_code = 404
return T.ApiResponse(message="ChatBot not found")
# return
return T.ApiChain(**chatbot.to_dict())
[docs]def update_chain(
req: Request,
resp: Response,
token: Annotated[str, Header()],
id: str,
chatbot_data: T.ApiChain,
tag_id: str = "",
db: Session = Depends(DB.fastapi_db_session),
) -> Union[T.ApiChain, T.ApiResponse]:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
# validate chatbot update
if not len(chatbot_data.update_keys):
resp.status_code = 400
return T.ApiResponse(message="No keys to update")
unq_keys = set(chatbot_data.update_keys)
valid_keys = {"name", "description", "dag"}
if not unq_keys.issubset(valid_keys):
resp.status_code = 400
return T.ApiResponse(message=f"Invalid keys {unq_keys.difference(valid_keys)}")
# DB Call
filters = [
DB.ChatBot.id == id,
DB.ChatBot.created_by == user.id,
DB.ChatBot.deleted_at == None,
]
if tag_id:
filters.append(DB.ChatBot.tag_id == tag_id)
chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore
if not chatbot:
resp.status_code = 404
return T.ApiResponse(message="ChatBot not found")
for field in unq_keys:
if field == "name":
chatbot.name = chatbot_data.name # type: ignore
elif field == "description":
chatbot.description = chatbot_data.description # type: ignore
elif field == "dag":
chatbot.dag = chatbot_data.dag.dict() # type: ignore
db.commit()
db.refresh(chatbot)
# return
return T.ApiChain(**chatbot.to_dict())
[docs]def delete_chain(
req: Request,
resp: Response,
token: Annotated[str, Header()],
id: str,
tag_id: str = "",
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
# DB Call
filters = [
DB.ChatBot.id == id,
DB.ChatBot.created_by == user.id,
DB.ChatBot.deleted_at == None,
]
if tag_id:
filters.append(DB.ChatBot.tag_id == tag_id)
chatbot: DB.ChatBot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore
if not chatbot:
resp.status_code = 404
return T.ApiResponse(message="ChatBot not found")
chatbot.deleted_at = datetime.now()
db.commit()
# return
return T.ApiResponse(message=f"ChatBot: '{chatbot.name}' ({chatbot.id}) deleted")
[docs]def list_chains(
req: Request,
resp: Response,
token: Annotated[str, Header()],
skip: int = 0,
limit: int = 10,
tag_id: str = "",
db: Session = Depends(DB.fastapi_db_session),
) -> T.ApiListChainsResponse:
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
# DB Call
filters = [
DB.ChatBot.created_by == user.id,
DB.ChatBot.deleted_at == None,
]
if tag_id:
filters.append(DB.ChatBot.tag_id == tag_id)
chatbots: List[DB.ChatBot] = db.query(DB.ChatBot).filter(*filters).offset(skip).limit(limit).all() # type: ignore
# return
return T.ApiListChainsResponse(
chatbots=[T.ApiChain(**chatbot.to_dict()) for chatbot in chatbots],
)
[docs]def run_chain(
req: Request,
resp: Response,
id: str,
token: Annotated[str, Header()],
prompt: T.ApiPromptBody,
stream: bool = False,
as_task: bool = False,
store_ir: bool = False,
store_io: bool = False,
db: Session = Depends(DB.fastapi_db_session),
) -> Union[StreamingResponse, T.CFPromptResult, T.ApiResponse]:
"""
This is the master function to run any chain over the API. This can behave in a bunch of different formats like:
- (default) this will wait for the entire chain to execute and return the response
- if ``stream`` is passed it will give a streaming response with line by line JSON and last response containing ``"done":true``
- if ``as_task`` is passed then a task ID is received and you can poll for the results at ``/chains/{id}/results`` this supercedes the ``stream``.
"""
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
# validate input
if not prompt.session_id:
raise HTTPException(status_code=400, detail="Session ID not specified")
if prompt.chat_history:
raise HTTPException(status_code=400, detail="chat history is not supported yet")
if prompt.new_message and prompt.data:
raise HTTPException(status_code=400, detail="new_message and data cannot be passed together")
elif not prompt.new_message and not prompt.data:
raise HTTPException(status_code=400, detail="new_message or data must be passed")
# DB call
filters = [
DB.ChatBot.id == id,
DB.ChatBot.created_by == user.id,
DB.ChatBot.deleted_at == None,
]
chatbot = db.query(DB.ChatBot).filter(*filters).first() # type: ignore
if not chatbot:
resp.status_code = 404
return T.ApiResponse(message="ChatBot not found")
# call the engine
engine = engine_registry.get(chatbot.engine)
if engine is None:
raise HTTPException(status_code=400, detail=f"Invalid engine {chatbot.engine}")
if as_task:
# when run as a task this will return a task ID that will be submitted
result = engine.submit(
chatbot=chatbot,
prompt=prompt,
db=db,
start=time.time(),
store_ir=store_ir,
store_io=store_io,
)
return result
elif stream:
def _get_streaming_response(result):
for ir, done in result:
if done:
ir.pop("result")
result = {**ir, "done": done}
else:
if type(ir) == str:
ir = {"main_out": ir}
result = {**ir, "done": done}
yield json.dumps(result) + "\n"
streaming_result = engine.stream(
chatbot=chatbot,
prompt=prompt,
db=db,
start=time.time(),
store_ir=store_ir,
store_io=store_io,
)
return StreamingResponse(content=_get_streaming_response(streaming_result))
else:
result = engine.run(
chatbot=chatbot,
prompt=prompt,
db=db,
start=time.time(),
store_ir=store_ir,
store_io=store_io,
)
return result
[docs]def get_chain_metrics(
req: Request,
resp: Response,
id: str,
token: Annotated[str, Header()],
db: Session = Depends(DB.fastapi_db_session),
):
# validate user
user = DB.get_user_from_jwt(token=token, db=db)
# DB call
results = db.query(func.count()).filter(DB.Prompt.chatbot_id == id).all() # type: ignore
metrics = {"total_conversations": results[0][0]}
hourly_average_latency = (
db.query(DB.Prompt)
.filter(DB.Prompt.chatbot_id == id) # type: ignore
.filter(DB.Prompt.created_at >= datetime.now() - timedelta(hours=24))
.with_entities(
(func.substr(DB.Prompt.created_at, 1, 14)).label("hour"),
func.avg(DB.Prompt.time_taken).label("avg_time_taken"),
)
.group_by((func.substr(DB.Prompt.created_at, 1, 14)))
.all()
)
latency_per_hour = []
for item in hourly_average_latency:
created_datetime = item[0] + "00:00"
latency_per_hour.append({"created_at": created_datetime, "time": item[1]})
return {"metrics": metrics, "latencies": latency_per_hour}