Skip to content

Commit

Permalink
Add stream_index as an option to VideoDecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Oct 9, 2024
1 parent 21aef92 commit c574980
Show file tree
Hide file tree
Showing 40 changed files with 280 additions and 106 deletions.
44 changes: 29 additions & 15 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numbers
from pathlib import Path
from typing import Literal, Tuple, Union
from typing import Literal, Optional, Tuple, Union

from torch import Tensor

Expand All @@ -22,14 +22,16 @@
class VideoDecoder:
"""A single-stream video decoder.
If the video contains multiple video streams, the :term:`best stream` is
used. This decoder always performs a :term:`scan` of the video.
This decoder always performs a :term:`scan` of the video.
Args:
source (str, ``Pathlib.path``, ``torch.Tensor``, or bytes): The source of the video.
- If ``str`` or ``Pathlib.path``: a path to a local video file.
- If ``bytes`` object or ``torch.Tensor``: the raw encoded video data.
stream_index (int, optional): Specifies which stream in the video to decode frames from.
Note that this index is absolute across all media types. If left unspecified, then
the :term:`best stream` is used.
dimension_order(str, optional): The dimension order of the decoded frames.
This can be either "NCHW" (default) or "NHWC", where N is the batch
size, C is the number of channels, H is the height, and W is the
Expand All @@ -45,11 +47,16 @@ class VideoDecoder:
Attributes:
metadata (VideoStreamMetadata): Metadata of the video stream.
stream_index (int): The stream index that this decoder is retrieving frames from. If a
stream index was provided at initialization, this is the same value. If it was left
unspecified, this is the :term:`best stream`.
"""

def __init__(
self,
source: Union[str, Path, bytes, Tensor],
*,
stream_index: Optional[int] = None,
dimension_order: Literal["NCHW", "NHWC"] = "NCHW",
):
if isinstance(source, str):
Expand All @@ -74,10 +81,12 @@ def __init__(
)

core.scan_all_streams_to_update_metadata(self._decoder)
core.add_video_stream(self._decoder, dimension_order=dimension_order)
core.add_video_stream(
self._decoder, stream_index=stream_index, dimension_order=dimension_order
)

self.metadata, self._stream_index = _get_and_validate_stream_metadata(
self._decoder
self.metadata, self.stream_index = _get_and_validate_stream_metadata(
self._decoder, stream_index
)

if self.metadata.num_frames_from_content is None:
Expand Down Expand Up @@ -114,7 +123,7 @@ def _getitem_int(self, key: int) -> Tensor:
)

frame_data, *_ = core.get_frame_at_index(
self._decoder, frame_index=key, stream_index=self._stream_index
self._decoder, frame_index=key, stream_index=self.stream_index
)
return frame_data

Expand All @@ -124,7 +133,7 @@ def _getitem_slice(self, key: slice) -> Tensor:
start, stop, step = key.indices(len(self))
frame_data, *_ = core.get_frames_in_range(
self._decoder,
stream_index=self._stream_index,
stream_index=self.stream_index,
start=start,
stop=stop,
step=step,
Expand Down Expand Up @@ -164,7 +173,7 @@ def get_frame_at(self, index: int) -> Frame:
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
)
data, pts_seconds, duration_seconds = core.get_frame_at_index(
self._decoder, frame_index=index, stream_index=self._stream_index
self._decoder, frame_index=index, stream_index=self.stream_index
)
return Frame(
data=data,
Expand Down Expand Up @@ -198,7 +207,7 @@ def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch:
raise IndexError(f"Step ({step}) must be greater than 0.")
frames = core.get_frames_in_range(
self._decoder,
stream_index=self._stream_index,
stream_index=self.stream_index,
start=start,
stop=stop,
step=step,
Expand Down Expand Up @@ -264,7 +273,7 @@ def get_frames_displayed_at(
)
frames = core.get_frames_by_pts_in_range(
self._decoder,
stream_index=self._stream_index,
stream_index=self.stream_index,
start_seconds=start_seconds,
stop_seconds=stop_seconds,
)
Expand All @@ -273,14 +282,19 @@ def get_frames_displayed_at(

def _get_and_validate_stream_metadata(
decoder: Tensor,
stream_index: Optional[int] = None,
) -> Tuple[core.VideoStreamMetadata, int]:
video_metadata = core.get_video_metadata(decoder)

best_stream_index = video_metadata.best_video_stream_index
if best_stream_index is None:
if best_stream_index is None and stream_index is None:
raise ValueError(
"The best video stream is unknown. " + _ERROR_REPORTING_INSTRUCTIONS
"The best video stream is unknown and there is no specified stream. "
+ _ERROR_REPORTING_INSTRUCTIONS
)

best_stream_metadata = video_metadata.streams[best_stream_index]
return (best_stream_metadata, best_stream_index)
if stream_index is None:
stream_index = best_stream_index

stream_metadata = video_metadata.streams[stream_index]
return (stream_metadata, stream_index)
120 changes: 84 additions & 36 deletions test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_create(self, source_kind):
== decoder.metadata.num_frames_from_content
== 390
)
assert decoder._stream_index == decoder.metadata.stream_index == 3
assert decoder.stream_index == decoder.metadata.stream_index == 3
assert decoder.metadata.duration_seconds == pytest.approx(13.013)
assert decoder.metadata.average_fps == pytest.approx(29.970029)
assert decoder.metadata.num_frames == 390
Expand All @@ -46,6 +46,9 @@ def test_create_fails(self):
with pytest.raises(TypeError, match="Unknown source type"):
decoder = VideoDecoder(123) # noqa

with pytest.raises(ValueError, match="No valid stream found"):
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=40) # noqa

def test_getitem_int(self):
decoder = VideoDecoder(NASA_VIDEO.path)

Expand Down Expand Up @@ -345,7 +348,7 @@ def test_get_frame_displayed_at(self):
def test_get_frame_displayed_at_h265(self):
# Non-regression test for https://github.com/pytorch/torchcodec/issues/179
decoder = VideoDecoder(H265_VIDEO.path)
ref_frame6 = H265_VIDEO.get_frame_by_name("frame000005")
ref_frame6 = H265_VIDEO.get_frame_data_by_index(5)
assert_tensor_equal(ref_frame6, decoder.get_frame_displayed_at(0.5).data)

def test_get_frame_displayed_at_fails(self):
Expand All @@ -357,56 +360,71 @@ def test_get_frame_displayed_at_fails(self):
with pytest.raises(IndexError, match="Invalid pts in seconds"):
frame = decoder.get_frame_displayed_at(100.0) # noqa

def test_get_frames_at(self):
decoder = VideoDecoder(NASA_VIDEO.path)
@pytest.mark.parametrize("stream_index", [0, 3, None])
def test_get_frames_at(self, stream_index):
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)

