import torch import torchaudio import gradio as gr from pathlib import Path from preprocess import PreprocessingPipeline from model_arch import ResNetModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load model model = ResNetModel(ch_in=1, num_classes=5) model.load_state_dict(torch.load("resnet3.pth", map_location=device)) model.to(device) model.eval() pipeline = PreprocessingPipeline(save_dir="tmp", max_duration=4, sr_out=22050) BIRD_CLASSES = ['short_chirp', 'sweep_down', 'sweep_up', 'warble', 'whistle_click'] # adjust def predict_bird(audio_file): # Preprocess signal, sr = pipeline.loader.load(audio_file) signal = pipeline.stereo2mono.stereo_to_mono(signal) signal, sr = pipeline.resampler.resample(signal, sr, pipeline.sr_out, debug=False) signal = pipeline.truncate_pad.truncate_or_pad(signal, debug=False) # Extract log-mel spectrogram log_mel = pipeline.fe.log_mel_scale(signal) # [1, n_mels, T] log_mel = log_mel.unsqueeze(0).to(device) # add batch dim # Forward pass with torch.no_grad(): output = model(log_mel) probs = torch.softmax(output, dim=1).cpu().numpy()[0] # Build result dictionary result = {cls: float(prob) for cls, prob in zip(BIRD_CLASSES, probs)} predicted_class = BIRD_CLASSES[probs.argmax()] return predicted_class, result iface = gr.Interface( fn=predict_bird, inputs=gr.Audio(type="filepath"), outputs=[gr.Textbox(label="Predicted Bird"), gr.Label(num_top_classes=5)], title="Bird Call Classifier", description="Upload a bird call .wav file and get the predicted bird species." ) if __name__ == "__main__": iface.launch()