import torch import torchaudio import torchaudio.transforms as T import torch.nn.functional as F import torchaudio.functional as AF import numpy as np import pandas as pd import matplotlib.pyplot as plt from pathlib import Path import random import noisereduce as nr import librosa import scipy import pickle import os from tqdm import tqdm class Load: """Loads an audio signal into memory in normalized form""" def __init__(self): pass def load(self, file_path): signal, sample_rate = torchaudio.load(file_path, channels_first=True, normalize=True) return signal, sample_rate class StereoToMono: """Applies mapping from stereo to mono""" def __init__(self): pass def stereo_to_mono(self, stereo_signal): mono_signal = stereo_signal.mean(dim=0, keepdim=True) return mono_signal class Resample: """Applies resampling onto a signal""" def __init__(self): self.sr_in = None self.sr_out = None def resample(self, signal, sr_in, sr_out, debug = True): self.sr_in = sr_in self.sr_out = sr_out if sr_in == sr_out: print('No remsampling needed') if debug else None return signal, sr_out print('Resampling the signal...') resampler = torchaudio.transforms.Resample(orig_freq=self.sr_in, new_freq=self.sr_out) return resampler(signal), sr_out class NoiseRemover: def __init__(self): self._sr = None self._signal = None self._denoised_signal = None def remove_noise(self, signal, sr): self._sr = sr signal = signal.squeeze(0).numpy() self._signal = signal denoised = nr.reduce_noise(y = signal, sr = sr) self._denoised_signal = torch.tensor(denoised).unsqueeze(0) return self._denoised_signal,sr class TruncateOrPad: """Dynamically truncates or pads depending on the signal""" def __init__(self, max_duration: int, sr_out: int = 16_000): self.max_duration = max_duration self.sr_out = sr_out self.tot_samples_expected = sr_out * max_duration def truncate_or_pad(self, signal, debug = True): tot_samples = signal.shape[-1] if tot_samples == self.tot_samples_expected: print('Signal already at max duration') if debug else None return signal elif tot_samples > self.tot_samples_expected: print('Truncating the signal') return self._truncate(signal) else: print('Padding the signal') return self._pad(signal) def _truncate(self, signal): return signal[..., :self.tot_samples_expected] def _pad(self, signal): pad_amount = self.tot_samples_expected - signal.shape[-1] return F.pad(signal, (0, pad_amount)) class FeatureExtractor: """Extracts features: linear, log spectrograms, mel spectrograms""" def __init__(self, n_fft=1024, hop_length=256, sr=16000, n_mels=80): self.n_fft = n_fft self.hop_length = hop_length self.sr = sr self.n_mels = n_mels self._window = torch.hann_window(n_fft) def stft_spec(self, signal): return torch.stft( signal, n_fft=self.n_fft, hop_length=self.hop_length, window=self._window.to(device=signal.device, dtype=signal.dtype), center=True, return_complex=True ) def linear_mag(self, signal): """stft -> abs""" return self.stft_spec(signal).abs() def linear_power(self, signal): """stft -> abs -> **2""" return self.linear_mag(signal).pow(2) def mel_scale(self, signal): """Mel spectrogram (power)""" mel_spec = torchaudio.transforms.MelSpectrogram( sample_rate=self.sr, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels, center=True, power=2.0 )(signal) return mel_spec def log_mag(self, signal, eps=1e-10): return 20 * torch.log10(self.linear_mag(signal) + eps) def log_power(self, signal, eps=1e-10): return 10 * torch.log10(self.linear_power(signal) + eps) def log_mel_scale(self, signal): """Log-mel spectrogram for classification""" mel_spec = self.mel_scale(signal) log_mel_spec = torchaudio.transforms.AmplitudeToDB(top_db=80)(mel_spec) return log_mel_spec class NormalizeFeatures: @staticmethod def min_max_normalize(mel: torch.Tensor): max_val = mel.max() min_val = mel.min() mel_norm = (mel - min_val) / (max_val - min_val + 1e-8) # avoid div by 0 return mel_norm, min_val, max_val class BirdDatasetSaver: def __init__(self, save_dir): self.save_dir = save_dir os.makedirs(save_dir, exist_ok=True) def save(self, bird_category: str, audio_file_name: str, log_mel: torch.Tensor, mel_norm: torch.Tensor): category_path = os.path.join(self.save_dir, bird_category) classification_path = os.path.join(category_path, "classification") generation_path = os.path.join(category_path, "generation") os.makedirs(classification_path, exist_ok=True) os.makedirs(generation_path, exist_ok=True) stem = Path(audio_file_name).stem torch.save(log_mel, os.path.join(classification_path, f"{stem}_logmel.pt")) torch.save(mel_norm, os.path.join(generation_path, f"{stem}_mel.pt")) class PreprocessingPipeline: def __init__(self, save_dir, max_duration=4, sr_out=22050, n_fft=1024, hop_length=256, n_mels=80, debug = False): self.loader = Load() self.stereo2mono = StereoToMono() self.resampler = Resample() self.truncate_pad = TruncateOrPad(max_duration=max_duration, sr_out=sr_out) self.fe = FeatureExtractor(n_fft=n_fft, hop_length=hop_length, sr=sr_out, n_mels=n_mels) self.normer = NormalizeFeatures() self.saver = BirdDatasetSaver(save_dir) self.sr_out = sr_out self.debug = debug def process_file(self, bird_category, audio_file_path): audio_file_name = Path(audio_file_path).name # Load signal, sr = self.loader.load(audio_file_path) # Stereo -> mono signal = self.stereo2mono.stereo_to_mono(signal) # Resample signal, sr = self.resampler.resample(signal, sr, self.sr_out, self.debug) # Truncate/pad signal = self.truncate_pad.truncate_or_pad(signal, self.debug) # Extract features log_mel = self.fe.log_mel_scale(signal) # for classification mel = self.fe.mel_scale(signal) # linear mel mel_norm, _, _ = self.normer.min_max_normalize(mel) # Save self.saver.save(bird_category, audio_file_name, log_mel, mel_norm) def process_dataset(self, root_dir): for bird_category in tqdm(os.listdir(root_dir)): category_path = os.path.join(root_dir, bird_category) if not os.path.isdir(category_path): continue for audio_file in os.listdir(category_path): if not audio_file.endswith(".wav"): continue audio_file_path = os.path.join(category_path, audio_file) self.process_file(bird_category, audio_file_path)