Ardantia (formerly VsingerXiaoiceSing)
at main 529 lines 24 kB view raw
1import csv 2import json 3import os 4import pathlib 5 6import librosa 7import numpy as np 8import torch 9import torch.nn.functional as F 10from scipy import interpolate 11 12from basics.base_binarizer import BaseBinarizer, BinarizationError 13from basics.base_pe import BasePE 14from modules.fastspeech.tts_modules import LengthRegulator 15from modules.pe import initialize_pe 16from utils.binarizer_utils import ( 17 SinusoidalSmoothingConv1d, 18 get_mel2ph_torch, 19 get_energy_librosa, 20 get_breathiness, 21 get_voicing, 22 get_tension_base_harmonic, 23) 24from utils.decomposed_waveform import DecomposedWaveform 25from utils.hparams import hparams 26from utils.infer_utils import resample_align_curve 27from utils.pitch_utils import interp_f0 28from utils.plot import distribution_to_figure 29 30os.environ["OMP_NUM_THREADS"] = "1" 31VARIANCE_ITEM_ATTRIBUTES = [ 32 'spk_id', # index number of dataset/speaker, int64 33 'languages', # index numbers of phoneme languages, int64[T_ph,] 34 'tokens', # index numbers of phonemes, int64[T_ph,] 35 'ph_dur', # durations of phonemes, in number of frames, int64[T_ph,] 36 'midi', # phoneme-level mean MIDI pitch, int64[T_ph,] 37 'ph2word', # similar to mel2ph format, representing number of phones within each note, int64[T_ph,] 38 'mel2ph', # mel2ph format representing number of frames within each phone, int64[T_s,] 39 'note_midi', # note-level MIDI pitch, float32[T_n,] 40 'note_rest', # flags for rest notes, bool[T_n,] 41 'note_dur', # durations of notes, in number of frames, int64[T_n,] 42 'note_glide', # flags for glides, 0 = none, 1 = up, 2 = down, int64[T_n,] 43 'mel2note', # mel2ph format representing number of frames within each note, int64[T_s,] 44 'base_pitch', # interpolated and smoothed frame-level MIDI pitch, float32[T_s,] 45 'pitch', # actual pitch in semitones, float32[T_s,] 46 'uv', # unvoiced masks (only for objective evaluation metrics), bool[T_s,] 47 'energy', # frame-level RMS (dB), float32[T_s,] 48 'breathiness', # frame-level RMS of aperiodic parts (dB), float32[T_s,] 49 'voicing', # frame-level RMS of harmonic parts (dB), float32[T_s,] 50 'tension', # frame-level tension (logit), float32[T_s,] 51] 52WAV_CANDIDATE_EXTENSIONS = ['.wav', '.flac'] 53DS_INDEX_SEP = '#' 54 55# These operators are used as global variables due to a PyTorch shared memory bug on Windows platforms. 56# See https://github.com/pytorch/pytorch/issues/100358 57pitch_extractor: BasePE = None 58midi_smooth: SinusoidalSmoothingConv1d = None 59energy_smooth: SinusoidalSmoothingConv1d = None 60breathiness_smooth: SinusoidalSmoothingConv1d = None 61voicing_smooth: SinusoidalSmoothingConv1d = None 62tension_smooth: SinusoidalSmoothingConv1d = None 63 64 65class VarianceBinarizer(BaseBinarizer): 66 def __init__(self): 67 super().__init__(data_attrs=VARIANCE_ITEM_ATTRIBUTES) 68 69 self.use_glide_embed = hparams['use_glide_embed'] 70 glide_types = hparams['glide_types'] 71 assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.' 72 self.glide_map = { 73 'none': 0, 74 **{ 75 typename: idx + 1 76 for idx, typename in enumerate(glide_types) 77 } 78 } 79 80 predict_energy = hparams['predict_energy'] 81 predict_breathiness = hparams['predict_breathiness'] 82 predict_voicing = hparams['predict_voicing'] 83 predict_tension = hparams['predict_tension'] 84 self.predict_variances = predict_energy or predict_breathiness or predict_voicing or predict_tension 85 self.lr = LengthRegulator().to(self.device) 86 self.prefer_ds = self.binarization_args['prefer_ds'] 87 self.cached_ds = {} 88 89 def load_attr_from_ds(self, ds_id, name, attr, idx=0): 90 item_name = f'{ds_id}:{name}' 91 item_name_with_idx = f'{item_name}{DS_INDEX_SEP}{idx}' 92 if item_name_with_idx in self.cached_ds: 93 ds = self.cached_ds[item_name_with_idx][0] 94 elif item_name in self.cached_ds: 95 ds = self.cached_ds[item_name][idx] 96 else: 97 ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}{DS_INDEX_SEP}{idx}.ds' 98 if ds_path.exists(): 99 cache_key = item_name_with_idx 100 else: 101 ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds' 102 cache_key = item_name 103 if not ds_path.exists(): 104 return None 105 with open(ds_path, 'r', encoding='utf8') as f: 106 ds = json.load(f) 107 if not isinstance(ds, list): 108 ds = [ds] 109 self.cached_ds[cache_key] = ds 110 ds = ds[idx] 111 return ds.get(attr) 112 113 def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk, lang): 114 meta_data_dict = {} 115 116 with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') as f: 117 for utterance_label in csv.DictReader(f): 118 utterance_label: dict 119 item_name = utterance_label['name'] 120 item_idx = int(item_name.rsplit(DS_INDEX_SEP, maxsplit=1)[-1]) if DS_INDEX_SEP in item_name else 0 121 122 def require(attr, optional=False): 123 if self.prefer_ds: 124 value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx) 125 else: 126 value = None 127 if value is None: 128 value = utterance_label.get(attr) 129 if value is None and not optional: 130 raise ValueError(f'Missing required attribute {attr} of item \'{item_name}\'.') 131 return value 132 133 wav_fn = None 134 for ext in WAV_CANDIDATE_EXTENSIONS: 135 candidate_fn = raw_data_dir / 'wavs' / f'{item_name}{ext}' 136 if candidate_fn.exists(): 137 wav_fn = candidate_fn 138 break 139 if wav_fn is None and not self.prefer_ds: 140 raise FileNotFoundError( 141 f'Waveform file not found for item \'{item_name}\'. ' 142 f'Candidate extensions: {WAV_CANDIDATE_EXTENSIONS}\n' 143 f'If you are using DS files instead of waveform files, please set \'prefer_ds\' to true.' 144 ) 145 146 temp_dict = { 147 'ds_idx': item_idx, 148 'spk_id': self.spk_map[spk], 149 'spk_name': spk, 150 'language_id': self.lang_map[lang], 151 'language_name': lang, 152 'wav_fn': str(wav_fn) if wav_fn is not None else None, 153 'lang_seq': [ 154 ( 155 self.lang_map[lang if '/' not in p else p.split('/', maxsplit=1)[0]] 156 if self.phoneme_dictionary.is_cross_lingual(p if '/' in p else f'{lang}/{p}') 157 else 0 158 ) 159 for p in utterance_label['ph_seq'].split() 160 ], 161 'ph_seq': self.phoneme_dictionary.encode(require('ph_seq'), lang=lang), 162 'ph_dur': [float(x) for x in require('ph_dur').split()], 163 'ph_text': require('ph_seq'), 164 } 165 166 assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \ 167 f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.' 168 assert all(ph_dur >= 0 for ph_dur in temp_dict['ph_dur']), \ 169 f'Negative ph_dur found in \'{item_name}\'.' 170 171 if hparams['predict_dur']: 172 temp_dict['ph_num'] = [int(x) for x in require('ph_num').split()] 173 assert len(temp_dict['ph_seq']) == sum(temp_dict['ph_num']), \ 174 f'Sum of ph_num does not equal length of ph_seq in \'{item_name}\'.' 175 176 if hparams['predict_pitch']: 177 temp_dict['note_seq'] = require('note_seq').split() 178 temp_dict['note_dur'] = [float(x) for x in require('note_dur').split()] 179 assert all(note_dur >= 0 for note_dur in temp_dict['note_dur']), \ 180 f'Negative note_dur found in \'{item_name}\'.' 181 assert len(temp_dict['note_seq']) == len(temp_dict['note_dur']), \ 182 f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.' 183 assert any([note != 'rest' for note in temp_dict['note_seq']]), \ 184 f'All notes are rest in \'{item_name}\'.' 185 if hparams['use_glide_embed']: 186 note_glide = require('note_glide', optional=True) 187 if note_glide is None: 188 note_glide = ['none' for _ in temp_dict['note_seq']] 189 else: 190 note_glide = note_glide.split() 191 assert len(note_glide) == len(temp_dict['note_seq']), \ 192 f'Lengths of note_seq and note_glide mismatch in \'{item_name}\'.' 193 assert all(g in self.glide_map for g in note_glide), \ 194 f'Invalid glide type found in \'{item_name}\'.' 195 temp_dict['note_glide'] = note_glide 196 197 meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict 198 199 return meta_data_dict 200 201 def check_coverage(self): 202 super().check_coverage() 203 if not hparams['predict_pitch']: 204 return 205 206 # MIDI pitch distribution summary 207 midi_map = {} 208 for item_name in self.items: 209 for midi in self.items[item_name]['note_seq']: 210 if midi == 'rest': 211 continue 212 midi = librosa.note_to_midi(midi, round_midi=True) 213 if midi in midi_map: 214 midi_map[midi] += 1 215 else: 216 midi_map[midi] = 1 217 218 print('===== MIDI Pitch Distribution Summary =====') 219 for i, key in enumerate(sorted(midi_map.keys())): 220 if i == len(midi_map) - 1: 221 end = '\n' 222 elif i % 10 == 9: 223 end = ',\n' 224 else: 225 end = ', ' 226 print(f'\'{librosa.midi_to_note(key, unicode=False)}\': {midi_map[key]}', end=end) 227 228 # Draw graph. 229 midis = sorted(midi_map.keys()) 230 notes = [librosa.midi_to_note(m, unicode=False) for m in range(midis[0], midis[-1] + 1)] 231 plt = distribution_to_figure( 232 title='MIDI Pitch Distribution Summary', 233 x_label='MIDI Key', y_label='Number of occurrences', 234 items=notes, values=[midi_map.get(m, 0) for m in range(midis[0], midis[-1] + 1)] 235 ) 236 filename = self.binary_data_dir / 'midi_distribution.jpg' 237 plt.savefig(fname=filename, 238 bbox_inches='tight', 239 pad_inches=0.25) 240 print(f'| save summary to \'{filename}\'') 241 242 if self.use_glide_embed: 243 # Glide type distribution summary 244 glide_count = { 245 g: 0 246 for g in self.glide_map 247 } 248 for item_name in self.items: 249 for glide in self.items[item_name]['note_glide']: 250 if glide == 'none' or glide not in self.glide_map: 251 glide_count['none'] += 1 252 else: 253 glide_count[glide] += 1 254 255 print('===== Glide Type Distribution Summary =====') 256 for i, key in enumerate(sorted(glide_count.keys(), key=lambda k: self.glide_map[k])): 257 if i == len(glide_count) - 1: 258 end = '\n' 259 elif i % 10 == 9: 260 end = ',\n' 261 else: 262 end = ', ' 263 print(f'\'{key}\': {glide_count[key]}', end=end) 264 265 if any(n == 0 for _, n in glide_count.items()): 266 raise BinarizationError( 267 f'Missing glide types in dataset: ' 268 f'{sorted([g for g, n in glide_count.items() if n == 0], key=lambda k: self.glide_map[k])}' 269 ) 270 271 @torch.no_grad() 272 def process_item(self, item_name, meta_data, binarization_args): 273 ds_id, name = item_name.split(':', maxsplit=1) 274 name = name.rsplit(DS_INDEX_SEP, maxsplit=1)[0] 275 ds_id = int(ds_id) 276 ds_seg_idx = meta_data['ds_idx'] 277 seconds = sum(meta_data['ph_dur']) 278 length = round(seconds / self.timestep) 279 T_ph = len(meta_data['ph_seq']) 280 processed_input = { 281 'name': item_name, 282 'wav_fn': meta_data['wav_fn'], 283 'spk_id': meta_data['spk_id'], 284 'spk_name': meta_data['spk_name'], 285 'seconds': seconds, 286 'length': length, 287 'languages': np.array(meta_data['lang_seq'], dtype=np.int64), 288 'tokens': np.array(meta_data['ph_seq'], dtype=np.int64), 289 'ph_text': meta_data['ph_text'], 290 } 291 292 ph_dur_sec = torch.FloatTensor(meta_data['ph_dur']).to(self.device) 293 ph_acc = torch.round(torch.cumsum(ph_dur_sec, dim=0) / self.timestep + 0.5).long() 294 ph_dur = torch.diff(ph_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device)) 295 processed_input['ph_dur'] = ph_dur.cpu().numpy() 296 297 mel2ph = get_mel2ph_torch( 298 self.lr, ph_dur_sec, length, self.timestep, device=self.device 299 ) 300 301 if hparams['predict_pitch'] or self.predict_variances: 302 processed_input['mel2ph'] = mel2ph.cpu().numpy() 303 304 # Below: extract actual f0, convert to pitch and calculate delta pitch 305 if meta_data['wav_fn'] is not None: 306 waveform, _ = librosa.load(meta_data['wav_fn'], sr=hparams['audio_sample_rate'], mono=True) 307 else: 308 waveform = None 309 310 global pitch_extractor 311 if pitch_extractor is None: 312 pitch_extractor = initialize_pe() 313 f0 = uv = None 314 if self.prefer_ds: 315 f0_seq = self.load_attr_from_ds(ds_id, name, 'f0_seq', idx=ds_seg_idx) 316 if f0_seq is not None: 317 f0 = resample_align_curve( 318 np.array(f0_seq.split(), np.float32), 319 original_timestep=float(self.load_attr_from_ds(ds_id, name, 'f0_timestep', idx=ds_seg_idx)), 320 target_timestep=self.timestep, 321 align_length=length 322 ) 323 uv = f0 == 0 324 f0, _ = interp_f0(f0, uv) 325 if f0 is None: 326 f0, uv = pitch_extractor.get_pitch( 327 waveform, samplerate=hparams['audio_sample_rate'], length=length, 328 hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'], 329 interp_uv=True 330 ) 331 if uv.all(): # All unvoiced 332 print(f'Skipped \'{item_name}\': empty gt f0') 333 return None 334 pitch = torch.from_numpy(librosa.hz_to_midi(f0.astype(np.float32)).astype(np.float32)).to(self.device) 335 336 if hparams['predict_dur']: 337 ph_num = torch.LongTensor(meta_data['ph_num']).to(self.device) 338 ph2word = self.lr(ph_num[None])[0] 339 processed_input['ph2word'] = ph2word.cpu().numpy() 340 mel2dur = torch.gather(F.pad(ph_dur, [1, 0], value=1), 0, mel2ph) # frame-level phone duration 341 ph_midi = pitch.new_zeros(T_ph + 1).scatter_add( 342 0, mel2ph, pitch / mel2dur 343 )[1:] 344 processed_input['midi'] = ph_midi.round().long().clamp(min=0, max=127).cpu().numpy() 345 346 if hparams['predict_pitch']: 347 # Below: get note sequence and interpolate rest notes 348 note_midi = np.array( 349 [(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in meta_data['note_seq']], 350 dtype=np.float32 351 ) 352 note_rest = note_midi < 0 353 interp_func = interpolate.interp1d( 354 np.where(~note_rest)[0], note_midi[~note_rest], 355 kind='nearest', fill_value='extrapolate' 356 ) 357 note_midi[note_rest] = interp_func(np.where(note_rest)[0]) 358 processed_input['note_midi'] = note_midi 359 processed_input['note_rest'] = note_rest 360 note_midi = torch.from_numpy(note_midi).to(self.device) 361 362 note_dur_sec = torch.FloatTensor(meta_data['note_dur']).to(self.device) 363 note_acc = torch.round(torch.cumsum(note_dur_sec, dim=0) / self.timestep + 0.5).long() 364 note_dur = torch.diff(note_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device)) 365 processed_input['note_dur'] = note_dur.cpu().numpy() 366 367 mel2note = get_mel2ph_torch( 368 self.lr, note_dur_sec, mel2ph.shape[0], self.timestep, device=self.device 369 ) 370 processed_input['mel2note'] = mel2note.cpu().numpy() 371 372 # Below: get ornament attributes 373 if hparams['use_glide_embed']: 374 processed_input['note_glide'] = np.array([ 375 self.glide_map.get(x, 0) for x in meta_data['note_glide'] 376 ], dtype=np.int64) 377 378 # Below: 379 # 1. Get the frame-level MIDI pitch, which is a step function curve 380 # 2. smoothen the pitch step curve as the base pitch curve 381 frame_midi_pitch = torch.gather(F.pad(note_midi, [1, 0], value=0), 0, mel2note) 382 global midi_smooth 383 if midi_smooth is None: 384 midi_smooth = SinusoidalSmoothingConv1d( 385 round(hparams['midi_smooth_width'] / self.timestep) 386 ).eval().to(self.device) 387 smoothed_midi_pitch = midi_smooth(frame_midi_pitch[None])[0] 388 processed_input['base_pitch'] = smoothed_midi_pitch.cpu().numpy() 389 390 if hparams['predict_pitch'] or self.predict_variances: 391 processed_input['pitch'] = pitch.cpu().numpy() 392 processed_input['uv'] = uv 393 394 # Below: extract energy 395 if hparams['predict_energy']: 396 energy = None 397 energy_from_wav = False 398 if self.prefer_ds: 399 energy_seq = self.load_attr_from_ds(ds_id, name, 'energy', idx=ds_seg_idx) 400 if energy_seq is not None: 401 energy = resample_align_curve( 402 np.array(energy_seq.split(), np.float32), 403 original_timestep=float(self.load_attr_from_ds( 404 ds_id, name, 'energy_timestep', idx=ds_seg_idx 405 )), 406 target_timestep=self.timestep, 407 align_length=length 408 ) 409 if energy is None: 410 energy = get_energy_librosa( 411 waveform, length, 412 hop_size=hparams['hop_size'], win_size=hparams['win_size'] 413 ).astype(np.float32) 414 energy_from_wav = True 415 416 if energy_from_wav: 417 global energy_smooth 418 if energy_smooth is None: 419 energy_smooth = SinusoidalSmoothingConv1d( 420 round(hparams['energy_smooth_width'] / self.timestep) 421 ).eval().to(self.device) 422 energy = energy_smooth(torch.from_numpy(energy).to(self.device)[None])[0].cpu().numpy() 423 424 processed_input['energy'] = energy 425 426 # create a DecomposedWaveform object for further feature extraction 427 dec_waveform = DecomposedWaveform( 428 waveform, samplerate=hparams['audio_sample_rate'], f0=f0 * ~uv, 429 hop_size=hparams['hop_size'], fft_size=hparams['fft_size'], win_size=hparams['win_size'], 430 algorithm=hparams['hnsep'] 431 ) if waveform is not None else None 432 433 # Below: extract breathiness 434 if hparams['predict_breathiness']: 435 breathiness = None 436 breathiness_from_wav = False 437 if self.prefer_ds: 438 breathiness_seq = self.load_attr_from_ds(ds_id, name, 'breathiness', idx=ds_seg_idx) 439 if breathiness_seq is not None: 440 breathiness = resample_align_curve( 441 np.array(breathiness_seq.split(), np.float32), 442 original_timestep=float(self.load_attr_from_ds( 443 ds_id, name, 'breathiness_timestep', idx=ds_seg_idx 444 )), 445 target_timestep=self.timestep, 446 align_length=length 447 ) 448 if breathiness is None: 449 breathiness = get_breathiness( 450 dec_waveform, None, None, length=length 451 ) 452 breathiness_from_wav = True 453 454 if breathiness_from_wav: 455 global breathiness_smooth 456 if breathiness_smooth is None: 457 breathiness_smooth = SinusoidalSmoothingConv1d( 458 round(hparams['breathiness_smooth_width'] / self.timestep) 459 ).eval().to(self.device) 460 breathiness = breathiness_smooth(torch.from_numpy(breathiness).to(self.device)[None])[0].cpu().numpy() 461 462 processed_input['breathiness'] = breathiness 463 464 # Below: extract voicing 465 if hparams['predict_voicing']: 466 voicing = None 467 voicing_from_wav = False 468 if self.prefer_ds: 469 voicing_seq = self.load_attr_from_ds(ds_id, name, 'voicing', idx=ds_seg_idx) 470 if voicing_seq is not None: 471 voicing = resample_align_curve( 472 np.array(voicing_seq.split(), np.float32), 473 original_timestep=float(self.load_attr_from_ds( 474 ds_id, name, 'voicing_timestep', idx=ds_seg_idx 475 )), 476 target_timestep=self.timestep, 477 align_length=length 478 ) 479 if voicing is None: 480 voicing = get_voicing( 481 dec_waveform, None, None, length=length 482 ) 483 voicing_from_wav = True 484 485 if voicing_from_wav: 486 global voicing_smooth 487 if voicing_smooth is None: 488 voicing_smooth = SinusoidalSmoothingConv1d( 489 round(hparams['voicing_smooth_width'] / self.timestep) 490 ).eval().to(self.device) 491 voicing = voicing_smooth(torch.from_numpy(voicing).to(self.device)[None])[0].cpu().numpy() 492 493 processed_input['voicing'] = voicing 494 495 # Below: extract tension 496 if hparams['predict_tension']: 497 tension = None 498 tension_from_wav = False 499 if self.prefer_ds: 500 tension_seq = self.load_attr_from_ds(ds_id, name, 'tension', idx=ds_seg_idx) 501 if tension_seq is not None: 502 tension = resample_align_curve( 503 np.array(tension_seq.split(), np.float32), 504 original_timestep=float(self.load_attr_from_ds( 505 ds_id, name, 'tension_timestep', idx=ds_seg_idx 506 )), 507 target_timestep=self.timestep, 508 align_length=length 509 ) 510 if tension is None: 511 tension = get_tension_base_harmonic( 512 dec_waveform, None, None, length=length, domain='logit' 513 ) 514 tension_from_wav = True 515 516 if tension_from_wav: 517 global tension_smooth 518 if tension_smooth is None: 519 tension_smooth = SinusoidalSmoothingConv1d( 520 round(hparams['tension_smooth_width'] / self.timestep) 521 ).eval().to(self.device) 522 tension = tension_smooth(torch.from_numpy(tension).to(self.device)[None])[0].cpu().numpy() 523 524 processed_input['tension'] = tension 525 526 return processed_input 527 528 def arrange_data_augmentation(self, data_iterator): 529 return {}