hamxaameer commited on
Commit
89410ee
Β·
verified Β·
1 Parent(s): eceb5f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -26
app.py CHANGED
@@ -44,8 +44,9 @@ def initialize_llm():
44
  logger.info("πŸ”„ Initializing FREE local language model...")
45
 
46
  BACKUP_MODELS = [
47
- "google/flan-t5-base", # Primary - 250M, very fast on CPU
48
- "google/flan-t5-large", # Backup - 780M, slower but better
 
49
  ]
50
 
51
  for model_name in BACKUP_MODELS:
@@ -53,19 +54,37 @@ def initialize_llm():
53
  logger.info(f" Trying {model_name}...")
54
  device = 0 if torch.cuda.is_available() else -1
55
 
56
- # Use text2text-generation for T5 models (not text-generation)
57
- task = "text2text-generation" if "t5" in model_name.lower() else "text-generation"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  llm_client = pipeline(
60
  task,
61
  model=model_name,
62
  device=device,
63
- max_length=300,
64
  truncation=True,
 
65
  )
66
 
67
  CONFIG["llm_model"] = model_name
68
- CONFIG["model_type"] = "t5" if "t5" in model_name.lower() else "instruct"
69
  logger.info(f"βœ… FREE LLM initialized: {model_name}")
70
  logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}")
71
  return llm_client
