Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Midi-120: BART #7

Open
wants to merge 19 commits into
base: MIDI-120/unsupervised-training
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions configs/BARTdenoise-dstart.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
train:
num_epochs: 5
accum_iter: 5
batch_size: 2
base_lr: 3e-5
warmup: 4000
finetune: False

model_name: BART
dataset_name: 'roszcz/maestro-v1-sustain'
target: denoise
seed: 26

overfit: False

tokens_per_note: single
time_quantization_method: dstart
masking_probability: 0.3
mask: tokens

encoder: velocity
time_bins: 100

dataset:
sequence_len: 128
sequence_step: 42

quantization:
dstart: 8
duration: 8
velocity: 3

device: "cuda:0"

log: True
log_frequency: 10
run_name: midi-bart-${now:%Y-%m-%d-%H-%M}
project: "midi-bart"

pre_defined_model: null
model:
encoder_layers: 6
encoder_ffn_dim: 2048
encoder_attention_heads: 8
decoder_layers: 6
decoder_ffn_dim: 2048
decoder_attention_heads: 8
d_model: 512
49 changes: 49 additions & 0 deletions configs/BARTdenoise.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
train:
num_epochs: 5
accum_iter: 10
batch_size: 8
base_lr: 5e-7
warmup: 4000
finetune: True

model_name: BART
dataset_name: 'roszcz/maestro-v1-sustain'
target: denoise
seed: 26

overfit: False

tokens_per_note: multiple
time_quantization_method: start
masking_probability: 0.3
mask: notes

encoder: velocity
time_bins: 100

dataset:
sequence_duration: 5
sequence_step: 2

quantization:
start: 100
duration: 5
velocity: 3

device: "cuda:0"

log: True
log_frequency: 10
run_name: midi-bart-${now:%Y-%m-%d-%H-%M}
project: "midi-bart"

pre_defined_model: null

model:
encoder_layers: 6
encoder_ffn_dim: 2048
encoder_attention_heads: 8
decoder_layers: 6
decoder_ffn_dim: 2048
decoder_attention_heads: 8
d_model: 512
44 changes: 44 additions & 0 deletions configs/BARTvelocity.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
train:
num_epochs: 5
accum_iter: 10
batch_size: 8
base_lr: 5e-6
warmup: 4000
finetune: True

pretrained_checkpoint: midi-bart-2023-12-26-19-05.pt
model_name: BART
dataset_name: 'roszcz/maestro-v1-sustain'
target: velocity
seed: 26

overfit: False

tokens_per_note: "multiple"
time_quantization_method: start
dataset:
sequence_duration: 5
sequence_step: 2

quantization:
start: 20
duration: 3
velocity: 3

device: "cuda:0"

log: True
log_frequency: 10
run_name: midi-bart-${now:%Y-%m-%d-%H-%M}
project: "midi-bart"

pre_defined_model: null

model:
encoder_layers: 6
encoder_ffn_dim: 2048
encoder_attention_heads: 8
decoder_layers: 6
decoder_ffn_dim: 2048
decoder_attention_heads: 8
d_model: 512
153 changes: 110 additions & 43 deletions dashboard/denoise/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
import pandas as pd
import fortepyan as ff
import streamlit as st
from datasets import Dataset
from fortepyan import MidiPiece
from datasets import Dataset, load_dataset
from omegaconf import OmegaConf, DictConfig
from streamlit_pianoroll import from_fortepyan
from transformers import T5Config, T5ForConditionalGeneration
from transformers import T5Config, BartConfig, T5ForConditionalGeneration, BartForConditionalGeneration

from utils import vocab_size
from data.midiencoder import QuantizedMidiEncoder
from data.multitokencoder import MultiMidiEncoder
from data.quantizer import MidiQuantizer, MidiATQuantizer
from data.dataset import MaskedMidiDataset, load_cache_dataset
from data.maskedmidiencoder import MaskedMidiEncoder, MaskedNoteEncoder
from data.maskedmidiencoder import MaskedMidiEncoder, MaskedNoteEncoder, SingleMaskedNoteEncoder
from data.dataset import MaskedMidiDataset, build_translation_dataset, build_AT_translation_dataset

