NeoPy commited on
Commit
cbe74d4
·
verified ·
1 Parent(s): 0ba66e0

Update RVC/modules/rmvpe.py

Browse files
Files changed (1) hide show
  1. RVC/modules/rmvpe.py +384 -52
RVC/modules/rmvpe.py CHANGED
@@ -10,14 +10,258 @@ from librosa.filters import mel
10
 
11
  sys.path.append(os.getcwd())
12
 
13
- from modules import opencl
14
-
15
  N_MELS, N_CLASS = 128, 360
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  class ConvBlockRes(nn.Module):
18
  def __init__(self, in_channels, out_channels, momentum=0.01):
19
  super(ConvBlockRes, self).__init__()
20
- self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  if in_channels != out_channels:
22
  self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
23
  self.is_shortcut = True
@@ -91,7 +335,23 @@ class ResDecoderBlock(nn.Module):
91
  def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
92
  super(ResDecoderBlock, self).__init__()
93
  out_padding = (0, 1) if stride == (1, 2) else (1, 1)
94
- self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  self.conv2 = nn.ModuleList()
96
  self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
97
 
@@ -131,19 +391,89 @@ class DeepUnet(nn.Module):
131
  def forward(self, x):
132
  x, concat_tensors = self.encoder(x)
133
  return self.decoder(self.intermediate(x), concat_tensors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  class E2E(nn.Module):
136
- def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
137
  super(E2E, self).__init__()
138
- self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
140
- self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid()) if n_gru else nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  def forward(self, mel):
143
  return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
144
 
145
- class MelSpectrogram(torch.nn.Module):
146
- def __init__(self, is_half, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
147
  super().__init__()
148
  n_fft = win_length if n_fft is None else n_fft
149
  self.hann_window = {}
@@ -156,7 +486,6 @@ class MelSpectrogram(torch.nn.Module):
156
  self.sample_rate = sample_rate
157
  self.n_mel_channels = n_mel_channels
158
  self.clamp = clamp
159
- self.is_half = is_half
160
 
161
  def forward(self, audio, keyshift=0, speed=1, center=True):
162
  factor = 2 ** (keyshift / 12)
@@ -167,12 +496,8 @@ class MelSpectrogram(torch.nn.Module):
167
  n_fft = int(np.round(self.n_fft * factor))
168
  hop_length = int(np.round(self.hop_length * speed))
169
 
170
- if str(audio.device).startswith("ocl"):
171
- stft = opencl.STFT(filter_length=n_fft, hop_length=hop_length, win_length=win_length_new).to(audio.device)
172
- magnitude = stft.transform(audio, 1e-9)
173
- else:
174
- fft = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
175
- magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
176
 
177
  if keyshift != 0:
178
  size = self.n_fft // 2 + 1
@@ -180,34 +505,55 @@ class MelSpectrogram(torch.nn.Module):
180
  if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
181
  magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
182
 
183
- mel_output = torch.matmul(self.mel_basis, magnitude)
184
- if self.is_half: mel_output = mel_output.half()
185
-
186
- return torch.log(torch.clamp(mel_output, min=self.clamp))
187
 
188
  class RMVPE:
189
- def __init__(self, model_path, is_half, device=None):
190
- self.resample_kernel = {}
191
- self.resample_kernel = {}
192
- model = E2E(4, 1, (2, 2))
193
- ckpt = torch.load(model_path, map_location="cpu")
194
- model.load_state_dict(ckpt)
195
- model.eval()
196
- if is_half: model = model.half()
197
- self.model = model.to(device)
198
- self.is_half = is_half
 
 
 
 
 
 
 
199
  self.device = device
200
- self.mel_extractor = MelSpectrogram(is_half, N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
 
201
  cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
202
  self.cents_mapping = np.pad(cents_mapping, (4, 4))
203
 
204
- def mel2hidden(self, mel):
205
  with torch.no_grad():
206
  n_frames = mel.shape[-1]
207
- n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
208
- if n_pad > 0: mel = F.pad(mel, (0, n_pad), mode="constant")
209
 
210
- hidden = self.model(mel.half() if self.is_half else mel.float())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  return hidden[:, :n_frames]
212
 
213
  def decode(self, hidden, thred=0.03):
@@ -219,12 +565,10 @@ class RMVPE:
219
  def infer_from_audio(self, audio, thred=0.03):
220
  hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
221
 
222
- return self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()), thred=thred)
223
 
