Skip to content

Commit

Permalink
Update message decoder APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
Tuan Tran committed Jan 29, 2024
1 parent b9a3036 commit 31f05c7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 70 deletions.
30 changes: 15 additions & 15 deletions CODE_OF_CONDUCT.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,22 @@ appearance, race, religion, or sexual identity and orientation.
Examples of behavior that contributes to creating a positive environment
include:

- Using welcoming and inclusive language
- Being respectful of differing viewpoints and experiences
- Gracefully accepting constructive criticism
- Focusing on what is best for the community
- Showing empathy towards other community members
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members

Examples of unacceptable behavior by participants include:

- The use of sexualized language or imagery and unwelcome sexual attention or
advances
- Trolling, insulting/derogatory comments, and personal or political attacks
- Public or private harassment
- Publishing others' private information, such as a physical or electronic
address, without explicit permission
- Other conduct which could reasonably be considered inappropriate in a
professional setting
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting

## Our Responsibilities

Expand Down Expand Up @@ -59,7 +59,7 @@ the project or its community.
## Enforcement

Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@fb.com>. All
reported by contacting the project team at <opensource-conduct@meta.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Expand All @@ -77,4 +77,4 @@ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.ht
[homepage]: https://www.contributor-covenant.org

For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
https://www.contributor-covenant.org/faq
36 changes: 0 additions & 36 deletions examples/convert_checkpoints.py

This file was deleted.

54 changes: 46 additions & 8 deletions src/audioseal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ class AudioSealDetector(torch.nn.Module):
Detect the watermarking from an audio signal
Args:
SEANetEncoderKeepDimension (_type_): _description_
nbits (int): The number of bits in the secret message. The watermarks (if detected)
will have size 2 + nbits, where the first two items indicate the possibilities
of a true watermarking (positive / negative scores), he rest is used to decode
nbits (int): The number of bits in the secret message. The result will have size
of 2 + nbits, where the first two items indicate the possibilities of the
audio being watermarked (positive / negative scores), he rest is used to decode
the secret message. In 0bit watermarking (no secret message), the detector just
returns 2 values.
"""
Expand All @@ -154,13 +154,51 @@ def __init__(self, *args, nbits: int = 0, **kwargs):
encoder = SEANetEncoderKeepDimension(*args, **kwargs)
last_layer = torch.nn.Conv1d(encoder.output_dim, 2 + nbits, 1)
self.detector = torch.nn.Sequential(encoder, last_layer)
self.nbits = nbits

def decode_message(self, result: torch.Tensor): ...
def detect_watermark(
self, x: torch.Tensor, message_threshold: float = 0.6
) -> Tuple[float, torch.Tensor]:
"""
A convenience function that returns a probability of an audio being watermarked,
together with its message in n-bits (binary) format. If the audio is not watermarked,
the message will be all zeros.
Args:
x: Audio signal, size batch x frames
message_threshold: threshold used to convert the watermark output (probability
of each bits being 0 or 1) into the binary n-bit message.
"""
result, message = self.forward(x) # b x 2+nbits
detected = torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]
detect_prob = detected.cpu().item() # type: ignore
message = torch.gt(message, message_threshold).int()
return detect_prob, message

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
def decode_message(self, result: torch.Tensor) -> torch.Tensor:
"""
Decode the message from the watermark result (batch x nbits x frames)
Args:
result: watermark result (batch x nbits x frames)
Returns:
The message of size batch x nbits, indicating probability of 1 for each bit
"""
assert (
(result.dim() > 2 and result.shape[1] == self.nbits) or
(self.dim() == 2 and result.shape[0] == self.nbits)
), f"Expect message of size [,{self.nbits}, frames] (get {result.size()})"
decoded_message = result.mean(dim=-1)
return torch.sigmoid(decoded_message)

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Detect the watermarks from the audio signal
Args:
x: Audio signal, size batch x frames
"""
result = self.detector(x) # b x 2+nbits
# hardcode softmax on 2 first units used for detection
result[:, :2, :] = torch.softmax(result[:, :2, :], dim=1)

# TODO: Return the result and the message as a tuple
return result[:, :2, :], result[:, 2:, :]
message = self.decode_message(result[:, 2:, :])
return result[:, :2, :], message
25 changes: 14 additions & 11 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,27 @@ def example_audio(tmp_path):
wav, _ = torchaudio.load(tmp_path / "test.wav")

# Add batch dimension
wav = wav.unsqueeze(0)

yield wav
yield wav.unsqueeze(0)


def test_detector(example_audio):
print(example_audio.size())

model = AudioSeal.load_generator("audioseal_wm_16bits")

secret_message = torch.randint(0, 2, (1, 16))
secret_message = torch.randint(0, 2, (1, 16), dtype=torch.int32)
watermark = model(example_audio, message=secret_message, alpha=0.8)

watermarked_audio = example_audio + watermark

detector = AudioSeal.load_detector(("audioseal_detector_16bits"))
result, message = detector(watermarked_audio) # noqa

pred_prob = torch.count_nonzero(torch.gt(result[:, 1, :], 0.5)) / result.shape[-1]

assert pred_prob.item() > 0.7
result, message = detector.detect_watermark(watermarked_audio) # noqa

# Due to non-deterministic decoding, messages are not always the same as message
print(
"Matching bits in decoded and original messages: "
f"{torch.count_nonzero(torch.eq(message, secret_message)).item()}\n"
)
assert result > 0.7

# Try to detect the unwatermarked audio
result, message = detector.detect_watermark(example_audio) # noqa
assert torch.all(message == 0)

0 comments on commit 31f05c7

Please sign in to comment.