# Set the layout of the Streamlit page
st.set_page_config(layout="wide", page_title="T5 Denoise", page_icon=":musical_keyboard")
st.set_page_config(layout="wide", page_title="Denoise MIDI", page_icon=":musical_keyboard")

with st.sidebar:
devices = ["cpu"] + [f"cuda:{it}" for it in range(torch.cuda.device_count())]
Expand Down Expand Up @@ -74,23 +74,24 @@ def main():
)


def model_predictions_review(
checkpoint: dict,
train_cfg: DictConfig,
):
# load checkpoint, force dashboard device
dataset_cfg: DictConfig = train_cfg.dataset
def dataset_selection(train_cfg: DictConfig):
dataset_name: str = st.text_input(label="dataset", value=train_cfg.dataset_name)
split: str = st.text_input(label="split", value="test")

random_seed: int = st.selectbox(label="random seed", options=range(20))

# load translation dataset and create MyTokenizedMidiDataset
val_translation_dataset: Dataset = load_cache_dataset(
dataset_cfg=dataset_cfg,
dataset_name=dataset_name,
split=split,
)
val_translation_dataset: Dataset = load_dataset(path=dataset_name, split=split)
return val_translation_dataset


def create_dataset(base_dataset: Dataset, train_cfg: DictConfig):
# load checkpoint, force dashboard device
dataset_cfg: DictConfig = train_cfg.dataset

if "dstart" in dataset_cfg.quantization:
translation_dataset = build_translation_dataset(base_dataset, dataset_cfg)
else:
translation_dataset = build_AT_translation_dataset(base_dataset, dataset_cfg)

