Skip to content

Commit

Permalink
trimming tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelJanas committed Nov 7, 2023
1 parent 9c6fd42 commit 933749c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 22 deletions.
2 changes: 1 addition & 1 deletion fortepyan/midi/structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __sanitize_get_index(self, index: slice) -> slice:
if not isinstance(index, slice):
raise TypeError("You can only get a part of MidiFile that has multiple notes: Index must be a slice")

# If you wan piece[:stop]
# If you want piece[:stop]
if not index.start:
index = slice(0, index.stop)

Expand Down
93 changes: 72 additions & 21 deletions tests/midi/test_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,100 @@

# Define a single comprehensive fixture
@pytest.fixture
def df_full():
return pd.DataFrame(
def sample_df():
df = pd.DataFrame(
{
"pitch": [60, 62],
"start": [0, 1],
"end": [1, 2],
"duration": [1, 1],
"velocity": [100, 100],
"start": [0, 1, 2, 3, 4],
"end": [1, 2, 3, 4, 5],
"duration": [1, 1, 1, 1, 1],
"pitch": [60, 62, 64, 65, 67],
"velocity": [80, 80, 80, 80, 80],
}
)
return df


# Tests using the fixture and modifying it according to the test's need
def test_with_start_end_duration(df_full):
piece = MidiPiece(df=df_full)
assert piece.df.shape[0] == 2
@pytest.fixture
def sample_midi_piece():
df = pd.DataFrame(
{
"start": [0, 1, 2, 3, 4],
"end": [1, 2, 3, 4, 5],
"duration": [1, 1, 1, 1, 1],
"pitch": [60, 62, 64, 65, 67],
"velocity": [80, 80, 80, 80, 80],
}
)
return MidiPiece(df)


def test_with_start_end_duration(sample_df):
piece = MidiPiece(df=sample_df)
assert piece.df.shape[0] == 5

def test_with_start_end(df_full):
df_mod = df_full.drop(columns=["duration"])

def test_with_start_end(sample_df):
df_mod = sample_df.drop(columns=["duration"])
piece = MidiPiece(df=df_mod)
assert "duration" in piece.df.columns


def test_with_start_duration(df_full):
df_mod = df_full.drop(columns=["end"])
def test_with_start_duration(sample_df):
df_mod = sample_df.drop(columns=["end"])
piece = MidiPiece(df=df_mod)
assert "end" in piece.df.columns


def test_with_end_duration(df_full):
df_mod = df_full.drop(columns=["start"])
def test_with_end_duration(sample_df):
df_mod = sample_df.drop(columns=["start"])
piece = MidiPiece(df=df_mod)
assert "start" in piece.df.columns


def test_missing_velocity(df_full):
df_mod = df_full.drop(columns=["velocity"])
def test_missing_velocity(sample_df):
df_mod = sample_df.drop(columns=["velocity"])
with pytest.raises(ValueError):
MidiPiece(df=df_mod)


def test_missing_pitch(df_full):
df_mod = df_full.drop(columns=["pitch"])
def test_missing_pitch(sample_df):
df_mod = sample_df.drop(columns=["pitch"])
with pytest.raises(ValueError):
MidiPiece(df=df_mod)


def test_midi_piece_duration_calculation(sample_df):
piece = MidiPiece(df=sample_df)
assert piece.duration == 5


def test_trim_within_bounds(sample_midi_piece):
# Test currently works as in the original code.
# We might want to change this behavior so that
# we do not treat the trimed piece as a new piece
trimmed_piece = sample_midi_piece.trim(2, 3)
assert len(trimmed_piece.df) == 2, "Trimmed MidiPiece should contain 2 notes."
assert trimmed_piece.df["start"].iloc[0] == 0, "New first note should start at 0 seconds."
assert trimmed_piece.df["end"].iloc[-1] == 2, "New last note should end at 2 seconds."


def test_trim_at_boundaries(sample_midi_piece):
trimmed_piece = sample_midi_piece.trim(0, 5)
assert trimmed_piece.size == sample_midi_piece.size, "Trimming at boundaries should not change the size."


def test_trim_out_of_bounds(sample_midi_piece):
# Assuming the behavior is to return an empty MidiPiece or raise an error
with pytest.raises(IndexError):
_ = sample_midi_piece.trim(6, 7) # Out of bounds, should raise an error


def test_trim_with_invalid_range(sample_midi_piece):
# Assuming the behavior is to raise an error with invalid range
with pytest.raises(IndexError):
_ = sample_midi_piece.trim(4, 2) # Invalid range, start is greater than finish


def test_source_update_after_trimming(sample_midi_piece):
trimmed_piece = sample_midi_piece.trim(1, 3)
assert trimmed_piece.source["start_time"] == 1, "Source start_time should be updated to reflect trimming."

0 comments on commit 933749c

Please sign in to comment.