Skip to content

Commit

Permalink
snagrecover: dfu: Simplify timeout logic
Browse files Browse the repository at this point in the history
The current logic to fetch and enforce timeouts during DFU downloads is
unecessarily complicated. Simplify it.

Signed-off-by: Romain Gantois <[email protected]>
  • Loading branch information
rgantois committed Jul 1, 2024
1 parent 3e53405 commit 0368231
Showing 1 changed file with 17 additions and 27 deletions.
44 changes: 17 additions & 27 deletions src/snagrecover/protocols/dfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,29 +96,21 @@ def __init__(self, dev: usb.core.Device, stm32: bool = True):
"""
self.transfer_size = bMaxPacketSize0 * (wTransferSize // bMaxPacketSize0)
logger.info(f"Found DFU Functional descriptor: wTransferSize = {self.transfer_size}")


def check_timeout(timeout: int, t0: int) -> int:
t1 = round(time.time() * 1000)
if t1 - t0 < timeout:
logger.warning("Too soon to send another get_status command")
while round(time.time() * 1000) - t0 < timeout:
pass
t1 = round(time.time() * 1000)
return round(time.time() * 1000)
self.status_timeout = 100

def get_status(self) -> tuple:
# make sure to wait long enough after last get_status()
time.sleep(self.status_timeout / 1000.0)
# status = status polltimeout state iString
status = self.dev.ctrl_transfer(0xa1, 3, wValue=0, wIndex=0, data_or_wLength=6)# DFU_GETSTATUS
state = status[4]
timeout = int.from_bytes(bytes(status[1:3]), "little")
self.status_timeout = int.from_bytes(bytes(status[1:3]), "little")
logger.debug(f"DFU state: {state} DFU status: {DFU.status_codes[status[0]]}")
return (state,timeout)
return state

def download_and_run(self, blob: bytes, partid: int, offset: int, size: int, show_progress=False) -> bool:
self.set_partition(partid)
(state,timeout) = self.get_status()
t0 = round(time.time() * 1000)
state = self.get_status()
if state != DFU.state_codes["dfuIDLE"]:
raise ValueError(f"Incompatible state {state} detected")

Expand All @@ -130,36 +122,34 @@ def download_and_run(self, blob: bytes, partid: int, offset: int, size: int, sho
bytes_written = 0
for chunk in utils.dnload_iter(blob[offset:offset + size], self.transfer_size):
bytes_written += self.dev.ctrl_transfer(0x21, 1, wValue=block_index, wIndex=0, data_or_wLength=chunk)
progress = int(100 * bytes_written / size)
if show_progress:
print(f"\rprogress:{progress}%", end="")

# make sure to wait enough before sending next get_status
t0 = DFU.check_timeout(timeout, t0)
(state,timeout) = self.get_status()
state = self.get_status()
while state != DFU.state_codes["dfuDNLOAD-IDLE"]:
# make sure to wait enough before sending next get_status
t0 = DFU.check_timeout(timeout, t0)
(state,timeout) = self.get_status()
if state == DFU.state_codes["dfuERROR"]:
raise ValueError("DFU error code reported by device!")
state = self.get_status()
block_index += 1

# send zero-length download command to leave DFU mode and manifest
# firmware
bytes_written += self.dev.ctrl_transfer(0x21, 1, wValue=block_index, wIndex=0, data_or_wLength=None)
t0 = DFU.check_timeout(timeout, t0)
(state,timeout) = self.get_status()
state = self.get_status()
while state != DFU.state_codes["dfuIDLE"]:
t0 = DFU.check_timeout(timeout, t0)
if state == DFU.state_codes["dfuMANIFEST"]:
(state,timeout) = self.get_status()
state = self.get_status()
time.sleep(1)
elif state == DFU.state_codes["dfuMANIFEST-SYNC"]:
try:
# this fails on AM625, but is still necessary
(state,timeout) = self.get_status()
state = self.get_status()
except usb.core.USBError:
print("Could not read status after end of manifest phase")
return True
elif state == DFU.state_codes["dfuMANIFEST-WAIT-RESET"]:
self.detach()
return True

if show_progress:
print("")
logger.info("Done manifesting firmware")
Expand Down

0 comments on commit 0368231

Please sign in to comment.