Spaces:
Paused
Paused
feat: Add Colab patient summary script, AI service utilities for performance, and related documentation.
f091f7a
| # @title Install Dependencies | |
| # Run this cell first to install necessary packages | |
| import subprocess | |
| import sys | |
| def install_dependencies(): | |
| packages = [ | |
| "torch", | |
| "transformers", | |
| "optimum", | |
| "optimum-intel", | |
| "openvino", | |
| "accelerate", | |
| "scipy" | |
| ] | |
| print(f"Installing packages: {', '.join(packages)}") | |
| subprocess.check_call([sys.executable, "-m", "pip", "install"] + packages) | |
| print("Dependencies installed successfully.") | |
| # Uncomment the line below to install dependencies in Colab | |
| # install_dependencies() | |
| import os | |
| import gc | |
| import time | |
| import logging | |
| import json | |
| import re | |
| import warnings | |
| import datetime | |
| from typing import List, Dict, Union, Optional, Any, Tuple | |
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| from textwrap import fill | |
| import concurrent.futures | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # ========================================== | |
| # MOCK PERFORMANCE MONITOR | |
| # ========================================== | |
| def cached_robust_parsing(func): | |
| return func | |
| def track_robust_processing(func): | |
| return func | |
| def track_prompt_generation(func): | |
| return func | |
| # ========================================== | |
| # MODEL CONFIGURATION (from model_config.py) | |
| # ========================================== | |
| # Detect if running on Hugging Face Spaces | |
| IS_HF_SPACES = os.getenv("HUGGINGFACE_SPACES", "").lower() == "true" | |
| IS_T4_MEDIUM = IS_HF_SPACES and os.getenv("SPACES_MACHINE", "").lower() == "t4-medium" | |
| # T4 Medium optimizations | |
| T4_OPTIMIZATIONS = { | |
| "max_memory_mb": 14000, | |
| "use_quantization": True, | |
| "load_in_4bit": True, | |
| "torch_dtype": "float16", | |
| "device_map": "auto", | |
| "trust_remote_code": True, | |
| "cache_dir": "/tmp/hf_cache", | |
| "local_files_only": False | |
| } | |
| # Model generation settings | |
| GENERATION_CONFIG = { | |
| "use_cache": True, | |
| "max_length": 8192, | |
| "temperature": 0.1, | |
| "num_return_sequences": 1, | |
| "do_sample": False, | |
| "pad_token_id": 0, | |
| "generation_config": { | |
| "use_cache": True, | |
| "max_new_tokens": 8192, | |
| "do_sample": False, | |
| "temperature": 0.1 | |
| } | |
| } | |
| # Default models | |
| DEFAULT_MODELS = { | |
| "text-generation": { | |
| "primary": "microsoft/DialoGPT-small", | |
| "fallback": "facebook/bart-base", | |
| }, | |
| "summarization": { | |
| "primary": "sshleifer/distilbart-cnn-6-6", | |
| "fallback": "facebook/bart-base", | |
| }, | |
| "openvino": { | |
| "primary": "microsoft/Phi-3-mini-4k-instruct", | |
| "fallback": "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov", | |
| }, | |
| "causal-openvino": { | |
| "primary": "microsoft/Phi-3-mini-4k-instruct", | |
| "fallback": "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov", | |
| } | |
| } | |
| MODEL_TYPE_MAPPINGS = { | |
| ".gguf": "gguf", | |
| "gguf": "gguf", | |
| "openvino": "openvino", | |
| "ov": "openvino", | |
| "causal-openvino": "causal-openvino", | |
| "text-generation": "text-generation", | |
| "summarization": "summarization", | |
| "instruct": "text-generation", | |
| } | |
| MODEL_TOKEN_LIMITS = { | |
| "microsoft/Phi-3-mini-4k-instruct": 8192, | |
| "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov": 8192, | |
| "default": 4096 | |
| } | |
| def get_model_token_limit(model_name: str) -> int: | |
| if model_name in MODEL_TOKEN_LIMITS: | |
| return MODEL_TOKEN_LIMITS[model_name] | |
| if "128k" in model_name.lower(): | |
| return 131072 | |
| elif "8k" in model_name.lower(): | |
| return 8192 | |
| elif "4k" in model_name.lower(): | |
| return 4096 | |
| return MODEL_TOKEN_LIMITS["default"] | |
| def get_t4_model_kwargs(model_type: str) -> dict: | |
| # Always return T4 optimizations for Colab usage to be safe/efficient | |
| base_kwargs = T4_OPTIMIZATIONS.copy() | |
| if model_type in ["summarization", "seq2seq", "text-generation"]: | |
| base_kwargs.update({ | |
| "load_in_4bit": True, | |
| "bnb_4bit_compute_dtype": "float16", | |
| "bnb_4bit_use_double_quant": True, | |
| "bnb_4bit_quant_type": "nf4" | |
| }) | |
| return base_kwargs | |
| def get_t4_generation_config(model_type: str) -> dict: | |
| config = GENERATION_CONFIG.copy() | |
| config["max_length"] = 8192 | |
| config["generation_config"]["max_new_tokens"] = 8192 | |
| return config | |
| def is_model_supported_on_t4(model_name: str, model_type: str) -> bool: | |
| return True | |
| def detect_model_type(model_name: str) -> str: | |
| model_name_lower = model_name.lower() | |
| for indicator, model_type in MODEL_TYPE_MAPPINGS.items(): | |
| if indicator in model_name_lower: | |
| return model_type | |
| return "text-generation" | |
| # ========================================== | |
| # ROBUST JSON PARSER (from robust_json_parser.py) | |
| # ========================================== | |
| def safe_get(data_dict: Dict[str, Any], key_aliases: List[str]) -> Optional[Any]: | |
| if not isinstance(data_dict, dict): | |
| return None | |
| for alias in key_aliases: | |
| for key, value in data_dict.items(): | |
| if key.lower() == alias.lower(): | |
| return value | |
| return None | |
| def normalize_visit_data(visit: Dict[str, Any]) -> Dict[str, Any]: | |
| if not isinstance(visit, dict): | |
| return {} | |
| normalized = {} | |
| date_value = safe_get(visit, ['chartdate', 'date', 'visitDate', 'encounterDate']) | |
| if date_value: | |
| normalized['chartdate'] = str(date_value)[:10] | |
| vitals = safe_get(visit, ['vitals', 'vitalSigns', 'vital_signs']) | |
| if vitals: | |
| if isinstance(vitals, dict): | |
| normalized['vitals'] = vitals | |
| elif isinstance(vitals, list): | |
| vitals_dict = {} | |
| for item in vitals: | |
| if isinstance(item, str) and ':' in item: | |
| key, value = item.split(':', 1) | |
| vitals_dict[key.strip()] = value.strip() | |
| normalized['vitals'] = vitals_dict | |
| diagnoses = safe_get(visit, ['diagnoses', 'diagnosis', 'conditions']) | |
| if diagnoses: | |
| if isinstance(diagnoses, list): | |
| normalized['diagnosis'] = [str(d).strip() for d in diagnoses if d] | |
| elif isinstance(diagnoses, str): | |
| normalized['diagnosis'] = [diagnoses.strip()] | |
| medications = safe_get(visit, ['medications', 'meds', 'prescriptions']) | |
| if medications: | |
| if isinstance(medications, list): | |
| normalized['medications'] = [str(m).strip() for m in medications if m] | |
| elif isinstance(medications, str): | |
| normalized['medications'] = [medications.strip()] | |
| complaint = safe_get(visit, ['chiefComplaint', 'reasonForVisit', 'chief_complaint']) | |
| if complaint: | |
| normalized['chiefComplaint'] = str(complaint).strip() | |
| symptoms = safe_get(visit, ['symptoms', 'reportedSymptoms']) | |
| if symptoms: | |
| if isinstance(symptoms, list): | |
| normalized['symptoms'] = [str(s).strip() for s in symptoms if s] | |
| elif isinstance(symptoms, str): | |
| normalized['symptoms'] = [symptoms.strip()] | |
| return normalized | |
| def process_patient_record_robust(patient_data: Dict[str, Any]) -> Dict[str, Any]: | |
| if not isinstance(patient_data, dict): | |
| return {"error": "Invalid patient data format"} | |
| processed = {} | |
| demographics = safe_get(patient_data, ['demographics', 'patientInfo', 'patient_info']) | |
| if demographics and isinstance(demographics, dict): | |
| processed['demographics'] = { | |
| 'age': safe_get(demographics, ['age', 'yearsOld']), | |
| 'gender': safe_get(demographics, ['gender', 'sex']), | |
| 'dob': safe_get(demographics, ['dob', 'dateOfBirth']) | |
| } | |
| processed['patientName'] = safe_get(patient_data, ['patientName', 'patient_name', 'name']) | |
| processed['patientNumber'] = safe_get(patient_data, ['patientNumber', 'patient_number', 'id']) | |
| pmh = safe_get(patient_data, ['pastMedicalHistory', 'pmh', 'medical_history']) | |
| if pmh: | |
| processed['pastMedicalHistory'] = pmh if isinstance(pmh, list) else [pmh] | |
| allergies = safe_get(patient_data, ['allergies', 'allergyInfo']) | |
| if allergies: | |
| processed['allergies'] = allergies if isinstance(allergies, list) else [allergies] | |
| visits = safe_get(patient_data, ['visits', 'encounters', 'appointments']) | |
| if visits and isinstance(visits, list): | |
| processed_visits = [] | |
| for visit in visits: | |
| if isinstance(visit, dict): | |
| normalized_visit = normalize_visit_data(visit) | |
| if normalized_visit: | |
| processed_visits.append(normalized_visit) | |
| processed['visits'] = processed_visits | |
| return processed | |
| def extract_structured_summary(processed_data: Dict[str, Any]) -> str: | |
| summary_parts = [] | |
| summary_parts.append("Patient Baseline Profile:") | |
| demographics = processed_data.get('demographics', {}) | |
| age = demographics.get('age', 'N/A') | |
| gender = demographics.get('gender', 'N/A') | |
| summary_parts.append(f"- Demographics: {age} y/o {gender}") | |
| pmh = processed_data.get('pastMedicalHistory', []) | |
| if pmh: | |
| summary_parts.append(f"- Past Medical History: {', '.join(pmh)}") | |
| allergies = processed_data.get('allergies', []) | |
| if allergies: | |
| summary_parts.append(f"- Allergies: {', '.join(allergies)}") | |
| visits = processed_data.get('visits', []) | |
| if visits: | |
| sorted_visits = sorted(visits, key=lambda v: v.get('chartdate', '')) | |
| historical_visits = sorted_visits[:-1] if len(sorted_visits) > 1 else [] | |
| if historical_visits: | |
| summary_parts.append("\nLongitudinal Visit History:") | |
| for visit in historical_visits: | |
| visit_date = visit.get('chartdate', 'N/A') | |
| summary_parts.append(f"\n- Date: {visit_date}") | |
| vitals = visit.get('vitals', {}) | |
| if vitals: | |
| vitals_str = ", ".join([f"{k}: {v}" for k, v in vitals.items()]) | |
| summary_parts.append(f" - Vitals: {vitals_str}") | |
| diagnoses = visit.get('diagnosis', []) | |
| if diagnoses: | |
| summary_parts.append(f" - Diagnoses: {', '.join(diagnoses)}") | |
| medications = visit.get('medications', []) | |
| if medications: | |
| summary_parts.append(f" - Medications: {', '.join(medications)}") | |
| if sorted_visits: | |
| current_visit = sorted_visits[-1] | |
| summary_parts.append("\nCurrent Visit Details:") | |
| current_date = current_visit.get('chartdate', 'N/A') | |
| summary_parts.append(f"- Date: {current_date}") | |
| complaint = current_visit.get('chiefComplaint', 'Not specified') | |
| summary_parts.append(f"- Chief Complaint: {complaint}") | |
| symptoms = current_visit.get('symptoms', []) | |
| if symptoms: | |
| summary_parts.append(f"- Reported Symptoms: {', '.join(symptoms)}") | |
| vitals = current_visit.get('vitals', {}) | |
| if vitals: | |
| vitals_str = ", ".join([f"{key}: {value}" for key, value in vitals.items()]) | |
| summary_parts.append(f"- Vitals: {vitals_str}") | |
| diagnoses = current_visit.get('diagnosis', []) | |
| if diagnoses: | |
| summary_parts.append(f"- Diagnoses This Visit: {', '.join(diagnoses)}") | |
| return "\n".join(summary_parts) | |
| def create_ai_prompt(processed_data: Dict[str, Any]) -> str: | |
| structured_text = extract_structured_summary(processed_data) | |
| visits = processed_data.get('visits', []) | |
| current_complaint = "Not specified" | |
| if visits: | |
| try: | |
| sorted_visits = sorted(visits, key=lambda v: v.get('chartdate', '')) | |
| if sorted_visits: | |
| current_complaint = sorted_visits[-1].get('chiefComplaint', 'Not specified') | |
| except Exception: | |
| pass | |
| prompt = f"""<|system|> | |
| You are an expert clinical AI assistant. Your task is to generate a comprehensive patient summary by integrating the patient's baseline profile, longitudinal history, and their current visit details. Your analysis must be holistic, connecting past events with the current presentation. The final output MUST strictly follow the multi-part markdown structure below. | |
| --- | |
| **PATIENT DATA FOR ANALYSIS:** | |
| {structured_text} | |
| --- | |
| **REQUIRED OUTPUT FORMAT:** | |
| ## Longitudinal Assessment | |
| - **Baseline Health Status:** [Summarize the patient's core health profile including chronic comorbidities, relevant PMH, and habits.] | |
| - **Key Historical Trends:** [Analyze trends from past visits. Comment on vital signs, consistency of chronic disease management, and recurring issues.] | |
| ## Current Visit Triage Assessment | |
| **Chief Complaint:** {current_complaint} | |
| **Clinical Findings:** | |
| - **Primary Symptoms:** [List the key symptoms from the current visit.] | |
| - **Objective Vitals:** [State the vitals and note any abnormalities.] | |
| - **Diagnoses:** [List the diagnoses for this visit.] | |
| ## Synthesized Plan & Guidance | |
| - **Integrated Assessment:** [Provide a short paragraph connecting the current complaint to the patient's baseline health.] | |
| - **Medication Management:** [Comment on the overall medication regimen.] | |
| - **Monitoring & Follow-up:** [Recommend specific parameters to monitor and suggest a clear follow-up timeline.] | |
| ## Clinical Recommendations | |
| - **Primary Clinical Concern:** [State the most important issue to focus on.] | |
| - **Potential Risks & Considerations:** [Identify key risks based on combined data.] | |
| <|user|> | |
| Generate a comprehensive patient summary in markdown format. | |
| <|assistant|> | |
| """ | |
| return prompt | |
| # ========================================== | |
| # UNIFIED MODEL MANAGER (from unified_model_manager.py) | |
| # ========================================== | |
| import torch | |
| class ModelType(Enum): | |
| TRANSFORMERS = "transformers" | |
| GGUF = "gguf" | |
| OPENVINO = "openvino" | |
| FALLBACK = "fallback" | |
| class ModelStatus(Enum): | |
| UNINITIALIZED = "uninitialized" | |
| LOADING = "loading" | |
| LOADED = "loaded" | |
| ERROR = "error" | |
| class GenerationConfig: | |
| max_tokens: int = 8192 | |
| min_tokens: int = 50 | |
| temperature: float = 0.3 | |
| top_p: float = 0.9 | |
| timeout: float = 180.0 | |
| stream: bool = False | |
| class BaseModel(ABC): | |
| def __init__(self, name: str, model_type: str, **kwargs): | |
| self.name = name | |
| self.model_type = model_type | |
| self._model = None | |
| self._status = ModelStatus.UNINITIALIZED | |
| self._kwargs = kwargs | |
| def status(self) -> ModelStatus: | |
| return self._status | |
| def _load_implementation(self) -> bool: | |
| pass | |
| def load(self): | |
| if self._status == ModelStatus.LOADED: | |
| return self | |
| try: | |
| self._status = ModelStatus.LOADING | |
| logger.info(f"Loading model: {self.name} ({self.model_type})") | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if self._load_implementation(): | |
| self._status = ModelStatus.LOADED | |
| logger.info(f"Model {self.name} loaded successfully") | |
| return self | |
| else: | |
| self._status = ModelStatus.ERROR | |
| return None | |
| except Exception as e: | |
| self._status = ModelStatus.ERROR | |
| logger.error(f"Failed to load model {self.name}: {e}") | |
| return None | |
| def generate(self, prompt: str, config: GenerationConfig) -> str: | |
| pass | |
| class OpenVINOModel(BaseModel): | |
| def __init__(self, name: str, model_type: str, **kwargs): | |
| super().__init__(name, model_type, **kwargs) | |
| self._tokenizer = None | |
| def _load_implementation(self) -> bool: | |
| try: | |
| from optimum.intel import OVModelForCausalLM | |
| from transformers import AutoTokenizer | |
| model_kwargs = get_t4_model_kwargs("openvino") | |
| model_path = self.name | |
| tokenizer_path = self.name | |
| if "OpenVINO/" in self.name: | |
| if "Phi-3-mini-4k-instruct" in self.name: | |
| tokenizer_path = "microsoft/Phi-3-mini-4k-instruct" | |
| logger.info(f"Loading OpenVINO model from {model_path} with tokenizer from {tokenizer_path}") | |
| self._model = OVModelForCausalLM.from_pretrained( | |
| model_path, | |
| device="GPU" if torch.cuda.is_available() else "CPU", | |
| **model_kwargs | |
| ) | |
| self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load OpenVINO model {self.name}: {e}") | |
| return False | |
| def generate(self, prompt: str, config: GenerationConfig) -> str: | |
| if self._model is None or self._tokenizer is None: | |
| raise RuntimeError("Model not loaded") | |
| try: | |
| inputs = self._tokenizer(prompt, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| outputs = self._model.generate( | |
| **inputs, | |
| max_new_tokens=min(config.max_tokens, 8192), | |
| temperature=config.temperature, | |
| top_p=config.top_p, | |
| do_sample=config.temperature > 0.1, | |
| pad_token_id=self._tokenizer.eos_token_id | |
| ) | |
| generated_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| if generated_text.startswith(prompt): | |
| generated_text = generated_text[len(prompt):].strip() | |
| return generated_text | |
| except Exception as e: | |
| logger.error(f"Generation failed: {e}") | |
| raise | |
| class UnifiedModelManager: | |
| def __init__(self): | |
| self._models = {} | |
| def get_model(self, name: str, model_type: str = None, lazy: bool = True, **kwargs) -> BaseModel: | |
| if model_type is None: | |
| model_type = detect_model_type(name) | |
| cache_key = f"{name}:{model_type}" | |
| if cache_key in self._models: | |
| return self._models[cache_key] | |
| model_kwargs = get_t4_model_kwargs(model_type) | |
| model_kwargs.update(kwargs) | |
| if model_type == "openvino" or model_type == "causal-openvino": | |
| model = OpenVINOModel(name, model_type, **model_kwargs) | |
| else: | |
| # Fallback for this script | |
| raise ValueError(f"Model type {model_type} not implemented in this script") | |
| self._models[cache_key] = model | |
| if not lazy: | |
| model.load() | |
| return model | |
| unified_model_manager = UnifiedModelManager() | |
| # ========================================== | |
| # PATIENT SUMMARIZER AGENT (from patient_summary_agent.py) | |
| # ========================================== | |
| class PatientSummarizerAgent: | |
| def __init__(self, model_name: str = None, model_type: str = None): | |
| self.current_model_name = model_name | |
| self.current_model_type = model_type | |
| self.model_loader = None | |
| def configure_model(self, model_name: str, model_type: str = None): | |
| self.current_model_name = model_name | |
| self.current_model_type = model_type or detect_model_type(model_name) | |
| self.model_loader = unified_model_manager.get_model( | |
| self.current_model_name, | |
| self.current_model_type, | |
| lazy=True | |
| ) | |
| return self.model_loader | |
| def generate_patient_summary(self, patient_data: Union[List[str], Dict]) -> str: | |
| if not self.model_loader: | |
| self.configure_model(self.current_model_name, self.current_model_type) | |
| if self.model_loader.status != ModelStatus.LOADED: | |
| self.model_loader.load() | |
| # Process data | |
| if isinstance(patient_data, dict): | |
| processed_data = process_patient_record_robust(patient_data) | |
| prompt = create_ai_prompt(processed_data) | |
| else: | |
| raise ValueError("Patient data must be a dictionary") | |
| # Generate | |
| gen_config = get_t4_generation_config(self.current_model_type) | |
| config = GenerationConfig(**gen_config) | |
| result = self.model_loader.generate(prompt, config) | |
| return result | |
| # ========================================== | |
| # MAIN EXECUTION | |
| # ========================================== | |
| if __name__ == "__main__": | |
| # Sample Patient Data | |
| sample_patient_data = { | |
| "patientName": "John Doe", | |
| "patientNumber": "12345", | |
| "demographics": { | |
| "age": "65", | |
| "gender": "Male", | |
| "dob": "1958-05-15" | |
| }, | |
| "pastMedicalHistory": [ | |
| "Hypertension", | |
| "Type 2 Diabetes", | |
| "Hyperlipidemia" | |
| ], | |
| "allergies": [ | |
| "Penicillin" | |
| ], | |
| "visits": [ | |
| { | |
| "chartdate": "2023-01-15", | |
| "chiefComplaint": "Routine checkup", | |
| "vitals": { | |
| "Bp(sys)(mmHg)": "130", | |
| "Bp(dia)(mmHg)": "85", | |
| "Pulse(bpm)": "72" | |
| }, | |
| "diagnosis": ["Hypertension", "Type 2 Diabetes"], | |
| "medications": ["Lisinopril 10mg", "Metformin 500mg"] | |
| }, | |
| { | |
| "chartdate": "2023-06-20", | |
| "chiefComplaint": "Dizziness and fatigue", | |
| "vitals": { | |
| "Bp(sys)(mmHg)": "110", | |
| "Bp(dia)(mmHg)": "70", | |
| "Pulse(bpm)": "65" | |
| }, | |
| "diagnosis": ["Dehydration", "Hypotension"], | |
| "medications": ["Lisinopril held", "Metformin 500mg"] | |
| } | |
| ] | |
| } | |
| print("Initializing PatientSummarizerAgent...") | |
| agent = PatientSummarizerAgent( | |
| model_name="microsoft/Phi-3-mini-4k-instruct", | |
| model_type="causal-openvino" | |
| ) | |
| print("Generating summary...") | |
| try: | |
| summary = agent.generate_patient_summary(sample_patient_data) | |
| print("\n" + "="*50) | |
| print("GENERATED PATIENT SUMMARY") | |
| print("="*50) | |
| print(summary) | |
| except Exception as e: | |
| print(f"Error generating summary: {e}") | |