@@ -356,33 +375,68 @@ def generate_llm_answer(
356
  repetition_penalty = 1.25
357
 
358
  # Create prompt based on model type
359
- if CONFIG.get("model_type") == "t5":
360
- # T5 needs simple input-output format
361
- user_prompt = f"Question: {query}\n\nContext: {context_text[:800]}\n\nProvide a helpful fashion answer:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  else:
363
- # Instruct models use INST format
364
  user_prompt = f"""[INST] Question: {query}
365
 
366
  Fashion Knowledge:
367
  {context_text}
368
 
369
- Answer the question using the knowledge above. Be specific and helpful (100-250 words). [/INST]"""
370
 
371
  try:
372
  logger.info(f" β†’ Calling {CONFIG['llm_model']} (temp={temperature}, tokens={max_tokens})...")
373
 
374
  # Call pipeline with model-specific parameters
375
- if CONFIG.get("model_type") == "t5":
376
- # T5 uses max_length not max_new_tokens
377
  output = llm_client(
378
  user_prompt,
379
- max_length=200, # Shorter for speed
380
- temperature=temperature,
381
- top_p=top_p,
 
 
 
 
 
 
 
 
 
 
 
 
382
  do_sample=True,
 
 
383
  )
384
  else:
385
- # Other models use max_new_tokens
386
  output = llm_client(
387
  user_prompt,
388
  max_new_tokens=max_tokens,
@@ -391,7 +445,7 @@ Answer the question using the knowledge above. Be specific and helpful (100-250
391
  repetition_penalty=repetition_penalty,
392
  do_sample=True,
393
  return_full_text=False,
394
- pad_token_id=llm_client.tokenizer.eos_token_id
395
  )
396
 
397
  # Extract generated text
@@ -488,26 +542,62 @@ def generate_answer_langchain(
488
  # GRADIO INTERFACE
489
  # ============================================================================
490
 
491
- def fashion_chatbot(message: str, history: List[List[str]]) -> str:
492
  """
493
- Chatbot function for Gradio interface
494
  """
495
  try:
496
  if not message or not message.strip():
497
- return "Please ask a fashion-related question!"
 
498
 
499
- # Generate answer using RAG pipeline
500
- answer = generate_answer_langchain(
 
 
 
501
  message.strip(),
502
  vectorstore,
503
- llm_client
504
  )
505
 
506
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
  except Exception as e:
509
  logger.error(f"Error in chatbot: {e}")
510
- return f"Sorry, I encountered an error: {str(e)}"
511
 
512
  # ============================================================================
513
  # INITIALIZE AND LAUNCH
 
44
  logger.info("πŸ”„ Initializing FREE local language model...")
45
 
46
  BACKUP_MODELS = [
47
+ "microsoft/phi-2", # Primary - 2.7B, excellent quality, fast
48
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # Backup - 1.1B, very fast
49
+ "google/flan-t5-large", # Fallback - 780M
50
  ]
51
 
52
  for model_name in BACKUP_MODELS:
 
54
  logger.info(f" Trying {model_name}...")
55
  device = 0 if torch.cuda.is_available() else -1
56
 
57
+ # Determine task and model type
58
+ if "t5" in model_name.lower():
59
+ task = "text2text-generation"
60
+ model_type = "t5"
61
+ elif "phi" in model_name.lower():
62
+ task = "text-generation"
63
+ model_type = "phi"
64
+ elif "tinyllama" in model_name.lower():
65
+ task = "text-generation"
66
+ model_type = "tinyllama"
67
+ else:
68
+ task = "text-generation"
69
+ model_type = "instruct"
70
+
71
+ # Model-specific kwargs for optimization
72
+ model_kwargs = {
73
+ "low_cpu_mem_usage": True,
74
+ "trust_remote_code": True # Required for Phi-2
75
+ }
76
 
77
  llm_client = pipeline(
78
  task,
79
  model=model_name,
80
  device=device,
81
+ max_length=400, # Good length for detailed answers
82
  truncation=True,
83
+ model_kwargs=model_kwargs
84
  )
85
 
86
  CONFIG["llm_model"] = model_name
87
+ CONFIG["model_type"] = model_type
88
  logger.info(f"βœ… FREE LLM initialized: {model_name}")
89
  logger.info(f" Device: {'GPU' if device == 0 else 'CPU'}")
90
  return llm_client
 
375
  repetition_penalty = 1.25
376
 
377
  # Create prompt based on model type
378
+ model_type = CONFIG.get("model_type", "instruct")
379
+
380
+ if model_type == "t5":
381
+ # T5 needs simple format
382
+ user_prompt = f"Question: {query}\n\nContext: {context_text[:800]}\n\nProvide helpful fashion advice:"
383
+ elif model_type == "phi":
384
+ # Phi-2 format (no special tokens needed)
385
+ user_prompt = f"""Instruct: You are a fashion advisor. Use the following knowledge to answer the question.
386
+
387
+ Fashion Knowledge:
388
+ {context_text}
389
+
390
+ Question: {query}
391
+
392
+ Output: Provide specific, helpful fashion advice in 150-200 words."""
393
+ elif model_type == "tinyllama":
394
+ # TinyLlama chat format
395
+ user_prompt = f"""<|system|>
396
+ You are a helpful fashion advisor.</s>
397
+ <|user|>
398
+ Use this fashion knowledge to answer: {context_text[:1000]}
399
+
400
+ Question: {query}</s>
401
+ <|assistant|>"""
402
  else:
403
+ # Generic instruct format
404
  user_prompt = f"""[INST] Question: {query}
405
 
406
  Fashion Knowledge:
407
  {context_text}
408
 
409
+ Answer the question using the knowledge above. Be specific and helpful (150-200 words). [/INST]"""
410
 
411
  try:
412
  logger.info(f" β†’ Calling {CONFIG['llm_model']} (temp={temperature}, tokens={max_tokens})...")
413
 
414
  # Call pipeline with model-specific parameters
415
+ if model_type == "t5":
416
+ # T5 uses max_length
417
  output = llm_client(
418
  user_prompt,
419
+ max_length=150,
420
+ temperature=0.7,
421
+ top_p=0.9,
422
+ do_sample=True,
423
+ num_beams=1,
424
+ early_stopping=True
425
+ )
426
+ elif model_type in ["phi", "tinyllama"]:
427
+ # Phi-2 and TinyLlama - optimized for quality and speed
428
+ output = llm_client(
429
+ user_prompt,
430
+ max_new_tokens=min(max_tokens, 300), # Cap at 300 for speed
431
+ temperature=0.75, # Balanced creativity
432
+ top_p=0.92,
433
+ repetition_penalty=1.15,
434
  do_sample=True,
435
+ return_full_text=False,
436
+ pad_token_id=llm_client.tokenizer.eos_token_id if hasattr(llm_client.tokenizer, 'eos_token_id') else None
437
  )
438
  else:
439
+ # Other models
440
  output = llm_client(
441
  user_prompt,
442
  max_new_tokens=max_tokens,
 
445
  repetition_penalty=repetition_penalty,
446
  do_sample=True,
447
  return_full_text=False,
448
+ pad_token_id=llm_client.tokenizer.eos_token_id if hasattr(llm_client.tokenizer, 'eos_token_id') else None
449
  )
450
 
451
  # Extract generated text
 
542
  # GRADIO INTERFACE
543
  # ============================================================================
544
 
545
+ def fashion_chatbot(message: str, history: List[List[str]]):
546
  """
547
+ Chatbot function for Gradio interface with streaming
548
  """
549
  try:
550
  if not message or not message.strip():
551
+ yield "Please ask a fashion-related question!"
552
+ return
553
 
554
+ # Show searching indicator
555
+ yield "πŸ” Searching fashion knowledge..."
556
+
557
+ # Retrieve documents
558
+ retrieved_docs, confidence = retrieve_knowledge_langchain(
559
  message.strip(),
560
  vectorstore,
561
+ top_k=CONFIG["top_k"]
562
  )
563
 
564
+ if not retrieved_docs:
565
+ yield "I couldn't find relevant information to answer your question."
566
+ return
567
+
568
+ # Show generating indicator
569
+ yield f"πŸ’­ Generating answer ({len(retrieved_docs)} sources found)..."
570
+
571
+ # Generate answer with multiple attempts
572
+ llm_answer = None
573
+ for attempt in range(1, 5):
574
+ logger.info(f"\n πŸ€– LLM Generation Attempt {attempt}/4")
575
+ llm_answer = generate_llm_answer(message.strip(), retrieved_docs, llm_client, attempt)
576
+
577
+ if llm_answer:
578
+ break
579
+
580
+ # Fallback if needed
581
+ if not llm_answer:
582
+ logger.error(f" βœ— All LLM attempts failed - using fallback")
583
+ llm_answer = synthesize_direct_answer(message.strip(), retrieved_docs)
584
+
585
+ # Stream the answer word by word for natural flow
586
+ import time
587
+ words = llm_answer.split()
588
+ displayed_text = ""
589
+
590
+ for i, word in enumerate(words):
591
+ displayed_text += word + " "
592
+
593
+ # Yield every 3 words for smooth streaming
594
+ if i % 3 == 0 or i == len(words) - 1:
595
+ yield displayed_text.strip()
596
+ time.sleep(0.05) # Small delay for natural flow
597
 
598
  except Exception as e:
599
  logger.error(f"Error in chatbot: {e}")
600
+ yield f"Sorry, I encountered an error: {str(e)}"
601
 
602
  # ============================================================================
603
  # INITIALIZE AND LAUNCH