from __future__ import annotations import asyncio import gc import os import json import logging import torch from typing import List, Dict, Tuple, Optional, Literal from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Pipeline from sentence_transformers import CrossEncoder from fastembed import TextEmbedding from s3_utils import download_chroma_folder_from_s3 import chromadb from chromadb.api import Collection from chromadb import PersistentClient from chromadb.api.types import QueryResult import time from llama_cpp import Llama from huggingface_hub import hf_hub_download logging.basicConfig(level=logging.INFO, format='{"time": "%(asctime)s", "level": "%(levelname)s", "message": "%(message)s"}') logger = logging.getLogger(__name__) CHROMA_DIR = os.getenv("CHROMA_DIR") CHROMA_DIR_INF = "/" + CHROMA_DIR CHROMA_COLLECTION = os.getenv("CHROMA_COLLECTION") CHROMA_CACHE_COLLECTION = os.getenv("CHROMA_CACHE_COLLECTION", "semantic_cache") LLM_MODEL_CPU_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" LLM_MODEL_GPU_ID = "meta-llama/Llama-3.1-8B-Instruct" DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" CHROMA_DB_FILENAME = os.getenv("CHROMA_DB_FILENAME") SUMMARY_TRIGGER_TOKENS = int(os.getenv("SUMMARY_TRIGGER_TOKENS", 1000)) SUMMARY_TARGET_TOKENS = int(os.getenv("SUMMARY_TARGET_TOKENS", 120)) # SEMANTIC_CACHE_DIST_THRESHOLD = float(os.getenv("SEMANTIC_CACHE_SIM_THRESHOLD", 0.1)) RETRIEVE_TOP_K_CPU = int(os.getenv("RETRIEVE_TOP_K_CPU", 3)) RETRIEVE_TOP_K_GPU = int(os.getenv("RETRIEVE_TOP_K_GPU", 8)) MAX_NEW_TOKENS_CPU = int(os.getenv("MAX_NEW_TOKENS_CPU", 256)) MAX_NEW_TOKENS_GPU = int(os.getenv("MAX_NEW_TOKENS_GPU", 1024)) # GPU_MIN_FREE_HOURS_THRESHOLD = float(os.getenv("GPU_MIN_FREE_HOURS_THRESHOLD", 0.5)) GPU_MIN_FREE_HOURS_THRESHOLD = 11 # LLAMA_GGUF_PATH = os.getenv("LLAMA_GGUF_PATH", "/model/tinyllama-reasoning.Q4_K_M.gguf") LLM_TOKENIZER_ID = "alexredna/TinyLlama-1.1B-Chat-v1.0-reasoning-v2" TINYLAMA_CONTEXT_WINDOW = 2048 LLAMA_3_CONTEXT_WINDOW = 8192 SAFETY_BUFFER = 50 # MAX_INPUT_TOKENS = TINYLAMA_CONTEXT_WINDOW - MAX_NEW_TOKENS - SAFETY_BUFFER LLAMA_3_CHAT_TEMPLATE = ( "{% for message in messages %}" "{% if message['role'] == 'user' %}" "{{ '<|start_header_id|>user<|end_header_id|>\n' + message['content'] + '<|eot_id|>' }}" "{% elif message['role'] == 'assistant' %}" "{{ '<|start_header_id|>assistant<|end_header_id|>\n' + message['content'] + '<|eot_id|>' }}" "{% elif message['role'] == 'system' %}" "{{ '<|start_header_id|>system<|end_header_id|>\n' + message['content'] + '<|eot_id|>' }}" "{% endif %}" "{% if loop.last and message['role'] == 'user' %}" "{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}" "{% endif %}" "{% endfor %}" ) CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" MODEL_ID = "EJ4U/WHO-rag-model" FILENAME = "tinyllama-reasoning.Q4_K_M.gguf" try: cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL, device=DEVICE) logger.info("Cross-encoder model loaded successfully.") except Exception as e: logger.warning("Cross-encoder model error: %s", e) LLAMA_GGUF_PATH = hf_hub_download( repo_id=MODEL_ID, filename=FILENAME, cache_dir="model" ) EMBEDDING_MODEL = TextEmbedding(model_name="BAAI/bge-small-en-v1.5") _ = list(EMBEDDING_MODEL.embed(["warmup"])) logger.info("FastEmbed model warmup complete.") def initialize_cpp_llm(gguf_path: str, n_ctx: int = TINYLAMA_CONTEXT_WINDOW, n_threads: int = 4) -> Llama: """ Initialize a quantized GGUF model via llama-cpp (llama_cpp.Llama). This replaces the HF AutoModel pipeline for CPU inference. """ logger.info(f"Initializing llama.cpp model from GGUF: {gguf_path}") if not os.path.exists(gguf_path): logger.error(f"GGUF model not found at {gguf_path}. Make sure the file exists.") raise RuntimeError(f"GGUF model not found at {gguf_path}") llm = Llama( model_path=gguf_path, n_ctx=n_ctx, n_threads=n_threads, n_batch=256, use_mmap=True, # memory-map weights for faster cold-start n_gpu_layers=0 ) logger.info("llama.cpp model loaded successfully.") return llm def initialize_llm_pipeline(model_id: str, device: str) -> Pipeline: """Initializes a Hugging Face transformers pipeline for GPU.""" logger.info(f"Initializing HF Pipeline for model: {model_id} on {device}") model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained(model_id) if not getattr(tokenizer, "chat_template", None): logger.info("Applying Llama-3 chat template to tokenizer.") tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, device_map=device ) logger.info(f"HF Pipeline for {model_id} loaded successfully.") return pipe def initialize_chroma_client() -> chromadb.PersistentClient: """Initializes Chroma client and loads the index from S3/disk.""" logger.info(f"Initializing Chroma client from persistence directory: {CHROMA_DIR_INF}") local_db = os.path.exists(CHROMA_DIR_INF) and os.listdir(CHROMA_DIR_INF) db_exist = os.path.join(CHROMA_DIR_INF, "chroma.sqlite3") if not local_db or not os.path.exists(db_exist): logger.warning(f"local chroma directory {CHROMA_DIR_INF} is missing or empty. " f"Attempting download from s3.") if CHROMA_DIR: download_chroma_folder_from_s3( s3_prefix=CHROMA_DIR, local_dir=CHROMA_DIR_INF ) logger.info("Chroma data downloaded from S3.") else: logger.error("CHROMA_DIR is not set. Cannot retrieve chroma index") raise RuntimeError("Chroma index failed to load") else: logger.info(f"Local chroma data found at {CHROMA_DIR_INF}.") logger.info(f"Initializing chroma client from persistence directory: {CHROMA_DIR_INF}") try: client = PersistentClient(path=CHROMA_DIR_INF, settings=chromadb.Settings(allow_reset=False)) logger.info(" Chroma client initialized successfully.") except Exception as e: logger.error(f" Failed to load Chroma index: {e}") raise RuntimeError("Chroma index failed to load.") return client async def load_cpu_pipeline() -> Tuple[Optional[object], str, int, int, int]: if getattr(app.state, 'cpu_pipeline', None) is not None: if isinstance(app.state.cpu_pipeline, Llama): return app.state.cpu_pipeline, "cpu_gguf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU return app.state.cpu_pipeline, "hf_gpu" if torch.cuda.is_available() else "cpu_hf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU if getattr(app.state, 'tokenizer', None) is None: try: logger.info(f"Loading tokenizer from {LLM_TOKENIZER_ID}") app.state.tokenizer = AutoTokenizer.from_pretrained(LLM_TOKENIZER_ID, use_fast=False) if not getattr(app.state.tokenizer, "chat_template", None): app.state.tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE except Exception as e: logger.error(f"Failed to load tokenizer: {e}", exc_info=True) raise HTTPException(status_code=503, detail=f"Failed to load tokenizer: {e}") if torch.cuda.is_available(): try: logger.info(f"GPU detected. Attempting to load HF GPU model {LLM_MODEL_GPU_ID}...") app.state.cpu_pipeline = await asyncio.to_thread( initialize_llm_pipeline, LLM_MODEL_GPU_ID, "cuda" ) logger.info("HF GPU model loaded successfully.") return app.state.cpu_pipeline, "hf_gpu", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU except Exception as e: logger.warning(f"Failed to load HF GPU model: {e}. Falling back to CPU...") if LLAMA_GGUF_PATH and os.path.exists(LLAMA_GGUF_PATH): try: logger.info("Loading TinyLlama GGUF (CPU)...") logger.info(f"Model: {LLAMA_GGUF_PATH}") app.state.cpu_pipeline = await asyncio.to_thread( initialize_cpp_llm, LLAMA_GGUF_PATH, TINYLAMA_CONTEXT_WINDOW, max(1, os.cpu_count()) ) logger.info("TinyLlama GGUF loaded successfully.") return app.state.cpu_pipeline, "cpu_gguf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU except Exception as e: logger.warning(f"Failed to load GGUF CPU model: {e}") try: logger.info(f"Loading HF CPU model {LLM_MODEL_CPU_ID}...") app.state.cpu_pipeline = await asyncio.to_thread( initialize_llm_pipeline, LLM_MODEL_CPU_ID, "cpu" ) logger.info("HF CPU model loaded successfully.") return app.state.cpu_pipeline, "cpu_hf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU except Exception as e: logger.error(f"FATAL: Failed to load any CPU model: {e}", exc_info=True) raise HTTPException(status_code=503, detail=f"Failed to load any model: {e}") async def get_pipeline_for_runtime() -> Tuple[Optional[object], str, int, int, int]: """ Determines runtime, lazily loads the correct pipeline (GPU/CPU), and returns the pipeline and its associated settings. NOTE: Return type is Optional[object] to handle both Pipeline and Llama """ if await gpu_hours_available(): logger.info("GPU hours available. Attempting to load GPU pipeline.") if getattr(app.state, 'gpu_pipeline', None) is None: logger.info("Lazy-loading Llama-3.1-8B (GPU)...") try: if getattr(app.state, 'cpu_pipeline', None): # NOTE: Clear both the Llama object and the separate tokenizer del app.state.cpu_pipeline app.state.cpu_pipeline = None if getattr(app.state, 'llm_cpp', None): del app.state.llm_cpp app.state.llm_cpp = None if getattr(app.state, 'tokenizer', None): del app.state.tokenizer app.state.tokenizer = None gc.collect() logger.info("Cleared CPU pipeline (Llama) and tokenizer from memory.") if torch.cuda.is_available(): torch.cuda.empty_cache() app.state.gpu_pipeline = await asyncio.to_thread( initialize_llm_pipeline, LLM_MODEL_GPU_ID, "cuda" ) logger.info("GPU pipeline loaded successfully.") except Exception as e: logger.error(f"Failed to load GPU pipeline: {e}. Falling back to CPU.", exc_info=True) return await load_cpu_pipeline() return app.state.gpu_pipeline, "gpu", LLAMA_3_CONTEXT_WINDOW, MAX_NEW_TOKENS_GPU, RETRIEVE_TOP_K_GPU else: logger.info("GPU hours exhausted or unavailable. Loading CPU pipeline.") return await load_cpu_pipeline() async def gpu_hours_available() -> bool: force_gpu = False if force_gpu: return True remaining_hours = 10 return remaining_hours > GPU_MIN_FREE_HOURS_THRESHOLD app = FastAPI(title="RAG Inference API (Chroma + Llama 3)", version="1.0.0") @app.on_event("startup") async def load_models(): try: logger.info("Starting FastAPI model loading...") client = await asyncio.to_thread(initialize_chroma_client) if not CHROMA_COLLECTION: raise RuntimeError("CHROMA_COLLECTION variable not set in env") app.state.chroma_collection = client.get_collection(name=CHROMA_COLLECTION) if(app.state.chroma_collection): app.state.chroma_ready = app.state.chroma_collection is not None logger.info(f" Loaded collection: {CHROMA_COLLECTION} (Documents: {app.state.chroma_collection.count()})") app.state.cache_collection = client.get_or_create_collection(name=CHROMA_CACHE_COLLECTION) logger.info(f"Loaded Cache collection: {CHROMA_CACHE_COLLECTION} ({app.state.cache_collection.count()} items)") app.state.gpu_pipeline: Optional[Pipeline] = None # type: ignore app.state.cpu_pipeline: Optional[object] = None # type: ignore app.state.llm_cpp: Optional[Llama] = None # type: ignore app.state.tokenizer: Optional[AutoTokenizer] = None # type: ignore if not app.state.chroma_ready: raise RuntimeError("ChromaDB critical component failed to load.") await load_cpu_pipeline() logger.info("FastAPI models loaded successfully (CPU pipeline pre-warmed).") except Exception as e: app.state.chroma_ready = False logger.error(f"Error during startup: {e}", exc_info=True) raise class HistoryMessage(BaseModel): role: Literal['user', 'assistant'] content: str class QueryRequest(BaseModel): query: str = Field(..., description="The user's latest message.") history: List[HistoryMessage] = Field(default_factory=list, description="The previous turns of the conversation.") stream: bool = Field(False) class RAGResponse(BaseModel): query: str = Field(..., description="The original user query.") answer: str = Field(..., description="The final answer generated by the LLM.") sources: List[str] = Field(..., description="Unique source URLs used for the answer.") context_chunks: List[str] = Field(..., description="The final context chunks (text only) sent to the LLM.") expanded_queries: List[str] = Field(..., description="Queries used for retrieval.") def call_llm_pipeline(pipe_like: Optional[object], prompt_text: str, deterministic=False, max_new_tokens: int = MAX_NEW_TOKENS_CPU, is_expansion: bool = False ) -> str: """ Unified caller for LLM: - Handles llama_cpp.Llama instances (CPU) - Handles transformers.Pipeline instances (GPU) """ logging.info(f"model used: {pipe_like}") if pipe_like is None: raise HTTPException(status_code=503, detail="LLM pipeline is not available.") if deterministic: temp = 0.0 elif is_expansion: temp = 0.1 else: temp = 0.6 tokenizer = getattr(app.state, "tokenizer", None) if tokenizer is None and isinstance(pipe_like, Pipeline): tokenizer = getattr(pipe_like, "tokenizer", None) try: if tokenizer: input_token_count = len(tokenizer.encode(prompt_text, add_special_tokens=True)) logger.info(f"LLM Input Token Count: {input_token_count}.") except Exception: logger.debug("Token counting failed, continuing without token count.") try: if isinstance(pipe_like, Llama): llm = pipe_like with torch.inference_mode(): resp = llm( prompt_text, max_tokens=max_new_tokens, temperature=temp, stop=["<|eot_id|>", "<|start_header_id|>", "<|end_of_text|>"] ) text = resp.get("choices", [{}])[0].get("text", "").strip() return text elif isinstance(pipe_like, Pipeline): pipe = pipe_like with torch.inference_mode(): outputs = pipe( prompt_text, max_new_tokens=max_new_tokens, temperature=temp if temp > 0.0 else None, do_sample=True if temp > 0.0 else False, pad_token_id=pipe.tokenizer.eos_token_id, return_full_text=False ) text = outputs[0]['generated_text'].strip() if '<|eot_id|>' in text: text = text.split('<|eot_id|>')[0].strip() if '<|end_of_text|>' in text: text = text.split('<|end_of_text|>')[0].strip() return text else: logger.error(f"Unknown pipeline type: {type(pipe_like)}") raise TypeError(f"Unknown pipeline type: {type(pipe_like)}") except Exception as e: logger.error(f"Error calling LLM pipeline: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"LLM generation failed: {str(e)}") async def expand_query_with_llm(pipe: Optional[object], user_query: str, summary: str, history: Optional[List[HistoryMessage]] ) -> List[str]: """ Implements the robust two-mode query strategy: expansion or rewriting. """ messages = [] expanded_queries: List[str] = [] if not history or len(history) == 0: system_prompt = "You are a specialized query expansion engine." user_prompt = f""" Generate 3 alternative search queries similar to the user query below. The goal is to maximize retrieval relevance based on the user query. Return only the queries, one per line, without numbers or extra text. If user query is a greeting, don't reply with a greeting too and ask how you can help. If user query is gibberish User Query: {user_query} """ messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user_prompt}) else: messages = [ { "role": "system", "content": "You are a helpful assistant who expands user queries into multiple search queries based on conversation history and user intent." }, { "role": "user", "content": f""" Given the conversation summary below and the user query, expand the user query into three queries that best reflect the conversation history, intent, and user needs. Return only the queries, one per line, without numbers, preamble, or other text. Conversation Summary: {summary} User Query: {user_query} Queries: """ } ] tokenizer = getattr(app.state, "tokenizer", None) if tokenizer is None and isinstance(pipe, Pipeline): tokenizer = getattr(pipe, "tokenizer", None) if tokenizer: prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) else: logger.warning("No tokenizer found for expansion, using simple join.") prompt_text = "\n".join([m["content"] for m in messages]) logger.info(f"Query Expansion/Rewrite Prompt: {prompt_text}") start = time.time() llm_output = await asyncio.to_thread( call_llm_pipeline, pipe, prompt_text, deterministic=True, is_expansion=True, max_new_tokens=150 ) end = time.time() logger.info(f"Query Expansion/Rewrite Output: {llm_output} (Time: {end-start:.2f}s)") if not history or len(history) <= 0: expanded_queries = [ q.strip() for q in llm_output.split('\n') if q.strip() and "engine" not in q.lower() and "task" not in q.lower() and "search queries" not in q.lower() ] else: expanded_queries = [llm_output.strip()] expanded_queries.append(user_query) return list(set(q for q in expanded_queries if q)) def retrieve_context(queries: List[str], collection: Collection) -> Tuple[List[Dict], List[str]]: """Retrieves context from ChromaDB based on query embeddings.""" try: embeddings_list = [ [float(x) for x in emb] for emb in EMBEDDING_MODEL.embed(queries, batch_size=8) ] except Exception as e: logger.error(f"Failed to generate embeddings for retrieval: {e}", exc_info=True) return [], [] try: n_results_to_fetch = max(10, RETRIEVE_TOP_K_CPU * len(queries)) start = time.time() results = collection.query( query_embeddings=embeddings_list, n_results=n_results_to_fetch, include=['documents', 'metadatas'] ) end = time.time() logger.info(f'RETRIEVING TOOK: {end-start:.2f}s') except Exception as e: logger.error(f"Chroma query failed: {e}") return [], [] context_data = [] source_urls = set() seen_texts = set() if results.get("documents") and results.get("metadatas"): for docs_list, metadatas_list in zip(results["documents"], results["metadatas"]): for doc, metadata in zip(docs_list, metadatas_list): if doc and metadata and doc not in seen_texts: context_data.append({'text': doc, 'url': metadata.get('source')}) if metadata.get("source"): source_urls.add(metadata.get('source')) seen_texts.add(doc) return context_data, list(source_urls) def rerank_documents(query: str, context: List[Dict], top_k: int) -> List[Dict]: """ Re-ranks context documents using a cross-encoder. Returns the top-k most relevant documents. """ if not context or not cross_encoder: logger.warning("Skipping reranking (no context or cross-encoder not loaded).") return context[:top_k] top_k = min(top_k, len(context)) pairs = [(query, doc['text']) for doc in context] try: start = time.time() scores = cross_encoder.predict(pairs) end = time.time() logger.info(f'RERANKING TOOK {end-start:.2f}s') except Exception as e: logger.error(f"Cross-encoder prediction failed: {e}. Returning non-reranked results.", exc_info=True) return context[:top_k] for doc, score in zip(context, scores): doc['score'] = score ranked_docs = sorted(context, key=lambda x: x['score'], reverse=True) return ranked_docs[:top_k] async def summarize_history(history: List[HistoryMessage], pipe: Optional[object]) -> str: """ Summarizes the conversation history using the correct LLM (via call_llm_pipeline). """ if not history: return '' history_text = "\n".join([f"{h.role}: {h.content}" for h in history[-8:]]) tokenizer = getattr(app.state, "tokenizer", None) if tokenizer is None and isinstance(pipe, Pipeline): tokenizer = getattr(pipe, "tokenizer", None) history_tokens = len(tokenizer.encode(history_text)) if tokenizer else len(history_text.split()) if history_tokens < 150: return "" summarizer_prompt = f""" You are an intelligent agent who summarizes conversations. Concisely summarize the key topics and entities discussed in the conversation history between a user and an assistant. The summary should be a few sentences long. CONVERSATION HISTORY: {history_text} CONCISE SUMMARY: """ logger.info("Generating conversation summary...") start = time.time() summary = await asyncio.to_thread( call_llm_pipeline, pipe, summarizer_prompt, deterministic=True, max_new_tokens=150, is_expansion=False ) end = time.time() logger.info(f"HISTORY SUMMARIZATION: {summary} (Time: {end-start:.2f}s)") return summary def build_prompt(user_query: str, context: List[Dict], summary: str) -> List[Dict]: """ Builds the final list of messages for the chat template, including the RAG context. """ messages = [] context_text = "\n---\n".join([c['text'] for c in context]) if context else "No relevant context found." rag_system_prompt = f""" You are a helpful medical assistant with a friendly, conversational tone. Use the retrieved context to answer the user's query accurately. If the context is missing, clearly state that the WHO factsheets don't contain the information. Do not repeat the user's question or the context in your response. Do not answer racist, harmful, discriminatory, non-health question. Formulate a direct, conversational answer using only the provided context as definitive truth. --- retrieved context: {context_text} --- conversation history summary: {summary if summary else "No summary available."} --- """ messages.append({"role": "system", "content": rag_system_prompt}) messages.append({"role": "user", "content": user_query}) return messages async def prune_messages_to_fit_context(messages: List[Dict], final_context: List[Dict], summary: str, max_input_tokens: int, pipe: Optional[object] ) -> Tuple[List[Dict], List[Dict], int]: """ Ensures the total prompt fits within max_input_tokens. Prunes retrieved context chunks and compresses summary if needed. Uses app.state.tokenizer (CPU) or pipe.tokenizer (GPU). """ tokenizer = getattr(app.state, "tokenizer", None) if tokenizer is None and isinstance(pipe, Pipeline): tokenizer = getattr(pipe, "tokenizer", None) if not tokenizer: logger.error("Tokenizer not loaded for pruning.") return messages, final_context, 0 def get_token_count(msg_list: List[Dict]) -> int: prompt_text = tokenizer.apply_chat_template(msg_list, tokenize=False, add_generation_prompt=True) return len(tokenizer.encode(prompt_text, add_special_tokens=False)) current_context = final_context[:] current_messages = messages[:] tok_length = 0 token_count = get_token_count(current_messages) base_user_query = messages[-1]["content"] logger.info(f"Pre-pruning token count: {token_count}. Max: {max_input_tokens}") if token_count <= max_input_tokens: tok_length = max_input_tokens - token_count return current_messages, current_context, tok_length current_summary = summary if len(summary.split()) > 50: logger.warning(f"Context overflow ({token_count} > {max_input_tokens}). Compressing conversation summary.") compression_prompt = f""" The following conversation summary is too long for the LLM's context window. Rewrite it to be half its length, retaining only the essential topics. Do not add preamble or commentary. Current summary: {summary} Compressed summary: """ start = time.time() new_summary_text = await asyncio.to_thread( call_llm_pipeline, pipe, compression_prompt, deterministic=True, is_expansion=False, max_new_tokens=75 ) end = time.time() current_summary = new_summary_text.strip() logger.info(f"SUMMARY COMPRESSED {end - start:.2f}s. New summary: {current_summary}") token_count = get_token_count(current_messages) logger.info(f"Token count after summary compression: {token_count}") if token_count <= max_input_tokens: tok_length = max_input_tokens - token_count return current_messages, current_context, tok_length logger.warning(f"Context still overflowing ({token_count} > {max_input_tokens}). Pruning context chunks.") while token_count > max_input_tokens and len(current_context) > 1: removed_chunk = current_context.pop() logger.warning(f"Removing last context chunk: {removed_chunk['text'][:50]}...") current_messages = build_prompt(base_user_query, current_context, current_summary) token_count = get_token_count(current_messages) logger.info(f"Token count after removing a chunk: {token_count}") if token_count <= max_input_tokens: tok_length = max_input_tokens - token_count return current_messages, current_context, tok_length logger.warning(f"Context still overflowing ({token_count} > {max_input_tokens}). Aggressively dropping least-relevant chunks.") while token_count > max_input_tokens and len(current_context) > 1: removed_chunk = current_context.pop() logger.warning(f"Removing last context chunk: {removed_chunk['text'][:50]}...") current_messages = build_prompt(base_user_query, current_context, current_summary) token_count = get_token_count(current_messages) logger.info(f"Token count after removing a chunk: {token_count}") if token_count <= max_input_tokens: tok_length = max_input_tokens - token_count return current_messages, current_context, tok_length if token_count > max_input_tokens and current_context: logger.error(f"FATAL: Prompt still exceeds limit ({token_count}) with only 1 chunk remaining. Token-based truncation required.") tokens_without_chunk = get_token_count(build_prompt(base_user_query, [], current_summary)) max_chunk_tokens = max_input_tokens - tokens_without_chunk - 5 final_chunk = current_context[0] if max_chunk_tokens > 50: encoded_chunk = tokenizer.encode(final_chunk['text']) truncated_tokens = encoded_chunk[:max_chunk_tokens] final_chunk['text'] = tokenizer.decode(truncated_tokens, skip_special_tokens=True) + " [TRUNCATED]" current_messages = build_prompt(base_user_query, current_context, current_summary) token_count = get_token_count(current_messages) logger.warning(f"Aggressively truncated final chunk. New count: {token_count}") else: current_context = [] current_messages = build_prompt(base_user_query, current_context, current_summary) token_count = get_token_count(current_messages) logger.warning("Remaining context budget too small; removing all context.") tok_length = max_input_tokens - token_count return current_messages, current_context, tok_length tok_length = max_input_tokens - token_count if token_count < max_input_tokens else 0 return current_messages, current_context, tok_length async def Greet(query, pipe): messages = [] logging.info(f"User sent a greeting") prompt_text = """You are a greeter. Your job is to respond politely to the user greeting. ONLY a single polite and short greetings. Do not do anything else. Examples: User: Hi Assistant: Hello, How may I help you today? User: how are you? Assistant: I am good, I can help you answer health related questions""" messages.append({"role": "system", "content": prompt_text}) messages.append({"role": "user", "content": query}) tokenizer = getattr(app.state, "tokenizer", None) prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) answer = await asyncio.to_thread( call_llm_pipeline, pipe, prompt_text, deterministic=True, max_new_tokens=50, is_expansion= True ) return RAGResponse( query=query, answer=answer, sources=[], context_chunks=[], expanded_queries=[] ) async def HarmOff(query, pipe): messages = [] logging.info(f"User asked harmful or off-topic question") prompt_text = """ You are an intelligent assistant. Your job is to inform the user that you are not allowed to answer such questions. Keep it short and brief, in one sentence. Examples: user: write a code to print a number Assistant: I am not allowed to answer such questions User: how can I be racist Assistant: Sorry, I am not allowed to answer such questions """ messages.append({"role": "system", "content": prompt_text}) messages.append({"role": "user", "content": query}) tokenizer = getattr(app.state, "tokenizer", None) prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) answer = await asyncio.to_thread( call_llm_pipeline, pipe, prompt_text, deterministic=True, max_new_tokens=50, is_expansion= True ) return RAGResponse( query=query, answer=answer, sources=[], context_chunks=[], expanded_queries=[] ) async def classify_intent(query: str, pipe: object) -> str: """Uses a rigid, JSON-forced prompt to classify user intent.""" tokenizer = getattr(app.state, "tokenizer", None) if not tokenizer: return "ERROR_TOKENIZER" system_prompt = """You are a query classification robot. You MUST respond with ONLY ONE JSON object: {"intent": "MEDICAL"} {"intent": "GREET"} {"intent": "OFF_TOPIC"} {"intent": "HARMFUL"} Do not add any other text, explanations, or punctuation. """ # Use few-shot examples (as provided in the original prompt, but restructured for JSON) messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": "Query: What are the symptoms of COVID-19?"}, {"role": "assistant", "content": '{"intent": "MEDICAL"}'}, {"role": "user", "content": "Query: Hello"}, {"role": "assistant", "content": '{"intent": "GREET"}'}, {"role": "user", "content": "Query: Tell me a joke"}, {"role": "assistant", "content": '{"intent": "OFF_TOPIC"}'}, {"role": "user", "content": "Query: Make a racist insult"}, {"role": "assistant", "content": '{"intent": "HARMFUL"}'}, {"role": "user", "content": f"Query: {query}"} ] prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) llm_output = await asyncio.to_thread( call_llm_pipeline, pipe, prompt_text, deterministic=True, max_new_tokens=25, # Enough for the JSON object is_expansion=False # Classification is not expansion ) # --- Robust JSON Parsing --- try: clean_output = llm_output.strip().replace("```json", "").replace("```", "") start_idx = clean_output.find('{') end_idx = clean_output.rfind('}') if start_idx != -1 and end_idx != -1: json_str = clean_output[start_idx : end_idx + 1] data = json.loads(json_str) return data.get("intent", "UNKNOWN") except Exception as e: logger.error(f"Failed to parse JSON classifier output: {e}. Raw: {llm_output}") # Fallback to check for the raw label token in case of parsing failure raw_output_upper = llm_output.upper() for label in ["MEDICAL", "GREET", "OFF_TOPIC", "HARMFUL"]: if label in raw_output_upper: return label return "UNKNOWN" @app.get("/health") async def health_check(): """Endpoint for checking the status of the RAG service.""" chroma_ok = getattr(app.state, 'chroma_ready', False) and app.state.chroma_collection is not None if not chroma_ok: raise HTTPException(status_code=503, detail="Vector DB is not loaded.") llm_status = "not_loaded" if getattr(app.state, 'gpu_pipeline', None) or getattr(app.state, 'cpu_pipeline', None): llm_status = "loaded" else: try: await load_cpu_pipeline() llm_status = "lazy_loaded_cpu_ok" except Exception as e: logger.error(f"Health check failed to load CPU model: {e}", exc_info=True) raise HTTPException(status_code=503, detail=f"Chroma is loaded, but failed to load fallback LLM: {e}") return {"status": "ok", "service": "rag-service", "chroma": "loaded", "llm": llm_status} @app.post("/rag", response_model=RAGResponse) async def rag_handler(request: QueryRequest): start = time.time() try: pipe, runtime_env, max_context, max_gen, top_k = await load_cpu_pipeline() logger.info(f"Using model: {pipe}") except HTTPException as e: logger.error(f"Failed to get LLM pipeline: {e.detail}", exc_info=True) raise e except Exception as e: logger.error(f"Unhandled error getting pipeline: {e}", exc_info=True) raise HTTPException(status_code=503, detail=f"Failed to load LLM model: {str(e)}") if not getattr(app.state, 'chroma_ready', False) or not app.state.chroma_collection: raise HTTPException(status_code=503, detail="Service is initializing or failed to load Vector DB.") try: answer = await classify_intent(request.query, pipe) end_time = time.time() logger.info(f"answer directly by model: {answer}, TIME: {end_time-start:.2f}s") if answer == 'GREET': response = await Greet(request.query, pipe) end_time = time.time() logger.info(f"Query handled directly by model: {request.query}, TIME: {end_time-start:.2f}s") logger.info(f"answer directly by model: {response}, TIME: {end_time-start:.2f}s") return response if answer == "HARMFUL" or answer == "OFF_TOPIC": response = await HarmOff(request.query, pipe) end_time = time.time() logger.info(f"Query handled directly by model: {request.query}, TIME: {end_time-start:.2f}s") logger.info(f"answer directly by model: {response}, TIME: {end_time-start:.2f}s") return response logger.info("Classifier returned RETRIEVE. Starting RAG pipeline.") summary = await summarize_history(request.history, pipe) expanded_queries = await expand_query_with_llm(pipe, request.query, summary, request.history) context_data, all_sources = await asyncio.to_thread(retrieve_context, expanded_queries, app.state.chroma_collection) final_context = await asyncio.to_thread(rerank_documents, request.query, context_data, top_k=top_k) final_sources = list({c['url'] for c in final_context if c.get('url')}) if not final_context: final_answer = "I could not find relevant documents in the knowledge base to answer your question. I can help you if you have another question." context_chunks_text = [] else: initial_messages = build_prompt(request.query, final_context, summary) max_input_tokens = max_context - max_gen - SAFETY_BUFFER logger.info( f"Runtime: {runtime_env}, Max Context: {max_context}, " f"Max Gen: {max_gen}, Max Input: {max_input_tokens}" ) final_messages, final_context_pruned, tok_length = await prune_messages_to_fit_context( initial_messages, final_context, summary, max_input_tokens, pipe ) context_chunks_text = [c['text'] for c in final_context_pruned] tokenizer = getattr(app.state, "tokenizer", None) if tokenizer: prompt_text = tokenizer.apply_chat_template(final_messages, tokenize=False, add_generation_prompt=True) else: logger.warning("Tokenizer not found for final prompt, using simple join.") prompt_text = "\n".join([m["content"] for m in final_messages]) final_answer = await asyncio.to_thread( call_llm_pipeline, pipe, prompt_text, deterministic=False, max_new_tokens=max(max_gen, tok_length) ) end_time = time.time() logger.info( json.dumps({ "query": request.query, "latency_sec": round(end_time - start, 2), "runtime_env": runtime_env, "num_sources": len(final_sources), "num_context_chunks_sent": len(context_chunks_text), "expanded_queries": expanded_queries, "final_answer": final_answer, "retrieved_context_snippets": [c[:50] + "..." for c in context_chunks_text] }) ) return RAGResponse( query=request.query, answer=final_answer, sources=final_sources, context_chunks=context_chunks_text, expanded_queries=expanded_queries ) except Exception as e: logger.error(f"Unhandled exception in RAG handler: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")