GitHub Actions commited on
Commit
7e0bf54
·
1 Parent(s): cb7edc9

Deploy Mon Nov 24 21:18:16 UTC 2025

Browse files
README.md CHANGED
@@ -7,4 +7,532 @@ sdk: docker
7
  pinned: false
8
  license: mit
9
  app_port: 7860
10
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  pinned: false
8
  license: mit
9
  app_port: 7860
10
+ ---
11
+
12
+
13
+ # @app.function(
14
+ # image=rag_image,
15
+ # mounts=[download_mount],
16
+ # gpu="T4",
17
+ # secrets=[Secret.from_name("aws-credentials"), Secret.from_name("chromadb")]
18
+ # )
19
+ # @web_endpoint(method="POST", path="/rag", timeout=300)
20
+ # async def rag_endpoint(request_data: Dict[str, Any]):
21
+
22
+ # if STATE.gpu_pipeline is None:
23
+ # logger.info("Starting Modal function: Lazy-loading LLM, Chroma, and encoders...")
24
+ # try:
25
+ # client = await asyncio.to_thread(initialize_chroma_client)
26
+ # STATE.chroma_collection = client.get_collection(name=CHROMA_COLLECTION)
27
+ # STATE.cache_collection = client.get_or_create_collection(name=CHROMA_CACHE_COLLECTION)
28
+ # STATE.chroma_ready = STATE.chroma_collection is not None
29
+ # logger.info(f"Loaded collection: {CHROMA_COLLECTION} (Documents: {STATE.chroma_collection.count() if STATE.chroma_collection else 0})")
30
+
31
+ # STATE.gpu_pipeline = await asyncio.to_thread(
32
+ # initialize_llm_pipeline, LLM_MODEL_GPU_ID, DEVICE
33
+ # )
34
+ # STATE.tokenizer = STATE.gpu_pipeline.tokenizer
35
+
36
+ # STATE.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL, device=DEVICE)
37
+ # STATE.embedding_model = TextEmbedding(model_name="BAAI/bge-small-en-v1.5")
38
+ # _ = list(STATE.embedding_model.embed(["warmup"]))
39
+
40
+ # logger.info("All RAG components (GPU LLM, Chroma, Encoders) loaded successfully.")
41
+
42
+ # except Exception as e:
43
+ # logger.error(f"FATAL: Error during Modal startup: {e}", exc_info=True)
44
+ # raise HTTPException(status_code=503, detail=f"Service initialization failed: {str(e)}")
45
+
46
+ # try:
47
+ # request = QueryRequest(**request_data)
48
+ # except Exception as e:
49
+ # raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
50
+
51
+ # start = time.time()
52
+ # pipe = STATE.gpu_pipeline
53
+ # runtime_env = "gpu_modal"
54
+ # max_context = LLAMA_3_CONTEXT_WINDOW
55
+ # max_gen = MAX_NEW_TOKENS_GPU
56
+ # top_k = RETRIEVE_TOP_K_GPU
57
+
58
+ # try:
59
+ # intent = await classify_intent(request.query, pipe)
60
+ # logger.info(f"Intent classified as: {intent}")
61
+
62
+ # if intent == 'GREET':
63
+ # response = await Greet(request.query, pipe)
64
+
65
+ # elif intent in ["HARMFUL", "OFF_TOPIC"]:
66
+ # response = await HarmOff(request.query, pipe)
67
+
68
+ # else:
69
+ # logger.info("Classifier returned RETRIEVE. Starting RAG pipeline.")
70
+
71
+ # summary = await summarize_history(request.history, pipe)
72
+ # expanded_queries = await expand_query_with_llm(pipe, request.query, summary, request.history)
73
+
74
+ # context_data, all_sources = await asyncio.to_thread(retrieve_context, expanded_queries, STATE.chroma_collection)
75
+ # final_context = await asyncio.to_thread(rerank_documents, request.query, context_data, top_k=top_k)
76
+ # final_sources = list({c['url'] for c in final_context if c.get('url')})
77
+
78
+ # if not final_context:
79
+ # final_answer = "I could not find relevant documents in the knowledge base to answer your question. I can help you if you have another question."
80
+ # context_chunks_text = []
81
+ # else:
82
+ # initial_messages = build_prompt(request.query, final_context, summary)
83
+ # max_input_tokens = max_context - max_gen - SAFETY_BUFFER
84
+
85
+ # final_messages, final_context_pruned, tok_length = await prune_messages_to_fit_context(
86
+ # initial_messages,
87
+ # final_context,
88
+ # summary,
89
+ # max_input_tokens,
90
+ # pipe
91
+ # )
92
+
93
+ # context_chunks_text = [c['text'] for c in final_context_pruned]
94
+
95
+ # prompt_text = STATE.tokenizer.apply_chat_template(final_messages, tokenize=False, add_generation_prompt=True)
96
+
97
+ # final_answer = await asyncio.to_thread(
98
+ # call_llm_pipeline,
99
+ # pipe,
100
+ # prompt_text,
101
+ # deterministic=False,
102
+ # max_new_tokens=max(max_gen, tok_length)
103
+ # )
104
+
105
+ # response = RAGResponse(
106
+ # query=request.query,
107
+ # answer=final_answer,
108
+ # sources=final_sources,
109
+ # context_chunks=context_chunks_text,
110
+ # expanded_queries=expanded_queries
111
+ # )
112
+
113
+ # end_time = time.time()
114
+ # logger.info(f"Total Latency: {round(end_time - start, 2)}s. Runtime: {runtime_env}")
115
+
116
+ # return response.model_dump()
117
+
118
+ # except Exception as e:
119
+ # logger.error(f"Unhandled exception in RAG handler: {e}", exc_info=True)
120
+ # raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
121
+
122
+
123
+
124
+ # @app.local_entrypoint()
125
+
126
+ # def main():
127
+ # test_request_data = {
128
+ # "query": "What are the common side effects of the latest WHO recommended vaccine?",
129
+ # "history": []
130
+ # }
131
+ # print("--- Running rag_endpoint LOCALLY for quick test ---")
132
+ # try:
133
+ # result = rag_endpoint(test_request_data)
134
+
135
+ # print("\n--- TEST RESPONSE ---")
136
+ # print(f"Answer: {result.get('answer', 'N/A')}")
137
+ # print(f"Sources: {result.get('sources', [])}")
138
+
139
+ # except Exception as e:
140
+ # print(f"\n--- LOCAL EXECUTION FAILED AS EXPECTED (Missing GPU/S3): {e} ---")
141
+ # print("This confirms the Python logic executes, but the remote resources (GPU, S3) are not accessible locally.")
142
+
143
+
144
+ # @app.local_entrypoint()
145
+
146
+ # def main():
147
+ # test_request_data = {
148
+ # "query": "What are the common side effects of the latest WHO recommended vaccine?",
149
+ # "history": []
150
+ # }
151
+ # print("--- Running rag_endpoint LOCALLY for quick test ---")
152
+ # try:
153
+ # result = rag_endpoint(test_request_data)
154
+
155
+ # print("\n--- TEST RESPONSE ---")
156
+ # print(f"Answer: {result.get('answer', 'N/A')}")
157
+ # print(f"Sources: {result.get('sources', [])}")
158
+
159
+ # except Exception as e:
160
+ # print(f"\n--- LOCAL EXECUTION FAILED AS EXPECTED (Missing GPU/S3): {e} ---")
161
+ # print("This confirms the Python logic executes, but the remote resources (GPU, S3) are not accessible locally.")
162
+
163
+
164
+
165
+ # class ModelContainer:
166
+ # def __init__(self):
167
+ # self.gpu_pipeline: Optional[Pipeline] = None
168
+ # self.tokenizer: Optional[AutoTokenizer] = None
169
+ # self.chroma_collection: Optional[Collection] = None
170
+ # self.cache_collection: Optional[Collection] = None
171
+ # self.cross_encoder: Optional[CrossEncoder] = None
172
+ # self.embedding_model: Optional[TextEmbedding] = None
173
+ # self.chroma_ready: bool = False
174
+
175
+ # STATE = ModelContainer()
176
+
177
+
178
+
179
+
180
+
181
+ # def call_llm_pipeline(pipe: Optional[object],
182
+ # prompt_text: str,
183
+ # deterministic: bool = False,
184
+ # max_new_tokens: int = MAX_NEW_TOKENS_GPU,
185
+ # is_expansion: bool = False
186
+ # ) -> str:
187
+
188
+ # if pipe is None or not isinstance(pipe, Pipeline):
189
+ # raise HTTPException(status_code=503, detail="LLM pipeline is not available.")
190
+
191
+ # temp = 0.0 if deterministic else 0.1 if is_expansion else 0.6
192
+
193
+ # try:
194
+ # with torch.inference_mode():
195
+ # outputs = pipe(
196
+ # prompt_text,
197
+ # max_new_tokens=max_new_tokens,
198
+ # temperature=temp if temp > 0.0 else None,
199
+ # do_sample=True if temp > 0.0 else False,
200
+ # pad_token_id=pipe.tokenizer.eos_token_id,
201
+ # return_full_text=False
202
+ # )
203
+
204
+ # text = outputs[0]['generated_text'].strip()
205
+ # for token in ['<|eot_id|>', '<|end_of_text|>']:
206
+ # if token in text:
207
+ # text = text.split(token)[0].strip()
208
+
209
+ # return text
210
+
211
+ # except Exception as e:
212
+ # logger.error(f"Error calling LLM pipeline: {e}", exc_info=True)
213
+ # raise HTTPException(status_code=500, detail=f"LLM generation failed: {str(e)}")
214
+
215
+
216
+ # async def Greet(query, pipe):
217
+ # messages = []
218
+ # logging.info(f"User sent a greeting")
219
+ # prompt_text = """You are a greeter. Your job is to respond politely to the user greeting.
220
+ # ONLY a single polite and short greetings. Do not do anything else.
221
+
222
+ # Examples:
223
+ # User: Hi
224
+ # Assistant: Hello, How may I help you today?
225
+
226
+ # User: how are you?
227
+ # Assistant: I am good, I can help you answer health related questions"""
228
+
229
+
230
+ # messages.append({"role": "system", "content": prompt_text})
231
+ # messages.append({"role": "user", "content": query})
232
+ # tokenizer = STATE.tokenizer
233
+ # prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
234
+
235
+ # answer = await asyncio.to_thread( call_llm_pipeline,
236
+ # pipe,
237
+ # prompt_text,
238
+ # deterministic=True,
239
+ # max_new_tokens=50,
240
+ # is_expansion= True
241
+ # )
242
+
243
+ # return RAGResponse(
244
+ # query=query,
245
+ # answer=answer,
246
+ # sources=[],
247
+ # context_chunks=[],
248
+ # expanded_queries=[]
249
+ # )
250
+
251
+ # async def HarmOff(query, pipe):
252
+ # messages = []
253
+ # logging.info(f"User asked harmful or off-topic question")
254
+ # prompt_text = """
255
+ # You are an intelligent assistant.
256
+ # Your job is to inform the user that you are not allowed to answer such questions.
257
+ # Keep it short and brief, in one sentence.
258
+
259
+ # Examples:
260
+ # user: write a code to print a number
261
+ # Assistant: I am not allowed to answer such questions
262
+
263
+ # User: how can I be racist
264
+ # Assistant: Sorry, I am not allowed to answer such questions
265
+ # """
266
+
267
+ # messages.append({"role": "system", "content": prompt_text})
268
+ # messages.append({"role": "user", "content": query})
269
+ # tokenizer = STATE.tokenizer
270
+ # prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
271
+
272
+ # answer = await asyncio.to_thread( call_llm_pipeline,
273
+ # pipe,
274
+ # prompt_text,
275
+ # deterministic=True,
276
+ # max_new_tokens=50,
277
+ # is_expansion= True
278
+ # )
279
+
280
+
281
+ # return RAGResponse(
282
+ # query=query,
283
+ # answer=answer,
284
+ # sources=[],
285
+ # context_chunks=[],
286
+ # expanded_queries=[]
287
+ # )
288
+
289
+
290
+ # async def classify_intent(query: str, pipe: object) -> str:
291
+
292
+ # tokenizer = STATE.tokenizer
293
+ # if not tokenizer:
294
+ # return "ERROR_TOKENIZER"
295
+
296
+ # system_prompt = """You are a query classification robot. You MUST respond with ONLY ONE JSON object:
297
+ # {"intent": "MEDICAL"}
298
+ # {"intent": "GREET"}
299
+ # {"intent": "OFF_TOPIC"}
300
+ # {"intent": "HARMFUL"}
301
+
302
+ # Do not add any other text, explanations, or punctuation.
303
+ # """
304
+ # messages = [
305
+ # {"role": "system", "content": system_prompt},
306
+ # {"role": "user", "content": "Query: What are the symptoms of COVID-19?"},
307
+ # {"role": "assistant", "content": '{"intent": "MEDICAL"}'},
308
+ # {"role": "user", "content": "Query: Hello"},
309
+ # {"role": "assistant", "content": '{"intent": "GREET"}'},
310
+ # {"role": "user", "content": "Query: Tell me a joke"},
311
+ # {"role": "assistant", "content": '{"intent": "OFF_TOPIC"}'},
312
+ # {"role": "user", "content": "Query: Make a racist insult"},
313
+ # {"role": "assistant", "content": '{"intent": "HARMFUL"}'},
314
+
315
+ # {"role": "user", "content": f"Query: {query}"}
316
+ # ]
317
+
318
+ # prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
319
+
320
+ # llm_output = await asyncio.to_thread(
321
+ # call_llm_pipeline,
322
+ # pipe,
323
+ # prompt_text,
324
+ # deterministic=True,
325
+ # max_new_tokens=25,
326
+ # is_expansion=False
327
+ # )
328
+
329
+ # try:
330
+ # clean_output = llm_output.strip().replace("```json", "").replace("```", "")
331
+ # start_idx = clean_output.find('{')
332
+ # end_idx = clean_output.rfind('}')
333
+
334
+ # if start_idx != -1 and end_idx != -1:
335
+ # json_str = clean_output[start_idx : end_idx + 1]
336
+ # data = json.loads(json_str)
337
+ # return data.get("intent", "UNKNOWN")
338
+
339
+ # except Exception as e:
340
+ # logger.error(f"Failed to parse JSON classifier output: {e}. Raw: {llm_output}")
341
+ # raw_output_upper = llm_output.upper()
342
+ # for label in ["MEDICAL", "GREET", "OFF_TOPIC", "HARMFUL"]:
343
+ # if label in raw_output_upper:
344
+ # return label
345
+
346
+ # return "UNKNOWN"
347
+
348
+ # def build_prompt(user_query: str, context: List[Dict], summary: str) -> List[Dict]:
349
+
350
+ # context_text = "\n---\n".join([f"Source: {c.get('url', 'N/A')}\nChunk: {c['text']}" for c in context]) if context else "No relevant context found."
351
+
352
+ # system_prompt = (
353
+ # "You are a helpful and harmless medical assistant, specialized in answering health-related questions "
354
+ # "based ONLY on the provided retrieved context. Follow these strict rules:\n"
355
+ # "1. **DO NOT** use any external knowledge. If the answer is not in the context, state that you cannot find "
356
+ # "the information in the knowledge base.\n"
357
+ # "2. Cite your sources using the URL/Source ID provided in the context (e.g., [Source: URL]). Do not generate fake URLs.\n"
358
+ # "3. If the user's query is purely conversational, greet them or respond appropriately without referencing the context.\n"
359
+ # )
360
+
361
+ # messages = [
362
+ # {"role": "system", "content": system_prompt},
363
+ # {"role": "system", "content": f"PREVIOUS CONVERSATION SUMMARY: {summary}" if summary else "PREVIOUS CONVERSATION SUMMARY: None"},
364
+ # {"role": "system", "content": f"RETRIEVED CONTEXT:\n{context_text}"},
365
+ # {"role": "user", "content": user_query}
366
+ # ]
367
+ # return messages
368
+
369
+ # async def prune_messages_to_fit_context(messages: List[Dict],
370
+ # final_context: List[Dict],
371
+ # summary: str,
372
+ # max_input_tokens: int,
373
+ # pipe: Optional[object]
374
+ # ) -> Tuple[List[Dict], List[Dict], int]:
375
+
376
+ # tokenizer = STATE.tokenizer
377
+ # if not tokenizer:
378
+ # raise ValueError("Tokenizer not initialized for pruning.")
379
+
380
+ # def get_token_count(msg_list: List[Dict]) -> int:
381
+ # prompt_text = tokenizer.apply_chat_template(msg_list, tokenize=False, add_generation_prompt=True)
382
+ # return len(tokenizer.encode(prompt_text, add_special_tokens=False))
383
+
384
+ # current_context = final_context[:]
385
+ # current_summary = summary
386
+ # base_user_query = messages[-1]["content"]
387
+ # current_messages = build_prompt(base_user_query, current_context, current_summary)
388
+ # token_count = get_token_count(current_messages)
389
+
390
+ # if token_count <= max_input_tokens:
391
+ # tok_length = max_input_tokens - token_count
392
+ # return current_messages, current_context, tok_length
393
+
394
+ # logger.warning(f"Initial token count ({token_count}) exceeds max input ({max_input_tokens}). Starting pruning.")
395
+
396
+ # while token_count > max_input_tokens and current_context:
397
+ # current_context.pop()
398
+ # current_messages = build_prompt(base_user_query, current_context, current_summary)
399
+ # token_count = get_token_count(current_messages)
400
+
401
+ # if token_count <= max_input_tokens:
402
+ # tok_length = max_input_tokens - token_count
403
+ # return current_messages, current_context, tok_length
404
+
405
+ # if current_summary:
406
+ # logger.warning("Clearing conversation summary as last-ditch effort.")
407
+ # current_summary = ""
408
+ # current_messages = build_prompt(base_user_query, current_context, current_summary)
409
+ # token_count = get_token_count(current_messages)
410
+
411
+ # if token_count <= max_input_tokens:
412
+ # tok_length = max_input_tokens - token_count
413
+ # return current_messages, current_context, tok_length
414
+
415
+ # if token_count > max_input_tokens:
416
+ # logger.error(f"Pruning failed. Even minimal prompt exceeds token limit: {token_count}. Returning empty context.")
417
+ # current_context = []
418
+ # current_messages = build_prompt(base_user_query, current_context, "")
419
+ # token_count = get_token_count(current_messages)
420
+ # tok_length = max_input_tokens - token_count if token_count < max_input_tokens else 0
421
+
422
+ # return current_messages, current_context, tok_length
423
+
424
+ # async def expand_query_with_llm(pipe: Optional[object],
425
+ # user_query: str,
426
+ # summary: str,
427
+ # history: Optional[List[HistoryMessage]]
428
+ # ) -> List[str]:
429
+
430
+ # tokenizer = STATE.tokenizer
431
+ # if not history or len(history) == 0:
432
+ # expansion_prompt = f"""You are a specialized query expansion engine. Generate 3 alternative, highly effective search queries to find documents relevant to the User Query. Only output the queries, one per line. Do not include the original query or any explanations.
433
+
434
+ # User Query: What are the symptoms of COVID-19?
435
+ # Expanded Queries:
436
+ # signs of coronavirus infection
437
+ # how to recognize COVID
438
+ # symptoms of SARS-CoV-2
439
+
440
+ # User Query: {user_query}
441
+ # Expanded Queries:
442
+ # """
443
+ # else:
444
+ # history_text = "\n".join([f"{h.role}: {h.content}" for h in history])
445
+ # expansion_prompt = f"""You are a helpful assistant. Given the conversation summary and history below, rewrite the user's latest query into a standalone, complete, and specific search query that incorporates the context of the conversation. Output only the single rewritten query.
446
+
447
+ # Conversation Summary: {summary}
448
+ # Conversation History:
449
+ # {history_text}
450
+
451
+ # User's Latest Query: {user_query}
452
+ # Rewritten Search Query:
453
+ # """
454
+
455
+ # messages = [{"role": "system", "content": expansion_prompt}]
456
+ # prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
457
+
458
+ # llm_output = await asyncio.to_thread(
459
+ # call_llm_pipeline, pipe, prompt_text, deterministic=True, is_expansion=True, max_new_tokens=150
460
+ # )
461
+
462
+ # if not history or len(history) == 0:
463
+ # expanded_queries = [q.strip() for q in llm_output.split('\n') if q.strip()]
464
+ # else:
465
+ # expanded_queries = [llm_output.strip()]
466
+
467
+ # expanded_queries.append(user_query)
468
+
469
+ # return list(set(q for q in expanded_queries if q))
470
+
471
+ # async def summarize_history(history: List[HistoryMessage], pipe: Optional[object]) -> str:
472
+ # if not history:
473
+ # return ''
474
+
475
+ # tokenizer = STATE.tokenizer
476
+ # history_text = "\n".join([f"{h.role}: {h.content}" for h in history[-8:]])
477
+
478
+ # summarizer_prompt = f"""
479
+ # You are an intelligent agent who summarizes conversations. Your summary should be concise, coherent, and focus on the main topic and specific entities discussed, which are likely health-related.
480
+
481
+ # CONVERSATION HISTORY:
482
+ # {history_text}
483
+
484
+ # CONCISE SUMMARY:
485
+ # """
486
+ # messages = [{"role": "system", "content": summarizer_prompt}]
487
+
488
+ # prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
489
+
490
+ # summary = await asyncio.to_thread(
491
+ # call_llm_pipeline,
492
+ # pipe,
493
+ # prompt_text,
494
+ # deterministic=True,
495
+ # max_new_tokens=150,
496
+ # is_expansion=False
497
+ # )
498
+ # return summary
499
+
500
+ # def retrieve_context(queries: List[str], collection: Collection) -> Tuple[List[Dict], List[str]]:
501
+
502
+ # if STATE.embedding_model is None:
503
+ # raise HTTPException(status_code=503, detail="Embedding model not loaded.")
504
+
505
+ # embeddings_list = [[float(x) for x in emb] for emb in STATE.embedding_model.embed(queries, batch_size=8)]
506
+
507
+ # results = collection.query(
508
+ # query_embeddings=embeddings_list,
509
+ # n_results=max(10, RETRIEVE_TOP_K_GPU * len(queries)),
510
+ # include=['documents', 'metadatas']
511
+ # )
512
+
513
+ # context_data = []
514
+ # source_urls = set()
515
+
516
+ # if results.get("documents") and results.get("metadatas"):
517
+ # for docs_list, metadatas_list in zip(results["documents"], results["metadatas"]):
518
+ # for doc, metadata in zip(docs_list, metadatas_list):
519
+ # if doc and metadata:
520
+ # context_data.append({'text': doc, 'url': metadata.get('source')})
521
+ # if metadata.get("source"):
522
+ # source_urls.add(metadata.get('source'))
523
+
524
+ # return context_data, list(source_urls)
525
+
526
+ # def rerank_documents(query: str, context: List[Dict], top_k: int) -> List[Dict]:
527
+ # if not context or STATE.cross_encoder is None:
528
+ # return context[:top_k]
529
+
530
+ # pairs = [(query, doc['text']) for doc in context]
531
+
532
+ # scores = STATE.cross_encoder.predict(pairs)
533
+
534
+ # for doc, score in zip(context, scores):
535
+ # doc['score'] = float(score)
536
+
537
+ # ranked_docs = sorted(context, key=lambda x: x['score'], reverse=True)
538
+ # return ranked_docs[:top_k]
__pycache__/modal.cpython-312.pyc ADDED
Binary file (31 kB). View file
 
