Skip to content

Commit

Permalink
sync patch to upstream repo (#31)
Browse files Browse the repository at this point in the history
Co-authored-by: Tuan Tran <{ID}+{username}@users.noreply.github.com>
  • Loading branch information
antoine-tran and Tuan Tran authored Apr 30, 2024
1 parent 3cd73dd commit 673c000
Show file tree
Hide file tree
Showing 13 changed files with 132 additions and 73 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint_and_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ jobs:
run: |
sudo apt-get install libsndfile1
python -m pip install --upgrade pip
pip install torch==1.13.0 torchaudio==0.13.0 func_argparse soundfile pytest omegaconf numpy julius
pip install torch==1.13.0 torchaudio==0.13.0 func_argparse soundfile pytest omegaconf numpy julius huggingface_hub
pip install --no-deps -e .
- name: pytest_unit
run: pytest -s -v tests/test_models.py
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# Visual Studio Code
.vscode/

# local training outputs
outputs/*
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@ repos:
- id: end-of-file-fixer

- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.4.2
hooks:
- id: black
language_version: python3.8

- repo: https://github.com/pycqa/isort
rev: 5.12.0
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to AudioSeal are documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.1.3] - 2024-04-30
- Fix bug in getting the watermark with non-empty message created in CPU, while the model is loaded in CUDA
- Update Fix bug in building the model card programmatically (not via .YAML file using OmegaConf)
- Add support for HuggingFace Hub, now we can load the model from HF. Unit tests are updated


## [0.1.2] - 2024-02-29
- Add py.typed to make audioseal mypy-friendly
- Add the option to resample the input audio's sample rate to the expected sample rate of the model (https://github.com/facebookresearch/audioseal/pull/18)
Expand Down
6 changes: 3 additions & 3 deletions examples/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def echo(
# Define a few reflections with decreasing amplitude
impulse_response[0] = 1.0 # Direct sound

impulse_response[
int(sample_rate * duration) - 1
] = volume # First reflection after 100ms
impulse_response[int(sample_rate * duration) - 1] = (
volume # First reflection after 100ms
)

# Add batch and channel dimensions to the impulse response
impulse_response = impulse_response.unsqueeze(0).unsqueeze(0)
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ black
isort
flake8
pre-commit
huggingface_hub
2 changes: 1 addition & 1 deletion src/audioseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

__version__ = "0.1.2"
__version__ = "0.1.3"


from audioseal import builder
Expand Down
24 changes: 18 additions & 6 deletions src/audioseal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from dataclasses import asdict, dataclass, is_dataclass
from typing import Any, Dict, List, Mapping, Optional

from omegaconf import DictConfig, OmegaConf
from torch import device, dtype
from typing_extensions import TypeAlias

Expand Down Expand Up @@ -71,6 +72,17 @@ class AudioSealDetectorConfig:
detector: DetectorConfig


def as_dict(obj: Any) -> Dict[str, Any]:
if isinstance(obj, dict):
return obj
if is_dataclass(obj):
return asdict(obj)
elif isinstance(obj, DictConfig):
return OmegaConf.to_container(obj) # type: ignore
else:
raise NotImplementedError(f"Unsupported type for config: {type(obj)}")


def create_generator(
config: AudioSealWMConfig,
*,
Expand All @@ -81,11 +93,11 @@ def create_generator(

# Currently the encoder hparams are the same as
# SEANet, but this can be changed in the future.
encoder = audiocraft.modules.SEANetEncoder(**config.seanet) # type: ignore[arg-type]
encoder = audiocraft.modules.SEANetEncoder(**as_dict(config.seanet))
encoder = encoder.to(device=device, dtype=dtype)

decoder_config = {**config.seanet, **config.decoder} # type: ignore
decoder = audiocraft.modules.SEANetDecoder(**decoder_config) # type: ignore[arg-type]
decoder_config = {**as_dict(config.seanet), **as_dict(config.decoder)}
decoder = audiocraft.modules.SEANetDecoder(**as_dict(decoder_config))
decoder = decoder.to(device=device, dtype=dtype)

msgprocessor = MsgProcessor(nbits=config.nbits, hidden_size=config.seanet.dimension)
Expand All @@ -100,7 +112,7 @@ def create_detector(
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> AudioSealDetector:
detector_config = {**config.seanet, **config.detector} # type: ignore
detector_config = {**as_dict(config.seanet), **as_dict(config.detector)}
detector = AudioSealDetector(nbits=config.nbits, **detector_config)
detector = detector.to(device=device, dtype=dtype)
return detector
2 changes: 1 addition & 1 deletion src/audioseal/libs/audiocraft/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
#
# Vendor from https://github.com/facebookresearch/audiocraft

import math
Expand Down
1 change: 1 addition & 0 deletions src/audioseal/libs/audiocraft/modules/seanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def forward(self, x):
# make sure dim didn't change
return x[:, :, :orig_nframes]


class SEANetDecoder(nn.Module):
"""SEANet decoder.
Expand Down
112 changes: 68 additions & 44 deletions src/audioseal/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
from omegaconf import DictConfig, OmegaConf

import audioseal
from audioseal.builder import (
AudioSealDetectorConfig,
AudioSealWMConfig,
Expand Down Expand Up @@ -80,11 +81,29 @@ def load_model_checkpoint(
parts = urlparse(str(model_path))
if parts.scheme == "https":

# TODO: Add HF Hub
hash_ = sha1(parts.path.encode()).hexdigest()[:24]
return torch.hub.load_state_dict_from_url(
str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_
)
elif str(model_path).startswith("facebook/audioseal/"):
hf_filename = str(model_path)[len("facebook/audioseal/") :]

try:
from huggingface_hub import hf_hub_download
except ModuleNotFoundError:
print(
f"The model path {model_path} seems to be a direct HF path, "
"but you do not install Huggingface_hub. Install with for example "
"`pip install huggingface_hub` to use this feature."
)
file = hf_hub_download(
repo_id="facebook/audioseal",
filename=hf_filename,
cache_dir=cache_dir,
library_name="audioseal",
library_version=audioseal.__version__,
)
return torch.load(file, map_location=device)
else:
raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist")

Expand All @@ -100,7 +119,7 @@ def load_local_model_config(model_card: str) -> Optional[DictConfig]:
class AudioSeal:

@staticmethod
def _parse_model(
def parse_model(
model_card_or_path: str,
model_type: Type[AudioSealT],
nbits: Optional[int] = None,
Expand All @@ -126,64 +145,67 @@ def _parse_model(
config_dict = {}
checkpoint = load_model_checkpoint(model_card_or_path)

# If the checkpoint has config in its, take this but uses the info
# in the mode as precedence
assert isinstance(
checkpoint, dict
), f"Expect loaded checkpoint to be a dictionary, get {type(checkpoint)}"
assert isinstance(
config_dict, dict
), f"Except loaded config to be a dictionary, get {type(config_dict)}"
if "xp.cfg" in checkpoint:
config = {**checkpoint["xp.cfg"], **config_dict} # type: ignore
assert config is not None
assert (
"seanet" in config
), f"missing seanet backbone config in {model_card_or_path}"

# Patch 1: Resolve the variables in the checkpoint
config = OmegaConf.create(config)
OmegaConf.resolve(config)
config = OmegaConf.to_container(config) # type: ignore

# Patch 2: Put decoder, encoder and detector outside seanet
seanet_config = config["seanet"]
for key_to_patch in ["encoder", "decoder", "detector"]:
if key_to_patch in seanet_config:
config_to_patch = config.get(key_to_patch) or {}
config[key_to_patch] = {
**config_to_patch,
**seanet_config.pop(key_to_patch),
}

config["seanet"] = seanet_config

# Patch 3: Put nbits into config if specified
if nbits and "nbits" not in config:
config["nbits"] = nbits
config_dict = {**checkpoint["xp.cfg"], **config_dict} # type: ignore

model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits) # type: ignore

if "model" in checkpoint:
checkpoint = checkpoint["model"]

return checkpoint, model_config

@staticmethod
def parse_config(
config: Dict[str, Any],
config_type: Type[AudioSealT],
nbits: Optional[int] = None,
) -> AudioSealT:

assert "seanet" in config, f"missing seanet backbone config in {config}"

# Patch 1: Resolve the variables in the checkpoint
config = OmegaConf.create(config) # type: ignore
OmegaConf.resolve(config) # type: ignore
config = OmegaConf.to_container(config) # type: ignore

# Patch 2: Put decoder, encoder and detector outside seanet
seanet_config = config["seanet"]
for key_to_patch in ["encoder", "decoder", "detector"]:
if key_to_patch in seanet_config:
config_to_patch = config.get(key_to_patch) or {}
config[key_to_patch] = {
**config_to_patch,
**seanet_config.pop(key_to_patch),
}

config["seanet"] = seanet_config

# Patch 3: Put nbits into config if specified
if nbits and "nbits" not in config:
config["nbits"] = nbits

# remove attributes not related to the model_type
result_config = {}
assert config, f"Empty config in {model_card_or_path}"
for field in fields(model_type):
assert config, f"Empty config"
for field in fields(config_type):
if field.name in config:
result_config[field.name] = config[field.name]

schema = OmegaConf.structured(model_type)
schema = OmegaConf.structured(config_type)
schema.merge_with(result_config)
return checkpoint, schema
return schema

@staticmethod
def load_generator(
model_card_or_path: str,
nbits: Optional[int] = None,
) -> AudioSealWM:
"""Load the AudioSeal generator from the model card"""
checkpoint, config = AudioSeal._parse_model(
model_card_or_path, AudioSealWMConfig, nbits=nbits,
checkpoint, config = AudioSeal.parse_model(
model_card_or_path,
AudioSealWMConfig,
nbits=nbits,
)

model = create_generator(config)
Expand All @@ -195,8 +217,10 @@ def load_detector(
model_card_or_path: str,
nbits: Optional[int] = None,
) -> AudioSealDetector:
checkpoint, config = AudioSeal._parse_model(
model_card_or_path, AudioSealDetectorConfig, nbits=nbits,
checkpoint, config = AudioSeal.parse_model(
model_card_or_path,
AudioSealDetectorConfig,
nbits=nbits,
)
model = create_detector(config)
model.load_state_dict(checkpoint)
Expand Down
33 changes: 18 additions & 15 deletions src/audioseal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,24 @@ def get_watermark(
hidden = self.encoder(x)

if self.msg_processor is not None:
if message is None:
self.message = self.message or torch.randint(
0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device
)
message = self.message
if self.message is None:
self.message = torch.randint(0, 2, (x.shape[0], self.msg_processor.nbits), device=x.device)
else:
self.message = self.message.to(device=x.device)


message = self.message

hidden = self.msg_processor(hidden, message)

watermark = self.decoder(hidden)

if sample_rate != 16000:
watermark = julius.resample_frac(watermark, old_sr=16000, new_sr=sample_rate)
watermark = julius.resample_frac(
watermark, old_sr=16000, new_sr=sample_rate
)

return watermark[
..., : length
] # trim output cf encodec codebase
return watermark[..., :length] # trim output cf encodec codebase

def forward(
self,
Expand Down Expand Up @@ -164,7 +166,7 @@ def detect_watermark(
self,
x: torch.Tensor,
sample_rate: Optional[int] = None,
message_threshold: float = 0.5
message_threshold: float = 0.5,
) -> Tuple[float, torch.Tensor]:
"""
A convenience function that returns a probability of an audio being watermarked,
Expand All @@ -174,13 +176,15 @@ def detect_watermark(
x: Audio signal, size: batch x frames
sample_rate: The sample rate of the input audio
message_threshold: threshold used to convert the watermark output (probability
of each bits being 0 or 1) into the binary n-bit message.
of each bits being 0 or 1) into the binary n-bit message.
"""
if sample_rate is None:
logger.warning(COMPATIBLE_WARNING)
sample_rate = 16_000
result, message = self.forward(x, sample_rate=sample_rate) # b x 2+nbits
detected = torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]
detected = (
torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]
)
detect_prob = detected.cpu().item() # type: ignore
message = torch.gt(message, message_threshold).int()
return detect_prob, message
Expand All @@ -193,9 +197,8 @@ def decode_message(self, result: torch.Tensor) -> torch.Tensor:
Returns:
The message of size batch x nbits, indicating probability of 1 for each bit
"""
assert (
(result.dim() > 2 and result.shape[1] == self.nbits) or
(self.dim() == 2 and result.shape[0] == self.nbits)
assert (result.dim() > 2 and result.shape[1] == self.nbits) or (
self.dim() == 2 and result.shape[0] == self.nbits
), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
decoded_message = result.mean(dim=-1)
return torch.sigmoid(decoded_message)
Expand Down
Loading

0 comments on commit 673c000

Please sign in to comment.