Spaces:
Sleeping
Sleeping
| """ | |
| Vera - AI Coaching Dashboard | |
| A real-time speech emotion analysis tool for coaching sessions. | |
| """ | |
| import os | |
| # Set cache directory to something writable in your Space | |
| os.environ["HF_HOME"] = "/app/cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/app/cache" | |
| os.environ["XDG_CACHE_HOME"] = "/app/cache" | |
| # Make sure it exists | |
| os.makedirs("/app/cache", exist_ok=True) | |
| from transformers import pipeline | |
| classifier = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") | |
| import io | |
| import wave | |
| import pyaudio | |
| import threading | |
| import time | |
| import logging | |
| from datetime import datetime | |
| from collections import deque | |
| from typing import Dict, Optional, List, Tuple | |
| from dataclasses import dataclass | |
| from contextlib import contextmanager | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| import streamlit as st | |
| from transformers import pipeline | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| load_dotenv() | |
| class SentimentResult: | |
| """Data class for sentiment analysis results.""" | |
| label: str | |
| score: float | |
| def __post_init__(self): | |
| """Validate sentiment result.""" | |
| if self.label not in ["POSITIVE", "NEGATIVE", "NEUTRAL"]: | |
| self.label = "NEUTRAL" | |
| self.score = max(0.0, min(1.0, self.score)) | |
| class TranscriptionEntry: | |
| """Data class for a single transcription entry.""" | |
| text: str | |
| sentiment: SentimentResult | |
| timestamp: datetime | |
| class AudioConfig: | |
| """Configuration for audio recording.""" | |
| def __init__( | |
| self, | |
| chunk_duration: int = 3, | |
| sample_rate: int = 16000, | |
| channels: int = 1, | |
| chunk_size: int = 1024, | |
| format: int = pyaudio.paInt16 | |
| ): | |
| self.chunk_duration = chunk_duration | |
| self.sample_rate = sample_rate | |
| self.channels = channels | |
| self.chunk_size = chunk_size | |
| self.format = format | |
| class SentimentAnalyzer: | |
| """Handles sentiment analysis with enhanced neutral detection.""" | |
| NEUTRAL_KEYWORDS = [ | |
| 'okay', 'ok', 'fine', 'alright', 'whatever', 'maybe', 'perhaps', | |
| 'guess', 'not sure', "don't know", 'dunno', 'meh', 'so-so', | |
| 'neither', 'middle', 'normal', 'average', 'moderate', 'fair' | |
| ] | |
| CONFIDENCE_THRESHOLD = 0.8 | |
| MIN_WORD_COUNT = 3 | |
| def __init__(self, model_name: str = "distilbert-base-uncased-finetuned-sst-2-english"): | |
| """Initialize sentiment analyzer with specified model.""" | |
| self.model = pipeline("sentiment-analysis", model=model_name) | |
| def analyze(self, text: str) -> SentimentResult: | |
| """ | |
| Analyze sentiment of text with enhanced neutral detection. | |
| Args: | |
| text: Input text to analyze | |
| Returns: | |
| SentimentResult with label and confidence score | |
| """ | |
| if not text or not text.strip(): | |
| return SentimentResult(label="NEUTRAL", score=0.5) | |
| try: | |
| # Get raw sentiment from model (truncate to avoid token limit) | |
| result = self.model(text[:512])[0] | |
| label = result["label"] | |
| score = result["score"] | |
| # Enhanced neutral detection | |
| if self._should_be_neutral(text, score): | |
| return SentimentResult(label="NEUTRAL", score=score) | |
| return SentimentResult(label=label, score=score) | |
| except Exception as e: | |
| logger.error(f"Sentiment analysis error: {e}") | |
| return SentimentResult(label="NEUTRAL", score=0.5) | |
| def _should_be_neutral(self, text: str, score: float) -> bool: | |
| """Determine if text should be classified as neutral.""" | |
| text_lower = text.lower() | |
| word_count = len(text.split()) | |
| has_neutral_keyword = any( | |
| keyword in text_lower for keyword in self.NEUTRAL_KEYWORDS | |
| ) | |
| return ( | |
| has_neutral_keyword or | |
| score < self.CONFIDENCE_THRESHOLD or | |
| word_count < self.MIN_WORD_COUNT | |
| ) | |
| def get_sentiment_analyzer() -> SentimentAnalyzer: | |
| """Get cached sentiment analyzer instance.""" | |
| return SentimentAnalyzer() | |
| class AudioTranscriber: | |
| """Handles audio transcription using OpenAI Whisper.""" | |
| def __init__(self, client: OpenAI, audio_config: AudioConfig): | |
| """ | |
| Initialize transcriber. | |
| Args: | |
| client: OpenAI client instance | |
| audio_config: Audio configuration | |
| """ | |
| self.client = client | |
| self.audio_config = audio_config | |
| self._audio = pyaudio.PyAudio() | |
| def transcribe(self, audio_data: bytes) -> Optional[str]: | |
| """ | |
| Transcribe audio data to text. | |
| Args: | |
| audio_data: Raw audio bytes | |
| Returns: | |
| Transcribed text or None if transcription fails | |
| """ | |
| try: | |
| wav_buffer = self._create_wav_buffer(audio_data) | |
| response = self.client.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=("audio.wav", wav_buffer.read(), "audio/wav"), | |
| language="en", | |
| ) | |
| return response.text.strip() if response.text else None | |
| except Exception as e: | |
| logger.error(f"Transcription error: {e}") | |
| return None | |
| def _create_wav_buffer(self, audio_data: bytes) -> io.BytesIO: | |
| """Create WAV format buffer from raw audio data.""" | |
| wav_buffer = io.BytesIO() | |
| with wave.open(wav_buffer, "wb") as wav_file: | |
| wav_file.setnchannels(self.audio_config.channels) | |
| wav_file.setsampwidth( | |
| self._audio.get_sample_size(self.audio_config.format) | |
| ) | |
| wav_file.setframerate(self.audio_config.sample_rate) | |
| wav_file.writeframes(audio_data) | |
| wav_buffer.seek(0) | |
| return wav_buffer | |
| def cleanup(self): | |
| """Clean up PyAudio resources.""" | |
| if self._audio: | |
| self._audio.terminate() | |
| class CoachingDashboard: | |
| """Main dashboard for real-time coaching emotion analysis.""" | |
| def __init__( | |
| self, | |
| chunk_duration: int = 3, | |
| sample_rate: int = 16000, | |
| max_history: int = 50 | |
| ): | |
| """ | |
| Initialize coaching dashboard. | |
| Args: | |
| chunk_duration: Duration of each audio chunk in seconds | |
| sample_rate: Audio sample rate in Hz | |
| max_history: Maximum number of transcriptions to keep | |
| """ | |
| self.audio_config = AudioConfig( | |
| chunk_duration=chunk_duration, | |
| sample_rate=sample_rate | |
| ) | |
| self.max_history = max_history | |
| # Initialize API client | |
| try: | |
| api_key = st.secrets.get("OPENAI_API_KEY") or os.getenv("OPENAI_API_KEY") | |
| except Exception: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY not found in environment or secrets") | |
| self.client = OpenAI(api_key=api_key) | |
| # Initialize components | |
| self.transcriber = AudioTranscriber(self.client, self.audio_config) | |
| self.sentiment_analyzer = get_sentiment_analyzer() | |
| # Audio recording state | |
| self.stream: Optional[pyaudio.Stream] = None | |
| self.is_recording = False | |
| self.audio_buffer_lock = threading.Lock() | |
| self.audio_buffer: List[bytes] = [] | |
| # Session data | |
| self.entries: deque[TranscriptionEntry] = deque(maxlen=max_history) | |
| self.current_sentiment = SentimentResult(label="NEUTRAL", score=0.5) | |
| self.session_start: Optional[datetime] = None | |
| def start_recording(self) -> bool: | |
| """ | |
| Start audio recording session. | |
| Returns: | |
| True if recording started successfully, False otherwise | |
| """ | |
| if self.is_recording: | |
| logger.warning("Recording already in progress") | |
| return False | |
| try: | |
| audio = pyaudio.PyAudio() | |
| self.stream = audio.open( | |
| format=self.audio_config.format, | |
| channels=self.audio_config.channels, | |
| rate=self.audio_config.sample_rate, | |
| input=True, | |
| frames_per_buffer=self.audio_config.chunk_size, | |
| ) | |
| self.is_recording = True | |
| self.session_start = datetime.now() | |
| # Start background threads | |
| threading.Thread(target=self._record_audio, daemon=True).start() | |
| threading.Thread(target=self._process_transcription, daemon=True).start() | |
| logger.info("Recording started successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to start recording: {e}") | |
| self.stop_recording() | |
| raise | |
| def stop_recording(self): | |
| """Stop audio recording session.""" | |
| if not self.is_recording: | |
| return | |
| self.is_recording = False | |
| if self.stream: | |
| try: | |
| self.stream.stop_stream() | |
| self.stream.close() | |
| except Exception as e: | |
| logger.error(f"Error closing stream: {e}") | |
| logger.info("Recording stopped") | |
| def _record_audio(self): | |
| """Background thread for recording audio chunks.""" | |
| frames = [] | |
| frames_per_chunk = int( | |
| self.audio_config.sample_rate * self.audio_config.chunk_duration | |
| ) | |
| while self.is_recording: | |
| try: | |
| if not self.stream: | |
| break | |
| data = self.stream.read( | |
| self.audio_config.chunk_size, | |
| exception_on_overflow=False | |
| ) | |
| frames.append(data) | |
| # When we have enough frames, add to buffer | |
| if len(frames) * self.audio_config.chunk_size >= frames_per_chunk: | |
| audio_chunk = b"".join(frames) | |
| with self.audio_buffer_lock: | |
| self.audio_buffer.append(audio_chunk) | |
| frames = [] | |
| except Exception as e: | |
| logger.error(f"Error recording audio: {e}") | |
| break | |
| def _process_transcription(self): | |
| """Background thread for processing transcriptions.""" | |
| while self.is_recording: | |
| # Get audio chunk from buffer | |
| audio_data = None | |
| with self.audio_buffer_lock: | |
| if self.audio_buffer: | |
| audio_data = self.audio_buffer.pop(0) | |
| if audio_data: | |
| self._process_audio_chunk(audio_data) | |
| else: | |
| time.sleep(0.1) | |
| def _process_audio_chunk(self, audio_data: bytes): | |
| """Process a single audio chunk through transcription and sentiment analysis.""" | |
| try: | |
| # Transcribe | |
| text = self.transcriber.transcribe(audio_data) | |
| if not text: | |
| return | |
| # Analyze sentiment | |
| sentiment = self.sentiment_analyzer.analyze(text) | |
| # Store entry | |
| entry = TranscriptionEntry( | |
| text=text, | |
| sentiment=sentiment, | |
| timestamp=datetime.now() | |
| ) | |
| self.entries.append(entry) | |
| self.current_sentiment = sentiment | |
| logger.info(f"Processed: {text[:50]}... ({sentiment.label})") | |
| except Exception as e: | |
| logger.error(f"Error processing audio chunk: {e}") | |
| def get_session_duration(self) -> int: | |
| """Get current session duration in seconds.""" | |
| if not self.session_start: | |
| return 0 | |
| return int((datetime.now() - self.session_start).total_seconds()) | |
| def get_sentiment_stats(self) -> Dict[str, int]: | |
| """Get count of each sentiment type.""" | |
| stats = {"POSITIVE": 0, "NEUTRAL": 0, "NEGATIVE": 0} | |
| for entry in self.entries: | |
| stats[entry.sentiment.label] += 1 | |
| return stats | |
| def get_recent_entries(self, n: int = 5) -> List[TranscriptionEntry]: | |
| """Get the n most recent transcription entries.""" | |
| return list(self.entries)[-n:] | |
| def cleanup(self): | |
| """Clean up all resources.""" | |
| self.stop_recording() | |
| self.transcriber.cleanup() | |
| class DashboardUI: | |
| """Handles the Streamlit UI for the coaching dashboard.""" | |
| COLORS = { | |
| "POSITIVE": "#00C853", | |
| "NEUTRAL": "#FFC107", | |
| "NEGATIVE": "#FF1744" | |
| } | |
| EMOJIS = { | |
| "POSITIVE": { | |
| 0.95: "🥳", | |
| 0.85: "😁", | |
| 0.70: "😊", | |
| 0.00: "🙂" | |
| }, | |
| "NEGATIVE": { | |
| 0.95: "😭", | |
| 0.85: "😢", | |
| 0.70: "😟", | |
| 0.00: "😕" | |
| }, | |
| "NEUTRAL": { | |
| 0.60: "😐", | |
| 0.00: "🤷" | |
| } | |
| } | |
| def __init__(self, dashboard: CoachingDashboard): | |
| """Initialize UI with dashboard instance.""" | |
| self.dashboard = dashboard | |
| def render(self): | |
| """Render the complete dashboard UI.""" | |
| st.set_page_config(page_title="Vera", layout="wide") | |
| self._inject_custom_css() | |
| st.title("🎯 Vera - Your Coaching Companion") | |
| self._render_sidebar() | |
| self._render_main_content() | |
| # Auto-refresh when recording | |
| if self.dashboard.is_recording: | |
| time.sleep(2) | |
| st.rerun() | |
| def _inject_custom_css(self): | |
| """Inject custom CSS styles.""" | |
| st.markdown(""" | |
| <style> | |
| .sentiment-box { | |
| padding: 30px; | |
| border-radius: 15px; | |
| text-align: center; | |
| font-size: 20px; | |
| font-weight: bold; | |
| margin: 20px 0; | |
| box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .transcription-card { | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin: 10px 0; | |
| transition: transform 0.2s; | |
| } | |
| .transcription-card:hover { | |
| transform: translateX(5px); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| def _render_sidebar(self): | |
| """Render sidebar with controls and stats.""" | |
| with st.sidebar: | |
| st.header("🎮 Controls") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("▶️ Start", disabled=self.dashboard.is_recording, use_container_width=True): | |
| try: | |
| self.dashboard.start_recording() | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Failed to start: {e}") | |
| with col2: | |
| if st.button("⏹️ Stop", disabled=not self.dashboard.is_recording, use_container_width=True): | |
| self.dashboard.stop_recording() | |
| st.rerun() | |
| st.divider() | |
| # Recording status | |
| if self.dashboard.is_recording: | |
| st.success("🔴 Recording...") | |
| duration = self.dashboard.get_session_duration() | |
| st.metric("Duration", f"{duration // 60}m {duration % 60}s") | |
| else: | |
| st.info("⚪ Stopped") | |
| st.divider() | |
| # Statistics | |
| st.header("📊 Statistics") | |
| st.metric("Total Entries", len(self.dashboard.entries)) | |
| if self.dashboard.entries: | |
| stats = self.dashboard.get_sentiment_stats() | |
| total = len(self.dashboard.entries) | |
| st.metric( | |
| "😊 Positive", | |
| f"{stats['POSITIVE']} ({stats['POSITIVE']/total*100:.0f}%)" | |
| ) | |
| st.metric( | |
| "😐 Neutral", | |
| f"{stats['NEUTRAL']} ({stats['NEUTRAL']/total*100:.0f}%)" | |
| ) | |
| st.metric( | |
| "😟 Negative", | |
| f"{stats['NEGATIVE']} ({stats['NEGATIVE']/total*100:.0f}%)" | |
| ) | |
| def _render_main_content(self): | |
| """Render main content area.""" | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| self._render_emotion_timeline() | |
| with col2: | |
| self._render_current_status() | |
| st.divider() | |
| self._render_recent_transcriptions() | |
| def _render_emotion_timeline(self): | |
| """Render emotion timeline chart.""" | |
| st.subheader("📈 Emotion Timeline") | |
| if not self.dashboard.entries: | |
| st.info("Start a session to see the emotion timeline") | |
| return | |
| # Prepare data | |
| timestamps = [entry.timestamp for entry in self.dashboard.entries] | |
| scores = [self._sentiment_to_score(entry.sentiment) for entry in self.dashboard.entries] | |
| labels = [entry.sentiment.label for entry in self.dashboard.entries] | |
| # Create chart | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=timestamps, | |
| y=scores, | |
| mode='lines+markers', | |
| line=dict(width=3, color='#2196F3'), | |
| marker=dict( | |
| size=12, | |
| color=[self.COLORS[label] for label in labels], | |
| line=dict(width=2, color='white') | |
| ), | |
| hovertemplate='<b>%{text}</b><br>Score: %{y:.2f}<br>%{x}<extra></extra>', | |
| text=labels | |
| )) | |
| # Add reference zones | |
| fig.add_hline(y=0, line_dash="dash", line_color="gray", opacity=0.5) | |
| fig.add_hrect(y0=0.3, y1=1, fillcolor="green", opacity=0.1, line_width=0, annotation_text="Positive") | |
| fig.add_hrect(y0=-0.3, y1=0.3, fillcolor="yellow", opacity=0.1, line_width=0, annotation_text="Neutral") | |
| fig.add_hrect(y0=-1, y1=-0.3, fillcolor="red", opacity=0.1, line_width=0, annotation_text="Negative") | |
| fig.update_layout( | |
| height=400, | |
| xaxis_title="Time", | |
| yaxis_title="Emotional Valence", | |
| yaxis=dict(range=[-1.1, 1.1]), | |
| showlegend=False, | |
| hovermode='closest' | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| def _render_current_status(self): | |
| """Render current emotional status.""" | |
| st.subheader("💭 Current Status") | |
| sentiment = self.dashboard.current_sentiment | |
| color = self.COLORS[sentiment.label] | |
| emoji = self._get_emoji(sentiment) | |
| st.markdown(f""" | |
| <div class="sentiment-box" style="background-color: {color}; color: white;"> | |
| <div style="font-size: 48px;">{emoji}</div> | |
| <div style="margin: 10px 0;">{sentiment.label}</div> | |
| <div style="font-size: 16px; opacity: 0.9;"> | |
| Confidence: {sentiment.score:.0%} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| def _render_recent_transcriptions(self): | |
| """Render recent transcription entries.""" | |
| st.subheader("💬 Recent Transcriptions") | |
| if not self.dashboard.entries: | |
| st.info("No transcriptions yet. Start recording to see results.") | |
| return | |
| recent = self.dashboard.get_recent_entries(5) | |
| for entry in reversed(recent): | |
| color = self.COLORS[entry.sentiment.label] | |
| time_str = entry.timestamp.strftime("%H:%M:%S") | |
| emoji = self._get_emoji(entry.sentiment) | |
| st.markdown(f""" | |
| <div class="transcription-card" style=" | |
| background-color: {color}20; | |
| border-left: 5px solid {color}; | |
| "> | |
| <div style="color: {color}; font-weight: bold; margin-bottom: 8px;"> | |
| {emoji} [{time_str}] {entry.sentiment.label} | |
| <span style="opacity: 0.8;">({entry.sentiment.score:.0%})</span> | |
| </div> | |
| <div style="font-size: 16px; color: #333;"> | |
| {entry.text} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| def _sentiment_to_score(self, sentiment: SentimentResult) -> float: | |
| """Convert sentiment to -1 to 1 scale for visualization.""" | |
| if sentiment.label == "POSITIVE": | |
| return sentiment.score | |
| elif sentiment.label == "NEGATIVE": | |
| return -sentiment.score | |
| else: | |
| return 0 | |
| def _get_emoji(self, sentiment: SentimentResult) -> str: | |
| """Get appropriate emoji for sentiment and confidence.""" | |
| emoji_map = self.EMOJIS.get(sentiment.label, self.EMOJIS["NEUTRAL"]) | |
| for threshold, emoji in sorted(emoji_map.items(), reverse=True): | |
| if sentiment.score >= threshold: | |
| return emoji | |
| return "😐" | |
| def main(): | |
| """Main application entry point.""" | |
| # Initialize dashboard in session state | |
| if 'dashboard' not in st.session_state: | |
| try: | |
| st.session_state.dashboard = CoachingDashboard(chunk_duration=3) | |
| except Exception as e: | |
| st.error(f"Failed to initialize dashboard: {e}") | |
| st.stop() | |
| dashboard = st.session_state.dashboard | |
| # Render UI | |
| ui = DashboardUI(dashboard) | |
| ui.render() | |
| if __name__ == "__main__": | |
| main() |