|
|
import torch |
|
|
import torchaudio |
|
|
import gradio as gr |
|
|
from pathlib import Path |
|
|
from preprocess import PreprocessingPipeline |
|
|
|
|
|
from model_arch import ResNetModel |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
model = ResNetModel(ch_in=1, num_classes=5) |
|
|
model.load_state_dict(torch.load("resnet3.pth", map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
pipeline = PreprocessingPipeline(save_dir="tmp", max_duration=4, sr_out=22050) |
|
|
|
|
|
|
|
|
BIRD_CLASSES = ['short_chirp', 'sweep_down', 'sweep_up', 'warble', 'whistle_click'] |
|
|
|
|
|
def predict_bird(audio_file): |
|
|
|
|
|
signal, sr = pipeline.loader.load(audio_file) |
|
|
signal = pipeline.stereo2mono.stereo_to_mono(signal) |
|
|
signal, sr = pipeline.resampler.resample(signal, sr, pipeline.sr_out, debug=False) |
|
|
signal = pipeline.truncate_pad.truncate_or_pad(signal, debug=False) |
|
|
|
|
|
|
|
|
log_mel = pipeline.fe.log_mel_scale(signal) |
|
|
log_mel = log_mel.unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model(log_mel) |
|
|
probs = torch.softmax(output, dim=1).cpu().numpy()[0] |
|
|
|
|
|
|
|
|
result = {cls: float(prob) for cls, prob in zip(BIRD_CLASSES, probs)} |
|
|
predicted_class = BIRD_CLASSES[probs.argmax()] |
|
|
|
|
|
return predicted_class, result |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=predict_bird, |
|
|
inputs=gr.Audio(type="filepath"), |
|
|
outputs=[gr.Textbox(label="Predicted Bird"), gr.Label(num_top_classes=5)], |
|
|
title="Bird Call Classifier", |
|
|
description="Upload a bird call .wav file and get the predicted bird species." |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch() |
|
|
|