224
  def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
225
- hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
226
-
227
- f0 = self.decode((hidden.squeeze(0).cpu().numpy().astype(np.float32) if self.is_half else hidden.squeeze(0).cpu().numpy()), thred=thred)
228
  f0[(f0 < f0_min) | (f0 > f0_max)] = 0
229
 
230
  return f0
@@ -245,16 +589,4 @@ class RMVPE:
245
  devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
246
  devided[np.max(salience, axis=1) <= thred] = 0
247
 
248
- return devided
249
-
250
- class BiGRU(nn.Module):
251
- def __init__(self, input_features, hidden_features, num_layers):
252
- super(BiGRU, self).__init__()
253
- self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
254
-
255
- def forward(self, x):
256
- try:
257
- return self.gru(x)[0]
258
- except:
259
- torch.backends.cudnn.enabled = False
260
- return self.gru(x)[0]
 
10
 
11
  sys.path.append(os.getcwd())
12
 
 
 
13
  N_MELS, N_CLASS = 128, 360
14
 
15
+ def autopad(k, p=None):
16
+ if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
17
+ return p
18
+
19
+ class Conv(nn.Module):
20
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
21
+ super().__init__()
22
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
23
+ self.bn = nn.BatchNorm2d(c2)
24
+ self.act = nn.SiLU() if act else nn.Identity()
25
+
26
+ def forward(self, x):
27
+ return self.act(self.bn(self.conv(x)))
28
+
29
+ class DSConv(nn.Module):
30
+ def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
31
+ super().__init__()
32
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
33
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
34
+ self.bn = nn.BatchNorm2d(c2)
35
+ self.act = nn.SiLU() if act else nn.Identity()
36
+
37
+ def forward(self, x):
38
+ return self.act(self.bn(self.pwconv(self.dwconv(x))))
39
+
40
+ class DS_Bottleneck(nn.Module):
41
+ def __init__(self, c1, c2, k=3, shortcut=True):
42
+ super().__init__()
43
+ self.dsconv1 = DSConv(c1, c1, k=3, s=1)
44
+ self.dsconv2 = DSConv(c1, c2, k=k, s=1)
45
+ self.shortcut = shortcut and c1 == c2
46
+
47
+ def forward(self, x):
48
+ return x + self.dsconv2(self.dsconv1(x)) if self.shortcut else self.dsconv2(self.dsconv1(x))
49
+
50
+ class DS_C3k(nn.Module):
51
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
52
+ super().__init__()
53
+ self.cv1 = Conv(c1, int(c2 * e), 1, 1)
54
+ self.cv2 = Conv(c1, int(c2 * e), 1, 1)
55
+ self.cv3 = Conv(2 * int(c2 * e), c2, 1, 1)
56
+ self.m = nn.Sequential(*[DS_Bottleneck(int(c2 * e), int(c2 * e), k=k, shortcut=True) for _ in range(n)])
57
+
58
+ def forward(self, x):
59
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
60
+
61
+ class DS_C3k2(nn.Module):
62
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
63
+ super().__init__()
64
+ self.cv1 = Conv(c1, int(c2 * e), 1, 1)
65
+ self.m = DS_C3k(int(c2 * e), int(c2 * e), n=n, k=k, e=1.0)
66
+ self.cv2 = Conv(int(c2 * e), c2, 1, 1)
67
+
68
+ def forward(self, x):
69
+ return self.cv2(self.m(self.cv1(x)))
70
+
71
+ class AdaptiveHyperedgeGeneration(nn.Module):
72
+ def __init__(self, in_channels, num_hyperedges, num_heads):
73
+ super().__init__()
74
+ self.num_hyperedges = num_hyperedges
75
+ self.num_heads = num_heads
76
+ self.head_dim = max(1, in_channels // num_heads)
77
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
78
+ self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False)
79
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
80
+ self.scale = self.head_dim ** -0.5
81
+
82
+ def forward(self, x):
83
+ B, N, C = x.shape
84
+ P = self.global_proto.unsqueeze(0) + self.context_mapper(torch.cat((F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1), F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)), dim=1)).view(B, self.num_hyperedges, C)
85
+
86
+ return F.softmax(((self.query_proj(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) @ P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)) * self.scale).mean(dim=1).permute(0, 2, 1), dim=-1)
87
+
88
+ class HypergraphConvolution(nn.Module):
89
+ def __init__(self, in_channels, out_channels):
90
+ super().__init__()
91
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
92
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
93
+ self.act = nn.SiLU()
94
+
95
+ def forward(self, x, A):
96
+ return x + self.act(self.W_v(A.transpose(1, 2).bmm(self.act(self.W_e(A.bmm(x))))))
97
+
98
+ class AdaptiveHypergraphComputation(nn.Module):
99
+ def __init__(self, in_channels, out_channels, num_hyperedges, num_heads):
100
+ super().__init__()
101
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(in_channels, num_hyperedges, num_heads)
102
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
103
+
104
+ def forward(self, x):
105
+ B, _, H, W = x.shape
106
+ x_flat = x.flatten(2).permute(0, 2, 1)
107
+ return self.hypergraph_conv(x_flat, self.adaptive_hyperedge_gen(x_flat)).permute(0, 2, 1).view(B, -1, H, W)
108
+
109
+ class C3AH(nn.Module):
110
+ def __init__(self, c1, c2, num_hyperedges, num_heads, e=0.5):
111
+ super().__init__()
112
+ self.cv1 = Conv(c1, int(c1 * e), 1, 1)
113
+ self.cv2 = Conv(c1, int(c1 * e), 1, 1)
114
+ self.ahc = AdaptiveHypergraphComputation(int(c1 * e), int(c1 * e), num_hyperedges, num_heads)
115
+ self.cv3 = Conv(2 * int(c1 * e), c2, 1, 1)
116
+
117
+ def forward(self, x):
118
+ return self.cv3(torch.cat((self.ahc(self.cv2(x)), self.cv1(x)), dim=1))
119
+
120
+ class HyperACE(nn.Module):
121
+ def __init__(self, in_channels, out_channels, num_hyperedges=16, num_heads=8, k=2, l=1, c_h=0.5, c_l=0.25):
122
+ super().__init__()
123
+ c2, c3, c4, c5 = in_channels
124
+ c_mid = c4
125
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
126
+ self.c_h = int(c_mid * c_h)
127
+ self.c_l = int(c_mid * c_l)
128
+ self.c_s = c_mid - self.c_h - self.c_l
129
+ self.high_order_branch = nn.ModuleList([C3AH(self.c_h, self.c_h, num_hyperedges=num_hyperedges, num_heads=num_heads, e=1.0) for _ in range(k)])
130
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
131
+ self.low_order_branch = nn.Sequential(*[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)])
132
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
133
+
134
+ def forward(self, x):
135
+ B2, B3, B4, B5 = x
136
+ _, _, H4, W4 = B4.shape
137
+
138
+ x_h, x_l, x_s = self.fuse_conv(
139
+ torch.cat(
140
+ (
141
+ F.interpolate(B2, size=(H4, W4), mode='bilinear', align_corners=False),
142
+ F.interpolate(B3, size=(H4, W4), mode='bilinear', align_corners=False),
143
+ B4,
144
+ F.interpolate(B5, size=(H4, W4), mode='bilinear', align_corners=False)
145
+ ),
146
+ dim=1
147
+ )
148
+ ).split([self.c_h, self.c_l, self.c_s], dim=1)
149
+
150
+ return self.final_fuse(torch.cat((self.high_order_fuse(torch.cat([m(x_h) for m in self.high_order_branch], dim=1)), self.low_order_branch(x_l), x_s), dim=1))
151
+
152
+ class GatedFusion(nn.Module):
153
+ def __init__(self, in_channels):
154
+ super().__init__()
155
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
156
+
157
+ def forward(self, f_in, h):
158
+ return f_in + self.gamma * h
159
+
160
+ class YOLO13Encoder(nn.Module):
161
+ def __init__(self, in_channels, base_channels=32):
162
+ super().__init__()
163
+ self.stem = DSConv(in_channels, base_channels, k=3, s=1)
164
+
165
+ self.p2 = nn.Sequential(
166
+ DSConv(base_channels, base_channels*2, k=3, s=(2, 2)),
167
+ DS_C3k2(base_channels*2, base_channels*2, n=1)
168
+ )
169
+
170
+ self.p3 = nn.Sequential(
171
+ DSConv(base_channels*2, base_channels*4, k=3, s=(2, 2)),
172
+ DS_C3k2(base_channels*4, base_channels*4, n=2)
173
+ )
174
+
175
+ self.p4 = nn.Sequential(
176
+ DSConv(base_channels*4, base_channels*8, k=3, s=(2, 2)),
177
+ DS_C3k2(base_channels*8, base_channels*8, n=2)
178
+ )
179
+
180
+ self.p5 = nn.Sequential(
181
+ DSConv(base_channels*8, base_channels*16, k=3, s=(2, 2)),
182
+ DS_C3k2(base_channels*16, base_channels*16, n=1)
183
+ )
184
+
185
+ self.out_channels = [base_channels*2, base_channels*4, base_channels*8, base_channels*16]
186
+
187
+ def forward(self, x):
188
+ x = self.stem(x)
189
+ p2 = self.p2(x)
190
+ p3 = self.p3(p2)
191
+ p4 = self.p4(p3)
192
+ p5 = self.p5(p4)
193
+ return [p2, p3, p4, p5]
194
+
195
+ class YOLO13FullPADDecoder(nn.Module):
196
+ def __init__(self, encoder_channels, hyperace_out_c, out_channels_final):
197
+ super().__init__()
198
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
199
+ c_d5, c_d4, c_d3, c_d2 = c_p5, c_p4, c_p3, c_p2
200
+
201
+ self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
202
+ self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
203
+ self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
204
+ self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
205
+
206
+ self.fusion_d5 = GatedFusion(c_d5)
207
+ self.fusion_d4 = GatedFusion(c_d4)
208
+ self.fusion_d3 = GatedFusion(c_d3)
209
+ self.fusion_d2 = GatedFusion(c_d2)
210
+
211
+ self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
212
+ self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
213
+ self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
214
+ self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
215
+
216
+ self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
217
+ self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
218
+ self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
219
+
220
+ self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
221
+ self.final_conv = Conv(c_d2, out_channels_final, 1, 1)
222
+
223
+ def forward(self, enc_feats, h_ace):
224
+ p2, p3, p4, p5 = enc_feats
225
+
226
+ d5 = self.skip_p5(p5)
227
+ d4 = self.up_d5(F.interpolate(self.fusion_d5(d5, self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode='bilinear', align_corners=False))), size=p4.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p4(p4)
228
+ d3 = self.up_d4(F.interpolate(self.fusion_d4(d4, self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode='bilinear', align_corners=False))), size=p3.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p3(p3)
229
+ d2 = self.up_d3(F.interpolate(self.fusion_d3(d3, self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode='bilinear', align_corners=False))), size=p2.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p2(p2)
230
+
231
+ return self.final_conv(self.final_d2(self.fusion_d2(d2, self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode='bilinear', align_corners=False)))))
232
+
233
  class ConvBlockRes(nn.Module):