# test degenerate case where we only actually get 1 frame
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(start=9, stop=10)
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
start=9, stop=10, stream_index=stream_index
)
frames9 = decoder.get_frames_at(start=9, stop=10)

assert_tensor_equal(ref_frames9, frames9.data)
assert frames9.pts_seconds[0].item() == pytest.approx(0.3003, rel=1e-3)
assert frames9.duration_seconds[0].item() == pytest.approx(0.03337, rel=1e-3)
assert frames9.pts_seconds[0].item() == pytest.approx(
NASA_VIDEO.get_frame_info(9, stream_index=stream_index).pts_seconds,
rel=1e-3,
)
assert frames9.duration_seconds[0].item() == pytest.approx(
NASA_VIDEO.get_frame_info(9, stream_index=stream_index).duration_seconds,
rel=1e-3,
)

# test simple ranges
ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(start=0, stop=10)
ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range(
start=0, stop=10, stream_index=stream_index
)
frames0_9 = decoder.get_frames_at(start=0, stop=10)
assert frames0_9.data.shape == torch.Size(
[
10,
NASA_VIDEO.num_color_channels,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.get_num_color_channels(stream_index=stream_index),
NASA_VIDEO.get_height(stream_index=stream_index),
NASA_VIDEO.get_width(stream_index=stream_index),
]
)
assert_tensor_equal(ref_frames0_9, frames0_9.data)
assert_tensor_close(
NASA_VIDEO.get_pts_seconds_by_range(0, 10),
NASA_VIDEO.get_pts_seconds_by_range(0, 10, stream_index=stream_index),
frames0_9.pts_seconds,
)
assert_tensor_close(
NASA_VIDEO.get_duration_seconds_by_range(0, 10),
NASA_VIDEO.get_duration_seconds_by_range(0, 10, stream_index=stream_index),
frames0_9.duration_seconds,
)

# test steps
ref_frames0_8_2 = NASA_VIDEO.get_frame_data_by_range(start=0, stop=10, step=2)
ref_frames0_8_2 = NASA_VIDEO.get_frame_data_by_range(
start=0, stop=10, step=2, stream_index=stream_index
)
frames0_8_2 = decoder.get_frames_at(start=0, stop=10, step=2)
assert frames0_8_2.data.shape == torch.Size(
[
5,
NASA_VIDEO.num_color_channels,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.get_num_color_channels(stream_index=stream_index),
NASA_VIDEO.get_height(stream_index=stream_index),
NASA_VIDEO.get_width(stream_index=stream_index),
]
)
assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data)
assert_tensor_close(
NASA_VIDEO.get_pts_seconds_by_range(0, 10, 2),
NASA_VIDEO.get_pts_seconds_by_range(0, 10, 2, stream_index=stream_index),
frames0_8_2.pts_seconds,
)
assert_tensor_close(
NASA_VIDEO.get_duration_seconds_by_range(0, 10, 2),
NASA_VIDEO.get_duration_seconds_by_range(
0, 10, 2, stream_index=stream_index
),
frames0_8_2.duration_seconds,
)

Expand All @@ -418,7 +436,10 @@ def test_get_frames_at(self):

