Skip to content

Commit

Permalink
Add timing column conversion and MidiPiece
Browse files Browse the repository at this point in the history
addition
  • Loading branch information
SamuelJanas committed Nov 8, 2023
1 parent 6d92857 commit 772b403
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
34 changes: 34 additions & 0 deletions fortepyan/midi/structures.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from warnings import showwarning
from dataclasses import field, dataclass

import numpy as np
Expand Down Expand Up @@ -31,6 +32,10 @@ def __post_init__(self):
elif "duration" not in self.df.columns:
self.df["duration"] = self.df["end"] - self.df["start"]

# Convert timing columns to float to ensure consistency
for col in timing_columns:
self.df[col] = self.df[col].astype(float)

# Check for the absolutely required columns: 'pitch' and 'velocity'
if "pitch" not in self.df.columns:
raise ValueError("The DataFrame is missing the required column: 'pitch'.")
Expand Down Expand Up @@ -109,6 +114,35 @@ def __getitem__(self, index: slice) -> "MidiPiece":

return out

def __add__(self, other: "MidiPiece") -> "MidiPiece":
if not isinstance(other, MidiPiece):
raise TypeError("You can only add MidiPiece objects to other MidiPiece objects.")

# Adjust the start/end times of the second piece
other.df.start += self.end
other.df.end += self.end

# Concatenate the two pieces
df = pd.concat([self.df, other.df], ignore_index=True)

# make sure the other piece is not modified
other.df.start -= self.end
other.df.end -= self.end

# make sure that start and end times are floats
df.start = df.start.astype(float)
df.end = df.end.astype(float)

out = MidiPiece(df=df)

# Show warning as the piece might not be musically valid.
showwarning("The resulting piece may not be musically valid.", UserWarning, "fortepyan", lineno=1)

return out

def __len__(self) -> int:
return self.size

@property
def duration(self) -> float:
duration = self.df.end.max() - self.df.start.min()
Expand Down
56 changes: 55 additions & 1 deletion tests/midi/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
def sample_df():
df = pd.DataFrame(
{
"start": [0, 1, 2, 3, 4],
"start": [0.0, 1, 2, 3, 4],
"end": [1, 2, 3, 4, 5.5],
"duration": [1, 1, 1, 1, 1.5],
"pitch": [60, 62, 64, 65, 67],
Expand Down Expand Up @@ -114,3 +114,57 @@ def test_to_midi(sample_midi_piece):
midi_end_time = midi_track.get_end_time()

assert midi_end_time == expected_end_time, f"MIDI end time {midi_end_time} does not match expected {expected_end_time}"


def test_add_two_midi_pieces(sample_midi_piece):
# Create a second MidiPiece to add to the sample one
df2 = pd.DataFrame(
{
"start": [0, 1, 2],
"end": [1, 2, 3],
"duration": [1, 1, 1],
"pitch": [70, 72, 74],
"velocity": [80, 80, 80],
}
)
midi_piece2 = MidiPiece(df=df2)

# Add the two pieces together
combined_piece = sample_midi_piece + midi_piece2

# Check that the resulting piece has the correct number of notes
assert len(combined_piece) == len(sample_midi_piece) + len(midi_piece2)

# Check if duration has been adjusted
assert combined_piece.duration == sample_midi_piece.duration + midi_piece2.duration


def test_add_non_midi_piece(sample_midi_piece):
# Try to add a non-MidiPiece object to a MidiPiece
with pytest.raises(TypeError):
_ = sample_midi_piece + "not a MidiPiece"


def test_add_does_not_modify_originals(sample_midi_piece):
# Create a second MidiPiece to add to the sample one
df2 = pd.DataFrame(
{
"start": [0, 1, 2],
"end": [1, 2, 3],
"duration": [1, 1, 1],
"pitch": [70, 72, 74],
"velocity": [80, 80, 80],
}
)
midi_piece2 = MidiPiece(df=df2)

# Store the original dataframes for comparison
original_df1 = sample_midi_piece.df.copy()
original_df2 = midi_piece2.df.copy()

# Add the two pieces together
_ = sample_midi_piece + midi_piece2

# Check that the original pieces have not been modified
pd.testing.assert_frame_equal(sample_midi_piece.df, original_df1)
pd.testing.assert_frame_equal(midi_piece2.df, original_df2)

0 comments on commit 772b403

Please sign in to comment.