234
  def __init__(self, in_channels, out_channels, momentum=0.01):
235
  super(ConvBlockRes, self).__init__()
236
+ self.conv = nn.Sequential(
237
+ nn.Conv2d(
238
+ in_channels=in_channels,
239
+ out_channels=out_channels,
240
+ kernel_size=(3, 3),
241
+ stride=(1, 1),
242
+ padding=(1, 1),
243
+ bias=False
244
+ ),
245
+ nn.BatchNorm2d(
246
+ out_channels,
247
+ momentum=momentum
248
+ ),
249
+ nn.ReLU(),
250
+ nn.Conv2d(
251
+ in_channels=out_channels,
252
+ out_channels=out_channels,
253
+ kernel_size=(3, 3),
254
+ stride=(1, 1),
255
+ padding=(1, 1),
256
+ bias=False
257
+ ),
258
+ nn.BatchNorm2d(
259
+ out_channels,
260
+ momentum=momentum
261
+ ),
262
+ nn.ReLU()
263
+ )
264
+
265
  if in_channels != out_channels:
266
  self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
267
  self.is_shortcut = True
 
335
  def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
336
  super(ResDecoderBlock, self).__init__()
337
  out_padding = (0, 1) if stride == (1, 2) else (1, 1)
338
+ self.conv1 = nn.Sequential(
339
+ nn.ConvTranspose2d(
340
+ in_channels=in_channels,
341
+ out_channels=out_channels,
342
+ kernel_size=(3, 3),
343
+ stride=stride,
344
+ padding=(1, 1),
345
+ output_padding=out_padding,
346
+ bias=False
347
+ ),
348
+ nn.BatchNorm2d(
349
+ out_channels,
350
+ momentum=momentum
351
+ ),
352
+ nn.ReLU()
353
+ )
354
+
355
  self.conv2 = nn.ModuleList()
