From 567ffd799625c17c98ff44461e16c48b10224f15 Mon Sep 17 00:00:00 2001 From: Nick Heppert Date: Thu, 15 Feb 2024 15:37:56 +0100 Subject: [PATCH] Adds a flag to step in raft to store the flow --- flow_control/flow/module_raft.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flow_control/flow/module_raft.py b/flow_control/flow/module_raft.py index 8ba9949..86de79b 100644 --- a/flow_control/flow/module_raft.py +++ b/flow_control/flow/module_raft.py @@ -74,7 +74,7 @@ def _totorch(self, array): """ return torch.from_numpy(array)[None].float().permute(0, 3, 1, 2).cuda() - def step(self, img0, img1): + def step(self, img0, img1, store_flow: bool=True): """ compute flow @@ -100,7 +100,7 @@ def step(self, img0, img1): test_mode=True ) - self.flow_prev = forward_interpolate(flow_low[0])[None].cuda() + self.flow_prev = forward_interpolate(flow_low[0])[None].cuda() if store_flow else None return padder.unpad(flow_up[0]).permute(1, 2, 0).detach().cpu().numpy() def warp(self, x, flow, mode="bilinear"):