Skip to content

Commit

Permalink
fix: ensure saved models tensors are contiguous
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Aug 24, 2024
1 parent 436fe39 commit 4659bd5
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions edsnlp/pipes/trainable/embeddings/transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,20 @@ def to_disk(self, path, *, exclude: Optional[Set[str]]):
if repr_id in exclude:
return
self.tokenizer.save_pretrained(path)

# Fix for https://github.com/aphp/edsnlp/issues/317
old_params_data = {}
for param in self.transformer.parameters():
if not param.is_contiguous():
old_params_data[param] = param.data
param.data = param.data.contiguous()

self.transformer.save_pretrained(path)

# Restore non-contiguous tensors
for param, data in old_params_data.items():
param.data = data

for param in self.transformer.parameters():
exclude.add(object.__repr__(param))
cfg = super().to_disk(path, exclude=exclude) or {}
Expand Down

0 comments on commit 4659bd5

Please sign in to comment.