BirdGen / app.py
triaNova's picture
dsds
90eaa93
import torch
import torchaudio
import gradio as gr
from pathlib import Path
from preprocess import PreprocessingPipeline
from model_arch import ResNetModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = ResNetModel(ch_in=1, num_classes=5)
model.load_state_dict(torch.load("resnet3.pth", map_location=device))
model.to(device)
model.eval()
pipeline = PreprocessingPipeline(save_dir="tmp", max_duration=4, sr_out=22050)
BIRD_CLASSES = ['short_chirp', 'sweep_down', 'sweep_up', 'warble', 'whistle_click'] # adjust
def predict_bird(audio_file):
# Preprocess
signal, sr = pipeline.loader.load(audio_file)
signal = pipeline.stereo2mono.stereo_to_mono(signal)
signal, sr = pipeline.resampler.resample(signal, sr, pipeline.sr_out, debug=False)
signal = pipeline.truncate_pad.truncate_or_pad(signal, debug=False)
# Extract log-mel spectrogram
log_mel = pipeline.fe.log_mel_scale(signal) # [1, n_mels, T]
log_mel = log_mel.unsqueeze(0).to(device) # add batch dim
# Forward pass
with torch.no_grad():
output = model(log_mel)
probs = torch.softmax(output, dim=1).cpu().numpy()[0]
# Build result dictionary
result = {cls: float(prob) for cls, prob in zip(BIRD_CLASSES, probs)}
predicted_class = BIRD_CLASSES[probs.argmax()]
return predicted_class, result
iface = gr.Interface(
fn=predict_bird,
inputs=gr.Audio(type="filepath"),
outputs=[gr.Textbox(label="Predicted Bird"), gr.Label(num_top_classes=5)],
title="Bird Call Classifier",
description="Upload a bird call .wav file and get the predicted bird species."
)
if __name__ == "__main__":
iface.launch()