Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor seeking to only store pts as int64 timestamp #262

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,26 +647,25 @@ bool VideoDecoder::canWeAvoidSeekingForStream(
// AVFormatContext if it is needed. We can skip seeking in certain cases. See
// the comment of canWeAvoidSeeking() for details.
void VideoDecoder::maybeSeekToBeforeDesiredPts() {
TORCH_CHECK(
hasDesiredPts_,
"maybeSeekToBeforeDesiredPts() called when hasDesiredPts_ is false");
if (activeStreamIndices_.size() == 0) {
return;
}
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streams_[streamIndex];
streamInfo.discardFramesBeforePts =
*maybeDesiredPts_ * streamInfo.timeBase.den;
}

decodeStats_.numSeeksAttempted++;
// See comment for canWeAvoidSeeking() for details on why this optimization
// works.
bool mustSeek = false;
for (int streamIndex : activeStreamIndices_) {
StreamInfo& streamInfo = streams_[streamIndex];
int64_t desiredPtsForStream = *maybeDesiredPts_ * streamInfo.timeBase.den;
if (!canWeAvoidSeekingForStream(
streamInfo, streamInfo.currentPts, desiredPtsForStream)) {
streamInfo,
streamInfo.currentPts,
*streamInfo.discardFramesBeforePts)) {
VLOG(5) << "Seeking is needed for streamIndex=" << streamIndex
<< " desiredPts=" << desiredPtsForStream
<< " desiredPts=" << *streamInfo.discardFramesBeforePts
<< " currentPts=" << streamInfo.currentPts;
mustSeek = true;
break;
Expand All @@ -678,7 +677,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
}
int firstActiveStreamIndex = *activeStreamIndices_.begin();
const auto& firstStreamInfo = streams_[firstActiveStreamIndex];
int64_t desiredPts = *maybeDesiredPts_ * firstStreamInfo.timeBase.den;
int64_t desiredPts = *firstStreamInfo.discardFramesBeforePts;

// For some encodings like H265, FFMPEG sometimes seeks past the point we
// set as the max_ts. So we use our own index to give it the exact pts of
Expand Down Expand Up @@ -718,10 +717,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter(
}
VLOG(9) << "Starting getDecodedOutputWithFilter()";
resetDecodeStats();
if (maybeDesiredPts_.has_value()) {
VLOG(9) << "maybeDesiredPts_=" << *maybeDesiredPts_;
if (hasDesiredPts_) {
maybeSeekToBeforeDesiredPts();
maybeDesiredPts_ = std::nullopt;
hasDesiredPts_ = false;
// FIXME: should we also reset each stream info's discardFramesBeforePts?
VLOG(9) << "seeking done";
}
auto seekDone = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -988,7 +987,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
validateFrameIndex(stream, frameIndex);

int64_t pts = stream.allFrames[frameIndex].pts;
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
setCursorPts(pts);
return getNextDecodedOutputNoDemux();
}

Expand All @@ -1010,7 +1009,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
"Invalid frame index=" + std::to_string(frameIndex));
}
int64_t pts = stream.allFrames[frameIndex].pts;
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase));
setCursorPts(pts);
auto rawSingleOutput = getNextRawDecodedOutputNoDemux();
if (stream.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
// We are using sws_scale to convert the frame to tensor. sws_scale can
Expand Down Expand Up @@ -1179,7 +1178,18 @@ VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() {
}

void VideoDecoder::setCursorPtsInSeconds(double seconds) {
maybeDesiredPts_ = seconds;
for (int streamIndex : activeStreamIndices_) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the original change because its lazier -- when the user seeks a bunch of times only the last seek value should affect anything. The new change sets things eagerly

More importantly, if the user adds a stream after calling seek it doesn't seek that stream, so it's less robust too

StreamInfo& streamInfo = streams_[streamIndex];
streamInfo.discardFramesBeforePts = seconds * streamInfo.timeBase.den;
}
hasDesiredPts_ = true;
}

void VideoDecoder::setCursorPts(int64_t pts) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that this function makes sense without all streams having a common timebase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's internal, and we always know a stream index when calling it, I think it's fine if this takes a stream index.

for (int streamIndex : activeStreamIndices_) {
streams_[streamIndex].discardFramesBeforePts = pts;
}
hasDesiredPts_ = true;
}

VideoDecoder::DecodeStats VideoDecoder::getDecodeStats() const {
Expand Down
7 changes: 4 additions & 3 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ class VideoDecoder {
void convertAVFrameToDecodedOutputOnCPU(
RawDecodedOutput& rawOutput,
DecodedOutput& output);
void setCursorPts(int64_t pts);

DecoderOptions options_;
ContainerMetadata containerMetadata_;
Expand All @@ -375,9 +376,9 @@ class VideoDecoder {
// Stores the stream indices of the active streams, i.e. the streams we are
// decoding and returning to the user.
std::set<int> activeStreamIndices_;
// Set when the user wants to seek and stores the desired pts that the user
// wants to seek to.
std::optional<double> maybeDesiredPts_;
// True when the user wants to seek. The actual pts values to seek to are
// stored in the per-stream metadata in discardFramesBeforePts.
bool hasDesiredPts_ = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeking is hard to implement correctly and I am not confident about this change.

I think at least for debugging purposes we should keep the original double value around and when returning a frame after a seek we should ensure it's >= the value passed in here


// Stores various internal decoding stats.
DecodeStats decodeStats_;
Expand Down
Loading