Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous fixes to the x-transformers implementation #79

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions mammoth/distributed/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,50 @@ def needs_communication(self) -> bool:
return self.group is not None


# TODO: This is a misnomer: Not an entire XCoder, but just one AttentionLayers block
@dataclass # type: ignore
class DistributedTransformerWrapper(DistributedComponent, ABC):
task_id: str
side: Side

def get_name(self) -> str:
return f'{self.side.name}_{self.task_id}'

def get_module(self, model: NMTModel) -> nn.Module:
parent = model.encoder if self.side == Side.encoder else model.decoder
tw = parent[self.task_id]
return tw

def named_parameters(self, model: NMTModel):
module = self.get_module(model)
for name, p in module.named_parameters():
# TransformerWrapper contains the AttentionLayers and the embs.
# however, we want to treat these as distinct DistributedComponents
if name.startswith('attn_layers.'):
continue
if name.startswith('token_emb.'):
continue
yield name, p

def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
if name.endswith('attn_layers'):
# stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination

def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
module = self.get_module(model)
mismatch = module.load_state_dict(state_dict, strict=False)
missing_keys = [
name for name in mismatch.missing_keys
if not name.startswith('attn_layers.') or name.startswith('token_emb.')
]
return mismatch._replace(missing_keys=missing_keys)


@dataclass # type: ignore
class DistributedAttentionLayersBlock(DistributedComponent, ABC):
layer_stack_index: int
Expand All @@ -106,22 +149,32 @@ def named_parameters(self, model: NMTModel):
for name, p in module.named_parameters():
# encoders and decoders contain embeddings and adapters as submodules
# however, we want to treat these as distinct DistributedComponents
if 'embeddings' not in name and 'adapter' not in name:
yield name, p
if 'adapter' in name:
continue
yield name, p

def state_dict(self, model: NMTModel, prefix='', keep_vars=False) -> Dict[str, Any]:
module = self.get_module(model)
destination: Dict[str, Any] = OrderedDict()
for name, sub_module in module._modules.items():
for name, sub_module in module.get_sub_modules().items():
if name == 'adapters':
# Adapters are stored separately
continue
sub_module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination

def load_state_dict(self, model: NMTModel, state_dict: Dict[str, Any]):
module = self.get_module(model)
mismatch = module.load_state_dict(state_dict, strict=False)
missing_keys = [
name for name in mismatch.missing_keys
if not name.startswith('layers.')
]
return mismatch._replace(missing_keys=missing_keys)


@dataclass
class DistributedEncoder(DistributedAttentionLayersBlock):
class DistributedEncoderAttentionLayersBlock(DistributedAttentionLayersBlock):
@property
def side(self) -> Side:
return Side.encoder
Expand All @@ -136,7 +189,7 @@ def get_module(self, model: NMTModel) -> nn.Module:


