Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pytorch/torchcodec into stream_index
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Oct 9, 2024
2 parents c574980 + e2ed57c commit 38a85bb
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 226 deletions.
267 changes: 46 additions & 221 deletions benchmarks/samplers/benchmark_samplers.py
Original file line number Diff line number Diff line change
@@ -1,227 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
from time import perf_counter_ns

import abc
import argparse
import importlib
import os

import decord
import numpy as np
import torch

import torch.utils.benchmark as benchmark
from torchcodec.samplers import (
IndexBasedSamplerArgs,
TimeBasedSamplerArgs,
VideoArgs,
VideoClipSampler,
)
from torchmultimodal.fb.utils.video_utils import (
ClipSamplerType,
VideoClipSampler as tmm_vcs,
)
from torchvision.datasets.video_clip_sampler import ( # @manual=//pytorch/vision:internal_datasets
TVVideoClipDecoder,
UniformClipSamplingStrategy,
VideoClipSampler as ta_vcs,
)


class AbstractSampler:
def __init__(self):
pass

@abc.abstractmethod
def sample_frames_uniformly(self, video_file, clips_per_video):
pass


class TorchCodecTimeBasedSampler(AbstractSampler):
def __init__(self):
pass

def sample_frames_uniformly(self, video_file, clips_per_video):
arr = np.fromfile(video_file, dtype=np.uint8)
video_tensor = torch.from_numpy(arr)
video_input = VideoArgs()
sampler_input = TimeBasedSamplerArgs(
sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1
)
sampler = VideoClipSampler(video_input, sampler_input)
return sampler(video_tensor)


class TorchCodecIndexBasedSampler(AbstractSampler):
def __init__(self):
pass

def sample_frames_uniformly(self, video_file, clips_per_video):
arr = np.fromfile(video_file, dtype=np.uint8)
video_tensor = torch.from_numpy(arr)
video_input = VideoArgs()
sampler_input = IndexBasedSamplerArgs(
sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1
)
sampler = VideoClipSampler(video_input, sampler_input)
return sampler(video_tensor)


class TorchCodecIndexBasedSamplerWithStackedOutput(AbstractSampler):
"""
On large batch, torch stack has impact on performance, but it's not obvious locally.
"""

def __init__(self):
pass

def sample_frames_uniformly(self, video_file, clips_per_video):
arr = np.fromfile(video_file, dtype=np.uint8)
video_tensor = torch.from_numpy(arr)
video_input = VideoArgs()
sampler_input = IndexBasedSamplerArgs(
sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1
)
sampler = VideoClipSampler(video_input, sampler_input)
clips = sampler(video_tensor)
return torch.stack([clip[0] for clip in clips])


class DecordSampler(AbstractSampler):
def __init__(self):
pass

def sample_frames_uniformly(self, video_file, clips_per_video):
decord.bridge.set_bridge("torch")
av_reader = decord.VideoReader(video_file)
num_frames = len(av_reader)
frame_indices = np.linspace(0, num_frames - 1, clips_per_video, dtype=int)
frames = av_reader.get_batch(frame_indices)
return frames


class TorchMMSamplerWithTorchVisionBackend(AbstractSampler):
"""
Here we use TorchMultimodal sampler as it's updated version on top of torchvision decoder.
"""

def __init__(self):
pass

def sample_frames_uniformly(self, video_file, clips_per_video):
arr = np.fromfile(video_file, dtype=np.uint8)
video_tensor = torch.from_numpy(arr)
sampler = tmm_vcs(
clip_sampler_type=ClipSamplerType("UNIFORM"),
clips_per_video=clips_per_video,
frames_per_clip=1,
frame_dilation=1,
)
return sampler(video_tensor)


class TorchVisionNewSamplerWithTorchVisionBackend(AbstractSampler):
def __init__(self):
pass

def sample_frames_uniformly(self, video_file, clips_per_video):
clip_sampling_strategy = UniformClipSamplingStrategy(
clips_per_video=clips_per_video
)
decoder = TVVideoClipDecoder(clip_length_in_frames=1, read_audio_stream=False)
sampler = ta_vcs(clip_sampling_strategy, decoder)
return sampler(str(video_file))


def main():
"""Benchmarks the performance of different samplers"""

