HNTAI / colab_patient_summary_script.py
sachinchandrankallar's picture
feat: Add Colab patient summary script, AI service utilities for performance, and related documentation.
f091f7a
# @title Install Dependencies
# Run this cell first to install necessary packages
import subprocess
import sys
def install_dependencies():
packages = [
"torch",
"transformers",
"optimum",
"optimum-intel",
"openvino",
"accelerate",
"scipy"
]
print(f"Installing packages: {', '.join(packages)}")
subprocess.check_call([sys.executable, "-m", "pip", "install"] + packages)
print("Dependencies installed successfully.")
# Uncomment the line below to install dependencies in Colab
# install_dependencies()
import os
import gc
import time
import logging
import json
import re
import warnings
import datetime
from typing import List, Dict, Union, Optional, Any, Tuple
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from textwrap import fill
import concurrent.futures
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
# ==========================================
# MOCK PERFORMANCE MONITOR
# ==========================================
def cached_robust_parsing(func):
return func
def track_robust_processing(func):
return func
def track_prompt_generation(func):
return func
# ==========================================
# MODEL CONFIGURATION (from model_config.py)
# ==========================================
# Detect if running on Hugging Face Spaces
IS_HF_SPACES = os.getenv("HUGGINGFACE_SPACES", "").lower() == "true"
IS_T4_MEDIUM = IS_HF_SPACES and os.getenv("SPACES_MACHINE", "").lower() == "t4-medium"
# T4 Medium optimizations
T4_OPTIMIZATIONS = {
"max_memory_mb": 14000,
"use_quantization": True,
"load_in_4bit": True,
"torch_dtype": "float16",
"device_map": "auto",
"trust_remote_code": True,
"cache_dir": "/tmp/hf_cache",
"local_files_only": False
}
# Model generation settings
GENERATION_CONFIG = {
"use_cache": True,
"max_length": 8192,
"temperature": 0.1,
"num_return_sequences": 1,
"do_sample": False,
"pad_token_id": 0,
"generation_config": {
"use_cache": True,
"max_new_tokens": 8192,
"do_sample": False,
"temperature": 0.1
}
}
# Default models
DEFAULT_MODELS = {
"text-generation": {
"primary": "microsoft/DialoGPT-small",
"fallback": "facebook/bart-base",
},
"summarization": {
"primary": "sshleifer/distilbart-cnn-6-6",
"fallback": "facebook/bart-base",
},
"openvino": {
"primary": "microsoft/Phi-3-mini-4k-instruct",
"fallback": "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov",
},
"causal-openvino": {
"primary": "microsoft/Phi-3-mini-4k-instruct",
"fallback": "OpenVINO/Phi-3-mini-4k-instruct-fp16-ov",
}
}
MODEL_TYPE_MAPPINGS = {
".gguf": "gguf",
"gguf": "gguf",
"openvino": "openvino",
"ov": "openvino",
"causal-openvino": "causal-openvino",
"text-generation": "text-generation",
"summarization": "summarization",
"instruct": "text-generation",
}
MODEL_TOKEN_LIMITS = {
"microsoft/Phi-3-mini-4k-instruct": 8192,
"OpenVINO/Phi-3-mini-4k-instruct-fp16-ov": 8192,
"default": 4096
}
def get_model_token_limit(model_name: str) -> int:
if model_name in MODEL_TOKEN_LIMITS:
return MODEL_TOKEN_LIMITS[model_name]
if "128k" in model_name.lower():
return 131072
elif "8k" in model_name.lower():
return 8192
elif "4k" in model_name.lower():
return 4096
return MODEL_TOKEN_LIMITS["default"]
def get_t4_model_kwargs(model_type: str) -> dict:
# Always return T4 optimizations for Colab usage to be safe/efficient
base_kwargs = T4_OPTIMIZATIONS.copy()
if model_type in ["summarization", "seq2seq", "text-generation"]:
base_kwargs.update({
"load_in_4bit": True,
"bnb_4bit_compute_dtype": "float16",
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4"
})
return base_kwargs
def get_t4_generation_config(model_type: str) -> dict:
config = GENERATION_CONFIG.copy()
config["max_length"] = 8192
config["generation_config"]["max_new_tokens"] = 8192
return config
def is_model_supported_on_t4(model_name: str, model_type: str) -> bool:
return True
def detect_model_type(model_name: str) -> str:
model_name_lower = model_name.lower()
for indicator, model_type in MODEL_TYPE_MAPPINGS.items():
if indicator in model_name_lower:
return model_type
return "text-generation"
# ==========================================
# ROBUST JSON PARSER (from robust_json_parser.py)
# ==========================================
def safe_get(data_dict: Dict[str, Any], key_aliases: List[str]) -> Optional[Any]:
if not isinstance(data_dict, dict):
return None
for alias in key_aliases:
for key, value in data_dict.items():
if key.lower() == alias.lower():
return value
return None
def normalize_visit_data(visit: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(visit, dict):
return {}
normalized = {}
date_value = safe_get(visit, ['chartdate', 'date', 'visitDate', 'encounterDate'])
if date_value:
normalized['chartdate'] = str(date_value)[:10]
vitals = safe_get(visit, ['vitals', 'vitalSigns', 'vital_signs'])
if vitals:
if isinstance(vitals, dict):
normalized['vitals'] = vitals
elif isinstance(vitals, list):
vitals_dict = {}
for item in vitals:
if isinstance(item, str) and ':' in item:
key, value = item.split(':', 1)
vitals_dict[key.strip()] = value.strip()
normalized['vitals'] = vitals_dict
diagnoses = safe_get(visit, ['diagnoses', 'diagnosis', 'conditions'])
if diagnoses:
if isinstance(diagnoses, list):
normalized['diagnosis'] = [str(d).strip() for d in diagnoses if d]
elif isinstance(diagnoses, str):
normalized['diagnosis'] = [diagnoses.strip()]
medications = safe_get(visit, ['medications', 'meds', 'prescriptions'])
if medications:
if isinstance(medications, list):
normalized['medications'] = [str(m).strip() for m in medications if m]
elif isinstance(medications, str):
normalized['medications'] = [medications.strip()]
complaint = safe_get(visit, ['chiefComplaint', 'reasonForVisit', 'chief_complaint'])
if complaint:
normalized['chiefComplaint'] = str(complaint).strip()
symptoms = safe_get(visit, ['symptoms', 'reportedSymptoms'])
if symptoms:
if isinstance(symptoms, list):
normalized['symptoms'] = [str(s).strip() for s in symptoms if s]
elif isinstance(symptoms, str):
normalized['symptoms'] = [symptoms.strip()]
return normalized
def process_patient_record_robust(patient_data: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(patient_data, dict):
return {"error": "Invalid patient data format"}
processed = {}
demographics = safe_get(patient_data, ['demographics', 'patientInfo', 'patient_info'])
if demographics and isinstance(demographics, dict):
processed['demographics'] = {
'age': safe_get(demographics, ['age', 'yearsOld']),
'gender': safe_get(demographics, ['gender', 'sex']),
'dob': safe_get(demographics, ['dob', 'dateOfBirth'])
}
processed['patientName'] = safe_get(patient_data, ['patientName', 'patient_name', 'name'])
processed['patientNumber'] = safe_get(patient_data, ['patientNumber', 'patient_number', 'id'])
pmh = safe_get(patient_data, ['pastMedicalHistory', 'pmh', 'medical_history'])
if pmh:
processed['pastMedicalHistory'] = pmh if isinstance(pmh, list) else [pmh]
allergies = safe_get(patient_data, ['allergies', 'allergyInfo'])
if allergies:
processed['allergies'] = allergies if isinstance(allergies, list) else [allergies]
visits = safe_get(patient_data, ['visits', 'encounters', 'appointments'])
if visits and isinstance(visits, list):
processed_visits = []
for visit in visits:
if isinstance(visit, dict):
normalized_visit = normalize_visit_data(visit)
if normalized_visit:
processed_visits.append(normalized_visit)
processed['visits'] = processed_visits
return processed
def extract_structured_summary(processed_data: Dict[str, Any]) -> str:
summary_parts = []
summary_parts.append("Patient Baseline Profile:")
demographics = processed_data.get('demographics', {})
age = demographics.get('age', 'N/A')
gender = demographics.get('gender', 'N/A')
summary_parts.append(f"- Demographics: {age} y/o {gender}")
pmh = processed_data.get('pastMedicalHistory', [])
if pmh:
summary_parts.append(f"- Past Medical History: {', '.join(pmh)}")
allergies = processed_data.get('allergies', [])
if allergies:
summary_parts.append(f"- Allergies: {', '.join(allergies)}")
visits = processed_data.get('visits', [])
if visits:
sorted_visits = sorted(visits, key=lambda v: v.get('chartdate', ''))
historical_visits = sorted_visits[:-1] if len(sorted_visits) > 1 else []
if historical_visits:
summary_parts.append("\nLongitudinal Visit History:")
for visit in historical_visits:
visit_date = visit.get('chartdate', 'N/A')
summary_parts.append(f"\n- Date: {visit_date}")
vitals = visit.get('vitals', {})
if vitals:
vitals_str = ", ".join([f"{k}: {v}" for k, v in vitals.items()])
summary_parts.append(f" - Vitals: {vitals_str}")
diagnoses = visit.get('diagnosis', [])
if diagnoses:
summary_parts.append(f" - Diagnoses: {', '.join(diagnoses)}")
medications = visit.get('medications', [])
if medications:
summary_parts.append(f" - Medications: {', '.join(medications)}")
if sorted_visits:
current_visit = sorted_visits[-1]
summary_parts.append("\nCurrent Visit Details:")
current_date = current_visit.get('chartdate', 'N/A')
summary_parts.append(f"- Date: {current_date}")
complaint = current_visit.get('chiefComplaint', 'Not specified')
summary_parts.append(f"- Chief Complaint: {complaint}")
symptoms = current_visit.get('symptoms', [])
if symptoms:
summary_parts.append(f"- Reported Symptoms: {', '.join(symptoms)}")
vitals = current_visit.get('vitals', {})
if vitals:
vitals_str = ", ".join([f"{key}: {value}" for key, value in vitals.items()])
summary_parts.append(f"- Vitals: {vitals_str}")
diagnoses = current_visit.get('diagnosis', [])
if diagnoses:
summary_parts.append(f"- Diagnoses This Visit: {', '.join(diagnoses)}")
return "\n".join(summary_parts)
def create_ai_prompt(processed_data: Dict[str, Any]) -> str:
structured_text = extract_structured_summary(processed_data)
visits = processed_data.get('visits', [])
current_complaint = "Not specified"
if visits:
try:
sorted_visits = sorted(visits, key=lambda v: v.get('chartdate', ''))
if sorted_visits:
current_complaint = sorted_visits[-1].get('chiefComplaint', 'Not specified')
except Exception:
pass
prompt = f"""<|system|>
You are an expert clinical AI assistant. Your task is to generate a comprehensive patient summary by integrating the patient's baseline profile, longitudinal history, and their current visit details. Your analysis must be holistic, connecting past events with the current presentation. The final output MUST strictly follow the multi-part markdown structure below.
---
**PATIENT DATA FOR ANALYSIS:**
{structured_text}
---
**REQUIRED OUTPUT FORMAT:**
## Longitudinal Assessment
- **Baseline Health Status:** [Summarize the patient's core health profile including chronic comorbidities, relevant PMH, and habits.]
- **Key Historical Trends:** [Analyze trends from past visits. Comment on vital signs, consistency of chronic disease management, and recurring issues.]
## Current Visit Triage Assessment
**Chief Complaint:** {current_complaint}
**Clinical Findings:**
- **Primary Symptoms:** [List the key symptoms from the current visit.]
- **Objective Vitals:** [State the vitals and note any abnormalities.]
- **Diagnoses:** [List the diagnoses for this visit.]
## Synthesized Plan & Guidance
- **Integrated Assessment:** [Provide a short paragraph connecting the current complaint to the patient's baseline health.]
- **Medication Management:** [Comment on the overall medication regimen.]
- **Monitoring & Follow-up:** [Recommend specific parameters to monitor and suggest a clear follow-up timeline.]
## Clinical Recommendations
- **Primary Clinical Concern:** [State the most important issue to focus on.]
- **Potential Risks & Considerations:** [Identify key risks based on combined data.]
<|user|>
Generate a comprehensive patient summary in markdown format.
<|assistant|>
"""
return prompt
# ==========================================
# UNIFIED MODEL MANAGER (from unified_model_manager.py)
# ==========================================
import torch
class ModelType(Enum):
TRANSFORMERS = "transformers"
GGUF = "gguf"
OPENVINO = "openvino"
FALLBACK = "fallback"
class ModelStatus(Enum):
UNINITIALIZED = "uninitialized"
LOADING = "loading"
LOADED = "loaded"
ERROR = "error"
@dataclass
class GenerationConfig:
max_tokens: int = 8192
min_tokens: int = 50
temperature: float = 0.3
top_p: float = 0.9
timeout: float = 180.0
stream: bool = False
class BaseModel(ABC):
def __init__(self, name: str, model_type: str, **kwargs):
self.name = name
self.model_type = model_type
self._model = None
self._status = ModelStatus.UNINITIALIZED
self._kwargs = kwargs
@property
def status(self) -> ModelStatus:
return self._status
@abstractmethod
def _load_implementation(self) -> bool:
pass
def load(self):
if self._status == ModelStatus.LOADED:
return self
try:
self._status = ModelStatus.LOADING
logger.info(f"Loading model: {self.name} ({self.model_type})")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if self._load_implementation():
self._status = ModelStatus.LOADED
logger.info(f"Model {self.name} loaded successfully")
return self
else:
self._status = ModelStatus.ERROR
return None
except Exception as e:
self._status = ModelStatus.ERROR
logger.error(f"Failed to load model {self.name}: {e}")
return None
@abstractmethod
def generate(self, prompt: str, config: GenerationConfig) -> str:
pass
class OpenVINOModel(BaseModel):
def __init__(self, name: str, model_type: str, **kwargs):
super().__init__(name, model_type, **kwargs)
self._tokenizer = None
def _load_implementation(self) -> bool:
try:
from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer
model_kwargs = get_t4_model_kwargs("openvino")
model_path = self.name
tokenizer_path = self.name
if "OpenVINO/" in self.name:
if "Phi-3-mini-4k-instruct" in self.name:
tokenizer_path = "microsoft/Phi-3-mini-4k-instruct"
logger.info(f"Loading OpenVINO model from {model_path} with tokenizer from {tokenizer_path}")
self._model = OVModelForCausalLM.from_pretrained(
model_path,
device="GPU" if torch.cuda.is_available() else "CPU",
**model_kwargs
)
self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
return True
except Exception as e:
logger.error(f"Failed to load OpenVINO model {self.name}: {e}")
return False
def generate(self, prompt: str, config: GenerationConfig) -> str:
if self._model is None or self._tokenizer is None:
raise RuntimeError("Model not loaded")
try:
inputs = self._tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
outputs = self._model.generate(
**inputs,
max_new_tokens=min(config.max_tokens, 8192),
temperature=config.temperature,
top_p=config.top_p,
do_sample=config.temperature > 0.1,
pad_token_id=self._tokenizer.eos_token_id
)
generated_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
return generated_text
except Exception as e:
logger.error(f"Generation failed: {e}")
raise
class UnifiedModelManager:
def __init__(self):
self._models = {}
def get_model(self, name: str, model_type: str = None, lazy: bool = True, **kwargs) -> BaseModel:
if model_type is None:
model_type = detect_model_type(name)
cache_key = f"{name}:{model_type}"
if cache_key in self._models:
return self._models[cache_key]
model_kwargs = get_t4_model_kwargs(model_type)
model_kwargs.update(kwargs)
if model_type == "openvino" or model_type == "causal-openvino":
model = OpenVINOModel(name, model_type, **model_kwargs)
else:
# Fallback for this script
raise ValueError(f"Model type {model_type} not implemented in this script")
self._models[cache_key] = model
if not lazy:
model.load()
return model
unified_model_manager = UnifiedModelManager()
# ==========================================
# PATIENT SUMMARIZER AGENT (from patient_summary_agent.py)
# ==========================================
class PatientSummarizerAgent:
def __init__(self, model_name: str = None, model_type: str = None):
self.current_model_name = model_name
self.current_model_type = model_type
self.model_loader = None
def configure_model(self, model_name: str, model_type: str = None):
self.current_model_name = model_name
self.current_model_type = model_type or detect_model_type(model_name)
self.model_loader = unified_model_manager.get_model(
self.current_model_name,
self.current_model_type,
lazy=True
)
return self.model_loader
def generate_patient_summary(self, patient_data: Union[List[str], Dict]) -> str:
if not self.model_loader:
self.configure_model(self.current_model_name, self.current_model_type)
if self.model_loader.status != ModelStatus.LOADED:
self.model_loader.load()
# Process data
if isinstance(patient_data, dict):
processed_data = process_patient_record_robust(patient_data)
prompt = create_ai_prompt(processed_data)
else:
raise ValueError("Patient data must be a dictionary")
# Generate
gen_config = get_t4_generation_config(self.current_model_type)
config = GenerationConfig(**gen_config)
result = self.model_loader.generate(prompt, config)
return result
# ==========================================
# MAIN EXECUTION
# ==========================================
if __name__ == "__main__":
# Sample Patient Data
sample_patient_data = {
"patientName": "John Doe",
"patientNumber": "12345",
"demographics": {
"age": "65",
"gender": "Male",
"dob": "1958-05-15"
},
"pastMedicalHistory": [
"Hypertension",
"Type 2 Diabetes",
"Hyperlipidemia"
],
"allergies": [
"Penicillin"
],
"visits": [
{
"chartdate": "2023-01-15",
"chiefComplaint": "Routine checkup",
"vitals": {
"Bp(sys)(mmHg)": "130",
"Bp(dia)(mmHg)": "85",
"Pulse(bpm)": "72"
},
"diagnosis": ["Hypertension", "Type 2 Diabetes"],
"medications": ["Lisinopril 10mg", "Metformin 500mg"]
},
{
"chartdate": "2023-06-20",
"chiefComplaint": "Dizziness and fatigue",
"vitals": {
"Bp(sys)(mmHg)": "110",
"Bp(dia)(mmHg)": "70",
"Pulse(bpm)": "65"
},
"diagnosis": ["Dehydration", "Hypotension"],
"medications": ["Lisinopril held", "Metformin 500mg"]
}
]
}
print("Initializing PatientSummarizerAgent...")
agent = PatientSummarizerAgent(
model_name="microsoft/Phi-3-mini-4k-instruct",
model_type="causal-openvino"
)
print("Generating summary...")
try:
summary = agent.generate_patient_summary(sample_patient_data)
print("\n" + "="*50)
print("GENERATED PATIENT SUMMARY")
print("="*50)
print(summary)
except Exception as e:
print(f"Error generating summary: {e}")