Skip to content

Commit

Permalink
Merge pull request #135 from pytti-tools/test
Browse files Browse the repository at this point in the history
merge test for release
  • Loading branch information
dmarx authored Apr 17, 2022
2 parents 520c29e + 0c94a36 commit 37098d8
Show file tree
Hide file tree
Showing 5 changed files with 558 additions and 321 deletions.
313 changes: 31 additions & 282 deletions src/pytti/ImageGuide.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def __init__(
embedder: nn.Module,
optimizer: optim.Optimizer = None,
lr: float = None,
null_update=True,
# null_update=True,
params=None,
writer=None,
OUTPATH=None,
base_name=None,
fig=None,
axs=None,
base_name=None,
OUTPATH=None, # <<<<<<<<<<<<<<
#####################
video_frames=None, # # only need this to pass to animate_video_source
optical_flows=None,
Expand All @@ -108,7 +108,7 @@ def __init__(
self.optimizer = optimizer
self.dataframe = []

self.null_update = null_update
# self.null_update = null_update
self.params = params
self.writer = writer
self.OUTPATH = OUTPATH
Expand All @@ -117,13 +117,13 @@ def __init__(
self.axs = axs
self.video_frames = video_frames
self.optical_flows = optical_flows
if stabilization_augs is None:
stabilization_augs = []
# if stabilization_augs is None:
# stabilization_augs = []
self.stabilization_augs = stabilization_augs
self.last_frame_semantic = last_frame_semantic
self.semantic_init_prompt = semantic_init_prompt
if init_augs is None:
init_augs = []
# if init_augs is None:
# init_augs = []
self.init_augs = init_augs

def run_steps(
Expand Down Expand Up @@ -152,8 +152,26 @@ def run_steps(
# and here we can check if the DirectImageGuide was
# initialized with a renderer or not, and call self.renderer.update()
# if appropriate
if not self.null_update:
self.update(i + i_offset, i + skipped_steps)
# if not self.null_update:
# self.update(i + i_offset, i + skipped_steps)
self.update(
model=self,
img=self.image_rep,
i=i + i_offset,
stage_i=i + skipped_steps,
params=self.params,
writer=self.writer,
fig=self.fig,
axs=self.axs,
base_name=self.base_name,
optical_flows=self.optical_flows,
video_frames=self.video_frames,
stabilization_augs=self.stabilization_augs,
last_frame_semantic=self.last_frame_semantic,
embedder=self.embedder,
init_augs=self.init_augs,
semantic_init_prompt=self.semantic_init_prompt,
)
losses = self.train(
i + skipped_steps,
prompts,
Expand Down Expand Up @@ -343,277 +361,8 @@ def train(

return {"TOTAL": float(total_loss)}

def report_out(
self,
i,
stage_i,
# model,
writer,
fig, # default to None...
axs, # default to None...
clear_every,
display_every,
approximate_vram_usage,
display_scale,
show_graphs,
show_palette,
):
model = self
img = self.image_rep # pretty sure this is right
# DM: I bet this could be abstracted out into a report_out() function or whatever
if clear_every > 0 and i > 0 and i % clear_every == 0:
display.clear_output()

if display_every > 0 and i % display_every == 0:
logger.debug(f"Step {i} losses:")
if model.dataframe:
rec = model.dataframe[0].iloc[-1]
logger.debug(rec)
if writer is not None:
for k, v in rec.iteritems():
writer.add_scalar(
tag=f"losses/{k}", scalar_value=v, global_step=i
)

# does this VRAM stuff even do anything?
if approximate_vram_usage:
logger.debug("VRAM Usage:")
print_vram_usage() # update this function to use logger
# update this stuff to use/rely on tensorboard
display_width = int(img.image_shape[0] * display_scale)
display_height = int(img.image_shape[1] * display_scale)
if stage_i > 0 and show_graphs:
model.plot_losses(axs)
im = img.decode_image()
sidebyside = make_hbox(
im.resize((display_width, display_height), Image.LANCZOS),
fig,
)
display.display(sidebyside)
else:
im = img.decode_image()
display.display(
im.resize((display_width, display_height), Image.LANCZOS)
)
logger.debug(PixelImage)
logger.debug(type(PixelImage))
if show_palette and isinstance(img, PixelImage):
logger.debug("Palette:")
display.display(img.render_pallet())

def save_out(
self,
i,
# img,
writer,
OUTPATH,
base_name,
save_every,
file_namespace,
backups,
):
img = self.image_rep
# save
# if i > 0 and save_every > 0 and i % save_every == 0:
if i > 0 and save_every > 0 and (i + 1) % save_every == 0:
im = (
img.decode_image()
) # let's turn this into a property so decoding is cheap
# n = i // save_every
n = (i + 1) // save_every
filename = f"{OUTPATH}/{file_namespace}/{base_name}_{n}.png"
logger.debug(filename)
im.save(filename)

im_np = np.array(im)
if writer is not None:
writer.add_image(
tag="pytti output",
# img_tensor=filename, # thought this would work?
img_tensor=im_np,
global_step=i,
dataformats="HWC", # this was the key
)

if backups > 0:
filename = f"backup/{file_namespace}/{base_name}_{n}.bak"
torch.save(img.state_dict(), filename)
if n > backups:

# YOOOOOOO let's not start shell processes unnecessarily
# and then execute commands using string interpolation.
# Replace this with a pythonic folder removal, then see
# if we can't deprecate the folder removal entirely. What
# is the purpose of "backups" here? Just use the frames that
# are being written to disk.
subprocess.run(
[
"rm",
f"backup/{file_namespace}/{base_name}_{n-backups}.bak",
]
)

def update(
self,
# params,
# move to class
i,
stage_i,
):
def update(self, model, img, i, stage_i, *args, **kwargs):
"""
Orchestrates animation transformations and reporting
update hook called ever step
"""
# logger.debug("model.update called")

# ... I have regrets.
params = self.params
writer = self.writer
OUTPATH = self.OUTPATH
base_name = self.base_name
fig = self.fig
axs = self.axs
video_frames = self.video_frames
optical_flows = self.optical_flows
stabilization_augs = self.stabilization_augs
last_frame_semantic = self.last_frame_semantic
semantic_init_prompt = self.semantic_init_prompt
init_augs = self.init_augs

model = self
img = self.image_rep
embedder = self.embedder

model.report_out(
i=i,
stage_i=stage_i,
# model=model,
writer=writer,
fig=fig, # default to None...
axs=axs, # default to None...
clear_every=params.clear_every,
display_every=params.display_every,
approximate_vram_usage=params.approximate_vram_usage,
display_scale=params.display_scale,
show_graphs=params.show_graphs,
show_palette=params.show_palette,
)

model.save_out(
i=i,
# img=img,
writer=writer,
OUTPATH=OUTPATH,
base_name=base_name,
save_every=params.save_every,
file_namespace=params.file_namespace,
backups=params.backups,
)

# animate
################
## TO DO: attach T as a class attribute
t = (i - params.pre_animation_steps) / (
params.steps_per_frame * params.frames_per_second
)
set_t(t) # this won't need to be a thing with `t`` attached to the class
if i >= params.pre_animation_steps:
# next_step_pil = None
if (i - params.pre_animation_steps) % params.steps_per_frame == 0:
logger.debug(f"Time: {t:.4f} seconds")
# update_rotoscopers(
ROTOSCOPERS.update_rotoscopers(
((i - params.pre_animation_steps) // params.steps_per_frame + 1)
* params.frame_stride
)
if params.reset_lr_each_frame:
model.set_optim(None)

if params.animation_mode == "2D":

next_step_pil = animate_2d(
translate_y=params.translate_y,
translate_x=params.translate_x,
rotate_2d=params.rotate_2d,
zoom_x_2d=params.zoom_x_2d,
zoom_y_2d=params.zoom_y_2d,
infill_mode=params.infill_mode,
sampling_mode=params.sampling_mode,
writer=writer,
i=i,
img=img,
t=t, # just here for logging
)

###########################
elif params.animation_mode == "3D":
try:
im
except NameError:
im = img.decode_image()
with vram_usage_mode("Optical Flow Loss"):
# zoom_3d -> rename to animate_3d or transform_3d
flow, next_step_pil = zoom_3d(
img,
(
params.translate_x,
params.translate_y,
params.translate_z_3d,
),
params.rotate_3d,
params.field_of_view,
params.near_plane,
params.far_plane,
border_mode=params.infill_mode,
sampling_mode=params.sampling_mode,
stabilize=params.lock_camera,
)
freeze_vram_usage()

for optical_flow in optical_flows:
optical_flow.set_last_step(im)
optical_flow.set_target_flow(flow)
optical_flow.set_enabled(True)

elif params.animation_mode == "Video Source":

flow_im, next_step_pil = animate_video_source(
i=i,
img=img,
video_frames=video_frames,
optical_flows=optical_flows,
base_name=base_name,
pre_animation_steps=params.pre_animation_steps,
frame_stride=params.frame_stride,
steps_per_frame=params.steps_per_frame,
file_namespace=params.file_namespace,
reencode_each_frame=params.reencode_each_frame,
lock_palette=params.lock_palette,
save_every=params.save_every,
infill_mode=params.infill_mode,
sampling_mode=params.sampling_mode,
)

if params.animation_mode != "off":
try:
for aug in stabilization_augs:
aug.set_comp(next_step_pil)
aug.set_enabled(True)
if last_frame_semantic is not None:
last_frame_semantic.set_image(embedder, next_step_pil)
last_frame_semantic.set_enabled(True)
for aug in init_augs:
aug.set_enabled(False)
if semantic_init_prompt is not None:
semantic_init_prompt.set_enabled(False)
except UnboundLocalError:
logger.critical(
"\n\n-----< PYTTI-TOOLS > ------"
"If you are seeing this error, it might mean "
"you are using an option that expects you have "
"provided an init_image or video_file.\n\nIf you "
"think you are seeing this message in error, please "
"file an issue here: "
"https://github.com/pytti-tools/pytti-core/issues/new"
"-----< PYTTI-TOOLS > ------\n\n"
)
raise
pass
Loading

0 comments on commit 37098d8

Please sign in to comment.