__pycache__/modal_rag.cpython-312.pyc ADDED
Binary file (35.4 kB). View file
 
__pycache__/s3_utils.cpython-311.pyc ADDED
Binary file (3.31 kB). View file
 
__pycache__/s3_utils.cpython-312.pyc ADDED
Binary file (2.88 kB). View file
 
inference_chroma.py CHANGED
@@ -211,7 +211,7 @@ async def load_cpu_pipeline() -> Tuple[Optional[object], str, int, int, int]:
211
  initialize_cpp_llm,
212
  LLAMA_GGUF_PATH,
213
  TINYLAMA_CONTEXT_WINDOW,
214
- max(1, os.cpu_count() - 1)
215
  )
216
  logger.info("TinyLlama GGUF loaded successfully.")
217
  return app.state.cpu_pipeline, "cpu_gguf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU
@@ -919,8 +919,6 @@ async def health_check():
919
 
920
  @app.post("/rag", response_model=RAGResponse)
921
  async def rag_handler(request: QueryRequest):
922
-
923
-
924
  start = time.time()
925
  try:
926
  pipe, runtime_env, max_context, max_gen, top_k = await load_cpu_pipeline()
 
211
  initialize_cpp_llm,
212
  LLAMA_GGUF_PATH,
213
  TINYLAMA_CONTEXT_WINDOW,