parser = argparse.ArgumentParser()
parser.add_argument(
"--bm_small_video_speed",
help="Benchmark small video decoding speed",
default=True,
action=argparse.BooleanOptionalAction,
)
parser.add_argument(
"--bm_large_video_speed",
help="Benchmark large video decoding speed",
default=True,
action=argparse.BooleanOptionalAction,
from torchcodec.decoders import VideoDecoder
from torchcodec.samplers import clips_at_random_indices


def bench(f, *args, num_exp=100, warmup=0, **kwargs):

for _ in range(warmup):
f(*args, **kwargs)

times = []
for _ in range(num_exp):
start = perf_counter_ns()
f(*args, **kwargs)
end = perf_counter_ns()
times.append(end - start)
return torch.tensor(times).float()


def report_stats(times, unit="ms"):
mul = {
"ns": 1,
"µs": 1e-3,
"ms": 1e-6,
"s": 1e-9,
}[unit]
times = times * mul
std = times.std().item()
med = times.median().item()
print(f"{med = :.2f}{unit} +- {std:.2f}")
return med


def sample(num_clips):
decoder = VideoDecoder(VIDEO_PATH)
clips_at_random_indices(
decoder,
num_clips=num_clips,
num_frames_per_clip=10,
num_indices_between_frames=2,
)
parser.add_argument(
"--bm_video_speed_min_run_seconds",
help="Benchmark minimum run time, in seconds, to wait per datapoint",
type=float,
default=5.0,
)
args = parser.parse_args()

small_video_path = importlib.resources.path(__package__, "nasa_13013.mp4")
small_video_path = os.fspath(str(small_video_path))

large_video_path = importlib.resources.path(__package__, "853.mp4")
large_video_path = os.fspath(str(large_video_path))

clips_per_video = 8

sampler_dict = {}
sampler_dict["TorchCodecTimeBasedSampler"] = TorchCodecTimeBasedSampler()
sampler_dict["TorchCodecIndexBasedSampler"] = TorchCodecIndexBasedSampler()
sampler_dict["TorchCodecIndexBasedSamplerWithStackedOutput"] = (
TorchCodecIndexBasedSamplerWithStackedOutput()
)
sampler_dict["DecordSampler"] = DecordSampler()
sampler_dict["TorchMMSamplerWithTorchVisionBackend"] = (
TorchMMSamplerWithTorchVisionBackend()
)
sampler_dict["TorchVisionNewSamplerWithTorchVisionBackend"] = (
TorchVisionNewSamplerWithTorchVisionBackend()
)

results = []

for sampler_name, sampler in sampler_dict.items():
if args.bm_small_video_speed:
sampler_result = benchmark.Timer(
stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)",
globals={
"video_file": small_video_path,
"clips_per_video": clips_per_video,
"sampler": sampler,
},
label="uniform sampling latency for 700KB video",
sub_label=sampler_name,
description=f"uniform sampling {clips_per_video} frames",
)
results.append(
sampler_result.blocked_autorange(
min_run_time=args.bm_video_speed_min_run_seconds
)
)

if args.bm_large_video_speed:
if sampler_name == "TorchMMSamplerWithTorchVisionBackend":
continue
sampler_result = benchmark.Timer(
stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)",
globals={
"video_file": large_video_path,
"clips_per_video": clips_per_video,
"sampler": sampler,
},
label="uniform sampling latency for 50MB video",
sub_label=sampler_name,
description=f"uniform sampling {clips_per_video} frames",
)
results.append(
sampler_result.blocked_autorange(
min_run_time=args.bm_video_speed_min_run_seconds
)
)
VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"

compare = benchmark.Compare(results)
compare.print()
times = bench(sample, num_clips=1, num_exp=30, warmup=2)
report_stats(times, unit="ms")
times = bench(sample, num_clips=50, num_exp=30, warmup=2)
report_stats(times, unit="ms")
6 changes: 5 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
"""Return frame or frames as tensors, at the given index or range.
Args:
key(numbers.Integral or slice): The index or range of frame(s) to retrieve.
key(int or slice): The index or range of frame(s) to retrieve.
Returns:
torch.Tensor: The frame or frames at the given index or range.
Expand Down Expand Up @@ -296,5 +296,9 @@ def _get_and_validate_stream_metadata(
if stream_index is None:
stream_index = best_stream_index

# This should be logically true because of the above conditions, but type checker
# is not clever enough.
assert stream_index is not None

stream_metadata = video_metadata.streams[stream_index]
return (stream_metadata, stream_index)
8 changes: 4 additions & 4 deletions test/decoders/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
EXPECT_EQ(output.pts, 1001);

torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
readTensorFromDisk("nasa_13013.mp4.stream3.frame000000.pt");
torch::Tensor tensor1FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000001.pt");
readTensorFromDisk("nasa_13013.mp4.stream3.frame000001.pt");

EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector<long>({3, 270, 480}));
EXPECT_TRUE(torch::equal(tensor0FromOurDecoder, tensor0FromFFMPEG));
Expand Down Expand Up @@ -215,7 +215,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) {
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 3, 270, 480}));

torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
readTensorFromDisk("nasa_13013.mp4.stream3.frame000000.pt");
torch::Tensor tensorTime6FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.time6.000000.pt");

Expand All @@ -239,7 +239,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) {
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 270, 480, 3}));

torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
readTensorFromDisk("nasa_13013.mp4.stream3.frame000000.pt");
torch::Tensor tensorTime6FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.time6.000000.pt");

Expand Down

0 comments on commit 38a85bb

Please sign in to comment.