Skip to content

Commit

Permalink
fix 37 (#38)
Browse files Browse the repository at this point in the history
Co-authored-by: Tuan Tran <{ID}+{username}@users.noreply.github.com>
  • Loading branch information
antoine-tran and Tuan Tran authored Jun 24, 2024
1 parent 3cff28b commit f95f92e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to AudioSeal are documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.1.3] - 2024-06-24

- Update scripts to new training code
- Fix bugs in loading custom fine-tuned model (https://github.com/facebookresearch/audioseal/issues/37)

## [0.1.3] - 2024-04-30

- Fix bug in getting the watermark with non-empty message created in CPU, while the model is loaded in CUDA
Expand Down
8 changes: 4 additions & 4 deletions src/audioseal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import asdict, dataclass, is_dataclass
from typing import Any, Dict, List, Mapping, Optional
from dataclasses import asdict, dataclass, field, is_dataclass
from typing import Any, Dict, List, Optional

from omegaconf import DictConfig, OmegaConf
from torch import device, dtype
Expand Down Expand Up @@ -55,7 +55,7 @@ class DecoderConfig:

@dataclass
class DetectorConfig:
output_dim: int
output_dim: int = 32


@dataclass
Expand All @@ -69,7 +69,7 @@ class AudioSealWMConfig:
class AudioSealDetectorConfig:
nbits: int
seanet: SEANetConfig
detector: DetectorConfig
detector: DetectorConfig = field(default_factory=lambda: DetectorConfig())


def as_dict(obj: Any) -> Dict[str, Any]:
Expand Down
6 changes: 4 additions & 2 deletions src/audioseal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
COMPATIBLE_WARNING = """
AudioSeal is designed to work at a sample rate 16khz.
Implicit sampling rate usage is deprecated and will be removed in future version.
To remove this warning please add this argument to the function call:
To remove this warning please add this argument to the function call:
sample_rate = your_sample_rate
"""

Expand Down Expand Up @@ -111,7 +111,9 @@ def get_watermark(
if self.msg_processor is not None:
if message is None:
if self.message is None:
message = torch.randint(0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device)
message = torch.randint(
0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device
)
else:
message = self.message.to(device=x.device)
else:
Expand Down
20 changes: 13 additions & 7 deletions src/scripts/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import torch


def convert(checkpoint: Path, outdir: Path, suffix: str = "base"):
def convert(checkpoint: str, outdir: str, suffix: str = "base"):
"""Convert the checkpoint to generator and detector"""
outdir_path = Path(outdir)
ckpt = torch.load(checkpoint)

# keep inference-related params only
Expand All @@ -27,16 +28,21 @@ def convert(checkpoint: Path, outdir: Path, suffix: str = "base"):

for layer in ckpt["model"].keys():
if layer.startswith("detector"):
detector_ckpt["model"][layer] = ckpt["model"][layer]
new_layer = layer[9:]
detector_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore
elif layer == "msg_processor.msg_processor.0.weight":
generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[
generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ # type: ignore
"model"
][layer]
][
layer
]
else:
generator_ckpt["model"][layer] = ckpt["model"][layer]
assert layer.startswith("generator"), f"Invalid layer: {layer}"
new_layer = layer[10:]
generator_ckpt["model"][new_layer] = ckpt["model"][layer] # type: ignore

torch.save(generator_ckpt, outdir / (checkpoint.stem + f"_generator_{suffix}.pth"))
torch.save(detector_ckpt, outdir / (checkpoint.stem + f"_detector_{suffix}.pth"))
torch.save(generator_ckpt, outdir_path / (f"checkpoint_generator_{suffix}.pth"))
torch.save(detector_ckpt, outdir_path / (f"checkpoint_detector_{suffix}.pth"))


if __name__ == "__main__":
Expand Down

0 comments on commit f95f92e

Please sign in to comment.