@dataclass
class DistributedDecoder(DistributedAttentionLayersBlock):
class DistributedDecoderAttentionLayersBlock(DistributedAttentionLayersBlock):
@property
def side(self) -> Side:
return Side.decoder
Expand Down
27 changes: 23 additions & 4 deletions mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
DistributedComponent,
DistributedComponentBuilder,
DistributedComponentGradientSync,
DistributedDecoder,
DistributedDecoderAttentionLayersBlock,
DistributedEmbedding,
DistributedEncoder,
DistributedEncoderAttentionLayersBlock,
DistributedTransformerWrapper,
Side,
)
from mammoth.distributed.contexts import DeviceContext, WorldContext
Expand Down Expand Up @@ -369,9 +370,27 @@ def create_all_distributed_components(
lang=task.tgt_lang,
)
)
builder.add(
DistributedTransformerWrapper(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
side=Side.encoder,
task_id=task.corpus_id,
)
)
builder.add(
DistributedTransformerWrapper(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
side=Side.decoder,
task_id=task.corpus_id,
)
)
for layer_stack_index, encoder_id in enumerate(task.encoder_id):
builder.add(
DistributedEncoder(
DistributedEncoderAttentionLayersBlock(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
Expand All @@ -381,7 +400,7 @@ def create_all_distributed_components(
)
for layer_stack_index, decoder_id in enumerate(task.decoder_id):
builder.add(
DistributedDecoder(
DistributedDecoderAttentionLayersBlock(
global_ranks={global_rank},
task_ids={task.corpus_id},
group=None,
Expand Down
66 changes: 51 additions & 15 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from mammoth.distributed.components import (
DistributedAdapter,
DistributedComponent,
DistributedDecoder,
DistributedEncoder,
DistributedDecoderAttentionLayersBlock,
DistributedEncoderAttentionLayersBlock,
Side,
)
from mammoth.modules.adapters import (
Expand All @@ -31,6 +31,14 @@
from mammoth.utils.logging import logger
from mammoth.utils.misc import use_gpu

TRANSFORMER_WRAPPER_OPTS = {
'post_emb_norm',
'tie_embedding',
'use_abs_pos_emb',
'scaled_sinu_pos_emb',
'emb_frac_gradient',
}


def _combine_ordered_dicts(input_dicts: Dict[str, OrderedDict]) -> OrderedDict:
result = []
Expand Down Expand Up @@ -59,17 +67,32 @@ def get_attention_layers_kwargs(
is_last = layer_stack_index == len(depths) - 1
pre_norm_has_final_norm = is_last
kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict()
kwargs = {key: val for key, val in kwargs.items() if key not in TRANSFORMER_WRAPPER_OPTS}
kwargs.update({
'dim': model_opts.model_dim,
'depth': depth,
'heads': model_opts.heads,
'causal': causal,
'cross_attend': cross_attend,
'pre_norm_has_final_norm': pre_norm_has_final_norm,
})
return kwargs


def get_transformer_wrapper_kwargs(
side: Side,
model_opts,
):
"""Return arguments for x_transformers.TransformerWrapper"""
assert side in {Side.encoder, Side.decoder}, f'Invalid side "{side}"'
kwargs = model_opts.x_transformers_opts if model_opts.x_transformers_opts else dict()
kwargs = {key: val for key, val in kwargs.items() if key in TRANSFORMER_WRAPPER_OPTS}
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
kwargs.update({
'max_seq_len': max_seq_len,
})
return kwargs


def build_xcoder(
side: Side,
model_opts,
Expand All @@ -96,10 +119,10 @@ def build_xcoder(
]
distributed_xcoder_class: type
if side == Side.encoder:
distributed_xcoder_class = DistributedEncoder
distributed_xcoder_class = DistributedEncoderAttentionLayersBlock
side_str = 'encoder'
else:
distributed_xcoder_class = DistributedDecoder
distributed_xcoder_class = DistributedDecoderAttentionLayersBlock
side_str = 'decoder'
if single_task:
my_components = [
Expand Down Expand Up @@ -197,6 +220,10 @@ def build_xcoder(
if single_task:
tasks = [task for task in tasks if task.corpus_id == single_task]
transformer_wrappers = dict()
transformer_wrapper_kwargs = get_transformer_wrapper_kwargs(
side=side,
model_opts=model_opts,
)
for task in tasks:
if side == Side.encoder:
xcoder_ids = task.encoder_id
Expand All @@ -212,22 +239,13 @@ def build_xcoder(

lang = task.src_lang if side == Side.encoder else task.tgt_lang
vocab = vocabs_dict[(side_alt_str, lang)]
max_seq_len = 0 if model_opts.max_length is None else model_opts.max_length
post_emb_norm = True
tie_embedding = True
use_abs_pos_emb = True
emb_frac_gradient = 1.
# Using custom extended TransformerWrapper to allow passing in an embedding
transformer_wrapper = TransformerWrapper(
num_tokens=len(vocab),
max_seq_len=max_seq_len,
attn_layers=adapted_attention_layers_stack,
emb_dim=model_opts.model_dim,
post_emb_norm=post_emb_norm,
tie_embedding=tie_embedding,
use_abs_pos_emb=use_abs_pos_emb,
emb_frac_gradient=emb_frac_gradient,
token_emb=token_embs[lang],
**transformer_wrapper_kwargs,
)
transformer_wrappers[task.corpus_id] = transformer_wrapper

Expand Down Expand Up @@ -310,3 +328,21 @@ def build_model(
# logger.info(model)
logger.info('Building model - done!')
return model


def validate_optimizer_coverage(model, optimizer):
trainable_model_params = {
name: p for name, p in model.named_parameters()
if p.requires_grad
}
optimized_params = set()
for group in optimizer.param_groups:
optimized_params.update(group['params'])
missing_params = [
name for name, p in trainable_model_params.items()
if p not in optimized_params
]
if len(missing_params) > 0:
raise Exception(f'Missing optimizer for params: {sorted(missing_params)}')
else:
logger.info('All non-frozen parameters have an optimizer')
7 changes: 7 additions & 0 deletions mammoth/modules/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,3 +225,10 @@ def _inject_adapters(self):
def forward(self, *args, **kwargs):
self._inject_adapters()
return super().forward(*args, **kwargs)

def get_sub_modules(self):
omit_submodules = {'layers'}
return {
name: sub_module for name, sub_module in self._modules.items()
if name not in omit_submodules
}
29 changes: 5 additions & 24 deletions mammoth/modules/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def __init__(

"""
super(TransformerDecoderLayerBase, self).__init__()
assert not full_context_alignment, 'alignment is obsolete'
assert alignment_heads == 0, 'alignment is obsolete'

if self_attn_type == "scaled-dot":
self.self_attn = MultiHeadedAttention(
Expand All @@ -86,8 +88,6 @@ def __init__(
self.layer_norm_3 = nn.LayerNorm(d_model, eps=1e-6)
self.layer_norm_4 = nn.LayerNorm(d_model, eps=1e-6)
self.drop = nn.Dropout(dropout)
self.full_context_alignment = full_context_alignment
self.alignment_heads = alignment_heads

def forward(self, *args, **kwargs):
"""Extend `_forward` for (possibly) multiple decoder pass:
Expand All @@ -97,7 +97,6 @@ def forward(self, *args, **kwargs):

Args:
* All arguments of _forward.
with_align (bool): whether return alignment attention.

Returns:
(FloatTensor, FloatTensor, FloatTensor or None):
Expand All @@ -106,22 +105,9 @@ def forward(self, *args, **kwargs):
* top_attn ``(batch_size, T, src_len)``
* attn_align ``(batch_size, T, src_len)`` or None
"""
with_align = kwargs.pop("with_align", False)
output, attns = self._forward(*args, **kwargs)
top_attn = attns[:, 0, :, :].contiguous()
attn_align = None
if with_align:
if self.full_context_alignment:
# return _, (B, Q_len, K_len)
_, attns = self._forward(*args, **kwargs, future=True)

if self.alignment_heads > 0:
attns = attns[:, : self.alignment_heads, :, :].contiguous()
# layer average attention across heads, get ``(B, Q, K)``
# Case 1: no full_context, no align heads -> layer avg baseline
# Case 2: no full_context, 1 align heads -> guided align
# Case 3: full_context, 1 align heads -> full cte guided align
attn_align = attns.mean(dim=1)
return output, top_attn, attn_align

def update_dropout(self, dropout, attention_dropout):
Expand Down Expand Up @@ -317,9 +303,9 @@ def from_opts(cls, opts, embeddings, is_on_top=False):
embeddings,
opts.max_relative_positions,
opts.aan_useffn,
opts.full_context_alignment,
opts.alignment_layer,
alignment_heads=opts.alignment_heads,
False,
None,
alignment_heads=0.,
pos_ffn_activation_fn=opts.pos_ffn_activation_fn,
layer_norm_module=(
nn.LayerNorm(opts.model_dim, eps=1e-6) if is_on_top
Expand Down Expand Up @@ -489,7 +475,6 @@ def forward(
src_max_len = memory_bank.size(1)
src_pad_mask = ~sequence_mask(memory_lengths, src_max_len).unsqueeze(1)

with_align = kwargs.pop("with_align", False)
attn_aligns = []

for i, layer in enumerate(self._get_layers()):
Expand All @@ -505,7 +490,6 @@ def forward(
tgt_pad_mask,
layer_cache=layer_cache,
step=step,
with_align=with_align,
)
if attn_align is not None:
attn_aligns.append(attn_align)
Expand All @@ -517,9 +501,6 @@ def forward(
attns = {"std": attn}
if self._copy:
attns["copy"] = attn
if with_align:
attns["align"] = attn_aligns[self.alignment_layer] # `(B, Q, K)`
# attns["align"] = torch.stack(attn_aligns, 0).mean(0) # All avg

# TODO change the way attns is returned dict => list or tuple (onnx)
return dec_outs, attns
Expand Down
Loading
Loading