Foreword
Generative modeling is one of the hottest topics in AI. It’s now possible to teach a machine to excel at human endeavors such as painting, writing, and composing music which is the case for my project to generate one second Gnaoua music (A moroccan kind of traditional music).
What is Generative modeling ?
“Generative modeling is an unsupervised learning task in machine learning that involves automatically discovering and learning the regularities or patterns in input data in such a way that the model can be used to generate or output new examples that plausibly could have been drawn from the original dataset.” A Gentle Introduction to Generative Adversarial Networks (GANs) .
Data preparation and preprocessing
Data collection
We used 10 videos to get samples to train our wavegan model, here it is one of them
Data preprocessing
In order to train our Gan model, we have to turn our mp4 video into small chunks of wav files of 2 seconds and fid it to the model in batches, check the audio below:
PyTorch offers a solution for paralleling the data loading process with automatic batching by using DataLoader,
To deal with data, we developed an iterable with the WaveLoader
helper class to load, manipulate and stream our data, this class calls
get_recursive_files
function which load all training
and validation data path as its name indicates and
audio_stream_sampling
which create a stream reader our data and stream it as batches.
class WaveLoader:
""""
The wave files loader and manipulator to train the network.
"""
def __init__(self, folder_path, audio_extension="wav"):
"""
The constructor of the class
:param folder_path: string
the training folder path.
:param audio_extension: string
audio extension
"""
self.audio_paths = get_recursive_files(folder_path, audio_extension)
self.data_iter = None
self.initialize_iterator()
def initialize_iterator(self):
"""
Create an iterator
:return:
"""
data_iter = audio_stream_sampling(self.audio_paths)
self.data_iter = iter(data_iter)
def __len__(self):
return len(self.audio_paths)
def __iter__(self):
return self
def __next__(self):
x = next(self.data_iter)
return self.numpy_to_tensor(x["single"])
@staticmethod
def numpy_to_tensor(numpy_array):
"""
Convert a numpy array to tensor.
:param numpy_array: ndarray array
:return: tensor
"""
numpy_array = numpy_array[:, np.newaxis, :]
return torch.Tensor(numpy_array).to(torch.device("cpu"))
def get_recursive_files(folder_path, ext):
results = os.listdir(folder_path)
out_files = []
for file in results:
if os.path.isdir(os.path.join(folder_path, file)):
out_files += get_recursive_files(os.path.join(folder_path, file), ext)
elif file.endswith(ext):
out_files.append(os.path.join(folder_path, file))
return out_files
The get_recursive_files
search recursively for files with ext
extension in the given path folder_path
and returns
all paths of files found in that folder.
audio_stream_sampling
function
def audio_stream_sampling(file_path_list):
"""
Create stream reader of audios
:param file_path_list:
:return:
"""
data_streams = []
# loop through training data.
for audio_path in file_path_list:
# create stream for each sample
stream = pescador.Streamer(sample_generator, audio_path)
data_streams.append(stream)
mux = pescador.ShuffledMux(data_streams)
batch_gen = pescador.buffer_stream(mux, BATCH_SIZE)
return batch_gen
The audio_stream_sampling
function creates a stream of each sample ( 2 seconds audio clips ) using an internal memory buffer,
since our training data is too large ( we build it like that even our data can obviously take place into memory ( roughly 400 MB of data ) )
using pescador
to stream samples from a generator sample_generator
and then takes
N streamers, and samples from them equally, guaranteeing all N streamers to be “active”, and finally returns
the buffer data from a stream into one data object with batch size BATCH_SIZE
.
def sample_generator(file_path):
"""
Load and sample sample and yield it.
:param file_path: string
path of the audio.
:return: sampled audio
"""
audio_data = load_wav(file_path)
while True:
sample = sample_audio(audio_data)
yield {"single": sample}
The sample_generator
function loads ( load_wav
) and
samples a WINDOW_LENGTH
( WINDOW_LENGTH = 16384 ) from the audio ( sample_audio
) and yields it
as numpy vector of type float32
.
def load_wav(wav_file_path, normalize_audio=True):
try:
audio_data, _ = librosa.load(wav_file_path, sr=16000)
if normalize_audio:
# clip magnitude
max_mag = np.max(np.abs(audio_data))
if max_mag > 1:
audio_data /= max_mag
except Exception as e:
LOGGER.error("Could not load {}: {}".format(wav_file_path, str(e)))
raise e
audio_len = len(audio_data)
if audio_len < WINDOW_LENGTH:
pad_length = WINDOW_LENGTH - audio_len
left_pad = pad_length // 2
right_pad = pad_length - left_pad
audio_data = np.pad(audio_data, (left_pad, right_pad), mode="constant")
return audio_data.astype("float32")
The load_wav
function tries to load the audio with the given path wav_file_path
and load the audio signal with
librosa
package and to turn a signal into a finite set of numbers, we need to choose a sampling rate (sr
= 16000 Hz), that
means that we will write down the value of the amplitude in intervals of 1/16000 seconds, so that an audio clip of
2 seconds will result in a single vector of 32000 numbers ( if the audio is stereo we will have 2 vectors of the same
length, one for each audio channel ).
def sample_audio(audio_data, start_index=None, end_index=None):
"""
Sample a window from an audio.
:param audio_data: an audio as array
:param start_index: start index if provided
:param end_index: end index if provided.
:return: the sample extracted
"""
audio_len = len(audio_data)
# the window and audio length are the same
if audio_len == WINDOW_LENGTH:
# yield the entire audio.
sample = audio_data
else:
# select a random window from the audio
if start_index is None or end_index is None:
start_index = np.random.randint(0, (audio_len - WINDOW_LENGTH) // 2)
end_index = start_index + WINDOW_LENGTH
sample = audio_data[start_index:end_index]
sample = sample.astype("float32")
return sample
The sample_audio
function samples a window from the given array representation of the audio,
windowing is used in spectral analysis, to view a short time segment of a longer signal and
analyze its frequency content. Windows are also used to create short sound segments of a few milliseconds’
duration called “grains”, which can be combined into granular sound clouds for unique sorts of synthesis.
In general, one can think of any finite sound that has a starting point and a stopping point
as being a windowed segment in time.
GANs preliminaries
The image above is the overall process of training a generative adversarial network (GANs), to train these models (the discriminator and generator) we begin by training the discriminator model, which a binary classification that learns to recognize real (training data) from fake data (generated by the generator).
To train the discriminator we use the training data with true
label and the generated sample by the generator from
random noise with false
label, the model updates its weights and learns how to differentiate between real and fake data.
Now, the turn of the generator to be trained, to train the generator we must freeze the discriminator from updating the previously learned weights, otherwise, the discriminator would learn to be more convinced by the generator every time the generator is trained, undoing anything useful it could have learned ! and get the predictions' discriminator and use it as an objective for training the generator model.
Model architecture
We used the WaveGAN model architecture proposed by Adversarial Audio Synthesis
, published as a conference paper at ICLR 2019.
WaveGAN is Generative Adversarial Network architecture capable of synthesizing audio. The network structure is extremely similar to the one called DCGAN - Deep convolutional generative adversarial networks - proposed also as a conference paper at ICLR 2016 using convolutional layers in both the generator and the discriminator.
“GANs consists of two networks, a Generator $\mathcal{G(X)}$, and a Discriminator $\mathcal{D(X)}$. They both play an adversarial game where the generator tries to fool the discriminator by generating data similar to those in the training set. The Discriminator tries not to be fooled by identifying fake data from real data.“ GANs — A Brief Introduction to Generative Adversarial Networks .
The Discriminator tries not to be fooled by identifying fake data from real data. They both work simultaneously to learn and train on complex data (Ex. audio in our case).
GANs learn mappings from low-dimensional latent vectors $z\in \mathcal{Z}$, i.i.d. samples from known prior $P_{Z}$, to points in the space of natural data $\mathcal{X}$ . In their original formulation (Goodfellow et al., 2014), a generator $G: \mathcal{Z} \mapsto \mathcal{X}$ is pitted against a discriminator $D: \mathcal{X} \mapsto[0,1]$ in a two-player minimax game. $\mathcal{G}$ is trained to minimize the following value function, while $\mathcal{D}$ is trained to maximize it.
In other words, $\mathcal{D}$ is trained to determine if an example is real or fake, and $\mathcal{G}$ is trained to fool the discriminator into thinking its output is real. Goodfellow et al. (2014) demonstrate that their proposed training algorithm for ${Equation \hspace{1mm} 1}$ equates to minimizing the Jensen-Shannon divergence between $P_{X}$, the data distribution, and $P_{G}$, the implicit distribution of the generator when $z \sim\ P_{Z}$.
In this original formulation, GANs are notoriously difficult to train, and prone to catastrophic failure cases. Instead of Jensen-Shannon divergence, Arjovsky et al. (2017) suggest minimizing the smoother Wasserstein-1 distance between generated and data distributions.
Where $||f||_{L} \leq 1: \mathcal{X} \mapsto \mathbb{R}$ is the family of functions that are $1-Lipschitz$. To minimize $Wasserstein \hspace{1mm} distance$, they suggest a GAN training algorithm (WGAN), similar to that of Goodfellow et al. (2014), for the following value function:
With this formulation, $D_{w}: \mathcal{X} \mapsto \mathbb{R}$ is not trained to identify examples as real or fake, but instead is trained as a function that assists in computing the Wasserstein distance. Arjovsky et al. (2017) suggest weight clipping as a means of enforcing that $D_{w}$ is $1-Lipschitz$. As an alternative strategy, Gulrajani et al. (2017) replace weight clipping with a gradient penalty (WGAN-GP) that also enforces the constraint. They demonstrate that their WGAN-GP strategy can successfully train a variety of model configurations where other GAN losses fail.
The representation above is from the paper discussed earlier, and it shows the depiction of the transposed convolution operation for the first layer of the waveGAN generator with one-dimensional filters and an upsampling factor (in signal processing system, upsampling is the process of resampling and it produces an approximation of the sequence that would have been obtained by sampling the signal at a higher rate samples of a signal).
When using generative model for images (Ex. DCGAN), we often see a strange checkerboard pattern of artifacts in the generated
images, why ? when using images in neural networks, we often build them from low-resolution (aka. latent space)
to high level description, and we do the that using the deconvolution layers (or transposed layer) these layers can have easily
“uneven overlap” (happens when the kernel size is not divisible by the stride) even if the network updates its weights to
prevent this from happening, the network often struggle to avoid it completely (for more details, check the awesome article
here
).
That said, for audio, pitched noise artifacts are present, moreover these artifacts occur a particular
phase, allowing the discriminator model to reject generated samples as fake (discriminator’s job is easy for this case),
perhaps we want the discriminator to be fooled by identifying fake data from real data (this is how it learns) which ensures that the generated
samples looks like real, to solve the that issue we can shuffle each layer’s activations before input to the next layer,
the operation is called Phase Shuffle
.
The figure above illustrates the operation of Phase shuffle
, which perturbs the phase of each feature map by Uniform [−n, n]
samples at each layer of the discriminator model, by filling in the missing samples (dashed outlines) by reflection.
See the image
above.
The phase shuffle operation is only applied to the discriminator model only, as the latent space already provides the generator a mechanism to manipulate the phase of a resultant waveform.
We copied the implementation of the phase shuffle operation from this
repository
.
class PhaseShuffle(nn.Module):
"""
Performs phase shuffling, i.e. shifting feature axis of a 3D tensor
by a random integer in {-n, n} and performing reflection padding where
necessary.
"""
# Copied from https://github.com/jtcramer/wavegan/blob/master/wavegan.py#L8
def __init__(self, shift_factor):
super(PhaseShuffle, self).__init__()
self.shift_factor = shift_factor
def forward(self, x):
if self.shift_factor == 0:
return x
# uniform in (L, R)
k_list = (
torch.Tensor(x.shape[0]).random_(0, 2 * self.shift_factor + 1)
- self.shift_factor
)
k_list = k_list.numpy().astype(int)
# Combine sample indices into lists so that less shuffle operations
# need to be performed
k_map = {}
for idx, k in enumerate(k_list):
k = int(k)
if k not in k_map:
k_map[k] = []
k_map[k].append(idx)
# Make a copy of x for our output
x_shuffle = x.clone()
# Apply shuffle to each sample
for k, idxs in k_map.items():
if k > 0:
x_shuffle[idxs] = F.pad(x[idxs][..., :-k], (k, 0), mode="reflect")
else:
x_shuffle[idxs] = F.pad(x[idxs][..., -k:], (0, -k), mode="reflect")
assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape, x.shape)
return x_shuffle
Discriminator model
The discriminator architecture is roughly the as the DCGAN’s discriminator, the discriminator is modified using length-25 (5 * 5) filters in one dimension and increasing stride from 2 to 4 of the convolution layers in the WaveGAN discriminator.
We don’t use batch normalization in both the discriminator and the generator in the WaveGAN architecture as the opposite of DCGAN architecture.
The discriminator model architecture proposed by the paper is as follows,
Our implementation of the discriminator
class Discriminator(nn.Module):
"""
The discriminator model.
Adapted from: https://github.com/caogang/wgan-gp/blob/ae47a185ed2e938c39cf3eb2f06b32dc1b6a2064/gan_toy.py#L63
"""
def __init__(
self,
model_size=64,
num_channels=1,
shuffle_factor=2,
alpha=0.2,
):
super(Discriminator, self).__init__()
self.model_size = model_size
self.num_channels = num_channels
self.shift_factor = shuffle_factor
self.alpha = alpha
# build the model
self._build()
def _build(self):
"""
Build the stacked layers.
:return:
"""
num_channels = self.num_channels
model_size = self.model_size
shift_factor = self.shift_factor
input_channels_iter = [num_channels, model_size,
2 * model_size, 4 * model_size,
8 * model_size]
output_channels_iter = [model_size, 2 * model_size,
4 * model_size, 8 * model_size,
16 * model_size]
shift_factor_iter = [shift_factor, shift_factor,
shift_factor, shift_factor, 0]
convolutional_layers = [
conv1d(
i,
j,
25,
stride=4,
padding=11,
alpha=self.alpha,
shift_factor=k,
) for i, j, k in zip(input_channels_iter,
output_channels_iter, shift_factor_iter)
]
self.fc_input_size = 256 * model_size
self.convolutional_layers = nn.ModuleList(convolutional_layers)
self.fc1 = nn.Linear(self.fc_input_size, 1)
# forward pass
def forward(self, x):
for conv in self.convolutional_layers:
x = conv(x)
x = x.view(-1, self.fc_input_size)
return self.fc1(x)
The Conv1dLeakyReluPhaseShuffle
class
class Conv1dLeakyReluPhaseShuffle(nn.Module):
def __init__(
self,
input_channels,
output_channels,
kernel_size,
alpha=0.2,
shift_factor=2,
stride=([4, ], ),
padding=([11, ], ),
use_batch_norm=False,
drop_prob=0,
):
super(Conv1dLeakyReluPhaseShuffle, self).__init__()
self.conv1d = nn.Conv1d(
input_channels, output_channels, kernel_size, stride=stride, padding=padding
)
self.batch_norm = nn.BatchNorm1d(output_channels)
self.phase_shuffle = PhaseShuffle(shift_factor)
self.alpha = alpha
self.use_phase_shuffle = shift_factor == 0
self.use_batch_norm = use_batch_norm
self.use_drop = drop_prob > 0
self.dropout = nn.Dropout2d(drop_prob)
def forward(self, x):
x = self.conv1d(x)
if self.use_batch_norm:
x = self.batch_norm(x)
x = f.leaky_relu(x, negative_slope=self.alpha)
if self.use_phase_shuffle:
x = self.phase_shuffle(x)
if self.use_drop:
x = self.dropout(x)
return x
Generator model
Our implementation of the Generator
class
class Generator(nn.Module):
"""
The generator model.
Adapted from: https://github.com/caogang/wgan-gp/blob/ae47a185ed2e938c39cf3eb2f06b32dc1b6a2064/gan_toy.py#L39
"""
def __init__(
self,
model_size=64,
num_channels=1
):
super(Generator, self).__init__()
self.model_size = model_size # d
self.num_channels = num_channels # c
self.latent_dim = 100
self.dim_mul = 16
self.up_sample = True,
self.stride = 4
self.fc1 = nn.Linear(self.latent_dim, 4 * 4 * model_size * self.dim_mul)
# build the model
self._build()
def _build(self):
"""
Build the network.
:return: None
"""
num_channels = self.num_channels
model_size = self.model_size
dim_mul = self.dim_mul
input_channels_iter = [(model_size * dim_mul // i) for i in [1, 2, 4, 8, 16]]
output_channels_iter = [(model_size * dim_mul // i) for i in [2, 4, 8, 16, 1]]
deconvolution_layers = [
ConvTranspose1d(
i,
j,
25,
stride=4,
) for i, j in zip(input_channels_iter, output_channels_iter)
]
self.deconvolution_layers_list = nn.ModuleList(deconvolution_layers)
def forward(self, x):
x = self.fc1(x).view(-1, self.dim_mul * self.model_size, 16)
x = f.relu(x)
for deconv_layer in self.deconvolution_layers_list[:-1]:
x = f.relu(deconv_layer(x))
output = torch.tanh(self.deconvolution_layers_list[-1](x))
return output
The ConvTranspose1d
class
class ConvTranspose1d(nn.Module):
def __init__(
self,
input_channels,
output_channels,
kernel_size,
stride=([4, ], ),
padding=([11, ], ),
upsample=None,
output_padding=([1, ], ),
use_batch_norm=False,
):
super(ConvTranspose1d, self).__init__()
self.upsample = upsample
reflection_pad = nn.ConstantPad1d(kernel_size // 2, value=0)
conv1d = nn.Conv1d(input_channels, output_channels, kernel_size, stride)
conv1d.weight.data.normal_(0.0, 0.02)
conv1d_trans = nn.ConvTranspose1d(
input_channels, output_channels, kernel_size, stride, padding, output_padding
)
batch_norm = nn.BatchNorm1d(output_channels)
if self.upsample:
operation_list = [reflection_pad, conv1d]
else:
operation_list = [conv1d_trans]
if use_batch_norm:
operation_list.append(batch_norm)
self.transpose_ops = nn.Sequential(*operation_list)
self.upsample = upsample
def forward(self, x):
# use nearest mode as pointed on the paper
if self.upsample:
# recommended by wavgan paper to use nearest upsampling
x = nn.functional.interpolate(x, scale_factor=self.upsample, mode="nearest")
return self.transpose_ops(x)
The waveGan
model
class Wavegan(object):
"""
The wavegan model, with generator and discriminator combined.
"""
def __init__(self, train_loader, val_loader):
super(Wavegan, self).__init__()
self.g_cost = []
self.train_d_cost = []
self.train_w_distance = []
self.valid_g_cost = [-1]
self.valid_reconstruction = []
self.discriminator = Discriminator(
model_size=MODEL_CAPACITY_SIZE,).to(torch.device("cpu"))
self.discriminator.apply(weights_init)
self.generator = Generator(
model_size=MODEL_CAPACITY_SIZE,).to(torch.device("cpu"))
self.generator.apply(weights_init)
self.optimizer_g = optim.Adam(
self.generator.parameters(), lr=LR_G, betas=(ADAM_BETA_1, ADAM_BETA_2)
) # Setup Adam optimizers for both G and D
self.optimizer_d = optim.Adam(
self.discriminator.parameters(), lr=LR_D, betas=(ADAM_BETA_1, ADAM_BETA_2)
)
self.train_loader = train_loader
self.val_loader = val_loader
self.n_samples_per_batch = len(train_loader)
def calculate_discriminator_loss(self, real, generated):
disc_out_gen = self.discriminator(generated)
disc_out_real = self.discriminator(real)
alpha = torch.FloatTensor(BATCH_SIZE, 1, 1).uniform_(0, 1).to(torch.device("cpu"))
alpha = alpha.expand(BATCH_SIZE, real.size(1), real.size(2))
interpolated = (1 - alpha) * real.data + (alpha) * generated.data[:BATCH_SIZE]
interpolated = Variable(interpolated, requires_grad=True)
# calculate probability of interpolated examples
prob_interpolated = self.discriminator(interpolated)
grad_inputs = interpolated
ones = torch.ones(prob_interpolated.size()).to(torch.device("cpu"))
gradients = grad(
outputs=prob_interpolated,
inputs=grad_inputs,
grad_outputs=ones,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
# calculate gradient penalty
grad_penalty = (
P_COEFF
* ((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2).mean()
)
cost_wd = disc_out_gen.mean() - disc_out_real.mean()
cost = cost_wd + grad_penalty
return cost, cost_wd
def apply_zero_grad(self):
self.generator.zero_grad()
self.optimizer_g.zero_grad()
self.discriminator.zero_grad()
self.optimizer_d.zero_grad()
def enable_disc_disable_gen(self):
gradients_status(self.discriminator, True)
gradients_status(self.generator, False)
def enable_gen_disable_disc(self):
gradients_status(self.discriminator, False)
gradients_status(self.generator, True)
def disable_all(self):
gradients_status(self.discriminator, False)
gradients_status(self.generator, False)
def train(self):
progress_bar = tqdm(total=N_ITERATIONS // PROGRESS_BAR_STEP_ITER_SIZE)
fixed_noise = sample_noise(BATCH_SIZE).to(torch.device("cpu"))
gan_model_name = "gan_{}.tar".format('wavegan')
first_iter = 0
if TAKE_BACKUP and os.path.isfile(gan_model_name):
checkpoint = torch.load(gan_model_name, map_location="cpu")
self.generator.load_state_dict(checkpoint["generator"])
self.discriminator.load_state_dict(checkpoint["discriminator"])
self.optimizer_d.load_state_dict(checkpoint["optimizer_d"])
self.optimizer_g.load_state_dict(checkpoint["optimizer_g"])
self.train_d_cost = checkpoint["train_d_cost"]
self.train_w_distance = checkpoint["train_w_distance"]
self.valid_g_cost = checkpoint["valid_g_cost"]
self.g_cost = checkpoint["g_cost"]
first_iter = checkpoint["n_iterations"]
for i in range(0, first_iter, PROGRESS_BAR_STEP_ITER_SIZE):
progress_bar.update()
self.generator.eval()
with torch.no_grad():
fake = self.generator(fixed_noise).detach().cpu().numpy()
save_audio(fake, first_iter)
self.generator.train()
self.discriminator.train()
for iter_indx in range(first_iter, N_ITERATIONS):
self.enable_disc_disable_gen()
for _ in range(N_CRITIC):
real_signal = next(self.train_loader)
noise = sample_noise(BATCH_SIZE)
generated = self.generator(noise)
self.apply_zero_grad()
disc_cost, disc_wd = self.calculate_discriminator_loss(
real_signal.data, generated.data
)
disc_cost.backward()
self.optimizer_d.step()
self.apply_zero_grad()
self.enable_gen_disable_disc()
noise = sample_noise(BATCH_SIZE)
generated = self.generator(noise)
discriminator_output_fake = self.discriminator(generated)
generator_cost = -discriminator_output_fake.mean()
generator_cost.backward()
self.optimizer_g.step()
self.disable_all()
if iter_indx % SAVE_SAMPLES_EVERY == 0:
with torch.no_grad():
interpolate_latent_space(self.generator, n_samples=2)
fake = self.generator(fixed_noise).detach().cpu().numpy()
save_audio(fake, iter_indx)
if TAKE_BACKUP and iter_indx % BACKUP_EVERY_N_ITERS == 0:
saving_dict = {
"generator": self.generator.state_dict(),
"discriminator": self.discriminator.state_dict(),
"n_iterations": iter_indx,
"optimizer_d": self.optimizer_d.state_dict(),
"optimizer_g": self.optimizer_g.state_dict(),
"train_d_cost": self.train_d_cost,
"train_w_distance": self.train_w_distance,
"valid_g_cost": self.valid_g_cost,
"g_cost": self.g_cost,
}
torch.save(saving_dict, gan_model_name)
self.generator.eval()
Results
The following clips are the training sample (the first audio instance) and 6 audios generated at different epoch during training the network.
Resources
- https://sites.uci.edu/computermusic/2013/05/14/windowing-an-audio-signal/
- https://towardsdatascience.com/synthesizing-audio-with-generative-adversarial-networks-8e0308184edd/
- https://arxiv.org/pdf/1802.04208v3.pdf/
- https://en.wikipedia.org/wiki/Upsampling
- https://distill.pub/2016/deconv-checkerboard/
- https://www.quora.com/Why-are-the-weights-frozen-in-the-discriminator-of-GANs-during-training
- https://arxiv.org/abs/1701.07875
- https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
Future work
We aim to build an application that generate DARIJA dialect human speech with the image-like feature representation with SpecGAN model trained on multiple GPUs.
Credits
- Background picture
Image by idouissadene rachid via Naive Art.
- Audio Player
Player component for reactJs by lijinke666 via react-jinke-music-player.
Conclusion
The trained model used to generate the above audio samples was trained on CPU only due to the lack of GPUs. It tokes us 5 days to train the model and just after shrinking data into 20% of full training data we’ve so far to make basic prototype.
I would like to sincerely thank you for reading this article, I really appreciate it, hoping that you managed to learn something new.
Have fun!