Ardantia (formerly VsingerXiaoiceSing)
1import csv
2import json
3import os
4import pathlib
5
6import librosa
7import numpy as np
8import torch
9import torch.nn.functional as F
10from scipy import interpolate
11
12from basics.base_binarizer import BaseBinarizer, BinarizationError
13from basics.base_pe import BasePE
14from modules.fastspeech.tts_modules import LengthRegulator
15from modules.pe import initialize_pe
16from utils.binarizer_utils import (
17 SinusoidalSmoothingConv1d,
18 get_mel2ph_torch,
19 get_energy_librosa,
20 get_breathiness,
21 get_voicing,
22 get_tension_base_harmonic,
23)
24from utils.decomposed_waveform import DecomposedWaveform
25from utils.hparams import hparams
26from utils.infer_utils import resample_align_curve
27from utils.pitch_utils import interp_f0
28from utils.plot import distribution_to_figure
29
30os.environ["OMP_NUM_THREADS"] = "1"
31VARIANCE_ITEM_ATTRIBUTES = [
32 'spk_id', # index number of dataset/speaker, int64
33 'languages', # index numbers of phoneme languages, int64[T_ph,]
34 'tokens', # index numbers of phonemes, int64[T_ph,]
35 'ph_dur', # durations of phonemes, in number of frames, int64[T_ph,]
36 'midi', # phoneme-level mean MIDI pitch, int64[T_ph,]
37 'ph2word', # similar to mel2ph format, representing number of phones within each note, int64[T_ph,]
38 'mel2ph', # mel2ph format representing number of frames within each phone, int64[T_s,]
39 'note_midi', # note-level MIDI pitch, float32[T_n,]
40 'note_rest', # flags for rest notes, bool[T_n,]
41 'note_dur', # durations of notes, in number of frames, int64[T_n,]
42 'note_glide', # flags for glides, 0 = none, 1 = up, 2 = down, int64[T_n,]
43 'mel2note', # mel2ph format representing number of frames within each note, int64[T_s,]
44 'base_pitch', # interpolated and smoothed frame-level MIDI pitch, float32[T_s,]
45 'pitch', # actual pitch in semitones, float32[T_s,]
46 'uv', # unvoiced masks (only for objective evaluation metrics), bool[T_s,]
47 'energy', # frame-level RMS (dB), float32[T_s,]
48 'breathiness', # frame-level RMS of aperiodic parts (dB), float32[T_s,]
49 'voicing', # frame-level RMS of harmonic parts (dB), float32[T_s,]
50 'tension', # frame-level tension (logit), float32[T_s,]
51]
52WAV_CANDIDATE_EXTENSIONS = ['.wav', '.flac']
53DS_INDEX_SEP = '#'
54
55# These operators are used as global variables due to a PyTorch shared memory bug on Windows platforms.
56# See https://github.com/pytorch/pytorch/issues/100358
57pitch_extractor: BasePE = None
58midi_smooth: SinusoidalSmoothingConv1d = None
59energy_smooth: SinusoidalSmoothingConv1d = None
60breathiness_smooth: SinusoidalSmoothingConv1d = None
61voicing_smooth: SinusoidalSmoothingConv1d = None
62tension_smooth: SinusoidalSmoothingConv1d = None
63
64
65class VarianceBinarizer(BaseBinarizer):
66 def __init__(self):
67 super().__init__(data_attrs=VARIANCE_ITEM_ATTRIBUTES)
68
69 self.use_glide_embed = hparams['use_glide_embed']
70 glide_types = hparams['glide_types']
71 assert 'none' not in glide_types, 'Type name \'none\' is reserved and should not appear in glide_types.'
72 self.glide_map = {
73 'none': 0,
74 **{
75 typename: idx + 1
76 for idx, typename in enumerate(glide_types)
77 }
78 }
79
80 predict_energy = hparams['predict_energy']
81 predict_breathiness = hparams['predict_breathiness']
82 predict_voicing = hparams['predict_voicing']
83 predict_tension = hparams['predict_tension']
84 self.predict_variances = predict_energy or predict_breathiness or predict_voicing or predict_tension
85 self.lr = LengthRegulator().to(self.device)
86 self.prefer_ds = self.binarization_args['prefer_ds']
87 self.cached_ds = {}
88
89 def load_attr_from_ds(self, ds_id, name, attr, idx=0):
90 item_name = f'{ds_id}:{name}'
91 item_name_with_idx = f'{item_name}{DS_INDEX_SEP}{idx}'
92 if item_name_with_idx in self.cached_ds:
93 ds = self.cached_ds[item_name_with_idx][0]
94 elif item_name in self.cached_ds:
95 ds = self.cached_ds[item_name][idx]
96 else:
97 ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}{DS_INDEX_SEP}{idx}.ds'
98 if ds_path.exists():
99 cache_key = item_name_with_idx
100 else:
101 ds_path = self.raw_data_dirs[ds_id] / 'ds' / f'{name}.ds'
102 cache_key = item_name
103 if not ds_path.exists():
104 return None
105 with open(ds_path, 'r', encoding='utf8') as f:
106 ds = json.load(f)
107 if not isinstance(ds, list):
108 ds = [ds]
109 self.cached_ds[cache_key] = ds
110 ds = ds[idx]
111 return ds.get(attr)
112
113 def load_meta_data(self, raw_data_dir: pathlib.Path, ds_id, spk, lang):
114 meta_data_dict = {}
115
116 with open(raw_data_dir / 'transcriptions.csv', 'r', encoding='utf8') as f:
117 for utterance_label in csv.DictReader(f):
118 utterance_label: dict
119 item_name = utterance_label['name']
120 item_idx = int(item_name.rsplit(DS_INDEX_SEP, maxsplit=1)[-1]) if DS_INDEX_SEP in item_name else 0
121
122 def require(attr, optional=False):
123 if self.prefer_ds:
124 value = self.load_attr_from_ds(ds_id, item_name, attr, item_idx)
125 else:
126 value = None
127 if value is None:
128 value = utterance_label.get(attr)
129 if value is None and not optional:
130 raise ValueError(f'Missing required attribute {attr} of item \'{item_name}\'.')
131 return value
132
133 wav_fn = None
134 for ext in WAV_CANDIDATE_EXTENSIONS:
135 candidate_fn = raw_data_dir / 'wavs' / f'{item_name}{ext}'
136 if candidate_fn.exists():
137 wav_fn = candidate_fn
138 break
139 if wav_fn is None and not self.prefer_ds:
140 raise FileNotFoundError(
141 f'Waveform file not found for item \'{item_name}\'. '
142 f'Candidate extensions: {WAV_CANDIDATE_EXTENSIONS}\n'
143 f'If you are using DS files instead of waveform files, please set \'prefer_ds\' to true.'
144 )
145
146 temp_dict = {
147 'ds_idx': item_idx,
148 'spk_id': self.spk_map[spk],
149 'spk_name': spk,
150 'language_id': self.lang_map[lang],
151 'language_name': lang,
152 'wav_fn': str(wav_fn) if wav_fn is not None else None,
153 'lang_seq': [
154 (
155 self.lang_map[lang if '/' not in p else p.split('/', maxsplit=1)[0]]
156 if self.phoneme_dictionary.is_cross_lingual(p if '/' in p else f'{lang}/{p}')
157 else 0
158 )
159 for p in utterance_label['ph_seq'].split()
160 ],
161 'ph_seq': self.phoneme_dictionary.encode(require('ph_seq'), lang=lang),
162 'ph_dur': [float(x) for x in require('ph_dur').split()],
163 'ph_text': require('ph_seq'),
164 }
165
166 assert len(temp_dict['ph_seq']) == len(temp_dict['ph_dur']), \
167 f'Lengths of ph_seq and ph_dur mismatch in \'{item_name}\'.'
168 assert all(ph_dur >= 0 for ph_dur in temp_dict['ph_dur']), \
169 f'Negative ph_dur found in \'{item_name}\'.'
170
171 if hparams['predict_dur']:
172 temp_dict['ph_num'] = [int(x) for x in require('ph_num').split()]
173 assert len(temp_dict['ph_seq']) == sum(temp_dict['ph_num']), \
174 f'Sum of ph_num does not equal length of ph_seq in \'{item_name}\'.'
175
176 if hparams['predict_pitch']:
177 temp_dict['note_seq'] = require('note_seq').split()
178 temp_dict['note_dur'] = [float(x) for x in require('note_dur').split()]
179 assert all(note_dur >= 0 for note_dur in temp_dict['note_dur']), \
180 f'Negative note_dur found in \'{item_name}\'.'
181 assert len(temp_dict['note_seq']) == len(temp_dict['note_dur']), \
182 f'Lengths of note_seq and note_dur mismatch in \'{item_name}\'.'
183 assert any([note != 'rest' for note in temp_dict['note_seq']]), \
184 f'All notes are rest in \'{item_name}\'.'
185 if hparams['use_glide_embed']:
186 note_glide = require('note_glide', optional=True)
187 if note_glide is None:
188 note_glide = ['none' for _ in temp_dict['note_seq']]
189 else:
190 note_glide = note_glide.split()
191 assert len(note_glide) == len(temp_dict['note_seq']), \
192 f'Lengths of note_seq and note_glide mismatch in \'{item_name}\'.'
193 assert all(g in self.glide_map for g in note_glide), \
194 f'Invalid glide type found in \'{item_name}\'.'
195 temp_dict['note_glide'] = note_glide
196
197 meta_data_dict[f'{ds_id}:{item_name}'] = temp_dict
198
199 return meta_data_dict
200
201 def check_coverage(self):
202 super().check_coverage()
203 if not hparams['predict_pitch']:
204 return
205
206 # MIDI pitch distribution summary
207 midi_map = {}
208 for item_name in self.items:
209 for midi in self.items[item_name]['note_seq']:
210 if midi == 'rest':
211 continue
212 midi = librosa.note_to_midi(midi, round_midi=True)
213 if midi in midi_map:
214 midi_map[midi] += 1
215 else:
216 midi_map[midi] = 1
217
218 print('===== MIDI Pitch Distribution Summary =====')
219 for i, key in enumerate(sorted(midi_map.keys())):
220 if i == len(midi_map) - 1:
221 end = '\n'
222 elif i % 10 == 9:
223 end = ',\n'
224 else:
225 end = ', '
226 print(f'\'{librosa.midi_to_note(key, unicode=False)}\': {midi_map[key]}', end=end)
227
228 # Draw graph.
229 midis = sorted(midi_map.keys())
230 notes = [librosa.midi_to_note(m, unicode=False) for m in range(midis[0], midis[-1] + 1)]
231 plt = distribution_to_figure(
232 title='MIDI Pitch Distribution Summary',
233 x_label='MIDI Key', y_label='Number of occurrences',
234 items=notes, values=[midi_map.get(m, 0) for m in range(midis[0], midis[-1] + 1)]
235 )
236 filename = self.binary_data_dir / 'midi_distribution.jpg'
237 plt.savefig(fname=filename,
238 bbox_inches='tight',
239 pad_inches=0.25)
240 print(f'| save summary to \'{filename}\'')
241
242 if self.use_glide_embed:
243 # Glide type distribution summary
244 glide_count = {
245 g: 0
246 for g in self.glide_map
247 }
248 for item_name in self.items:
249 for glide in self.items[item_name]['note_glide']:
250 if glide == 'none' or glide not in self.glide_map:
251 glide_count['none'] += 1
252 else:
253 glide_count[glide] += 1
254
255 print('===== Glide Type Distribution Summary =====')
256 for i, key in enumerate(sorted(glide_count.keys(), key=lambda k: self.glide_map[k])):
257 if i == len(glide_count) - 1:
258 end = '\n'
259 elif i % 10 == 9:
260 end = ',\n'
261 else:
262 end = ', '
263 print(f'\'{key}\': {glide_count[key]}', end=end)
264
265 if any(n == 0 for _, n in glide_count.items()):
266 raise BinarizationError(
267 f'Missing glide types in dataset: '
268 f'{sorted([g for g, n in glide_count.items() if n == 0], key=lambda k: self.glide_map[k])}'
269 )
270
271 @torch.no_grad()
272 def process_item(self, item_name, meta_data, binarization_args):
273 ds_id, name = item_name.split(':', maxsplit=1)
274 name = name.rsplit(DS_INDEX_SEP, maxsplit=1)[0]
275 ds_id = int(ds_id)
276 ds_seg_idx = meta_data['ds_idx']
277 seconds = sum(meta_data['ph_dur'])
278 length = round(seconds / self.timestep)
279 T_ph = len(meta_data['ph_seq'])
280 processed_input = {
281 'name': item_name,
282 'wav_fn': meta_data['wav_fn'],
283 'spk_id': meta_data['spk_id'],
284 'spk_name': meta_data['spk_name'],
285 'seconds': seconds,
286 'length': length,
287 'languages': np.array(meta_data['lang_seq'], dtype=np.int64),
288 'tokens': np.array(meta_data['ph_seq'], dtype=np.int64),
289 'ph_text': meta_data['ph_text'],
290 }
291
292 ph_dur_sec = torch.FloatTensor(meta_data['ph_dur']).to(self.device)
293 ph_acc = torch.round(torch.cumsum(ph_dur_sec, dim=0) / self.timestep + 0.5).long()
294 ph_dur = torch.diff(ph_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device))
295 processed_input['ph_dur'] = ph_dur.cpu().numpy()
296
297 mel2ph = get_mel2ph_torch(
298 self.lr, ph_dur_sec, length, self.timestep, device=self.device
299 )
300
301 if hparams['predict_pitch'] or self.predict_variances:
302 processed_input['mel2ph'] = mel2ph.cpu().numpy()
303
304 # Below: extract actual f0, convert to pitch and calculate delta pitch
305 if meta_data['wav_fn'] is not None:
306 waveform, _ = librosa.load(meta_data['wav_fn'], sr=hparams['audio_sample_rate'], mono=True)
307 else:
308 waveform = None
309
310 global pitch_extractor
311 if pitch_extractor is None:
312 pitch_extractor = initialize_pe()
313 f0 = uv = None
314 if self.prefer_ds:
315 f0_seq = self.load_attr_from_ds(ds_id, name, 'f0_seq', idx=ds_seg_idx)
316 if f0_seq is not None:
317 f0 = resample_align_curve(
318 np.array(f0_seq.split(), np.float32),
319 original_timestep=float(self.load_attr_from_ds(ds_id, name, 'f0_timestep', idx=ds_seg_idx)),
320 target_timestep=self.timestep,
321 align_length=length
322 )
323 uv = f0 == 0
324 f0, _ = interp_f0(f0, uv)
325 if f0 is None:
326 f0, uv = pitch_extractor.get_pitch(
327 waveform, samplerate=hparams['audio_sample_rate'], length=length,
328 hop_size=hparams['hop_size'], f0_min=hparams['f0_min'], f0_max=hparams['f0_max'],
329 interp_uv=True
330 )
331 if uv.all(): # All unvoiced
332 print(f'Skipped \'{item_name}\': empty gt f0')
333 return None
334 pitch = torch.from_numpy(librosa.hz_to_midi(f0.astype(np.float32)).astype(np.float32)).to(self.device)
335
336 if hparams['predict_dur']:
337 ph_num = torch.LongTensor(meta_data['ph_num']).to(self.device)
338 ph2word = self.lr(ph_num[None])[0]
339 processed_input['ph2word'] = ph2word.cpu().numpy()
340 mel2dur = torch.gather(F.pad(ph_dur, [1, 0], value=1), 0, mel2ph) # frame-level phone duration
341 ph_midi = pitch.new_zeros(T_ph + 1).scatter_add(
342 0, mel2ph, pitch / mel2dur
343 )[1:]
344 processed_input['midi'] = ph_midi.round().long().clamp(min=0, max=127).cpu().numpy()
345
346 if hparams['predict_pitch']:
347 # Below: get note sequence and interpolate rest notes
348 note_midi = np.array(
349 [(librosa.note_to_midi(n, round_midi=False) if n != 'rest' else -1) for n in meta_data['note_seq']],
350 dtype=np.float32
351 )
352 note_rest = note_midi < 0
353 interp_func = interpolate.interp1d(
354 np.where(~note_rest)[0], note_midi[~note_rest],
355 kind='nearest', fill_value='extrapolate'
356 )
357 note_midi[note_rest] = interp_func(np.where(note_rest)[0])
358 processed_input['note_midi'] = note_midi
359 processed_input['note_rest'] = note_rest
360 note_midi = torch.from_numpy(note_midi).to(self.device)
361
362 note_dur_sec = torch.FloatTensor(meta_data['note_dur']).to(self.device)
363 note_acc = torch.round(torch.cumsum(note_dur_sec, dim=0) / self.timestep + 0.5).long()
364 note_dur = torch.diff(note_acc, dim=0, prepend=torch.LongTensor([0]).to(self.device))
365 processed_input['note_dur'] = note_dur.cpu().numpy()
366
367 mel2note = get_mel2ph_torch(
368 self.lr, note_dur_sec, mel2ph.shape[0], self.timestep, device=self.device
369 )
370 processed_input['mel2note'] = mel2note.cpu().numpy()
371
372 # Below: get ornament attributes
373 if hparams['use_glide_embed']:
374 processed_input['note_glide'] = np.array([
375 self.glide_map.get(x, 0) for x in meta_data['note_glide']
376 ], dtype=np.int64)
377
378 # Below:
379 # 1. Get the frame-level MIDI pitch, which is a step function curve
380 # 2. smoothen the pitch step curve as the base pitch curve
381 frame_midi_pitch = torch.gather(F.pad(note_midi, [1, 0], value=0), 0, mel2note)
382 global midi_smooth
383 if midi_smooth is None:
384 midi_smooth = SinusoidalSmoothingConv1d(
385 round(hparams['midi_smooth_width'] / self.timestep)
386 ).eval().to(self.device)
387 smoothed_midi_pitch = midi_smooth(frame_midi_pitch[None])[0]
388 processed_input['base_pitch'] = smoothed_midi_pitch.cpu().numpy()
389
390 if hparams['predict_pitch'] or self.predict_variances:
391 processed_input['pitch'] = pitch.cpu().numpy()
392 processed_input['uv'] = uv
393
394 # Below: extract energy
395 if hparams['predict_energy']:
396 energy = None
397 energy_from_wav = False
398 if self.prefer_ds:
399 energy_seq = self.load_attr_from_ds(ds_id, name, 'energy', idx=ds_seg_idx)
400 if energy_seq is not None:
401 energy = resample_align_curve(
402 np.array(energy_seq.split(), np.float32),
403 original_timestep=float(self.load_attr_from_ds(
404 ds_id, name, 'energy_timestep', idx=ds_seg_idx
405 )),
406 target_timestep=self.timestep,
407 align_length=length
408 )
409 if energy is None:
410 energy = get_energy_librosa(
411 waveform, length,
412 hop_size=hparams['hop_size'], win_size=hparams['win_size']
413 ).astype(np.float32)
414 energy_from_wav = True
415
416 if energy_from_wav:
417 global energy_smooth
418 if energy_smooth is None:
419 energy_smooth = SinusoidalSmoothingConv1d(
420 round(hparams['energy_smooth_width'] / self.timestep)
421 ).eval().to(self.device)
422 energy = energy_smooth(torch.from_numpy(energy).to(self.device)[None])[0].cpu().numpy()
423
424 processed_input['energy'] = energy
425
426 # create a DecomposedWaveform object for further feature extraction
427 dec_waveform = DecomposedWaveform(
428 waveform, samplerate=hparams['audio_sample_rate'], f0=f0 * ~uv,
429 hop_size=hparams['hop_size'], fft_size=hparams['fft_size'], win_size=hparams['win_size'],
430 algorithm=hparams['hnsep']
431 ) if waveform is not None else None
432
433 # Below: extract breathiness
434 if hparams['predict_breathiness']:
435 breathiness = None
436 breathiness_from_wav = False
437 if self.prefer_ds:
438 breathiness_seq = self.load_attr_from_ds(ds_id, name, 'breathiness', idx=ds_seg_idx)
439 if breathiness_seq is not None:
440 breathiness = resample_align_curve(
441 np.array(breathiness_seq.split(), np.float32),
442 original_timestep=float(self.load_attr_from_ds(
443 ds_id, name, 'breathiness_timestep', idx=ds_seg_idx
444 )),
445 target_timestep=self.timestep,
446 align_length=length
447 )
448 if breathiness is None:
449 breathiness = get_breathiness(
450 dec_waveform, None, None, length=length
451 )
452 breathiness_from_wav = True
453
454 if breathiness_from_wav:
455 global breathiness_smooth
456 if breathiness_smooth is None:
457 breathiness_smooth = SinusoidalSmoothingConv1d(
458 round(hparams['breathiness_smooth_width'] / self.timestep)
459 ).eval().to(self.device)
460 breathiness = breathiness_smooth(torch.from_numpy(breathiness).to(self.device)[None])[0].cpu().numpy()
461
462 processed_input['breathiness'] = breathiness
463
464 # Below: extract voicing
465 if hparams['predict_voicing']:
466 voicing = None
467 voicing_from_wav = False
468 if self.prefer_ds:
469 voicing_seq = self.load_attr_from_ds(ds_id, name, 'voicing', idx=ds_seg_idx)
470 if voicing_seq is not None:
471 voicing = resample_align_curve(
472 np.array(voicing_seq.split(), np.float32),
473 original_timestep=float(self.load_attr_from_ds(
474 ds_id, name, 'voicing_timestep', idx=ds_seg_idx
475 )),
476 target_timestep=self.timestep,
477 align_length=length
478 )
479 if voicing is None:
480 voicing = get_voicing(
481 dec_waveform, None, None, length=length
482 )
483 voicing_from_wav = True
484
485 if voicing_from_wav:
486 global voicing_smooth
487 if voicing_smooth is None:
488 voicing_smooth = SinusoidalSmoothingConv1d(
489 round(hparams['voicing_smooth_width'] / self.timestep)
490 ).eval().to(self.device)
491 voicing = voicing_smooth(torch.from_numpy(voicing).to(self.device)[None])[0].cpu().numpy()
492
493 processed_input['voicing'] = voicing
494
495 # Below: extract tension
496 if hparams['predict_tension']:
497 tension = None
498 tension_from_wav = False
499 if self.prefer_ds:
500 tension_seq = self.load_attr_from_ds(ds_id, name, 'tension', idx=ds_seg_idx)
501 if tension_seq is not None:
502 tension = resample_align_curve(
503 np.array(tension_seq.split(), np.float32),
504 original_timestep=float(self.load_attr_from_ds(
505 ds_id, name, 'tension_timestep', idx=ds_seg_idx
506 )),
507 target_timestep=self.timestep,
508 align_length=length
509 )
510 if tension is None:
511 tension = get_tension_base_harmonic(
512 dec_waveform, None, None, length=length, domain='logit'
513 )
514 tension_from_wav = True
515
516 if tension_from_wav:
517 global tension_smooth
518 if tension_smooth is None:
519 tension_smooth = SinusoidalSmoothingConv1d(
520 round(hparams['tension_smooth_width'] / self.timestep)
521 ).eval().to(self.device)
522 tension = tension_smooth(torch.from_numpy(tension).to(self.device)[None])[0].cpu().numpy()
523
524 processed_input['tension'] = tension
525
526 return processed_input
527
528 def arrange_data_augmentation(self, data_iterator):
529 return {}