Skip to content

Commit

Permalink
Publish internal code to upstream (#36)
Browse files Browse the repository at this point in the history
Co-authored-by: Tuan Tran <{ID}+{username}@users.noreply.github.com>
  • Loading branch information
antoine-tran and Tuan Tran authored Jun 18, 2024
1 parent 391ae43 commit 3cff28b
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 9 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ All notable changes to AudioSeal are documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.1.3] - 2024-04-30
- Fix bug in getting the watermark with non-empty message created in CPU, while the model is loaded in CUDA
- Update Fix bug in building the model card programmatically (not via .YAML file using OmegaConf)
- Add support for HuggingFace Hub, now we can load the model from HF. Unit tests are updated

- Fix bug in getting the watermark with non-empty message created in CPU, while the model is loaded in CUDA
- Update Fix bug in building the model card programmatically (not via .YAML file using OmegaConf)
- Add support for HuggingFace Hub, now we can load the model from HF. Unit tests are updated

## [0.1.2] - 2024-02-29

- Add py.typed to make audioseal mypy-friendly
- Add the option to resample the input audio's sample rate to the expected sample rate of the model (https://github.com/facebookresearch/audioseal/pull/18)
- Move `attacks.py` to non-core code base of audioseal
Expand Down
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ More details can be found in the [paper](https://arxiv.org/abs/2401.17264).

# Updates:

- 2024-06-17: Training code is now available. Check the [instruction](./docs/TRAINING.md) !!!
- 2024-05-31: Our paper gets accepted at ICML'24 :)
- 2024-04-02: We have updated our license to full MIT license (including the license for the model weights) ! Now you can use AudioSeal in commercial application too !
- 2024-02-29: AudioSeal 0.1.2 is out, with more bug fixes for resampled audios and updated notebooks

Expand Down Expand Up @@ -101,6 +103,10 @@ print(result[:, 1 , :])
print(message)
```

# Train your own watermarking model

See [here](./docs/TRAINING.md) for details on how to train your own Watermarking model.

# Want to contribute?

We welcome Pull Requests with improvements or suggestions.
Expand All @@ -115,8 +121,8 @@ dummy batch dimension to your input (e.g. `wav.unsqueeze(0)`, see [example noteb
uploaded to the model hub, which is not compatible in Windows. Try to invalidate the cache by removing the files in `C:\Users\<USER>\.cache\audioseal`
and re-run again.

- If you use torchaudio to handle your audios and encounter the error `Couldn't find appropriate backend to handle uri ...`, this is due to newer version of
torchaudio does not handle the default backend well. Either downgrade your torchaudio to `2.0.1` or earlier, or install `soundfile` as your audio backend.
- If you use torchaudio to handle your audios and encounter the error `Couldn't find appropriate backend to handle uri ...`, this is due to newer version of
torchaudio does not handle the default backend well. Either downgrade your torchaudio to `2.1.0` or earlier, or install `soundfile` as your audio backend.

# License

Expand All @@ -136,7 +142,7 @@ If you find this repository useful, please consider giving a star :star: and ple
@article{sanroman2024proactive,
title={Proactive Detection of Voice Cloning with Localized Watermarking},
author={San Roman, Robin and Fernandez, Pierre and Elsahar, Hady and D´efossez, Alexandre and Furon, Teddy and Tran, Tuan},
journal={arXiv preprint},
journal={ICML},
year={2024}
}
```
129 changes: 129 additions & 0 deletions docs/TRAINING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Training a new watermarking model

This doc shows how to train a new AudioSeal model. The training pipeline was developed using [AudioCraft](https://github.com/facebookresearch/audiocraft) (version 0.1.4 and later). The following example is tested on Pytorch==2.1.0 and, torchaudio==2.1.0:

## Prerequisite

We need AudioCraft >=1.4.0a1. If you want to experiment with different datasets and training recipes, we advise that you download the source code of audiocraft and install directly from source, see [Installation notes](https://github.com/facebookresearch/audiocraft/blob/main/README.md#installation):

```bash
git clone https://github.com/facebookresearch/audiocraft.git
cd audiocraft
pip install -e .

sudo apt-get install ffmpeg
# Or if you are using Anaconda or Miniconda
conda install "ffmpeg<5" -c conda-forge
```

Note that the step of installing ffmpeg (<5.0.0) in the notes is mandatory, otherwise the training loop will fail as our AAC augmentation step depends on it.

## Preparing dataset

The dataset should be processed in AudioCraft [format](https://github.com/facebookresearch/audiocraft/blob/main/docs/DATASETS.md). The first step is to create the manifest for your dataset. For Voxpopuli (which is used in the paper), run the following command:

```bash

# Download the raw audios and segment them
git clone https://github.com/facebookresearch/voxpopuli.git
cd voxpopuli
python -m voxpopuli.download_audios --root [ROOT] --subset 400k
python -m voxpopuli.get_unlabelled_data --root [ROOT] --subset 400k

# Run audiocraft data tool to prepare the manifest
cd [PATH to audiocraft]
python -m audiocraft.data.audio_dataset [ROOT] egs/voxpopuli/data.jsonl.gz
```

Then, prepare the following datasource definition and put it inside the "[audiocraft root]/configs/dset/audio/voxpopuli.yaml":

```yaml
# @package __global__

datasource:
max_sample_rate: 16000
max_channels: 1

train: egs/voxpopuli
valid: egs/voxpopuli
evaluate: egs/voxpopuli
generate: egs/voxpopuli
```
## Training
The training pipeline uses [Dora](https://github.com/facebookresearch/dora) to structure the experiments and perform grid-based paratermeter tuning. It is useful to get yourself familiar with Dora concepts such as dora run, dora grid, etc. before starting.
To test the training pipeline locally, see [this documentation in Audiocraft](https://github.com/facebookresearch/audiocraft/blob/main/docs/WATERMARKING.md). You can replace the example dataset with the above Voxpopuli, e.g. run the following command within the Audiocraft cloned directory:
```bash
dora run solver=watermark/robustness dset=audio/example
```

By default the checkpoints and experiment files are stored in `/tmp/audiocraft_$USER/outputs`. To customize where your own Dora output and experiment folder are, as well as to run in a SLURM cluster, define a config file with the following structure:

```yaml
# File name: my_config.yaml

default:
dora_dir: [DORA PATH]
partitions:
global: your_slurm_partitions
team: your_slurm_partitions
reference_dir: /tmp
darwin: # if we detect we are on a Mac, then most likely we are doing unit testing etc.
dora_dir: [YOUR PATH]
partitions:
global: your_slurm_partitions
team: your_slurm_partitions
reference_dir: [REFERENCE PATH]
```
where `partitions` indicates the SLURM partitions you are entitled to run your jobs. Then re-run the `dora run` command with the custom config:

```bash
AUDIOCRAFT_CONFIG=my_config.yaml dora run solver=watermark/robustness dset=audio/voxpopuli
```

## Evaluate the checkpoint

If successful, the checkpoints will be stored in an experiment folder in your dora dir, i.e. `[DORA_PATH]/xps/[HASH-ID]/checkpoint_XXX.th` , where `HASH-ID` is the Id of the experiment you will see in the output log when running `dora run`. You can choose to evaluate your checkpoints with diffferent settings for nbits, and choose the ones with lowest losses:

```bash
AUDIOCRAFT_CONFIG=my_config.yaml dora run solver=watermark/robustness execute_only=evaluate dset=audio/voxpopuli continue_from=[PATH_TO_THE_CHECKPOINT_FILE] +dummy_watermarker.nbits=16 seanet.detector.output_dim=32
```

## Postprocessing the checkpoints for inference

The checkpoint contains the jointly-trained generator and detector, so it cannot be used right away in AudioSeal API. To extract the generator and detector, run the conversion script in Audioseal code "src/scripts/checkpoints.py":

```bash
python [AudioSeal path]/src/scripts/checkpoints.py --checkpoint=[PATH TO CHECKPOINT] --outdir=[OUTPUT_DIR] --suffix=[name of the new model]
```

After this step, there will be two checkpoint files named `generator_[suffix].pth` and `detector_[suffix].pth` in the output directory [OUTPUT_DIR]. You can use these new checkpoints directly with AudioSeal API, for instance:

```python
model = AudioSeal.load_generator("[OUTPUT_DIR]/generator_[suffix].pth", nbits=16)
watermark = model.get_watermark(wav, sr)
detector = AudioSeal.load_detector("[OUTPUT_DIR]/detector_[suffix].pth", nbits=16)
result, message = detector(watermarked_audio, sr)
```

## Training the HF AudioSeal model

We also provide the hyperparameter and training config (in Dora term, a "grid") to reproduce our checkpoints for AudioSeal in HuggingFace (which is also the one used to produce the results ported in the ICML paper). To get this, check the AudioCraft's watermarking [grid](https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/grids/watermarking/1315_kbits_seeds.py). To reproduce the result, run the `dora grid` command:

```bash
AUDIOCRAFT_CONFIG=my_config.yaml AUDIOCRAFT_DSET=audio/voxpopuli dora grid watermarking.1315_kbits_seeds
```

## Troubleshooting

1. If you encounter the error `Unsupported formats` on Linux, the ffmpeg is not properly installed or superseded by other backends in your system. Try to instruct dora to use the libs you installed in your environment explicitly, i.e. adding them to `LD_LIBRARY_PATH`. If you use Anaconda, you can try:

```bash
LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH AUDIOCRAFT_DORA_DIR=my_config.yaml [dora run/grid command]
```
7 changes: 6 additions & 1 deletion examples/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

#
# Example attacks using different audio effects.
# For full list of atacks, check
# https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/audio_effects.py
#
#
import typing as tp

import julius
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# For developers wanting to contribute to AudioSeal
func_argparse
fire
torchaudio
soundfile
pytest
Expand All @@ -8,4 +8,4 @@ black
isort
flake8
pre-commit
huggingface_hub
huggingface_hub
45 changes: 45 additions & 0 deletions src/scripts/checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from pathlib import Path

import torch


def convert(checkpoint: Path, outdir: Path, suffix: str = "base"):
"""Convert the checkpoint to generator and detector"""
ckpt = torch.load(checkpoint)

# keep inference-related params only
infer_cfg = {
"seanet": ckpt["xp.cfg"]["seanet"],
"channels": ckpt["xp.cfg"]["channels"],
"dtype": ckpt["xp.cfg"]["dtype"],
"sample_rate": ckpt["xp.cfg"]["sample_rate"],
}

generator_ckpt = {"xp.cfg": infer_cfg, "model": {}}
detector_ckpt = {"xp.cfg": infer_cfg, "model": {}}

for layer in ckpt["model"].keys():
if layer.startswith("detector"):
detector_ckpt["model"][layer] = ckpt["model"][layer]
elif layer == "msg_processor.msg_processor.0.weight":
generator_ckpt["model"]["msg_processor.msg_processor.weight"] = ckpt[
"model"
][layer]
else:
generator_ckpt["model"][layer] = ckpt["model"][layer]

torch.save(generator_ckpt, outdir / (checkpoint.stem + f"_generator_{suffix}.pth"))
torch.save(detector_ckpt, outdir / (checkpoint.stem + f"_detector_{suffix}.pth"))


if __name__ == "__main__":
import fire

fire.Fire(convert)

0 comments on commit 3cff28b

Please sign in to comment.