Skip to content

Commit

Permalink
modified zipformer
Browse files Browse the repository at this point in the history
  • Loading branch information
dkulko committed Oct 17, 2024
1 parent f84270c commit 6072407
Show file tree
Hide file tree
Showing 5 changed files with 2,138 additions and 4,088 deletions.
225 changes: 136 additions & 89 deletions egs/librispeech/ASR/zipformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,115 +20,162 @@
from scaling import Balancer


class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
class Decoder(torch.nn.Module):
"""
This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
It removes the recurrent connection from the decoder, i.e., the prediction network.
Different from the above paper, it adds an extra Conv1d right after the embedding layer.
"""

def __init__(
self,
vocab_size: int,
decoder_dim: int,
blank_id: int,
context_size: int,
):
self, vocab_size: int, decoder_dim: int, context_size: int, device: torch.device,
) -> None:
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
decoder_dim:
Dimension of the input embedding, and of the decoder output.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
Decoder initialization.
Parameters
----------
vocab_size : int
A number of tokens or modeling units, includes blank.
decoder_dim : int
A dimension of the decoder embeddings, and the decoder output.
context_size : int
A number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
super().__init__()

self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=decoder_dim,
)
# the balancers are to avoid any drift in the magnitude of the
# embeddings, which would interact badly with parameter averaging.
self.balancer = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
super().__init__()

self.blank_id = blank_id
self.embedding = torch.nn.Embedding(vocab_size, decoder_dim)

assert context_size >= 1, context_size
if context_size < 1:
raise ValueError(
'RNN-T decoder context size should be an integer greater '
f'or equal than 1, but got {context_size}.',
)
self.context_size = context_size
self.vocab_size = vocab_size

if context_size > 1:
self.conv = nn.Conv1d(
in_channels=decoder_dim,
out_channels=decoder_dim,
kernel_size=context_size,
padding=0,
groups=decoder_dim // 4, # group size == 4
bias=False,
)
self.balancer2 = Balancer(
decoder_dim,
channel_dim=-1,
min_positive=0.0,
max_positive=1.0,
min_abs=0.5,
max_abs=1.0,
prob=0.05,
)
else:
# To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'`
# when inference with torch.jit.script and context_size == 1
self.conv = nn.Identity()
self.balancer2 = nn.Identity()
self.conv = torch.nn.Conv1d(
decoder_dim,
decoder_dim,
context_size,
groups=decoder_dim // 4,
bias=False,
device=device,
)

def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
def forward(self, y: torch.Tensor) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, decoder_dim).
Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.
Parameters
----------
y : torch.Tensor[torch.int32]
The input integer tensor of shape (N, context_size).
The module input that corresponds to the last context_size decoded token indexes.
Returns
-------
torch.Tensor[torch.float32]
An output float tensor of shape (N, 1, decoder_dim).
"""
y = y.to(torch.int64)
# this stuff about clamp() is a temporary fix for a mismatch
# at utterance start, we use negative ids in beam_search.py
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1)

embedding_out = self.balancer(embedding_out)
# this stuff about clamp() is a fix for a mismatch at utterance start,
# we use negative ids in RNN-T decoding.
embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(2)

if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = F.relu(embedding_out)
embedding_out = self.balancer2(embedding_out)
embedding_out = torch.nn.functional.relu(embedding_out)

return embedding_out


class DecoderModule(torch.nn.Module):
"""
A helper module to combine decoder, decoder projection, and joiner inference together.
"""

def __init__(
self,
vocab_size: int,
decoder_dim: int,
joiner_dim: int,
context_size: int,
beam: int,
device: torch.device,
) -> None:
"""
DecoderModule initialization.
Parameters
----------
vocab_size:
A number of tokens or modeling units, includes blank.
decoder_dim : int
A dimension of the decoder embeddings, and the decoder output.
joiner_dim : int
Input joiner dimension.
context_size : int
A number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
beam : int
A decoder beam.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""

super().__init__()