# an empty range is valid!
empty_frames = decoder.get_frames_at(5, 5)
assert_tensor_equal(empty_frames.data, NASA_VIDEO.empty_chw_tensor)
assert_tensor_equal(
empty_frames.data,
NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index),
)
assert_tensor_equal(empty_frames.pts_seconds, NASA_VIDEO.empty_pts_seconds)
assert_tensor_equal(
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
Expand Down Expand Up @@ -456,8 +477,9 @@ def test_dimension_order_fails(self):
with pytest.raises(ValueError, match="Invalid dimension order"):
VideoDecoder(NASA_VIDEO.path, dimension_order="NCDHW")

def test_get_frames_by_pts_in_range(self):
decoder = VideoDecoder(NASA_VIDEO.path)
@pytest.mark.parametrize("stream_index", [0, 3, None])
def test_get_frames_by_pts_in_range(self, stream_index):
decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index)

# Note that we are comparing the results of VideoDecoder's method:
# get_frames_displayed_at()
Expand All @@ -480,7 +502,10 @@ def test_get_frames_by_pts_in_range(self):
frames0_4 = decoder.get_frames_displayed_at(
decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(5).pts_seconds
)
assert_tensor_equal(frames0_4.data, NASA_VIDEO.get_frame_data_by_range(0, 5))
assert_tensor_equal(
frames0_4.data,
NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index),
)

# Range where the stop seconds is about halfway between pts values for two frames.
also_frames0_4 = decoder.get_frames_displayed_at(
Expand All @@ -495,7 +520,10 @@ def test_get_frames_by_pts_in_range(self):
decoder.get_frame_at(5).pts_seconds,
decoder.get_frame_at(10).pts_seconds,
)
assert_tensor_equal(frames5_9.data, NASA_VIDEO.get_frame_data_by_range(5, 10))
assert_tensor_equal(
frames5_9.data,
NASA_VIDEO.get_frame_data_by_range(5, 10, stream_index=stream_index),
)

# Range where we provide start_seconds and stop_seconds that are different, but
# also should land in the same window of time between two frame's pts values. As
Expand All @@ -504,41 +532,61 @@ def test_get_frames_by_pts_in_range(self):
decoder.get_frame_at(6).pts_seconds,
decoder.get_frame_at(6).pts_seconds + HALF_DURATION,
)
assert_tensor_equal(frame6.data, NASA_VIDEO.get_frame_data_by_range(6, 7))
assert_tensor_equal(
frame6.data,
NASA_VIDEO.get_frame_data_by_range(6, 7, stream_index=stream_index),
)

# Very small range that falls in the same frame.
frame35 = decoder.get_frames_displayed_at(
decoder.get_frame_at(35).pts_seconds,
decoder.get_frame_at(35).pts_seconds + 1e-10,
)
assert_tensor_equal(frame35.data, NASA_VIDEO.get_frame_data_by_range(35, 36))
assert_tensor_equal(
frame35.data,
NASA_VIDEO.get_frame_data_by_range(35, 36, stream_index=stream_index),
)

# Single frame where the start seconds is before frame i's pts, and the stop is
# after frame i's pts, but before frame i+1's pts. In that scenario, we expect
# to see frames i-1 and i.
frames7_8 = decoder.get_frames_displayed_at(
NASA_VIDEO.frames[8].pts_seconds - HALF_DURATION,
NASA_VIDEO.frames[8].pts_seconds + HALF_DURATION,
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
- HALF_DURATION,
NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds
+ HALF_DURATION,
)
assert_tensor_equal(
frames7_8.data,
NASA_VIDEO.get_frame_data_by_range(7, 9, stream_index=stream_index),
)
assert_tensor_equal(frames7_8.data, NASA_VIDEO.get_frame_data_by_range(7, 9))

# Start and stop seconds are the same value, which should not return a frame.
empty_frame = decoder.get_frames_displayed_at(
NASA_VIDEO.frames[4].pts_seconds,
NASA_VIDEO.frames[4].pts_seconds,
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds,
)
assert_tensor_equal(
empty_frame.data, NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index)
)
assert_tensor_equal(
empty_frame.pts_seconds,
NASA_VIDEO.empty_pts_seconds,
)
assert_tensor_equal(empty_frame.data, NASA_VIDEO.empty_chw_tensor)
assert_tensor_equal(empty_frame.pts_seconds, NASA_VIDEO.empty_pts_seconds)
assert_tensor_equal(
empty_frame.duration_seconds, NASA_VIDEO.empty_duration_seconds
)

# Start and stop seconds land within the first frame.
frame0 = decoder.get_frames_displayed_at(
NASA_VIDEO.frames[0].pts_seconds,
NASA_VIDEO.frames[0].pts_seconds + HALF_DURATION,
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds,
NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds
+ HALF_DURATION,
)
assert_tensor_equal(
frame0.data,
NASA_VIDEO.get_frame_data_by_range(0, 1, stream_index=stream_index),
)
assert_tensor_equal(frame0.data, NASA_VIDEO.get_frame_data_by_range(0, 1))

# We should be able to get all frames by giving the beginning and ending time
# for the stream.
Expand Down
Loading

0 comments on commit c574980

Please sign in to comment.