Skip to content

Commit

Permalink
Refactor seeking to only store pts as int64 timestamp
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Oct 15, 2024
1 parent 82924d2 commit 2ea149d
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
40 changes: 25 additions & 15 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,26 +647,25 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
// the comment of canWeAvoidSeeking() for details.
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
TORCH_CHECK(
hasDesiredPts_,
"maybeSeekToBeforeDesiredPts() called when hasDesiredPts_ is false");
if (activeStreamIndices_.size() == 0) {
return;
}
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streams_[streamIndex];
streamInfo.discardFramesBeforePts =
*maybeDesiredPts_ * streamInfo.timeBase.den;
}

decodeStats_.numSeeksAttempted++;
// See comment for canWeAvoidSeeking() for details on why this optimization
// works.
bool mustSeek = false;
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streams_[streamIndex];
int64_t desiredPtsForStream = *maybeDesiredPts_ * streamInfo.timeBase.den;
if (!canWeAvoidSeekingForStream(
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
streamInfo,
streamInfo.currentPts,
*streamInfo.discardFramesBeforePts)) {
VLOG(5) << "Seeking is needed for streamIndex=" << streamIndex
<< " desiredPts=" << desiredPtsForStream
<< " desiredPts=" << *streamInfo.discardFramesBeforePts
<< " currentPts=" << streamInfo.currentPts;
mustSeek = true;
break;
Expand All @@ -678,7 +677,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
}
int firstActiveStreamIndex = *activeStreamIndices_.begin();
const auto& firstStreamInfo = streams_[firstActiveStreamIndex];
int64_t desiredPts = *maybeDesiredPts_ * firstStreamInfo.timeBase.den;
int64_t desiredPts = *firstStreamInfo.discardFramesBeforePts;

// For some encodings like H265, FFMPEG sometimes seeks past the point we
// set as the max_ts. So we use our own index to give it the exact pts of
Expand Down Expand Up @@ -718,10 +717,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
}
VLOG(9) << "Starting getDecodedOutputWithFilter()";
resetDecodeStats();
if (maybeDesiredPts_.has_value()) {
VLOG(9) << "maybeDesiredPts_=" << *maybeDesiredPts_;
if (hasDesiredPts_) {
maybeSeekToBeforeDesiredPts();
maybeDesiredPts_ = std::nullopt;
hasDesiredPts_ = false;
// FIXME: should we also reset each stream info's discardFramesBeforePts?
VLOG(9) << "seeking done";
}
auto seekDone = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -988,7 +987,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
validateFrameIndex(stream, frameIndex);

int64_t pts = stream.allFrames[frameIndex].pts;
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
setCursorPts(pts);
return getNextDecodedOutputNoDemux();
}

Expand All @@ -1010,7 +1009,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
"Invalid frame index=" + std::to_string(frameIndex));
}
int64_t pts = stream.allFrames[frameIndex].pts;
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
setCursorPts(pts);
auto rawSingleOutput = getNextRawDecodedOutputNoDemux();
if (stream.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
// We are using sws_scale to convert the frame to tensor. sws_scale can
Expand Down Expand Up @@ -1179,7 +1178,18 @@ VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
}

void VideoDecoder::setCursorPtsInSeconds(double seconds) {
maybeDesiredPts_ = seconds;
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streams_[streamIndex];
streamInfo.discardFramesBeforePts = seconds * streamInfo.timeBase.den;
}
hasDesiredPts_ = true;
}

void VideoDecoder::setCursorPts(int64_t pts) {
for (int streamIndex : activeStreamIndices_) {
streams_[streamIndex].discardFramesBeforePts = pts;
}
hasDesiredPts_ = true;
}

VideoDecoder::DecodeStats VideoDecoder::getDecodeStats() const {
Expand Down
7 changes: 4 additions & 3 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ class VideoDecoder {
void convertAVFrameToDecodedOutputOnCPU(
RawDecodedOutput& rawOutput,
DecodedOutput& output);
void setCursorPts(int64_t pts);

DecoderOptions options_;
ContainerMetadata containerMetadata_;
Expand All @@ -375,9 +376,9 @@ class VideoDecoder {
// Stores the stream indices of the active streams, i.e. the streams we are
// decoding and returning to the user.
std::set<int> activeStreamIndices_;
// Set when the user wants to seek and stores the desired pts that the user
// wants to seek to.
std::optional<double> maybeDesiredPts_;
// True when the user wants to seek. The actual pts values to seek to are
// stored in the per-stream metadata in discardFramesBeforePts.
bool hasDesiredPts_;

// Stores various internal decoding stats.
DecodeStats decodeStats_;
Expand Down

0 comments on commit 2ea149d

Please sign in to comment.