Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Gradio demo for CountEM Automatic Music Transcription. | |
| This demo allows users to upload audio files and transcribe them to MIDI | |
| using pre-trained models from Hugging Face Hub. | |
| """ | |
| import gradio as gr | |
| import spaces | |
| import tempfile | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import soundfile as sf | |
| import librosa | |
| import logging | |
| from onsets_and_frames.hf_model import CountEMModel | |
| from onsets_and_frames.constants import SAMPLE_RATE | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Cache for loaded models to avoid reloading | |
| model_cache = {} | |
| def load_model(model_name: str) -> CountEMModel: | |
| """Load model from cache or download from Hugging Face Hub.""" | |
| if model_name not in model_cache: | |
| logger.info(f"Loading model: {model_name}") | |
| model_cache[model_name] = CountEMModel.from_pretrained(model_name) | |
| logger.info(f"Model loaded successfully") | |
| return model_cache[model_name] | |
| def transcribe_audio( | |
| audio_input, | |
| model_choice: str, | |
| onset_threshold: float, | |
| frame_threshold: float, | |
| ) -> tuple: | |
| """ | |
| Transcribe audio to MIDI. | |
| Args: | |
| audio_input: Tuple of (sample_rate, audio_data) from Gradio Audio component | |
| model_choice: Model to use ("MusicNet" or "Synth") | |
| onset_threshold: Threshold for onset detection | |
| frame_threshold: Threshold for frame detection | |
| Returns: | |
| Tuple of (output_midi_path, status_message) | |
| """ | |
| try: | |
| # Handle empty input | |
| if audio_input is None: | |
| return None, "Error: Please upload an audio file" | |
| # Map model choice to HuggingFace repo ID | |
| model_map = { | |
| "MusicNet (Recommended)": "Yoni232/countem-musicnet", | |
| "Synth": "Yoni232/countem-synth", | |
| } | |
| model_name = model_map[model_choice] | |
| # Extract audio data | |
| # Gradio Audio component returns (sample_rate, audio_array) or audio file path | |
| input_filename = None | |
| if isinstance(audio_input, tuple): | |
| sr, audio = audio_input | |
| # Convert to float32 if needed | |
| if audio.dtype == np.int16: | |
| audio = audio.astype(np.float32) / 32768.0 | |
| elif audio.dtype == np.int32: | |
| audio = audio.astype(np.float32) / 2147483648.0 | |
| elif isinstance(audio_input, str): | |
| # Audio file path provided | |
| audio, sr = librosa.load(audio_input, sr=None, mono=True) | |
| # Extract filename for output naming | |
| input_filename = Path(audio_input).stem | |
| else: | |
| return None, f"Error: Unexpected audio input type: {type(audio_input)}" | |
| # Convert stereo to mono if needed | |
| if len(audio.shape) > 1: | |
| audio = audio.mean(axis=1) | |
| # Resample to 16kHz if needed | |
| if sr != SAMPLE_RATE: | |
| logger.info(f"Resampling from {sr}Hz to {SAMPLE_RATE}Hz") | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=SAMPLE_RATE) | |
| sr = SAMPLE_RATE | |
| # Check audio length | |
| duration = len(audio) / sr | |
| if duration < 0.5: | |
| return None, "Error: Audio is too short (minimum 0.5 seconds)" | |
| if duration > 600: # 10 minutes | |
| return ( | |
| None, | |
| f"Error: Audio is too long ({duration:.1f}s). Maximum is 10 minutes (600s).", | |
| ) | |
| # Load model | |
| status = f"Loading {model_choice} model..." | |
| logger.info(status) | |
| model = load_model(model_name) | |
| # Transcribe | |
| status = f"Transcribing {duration:.1f} seconds of audio..." | |
| logger.info(status) | |
| # Create temporary MIDI file with original filename if available | |
| if input_filename: | |
| temp_dir = tempfile.gettempdir() | |
| output_path = os.path.join(temp_dir, f"{input_filename}.mid") | |
| else: | |
| with tempfile.NamedTemporaryFile(suffix=".mid", delete=False) as tmp: | |
| output_path = tmp.name | |
| model.transcribe_to_midi( | |
| audio, | |
| output_path, | |
| onset_threshold=onset_threshold, | |
| frame_threshold=frame_threshold, | |
| ) | |
| # Success message | |
| success_msg = f""" | |
| ✓ Transcription complete! | |
| - Model: {model_choice} | |
| - Duration: {duration:.2f} seconds | |
| - Sample rate: {sr} Hz | |
| - Onset threshold: {onset_threshold} | |
| - Frame threshold: {frame_threshold} | |
| Download your MIDI file using the button below. | |
| """ | |
| return output_path, success_msg.strip() | |
| except Exception as e: | |
| error_msg = f"Error during transcription: {str(e)}" | |
| logger.error(error_msg) | |
| return None, error_msg | |
| # Build Gradio interface | |
| with gr.Blocks(title="CountEM - Music Transcription") as demo: | |
| gr.Markdown( | |
| """ | |
| # CountEM - Automatic Music Transcription | |
| Upload a piano/music recording and transcribe it to MIDI using a model that was trained using the CountEM framework on the MusicNet dataset. | |
| **Paper:** [Count the Notes: Histogram-Based Supervision for Automatic Music Transcription](https://arxiv.org/abs/2511.14250) (ISMIR 2025) | |
| **Models on Hugging Face:** | |
| - [countem-musicnet](https://huggingface.co/Yoni232/countem-musicnet) - Trained on MusicNet dataset | |
| - [countem-synth](https://huggingface.co/Yoni232/countem-synth) - Trained on synthetic data | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input section | |
| audio_input = gr.Audio( | |
| label="Upload Audio File", | |
| type="filepath", | |
| sources=["upload"], | |
| ) | |
| model_choice = gr.Radio( | |
| choices=["MusicNet (Recommended)", "Synth"], | |
| value="MusicNet (Recommended)", | |
| label="Model Selection", | |
| info="MusicNet model is trained on real piano recordings, Synth on synthetic data", | |
| ) | |
| with gr.Row(): | |
| onset_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.5, | |
| step=0.05, | |
| label="Onset Threshold", | |
| info="Higher = fewer notes detected", | |
| ) | |
| frame_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=0.9, | |
| value=0.5, | |
| step=0.05, | |
| label="Frame Threshold", | |
| info="Higher = shorter note durations", | |
| ) | |
| transcribe_btn = gr.Button("Transcribe to MIDI", variant="primary") | |
| with gr.Column(): | |
| # Output section | |
| output_midi = gr.File(label="Download MIDI", interactive=False) | |
| status_output = gr.Textbox( | |
| label="Status", | |
| lines=10, | |
| interactive=False, | |
| placeholder="Upload audio and click 'Transcribe to MIDI' to start...", | |
| ) | |
| # Example files | |
| gr.Markdown( | |
| """ | |
| ### Notes: | |
| - Audio will be automatically resampled to 16kHz if needed, and converted to mono | |
| - Supports common formats: WAV, FLAC, MP3, M4a | |
| - Maximum duration: 10 minutes | |
| - Best results with classical music | |
| - Processing time depends on audio length (typically a few seconds per minute of audio) | |
| """ | |
| ) | |
| # Connect button to function | |
| transcribe_btn.click( | |
| fn=transcribe_audio, | |
| inputs=[audio_input, model_choice, onset_threshold, frame_threshold], | |
| outputs=[output_midi, status_output], | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Project Links:** | |
| - [GitHub Repository](https://github.com/Yoni-Yaffe/count-the-notes) | |
| - [Project Page](https://yoni-yaffe.github.io/count-the-notes/) | |
| - [ArXiv Paper](https://arxiv.org/abs/2511.14250) | |
| If you use this work, please cite: | |
| ``` | |
| @misc{yaffe2025countnoteshistogrambasedsupervision, | |
| title={Count The Notes: Histogram-Based Supervision for Automatic Music Transcription}, | |
| author={Jonathan Yaffe and Ben Maman and Meinard Müller and Amit H. Bermano}, | |
| year={2025}, | |
| eprint={2511.14250}, | |
| archivePrefix={arXiv}, | |
| primaryClass={cs.SD}, | |
| url={https://arxiv.org/abs/2511.14250}, | |
| } | |
| ``` | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| # Pre-load the default model to speed up first transcription | |
| logger.info("Pre-loading default model...") | |
| load_model("Yoni232/countem-musicnet") | |
| logger.info("Model pre-loaded. Starting Gradio interface...") | |
| # Launch the demo | |
| demo.launch( | |
| share=False, # Set to True to create a public link | |
| server_name="0.0.0.0", # Allow access from network | |
| server_port=7860, | |
| ) | |