Source code for chainfury_server.engines.fury

import time
import json
import traceback
from uuid import uuid4
from fastapi import HTTPException
from sqlalchemy.orm import Session
from typing import Tuple, Dict, Any, Generator, Union

import chainfury.types as T
from chainfury import Chain
from chainfury.utils import SimplerTimes

import chainfury_server.database as DB
from chainfury_server.utils import logger
from chainfury_server.engines.registry import EngineInterface, engine_registry


[docs]class FuryEngine(EngineInterface): @property def engine_name(self) -> str: return "fury"
[docs] def run( self, chatbot: DB.ChatBot, prompt: T.ApiPromptBody, db: Session, start: float, store_ir: bool, store_io: bool, ) -> T.CFPromptResult: if prompt.new_message and prompt.data: raise HTTPException(status_code=400, detail="prompt cannot have both new_message and data") try: logger.debug("Adding prompt to database") prompt_row = create_prompt(db, chatbot.id, prompt.new_message if store_io else "", prompt.session_id) # type: ignore # Create a Fury chain then run the chain while logging all the intermediate steps dag = T.Dag(**chatbot.dag) # type: ignore chain = Chain.from_dag(dag, check_server=False) callback = FuryThoughts(db, prompt_row.id) if prompt.new_message: prompt.data = {chain.main_in: prompt.new_message} # call the chain mainline_out, full_ir = chain( data=prompt.data, thoughts_callback=callback, print_thoughts=False, ) result = T.CFPromptResult( result=json.dumps(mainline_out) if type(mainline_out) != str else mainline_out, prompt_id=prompt_row.id, # type: ignore ) # commit the prompt to DB if store_io: prompt_row.response = result.result # type: ignore prompt_row.time_taken = float(time.time() - start) # type: ignore db.commit() # result["prompt_id"] = prompt_row.id logger.debug("Processed graph") return result except Exception as e: traceback.print_exc() logger.exception(e) raise HTTPException(status_code=500, detail=str(e)) from e
[docs] def stream( self, chatbot: DB.ChatBot, prompt: T.ApiPromptBody, db: Session, start: float, store_ir: bool, store_io: bool, ) -> Generator[Tuple[Union[T.CFPromptResult, Dict[str, Any]], bool], None, None]: if prompt.new_message and prompt.data: raise HTTPException(status_code=400, detail="prompt cannot have both new_message and data") try: logger.debug("Adding prompt to database") prompt_row = create_prompt(db, chatbot.id, prompt.new_message if store_io else "", prompt.session_id) # type: ignore # Create a Fury chain then run the chain while logging all the intermediate steps dag = T.Dag(**chatbot.dag) # type: ignore chain = Chain.from_dag(dag, check_server=False) callback = FuryThoughts(db, prompt_row.id) if prompt.new_message: prompt.data = {chain.main_in: prompt.new_message} # call the chain iterator = chain.stream( data=prompt.data, thoughts_callback=callback, print_thoughts=False, ) # full_ir = {} mainline_out = "" for ir, done in iterator: if not done: # full_ir.update(ir) yield ir, False else: mainline_out = ir yield ir, False result = T.CFPromptResult( result=str(mainline_out), prompt_id=prompt_row.id, # type: ignore ) # commit the prompt to DB if store_io: prompt_row.response = result.result # type: ignore prompt_row.time_taken = float(time.time() - start) # type: ignore db.commit() logger.debug("Processed graph") yield result, True except Exception as e: traceback.print_exc() logger.exception(e) raise HTTPException(status_code=500, detail=str(e)) from e
[docs] def submit( self, chatbot: DB.ChatBot, prompt: T.ApiPromptBody, db: Session, start: float, store_ir: bool, store_io: bool, ) -> T.CFPromptResult: if prompt.new_message and prompt.data: raise HTTPException(status_code=400, detail="prompt cannot have both new_message and data") try: logger.debug("Adding prompt to database") prompt_row = create_prompt(db, chatbot.id, prompt.new_message if store_io else "", prompt.session_id) # type: ignore # Create a Fury chain then run the chain while logging all the intermediate steps dag = T.Dag(**chatbot.dag) # type: ignore chain = Chain.from_dag(dag, check_server=False) if prompt.new_message: prompt.data = {chain.main_in: prompt.new_message} # call the chain task_id: str = str(uuid4()) result = T.CFPromptResult( result=f"Task '{task_id}' scheduled", prompt_id=prompt_row.id, task_id=task_id, ) if store_io: prompt_row.response = result.result # type: ignore prompt_row.time_taken = float(time.time() - start) # type: ignore db.commit() return result except Exception as e: traceback.print_exc() logger.exception(e) raise HTTPException(status_code=500, detail=str(e)) from e
engine_registry.register(FuryEngine()) # helpers
[docs]class FuryThoughts: def __init__(self, db, prompt_id): self.db = db self.prompt_id = prompt_id self.count = 0
[docs] def __call__(self, thought): intermediate_response = thought.get("value", "") if intermediate_response is None: intermediate_response = "" if type(intermediate_response) != str: intermediate_response = str(intermediate_response) # create_intermediate_steps(self.db, prompt_id=self.prompt_id, intermediate_response=intermediate_response) self.count += 1
# def create_intermediate_steps( # db: Session, # prompt_id: int, # intermediate_prompt: str = "", # intermediate_response: str = "", # response_json: Dict = {}, # ) -> DB.IntermediateStep: # db_prompt = DB.IntermediateStep( # prompt_id=prompt_id, # intermediate_prompt=intermediate_prompt, # intermediate_response=intermediate_response, # response_json=response_json, # created_at=SimplerTimes.get_now_datetime(), # ) # type: ignore # db.add(db_prompt) # db.commit() # db.refresh(db_prompt) # return db_prompt
[docs]def create_prompt(db: Session, chatbot_id: str, input_prompt: str, session_id: str) -> DB.Prompt: db_prompt = DB.Prompt( chatbot_id=chatbot_id, input_prompt=input_prompt, created_at=SimplerTimes.get_now_datetime(), session_id=session_id, ) # type: ignore db.add(db_prompt) db.commit() db.refresh(db_prompt) return db_prompt