From 31f05c74f27876738e01786d3909c03706175a34 Mon Sep 17 00:00:00 2001 From: Tuan Tran Date: Mon, 29 Jan 2024 13:30:26 +0000 Subject: [PATCH] Update message decoder APIs --- CODE_OF_CONDUCT.md | 30 +++++++++--------- examples/convert_checkpoints.py | 36 ---------------------- src/audioseal/models.py | 54 ++++++++++++++++++++++++++++----- tests/test_models.py | 25 ++++++++------- 4 files changed, 75 insertions(+), 70 deletions(-) delete mode 100644 examples/convert_checkpoints.py diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index c4a3c1d..cf9dc24 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -14,22 +14,22 @@ appearance, race, religion, or sexual identity and orientation. Examples of behavior that contributes to creating a positive environment include: -- Using welcoming and inclusive language -- Being respectful of differing viewpoints and experiences -- Gracefully accepting constructive criticism -- Focusing on what is best for the community -- Showing empathy towards other community members +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members Examples of unacceptable behavior by participants include: -- The use of sexualized language or imagery and unwelcome sexual attention or - advances -- Trolling, insulting/derogatory comments, and personal or political attacks -- Public or private harassment -- Publishing others' private information, such as a physical or electronic - address, without explicit permission -- Other conduct which could reasonably be considered inappropriate in a - professional setting +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting ## Our Responsibilities @@ -59,7 +59,7 @@ the project or its community. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be -reported by contacting the project team at . All +reported by contacting the project team at . All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. @@ -77,4 +77,4 @@ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.ht [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see -https://www.contributor-covenant.org/faq +https://www.contributor-covenant.org/faq \ No newline at end of file diff --git a/examples/convert_checkpoints.py b/examples/convert_checkpoints.py deleted file mode 100644 index a2160c0..0000000 --- a/examples/convert_checkpoints.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - - -from pathlib import Path - -import torch - - -def convert(checkpoint: Path, outdir: Path): - """Convert the checkpoint to generator and detector""" - ckpt = torch.load(checkpoint) - generator_ckpt = {"xp.cfg": ckpt["xp.cfg"], "model": {}} - detector_ckpt = {"xp.cfg": ckpt["xp.cfg"], "model": {}} - - for layer in ckpt["model"].keys(): - if layer.startswith("detector"): - detector_ckpt["model"][layer] = ckpt["model"][layer] - elif layer == "msg_processor.msg_processor.0.weight": - generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[ - "model" - ][layer] - else: - generator_ckpt["model"][layer] = ckpt["model"][layer] - - torch.save(generator_ckpt, outdir / (checkpoint.stem + "_generator.pth")) - torch.save(detector_ckpt, outdir / (checkpoint.stem + "_detector.pth")) - - -if __name__ == "__main__": - import func_argparse - - func_argparse.single_main(convert) diff --git a/src/audioseal/models.py b/src/audioseal/models.py index 96a5ce4..e223315 100644 --- a/src/audioseal/models.py +++ b/src/audioseal/models.py @@ -142,9 +142,9 @@ class AudioSealDetector(torch.nn.Module): Detect the watermarking from an audio signal Args: SEANetEncoderKeepDimension (_type_): _description_ - nbits (int): The number of bits in the secret message. The watermarks (if detected) - will have size 2 + nbits, where the first two items indicate the possibilities - of a true watermarking (positive / negative scores), he rest is used to decode + nbits (int): The number of bits in the secret message. The result will have size + of 2 + nbits, where the first two items indicate the possibilities of the + audio being watermarked (positive / negative scores), he rest is used to decode the secret message. In 0bit watermarking (no secret message), the detector just returns 2 values. """ @@ -154,13 +154,51 @@ def __init__(self, *args, nbits: int = 0, **kwargs): encoder = SEANetEncoderKeepDimension(*args, **kwargs) last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1) self.detector = torch.nn.Sequential(encoder, last_layer) + self.nbits = nbits - def decode_message(self, result: torch.Tensor): ... + def detect_watermark( + self, x: torch.Tensor, message_threshold: float = 0.6 + ) -> Tuple[float, torch.Tensor]: + """ + A convenience function that returns a probability of an audio being watermarked, + together with its message in n-bits (binary) format. If the audio is not watermarked, + the message will be all zeros. + + Args: + x: Audio signal, size batch x frames + message_threshold: threshold used to convert the watermark output (probability + of each bits being 0 or 1) into the binary n-bit message. + """ + result, message = self.forward(x) # b x 2+nbits + detected = torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1] + detect_prob = detected.cpu().item() # type: ignore + message = torch.gt(message, message_threshold).int() + return detect_prob, message - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + def decode_message(self, result: torch.Tensor) -> torch.Tensor: + """ + Decode the message from the watermark result (batch x nbits x frames) + Args: + result: watermark result (batch x nbits x frames) + Returns: + The message of size batch x nbits, indicating probability of 1 for each bit + """ + assert ( + (result.dim() > 2 and result.shape[1] == self.nbits) or + (self.dim() == 2 and result.shape[0] == self.nbits) + ), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})" + decoded_message = result.mean(dim=-1) + return torch.sigmoid(decoded_message) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Detect the watermarks from the audio signal + + Args: + x: Audio signal, size batch x frames + """ result = self.detector(x) # b x 2+nbits # hardcode softmax on 2 first units used for detection result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1) - - # TODO: Return the result and the message as a tuple - return result[:, :2, :], result[:, 2:, :] + message = self.decode_message(result[:, 2:, :]) + return result[:, :2, :], message diff --git a/tests/test_models.py b/tests/test_models.py index 5e9e6f5..f9f839e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,24 +24,27 @@ def example_audio(tmp_path): wav, _ = torchaudio.load(tmp_path / "test.wav") # Add batch dimension - wav = wav.unsqueeze(0) - - yield wav + yield wav.unsqueeze(0) def test_detector(example_audio): - print(example_audio.size()) - model = AudioSeal.load_generator("audioseal_wm_16bits") - secret_message = torch.randint(0, 2, (1, 16)) + secret_message = torch.randint(0, 2, (1, 16), dtype=torch.int32) watermark = model(example_audio, message=secret_message, alpha=0.8) watermarked_audio = example_audio + watermark detector = AudioSeal.load_detector(("audioseal_detector_16bits")) - result, message = detector(watermarked_audio) # noqa - - pred_prob = torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1] - - assert pred_prob.item() > 0.7 + result, message = detector.detect_watermark(watermarked_audio) # noqa + + # Due to non-deterministic decoding, messages are not always the same as message + print( + "Matching bits in decoded and original messages: " + f"{torch.count_nonzero(torch.eq(message, secret_message)).item()}\n" + ) + assert result > 0.7 + + # Try to detect the unwatermarked audio + result, message = detector.detect_watermark(example_audio) # noqa + assert torch.all(message == 0)