GitHub Actions commited on
Commit
cb7edc9
·
1 Parent(s): 2ab3610

Deploy Wed Nov 19 16:30:12 UTC 2025

Browse files
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ FROM python:3.11-slim
3
+
4
+ #for huggingface space
5
+ ENV PORT = 7860
6
+ EXPOSE 7860
7
+
8
+ WORKDIR /app
9
+
10
+
11
+ RUN apt-get update \
12
+ && apt-get install -y --no-install-recommends \
13
+ build-essential \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+
17
+ COPY requirements_heavy.txt .
18
+ RUN pip install --timeout 2000 -r requirements_heavy.txt \
19
+ --extra-index-url https://download.pytorch.org/whl/cpu
20
+
21
+
22
+ COPY requirements_light.txt .
23
+ RUN pip install --timeout 2000 -r requirements_light.txt \
24
+ --extra-index-url https://download.pytorch.org/whl/cpu
25
+
26
+
27
+ # COPY inference_chroma.py .
28
+ COPY . .
29
+
30
+ # CMD ["uvicorn", "inference_chroma:app", "--host", "0.0.0.0", "--port", "8000"]
31
+ CMD ["uvicorn", "inference_chroma:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,10 @@
1
  ---
2
- title: WHO Rag System
3
  emoji: 😻
4
  colorFrom: pink
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: "WHO Rag System"
3
  emoji: 😻
4
  colorFrom: pink
5
  colorTo: purple
6
  sdk: docker
7
  pinned: false
8
  license: mit
9
+ app_port: 7860
10
+ ---
 
hg_login.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from huggingface_hub import login
2
+ import os
3
+
4
+ login(token=os.getenv('HF_TOKEN'))
inference_chroma.py ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import asyncio
3
+ import gc
4
+ import os
5
+ import json
6
+ import logging
7
+ import torch
8
+ from typing import List, Dict, Tuple, Optional, Literal
9
+ from fastapi import FastAPI, HTTPException
10
+ from pydantic import BaseModel, Field
11
+
12
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Pipeline
13
+ from sentence_transformers import CrossEncoder
14
+ from fastembed import TextEmbedding
15
+ from s3_utils import download_chroma_folder_from_s3
16
+
17
+ import chromadb
18
+ from chromadb.api import Collection
19
+ from chromadb import PersistentClient
20
+ from chromadb.api.types import QueryResult
21
+ import time
22
+ from llama_cpp import Llama
23
+
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ logging.basicConfig(level=logging.INFO, format='{"time": "%(asctime)s", "level": "%(levelname)s", "message": "%(message)s"}')
27
+ logger = logging.getLogger(__name__)
28
+
29
+ CHROMA_DIR = os.getenv("CHROMA_DIR")
30
+ CHROMA_DIR_INF = "/" + CHROMA_DIR
31
+ CHROMA_COLLECTION = os.getenv("CHROMA_COLLECTION")
32
+ CHROMA_CACHE_COLLECTION = os.getenv("CHROMA_CACHE_COLLECTION", "semantic_cache")
33
+
34
+ LLM_MODEL_CPU_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
35
+ LLM_MODEL_GPU_ID = "meta-llama/Llama-3.1-8B-Instruct"
36
+ DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
37
+ CHROMA_DB_FILENAME = os.getenv("CHROMA_DB_FILENAME")
38
+ SUMMARY_TRIGGER_TOKENS = int(os.getenv("SUMMARY_TRIGGER_TOKENS", 1000))
39
+ SUMMARY_TARGET_TOKENS = int(os.getenv("SUMMARY_TARGET_TOKENS", 120))
40
+ # SEMANTIC_CACHE_DIST_THRESHOLD = float(os.getenv("SEMANTIC_CACHE_SIM_THRESHOLD", 0.1))
41
+
42
+ RETRIEVE_TOP_K_CPU = int(os.getenv("RETRIEVE_TOP_K_CPU", 3))
43
+ RETRIEVE_TOP_K_GPU = int(os.getenv("RETRIEVE_TOP_K_GPU", 8))
44
+ MAX_NEW_TOKENS_CPU = int(os.getenv("MAX_NEW_TOKENS_CPU", 256))
45
+ MAX_NEW_TOKENS_GPU = int(os.getenv("MAX_NEW_TOKENS_GPU", 1024))
46
+ # GPU_MIN_FREE_HOURS_THRESHOLD = float(os.getenv("GPU_MIN_FREE_HOURS_THRESHOLD", 0.5))
47
+ GPU_MIN_FREE_HOURS_THRESHOLD = 11
48
+
49
+ # LLAMA_GGUF_PATH = os.getenv("LLAMA_GGUF_PATH", "/model/tinyllama-reasoning.Q4_K_M.gguf")
50
+ LLM_TOKENIZER_ID = "alexredna/TinyLlama-1.1B-Chat-v1.0-reasoning-v2"
51
+
52
+ TINYLAMA_CONTEXT_WINDOW = 2048
53
+ LLAMA_3_CONTEXT_WINDOW = 8192
54
+ SAFETY_BUFFER = 50
55
+ # MAX_INPUT_TOKENS = TINYLAMA_CONTEXT_WINDOW - MAX_NEW_TOKENS - SAFETY_BUFFER
56
+
57
+ LLAMA_3_CHAT_TEMPLATE = (
58
+ "{% for message in messages %}"
59
+ "{% if message['role'] == 'user' %}"
60
+ "{{ '<|start_header_id|>user<|end_header_id|>\n' + message['content'] + '<|eot_id|>' }}"
61
+ "{% elif message['role'] == 'assistant' %}"
62
+ "{{ '<|start_header_id|>assistant<|end_header_id|>\n' + message['content'] + '<|eot_id|>' }}"
63
+ "{% elif message['role'] == 'system' %}"
64
+ "{{ '<|start_header_id|>system<|end_header_id|>\n' + message['content'] + '<|eot_id|>' }}"
65
+ "{% endif %}"
66
+ "{% if loop.last and message['role'] == 'user' %}"
67
+ "{{ '<|start_header_id|>assistant<|end_header_id|>\n' }}"
68
+ "{% endif %}"
69
+ "{% endfor %}"
70
+ )
71
+
72
+ CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
73
+ MODEL_ID = "EJ4U/WHO-rag-model"
74
+ FILENAME = "tinyllama-reasoning.Q4_K_M.gguf"
75
+
76
+ try:
77
+ cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL, device=DEVICE)
78
+ logger.info("Cross-encoder model loaded successfully.")
79
+ except Exception as e:
80
+ logger.warning("Cross-encoder model error: %s", e)
81
+
82
+ LLAMA_GGUF_PATH = hf_hub_download(
83
+ repo_id=MODEL_ID,
84
+ filename=FILENAME,
85
+ cache_dir="model"
86
+ )
87
+
88
+ EMBEDDING_MODEL = TextEmbedding(model_name="BAAI/bge-small-en-v1.5")
89
+ _ = list(EMBEDDING_MODEL.embed(["warmup"]))
90
+ logger.info("FastEmbed model warmup complete.")
91
+
92
+
93
+ def initialize_cpp_llm(gguf_path: str, n_ctx: int = TINYLAMA_CONTEXT_WINDOW, n_threads: int = 4) -> Llama:
94
+ """
95
+ Initialize a quantized GGUF model via llama-cpp (llama_cpp.Llama).
96
+ This replaces the HF AutoModel pipeline for CPU inference.
97
+ """
98
+ logger.info(f"Initializing llama.cpp model from GGUF: {gguf_path}")
99
+ if not os.path.exists(gguf_path):
100
+ logger.error(f"GGUF model not found at {gguf_path}. Make sure the file exists.")
101
+ raise RuntimeError(f"GGUF model not found at {gguf_path}")
102
+
103
+ llm = Llama(
104
+ model_path=gguf_path,
105
+ n_ctx=n_ctx,
106
+ n_threads=n_threads,
107
+ n_batch=256,
108
+ use_mmap=True, # memory-map weights for faster cold-start
109
+ n_gpu_layers=0
110
+ )
111
+ logger.info("llama.cpp model loaded successfully.")
112
+ return llm
113
+
114
+
115
+ def initialize_llm_pipeline(model_id: str, device: str) -> Pipeline:
116
+ """Initializes a Hugging Face transformers pipeline for GPU."""
117
+ logger.info(f"Initializing HF Pipeline for model: {model_id} on {device}")
118
+ model = AutoModelForCausalLM.from_pretrained(
119
+ model_id,
120
+ torch_dtype=torch.bfloat16,
121
+ device_map=device,
122
+ trust_remote_code=True
123
+ )
124
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
125
+ if not getattr(tokenizer, "chat_template", None):
126
+ logger.info("Applying Llama-3 chat template to tokenizer.")
127
+ tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE
128
+
129
+ if tokenizer.pad_token is None:
130
+ tokenizer.pad_token = tokenizer.eos_token
131
+
132
+ pipe = pipeline(
133
+ "text-generation",
134
+ model=model,
135
+ tokenizer=tokenizer,
136
+ device_map=device
137
+ )
138
+ logger.info(f"HF Pipeline for {model_id} loaded successfully.")
139
+ return pipe
140
+
141
+
142
+ def initialize_chroma_client() -> chromadb.PersistentClient:
143
+ """Initializes Chroma client and loads the index from S3/disk."""
144
+ logger.info(f"Initializing Chroma client from persistence directory: {CHROMA_DIR_INF}")
145
+
146
+ local_db = os.path.exists(CHROMA_DIR_INF) and os.listdir(CHROMA_DIR_INF)
147
+ db_exist = os.path.join(CHROMA_DIR_INF, "chroma.sqlite3")
148
+
149
+ if not local_db or not os.path.exists(db_exist):
150
+ logger.warning(f"local chroma directory {CHROMA_DIR_INF} is missing or empty. "
151
+ f"Attempting download from s3.")
152
+
153
+ if CHROMA_DIR:
154
+ download_chroma_folder_from_s3(
155
+ s3_prefix=CHROMA_DIR,
156
+ local_dir=CHROMA_DIR_INF
157
+ )
158
+ logger.info("Chroma data downloaded from S3.")
159
+ else:
160
+ logger.error("CHROMA_DIR is not set. Cannot retrieve chroma index")
161
+ raise RuntimeError("Chroma index failed to load")
162
+ else:
163
+ logger.info(f"Local chroma data found at {CHROMA_DIR_INF}.")
164
+ logger.info(f"Initializing chroma client from persistence directory: {CHROMA_DIR_INF}")
165
+ try:
166
+ client = PersistentClient(path=CHROMA_DIR_INF, settings=chromadb.Settings(allow_reset=False))
167
+ logger.info(" Chroma client initialized successfully.")
168
+ except Exception as e:
169
+ logger.error(f" Failed to load Chroma index: {e}")
170
+ raise RuntimeError("Chroma index failed to load.")
171
+
172
+ return client
173
+
174
+
175
+
176
+ async def load_cpu_pipeline() -> Tuple[Optional[object], str, int, int, int]:
177
+
178
+ if getattr(app.state, 'cpu_pipeline', None) is not None:
179
+ if isinstance(app.state.cpu_pipeline, Llama):
180
+ return app.state.cpu_pipeline, "cpu_gguf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU
181
+ return app.state.cpu_pipeline, "hf_gpu" if torch.cuda.is_available() else "cpu_hf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU
182
+
183
+ if getattr(app.state, 'tokenizer', None) is None:
184
+ try:
185
+ logger.info(f"Loading tokenizer from {LLM_TOKENIZER_ID}")
186
+ app.state.tokenizer = AutoTokenizer.from_pretrained(LLM_TOKENIZER_ID, use_fast=False)
187
+ if not getattr(app.state.tokenizer, "chat_template", None):
188
+ app.state.tokenizer.chat_template = LLAMA_3_CHAT_TEMPLATE
189
+ except Exception as e:
190
+ logger.error(f"Failed to load tokenizer: {e}", exc_info=True)
191
+ raise HTTPException(status_code=503, detail=f"Failed to load tokenizer: {e}")
192
+
193
+ if torch.cuda.is_available():
194
+ try:
195
+ logger.info(f"GPU detected. Attempting to load HF GPU model {LLM_MODEL_GPU_ID}...")
196
+ app.state.cpu_pipeline = await asyncio.to_thread(
197
+ initialize_llm_pipeline,
198
+ LLM_MODEL_GPU_ID,
199
+ "cuda"
200
+ )
201
+ logger.info("HF GPU model loaded successfully.")
202
+ return app.state.cpu_pipeline, "hf_gpu", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU
203
+ except Exception as e:
204
+ logger.warning(f"Failed to load HF GPU model: {e}. Falling back to CPU...")
205
+
206
+ if LLAMA_GGUF_PATH and os.path.exists(LLAMA_GGUF_PATH):
207
+ try:
208
+ logger.info("Loading TinyLlama GGUF (CPU)...")
209
+ logger.info(f"Model: {LLAMA_GGUF_PATH}")
210
+ app.state.cpu_pipeline = await asyncio.to_thread(
211
+ initialize_cpp_llm,
212
+ LLAMA_GGUF_PATH,
213
+ TINYLAMA_CONTEXT_WINDOW,
214
+ max(1, os.cpu_count() - 1)
215
+ )
216
+ logger.info("TinyLlama GGUF loaded successfully.")
217
+ return app.state.cpu_pipeline, "cpu_gguf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU
218
+ except Exception as e:
219
+ logger.warning(f"Failed to load GGUF CPU model: {e}")
220
+
221
+ try:
222
+ logger.info(f"Loading HF CPU model {LLM_MODEL_CPU_ID}...")
223
+ app.state.cpu_pipeline = await asyncio.to_thread(
224
+ initialize_llm_pipeline,
225
+ LLM_MODEL_CPU_ID,
226
+ "cpu"
227
+ )
228
+ logger.info("HF CPU model loaded successfully.")
229
+ return app.state.cpu_pipeline, "cpu_hf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU
230
+ except Exception as e:
231
+ logger.error(f"FATAL: Failed to load any CPU model: {e}", exc_info=True)
232
+ raise HTTPException(status_code=503, detail=f"Failed to load any model: {e}")
233
+
234
+
235
+ async def get_pipeline_for_runtime() -> Tuple[Optional[object], str, int, int, int]:
236
+ """
237
+ Determines runtime, lazily loads the correct pipeline (GPU/CPU),
238
+ and returns the pipeline and its associated settings.
239
+ NOTE: Return type is Optional[object] to handle both Pipeline and Llama
240
+ """
241
+ if await gpu_hours_available():
242
+ logger.info("GPU hours available. Attempting to load GPU pipeline.")
243
+ if getattr(app.state, 'gpu_pipeline', None) is None:
244
+ logger.info("Lazy-loading Llama-3.1-8B (GPU)...")
245
+ try:
246
+ if getattr(app.state, 'cpu_pipeline', None):
247
+ # NOTE: Clear both the Llama object and the separate tokenizer
248
+ del app.state.cpu_pipeline
249
+ app.state.cpu_pipeline = None
250
+ if getattr(app.state, 'llm_cpp', None):
251
+ del app.state.llm_cpp
252
+ app.state.llm_cpp = None
253
+ if getattr(app.state, 'tokenizer', None):
254
+ del app.state.tokenizer
255
+ app.state.tokenizer = None
256
+ gc.collect()
257
+ logger.info("Cleared CPU pipeline (Llama) and tokenizer from memory.")
258
+
259
+ if torch.cuda.is_available():
260
+ torch.cuda.empty_cache()
261
+
262
+ app.state.gpu_pipeline = await asyncio.to_thread(
263
+ initialize_llm_pipeline, LLM_MODEL_GPU_ID, "cuda"
264
+ )
265
+ logger.info("GPU pipeline loaded successfully.")
266
+ except Exception as e:
267
+ logger.error(f"Failed to load GPU pipeline: {e}. Falling back to CPU.", exc_info=True)
268
+ return await load_cpu_pipeline()
269
+
270
+ return app.state.gpu_pipeline, "gpu", LLAMA_3_CONTEXT_WINDOW, MAX_NEW_TOKENS_GPU, RETRIEVE_TOP_K_GPU
271
+ else:
272
+ logger.info("GPU hours exhausted or unavailable. Loading CPU pipeline.")
273
+ return await load_cpu_pipeline()
274
+
275
+ async def gpu_hours_available() -> bool:
276
+
277
+ force_gpu = False
278
+ if force_gpu:
279
+ return True
280
+
281
+ remaining_hours = 10
282
+ return remaining_hours > GPU_MIN_FREE_HOURS_THRESHOLD
283
+
284
+ app = FastAPI(title="RAG Inference API (Chroma + Llama 3)", version="1.0.0")
285
+
286
+ @app.on_event("startup")
287
+ async def load_models():
288
+ try:
289
+ logger.info("Starting FastAPI model loading...")
290
+
291
+ client = await asyncio.to_thread(initialize_chroma_client)
292
+
293
+ if not CHROMA_COLLECTION:
294
+ raise RuntimeError("CHROMA_COLLECTION variable not set in env")
295
+ app.state.chroma_collection = client.get_collection(name=CHROMA_COLLECTION)
296
+ if(app.state.chroma_collection):
297
+ app.state.chroma_ready = app.state.chroma_collection is not None
298
+ logger.info(f" Loaded collection: {CHROMA_COLLECTION} (Documents: {app.state.chroma_collection.count()})")
299
+
300
+
301
+ app.state.cache_collection = client.get_or_create_collection(name=CHROMA_CACHE_COLLECTION)
302
+ logger.info(f"Loaded Cache collection: {CHROMA_CACHE_COLLECTION} ({app.state.cache_collection.count()} items)")
303
+
304
+ app.state.gpu_pipeline: Optional[Pipeline] = None # type: ignore
305
+ app.state.cpu_pipeline: Optional[object] = None # type: ignore
306
+ app.state.llm_cpp: Optional[Llama] = None # type: ignore
307
+ app.state.tokenizer: Optional[AutoTokenizer] = None # type: ignore
308
+
309
+ if not app.state.chroma_ready:
310
+ raise RuntimeError("ChromaDB critical component failed to load.")
311
+
312
+
313
+ await load_cpu_pipeline()
314
+ logger.info("FastAPI models loaded successfully (CPU pipeline pre-warmed).")
315
+
316
+ except Exception as e:
317
+ app.state.chroma_ready = False
318
+ logger.error(f"Error during startup: {e}", exc_info=True)
319
+ raise
320
+
321
+ class HistoryMessage(BaseModel):
322
+ role: Literal['user', 'assistant']
323
+ content: str
324
+
325
+ class QueryRequest(BaseModel):
326
+ query: str = Field(..., description="The user's latest message.")
327
+ history: List[HistoryMessage] = Field(default_factory=list, description="The previous turns of the conversation.")
328
+ stream: bool = Field(False)
329
+
330
+ class RAGResponse(BaseModel):
331
+ query: str = Field(..., description="The original user query.")
332
+ answer: str = Field(..., description="The final answer generated by the LLM.")
333
+ sources: List[str] = Field(..., description="Unique source URLs used for the answer.")
334
+ context_chunks: List[str] = Field(..., description="The final context chunks (text only) sent to the LLM.")
335
+ expanded_queries: List[str] = Field(..., description="Queries used for retrieval.")
336
+
337
+
338
+ def call_llm_pipeline(pipe_like: Optional[object],
339
+ prompt_text: str,
340
+ deterministic=False,
341
+ max_new_tokens: int = MAX_NEW_TOKENS_CPU,
342
+ is_expansion: bool = False
343
+ ) -> str:
344
+ """
345
+ Unified caller for LLM:
346
+ - Handles llama_cpp.Llama instances (CPU)
347
+ - Handles transformers.Pipeline instances (GPU)
348
+ """
349
+
350
+ logging.info(f"model used: {pipe_like}")
351
+
352
+ if pipe_like is None:
353
+ raise HTTPException(status_code=503, detail="LLM pipeline is not available.")
354
+
355
+ if deterministic:
356
+ temp = 0.0
357
+ elif is_expansion:
358
+ temp = 0.1
359
+ else:
360
+ temp = 0.6
361
+ tokenizer = getattr(app.state, "tokenizer", None)
362
+ if tokenizer is None and isinstance(pipe_like, Pipeline):
363
+ tokenizer = getattr(pipe_like, "tokenizer", None)
364
+
365
+ try:
366
+ if tokenizer:
367
+ input_token_count = len(tokenizer.encode(prompt_text, add_special_tokens=True))
368
+ logger.info(f"LLM Input Token Count: {input_token_count}.")
369
+ except Exception:
370
+ logger.debug("Token counting failed, continuing without token count.")
371
+
372
+ try:
373
+
374
+ if isinstance(pipe_like, Llama):
375
+ llm = pipe_like
376
+ with torch.inference_mode():
377
+ resp = llm(
378
+ prompt_text,
379
+ max_tokens=max_new_tokens,
380
+ temperature=temp,
381
+ stop=["<|eot_id|>", "<|start_header_id|>", "<|end_of_text|>"]
382
+ )
383
+ text = resp.get("choices", [{}])[0].get("text", "").strip()
384
+ return text
385
+
386
+ elif isinstance(pipe_like, Pipeline):
387
+ pipe = pipe_like
388
+
389
+ with torch.inference_mode():
390
+ outputs = pipe(
391
+ prompt_text,
392
+ max_new_tokens=max_new_tokens,
393
+ temperature=temp if temp > 0.0 else None,
394
+ do_sample=True if temp > 0.0 else False,
395
+ pad_token_id=pipe.tokenizer.eos_token_id,
396
+ return_full_text=False
397
+ )
398
+
399
+ text = outputs[0]['generated_text'].strip()
400
+
401
+ if '<|eot_id|>' in text:
402
+ text = text.split('<|eot_id|>')[0].strip()
403
+ if '<|end_of_text|>' in text:
404
+ text = text.split('<|end_of_text|>')[0].strip()
405
+
406
+ return text
407
+
408
+ else:
409
+ logger.error(f"Unknown pipeline type: {type(pipe_like)}")
410
+ raise TypeError(f"Unknown pipeline type: {type(pipe_like)}")
411
+
412
+ except Exception as e:
413
+ logger.error(f"Error calling LLM pipeline: {e}", exc_info=True)
414
+ raise HTTPException(status_code=500, detail=f"LLM generation failed: {str(e)}")
415
+
416
+
417
+ async def expand_query_with_llm(pipe: Optional[object],
418
+ user_query: str,
419
+ summary: str,
420
+ history: Optional[List[HistoryMessage]]
421
+ ) -> List[str]:
422
+ """
423
+ Implements the robust two-mode query strategy: expansion or rewriting.
424
+ """
425
+
426
+ messages = []
427
+ expanded_queries: List[str] = []
428
+
429
+
430
+ if not history or len(history) == 0:
431
+ system_prompt = "You are a specialized query expansion engine."
432
+ user_prompt = f"""
433
+ Generate 3 alternative search queries similar to the user query below.
434
+ The goal is to maximize retrieval relevance based on the user query.
435
+ Return only the queries, one per line, without numbers or extra text.
436
+ If user query is a greeting, don't reply with a greeting too and ask how you can help.
437
+ If user query is gibberish
438
+
439
+ User Query:
440
+ {user_query}
441
+ """
442
+ messages.append({"role": "system", "content": system_prompt})
443
+ messages.append({"role": "user", "content": user_prompt})
444
+
445
+ else:
446
+
447
+ messages = [
448
+ {
449
+ "role": "system",
450
+ "content": "You are a helpful assistant who expands user queries into multiple search queries based on conversation history and user intent."
451
+ },
452
+ {
453
+ "role": "user",
454
+ "content": f"""
455
+ Given the conversation summary below and the user query, expand the user query into three queries that best reflect the conversation history, intent, and user needs.
456
+ Return only the queries, one per line, without numbers, preamble, or other text.
457
+
458
+ Conversation Summary:
459
+ {summary}
460
+
461
+ User Query:
462
+ {user_query}
463
+
464
+ Queries:
465
+ """
466
+ }
467
+ ]
468
+
469
+ tokenizer = getattr(app.state, "tokenizer", None)
470
+ if tokenizer is None and isinstance(pipe, Pipeline):
471
+ tokenizer = getattr(pipe, "tokenizer", None)
472
+
473
+ if tokenizer:
474
+ prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
475
+ else:
476
+ logger.warning("No tokenizer found for expansion, using simple join.")
477
+ prompt_text = "\n".join([m["content"] for m in messages])
478
+
479
+ logger.info(f"Query Expansion/Rewrite Prompt: {prompt_text}")
480
+ start = time.time()
481
+
482
+ llm_output = await asyncio.to_thread(
483
+ call_llm_pipeline, pipe, prompt_text, deterministic=True, is_expansion=True, max_new_tokens=150
484
+ )
485
+ end = time.time()
486
+ logger.info(f"Query Expansion/Rewrite Output: {llm_output} (Time: {end-start:.2f}s)")
487
+
488
+ if not history or len(history) <= 0:
489
+ expanded_queries = [
490
+ q.strip() for q in llm_output.split('\n')
491
+ if q.strip() and "engine" not in q.lower() and "task" not in q.lower() and "search queries" not in q.lower()
492
+ ]
493
+ else:
494
+ expanded_queries = [llm_output.strip()]
495
+
496
+ expanded_queries.append(user_query)
497
+
498
+ return list(set(q for q in expanded_queries if q))
499
+
500
+
501
+ def retrieve_context(queries: List[str], collection: Collection) -> Tuple[List[Dict], List[str]]:
502
+ """Retrieves context from ChromaDB based on query embeddings."""
503
+ try:
504
+ embeddings_list = [
505
+ [float(x) for x in emb]
506
+ for emb in EMBEDDING_MODEL.embed(queries, batch_size=8)
507
+ ]
508
+ except Exception as e:
509
+ logger.error(f"Failed to generate embeddings for retrieval: {e}", exc_info=True)
510
+ return [], []
511
+
512
+ try:
513
+ n_results_to_fetch = max(10, RETRIEVE_TOP_K_CPU * len(queries))
514
+ start = time.time()
515
+ results = collection.query(
516
+ query_embeddings=embeddings_list,
517
+ n_results=n_results_to_fetch,
518
+ include=['documents', 'metadatas']
519
+ )
520
+ end = time.time()
521
+ logger.info(f'RETRIEVING TOOK: {end-start:.2f}s')
522
+ except Exception as e:
523
+ logger.error(f"Chroma query failed: {e}")
524
+ return [], []
525
+
526
+ context_data = []
527
+ source_urls = set()
528
+ seen_texts = set()
529
+
530
+ if results.get("documents") and results.get("metadatas"):
531
+ for docs_list, metadatas_list in zip(results["documents"], results["metadatas"]):
532
+ for doc, metadata in zip(docs_list, metadatas_list):
533
+ if doc and metadata and doc not in seen_texts:
534
+ context_data.append({'text': doc, 'url': metadata.get('source')})
535
+ if metadata.get("source"):
536
+ source_urls.add(metadata.get('source'))
537
+ seen_texts.add(doc)
538
+
539
+ return context_data, list(source_urls)
540
+
541
+
542
+ def rerank_documents(query: str, context: List[Dict], top_k: int) -> List[Dict]:
543
+ """
544
+ Re-ranks context documents using a cross-encoder.
545
+ Returns the top-k most relevant documents.
546
+ """
547
+ if not context or not cross_encoder:
548
+ logger.warning("Skipping reranking (no context or cross-encoder not loaded).")
549
+ return context[:top_k]
550
+
551
+ top_k = min(top_k, len(context))
552
+ pairs = [(query, doc['text']) for doc in context]
553
+
554
+ try:
555
+ start = time.time()
556
+ scores = cross_encoder.predict(pairs)
557
+ end = time.time()
558
+ logger.info(f'RERANKING TOOK {end-start:.2f}s')
559
+ except Exception as e:
560
+ logger.error(f"Cross-encoder prediction failed: {e}. Returning non-reranked results.", exc_info=True)
561
+ return context[:top_k]
562
+
563
+ for doc, score in zip(context, scores):
564
+ doc['score'] = score
565
+
566
+ ranked_docs = sorted(context, key=lambda x: x['score'], reverse=True)
567
+
568
+ return ranked_docs[:top_k]
569
+
570
+ async def summarize_history(history: List[HistoryMessage], pipe: Optional[object]) -> str:
571
+ """
572
+ Summarizes the conversation history using the correct LLM (via call_llm_pipeline).
573
+ """
574
+ if not history:
575
+ return ''
576
+
577
+ history_text = "\n".join([f"{h.role}: {h.content}" for h in history[-8:]])
578
+
579
+ tokenizer = getattr(app.state, "tokenizer", None)
580
+ if tokenizer is None and isinstance(pipe, Pipeline):
581
+ tokenizer = getattr(pipe, "tokenizer", None)
582
+
583
+ history_tokens = len(tokenizer.encode(history_text)) if tokenizer else len(history_text.split())
584
+
585
+ if history_tokens < 150:
586
+ return ""
587
+
588
+ summarizer_prompt = f"""
589
+ You are an intelligent agent who summarizes conversations.
590
+ Concisely summarize the key topics and entities discussed in the
591
+ conversation history between a user and an assistant.
592
+ The summary should be a few sentences long.
593
+
594
+ CONVERSATION HISTORY:
595
+ {history_text}
596
+
597
+ CONCISE SUMMARY:
598
+ """
599
+
600
+ logger.info("Generating conversation summary...")
601
+ start = time.time()
602
+ summary = await asyncio.to_thread(
603
+ call_llm_pipeline, pipe, summarizer_prompt, deterministic=True, max_new_tokens=150, is_expansion=False
604
+ )
605
+ end = time.time()
606
+ logger.info(f"HISTORY SUMMARIZATION: {summary} (Time: {end-start:.2f}s)")
607
+ return summary
608
+
609
+
610
+ def build_prompt(user_query: str, context: List[Dict], summary: str) -> List[Dict]:
611
+ """
612
+ Builds the final list of messages for the chat template, including the RAG context.
613
+ """
614
+
615
+ messages = []
616
+ context_text = "\n---\n".join([c['text'] for c in context]) if context else "No relevant context found."
617
+
618
+ rag_system_prompt = f"""
619
+ You are a helpful medical assistant with a friendly, conversational tone.
620
+ Use the retrieved context to answer the user's query accurately.
621
+ If the context is missing, clearly state that the WHO factsheets don't contain the information.
622
+ Do not repeat the user's question or the context in your response. Do not answer racist, harmful, discriminatory, non-health question.
623
+ Formulate a direct, conversational answer using only the provided context as definitive truth.
624
+
625
+ ---
626
+ retrieved context:
627
+ {context_text}
628
+ ---
629
+ conversation history summary:
630
+ {summary if summary else "No summary available."}
631
+ ---
632
+ """
633
+ messages.append({"role": "system", "content": rag_system_prompt})
634
+
635
+ messages.append({"role": "user", "content": user_query})
636
+
637
+ return messages
638
+
639
+
640
+ async def prune_messages_to_fit_context(messages: List[Dict],
641
+ final_context: List[Dict],
642
+ summary: str,
643
+ max_input_tokens: int,
644
+ pipe: Optional[object]
645
+ ) -> Tuple[List[Dict], List[Dict], int]:
646
+ """
647
+ Ensures the total prompt fits within max_input_tokens.
648
+ Prunes retrieved context chunks and compresses summary if needed.
649
+ Uses app.state.tokenizer (CPU) or pipe.tokenizer (GPU).
650
+ """
651
+
652
+ tokenizer = getattr(app.state, "tokenizer", None)
653
+ if tokenizer is None and isinstance(pipe, Pipeline):
654
+ tokenizer = getattr(pipe, "tokenizer", None)
655
+
656
+ if not tokenizer:
657
+ logger.error("Tokenizer not loaded for pruning.")
658
+ return messages, final_context, 0
659
+
660
+ def get_token_count(msg_list: List[Dict]) -> int:
661
+ prompt_text = tokenizer.apply_chat_template(msg_list, tokenize=False, add_generation_prompt=True)
662
+ return len(tokenizer.encode(prompt_text, add_special_tokens=False))
663
+
664
+ current_context = final_context[:]
665
+ current_messages = messages[:]
666
+ tok_length = 0
667
+ token_count = get_token_count(current_messages)
668
+ base_user_query = messages[-1]["content"]
669
+ logger.info(f"Pre-pruning token count: {token_count}. Max: {max_input_tokens}")
670
+
671
+ if token_count <= max_input_tokens:
672
+ tok_length = max_input_tokens - token_count
673
+ return current_messages, current_context, tok_length
674
+
675
+ current_summary = summary
676
+ if len(summary.split()) > 50:
677
+ logger.warning(f"Context overflow ({token_count} > {max_input_tokens}). Compressing conversation summary.")
678
+ compression_prompt = f"""
679
+ The following conversation summary is too long for the LLM's context window.
680
+ Rewrite it to be half its length, retaining only the essential topics.
681
+ Do not add preamble or commentary.
682
+
683
+ Current summary:
684
+ {summary}
685
+
686
+ Compressed summary:
687
+ """
688
+ start = time.time()
689
+ new_summary_text = await asyncio.to_thread(
690
+ call_llm_pipeline, pipe, compression_prompt, deterministic=True, is_expansion=False, max_new_tokens=75
691
+ )
692
+ end = time.time()
693
+ current_summary = new_summary_text.strip()
694
+ logger.info(f"SUMMARY COMPRESSED {end - start:.2f}s. New summary: {current_summary}")
695
+
696
+
697
+ token_count = get_token_count(current_messages)
698
+ logger.info(f"Token count after summary compression: {token_count}")
699
+
700
+ if token_count <= max_input_tokens:
701
+ tok_length = max_input_tokens - token_count
702
+ return current_messages, current_context, tok_length
703
+
704
+ logger.warning(f"Context still overflowing ({token_count} > {max_input_tokens}). Pruning context chunks.")
705
+ while token_count > max_input_tokens and len(current_context) > 1:
706
+ removed_chunk = current_context.pop()
707
+ logger.warning(f"Removing last context chunk: {removed_chunk['text'][:50]}...")
708
+
709
+ current_messages = build_prompt(base_user_query, current_context, current_summary)
710
+ token_count = get_token_count(current_messages)
711
+ logger.info(f"Token count after removing a chunk: {token_count}")
712
+
713
+ if token_count <= max_input_tokens:
714
+ tok_length = max_input_tokens - token_count
715
+ return current_messages, current_context, tok_length
716
+ logger.warning(f"Context still overflowing ({token_count} > {max_input_tokens}). Aggressively dropping least-relevant chunks.")
717
+
718
+ while token_count > max_input_tokens and len(current_context) > 1:
719
+ removed_chunk = current_context.pop()
720
+ logger.warning(f"Removing last context chunk: {removed_chunk['text'][:50]}...")
721
+
722
+ current_messages = build_prompt(base_user_query, current_context, current_summary)
723
+ token_count = get_token_count(current_messages)
724
+ logger.info(f"Token count after removing a chunk: {token_count}")
725
+
726
+ if token_count <= max_input_tokens:
727
+ tok_length = max_input_tokens - token_count
728
+ return current_messages, current_context, tok_length
729
+
730
+ if token_count > max_input_tokens and current_context:
731
+ logger.error(f"FATAL: Prompt still exceeds limit ({token_count}) with only 1 chunk remaining. Token-based truncation required.")
732
+
733
+ tokens_without_chunk = get_token_count(build_prompt(base_user_query, [], current_summary))
734
+
735
+ max_chunk_tokens = max_input_tokens - tokens_without_chunk - 5
736
+
737
+ final_chunk = current_context[0]
738
+
739
+ if max_chunk_tokens > 50:
740
+ encoded_chunk = tokenizer.encode(final_chunk['text'])
741
+ truncated_tokens = encoded_chunk[:max_chunk_tokens]
742
+
743
+ final_chunk['text'] = tokenizer.decode(truncated_tokens, skip_special_tokens=True) + " [TRUNCATED]"
744
+
745
+ current_messages = build_prompt(base_user_query, current_context, current_summary)
746
+ token_count = get_token_count(current_messages)
747
+ logger.warning(f"Aggressively truncated final chunk. New count: {token_count}")
748
+ else:
749
+
750
+ current_context = []
751
+ current_messages = build_prompt(base_user_query, current_context, current_summary)
752
+ token_count = get_token_count(current_messages)
753
+ logger.warning("Remaining context budget too small; removing all context.")
754
+
755
+ tok_length = max_input_tokens - token_count
756
+ return current_messages, current_context, tok_length
757
+
758
+ tok_length = max_input_tokens - token_count if token_count < max_input_tokens else 0
759
+ return current_messages, current_context, tok_length
760
+
761
+
762
+ async def Greet(query, pipe):
763
+ messages = []
764
+ logging.info(f"User sent a greeting")
765
+ prompt_text = """You are a greeter. Your job is to respond politely to the user greeting.
766
+ ONLY a single polite and short greetings. Do not do anything else.
767
+
768
+ Examples:
769
+ User: Hi
770
+ Assistant: Hello, How may I help you today?
771
+
772
+ User: how are you?
773
+ Assistant: I am good, I can help you answer health related questions"""
774
+
775
+
776
+ messages.append({"role": "system", "content": prompt_text})
777
+ messages.append({"role": "user", "content": query})
778
+ tokenizer = getattr(app.state, "tokenizer", None)
779
+ prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
780
+
781
+ answer = await asyncio.to_thread( call_llm_pipeline,
782
+ pipe,
783
+ prompt_text,
784
+ deterministic=True,
785
+ max_new_tokens=50,
786
+ is_expansion= True
787
+ )
788
+
789
+ return RAGResponse(
790
+ query=query,
791
+ answer=answer,
792
+ sources=[],
793
+ context_chunks=[],
794
+ expanded_queries=[]
795
+ )
796
+
797
+ async def HarmOff(query, pipe):
798
+ messages = []
799
+ logging.info(f"User asked harmful or off-topic question")
800
+ prompt_text = """
801
+ You are an intelligent assistant.
802
+ Your job is to inform the user that you are not allowed to answer such questions.
803
+ Keep it short and brief, in one sentence.
804
+
805
+ Examples:
806
+ user: write a code to print a number
807
+ Assistant: I am not allowed to answer such questions
808
+
809
+ User: how can I be racist
810
+ Assistant: Sorry, I am not allowed to answer such questions
811
+ """
812
+
813
+ messages.append({"role": "system", "content": prompt_text})
814
+ messages.append({"role": "user", "content": query})
815
+ tokenizer = getattr(app.state, "tokenizer", None)
816
+ prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
817
+
818
+ answer = await asyncio.to_thread( call_llm_pipeline,
819
+ pipe,
820
+ prompt_text,
821
+ deterministic=True,
822
+ max_new_tokens=50,
823
+ is_expansion= True
824
+ )
825
+
826
+
827
+ return RAGResponse(
828
+ query=query,
829
+ answer=answer,
830
+ sources=[],
831
+ context_chunks=[],
832
+ expanded_queries=[]
833
+ )
834
+
835
+
836
+ async def classify_intent(query: str, pipe: object) -> str:
837
+ """Uses a rigid, JSON-forced prompt to classify user intent."""
838
+
839
+ tokenizer = getattr(app.state, "tokenizer", None)
840
+ if not tokenizer:
841
+ return "ERROR_TOKENIZER"
842
+
843
+ system_prompt = """You are a query classification robot. You MUST respond with ONLY ONE JSON object:
844
+ {"intent": "MEDICAL"}
845
+ {"intent": "GREET"}
846
+ {"intent": "OFF_TOPIC"}
847
+ {"intent": "HARMFUL"}
848
+
849
+ Do not add any other text, explanations, or punctuation.
850
+ """
851
+ # Use few-shot examples (as provided in the original prompt, but restructured for JSON)
852
+ messages = [
853
+ {"role": "system", "content": system_prompt},
854
+ {"role": "user", "content": "Query: What are the symptoms of COVID-19?"},
855
+ {"role": "assistant", "content": '{"intent": "MEDICAL"}'},
856
+ {"role": "user", "content": "Query: Hello"},
857
+ {"role": "assistant", "content": '{"intent": "GREET"}'},
858
+ {"role": "user", "content": "Query: Tell me a joke"},
859
+ {"role": "assistant", "content": '{"intent": "OFF_TOPIC"}'},
860
+ {"role": "user", "content": "Query: Make a racist insult"},
861
+ {"role": "assistant", "content": '{"intent": "HARMFUL"}'},
862
+
863
+ {"role": "user", "content": f"Query: {query}"}
864
+ ]
865
+
866
+ prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
867
+
868
+ llm_output = await asyncio.to_thread(
869
+ call_llm_pipeline,
870
+ pipe,
871
+ prompt_text,
872
+ deterministic=True,
873
+ max_new_tokens=25, # Enough for the JSON object
874
+ is_expansion=False # Classification is not expansion
875
+ )
876
+
877
+ # --- Robust JSON Parsing ---
878
+ try:
879
+ clean_output = llm_output.strip().replace("```json", "").replace("```", "")
880
+ start_idx = clean_output.find('{')
881
+ end_idx = clean_output.rfind('}')
882
+
883
+ if start_idx != -1 and end_idx != -1:
884
+ json_str = clean_output[start_idx : end_idx + 1]
885
+ data = json.loads(json_str)
886
+ return data.get("intent", "UNKNOWN")
887
+
888
+ except Exception as e:
889
+ logger.error(f"Failed to parse JSON classifier output: {e}. Raw: {llm_output}")
890
+ # Fallback to check for the raw label token in case of parsing failure
891
+ raw_output_upper = llm_output.upper()
892
+ for label in ["MEDICAL", "GREET", "OFF_TOPIC", "HARMFUL"]:
893
+ if label in raw_output_upper:
894
+ return label
895
+
896
+ return "UNKNOWN"
897
+
898
+ @app.get("/health")
899
+ async def health_check():
900
+ """Endpoint for checking the status of the RAG service."""
901
+ chroma_ok = getattr(app.state, 'chroma_ready', False) and app.state.chroma_collection is not None
902
+
903
+ if not chroma_ok:
904
+ raise HTTPException(status_code=503, detail="Vector DB is not loaded.")
905
+
906
+ llm_status = "not_loaded"
907
+ if getattr(app.state, 'gpu_pipeline', None) or getattr(app.state, 'cpu_pipeline', None):
908
+ llm_status = "loaded"
909
+ else:
910
+ try:
911
+ await load_cpu_pipeline()
912
+ llm_status = "lazy_loaded_cpu_ok"
913
+ except Exception as e:
914
+ logger.error(f"Health check failed to load CPU model: {e}", exc_info=True)
915
+ raise HTTPException(status_code=503, detail=f"Chroma is loaded, but failed to load fallback LLM: {e}")
916
+
917
+ return {"status": "ok", "service": "rag-service", "chroma": "loaded", "llm": llm_status}
918
+
919
+
920
+ @app.post("/rag", response_model=RAGResponse)
921
+ async def rag_handler(request: QueryRequest):
922
+
923
+
924
+ start = time.time()
925
+ try:
926
+ pipe, runtime_env, max_context, max_gen, top_k = await load_cpu_pipeline()
927
+ logger.info(f"Using model: {pipe}")
928
+
929
+ except HTTPException as e:
930
+ logger.error(f"Failed to get LLM pipeline: {e.detail}", exc_info=True)
931
+ raise e
932
+ except Exception as e:
933
+ logger.error(f"Unhandled error getting pipeline: {e}", exc_info=True)
934
+ raise HTTPException(status_code=503, detail=f"Failed to load LLM model: {str(e)}")
935
+
936
+ if not getattr(app.state, 'chroma_ready', False) or not app.state.chroma_collection:
937
+ raise HTTPException(status_code=503, detail="Service is initializing or failed to load Vector DB.")
938
+
939
+ try:
940
+ answer = await classify_intent(request.query, pipe)
941
+ end_time = time.time()
942
+ logger.info(f"answer directly by model: {answer}, TIME: {end_time-start:.2f}s")
943
+
944
+ if answer == 'GREET':
945
+ response = await Greet(request.query, pipe)
946
+ end_time = time.time()
947
+ logger.info(f"Query handled directly by model: {request.query}, TIME: {end_time-start:.2f}s")
948
+ logger.info(f"answer directly by model: {response}, TIME: {end_time-start:.2f}s")
949
+ return response
950
+ if answer == "HARMFUL" or answer == "OFF_TOPIC":
951
+ response = await HarmOff(request.query, pipe)
952
+ end_time = time.time()
953
+ logger.info(f"Query handled directly by model: {request.query}, TIME: {end_time-start:.2f}s")
954
+ logger.info(f"answer directly by model: {response}, TIME: {end_time-start:.2f}s")
955
+ return response
956
+
957
+ logger.info("Classifier returned RETRIEVE. Starting RAG pipeline.")
958
+
959
+ summary = await summarize_history(request.history, pipe)
960
+
961
+ expanded_queries = await expand_query_with_llm(pipe, request.query, summary, request.history)
962
+
963
+ context_data, all_sources = await asyncio.to_thread(retrieve_context, expanded_queries, app.state.chroma_collection)
964
+
965
+ final_context = await asyncio.to_thread(rerank_documents, request.query, context_data, top_k=top_k)
966
+ final_sources = list({c['url'] for c in final_context if c.get('url')})
967
+
968
+ if not final_context:
969
+ final_answer = "I could not find relevant documents in the knowledge base to answer your question. I can help you if you have another question."
970
+ context_chunks_text = []
971
+ else:
972
+ initial_messages = build_prompt(request.query, final_context, summary)
973
+
974
+ max_input_tokens = max_context - max_gen - SAFETY_BUFFER
975
+ logger.info(
976
+ f"Runtime: {runtime_env}, Max Context: {max_context}, "
977
+ f"Max Gen: {max_gen}, Max Input: {max_input_tokens}"
978
+ )
979
+
980
+ final_messages, final_context_pruned, tok_length = await prune_messages_to_fit_context(
981
+ initial_messages,
982
+ final_context,
983
+ summary,
984
+ max_input_tokens,
985
+ pipe
986
+ )
987
+
988
+ context_chunks_text = [c['text'] for c in final_context_pruned]
989
+
990
+ tokenizer = getattr(app.state, "tokenizer", None)
991
+ if tokenizer:
992
+ prompt_text = tokenizer.apply_chat_template(final_messages, tokenize=False, add_generation_prompt=True)
993
+ else:
994
+ logger.warning("Tokenizer not found for final prompt, using simple join.")
995
+ prompt_text = "\n".join([m["content"] for m in final_messages])
996
+
997
+ final_answer = await asyncio.to_thread(
998
+ call_llm_pipeline,
999
+ pipe,
1000
+ prompt_text,
1001
+ deterministic=False,
1002
+ max_new_tokens=max(max_gen, tok_length)
1003
+ )
1004
+
1005
+ end_time = time.time()
1006
+ logger.info(
1007
+ json.dumps({
1008
+ "query": request.query,
1009
+ "latency_sec": round(end_time - start, 2),
1010
+ "runtime_env": runtime_env,
1011
+ "num_sources": len(final_sources),
1012
+ "num_context_chunks_sent": len(context_chunks_text),
1013
+ "expanded_queries": expanded_queries,
1014
+ "final_answer": final_answer,
1015
+ "retrieved_context_snippets": [c[:50] + "..." for c in context_chunks_text]
1016
+ })
1017
+ )
1018
+
1019
+ return RAGResponse(
1020
+ query=request.query,
1021
+ answer=final_answer,
1022
+ sources=final_sources,
1023
+ context_chunks=context_chunks_text,
1024
+ expanded_queries=expanded_queries
1025
+ )
1026
+
1027
+ except Exception as e:
1028
+ logger.error(f"Unhandled exception in RAG handler: {e}", exc_info=True)
1029
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
requirements_heavy.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # torch
2
+ # transformers
3
+ # bitsandbytes
4
+ # sentence-transformers
5
+ # accelerate
6
+
7
+ torch==2.9.0
8
+ transformers
9
+ # bitsandbytes
10
+ sentence-transformers
11
+ accelerate
12
+ # llama-cpp-python UNCOMMENT THIS LINE FOR LOCAL DOCKER TESTING
13
+ llama-cpp-python==0.2.83 --extra-index-url https://abetlen.github.io/llama-cpp-python-wheels/
14
+ tiktoken
15
+ sentencepiece
requirements_light.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ chromadb
4
+ pydantic
5
+ fastembed
6
+ requests
7
+ python-json-logger
8
+ boto3
s3_utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional
2
+ import boto3
3
+ import os
4
+ import json
5
+ import logging
6
+ # from botocore.exceptions import NoCredentialsError, ClientError
7
+
8
+ S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
9
+ AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY_ID")
10
+ AWS_SECRET_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
11
+ AWS_REGION = os.getenv("AWS_REGION")
12
+
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+
15
+ def get_s3_client():
16
+ if not AWS_ACCESS_KEY or not AWS_SECRET_KEY:
17
+ logging.warning("AWS credentials not found in environment. Using default config.")
18
+ return boto3.client('s3', region_name=AWS_REGION)
19
+
20
+ return boto3.client(
21
+ 's3',
22
+ aws_access_key_id=AWS_ACCESS_KEY,
23
+ aws_secret_access_key=AWS_SECRET_KEY,
24
+ region_name=AWS_REGION
25
+ )
26
+
27
+ def download_chroma_folder_from_s3(s3_prefix: str, local_dir: str):
28
+ """
29
+ Downloads all files under s3_prefix from S3 to local_dir,
30
+ preserving the folder structure for ChromaDB.
31
+ """
32
+ s3 = get_s3_client()
33
+ paginator = s3.get_paginator("list_objects_v2")
34
+ try:
35
+ for page in paginator.paginate(Bucket=S3_BUCKET_NAME, Prefix=s3_prefix):
36
+ for obj in page.get("Contents", []):
37
+ s3_key = obj["Key"]
38
+ rel_path = os.path.relpath(s3_key, s3_prefix)
39
+ local_path = os.path.join(local_dir, rel_path)
40
+
41
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
42
+
43
+ with open(local_path, "wb") as f:
44
+ s3.download_fileobj(Bucket=S3_BUCKET_NAME, Key=s3_key, Fileobj=f)
45
+
46
+ logging.info(f"ChromaDB folder downloaded from S3 to {local_dir} successfully.")
47
+
48
+ except Exception as e:
49
+ logging.error(f"Failed to download ChromaDB folder from S3: {e}")
upload_model.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import HfApi
2
+ import os
3
+
4
+
5
+ local_model_folder = os.path.join(os.path.dirname(__file__), 'model')
6
+
7
+ repo_id = "EJ4U/WHO-rag-model"
8
+ repo_type = "model"
9
+
10
+
11
+
12
+ api = HfApi()
13
+ token = os.getenv('HF_TOKEN')
14
+ api.create_repo(repo_id=repo_id, repo_type="model", token=token, exist_ok=True)
15
+
16
+
17
+
18
+ api.upload_folder(
19
+ folder_path=local_model_folder,
20
+ path_in_repo="",
21
+ repo_id=repo_id,
22
+ repo_type=repo_type,
23
+ token=token
24
+ )
25
+
26
+ print(f"Uploaded {local_model_folder} to https://huggingface.co/{repo_id}")