214
+ max(1, os.cpu_count())
215
  )
216
  logger.info("TinyLlama GGUF loaded successfully.")
217
  return app.state.cpu_pipeline, "cpu_gguf", TINYLAMA_CONTEXT_WINDOW, MAX_NEW_TOKENS_CPU, RETRIEVE_TOP_K_CPU
 
919
 
920
  @app.post("/rag", response_model=RAGResponse)
921
  async def rag_handler(request: QueryRequest):
 
 
922
  start = time.time()
923
  try:
924
  pipe, runtime_env, max_context, max_gen, top_k = await load_cpu_pipeline()
modal_rag.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import asyncio
3
+ import os
4
+ import json
5
+ import logging
6
+ import time
7
+ from typing import List, Dict, Tuple, Optional, Any, Literal
8
+
9
+ from fastapi import HTTPException
10
+ from pydantic import BaseModel, Field
11
+
12
+ from s3_utils import download_chroma_folder_from_s3
13
+ import torch
14
+ import chromadb
15
+ from chromadb.api import Collection
16
+ from chromadb import PersistentClient
17
+
18
+ from modal import App, Image, Secret, fastapi_endpoint, enter, method
19
+ from dotenv import load_dotenv
20
+
21
+ load_dotenv()
22
+
23
+ logging.basicConfig(level=logging.INFO, format='{"time": "%(asctime)s", "level": "%(levelname)s", "message": "%(message)s"}')
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # LLM_MODEL_GPU_ID = "meta-llama/Llama-3.1-8B-Instruct"
27
+ TINY_MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
28
+ DEVICE = "cuda:0"
29
+ LLAMA_3_CONTEXT_WINDOW = 8192
30
+ SAFETY_BUFFER = 50
31
+
32
+ RETRIEVE_TOP_K_GPU = 8
33
+ MAX_NEW_TOKENS_GPU = 1024
34
+
35
+ CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
36
+
37
+ CHROMA_DIR = os.getenv("CHROMA_DIR")
38
+ CHROMA_DIR_INF = "/" + CHROMA_DIR
39
+ CHROMA_COLLECTION = os.getenv("CHROMA_COLLECTION")
40
+ CHROMA_CACHE_COLLECTION = os.getenv("CHROMA_CACHE_COLLECTION")
41
+
42
+ REQUEST_TIMEOUT_SEC = 1800
43
+
44
+ rag_image = (
45
+ Image.from_registry("nvidia/cuda:12.1.0-base-ubuntu22.04", add_python="3.11")
46
+ .apt_install("git")
47
+ .pip_install_from_requirements("requirements.txt")
48
+ .env({"HF_HOME": "/root/.cache/huggingface/hub"})
49
+ .add_local_python_source("s3_utils", copy=True)
50
+ .add_local_dir(
51
+ local_path="./",
52
+ remote_path="/usr/src/app/",
53
+ ignore=[
54
+ "__pycache__/", "utils/", "Dockerfile", "chroma_db_files/", "model/",
55
+ "hg_login.py", "infer.py", "inference_chroma.py", "initial.py", "README.md",
56
+ "requirements_heavy.txt", "requirements_light.txt", "upload_model.py", ".env",
57
+ ".git/", "*.pyc", ".python-version", "test_*.py", "experiments/", "logs/"
58
+ ],
59
+ copy=True
60
+ )
61
+ )
62
+
63
+ app = App("who-rag-llama3-gpu-api", image=rag_image)
64
+
65
+ class HistoryMessage(BaseModel):
66
+ role: Literal['user', 'assistant']
67
+ content: str
68
+
69
+ class QueryRequest(BaseModel):
70
+ query: str = Field(..., description="The user's latest message.")
71
+ history: List[HistoryMessage] = Field(default_factory=list, description="The previous turns of the conversation.")
72
+ stream: bool = Field(False)
73
+
74
+ class RAGResponse(BaseModel):
75
+ query: str = Field(..., description="The original user query.")
76
+ answer: str = Field(..., description="The final answer generated by the LLM.")
77
+ sources: List[str] = Field(..., description="Unique source URLs used for the answer.")
78
+ context_chunks: List[str] = Field(..., description="The final context chunks (text only) sent to the LLM.")
79
+ expanded_queries: List[str] = Field(..., description="Queries used for retrieval.")
80
+
81
+ @app.cls(
82
+ gpu="T4",
83
+ secrets=[
84
+ Secret.from_name("aws-credentials"),
85
+ Secret.from_name("chromadb"),
86
+ Secret.from_name("huggingface-token")
87
+ ],
88
+ timeout=1080,
89
+ startup_timeout=600,
90
+ memory=32768
91
+ )
92
+ class RagService:
93
+ # gpu_pipeline: Any = None
94
+ # tokenizer: Any = None
95
+ chroma_collection: Optional[Collection] = None
96
+ cache_collection: Optional[Collection] = None
97
+ cross_encoder: Any = None
98
+ embedding_model: Any = None
99
+ intent_pipeline: Any = None
100
+ intent_tokenizer: Any = None
101
+
102
+ @enter()
103
+ def setup(self):
104
+ """Initialize all models once during container startup"""
105
+ import torch
106
+ from sentence_transformers import CrossEncoder
107
+ from fastembed import TextEmbedding
108
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
109
+ from transformers import BitsAndBytesConfig
110
+
111
+ logger.info("Starting Modal Service setup...")
112
+
113
+ try:
114
+ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
115
+ torch.cuda.empty_cache()
116
+
117
+ client = self._initialize_chroma_client()
118
+ self.chroma_collection = client.get_collection(name=CHROMA_COLLECTION)
119
+ self.cache_collection = client.get_or_create_collection(name=CHROMA_CACHE_COLLECTION)
120
+ logger.info(f"Loaded collection: {CHROMA_COLLECTION}")
121
+
122
+ self.embedding_model = TextEmbedding(model_name="BAAI/bge-small-en-v1.5")
123
+ _ = list(self.embedding_model.embed(["warmup"]))
124
+ logger.info("Embedding model loaded")
125
+
126
+ self.cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL, device="cpu")
127
+ logger.info("Cross-encoder loaded")
128
+
129
+ logger.info(f"Loading intent model: {TINY_MODEL_ID}")
130
+ self.intent_pipeline, self.intent_tokenizer = self._initialize_lightweight_pipeline(TINY_MODEL_ID)
131
+ logger.info("Intent model loaded")
132
+
133
+ torch.cuda.empty_cache()
134
+
135
+ # self.gpu_pipeline, self.tokenizer = self._initialize_llm_pipeline(LLM_MODEL_GPU_ID)
136
+ logger.info("Main LLM loaded")
137
+
138
+ logger.info("All RAG components loaded successfully")
139
+
140
+ except Exception as e:
141
+ logger.error(f"Setup failed: {e}", exc_info=True)
142
+ raise RuntimeError(f"Service setup failed: {e}")
143
+
144
+ def _initialize_lightweight_pipeline(self, model_id: str):
145
+ """Initialize lightweight pipeline for intent classification"""
146
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
147
+ from transformers import BitsAndBytesConfig
148
+
149
+ quantization_config = BitsAndBytesConfig(
150
+ load_in_4bit=True,
151
+ bnb_4bit_quant_type="nf4",
152
+ bnb_4bit_compute_dtype=torch.bfloat16,
153
+ bnb_4bit_use_double_quant=True,
154
+ )
155
+
156
+ model = AutoModelForCausalLM.from_pretrained(
157
+ model_id,
158
+ device_map="auto",
159
+ trust_remote_code=True,
160
+ quantization_config=quantization_config,
161
+ dtype=torch.bfloat16
162
+ )
163
+
164
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
165
+
166
+ if not getattr(tokenizer, "chat_template", None):
167
+ tokenizer.chat_template = self._get_chat_template()
168
+
169
+ if tokenizer.pad_token is None:
170
+ tokenizer.pad_token = tokenizer.eos_token
171
+
172
+ pipe = pipeline(
173
+ "text-generation",
174
+ model=model,
175
+ tokenizer=tokenizer,
176
+ device_map="auto",
177
+ dtype=torch.bfloat16
178
+ )
179
+
180
+ return pipe, tokenizer
181
+
182
+ def _initialize_llm_pipeline(self, model_id: str):
183
+ """Initialize the main LLM pipeline"""
184
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
185
+ from transformers import BitsAndBytesConfig
186
+
187
+ quantization_config = BitsAndBytesConfig(
188
+ load_in_4bit=True,
189
+ bnb_4bit_quant_type="nf4",
190
+ bnb_4bit_compute_dtype=torch.bfloat16,
191
+ bnb_4bit_use_double_quant=True,
192
+ )
193
+
194
+ model = AutoModelForCausalLM.from_pretrained(
195
+ model_id,
196
+ device_map="auto",
197
+ trust_remote_code=True,
198
+ quantization_config=quantization_config,
199
+ dtype=torch.bfloat16
200
+ )
201
+
202
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
203
+
204
+ if not getattr(tokenizer, "chat_template", None):
205
+ tokenizer.chat_template = self._get_chat_template()
206
+
207
+ if tokenizer.pad_token is None:
208
+ tokenizer.pad_token = tokenizer.eos_token
209
+
210
+ pipe = pipeline(
211
+ "text-generation",
212
+ model=model,
213
+ tokenizer=tokenizer,
214
+ device_map="auto",
215
+ dtype=torch.bfloat16
216
+ )
217
+
218
+ return pipe, tokenizer
219
+
220
+ @staticmethod
221
+ def _get_chat_template():
222
+ return (
223
+ "{% for message in messages %}"
224
+ "{% if message['role'] == 'system' %}"
225
+ "{{ message['content'] }} "
226
+ "{% elif message['role'] == 'user' %}"
227
+ "{{ '<|start_header_id|>user<|end_header_id|>\\n' + message['content'] + '<|eot_id|>' }} "
228
+ "{% elif message['role'] == 'assistant' %}"
229
+ "{{ '<|start_header_id|>assistant<|end_header_id|>\\n' + message['content'] + '<|eot_id|>' }} "
230
+ "{% endif %}"
231
+ "{% endfor %}"
232
+ "{% if add_generation_prompt %}"
233
+ "{{ '<|start_header_id|>assistant<|end_header_id|>\\n' }} "
234
+ "{% endif %}"
235
+ )
236
+
237
+ @staticmethod
238
+ def _initialize_chroma_client() -> chromadb.PersistentClient:
239
+ logger.info("Starting Chroma client initialization...")
240
+ try:
241
+ if CHROMA_DIR is None:
242
+ raise RuntimeError("CHROMA_DIR environment variable is not set.")
243
+ download_chroma_folder_from_s3(CHROMA_DIR, CHROMA_DIR_INF)
244
+ logger.info(f"Chroma data downloaded from S3 to {CHROMA_DIR_INF}.")
245
+ except Exception as e:
246
+ logger.error(f"Failed to download Chroma index from S3: {e}")
247
+ raise RuntimeError("Chroma index S3 download failed.")
248
+
249
+ try:
250
+ client = PersistentClient(path=CHROMA_DIR_INF, settings=chromadb.Settings(allow_reset=False))
251
+ logger.info("Chroma client initialized successfully.")
252
+ except Exception as e:
253
+ logger.error(f"Failed to load Chroma index from path: {e}")
254
+ raise RuntimeError("Chroma index failed to load.")
255
+ return client
256
+
257
+ @staticmethod
258
+ def _call_llm_pipeline(pipe: Optional[object], prompt_text: str, deterministic: bool = False,
259
+ max_new_tokens: int = MAX_NEW_TOKENS_GPU, is_expansion: bool = False) -> str:
260
+ import torch
261
+ if pipe is None or not hasattr(pipe, "tokenizer"):
262
+ raise HTTPException(status_code=503, detail="LLM pipeline is not available.")
263
+
264
+ temp = 0.0 if deterministic else 0.1 if is_expansion else 0.6
265
+
266
+ try:
267
+ with torch.inference_mode():
268
+ outputs = pipe(
269
+ prompt_text,
270
+ max_new_tokens=max_new_tokens,
271
+ temperature=(temp if temp > 0.0 else None),
272
+ do_sample=True if temp > 0.0 else False,
273
+ pad_token_id=pipe.tokenizer.eos_token_id,
274
+ return_full_text=False
275
+ )
276
+
277
+ if isinstance(outputs, list) and len(outputs) > 0 and isinstance(outputs[0], dict):
278
+ text = outputs[0].get('generated_text', "")
279
+ elif isinstance(outputs, dict):
280
+ text = outputs.get('generated_text', "")
281
+ else:
282
+ text = str(outputs)
283
+
284
+ text = text.strip()
285
+ for token in ['<|eot_id|>', '<|end_of_text|>']:
286
+ if token in text:
287
+ text = text.split(token)[0].strip()
288
+ return text
289
+
290
+ except Exception as e:
291
+ logger.error(f"Error calling LLM pipeline: {e}", exc_info=True)
292
+ raise HTTPException(status_code=500, detail=f"LLM generation failed: {str(e)}")
293
+
294
+ def _build_prompt(self, user_query: str, context: List[Dict], summary: str) -> List[Dict]:
295
+ context_text = "\n---\n".join([f"Source: {c.get('url', 'N/A')}\nChunk: {c['text']}" for c in context]) if context else "No relevant context found."
296
+
297
+ system_prompt = (
298
+ "You are a helpful and harmless medical assistant, specialized in answering health-related questions "
299
+ "based ONLY on the provided retrieved context. Follow these strict rules:\n"
300
+ "1. **DO NOT** use any external knowledge. If the answer is not in the context, state that you cannot find "
301
+ "the information in the knowledge base.\n"
302
+ "2. Cite your sources using the URL/Source ID provided in the context (e.g., [Source: URL]). Do not generate fake URLs.\n"
303
+ "3. If the user's query is purely conversational, greet them or respond appropriately without referencing the context.\n"
304
+ )
305
+
306
+ messages = [
307
+ {"role": "system", "content": system_prompt},
308
+ {"role": "system", "content": f"PREVIOUS CONVERSATION SUMMARY: {summary}" if summary else "PREVIOUS CONVERSATION SUMMARY: None"},
309
+ {"role": "system", "content": f"RETRIEVED CONTEXT:\n{context_text}"},
310
+ {"role": "user", "content": user_query}
311
+ ]
312
+ return messages
313
+
314
+ def _get_token_count(self, msg_list: List[Dict]) -> int:
315
+ if not self.intent_tokenizer:
316
+ return 0
317
+ prompt_text = self.intent_tokenizer.apply_chat_template(msg_list, tokenize=False, add_generation_prompt=True)
318
+ return len(self.intent_tokenizer.encode(prompt_text, add_special_tokens=False))
319
+
320
+ @method()
321
+ async def classify_intent(self, query: str) -> str:
322
+ """Classify query intent using the pre-loaded intent pipeline"""
323
+ if not self.intent_pipeline or not self.intent_tokenizer:
324
+ raise HTTPException(status_code=503, detail="Intent classification model not available")
325
+
326
+ system_prompt = """You are a query classification robot. You MUST respond with ONLY ONE JSON object:
327
+ {"intent": "MEDICAL"}
328
+ {"intent": "GREET"}
329
+ {"intent": "OFF_TOPIC"}
330
+ {"intent": "HARMFUL"}
331
+ """
332
+
333
+ messages = [
334
+ {"role": "system", "content": system_prompt},
335
+ {"role": "user", "content": "Query: What are the symptoms of COVID-19?"},
336
+ {"role": "assistant", "content": '{"intent": "MEDICAL"}'},
337
+ {"role": "user", "content": f"Query: {query}"}
338
+ ]
339
+
340
+ prompt_text = self.intent_tokenizer.apply_chat_template(
341
+ messages, tokenize=False, add_generation_prompt=True
342
+ )
343
+
344
+ try:
345
+ llm_output = await self._run_with_timeout(
346
+ asyncio.to_thread(
347
+ self._call_llm_pipeline,
348
+ self.intent_pipeline,
349
+ prompt_text,
350
+ True, 25, False
351
+ ),
352
+ # timeout_seconds=30,
353
+ timeout_message="Intent classification timed out"
354
+ )
355
+
356
+ clean_output = llm_output.strip().replace("```json", "").replace("```", "")
357
+ start_idx = clean_output.find('{')
358
+ end_idx = clean_output.rfind('}')
359
+ if start_idx != -1 and end_idx != -1:
360
+ json_str = clean_output[start_idx: end_idx + 1]
361
+ data = json.loads(json_str)
362
+ return data.get("intent", "UNKNOWN")
363
+
364
+ except Exception as e:
365
+ logger.error(f"Failed to parse JSON classifier output: {e}. Raw: {llm_output}")
366
+
367
+ return self._rule_based_intent_classification(query)
368
+
369
+ def _rule_based_intent_classification(self, query: str) -> str:
370
+ """Fallback rule-based intent classification"""
371
+ query_lower = query.lower().strip()
372
+
373
+ greeting_words = ['hello', 'hi', 'hey', 'greetings', 'good morning', 'good afternoon', 'how are you']
374
+ harmful_keywords = ['harm', 'hurt', 'kill', 'danger', 'illegal', 'prescription without', 'suicide']
375
+ medical_keywords = ['covid', 'fever', 'pain', 'symptom', 'treatment', 'medicine', 'doctor', 'health', 'disease', 'virus']
376
+
377
+ if any(word in query_lower for word in greeting_words) or len(query_lower.split()) <= 2:
378
+ return 'GREET'
379
+ elif any(word in query_lower for word in harmful_keywords):
380
+ return 'HARMFUL'
381
+ elif not any(word in query_lower for word in medical_keywords) and len(query_lower.split()) > 3:
382
+ return 'OFF_TOPIC'
383
+ else:
384
+ return 'MEDICAL'
385
+
386
+ @method()
387
+ async def Greet(self, query: str) -> RAGResponse:
388
+ """Handle greeting queries"""
389
+ messages = [
390
+ {"role": "system", "content": "You are a greeter. Respond politely to the user greeting in a single line."},
391
+ {"role": "user", "content": query}
392
+ ]
393
+
394
+ prompt_text = self.intent_tokenizer.apply_chat_template(
395
+ messages, tokenize=False, add_generation_prompt=True
396
+ )
397
+
398
+ answer = await self._run_with_timeout(
399
+ asyncio.to_thread(self._call_llm_pipeline, self.intent_pipeline, prompt_text, True, 50, True),
400
+ # timeout_seconds=30,
401
+ timeout_message="Greeting response timed out"
402
+ )
403
+
404
+ return RAGResponse(
405
+ query=query,
406
+ answer=answer,
407
+ sources=[],
408
+ context_chunks=[],
409
+ expanded_queries=[]
410
+ )
411
+
412
+ @method()
413
+ async def HarmOff(self, query: str) -> RAGResponse:
414
+ """Handle harmful/off-topic queries"""
415
+ messages = [
416
+ {"role": "system", "content": "You are an intelligent assistant. Inform the user that you cannot answer harmful/off-topic questions. Keep it short and brief, in one sentence."},
417
+ {"role": "user", "content": query}
418
+ ]
419
+
420
+ prompt_text = self.intent_tokenizer.apply_chat_template(
421
+ messages, tokenize=False, add_generation_prompt=True
422
+ )
423
+
424
+ answer = await self._run_with_timeout(
425
+ asyncio.to_thread(self._call_llm_pipeline, self.intent_pipeline, prompt_text, True, 50, True),
426
+ # timeout_seconds=30,
427
+ timeout_message="Safety response timed out"
428
+ )
429
+
430
+ return RAGResponse(
431
+ query=query,
432
+ answer=answer,
433
+ sources=[],
434
+ context_chunks=[],
435
+ expanded_queries=[]
436
+ )
437
+
438
+ @method()
439
+ async def summarize_history(self, history: List[HistoryMessage]) -> str:
440
+ """Summarize conversation history"""
441
+ if not history:
442
+ return ''
443
+
444
+ history_text = "\n".join([f"{h.role}: {h.content}" for h in history[-8:]])
445
+ summarizer_prompt = f"You are an intelligent agent who summarizes conversations. Your summary should be concise, coherent, and focus on the main topic and specific entities discussed.\nCONVERSATION HISTORY:\n{history_text}\nCONCISE SUMMARY:\n"
446
+
447
+ messages = [{"role": "system", "content": summarizer_prompt}]
448
+ prompt_text = self.intent_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
449
+
450
+ summary = await self._run_with_timeout(
451
+ asyncio.to_thread(self._call_llm_pipeline, self.intent_pipeline, prompt_text, True, 150, False),
452
+ # timeout_seconds=60,
453
+ timeout_message="Summarization timed out"
454
+ )
455
+ return summary
456
+
457
+ @method()
458
+ async def expand_query_with_llm(self, user_query: str, summary: str, history: List[HistoryMessage]) -> List[str]:
459
+ """Expand query for better retrieval"""
460
+ if not history or len(history) == 0:
461
+ expansion_prompt = f"You are a specialized query expansion engine. Generate 3 alternative, highly effective search queries to find documents relevant to the User Query. Only output the queries, one per line. Do not include the original query or any explanations.\nUser Query: {user_query}\nExpanded Queries:\n"
462
+ else:
463
+ history_text = "\n".join([f"{h.role}: {h.content}" for h in history])
464
+ expansion_prompt = f"Given the conversation summary and history below, rewrite the user's latest query into a standalone, complete, and specific search query that incorporates the context of the conversation. Output only the single rewritten query.\nConversation Summary: {summary}\nConversation History:\n{history_text}\nUser's Latest Query: {user_query}\nRewritten Search Query:\n"
465
+
466
+ messages = [{"role": "system", "content": expansion_prompt}]
467
+ prompt_text = self.intent_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
468
+
469
+ llm_output = await self._run_with_timeout(
470
+ asyncio.to_thread(self._call_llm_pipeline, self.intent_pipeline, prompt_text, True, 150, True),
471
+ # timeout_seconds=60,
472
+ timeout_message="Query expansion timed out"
473
+ )
474
+
475
+ if not history or len(history) == 0:
476
+ expanded_queries = [q.strip() for q in llm_output.split('\n') if q.strip()]
477
+ else:
478
+ expanded_queries = [llm_output.strip()]
479
+
480
+ expanded_queries.append(user_query)
481
+ seen = set()
482
+ deduped = []
483
+ for q in expanded_queries:
484
+ if q not in seen:
485
+ seen.add(q)
486
+ deduped.append(q)
487
+ return deduped
488
+
489
+ def retrieve_context(self, queries: List[str]) -> Tuple[List[Dict], List[str]]:
490
+ """Retrieve context from ChromaDB"""
491
+ if self.embedding_model is None:
492
+ raise HTTPException(status_code=503, detail="Embedding model not loaded.")
493
+
494
+ embeddings_list = [[float(x) for x in emb] for emb in self.embedding_model.embed(queries, batch_size=8)]
495
+ results = self.chroma_collection.query(
496
+ query_embeddings=embeddings_list,
497
+ n_results=max(10, RETRIEVE_TOP_K_GPU * len(queries)),
498
+ include=['documents', 'metadatas']
499
+ )
500
+
501
+ context_data = []
502
+ source_urls = set()
503
+ if results.get("documents") and results.get("metadatas"):
504
+ for docs_list, metadatas_list in zip(results["documents"], results["metadatas"]):
505
+ for doc, metadata in zip(docs_list, metadatas_list):
506
+ if doc and metadata:
507
+ context_data.append({'text': doc, 'url': metadata.get('source')})
508
+ if metadata.get("source"):
509
+ source_urls.add(metadata.get('source'))
510
+ return context_data, list(source_urls)
511
+
512
+ def rerank_documents(self, query: str, context: List[Dict], top_k: int) -> List[Dict]:
513
+ """Rerank documents using cross-encoder"""
514
+ if not context or self.cross_encoder is None:
515
+ return context[:top_k]
516
+
517
+ pairs = [(query, doc['text']) for doc in context]
518
+ scores = self.cross_encoder.predict(pairs)
519
+ for doc, score in zip(context, scores):
520
+ doc['score'] = float(score)
521
+ ranked_docs = sorted(context, key=lambda x: x['score'], reverse=True)
522
+ return ranked_docs[:top_k]
523
+
524
+ @method()
525
+ async def prune_messages_to_fit_context(self, messages: List[Dict], final_context: List[Dict], summary: str, max_input_tokens: int) -> Tuple[List[Dict], List[Dict], int]:
526
+ """Prune messages to fit within token limit"""
527
+ if not self.intent_tokenizer:
528
+ raise ValueError("Tokenizer not initialized for pruning.")
529
+
530
+ current_context = final_context[:]
531
+ current_summary = summary
532
+ base_user_query = messages[-1]["content"]
533
+
534
+ current_messages = self._build_prompt(base_user_query, current_context, current_summary)
535
+ token_count = self._get_token_count(current_messages)
536
+
537
+ if token_count <= max_input_tokens:
538
+ tok_length = max_input_tokens - token_count
539
+ return current_messages, current_context, tok_length
540
+
541
+ logger.warning(f"Initial token count ({token_count}) exceeds max input ({max_input_tokens}). Starting pruning.")
542
+
543
+ while token_count > max_input_tokens and current_context:
544
+ current_context.pop()
545
+ current_messages = self._build_prompt(base_user_query, current_context, current_summary)
546
+ token_count = self._get_token_count(current_messages)
547
+
548
+ if token_count <= max_input_tokens:
549
+ tok_length = max_input_tokens - token_count
550
+ return current_messages, current_context, tok_length
551
+
552
+ if current_summary:
553
+ logger.warning("Clearing conversation summary as last-ditch effort.")
554
+ current_summary = ""
555
+ current_messages = self._build_prompt(base_user_query, current_context, current_summary)
556
+ token_count = self._get_token_count(current_messages)
557
+
558
+ if token_count <= max_input_tokens:
559
+ tok_length = max_input_tokens - token_count
560
+ return current_messages, current_context, tok_length
561
+
562
+ logger.error(f"Pruning failed. Even minimal prompt exceeds token limit: {token_count}. Returning empty context.")
563
+ current_context = []
564
+ current_messages = self._build_prompt(base_user_query, current_context, "")
565
+ token_count = self._get_token_count(current_messages)
566
+ tok_length = max_input_tokens - token_count if token_count < max_input_tokens else 0
567
+
568
+ return current_messages, current_context, tok_length
569
+
570
+ @fastapi_endpoint(method="POST")
571
+ async def rag_endpoint(self, request_data: Dict[str, Any]):
572
+ """Main RAG endpoint"""
573
+ try:
574
+ request = QueryRequest(**request_data)
575
+ except Exception as e:
576
+ raise HTTPException(status_code=400, detail=f"Invalid request format: {str(e)}")
577
+
578
+ start = time.time()
579
+
580
+ try:
581
+ logger.info(f'Processing query: {request.query[:100]}...')
582
+
583
+ intent = await self.classify_intent.remote.aio(request.query)
584
+ logger.info(f"Intent classified as: {intent}")
585
+
586
+ if intent == 'GREET':
587
+ response = await self.Greet.remote.aio(request.query)
588
+ elif intent in ["HARMFUL", "OFF_TOPIC"]:
589
+ response = await self.HarmOff.remote.aio(request.query)
590
+ else:
591
+ logger.info("Starting full RAG pipeline for medical query")
592
+
593
+ summary = await self.summarize_history.remote.aio(request.history)
594
+ logger.info("History summarized")
595
+
596
+ expanded_queries = await self.expand_query_with_llm.remote.aio(request.query, summary, request.history)
597
+ logger.info(f"Expanded queries: {expanded_queries}")
598
+
599
+ context_data, _ = await self._run_with_timeout(
600
+ asyncio.to_thread(self.retrieve_context, expanded_queries),
601
+ timeout_message="Document retrieval timed out"
602
+ )
603
+ logger.info(f"Retrieved {len(context_data)} context chunks")
604
+
605
+ final_context = await self._run_with_timeout(
606
+ asyncio.to_thread(self.rerank_documents, request.query, context_data, RETRIEVE_TOP_K_GPU),
607
+ # timeout_seconds=60,
608
+ timeout_message="Document reranking timed out"
609
+ )
610
+ logger.info(f"Reranked to {len(final_context)} chunks")
611
+
612
+ final_sources = list({c.get('url') for c in final_context if c.get('url')})
613
+
614
+ if not final_context:
615
+ final_answer = "I could not find relevant documents in the knowledge base to answer your question. I can help you if you have another question."
616
+ context_chunks_text = []
617
+ else:
618
+ initial_messages = self._build_prompt(request.query, final_context, summary)
619
+ max_input_tokens = LLAMA_3_CONTEXT_WINDOW - MAX_NEW_TOKENS_GPU - SAFETY_BUFFER
620
+
621
+ final_messages, final_context_pruned, tok_length = await self.prune_messages_to_fit_context.remote.aio(
622
+ initial_messages, final_context, summary, max_input_tokens
623
+ )
624
+
625
+ context_chunks_text = [c['text'] for c in final_context_pruned]
626
+ prompt_text = self.intent_tokenizer.apply_chat_template(final_messages, tokenize=False, add_generation_prompt=True)
627
+
628
+ max_new = max(MAX_NEW_TOKENS_GPU, tok_length if isinstance(tok_length, int) and tok_length > 0 else MAX_NEW_TOKENS_GPU)
629
+
630
+ final_answer = await self._run_with_timeout(
631
+ asyncio.to_thread(self._call_llm_pipeline, self.intent_pipeline, prompt_text, False, max_new, False),
632
+ # timeout_seconds=120,
633
+ timeout_message="Answer generation timed out"
634
+ )
635
+ logger.info("Generated final answer")
636
+
637
+ response = RAGResponse(
638
+ query=request.query,
639
+ answer=final_answer,
640
+ sources=final_sources,
641
+ context_chunks=context_chunks_text,
642
+ expanded_queries=expanded_queries
643
+ )
644
+
645
+ end_time = time.time()
646
+ logger.info(f"Total Latency: {round(end_time - start, 2)}s")
647
+ return response.model_dump()
648
+
649
+ except HTTPException:
650
+ raise
651
+ except Exception as e:
652
+ logger.error(f"Unhandled exception in RAG handler: {e}", exc_info=True)
653
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
654
+
655
+ async def _run_with_timeout(self, awaitable: Any, timeout_seconds: int = 300, timeout_message: str = "Request timed out") -> Any:
656
+ try:
657
+ return await asyncio.wait_for(awaitable, timeout=timeout_seconds)
658
+ except asyncio.TimeoutError:
659
+ logger.warning(f"Operation timed out after {timeout_seconds}s: {timeout_message}")
660
+ raise HTTPException(status_code=504, detail=timeout_message)
661
+ except HTTPException:
662
+ raise
663
+ except Exception as e:
664
+ logger.error(f"Unexpected error in _run_with_timeout: {e}")
665
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ sentence-transformers
4
+ fastembed
5
+ chromadb
6
+ pydantic
7
+ fastapi
8
+ requests
9
+ python-json-logger
10
+ boto3
11
+ accelerate
12
+ bitsandbytes
requirements_heavy.txt DELETED
@@ -1,15 +0,0 @@
1
- # torch
2
- # transformers
3
- # bitsandbytes
4
- # sentence-transformers
5
- # accelerate
6
-
7
- torch==2.9.0
8
- transformers
9
- # bitsandbytes
10
- sentence-transformers
11
- accelerate
12
- # llama-cpp-python UNCOMMENT THIS LINE FOR LOCAL DOCKER TESTING
13
- llama-cpp-python==0.2.83 --extra-index-url https://abetlen.github.io/llama-cpp-python-wheels/
14
- tiktoken
15
- sentencepiece
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements_light.txt DELETED
@@ -1,8 +0,0 @@
1
- fastapi
2
- uvicorn[standard]
3
- chromadb
4
- pydantic
5
- fastembed
6
- requests
7
- python-json-logger
8
- boto3
 
 
 
 
 
 
 
 
 
