Source code for chainfury_server.database

import os
import jwt
import json
import random, string
from datetime import datetime
from enum import Enum as EnumType
from fastapi import HTTPException
from passlib.hash import sha256_crypt
from dataclasses import dataclass, asdict
from typing import Dict, Any

from sqlalchemy.pool import QueuePool
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from sqlalchemy import Column, ForeignKey, Integer, String, JSON, Text, Float, DateTime, Enum, create_engine

from chainfury_server.utils import logger, Env, folder, joinp

########
#
# Init things
#
########

Base = declarative_base()

ID_LENGTH = 8

db = Env.CFS_DATABASE("")
if not db:
    # create a sqlite in the chainfury directory
    cf_folder = os.path.expanduser("~/cf")
    os.makedirs(cf_folder, exist_ok=True)
    db = "sqlite:///" + cf_folder + "/cfs.db"
    logger.warning(f"No database passed will connect to local SQLite: {db}")
    engine = create_engine(
        db,
        connect_args={
            "check_same_thread": False,
        },
    )
else:
    logger.info(f"Using via database URL")
    engine = create_engine(
        db,
        poolclass=QueuePool,
        pool_size=10,
        pool_recycle=30,
        pool_pre_ping=True,
    )


########
#
# Helper Functions
#
########


[docs]def get_random_alphanumeric_string(length) -> str: letters_and_digits = string.ascii_letters + string.digits result_str = "".join((random.choice(letters_and_digits) for i in range(length))) return result_str
[docs]def get_random_number(length) -> int: smallest_number = 10 ** (length - 1) largest_number = (10**length) - 1 random_numbers = random.randint(smallest_number, largest_number) return random_numbers
[docs]def get_local_session() -> sessionmaker: logger.debug("Database opened successfully") SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) return SessionLocal
[docs]def db_session() -> Session: # type: ignore session_factory = sessionmaker(bind=engine) session_class = scoped_session(session_factory) # now all calls to session_class() will create a thread-local session_class try: return session_class() finally: session_class.remove()
[docs]def fastapi_db_session(): sess_cls = get_local_session() db = sess_cls() try: yield db finally: db.close()
[docs]def unique_string(table, row_reference, length=ID_LENGTH): """ Gets Random Unique String for Primary key and makes sure its unique for the table. """ db = db_session() random_string = get_random_alphanumeric_string(length).lower() while db.query(table).filter(row_reference == random_string).limit(1).first() is not None: # type: ignore random_string = get_random_alphanumeric_string(length).lower() return random_string
[docs]def unique_number(Table, row_reference, length=ID_LENGTH): """ Gets Random Unique Number for Primary key and makes sure its unique for the table. """ db = db_session() random_number = get_random_number(length) while db.query(Table).filter(row_reference == random_number).limit(1).first() is not None: # type: ignore random_number = get_random_number(length) return random_number
[docs]def add_default_user(): admin_password = sha256_crypt.hash("admin") db = db_session() try: db.add(User(username="admin", password=admin_password, email="admin@admin.com", meta="")) # type: ignore db.commit() except IntegrityError as e: logger.info("Not adding default user")
[docs]def add_default_templates(): db = db_session() try: ex_folder = joinp(folder(__file__), "examples") # with open("./examples/index.json") as f: with open(joinp(ex_folder, "index.json")) as f: data = json.load(f) for template_data in data: template = db.query(Template).filter_by(id=template_data["id"]).first() if template: template.name = template_data["name"] template.description = template_data["description"] # with open("./examples/" + template_data["dag"]) as f: with open(joinp(ex_folder, template_data["dag"])) as f: dag = json.load(f) template.dag = dag else: # with open("./examples/" + template_data["dag"]) as f: with open(joinp(ex_folder, template_data["dag"])) as f: dag = json.load(f) template = Template( id=template_data["id"], name=template_data["name"], description=template_data["description"], dag=dag, ) # type: ignore db.add(template) db.commit() except IntegrityError as e: logger.info("Not adding default templates")
######## # # Tables # ########
[docs]class User(Base): __tablename__ = "user" id: str = Column(String(8), default=lambda: unique_string(User, User.id), primary_key=True) email: str = Column(String(80), unique=True, nullable=False) username: str = Column(String(80), unique=True, nullable=False) password: str = Column(String(80), nullable=False) meta: Dict[str, Any] = Column(JSON) def __repr__(self): return f"User(id={self.id}, username={self.username}, meta={self.meta})"
[docs]class ChatBotTypes: LANGFLOW = "langflow" FURY = "fury"
[docs] def all(): # type: ignore return [getattr(ChatBotTypes, attr) for attr in dir(ChatBotTypes) if not attr.startswith("__")]
[docs]class ChatBot(Base): __tablename__ = "chatbot" id: str = Column( String(8), default=lambda: unique_string(ChatBot, ChatBot.id), primary_key=True, ) name: str = Column(String(80), unique=False) description: str = Column(Text, nullable=True) created_by: str = Column(String(8), ForeignKey("user.id"), nullable=False) dag: Dict[str, Any] = Column(JSON) meta: Dict[str, Any] = Column(JSON) engine: str = Column(String(80), nullable=False) tag_id: str = Column(String(80), nullable=True) created_at: datetime = Column(DateTime, nullable=False) deleted_at: datetime = Column(DateTime, nullable=True)
[docs] def to_dict(self): return { "id": self.id, "name": self.name, "description": self.description, "created_by": self.created_by, "dag": self.dag, "meta": self.meta, "engine": self.engine, "tag_id": self.tag_id, "created_at": self.created_at, "deleted_at": self.deleted_at, }
def __repr__(self): return f"ChatBot(id={self.id}, name={self.name}, created_by={self.created_by}, dag={self.dag}, meta={self.meta})"
[docs]class PromptRating(EnumType): """Enum to know how the conversation went with chat.""" UNRATED = 0 SAD = 1 NEUTRAL = 2 HAPPY = 3
[docs]class Prompt(Base): __tablename__ = "prompt" id: int = Column( Integer, default=lambda: unique_number(Prompt, Prompt.id), primary_key=True, ) chatbot_id: str = Column(String(8), ForeignKey("chatbot.id"), nullable=False) input_prompt: str = Column(Text, nullable=False) response: str = Column(Text, nullable=True) gpt_rating: str = Column(String(5), nullable=True) user_rating: int = Column(Enum(PromptRating), nullable=True) time_taken: float = Column(Float, nullable=True) num_tokens: int = Column(Integer, nullable=True) created_at: datetime = Column(DateTime, nullable=False) session_id: Dict[str, Any] = Column(String(80), nullable=False) meta: Dict[str, Any] = Column(JSON)
[docs] def to_dict(self): return { "id": self.id, "chatbot_id": self.chatbot_id, "input_prompt": self.input_prompt, "response": self.response, "gpt_rating": self.gpt_rating, "user_rating": self.user_rating, "time_taken": self.time_taken, "num_tokens": self.num_tokens, "created_at": self.created_at, "session_id": self.session_id, "meta": self.meta, }
[docs]class ChainLog(Base): __tablename__ = "chain_logs" id: str = Column( String(16), default=lambda: unique_string(ChainLog, ChainLog.id, 16), primary_key=True, ) created_at: datetime = Column(DateTime, nullable=False) prompt_id: int = Column(Integer, ForeignKey("prompt.id"), nullable=False) node_id: str = Column(String(Env.CFS_MAX_NODE_ID_LEN()), nullable=False) worker_id: str = Column(String(Env.CF_MAX_WORKER_ID_LEN()), nullable=False) message: str = Column(Text, nullable=False) data: Dict[str, Any] = Column(JSON, nullable=True)
[docs]class Template(Base): __tablename__ = "template" id: int = Column( Integer, default=lambda: unique_number(Template, Template.id), primary_key=True, ) name: str = Column(Text, nullable=False) dag: Dict[str, Any] = Column(JSON, nullable=False) description: str = Column(Text) meta: Dict[str, Any] = Column(JSON)
[docs] def to_dict(self): return { "id": self.id, "name": self.name, "dag": self.dag, "description": self.description, "meta": self.meta, }
Base.metadata.create_all(bind=engine) # type: ignore ######## # # JWT Helpers # ########
[docs]@dataclass class JWTPayload: username: str user_id: str
[docs] def to_dict(self): return asdict(self)
[docs]def get_user_from_jwt(token, db: Session) -> User: try: payload = jwt.decode(token, key=Env.JWT_SECRET(), algorithms=["HS256"]) payload = JWTPayload( username=payload.get("username", ""), user_id=payload.get("user_id", "") or payload.get("userid", ""), # grandfather 'userid' ) except Exception as e: logger.error("Could not decode JWT token") raise HTTPException(status_code=401, detail="Could not decode JWT token") logger.debug(f"Verifying user {payload.username}") return db.query(User).filter(User.username == payload.username).first() # type: ignore