Ardantia (formerly VsingerXiaoiceSing)
1import matplotlib.pyplot as plt
2import numpy as np
3import torch
4from matplotlib.ticker import MultipleLocator
5
6
7def spec_to_figure(spec, vmin=None, vmax=None, title=None):
8 if isinstance(spec, torch.Tensor):
9 spec = spec.cpu().numpy()
10 fig = plt.figure(figsize=(12, 9))
11 plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
12 if title is not None:
13 plt.title(title, fontsize=15)
14 plt.tight_layout()
15 return fig
16
17
18def dur_to_figure(dur_gt, dur_pred, txt, title=None):
19 if isinstance(dur_gt, torch.Tensor):
20 dur_gt = dur_gt.cpu().numpy()
21 if isinstance(dur_pred, torch.Tensor):
22 dur_pred = dur_pred.cpu().numpy()
23 dur_gt = dur_gt.astype(np.int64)
24 dur_pred = dur_pred.astype(np.int64)
25 dur_gt = np.cumsum(dur_gt)
26 dur_pred = np.cumsum(dur_pred)
27 width = max(12, min(48, len(txt) // 2))
28 fig = plt.figure(figsize=(width, 8))
29 plt.vlines(dur_pred, 12, 22, colors='r', label='pred')
30 plt.vlines(dur_gt, 0, 10, colors='b', label='gt')
31 for i in range(len(txt)):
32 shift = (i % 8) + 1
33 plt.text((dur_pred[i-1] + dur_pred[i]) / 2 if i > 0 else dur_pred[i] / 2, 12 + shift, txt[i],
34 size=16, horizontalalignment='center')
35 plt.text((dur_gt[i-1] + dur_gt[i]) / 2 if i > 0 else dur_gt[i] / 2, shift, txt[i],
36 size=16, horizontalalignment='center')
37 plt.plot([dur_pred[i], dur_gt[i]], [12, 10], color='black', linewidth=2, linestyle=':')
38 plt.yticks([])
39 plt.xlim(0, max(dur_pred[-1], dur_gt[-1]))
40 plt.legend()
41 if title is not None:
42 plt.title(title, fontsize=15)
43 plt.tight_layout()
44 return fig
45
46
47def pitch_note_to_figure(pitch_gt, pitch_pred=None, note_midi=None, note_dur=None, note_rest=None, title=None):
48 if isinstance(pitch_gt, torch.Tensor):
49 pitch_gt = pitch_gt.cpu().numpy()
50 if isinstance(pitch_pred, torch.Tensor):
51 pitch_pred = pitch_pred.cpu().numpy()
52 if isinstance(note_midi, torch.Tensor):
53 note_midi = note_midi.cpu().numpy()
54 if isinstance(note_dur, torch.Tensor):
55 note_dur = note_dur.cpu().numpy()
56 if isinstance(note_rest, torch.Tensor):
57 note_rest = note_rest.cpu().numpy()
58 fig = plt.figure()
59 if note_midi is not None and note_dur is not None:
60 note_dur_acc = np.cumsum(note_dur)
61 if note_rest is None:
62 note_rest = np.zeros_like(note_midi, dtype=np.bool_)
63 for i in range(len(note_midi)):
64 # if note_rest[i]:
65 # continue
66 plt.gca().add_patch(
67 plt.Rectangle(
68 xy=(note_dur_acc[i-1] if i > 0 else 0, note_midi[i] - 0.5),
69 width=note_dur[i], height=1,
70 edgecolor='grey', fill=False,
71 linewidth=1.5, linestyle='--' if note_rest[i] else '-'
72 )
73 )
74 plt.plot(pitch_gt, color='b', label='gt')
75 if pitch_pred is not None:
76 plt.plot(pitch_pred, color='r', label='pred')
77 plt.gca().yaxis.set_major_locator(MultipleLocator(1))
78 plt.grid(axis='y')
79 plt.legend()
80 if title is not None:
81 plt.title(title, fontsize=15)
82 plt.tight_layout()
83 return fig
84
85
86def curve_to_figure(curve_gt, curve_pred=None, curve_base=None, grid=None, title=None):
87 if isinstance(curve_gt, torch.Tensor):
88 curve_gt = curve_gt.cpu().numpy()
89 if isinstance(curve_pred, torch.Tensor):
90 curve_pred = curve_pred.cpu().numpy()
91 if isinstance(curve_base, torch.Tensor):
92 curve_base = curve_base.cpu().numpy()
93 fig = plt.figure()
94 if curve_base is not None:
95 plt.plot(curve_base, color='g', label='base')
96 plt.plot(curve_gt, color='b', label='gt')
97 if curve_pred is not None:
98 plt.plot(curve_pred, color='r', label='pred')
99 if grid is not None:
100 plt.gca().yaxis.set_major_locator(MultipleLocator(grid))
101 plt.grid(axis='y')
102 plt.legend()
103 if title is not None:
104 plt.title(title, fontsize=15)
105 plt.tight_layout()
106 return fig
107
108
109def distribution_to_figure(title, x_label, y_label, items: list, values: list, zoom=0.8, rotate=False):
110 fig = plt.figure(figsize=(int(len(items) * zoom), 10))
111 plt.bar(x=items, height=values)
112 plt.tick_params(labelsize=15)
113 plt.xlim(-1, len(items))
114 for a, b in zip(items, values):
115 plt.text(a, b, b, ha='center', va='bottom', fontsize=15)
116 plt.grid()
117 plt.title(title, fontsize=30)
118 plt.xlabel(x_label, fontsize=20)
119 plt.ylabel(y_label, fontsize=20)
120 if rotate:
121 fig.autofmt_xdate(rotation=45)
122 return fig