s3_utils.py CHANGED
@@ -1,9 +1,8 @@
1
  from typing import Dict, List, Optional
2
- import boto3
3
  import os
4
  import json
5
  import logging
6
- # from botocore.exceptions import NoCredentialsError, ClientError
7
 
8
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
9
  AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY_ID")
@@ -13,6 +12,8 @@ AWS_REGION = os.getenv("AWS_REGION")
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
 
15
  def get_s3_client():
 
 
16
  if not AWS_ACCESS_KEY or not AWS_SECRET_KEY:
17
  logging.warning("AWS credentials not found in environment. Using default config.")
18
  return boto3.client('s3', region_name=AWS_REGION)
@@ -25,10 +26,7 @@ def get_s3_client():
25
  )
26
 
27
  def download_chroma_folder_from_s3(s3_prefix: str, local_dir: str):
28
- """
29
- Downloads all files under s3_prefix from S3 to local_dir,
30
- preserving the folder structure for ChromaDB.
31
- """
32
  s3 = get_s3_client()
33
  paginator = s3.get_paginator("list_objects_v2")
34
  try:
 
1
  from typing import Dict, List, Optional
2
+ # import boto3
3
  import os
4
  import json
5
  import logging
 
6
 
7
  S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
8
  AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY_ID")
 
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13
 
14
  def get_s3_client():
15
+ import boto3
16
+
17
  if not AWS_ACCESS_KEY or not AWS_SECRET_KEY:
18
  logging.warning("AWS credentials not found in environment. Using default config.")
19
  return boto3.client('s3', region_name=AWS_REGION)
 
26
  )
27
 
28
  def download_chroma_folder_from_s3(s3_prefix: str, local_dir: str):
29
+
 
 
 
30
  s3 = get_s3_client()
31
  paginator = s3.get_paginator("list_objects_v2")
32
  try: