From 2ea149dcf929cffe75507ac5dcdf7d3381b95eb2 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 15 Oct 2024 13:01:18 -0700 Subject: [PATCH] Refactor seeking to only store pts as int64 timestamp --- .../decoders/_core/VideoDecoder.cpp | 40 ++++++++++++------- src/torchcodec/decoders/_core/VideoDecoder.h | 7 ++-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 58a605fd..da93746f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -647,14 +647,12 @@ bool VideoDecoder::canWeAvoidSeekingForStream( // AVFormatContext if it is needed. We can skip seeking in certain cases. See // the comment of canWeAvoidSeeking() for details. void VideoDecoder::maybeSeekToBeforeDesiredPts() { + TORCH_CHECK( + hasDesiredPts_, + "maybeSeekToBeforeDesiredPts() called when hasDesiredPts_ is false"); if (activeStreamIndices_.size() == 0) { return; } - for (int streamIndex : activeStreamIndices_) { - StreamInfo& streamInfo = streams_[streamIndex]; - streamInfo.discardFramesBeforePts = - *maybeDesiredPts_ * streamInfo.timeBase.den; - } decodeStats_.numSeeksAttempted++; // See comment for canWeAvoidSeeking() for details on why this optimization @@ -662,11 +660,12 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { bool mustSeek = false; for (int streamIndex : activeStreamIndices_) { StreamInfo& streamInfo = streams_[streamIndex]; - int64_t desiredPtsForStream = *maybeDesiredPts_ * streamInfo.timeBase.den; if (!canWeAvoidSeekingForStream( - streamInfo, streamInfo.currentPts, desiredPtsForStream)) { + streamInfo, + streamInfo.currentPts, + *streamInfo.discardFramesBeforePts)) { VLOG(5) << "Seeking is needed for streamIndex=" << streamIndex - << " desiredPts=" << desiredPtsForStream + << " desiredPts=" << *streamInfo.discardFramesBeforePts << " currentPts=" << streamInfo.currentPts; mustSeek = true; break; @@ -678,7 +677,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { } int firstActiveStreamIndex = *activeStreamIndices_.begin(); const auto& firstStreamInfo = streams_[firstActiveStreamIndex]; - int64_t desiredPts = *maybeDesiredPts_ * firstStreamInfo.timeBase.den; + int64_t desiredPts = *firstStreamInfo.discardFramesBeforePts; // For some encodings like H265, FFMPEG sometimes seeks past the point we // set as the max_ts. So we use our own index to give it the exact pts of @@ -718,10 +717,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( } VLOG(9) << "Starting getDecodedOutputWithFilter()"; resetDecodeStats(); - if (maybeDesiredPts_.has_value()) { - VLOG(9) << "maybeDesiredPts_=" << *maybeDesiredPts_; + if (hasDesiredPts_) { maybeSeekToBeforeDesiredPts(); - maybeDesiredPts_ = std::nullopt; + hasDesiredPts_ = false; + // FIXME: should we also reset each stream info's discardFramesBeforePts? VLOG(9) << "seeking done"; } auto seekDone = std::chrono::high_resolution_clock::now(); @@ -988,7 +987,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( validateFrameIndex(stream, frameIndex); int64_t pts = stream.allFrames[frameIndex].pts; - setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); + setCursorPts(pts); return getNextDecodedOutputNoDemux(); } @@ -1010,7 +1009,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( "Invalid frame index=" + std::to_string(frameIndex)); } int64_t pts = stream.allFrames[frameIndex].pts; - setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); + setCursorPts(pts); auto rawSingleOutput = getNextRawDecodedOutputNoDemux(); if (stream.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { // We are using sws_scale to convert the frame to tensor. sws_scale can @@ -1179,7 +1178,18 @@ VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() { } void VideoDecoder::setCursorPtsInSeconds(double seconds) { - maybeDesiredPts_ = seconds; + for (int streamIndex : activeStreamIndices_) { + StreamInfo& streamInfo = streams_[streamIndex]; + streamInfo.discardFramesBeforePts = seconds * streamInfo.timeBase.den; + } + hasDesiredPts_ = true; +} + +void VideoDecoder::setCursorPts(int64_t pts) { + for (int streamIndex : activeStreamIndices_) { + streams_[streamIndex].discardFramesBeforePts = pts; + } + hasDesiredPts_ = true; } VideoDecoder::DecodeStats VideoDecoder::getDecodeStats() const { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8535a61e..a7ff06d4 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -367,6 +367,7 @@ class VideoDecoder { void convertAVFrameToDecodedOutputOnCPU( RawDecodedOutput& rawOutput, DecodedOutput& output); + void setCursorPts(int64_t pts); DecoderOptions options_; ContainerMetadata containerMetadata_; @@ -375,9 +376,9 @@ class VideoDecoder { // Stores the stream indices of the active streams, i.e. the streams we are // decoding and returning to the user. std::set activeStreamIndices_; - // Set when the user wants to seek and stores the desired pts that the user - // wants to seek to. - std::optional maybeDesiredPts_; + // True when the user wants to seek. The actual pts values to seek to are + // stored in the per-stream metadata in discardFramesBeforePts. + bool hasDesiredPts_; // Stores various internal decoding stats. DecodeStats decodeStats_;