if train_cfg.time_quantization_method == "start":
quantizer = MidiATQuantizer(
n_duration_bins=dataset_cfg.quantization.duration,
Expand All @@ -117,35 +118,94 @@ def model_predictions_review(

if "mask" in train_cfg:
if train_cfg.mask == "notes":
encoder = MaskedNoteEncoder(base_encoder=base_tokenizer, masking_probability=train_cfg.masking_probability)
if train_cfg.model_name == "T5":
encoder = MaskedNoteEncoder(
base_encoder=base_tokenizer,
masking_probability=train_cfg.masking_probability,
)
else:
encoder = SingleMaskedNoteEncoder(
base_encoder=base_tokenizer,
masking_probability=train_cfg.masking_probability,
)
else:
encoder = MaskedMidiEncoder(base_encoder=base_tokenizer, masking_probability=train_cfg.masking_probability)
else:
encoder = MaskedMidiEncoder(base_encoder=base_tokenizer, masking_probability=train_cfg.masking_probability)

dataset = MaskedMidiDataset(
dataset=val_translation_dataset,
dataset=translation_dataset,
dataset_cfg=train_cfg.dataset,
base_encoder=base_tokenizer,
encoder=encoder,
)
return dataset, quantizer


def model_predictions_review(
checkpoint: dict,
train_cfg: DictConfig,
):
midi_dataset = dataset_selection(train_cfg=train_cfg)
source_df = midi_dataset.to_pandas()
composers = source_df.composer.unique()
selected_composer = st.selectbox(
label="Select composer",
options=composers,
index=3,
)

ids = source_df.composer == selected_composer
piece_titles = source_df[ids].title.unique()
selected_title = st.selectbox(
label="Select title",
options=piece_titles,
)
st.write(selected_title)

ids = (source_df.composer == selected_composer) & (source_df.title == selected_title)
part_df = source_df[ids]
part_dataset = midi_dataset.select(part_df.index.values)

dataset, quantizer = create_dataset(part_dataset, train_cfg=train_cfg)

random_seed: int = st.selectbox(label="random seed", options=range(20))

start_token_id: int = dataset.encoder.token_to_id["<CLS>"]
pad_token_id: int = dataset.encoder.token_to_id["<PAD>"]
config = T5Config(
vocab_size=vocab_size(train_cfg),
decoder_start_token_id=start_token_id,
pad_token_id=pad_token_id,
eos_token_id=pad_token_id,
use_cache=False,
d_model=train_cfg.model.d_model,
d_kv=train_cfg.model.d_kv,
d_ff=train_cfg.model.d_ff,
num_layers=train_cfg.model.num_layers,
num_heads=train_cfg.model.num_heads,
)
if train_cfg.model_name == "T5":
config = T5Config(
vocab_size=vocab_size(train_cfg),
decoder_start_token_id=start_token_id,
pad_token_id=pad_token_id,
eos_token_id=pad_token_id,
use_cache=False,
d_model=train_cfg.model.d_model,
d_kv=train_cfg.model.d_kv,
d_ff=train_cfg.model.d_ff,
num_layers=train_cfg.model.num_layers,
num_heads=train_cfg.model.num_heads,
)

model = T5ForConditionalGeneration(config)
else:
config = BartConfig(
vocab_size=vocab_size(train_cfg),
decoder_start_token_id=start_token_id,
pad_token_id=pad_token_id,
eos_token_id=pad_token_id,
use_cache=False,
d_model=train_cfg.model.d_model,
encoder_layers=train_cfg.model.encoder_layers,
decoder_layers=train_cfg.model.decoder_layers,
encoder_ffn_dim=train_cfg.model.encoder_ffn_dim,
decoder_ffn_dim=train_cfg.model.decoder_ffn_dim,
encoder_attention_heads=train_cfg.model.encoder_attention_heads,
decoder_attention_heads=train_cfg.model.decoder_attention_heads,
)

model = BartForConditionalGeneration(config)

model = T5ForConditionalGeneration(config)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval().to(DEVICE)

Expand All @@ -154,7 +214,7 @@ def model_predictions_review(

n_samples: int = 5
np.random.seed(random_seed)
idxs: np.ndarray[int] = np.random.randint(len(dataset), size=n_samples)
ids: int | np.ndarray[int] = np.random.randint(len(dataset), size=n_samples)

cols = st.columns(2)
with cols[0]:
Expand All @@ -167,7 +227,7 @@ def model_predictions_review(

# widget id for streamlit_pianoroll widget
key = 0
for record_id in idxs:
for record_id in ids:
# Numpy to int :(
record: dict = dataset.get_complete_record(int(record_id))
record_source: dict = json.loads(record["source"])
Expand All @@ -182,16 +242,21 @@ def model_predictions_review(
true_piece = MidiPiece(df=true_notes, source=record_source)
true_piece.time_shift(-true_piece.df.start.min())
try:
generated_df: pd.DataFrame = encoder.decode(src_token_ids, generated_token_ids)
generated_df: pd.DataFrame = dataset.encoder.decode(src_token_ids, generated_token_ids)
df = quantizer.apply_quantization(generated_df)
df["mask"] = generated_df["mask"]
# create quantized piece with predicted notes
pred_piece = MidiPiece(df)
unmasked_notes_df = generated_df[generated_df["mask"]]
unmasked_notes_piece = MidiPiece(unmasked_notes_df)

except ValueError:
# create an empty piece if the model did not generate the structure correctly
generated_df = pd.DataFrame([[23.0, 1.0, 1.0, 1.0, 1.0]], columns=midi_columns)
generated_df["mask"] = [False]
pred_piece = MidiPiece(generated_df)
unmasked_notes_df = generated_df[generated_df["mask"]]
unmasked_notes_piece = MidiPiece(unmasked_notes_df)

pred_piece.source = true_piece.source.copy()

Expand All @@ -208,18 +273,20 @@ def model_predictions_review(
st.pyplot(fig)
from_fortepyan(true_piece, key=key)
# Unchanged
st.markdown("**Source tokens:**")
st.markdown(source_tokens)
st.markdown("**Target tokens:**")
st.markdown(tgt_tokens)
with st.expander(label="original tokens", expanded=False):
st.markdown("**Source tokens:**")
st.markdown(source_tokens)
st.markdown("**Target tokens:**")
st.markdown(tgt_tokens)

with cols[1]:
# Predicted
fig = ff.view.draw_dual_pianoroll(pred_piece)
st.pyplot(fig)
from_fortepyan(pred_piece, key=key + 1)
st.markdown("**Predicted tokens:**")
st.markdown(generated_tokens)
from_fortepyan(pred_piece, secondary_piece=unmasked_notes_piece, key=key + 1)
with st.expander(label="predicted tokens", expanded=False):
st.markdown("**Predicted tokens:**")
st.markdown(generated_tokens)
key += 2


Expand Down
Loading