356
  self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
357
 
 
391
  def forward(self, x):
392
  x, concat_tensors = self.encoder(x)
393
  return self.decoder(self.intermediate(x), concat_tensors)
394
+
395
+ class HPADeepUnet(nn.Module):
396
+ def __init__(self, in_channels=1, en_out_channels=16, base_channels=64, hyperace_k=2, hyperace_l=1, num_hyperedges=16, num_heads=8):
397
+ super().__init__()
398
+ self.encoder = YOLO13Encoder(in_channels, base_channels)
399
+ enc_ch = self.encoder.out_channels
400
+
401
+ self.hyperace = HyperACE(
402
+ in_channels=enc_ch,
403
+ out_channels=enc_ch[-1],
404
+ num_hyperedges=num_hyperedges,
405
+ num_heads=num_heads,
406
+ k=hyperace_k,
407
+ l=hyperace_l
408
+ )
409
+
410
+ self.decoder = YOLO13FullPADDecoder(
411
+ encoder_channels=enc_ch,
412
+ hyperace_out_c=enc_ch[-1],
413
+ out_channels_final=en_out_channels
414
+ )
415
 
416
+ def forward(self, x):
417
+ features = self.encoder(x)
418
+ return nn.functional.interpolate(self.decoder(features, self.hyperace(features)), size=x.shape[2:], mode='bilinear', align_corners=False)
419
+
420
+ class BiGRU(nn.Module):
421
+ def __init__(self, input_features, hidden_features, num_layers):
422
+ super(BiGRU, self).__init__()
423
+ self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
424
+
425
+ def forward(self, x):
426
+ try:
427
+ return self.gru(x)[0]
428
+ except:
429
+ torch.backends.cudnn.enabled = False
430
+ return self.gru(x)[0]
431
+
432
  class E2E(nn.Module):
