Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save fix for t0==t1 #494

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

dkweiss31
Copy link

Addresses edge case raised in #488 when t0 == t1 and saveat.ts is not None. Additionally if saveat.t0 is True then those values were not updated either, which should be addressed by this PR. I've additionally included a test for this case.

WRT the implementation: while a loop is not very nice since everything could in principle be done in parallel, the below did not work for the ts part due to dynamic slicing errors. Let me know if there is a nicer workaround I could try :)

if subsaveat.ts is not None:
    _ts = subsaveat.ts
    save_idx = save_state.save_index
    ts = save_state.ts.at[save_idx: save_idx + len(_ts)].set(_ts)
    _ys = [subsaveat.fn(t1, yfinal, args)] * len(_ts)
    ys = save_state.ys.at[save_idx: save_idx + len(_ts)].set(_ys)
    save_state = SaveState(
         saveat_ts_index=save_idx + len(_ts),
         ts=ts,
         ys=ys,
         save_index=save_idx + len(_ts),
     )

@dkweiss31
Copy link
Author

To address some failing tests re reverse mode differentiation I converted it to a while_loop, but I'm still seeing some failed tests. Converting this to a draft for now

@dkweiss31 dkweiss31 marked this pull request as draft August 21, 2024 19:18
@dkweiss31 dkweiss31 marked this pull request as ready for review November 13, 2024 13:15
@dkweiss31
Copy link
Author

@patrick-kidger sorry for the long delay! I think the PR is ready for review now. All tests pass except for one of the tqdm progress bar tests involving jit: I'm not at all sure what is going on there?

Additionally I wanted to draw your attention to the line I wrote on line 773:

def _save_ts_impl(ts, fn, _save_state):
    def _cond_fun(__save_state):
        return __save_state.saveat_ts_index < len(_save_state.ts)

where I had to use _save_state.ts instead of ts in the conditional check because saveat_ts_index can already be 1 if _save_state.t0==True. So if I used ts, then the last entry doesn't get updated. This doesn't mirror exactly what's happening on lines 421-427, so I just wanted to briefly mention it.

@dkweiss31 dkweiss31 changed the title Save fix for to==t1 Save fix for t0==t1 Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant