Ardantia (formerly VsingerXiaoiceSing)
at main 122 lines 4.6 kB view raw
1import matplotlib.pyplot as plt 2import numpy as np 3import torch 4from matplotlib.ticker import MultipleLocator 5 6 7def spec_to_figure(spec, vmin=None, vmax=None, title=None): 8 if isinstance(spec, torch.Tensor): 9 spec = spec.cpu().numpy() 10 fig = plt.figure(figsize=(12, 9)) 11 plt.pcolor(spec.T, vmin=vmin, vmax=vmax) 12 if title is not None: 13 plt.title(title, fontsize=15) 14 plt.tight_layout() 15 return fig 16 17 18def dur_to_figure(dur_gt, dur_pred, txt, title=None): 19 if isinstance(dur_gt, torch.Tensor): 20 dur_gt = dur_gt.cpu().numpy() 21 if isinstance(dur_pred, torch.Tensor): 22 dur_pred = dur_pred.cpu().numpy() 23 dur_gt = dur_gt.astype(np.int64) 24 dur_pred = dur_pred.astype(np.int64) 25 dur_gt = np.cumsum(dur_gt) 26 dur_pred = np.cumsum(dur_pred) 27 width = max(12, min(48, len(txt) // 2)) 28 fig = plt.figure(figsize=(width, 8)) 29 plt.vlines(dur_pred, 12, 22, colors='r', label='pred') 30 plt.vlines(dur_gt, 0, 10, colors='b', label='gt') 31 for i in range(len(txt)): 32 shift = (i % 8) + 1 33 plt.text((dur_pred[i-1] + dur_pred[i]) / 2 if i > 0 else dur_pred[i] / 2, 12 + shift, txt[i], 34 size=16, horizontalalignment='center') 35 plt.text((dur_gt[i-1] + dur_gt[i]) / 2 if i > 0 else dur_gt[i] / 2, shift, txt[i], 36 size=16, horizontalalignment='center') 37 plt.plot([dur_pred[i], dur_gt[i]], [12, 10], color='black', linewidth=2, linestyle=':') 38 plt.yticks([]) 39 plt.xlim(0, max(dur_pred[-1], dur_gt[-1])) 40 plt.legend() 41 if title is not None: 42 plt.title(title, fontsize=15) 43 plt.tight_layout() 44 return fig 45 46 47def pitch_note_to_figure(pitch_gt, pitch_pred=None, note_midi=None, note_dur=None, note_rest=None, title=None): 48 if isinstance(pitch_gt, torch.Tensor): 49 pitch_gt = pitch_gt.cpu().numpy() 50 if isinstance(pitch_pred, torch.Tensor): 51 pitch_pred = pitch_pred.cpu().numpy() 52 if isinstance(note_midi, torch.Tensor): 53 note_midi = note_midi.cpu().numpy() 54 if isinstance(note_dur, torch.Tensor): 55 note_dur = note_dur.cpu().numpy() 56 if isinstance(note_rest, torch.Tensor): 57 note_rest = note_rest.cpu().numpy() 58 fig = plt.figure() 59 if note_midi is not None and note_dur is not None: 60 note_dur_acc = np.cumsum(note_dur) 61 if note_rest is None: 62 note_rest = np.zeros_like(note_midi, dtype=np.bool_) 63 for i in range(len(note_midi)): 64 # if note_rest[i]: 65 # continue 66 plt.gca().add_patch( 67 plt.Rectangle( 68 xy=(note_dur_acc[i-1] if i > 0 else 0, note_midi[i] - 0.5), 69 width=note_dur[i], height=1, 70 edgecolor='grey', fill=False, 71 linewidth=1.5, linestyle='--' if note_rest[i] else '-' 72 ) 73 ) 74 plt.plot(pitch_gt, color='b', label='gt') 75 if pitch_pred is not None: 76 plt.plot(pitch_pred, color='r', label='pred') 77 plt.gca().yaxis.set_major_locator(MultipleLocator(1)) 78 plt.grid(axis='y') 79 plt.legend() 80 if title is not None: 81 plt.title(title, fontsize=15) 82 plt.tight_layout() 83 return fig 84 85 86def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None, title=None): 87 if isinstance(curve_gt, torch.Tensor): 88 curve_gt = curve_gt.cpu().numpy() 89 if isinstance(curve_pred, torch.Tensor): 90 curve_pred = curve_pred.cpu().numpy() 91 if isinstance(curve_base, torch.Tensor): 92 curve_base = curve_base.cpu().numpy() 93 fig = plt.figure() 94 if curve_base is not None: 95 plt.plot(curve_base, color='g', label='base') 96 plt.plot(curve_gt, color='b', label='gt') 97 if curve_pred is not None: 98 plt.plot(curve_pred, color='r', label='pred') 99 if grid is not None: 100 plt.gca().yaxis.set_major_locator(MultipleLocator(grid)) 101 plt.grid(axis='y') 102 plt.legend() 103 if title is not None: 104 plt.title(title, fontsize=15) 105 plt.tight_layout() 106 return fig 107 108 109def distribution_to_figure(title, x_label, y_label, items: list, values: list, zoom=0.8, rotate=False): 110 fig = plt.figure(figsize=(int(len(items) * zoom), 10)) 111 plt.bar(x=items, height=values) 112 plt.tick_params(labelsize=15) 113 plt.xlim(-1, len(items)) 114 for a, b in zip(items, values): 115 plt.text(a, b, b, ha='center', va='bottom', fontsize=15) 116 plt.grid() 117 plt.title(title, fontsize=30) 118 plt.xlabel(x_label, fontsize=20) 119 plt.ylabel(y_label, fontsize=20) 120 if rotate: 121 fig.autofmt_xdate(rotation=45) 122 return fig