nyxionlabs commited on
Commit
e2ffd2f
·
verified ·
1 Parent(s): 08c8ad0

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +187 -0
  2. rag_demo.py +346 -0
  3. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pickle, json, gradio as gr
2
+ import numpy as np, faiss
3
+ from sentence_transformers import SentenceTransformer
4
+
5
+ # ---------- Optional LLM (OpenAI) ----------
6
+ OPENAI_API_KEY = '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' # add in Space -> Settings -> Secrets
7
+ USE_OPENAI = bool(OPENAI_API_KEY)
8
+ if USE_OPENAI:
9
+ try:
10
+ from openai import OpenAI
11
+ oai = OpenAI(api_key=OPENAI_API_KEY)
12
+ OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
13
+ except Exception as e:
14
+ print("[RAG] OpenAI not available:", e)
15
+ USE_OPENAI = False
16
+
17
+ # ---------- Artifacts you already have ----------
18
+ FAISS_PATH = os.getenv("FAISS_PATH", "squad_v2.faiss")
19
+ META_PATH = os.getenv("META_PATH", "squad_v2_meta.pkl")
20
+
21
+ CACHE = {
22
+ "index": None,
23
+ "contexts": None,
24
+ "encoder": None,
25
+ "model_name": None
26
+ }
27
+
28
+ def _coerce_text_list(x):
29
+ """Accepts list[str] or list[dict]; extracts text nicely."""
30
+ out = []
31
+ if isinstance(x, list):
32
+ for it in x:
33
+ if isinstance(it, str):
34
+ out.append(it)
35
+ elif isinstance(it, dict):
36
+ # common keys people use
37
+ text = it.get("text") or it.get("content") or it.get("ctx") or ""
38
+ if text:
39
+ out.append(text)
40
+ return out
41
+
42
+ def load_artifacts():
43
+ if CACHE["index"] is not None:
44
+ return
45
+
46
+ # 1) FAISS
47
+ if not os.path.exists(FAISS_PATH):
48
+ raise FileNotFoundError(f"Missing FAISS index: {FAISS_PATH}")
49
+ index = faiss.read_index(FAISS_PATH)
50
+
51
+ # 2) META
52
+ if not os.path.exists(META_PATH):
53
+ raise FileNotFoundError(f"Missing meta file: {META_PATH}")
54
+ with open(META_PATH, "rb") as f:
55
+ meta = pickle.load(f)
56
+
57
+ # parse meta
58
+ model_name = "all-MiniLM-L6-v2"
59
+ contexts = None
60
+
61
+ if isinstance(meta, dict):
62
+ # common keys
63
+ model_name = meta.get("model") or meta.get("encoder") or model_name
64
+ contexts = (
65
+ meta.get("contexts")
66
+ or meta.get("texts")
67
+ or meta.get("documents")
68
+ or meta.get("corpus")
69
+ )
70
+ else:
71
+ # meta is just a list of contexts
72
+ contexts = meta
73
+
74
+ # normalize contexts
75
+ contexts = _coerce_text_list(contexts) if contexts is not None else []
76
+ if not contexts:
77
+ raise ValueError("No contexts found in meta; expected a list of texts.")
78
+
79
+ # Align lengths (safeguard)
80
+ ntotal = index.ntotal
81
+ if ntotal != len(contexts):
82
+ m = min(ntotal, len(contexts))
83
+ print(f"[RAG] WARNING: index.ntotal({ntotal}) != contexts({len(contexts)}). Trimming to {m}.")
84
+ # We can’t resize FAISS easily here; instead trim contexts so we never index out of range.
85
+ contexts = contexts[:m]
86
+
87
+ # 3) load encoder (lazy; we instantiate now to avoid first-click delay)
88
+ encoder = SentenceTransformer(model_name)
89
+
90
+ CACHE.update(index=index, contexts=contexts, encoder=encoder, model_name=model_name)
91
+ print(f"[RAG] Loaded index={FAISS_PATH} (ntotal={CACHE['index'].ntotal}), "
92
+ f"contexts={len(CACHE['contexts'])}, model={CACHE['model_name']}")
93
+
94
+ def _retrieve(question: str, k: int):
95
+ # encode query; FAISS expects float32
96
+ q_emb = CACHE["encoder"].encode([question]).astype("float32")
97
+ D, I = CACHE["index"].search(q_emb, int(k))
98
+ idxs = I[0].tolist()
99
+ dists = D[0].tolist()
100
+ # guard for any out-of-range due to mismatched sizes
101
+ max_ok = len(CACHE["contexts"]) - 1
102
+ pairs = []
103
+ for j, dist in zip(idxs, dists):
104
+ if 0 <= j <= max_ok:
105
+ pairs.append((j, dist, CACHE["contexts"][j]))
106
+ return pairs
107
+
108
+ def _build_prompt(question: str, pairs):
109
+ chunks = []
110
+ for i, (_, _d, ctx) in enumerate(pairs, start=1):
111
+ # keep prompt size reasonable
112
+ ctx_short = ctx.strip()
113
+ if len(ctx_short) > 1200:
114
+ ctx_short = ctx_short[:1200] + "..."
115
+ chunks.append(f"[Source {i}] {ctx_short}")
116
+ context_block = "\n\n".join(chunks) if chunks else "(no context)"
117
+ prompt = f"""Answer strictly from the context below. If not answerable, say so.
118
+ Include [Source X] citations in your answer.
119
+
120
+ Context:
121
+ {context_block}
122
+
123
+ Question: {question}
124
+ Answer:"""
125
+ return prompt
126
+
127
+ def answer(question: str, k: int):
128
+ if not question.strip():
129
+ return "Please enter a question.", [], None
130
+
131
+ pairs = _retrieve(question, k)
132
+ if not pairs:
133
+ return "No results found in the index.", [], None
134
+
135
+ # Build citations list for UI
136
+ citations = [{"rank": i+1, "faiss_dist": round(d, 4), "snippet": ctx[:240] + ("..." if len(ctx) > 240 else "")}
137
+ for i, (_idx, d, ctx) in enumerate(pairs)]
138
+
139
+ if USE_OPENAI:
140
+ prompt = _build_prompt(question, pairs)
141
+ try:
142
+ resp = oai.chat.completions.create(
143
+ model=OPENAI_MODEL,
144
+ messages=[{"role":"user","content":prompt}],
145
+ temperature=0.2
146
+ )
147
+ ans = resp.choices[0].message.content
148
+ except Exception as e:
149
+ ans = f"LLM call failed: {e}\n\nTop results shown below."
150
+ else:
151
+ # Fallback: show top-1 context as the “answer”
152
+ ans = ("(No OPENAI_API_KEY set — showing most relevant context instead.)\n\n" +
153
+ pairs[0][2][:1200])
154
+
155
+ # simple JSON for debugging/export
156
+ raw = {
157
+ "k": int(k),
158
+ "answer": ans,
159
+ "citations": citations
160
+ }
161
+ return ans, citations, json.dumps(raw, indent=2)
162
+
163
+ # ---------- UI ----------
164
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
165
+ gr.Markdown("## Nyxion Labs · Grounded Q&A (SQuAD v2, FAISS)")
166
+
167
+ with gr.Row():
168
+ q = gr.Textbox(label="Ask a question", placeholder="e.g., What is the capital of France?")
169
+ k = gr.Slider(1, 10, value=3, step=1, label="Citations (top-k)")
170
+
171
+ run_btn = gr.Button("Ask")
172
+ ans_md = gr.Markdown(label="Answer")
173
+ cites = gr.Dataframe(headers=["rank","faiss_dist","snippet"], datatype=["number","number","str"],
174
+ row_count=(0,"dynamic"), label="Retrieved contexts")
175
+ raw_json = gr.JSON(label="Debug / raw response")
176
+
177
+ def _startup():
178
+ load_artifacts()
179
+ return "Ready."
180
+
181
+ status = gr.Markdown()
182
+ demo.load(_startup, inputs=None, outputs=status)
183
+ run_btn.click(answer, [q, k], [ans_md, cites, raw_json])
184
+
185
+ if __name__ == "__main__":
186
+ load_artifacts()
187
+ demo.launch()
rag_demo.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- quiet TensorFlow / tf-keras noise (must be first lines in the file) -----
2
+ import os, warnings
3
+
4
+ # Hide TF C++ INFO/WARNING/ERROR levels except errors
5
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")
6
+
7
+ # Stop the oneDNN notice you’re seeing
8
+ os.environ.setdefault("TF_ENABLE_ONEDNN_OPTS", "0")
9
+
10
+ # Silence specific deprecation chatter from tf_keras
11
+ warnings.filterwarnings(
12
+ "ignore",
13
+ category=DeprecationWarning,
14
+ message=r".*tf\.losses\.sparse_softmax_cross_entropy.*",
15
+ )
16
+ # Blanket-ignore DeprecationWarnings originating from tensorflow / tf_keras modules
17
+ warnings.filterwarnings("ignore", category=DeprecationWarning, module=r"^(tensorflow|tf_keras)\b")
18
+ warnings.filterwarnings("ignore", category=UserWarning, module=r"^(tensorflow|tf_keras)\b")
19
+ # -----------------------------------------------------------------------------
20
+
21
+
22
+ # (now your existing imports follow)
23
+
24
+
25
+ from contextlib import asynccontextmanager
26
+ import argparse
27
+ import os
28
+ import sys
29
+ import json
30
+ import time
31
+ import pickle
32
+ import pathlib
33
+ from typing import List, Tuple, Dict, Any
34
+
35
+ import numpy as np
36
+ from tqdm import tqdm
37
+
38
+ # --- DEV-ONLY TOKENS (you asked to avoid .env) --------------------------------
39
+ OPENAI_API_KEY = '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' # <<< put your dev key here
40
+ OPENAI_MODEL = "gpt-4o-mini" # solid + cost-effective for demo
41
+
42
+ # --- Heavy deps ----------------------------------------------------------------
43
+ try:
44
+ import faiss # type: ignore
45
+ except Exception as e:
46
+ print("FAISS is required. pip install faiss-cpu", file=sys.stderr)
47
+ raise
48
+
49
+ try:
50
+ from datasets import load_dataset # type: ignore
51
+ except Exception:
52
+ print("HuggingFace datasets is required. pip install datasets", file=sys.stderr)
53
+ raise
54
+
55
+ try:
56
+ from sentence_transformers import SentenceTransformer # type: ignore
57
+ except Exception:
58
+ print("sentence-transformers is required. pip install sentence-transformers", file=sys.stderr)
59
+ raise
60
+
61
+ try:
62
+ from openai import OpenAI # type: ignore
63
+ except Exception:
64
+ print("openai>=1.0 is required. pip install openai", file=sys.stderr)
65
+ raise
66
+
67
+ # --- Optional API mode ---------------------------------------------------------
68
+ try:
69
+ from fastapi import FastAPI
70
+ from pydantic import BaseModel
71
+ import uvicorn
72
+ FASTAPI_AVAILABLE = True
73
+ except Exception:
74
+ FASTAPI_AVAILABLE = False
75
+
76
+ # --- Paths --------------------------------------------------------------------
77
+ ROOT = pathlib.Path(__file__).resolve().parent
78
+ ART = ROOT / "artifacts"
79
+ ART.mkdir(exist_ok=True)
80
+
81
+ INDEX_FILE = ART / "squad_v2.faiss"
82
+ META_FILE = ART / "squad_v2_meta.pkl"
83
+
84
+ # --- Chunking params -----------------------------------------------------------
85
+ # SQuAD contexts can be long. We chunk for better retrieval quality.
86
+ CHUNK_SIZE = 500 # characters per chunk
87
+ CHUNK_OVERLAP = 100 # overlap to preserve context across boundaries
88
+
89
+ # --- Minimal logging -----------------------------------------------------------
90
+ def log(msg: str):
91
+ print(f"[RAG] {msg}", flush=True)
92
+
93
+ # --- Data prep ----------------------------------------------------------------
94
+ def chunk_text(text: str, chunk_size: int, overlap: int) -> List[str]:
95
+ if not text:
96
+ return []
97
+ chunks = []
98
+ start = 0
99
+ while start < len(text):
100
+ end = min(len(text), start + chunk_size)
101
+ chunks.append(text[start:end])
102
+ if end == len(text):
103
+ break
104
+ start = end - overlap
105
+ if start < 0:
106
+ start = 0
107
+ return chunks
108
+
109
+ def load_and_prepare_squad() -> List[Dict[str, Any]]:
110
+ """
111
+ Returns a list of dicts:
112
+ {
113
+ 'id': str, # synthetic id per chunk
114
+ 'title': str,
115
+ 'context': str, # chunk text
116
+ 'source_meta': { 'split': 'train|validation', 'orig_example_id': ..., 'title': ...}
117
+ }
118
+ """
119
+ log("Downloading SQuAD v2 via datasets ...")
120
+ ds = load_dataset("rajpurkar/squad_v2")
121
+
122
+ prepared: List[Dict[str, Any]] = []
123
+ for split in ["train", "validation"]:
124
+ rows = ds[split]
125
+ log(f"Processing split: {split} (n={len(rows)})")
126
+ for i, ex in enumerate(rows):
127
+ title = ex.get("title") or ""
128
+ context = ex.get("context") or ""
129
+ ex_id = ex.get("id") or f"{split}-{i}"
130
+ chunks = chunk_text(context, CHUNK_SIZE, CHUNK_OVERLAP)
131
+ for j, chunk in enumerate(chunks):
132
+ prepared.append({
133
+ "id": f"{ex_id}::chunk{j}",
134
+ "title": title,
135
+ "context": chunk.strip(),
136
+ "source_meta": {"split": split, "orig_example_id": ex_id, "title": title},
137
+ })
138
+ log(f"Prepared {len(prepared)} chunks total.")
139
+ return prepared
140
+
141
+ # --- Embeddings & FAISS -------------------------------------------------------
142
+ def build_index(prepared: List[Dict[str, Any]], model_name: str = "all-MiniLM-L6-v2"):
143
+ log(f"Loading embedding model: {model_name}")
144
+ st_model = SentenceTransformer(model_name)
145
+ texts = [r["context"] for r in prepared]
146
+
147
+ log("Encoding chunks -> embeddings (this can take a while) ...")
148
+ embs = st_model.encode(texts, show_progress_bar=True, convert_to_numpy=True, batch_size=256)
149
+ embs = embs.astype("float32")
150
+
151
+ dim = embs.shape[1]
152
+ index = faiss.IndexFlatL2(dim)
153
+ index.add(embs)
154
+
155
+ log(f"Built FAISS index with {index.ntotal} vectors. Saving to disk ...")
156
+ faiss.write_index(index, str(INDEX_FILE))
157
+
158
+ meta = {
159
+ "records": prepared,
160
+ "embedding_model": model_name,
161
+ "dim": dim,
162
+ "created_at": time.time(),
163
+ "chunk_size": CHUNK_SIZE,
164
+ "chunk_overlap": CHUNK_OVERLAP,
165
+ }
166
+ with open(META_FILE, "wb") as f:
167
+ pickle.dump(meta, f)
168
+
169
+ log("Index + metadata saved.")
170
+ return index, meta, st_model
171
+
172
+ def load_index():
173
+ if not INDEX_FILE.exists() or not META_FILE.exists():
174
+ raise FileNotFoundError("Index or metadata not found. Run with --build-index first.")
175
+ index = faiss.read_index(str(INDEX_FILE))
176
+ with open(META_FILE, "rb") as f:
177
+ meta = pickle.load(f)
178
+ # lazy load embedding model to match metadata
179
+ st_model = SentenceTransformer(meta.get("embedding_model", "all-MiniLM-L6-v2"))
180
+ return index, meta, st_model
181
+
182
+ # --- RAG core -----------------------------------------------------------------
183
+ class GroundedQA:
184
+ def __init__(self, index, records: List[Dict[str, Any]], embed_model, openai_api_key: str):
185
+ self.index = index
186
+ self.records = records
187
+ self.embed_model = embed_model
188
+ self.client = OpenAI(api_key=openai_api_key)
189
+
190
+ def retrieve(self, question: str, k: int = 5) -> List[Tuple[Dict[str, Any], float]]:
191
+ q_emb = self.embed_model.encode([question], convert_to_numpy=True).astype("float32")
192
+ distances, indices = self.index.search(q_emb, k)
193
+ out = []
194
+ for rank, idx in enumerate(indices[0]):
195
+ rec = self.records[idx]
196
+ dist = float(distances[0][rank])
197
+ out.append((rec, dist))
198
+ return out
199
+
200
+ def _build_prompt(self, question: str, retrieved: List[Tuple[Dict[str, Any], float]]) -> str:
201
+ context_blocks = []
202
+ for i, (rec, _) in enumerate(retrieved, start=1):
203
+ title = rec.get("title") or "Untitled"
204
+ ctx = rec["context"]
205
+ context_blocks.append(f"[Source {i} | {title}] {ctx}")
206
+ context_text = "\n\n".join(context_blocks)
207
+
208
+ prompt = (
209
+ "You are a precise, grounded Q&A assistant. "
210
+ "Answer ONLY using the provided context. If the answer is not in the context, say you don't know.\n"
211
+ "Add citations like [Source X] inline where relevant.\n\n"
212
+ f"Context:\n{context_text}\n\n"
213
+ f"Question: {question}\n\n"
214
+ "Answer (with citations):"
215
+ )
216
+ return prompt
217
+
218
+ def answer_with_citations(self, question: str, k: int = 5) -> Dict[str, Any]:
219
+ retrieved = self.retrieve(question, k=k)
220
+ prompt = self._build_prompt(question, retrieved)
221
+
222
+ resp = self.client.chat.completions.create(
223
+ model=OPENAI_MODEL,
224
+ messages=[{"role": "user", "content": prompt}],
225
+ temperature=0.2,
226
+ max_tokens=400,
227
+ )
228
+ answer = resp.choices[0].message.content.strip()
229
+ return {
230
+ "answer": answer,
231
+ "sources": [
232
+ {
233
+ "rank": i + 1,
234
+ "distance": d,
235
+ "id": rec["id"],
236
+ "title": rec.get("title"),
237
+ "split": rec["source_meta"]["split"],
238
+ "excerpt": rec["context"][:240] + ("..." if len(rec["context"]) > 240 else "")
239
+ }
240
+ for i, (rec, d) in enumerate(retrieved)
241
+ ],
242
+ }
243
+
244
+ # --- Simple confidence heuristic ----------------------------------------------
245
+ def should_review(rag_result: Dict[str, Any], threshold: float = 1.2) -> bool:
246
+ # Lower L2 distance -> closer match. We flag for human review if the average distance is high.
247
+ if not rag_result.get("sources"):
248
+ return True
249
+ avg = float(np.mean([s["distance"] for s in rag_result["sources"]]))
250
+ return avg > threshold
251
+
252
+ # --- CLI ----------------------------------------------------------------------
253
+ def cli_build_index():
254
+ prepared = load_and_prepare_squad()
255
+ build_index(prepared)
256
+
257
+ def cli_query(question: str, k: int = 5):
258
+ index, meta, st_model = load_index()
259
+ qa = GroundedQA(index, meta["records"], st_model, OPENAI_API_KEY)
260
+ result = qa.answer_with_citations(question, k=k)
261
+
262
+ print("\n=== Answer ===")
263
+ print(result["answer"])
264
+ print("\n=== Sources ===")
265
+ for s in result["sources"]:
266
+ print(f"[{s['rank']}] ({s['distance']:.4f}) {s['title']} :: {s['id']}")
267
+ print(f" {s['excerpt']}")
268
+ print("\nReview flag:", "YES" if should_review(result) else "NO")
269
+
270
+ # --- API (optional) -----------------------------------------------------------
271
+ if FASTAPI_AVAILABLE:
272
+ app = FastAPI(title="Nyxion Labs RAG — SQuAD v2")
273
+
274
+ class AskBody(BaseModel):
275
+ question: str
276
+ k: int = 5
277
+
278
+ _STATE = {"qa": None}
279
+
280
+ if FASTAPI_AVAILABLE:
281
+ @asynccontextmanager
282
+ async def lifespan(app: FastAPI):
283
+ # Startup: warm the RAG pipeline once
284
+ index, meta, st_model = load_index()
285
+ app.state.qa = GroundedQA(index, meta["records"], st_model, OPENAI_API_KEY)
286
+ yield
287
+ # Teardown (optional): nothing to clean up
288
+
289
+
290
+ app = FastAPI(title="Nyxion Labs RAG — SQuAD v2", lifespan=lifespan)
291
+
292
+
293
+ class AskBody(BaseModel):
294
+ question: str
295
+ k: int = 5
296
+
297
+
298
+ @app.post("/api/v1/assistant/query")
299
+ def query_api(body: AskBody):
300
+ qa: GroundedQA = app.state.qa
301
+ res = qa.answer_with_citations(body.question, k=body.k)
302
+
303
+ # Keep types JSON-safe + quick review flag
304
+ avg = float(np.mean([s["distance"] for s in res["sources"]])) if res["sources"] else float("inf")
305
+ res["needs_review"] = bool(avg > 1.2)
306
+ return res
307
+ @app.post("/api/v1/assistant/query")
308
+ def query_api(body: AskBody):
309
+ qa: GroundedQA = _STATE["qa"]
310
+ res = qa.answer_with_citations(body.question, k=body.k)
311
+ res["needs_review"] = should_review(res)
312
+ return res
313
+
314
+ # --- main ---------------------------------------------------------------------
315
+ def parse_args():
316
+ p = argparse.ArgumentParser(description="Nyxion Labs — RAG on SQuAD v2")
317
+ p.add_argument("--build-index", action="store_true", help="Download SQuAD and build FAISS index")
318
+ p.add_argument("--q", "--question", dest="question", type=str, help="Ask a question")
319
+ p.add_argument("-k", type=int, default=5, help="Top-k contexts to retrieve")
320
+ p.add_argument("--serve", action="store_true", help="Run FastAPI server on :8000")
321
+ return p.parse_args()
322
+
323
+ def main():
324
+ args = parse_args()
325
+
326
+ if args.build_index:
327
+ cli_build_index()
328
+ return
329
+
330
+ if args.serve:
331
+ if not FASTAPI_AVAILABLE:
332
+ print("FastAPI not installed. pip install fastapi uvicorn pydantic", file=sys.stderr)
333
+ sys.exit(1)
334
+ uvicorn.run("rag_demo:app", host="0.0.0.0", port=8000, reload=False)
335
+ return
336
+
337
+ if args.question:
338
+ if OPENAI_API_KEY.startswith("sk-your-dev-key-here"):
339
+ log("WARNING: Set your OPENAI_API_KEY at top of file.")
340
+ cli_query(args.question, k=args.k)
341
+ return
342
+
343
+ print(__doc__)
344
+
345
+ if __name__ == "__main__":
346
+ main()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ datasets
3
+ faiss-cpu
4
+ sentence-transformers
5
+ openai
6
+ tiktoken