433
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16, hpa=False):
434
  super(E2E, self).__init__()
435
+ self.unet = (
436
+ HPADeepUnet(
437
+ in_channels=in_channels,
438
+ en_out_channels=en_out_channels,
439
+ base_channels=64,
440
+ hyperace_k=2,
441
+ hyperace_l=1,
442
+ num_hyperedges=16,
443
+ num_heads=4
444
+ )
445
+ ) if hpa else (
446
+ DeepUnet(
447
+ kernel_size,
448
+ n_blocks,
449
+ en_de_layers,
450
+ inter_layers,
451
+ in_channels,
452
+ en_out_channels
453
+ )
454
+ )
455
+
456
  self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
457
+ self.fc = (
458
+ nn.Sequential(
459
+ BiGRU(3 * 128, 256, n_gru),
460
+ nn.Linear(512, N_CLASS),
461
+ nn.Dropout(0.25),
462
+ nn.Sigmoid()
463
+ )
464
+ ) if n_gru else (
465
+ nn.Sequential(
466
+ nn.Linear(3 * N_MELS, N_CLASS),
467
+ nn.Dropout(0.25),
468
+ nn.Sigmoid()
469
+ )
470
+ )
471
 
472
  def forward(self, mel):
473
  return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
474
 
475
+ class MelSpectrogram(nn.Module):
476
+ def __init__(self, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
477
  super().__init__()
478
  n_fft = win_length if n_fft is None else n_fft
479
  self.hann_window = {}
 
486
  self.sample_rate = sample_rate
487
  self.n_mel_channels = n_mel_channels
488
  self.clamp = clamp
 
489
 
490
  def forward(self, audio, keyshift=0, speed=1, center=True):
491
  factor = 2 ** (keyshift / 12)
 
496
  n_fft = int(np.round(self.n_fft * factor))
497
  hop_length = int(np.round(self.hop_length * speed))
498
 
499
+ fft = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
500
+ magnitude = (fft.real.pow(2) + fft.imag.pow(2)).sqrt()
 
 
 
 
501
 
502
  if keyshift != 0:
503
  size = self.n_fft // 2 + 1
 
505
  if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
506
  magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
507
 
508
+ mel_output = self.mel_basis @ magnitude
509
+ return mel_output.clamp(min=self.clamp).log()
 
 
510
 
511
  class RMVPE:
512
+ def __init__(self, model_path, is_half, device=None, providers=None, onnx=False, hpa=False):
513
+ self.onnx = onnx
514
+
515
+ if self.onnx:
516
+ import onnxruntime as ort
517
+
518
+ sess_options = ort.SessionOptions()
519
+ sess_options.log_severity_level = 3
520
+ self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
521
+ else:
522
+ model = E2E(4, 1, (2, 2), 5, 4, 1, 16, hpa=hpa)
523
+
524
+ model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
525
+ model.eval()
526
+ if is_half: model = model.half()
527
+ self.model = model.to(device)
528
+
529
  self.device = device
530
+ self.is_half = is_half
531
+ self.mel_extractor = MelSpectrogram(N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
532
  cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
533
  self.cents_mapping = np.pad(cents_mapping, (4, 4))
534
 
535
+ def mel2hidden(self, mel, chunk_size = 32000):
536
  with torch.no_grad():
537
  n_frames = mel.shape[-1]
538
+ mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
 
539
 
540
+ output_chunks = []
541
+ pad_frames = mel.shape[-1]
542
+
543
+ for start in range(0, pad_frames, chunk_size):
544
+ mel_chunk = mel[..., start:min(start + chunk_size, pad_frames)]
545
+ assert mel_chunk.shape[-1] % 32 == 0
546
+
547
+ if self.onnx:
548
+ mel_chunk = mel_chunk.cpu().numpy().astype(np.float32)
549
+ out_chunk = torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: mel_chunk})[0], device=self.device)
550
+ else:
551
+ if self.is_half: mel_chunk = mel_chunk.half()
552
+ out_chunk = self.model(mel_chunk)
553
+
554
+ output_chunks.append(out_chunk)
555
+
556
+ hidden = torch.cat(output_chunks, dim=1)
557
  return hidden[:, :n_frames]
558
 
559
  def decode(self, hidden, thred=0.03):
 
565
  def infer_from_audio(self, audio, thred=0.03):
566
  hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
567
 
568
+ return self.decode(hidden.squeeze(0).cpu().numpy().astype(np.float32), thred=thred)
569
 
570
  def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
571
+ f0 = self.infer_from_audio(audio, thred)
 
 
572
  f0[(f0 < f0_min) | (f0 > f0_max)] = 0
573
 
574
  return f0
 
589
  devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
590
  devided[np.max(salience, axis=1) <= thred] = 0
591
 
592
+ return devided