self.decoder = Decoder(vocab_size, decoder_dim, context_size, device)
self.decoder_proj = torch.nn.Linear(decoder_dim, joiner_dim, device=device)
self.joiner = Joiner(joiner_dim, vocab_size, device)

self.vocab_size = vocab_size
self.beam = beam

def forward(
self, decoder_input: torch.Tensor, encoder_out: torch.Tensor, hyps_log_prob: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Does a forward pass of the stateless Decoder module. Returns an output decoder tensor.
Parameters
----------
decoder_input : torch.Tensor[torch.int32]
The input integer tensor of shape (num_hyps, context_size).
The module input that corresponds to the last context_size decoded token indexes.
encoder_out : torch.Tensor[torch.float32]
An output tensor from the encoder after projection of shape (num_hyps, joiner_dim).
hyps_log_prob : torch.Tensor[torch.float32]
Hypothesis probabilities in a logarithmic scale of shape (num_hyps, 1).
Returns
-------
torch.Tensor[torch.float32]
A float output tensor of logit token probabilities of shape (num_hyps, vocab_size).
"""

decoder_out = self.decoder(decoder_input)
decoder_out = self.decoder_proj(decoder_out)

logits = self.joiner(encoder_out, decoder_out[:, 0, :])

tokens_log_prob = torch.log_softmax(logits, dim=1)
log_probs = (tokens_log_prob + hyps_log_prob).reshape(-1)

hyps_topk_log_prob, topk_indexes = log_probs.topk(self.beam)
topk_hyp_indexes = torch.floor_divide(topk_indexes, self.vocab_size).to(torch.int32)
topk_token_indexes = torch.remainder(topk_indexes, self.vocab_size).to(torch.int32)
tokens_topk_prob = torch.exp(tokens_log_prob.reshape(-1)[topk_indexes])

return hyps_topk_log_prob, tokens_topk_prob, topk_hyp_indexes, topk_token_indexes
73 changes: 33 additions & 40 deletions egs/librispeech/ASR/zipformer/joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,42 @@
from scaling import ScaledLinear


class Joiner(nn.Module):
def __init__(
self,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
super().__init__()

self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
self.output_linear = nn.Linear(joiner_dim, vocab_size)
class Joiner(torch.nn.Module):

def forward(
self,
encoder_out: torch.Tensor,
decoder_out: torch.Tensor,
project_input: bool = True,
) -> torch.Tensor:
def __init__(self, joiner_dim: int, vocab_size: int, device: torch.device) -> None:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, s_range, C).
decoder_out:
Output from the decoder. Its shape is (N, T, s_range, C).
project_input:
If true, apply input projections encoder_proj and decoder_proj.
If this is false, it is the user's responsibility to do this
manually.
Returns:
Return a tensor of shape (N, T, s_range, C).
Joiner initialization.
Parameters
----------
joiner_dim : int
Input joiner dimension.
vocab_size : int
Output joiner dimension, the vocabulary size, the number of BPEs of the model.
device : torch.device
The device used to store the layer weights. Should be
either torch.device("cpu") or torch.device("cuda").
"""
assert encoder_out.ndim == decoder_out.ndim, (
encoder_out.shape,
decoder_out.shape,
)

if project_input:
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
else:
logit = encoder_out + decoder_out
super().__init__()

self.output_linear = torch.nn.Linear(joiner_dim, vocab_size, device=device)

logit = self.output_linear(torch.tanh(logit))
def forward(self, encoder_out: torch.Tensor, decoder_out: torch.Tensor) -> torch.Tensor:
"""
Does a forward pass of the Joiner module. Returns an output tensor after a simple joining.
Parameters
----------
encoder_out : torch.Tensor[torch.float32]
An output tensor from the encoder after projection of shape (N, joiner_dim).
decoder_out : torch.Tensor[torch.float32]
An output tensor from the decoder after projection of shape (N, joiner_dim).
Returns
-------
torch.Tensor[torch.float32]
A float output tensor of log token probabilities of shape (N, vocab_size).
"""

return logit
return self.output_linear(torch.tanh(encoder_out + decoder_out))
Loading

0 comments on commit 6072407

Please sign in to comment.