import os import random import sys import time import librosa import numpy as np import soundfile import torch from torch.utils.data import Dataset from tqdm import tqdm from onsets_and_frames import constants from onsets_and_frames.constants import DEFAULT_DEVICE, N_KEYS, SAMPLE_RATE from onsets_and_frames.mel import melspectrogram from onsets_and_frames.midi_utils import ( midi_to_frames, save_midi_alignments_and_predictions, ) from onsets_and_frames.utils import ( get_diff, get_logger, get_peaks, shift_label, smooth_labels, ) class EMDATASET(Dataset): def __init__( self, audio_path="NoteEM_audio", tsv_path="NoteEM_tsv", labels_path="NoteEm_labels", groups=None, sequence_length=None, seed=42, device=DEFAULT_DEVICE, instrument_map=None, update_instruments=False, transcriber=None, conversion_map=None, pitch_shift=True, pitch_shift_limit=5, keep_eval_files=False, n_eval=1, evaluation_list=None, only_eval=False, save_to_memory=False, smooth_labels=False, use_onset_mask=False, ): # Get the dataset logger (logging system should already be initialized by train.py) self.logger = get_logger("dataset") self.audio_path = audio_path self.tsv_path = tsv_path self.labels_path = labels_path self.sequence_length = sequence_length self.device = device self.random = np.random.RandomState(seed) self.groups = groups self.conversion_map = conversion_map self.eval_file_list = [] self.file_list = self.files( self.groups, pitch_shift=pitch_shift, keep_eval_files=keep_eval_files, n_eval=n_eval, evaluation_list=evaluation_list, pitch_shift_limit=pitch_shift_limit, ) self.save_to_memory = save_to_memory self.smooth_labels = smooth_labels self.use_onset_mask = use_onset_mask self.pitch_shift_limit = pitch_shift_limit self.logger.debug("Save to memory is %s", self.save_to_memory) self.logger.info("len file list %d", len(self.file_list)) self.logger.info("\n\n") if instrument_map is None: self.get_instruments(conversion_map=conversion_map) else: self.instruments = instrument_map if update_instruments: self.add_instruments() self.transcriber = transcriber if only_eval: return self.load_pts(self.file_list) self.data = [] self.logger.info("Reading files...") for input_files in tqdm(self.file_list, desc="creating data list"): flac, _ = input_files audio_len = librosa.get_duration(path=flac) minutes = int(np.ceil(audio_len / 60)) copies = minutes for _ in range(copies): self.data.append(input_files) random.shuffle(self.data) def flac_to_pt_path(self, flac): pt_fname = os.path.basename(flac).replace(".flac", ".pt") pt_path = os.path.join(self.labels_path, pt_fname) return pt_path def __len__(self): return len(self.data) def files( self, groups, pitch_shift=True, keep_eval_files=False, n_eval=1, evaluation_list=None, pitch_shift_limit=5, ): self.path = self.audio_path tsvs_path = self.tsv_path self.logger.info("tsv path: %s", tsvs_path) self.logger.info("Evaluation list: %s", evaluation_list) res = [] self.logger.info("keep eval files: %s", keep_eval_files) self.logger.info("n eval: %d", n_eval) for group in groups: tsvs = os.listdir(tsvs_path + os.sep + group) tsvs = sorted(tsvs) if keep_eval_files and evaluation_list is None: eval_tsvs = tsvs[:n_eval] tsvs = tsvs[n_eval:] elif keep_eval_files and evaluation_list is not None: eval_tsvs_names = [ i.split("#")[0].split(".flac")[0].split(".tsv")[0] for i in evaluation_list ] eval_tsvs = [ i for i in tsvs if i.split("#")[0].split(".tsv")[0] in eval_tsvs_names ] tsvs = [i for i in tsvs if i not in eval_tsvs] else: eval_tsvs = [] self.logger.info("len tsvs: %d", len(tsvs)) tsvs_names = [t.split(".tsv")[0].split("#")[0] for t in tsvs] eval_tsvs_names = [t.split(".tsv")[0].split("#")[0] for t in eval_tsvs] for shft in range(-5, 6): if shft != 0 and not pitch_shift or abs(shft) > pitch_shift_limit: continue curr_fls_pth = self.path + os.sep + group + "#{}".format(shft) fls = os.listdir(curr_fls_pth) orig_files = fls # print(f"files names before\n {fls}") fls = [ i for i in fls if i.split("#")[0] in tsvs_names ] # in case we dont have the corresponding midi missing_fls = [i for i in orig_files if i not in fls] if len(missing_fls) > 0: self.logger.warning("missing files: %s", missing_fls) fls_names = [i.split("#")[0].split(".flac")[0] for i in fls] tsvs = [ i for i in tsvs if i.split(".tsv")[0].split("#")[0] in fls_names ] assert len(tsvs) == len(fls) # print(f"files names after\n {fls}") fls = sorted(fls) if shft == 0: eval_fls = os.listdir(curr_fls_pth) # print(f"files names\n {eval_fls}") eval_fls = [ i for i in eval_fls if i.split("#")[0] in eval_tsvs_names ] # in case we dont have the corresponding midi eval_fls_names = [i.split("#")[0] for i in eval_fls] eval_tsvs = [ i for i in eval_tsvs if i.split(".tsv")[0].split("#")[0] in eval_fls_names ] assert len(eval_fls_names) == len(eval_tsvs_names) # print(f"files names\n {eval_fls}") eval_fls = sorted(eval_fls) for f, t in zip(eval_fls, eval_tsvs): self.eval_file_list.append( ( curr_fls_pth + os.sep + f, tsvs_path + os.sep + group + os.sep + t, ) ) for f, t in zip(fls, tsvs): res.append( ( curr_fls_pth + os.sep + f, tsvs_path + os.sep + group + os.sep + t, ) ) for flac, tsv in res: if ( os.path.basename(flac).split("#")[0].split(".flac")[0] != os.path.basename(tsv).split("#")[0].split(".tsv")[0] ): self.logger.warning("found mismatch in the files: ") self.logger.warning("flac: %s", os.path.basename(flac).split("#")[0]) self.logger.warning("tsv: %s", os.path.basename(tsv).split("#")[0]) self.logger.warning("please check the input files") exit(1) return res def get_instruments(self, conversion_map=None): instruments = set() for _, f in self.file_list: events = np.loadtxt(f, delimiter="\t", skiprows=1) curr_instruments = set(events[:, -1]) if conversion_map is not None: curr_instruments = { conversion_map[c] if c in conversion_map else c for c in curr_instruments } instruments = instruments.union(curr_instruments) instruments = [int(elem) for elem in instruments if elem < 115] if conversion_map is not None: instruments = [i for i in instruments if i in conversion_map] instruments = list(set(instruments)) if 0 in instruments: piano_ind = instruments.index(0) instruments.pop(piano_ind) instruments.insert(0, 0) self.instruments = instruments self.instruments = list( set(self.instruments) - set(range(88, 104)) - set(range(112, 150)) ) self.logger.info("Dataset instruments: %s", self.instruments) self.logger.info("Total: %d instruments", len(self.instruments)) def add_instruments(self): for _, f in self.file_list: events = np.loadtxt(f, delimiter="\t", skiprows=1) curr_instruments = set(events[:, -1]) new_instruments = curr_instruments - set(self.instruments) self.instruments += list(new_instruments) instruments = [int(elem) for elem in self.instruments if (elem < 115)] self.instruments = instruments def __getitem__(self, index): data = self.load(*self.data[index]) # result = dict(path=data['path']) midi_length = len(data["label"]) n_steps = self.sequence_length // constants.HOP_LENGTH if midi_length < n_steps: step_begin = 0 step_end = midi_length else: step_begin = self.random.randint(max(midi_length - n_steps, 1)) step_end = step_begin + n_steps begin = step_begin * constants.HOP_LENGTH end = begin + self.sequence_length audio = ( data["audio"][begin:end].float().div_(32768.0) ) # torch.ShortTensor → float label = data["label"][step_begin:step_end].clone() # torch.Tensor if audio.shape[0] < self.sequence_length: pad_amt = self.sequence_length - audio.shape[0] audio = torch.cat([audio, torch.zeros(pad_amt, dtype=audio.dtype)], dim=0) if label.shape[0] < n_steps: pad_amt = n_steps - label.shape[0] label = torch.cat( [label, torch.zeros((pad_amt, *label.shape[1:]), dtype=label.dtype)], dim=0, ) audio = torch.clamp(audio, -1.0, 1.0) result = {"path": data["path"], "audio": audio, "label": label} if "velocity" in data: result["velocity"] = data["velocity"][step_begin:step_end, ...] result["velocity"] = result["velocity"].float() / 128.0 if result["label"].max() < 3: result["onset"] = result["label"].float() else: result["onset"] = (result["label"] == 3).float() result["offset"] = (result["label"] == 1).float() result["frame"] = (result["label"] > 1).float() if self.smooth_labels: result["onset"] = smooth_labels(result["onset"]) if self.use_onset_mask: if "onset_mask" in data: result["onset_mask"] = data["onset_mask"][ step_begin:step_end, ... ].float() else: result["onset_mask"] = torch.ones_like(result["onset"]).float() if "frame_mask" in data: result["frame_mask"] = data["frame_mask"][ step_begin:step_end, ... ].float() else: result["frame_mask"] = torch.ones_like(result["frame"]).float() shape = result["frame"].shape keys = N_KEYS new_shape = shape[:-1] + (shape[-1] // keys, keys) result["big_frame"] = result["frame"] result["frame"], _ = result["frame"].reshape(new_shape).max(axis=-2) # if 'frame_mask' not in data: # result['frame_mask'] = torch.ones_like(result['frame']).to(self.device).float() result["big_offset"] = result["offset"] result["offset"], _ = result["offset"].reshape(new_shape).max(axis=-2) result["group"] = self.data[index][0].split(os.sep)[-2].split("#")[0] return result def load(self, audio_path, tsv_path): if self.save_to_memory: data = self.pts[audio_path] else: data = torch.load(self.flac_to_pt_path(audio_path)) if len(data["audio"].shape) > 1: data["audio"] = (data["audio"].float().mean(dim=-1)).short() if "label" in data: return data else: piece, part = audio_path.split(os.sep)[-2:] piece_split = piece.split("#") if len(piece_split) == 2: piece, shift1 = piece_split else: piece, shift1 = "#".join(piece_split[:2]), piece_split[-1] part_split = part.split("#") if len(part_split) == 2: part, shift2 = part_split else: part, shift2 = "#".join(part_split[:2]), part_split[-1] shift2, _ = shift2.split(".") assert shift1 == shift2 shift = shift1 assert shift != 0 orig = audio_path.replace("#{}".format(shift), "#0") if self.save_to_memory: orig_data = self.pts[orig] else: orig_data = torch.load(self.flac_to_pt_path(orig)) res = {} res["label"] = shift_label(orig_data["label"], int(shift)) res["path"] = audio_path res["audio"] = data["audio"] if "velocity" in orig_data: res["velocity"] = shift_label(orig_data["velocity"], int(shift)) if "onset_mask" in orig_data: res["onset_mask"] = shift_label(orig_data["onset_mask"], int(shift)) if "frame_mask" in orig_data: res["frame_mask"] = shift_label(orig_data["frame_mask"], int(shift)) return res def load_pts(self, files): self.pts = {} self.logger.info("loading pts...") for flac, tsv in tqdm(files, desc="loading pts"): # print('flac, tsv', flac, tsv) if os.path.isfile( self.labels_path + os.sep + flac.split(os.sep)[-1].replace(".flac", ".pt") ): if self.save_to_memory: self.pts[flac] = torch.load( self.labels_path + os.sep + flac.split(os.sep)[-1].replace(".flac", ".pt") ) else: if flac.count("#") != 2: self.logger.debug("two # in filename: %s", flac) audio, sr = soundfile.read(flac, dtype="int16") if len(audio.shape) == 2: audio = audio.astype(float).mean(axis=1) else: audio = audio.astype(float) audio = audio.astype(np.int16) self.logger.debug("audio len: %d", len(audio)) assert sr == SAMPLE_RATE audio = torch.ShortTensor(audio) if "#0" not in flac: assert "#" in flac data = {"audio": audio} if self.save_to_memory: self.pts[flac] = data torch.save(data, self.flac_to_pt_path(flac)) continue midi = np.loadtxt(tsv, delimiter="\t", skiprows=1) unaligned_label = midi_to_frames( midi, self.instruments, conversion_map=self.conversion_map ) if len(self.instruments) == 1: unaligned_label = unaligned_label[:, -N_KEYS:] if len(unaligned_label) < self.sequence_length // constants.HOP_LENGTH: diff = self.sequence_length // constants.HOP_LENGTH - len( unaligned_label ) pad = torch.zeros( (diff, unaligned_label.shape[1]), dtype=unaligned_label.dtype ) unaligned_label = torch.cat((unaligned_label, pad), dim=0) group = flac.split(os.sep)[-2].split("#")[0] data = dict( path=self.labels_path + os.sep + flac.split(os.sep)[-1], audio=audio, unaligned_label=unaligned_label, group=group, BON=float("inf"), BON_VEC=np.full(unaligned_label.shape[1], float("inf")), ) torch.save(data, self.flac_to_pt_path(flac)) if self.save_to_memory: self.pts[flac] = data def update_pts_counting( self, transcriber, counting_window_length, POS=1.1, NEG=-0.001, FRAME_POS=0.5, to_save=None, first=False, update=True, BEST_DIST=False, peak_size=3, BEST_DIST_VEC=False, counting_window_hop=0, ): self.logger.info("Updating pts...") self.logger.info("First %s", first) total_counting_time = 0.0 # Initialize total time for counting-based alignment self.logger.info("POS, NEG: %s, %s", POS, NEG) if to_save is not None: os.makedirs(to_save, exist_ok=True) self.logger.info("There are %d pts", len(self.pts)) update_count = 0 sys.stdout.flush() onlt_pitch_0_files = [f for f in self.file_list if "#0" in f[0]] for input_files in tqdm(onlt_pitch_0_files, desc="updating pts"): flac, tsv = input_files data = torch.load(self.flac_to_pt_path(flac)) if "unaligned_label" not in data: self.logger.warning("No unaligned labels for %s", flac) continue audio_inp = data["audio"].float() / 32768.0 MAX_TIME = 5 * 60 * SAMPLE_RATE audio_inp_len = len(audio_inp) if audio_inp_len > MAX_TIME: n_segments = int(np.ceil(audio_inp_len / MAX_TIME)) self.logger.debug("Long audio, splitting to %d segments", n_segments) seg_len = MAX_TIME onsets_preds = [] offset_preds = [] frame_preds = [] for i_s in range(n_segments): curr = ( audio_inp[i_s * seg_len : (i_s + 1) * seg_len] .unsqueeze(0) .cuda() ) curr_mel = melspectrogram( curr.reshape(-1, curr.shape[-1])[:, :-1] ).transpose(-1, -2) ( curr_onset_pred, curr_offset_pred, _, curr_frame_pred, curr_velocity_pred, ) = transcriber(curr_mel) onsets_preds.append(curr_onset_pred) offset_preds.append(curr_offset_pred) frame_preds.append(curr_frame_pred) onset_pred = torch.cat(onsets_preds, dim=1) offset_pred = torch.cat(offset_preds, dim=1) frame_pred = torch.cat(frame_preds, dim=1) else: audio_inp = audio_inp.unsqueeze(0).cuda() mel = melspectrogram( audio_inp.reshape(-1, audio_inp.shape[-1])[:, :-1] ).transpose(-1, -2) onset_pred, offset_pred, _, frame_pred, _ = transcriber(mel) self.logger.debug("Done predicting.") # We assume onset predictions are of length N_KEYS * (len(instruments) + 1), # first N_KEYS classes are the first instrument, next N_KEYS classes are the next instrument, etc., # and last N_KEYS classes are for pitch regardless of instrument # Currently, frame and offset predictions are only N_KEYS classes. onset_pred = onset_pred.detach().squeeze().cpu() frame_pred = frame_pred.detach().squeeze().cpu() PEAK_SIZE = peak_size self.logger.debug("PEAK_SIZE: %d", PEAK_SIZE) # we peak peak the onset prediction to only keep local maximum onsets if peak_size > 0: peaks = get_peaks( onset_pred, PEAK_SIZE ) # we only want local peaks, in a 7-frame neighborhood, 3 to each side. onset_pred[~peaks] = 0 unaligned_onsets = (data["unaligned_label"] == 3).float().numpy() onset_pred_np = onset_pred.numpy() frame_pred_np = frame_pred.numpy() pred_bag_of_notes = (onset_pred_np[:, -N_KEYS:] >= 0.5).sum(axis=0) gt_bag_of_notes = unaligned_onsets[:, -N_KEYS:].astype(bool).sum(axis=0) bon_dist = (((pred_bag_of_notes - gt_bag_of_notes) ** 2).sum()) ** 0.5 pred_bag_of_notes_with_inst = (onset_pred_np >= 0.5).sum(axis=0) gt_bag_of_notes_with_inst = unaligned_onsets.astype(bool).sum(axis=0) bon_dist_vec = np.abs( pred_bag_of_notes_with_inst - gt_bag_of_notes_with_inst ) bon_dist /= gt_bag_of_notes.sum() self.logger.debug("bag of notes dist: %f", bon_dist) #### aligned_onsets = np.zeros(onset_pred_np.shape, dtype=bool) aligned_frames = np.zeros(onset_pred_np.shape, dtype=bool) # This block is the main difference between the counting approach and the DTW approach. # In the counting approach we label the audio by counting note onsets: For each onset pitch class, # denote by K the number of times it occurs in the unaligned label. We simply take the K highest local # peaks predicted by the current model. # Split unaligned onsets into chunks of size counting_window_length self.logger.debug( "unaligned onsets shape: %s, counting window length: %d, counting window hop: %d", unaligned_onsets.shape, counting_window_length, counting_window_hop, ) assert counting_window_hop <= counting_window_length if counting_window_hop == 0: counting_window_hop = counting_window_length num_chunks = ( 1 if counting_window_length == 0 else int(np.ceil(len(unaligned_onsets) / counting_window_hop)) ) self.logger.debug("number of chunks: %d", num_chunks) start_time = time.time() for chunk_idx in range(num_chunks): start_idx = chunk_idx * counting_window_hop if counting_window_length == 0: end_idx = max(len(unaligned_onsets), len(onset_pred_np)) else: end_idx = min( start_idx + counting_window_length, len(unaligned_onsets) ) chunk_onsets = unaligned_onsets[start_idx:end_idx] chunk_onsets_count = ( (data["unaligned_label"][start_idx:end_idx, :] == 3) .sum(dim=0) .numpy() ) for f, f_count in enumerate(chunk_onsets_count): if f_count == 0: continue f_most_likely = np.sort( onset_pred_np[start_idx:end_idx, f].argsort()[::-1][:f_count] ) f_most_likely += start_idx # Adjust indices to the original size aligned_onsets[f_most_likely, f] = 1 f_unaligned = chunk_onsets[:, f].nonzero() assert len(f_unaligned) == 1 f_unaligned = f_unaligned[0] counting_duration = time.time() - start_time total_counting_time += counting_duration self.logger.debug( "Counting alignment for file '%s' took %.2f seconds.", flac, counting_duration, ) # Pseudo labels, Pos bigger than 1 is equivalent to not using pseudo labels pseudo_onsets = (onset_pred_np >= POS) & (~aligned_onsets) onset_label = np.maximum(pseudo_onsets, aligned_onsets) # in this project we do not train frame stack but we calculate the labeels anyways pseudo_frames = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype) pseudo_offsets = np.zeros(pseudo_onsets.shape, dtype=pseudo_onsets.dtype) for t, f in zip(*onset_label.nonzero()): t_off = t while ( t_off < len(pseudo_frames) and frame_pred[t_off, f % N_KEYS] >= FRAME_POS ): t_off += 1 pseudo_frames[t:t_off, f] = 1 if t_off < len(pseudo_offsets): pseudo_offsets[t_off, f] = 1 frame_label = np.maximum(pseudo_frames, aligned_frames) offset_label = get_diff(frame_label, offset=True) label = np.maximum(2 * frame_label, offset_label) label = np.maximum(3 * onset_label, label).astype(np.uint8) if to_save is not None: save_midi_alignments_and_predictions( to_save, data["path"], self.instruments, aligned_onsets, aligned_frames, onset_pred_np, frame_pred_np, prefix="", group=data["group"], ) prev_bon_dist = data.get("BON", float("inf")) prev_bon_dist_vec = data.get("BON_VEC", None) if update: if BEST_DIST_VEC: self.logger.debug("Updated Labels") if prev_bon_dist_vec is None: raise ValueError( "BEST_DIST_VEC is True but no previous BON_VEC found" ) prev_label = data["label"] new_label = torch.from_numpy(label).byte() if first: prev_label = new_label update_count += 1 else: updated_flag = False num_pitches_updated = 0 for k in range(prev_label.shape[1]): if prev_bon_dist_vec[k] > bon_dist_vec[k]: prev_label[:, k] = new_label[:, k] prev_bon_dist_vec[k] = bon_dist_vec[k] num_pitches_updated += 1 updated_flag = True if updated_flag: update_count += 1 self.logger.debug("Updated %d pitches", num_pitches_updated) data["label"] = prev_label data["BON_VEC"] = prev_bon_dist_vec self.logger.debug("saved updated pt") torch.save( data, self.labels_path + os.sep + flac.split(os.sep)[-1] .replace(".flac", ".pt") .replace(".mp3", ".pt"), ) elif not BEST_DIST or bon_dist < prev_bon_dist: update_count += 1 self.logger.debug("Updated Labels") data["label"] = torch.from_numpy(label).byte() data["BON"] = bon_dist self.logger.debug("saved updated pt") torch.save( data, self.labels_path + os.sep + flac.split(os.sep)[-1] .replace(".flac", ".pt") .replace(".mp3", ".pt"), ) if bon_dist < prev_bon_dist: self.logger.debug( "Bag of notes distance improved from %f to %f", prev_bon_dist, bon_dist, ) data["BON"] = bon_dist if to_save is not None and BEST_DIST: os.makedirs(to_save + "/BEST_BON", exist_ok=True) save_midi_alignments_and_predictions( to_save + "/BEST_BON", data["path"], self.instruments, aligned_onsets, aligned_frames, onset_pred_np, frame_pred_np, prefix="BEST_BON", group=data["group"], use_time=False, ) self.logger.info( "Updated %d pts out of %d", update_count, len(onlt_pitch_0_files) ) self.logger.info( "Total counting alignment time for all files: %.2f seconds.", total_counting_time )