diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 5cb6ccb..3534ce0 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -77,13 +77,13 @@ def __init__( embedder: nn.Module, optimizer: optim.Optimizer = None, lr: float = None, - null_update=True, + # null_update=True, params=None, writer=None, - OUTPATH=None, - base_name=None, fig=None, axs=None, + base_name=None, + OUTPATH=None, # <<<<<<<<<<<<<< ##################### video_frames=None, # # only need this to pass to animate_video_source optical_flows=None, @@ -108,7 +108,7 @@ def __init__( self.optimizer = optimizer self.dataframe = [] - self.null_update = null_update + # self.null_update = null_update self.params = params self.writer = writer self.OUTPATH = OUTPATH @@ -117,13 +117,13 @@ def __init__( self.axs = axs self.video_frames = video_frames self.optical_flows = optical_flows - if stabilization_augs is None: - stabilization_augs = [] + # if stabilization_augs is None: + # stabilization_augs = [] self.stabilization_augs = stabilization_augs self.last_frame_semantic = last_frame_semantic self.semantic_init_prompt = semantic_init_prompt - if init_augs is None: - init_augs = [] + # if init_augs is None: + # init_augs = [] self.init_augs = init_augs def run_steps( @@ -152,8 +152,26 @@ def run_steps( # and here we can check if the DirectImageGuide was # initialized with a renderer or not, and call self.renderer.update() # if appropriate - if not self.null_update: - self.update(i + i_offset, i + skipped_steps) + # if not self.null_update: + # self.update(i + i_offset, i + skipped_steps) + self.update( + model=self, + img=self.image_rep, + i=i + i_offset, + stage_i=i + skipped_steps, + params=self.params, + writer=self.writer, + fig=self.fig, + axs=self.axs, + base_name=self.base_name, + optical_flows=self.optical_flows, + video_frames=self.video_frames, + stabilization_augs=self.stabilization_augs, + last_frame_semantic=self.last_frame_semantic, + embedder=self.embedder, + init_augs=self.init_augs, + semantic_init_prompt=self.semantic_init_prompt, + ) losses = self.train( i + skipped_steps, prompts, @@ -343,277 +361,8 @@ def train( return {"TOTAL": float(total_loss)} - def report_out( - self, - i, - stage_i, - # model, - writer, - fig, # default to None... - axs, # default to None... - clear_every, - display_every, - approximate_vram_usage, - display_scale, - show_graphs, - show_palette, - ): - model = self - img = self.image_rep # pretty sure this is right - # DM: I bet this could be abstracted out into a report_out() function or whatever - if clear_every > 0 and i > 0 and i % clear_every == 0: - display.clear_output() - - if display_every > 0 and i % display_every == 0: - logger.debug(f"Step {i} losses:") - if model.dataframe: - rec = model.dataframe[0].iloc[-1] - logger.debug(rec) - if writer is not None: - for k, v in rec.iteritems(): - writer.add_scalar( - tag=f"losses/{k}", scalar_value=v, global_step=i - ) - - # does this VRAM stuff even do anything? - if approximate_vram_usage: - logger.debug("VRAM Usage:") - print_vram_usage() # update this function to use logger - # update this stuff to use/rely on tensorboard - display_width = int(img.image_shape[0] * display_scale) - display_height = int(img.image_shape[1] * display_scale) - if stage_i > 0 and show_graphs: - model.plot_losses(axs) - im = img.decode_image() - sidebyside = make_hbox( - im.resize((display_width, display_height), Image.LANCZOS), - fig, - ) - display.display(sidebyside) - else: - im = img.decode_image() - display.display( - im.resize((display_width, display_height), Image.LANCZOS) - ) - logger.debug(PixelImage) - logger.debug(type(PixelImage)) - if show_palette and isinstance(img, PixelImage): - logger.debug("Palette:") - display.display(img.render_pallet()) - - def save_out( - self, - i, - # img, - writer, - OUTPATH, - base_name, - save_every, - file_namespace, - backups, - ): - img = self.image_rep - # save - # if i > 0 and save_every > 0 and i % save_every == 0: - if i > 0 and save_every > 0 and (i + 1) % save_every == 0: - im = ( - img.decode_image() - ) # let's turn this into a property so decoding is cheap - # n = i // save_every - n = (i + 1) // save_every - filename = f"{OUTPATH}/{file_namespace}/{base_name}_{n}.png" - logger.debug(filename) - im.save(filename) - - im_np = np.array(im) - if writer is not None: - writer.add_image( - tag="pytti output", - # img_tensor=filename, # thought this would work? - img_tensor=im_np, - global_step=i, - dataformats="HWC", # this was the key - ) - - if backups > 0: - filename = f"backup/{file_namespace}/{base_name}_{n}.bak" - torch.save(img.state_dict(), filename) - if n > backups: - - # YOOOOOOO let's not start shell processes unnecessarily - # and then execute commands using string interpolation. - # Replace this with a pythonic folder removal, then see - # if we can't deprecate the folder removal entirely. What - # is the purpose of "backups" here? Just use the frames that - # are being written to disk. - subprocess.run( - [ - "rm", - f"backup/{file_namespace}/{base_name}_{n-backups}.bak", - ] - ) - - def update( - self, - # params, - # move to class - i, - stage_i, - ): + def update(self, model, img, i, stage_i, *args, **kwargs): """ - Orchestrates animation transformations and reporting + update hook called ever step """ - # logger.debug("model.update called") - - # ... I have regrets. - params = self.params - writer = self.writer - OUTPATH = self.OUTPATH - base_name = self.base_name - fig = self.fig - axs = self.axs - video_frames = self.video_frames - optical_flows = self.optical_flows - stabilization_augs = self.stabilization_augs - last_frame_semantic = self.last_frame_semantic - semantic_init_prompt = self.semantic_init_prompt - init_augs = self.init_augs - - model = self - img = self.image_rep - embedder = self.embedder - - model.report_out( - i=i, - stage_i=stage_i, - # model=model, - writer=writer, - fig=fig, # default to None... - axs=axs, # default to None... - clear_every=params.clear_every, - display_every=params.display_every, - approximate_vram_usage=params.approximate_vram_usage, - display_scale=params.display_scale, - show_graphs=params.show_graphs, - show_palette=params.show_palette, - ) - - model.save_out( - i=i, - # img=img, - writer=writer, - OUTPATH=OUTPATH, - base_name=base_name, - save_every=params.save_every, - file_namespace=params.file_namespace, - backups=params.backups, - ) - - # animate - ################ - ## TO DO: attach T as a class attribute - t = (i - params.pre_animation_steps) / ( - params.steps_per_frame * params.frames_per_second - ) - set_t(t) # this won't need to be a thing with `t`` attached to the class - if i >= params.pre_animation_steps: - # next_step_pil = None - if (i - params.pre_animation_steps) % params.steps_per_frame == 0: - logger.debug(f"Time: {t:.4f} seconds") - # update_rotoscopers( - ROTOSCOPERS.update_rotoscopers( - ((i - params.pre_animation_steps) // params.steps_per_frame + 1) - * params.frame_stride - ) - if params.reset_lr_each_frame: - model.set_optim(None) - - if params.animation_mode == "2D": - - next_step_pil = animate_2d( - translate_y=params.translate_y, - translate_x=params.translate_x, - rotate_2d=params.rotate_2d, - zoom_x_2d=params.zoom_x_2d, - zoom_y_2d=params.zoom_y_2d, - infill_mode=params.infill_mode, - sampling_mode=params.sampling_mode, - writer=writer, - i=i, - img=img, - t=t, # just here for logging - ) - - ########################### - elif params.animation_mode == "3D": - try: - im - except NameError: - im = img.decode_image() - with vram_usage_mode("Optical Flow Loss"): - # zoom_3d -> rename to animate_3d or transform_3d - flow, next_step_pil = zoom_3d( - img, - ( - params.translate_x, - params.translate_y, - params.translate_z_3d, - ), - params.rotate_3d, - params.field_of_view, - params.near_plane, - params.far_plane, - border_mode=params.infill_mode, - sampling_mode=params.sampling_mode, - stabilize=params.lock_camera, - ) - freeze_vram_usage() - - for optical_flow in optical_flows: - optical_flow.set_last_step(im) - optical_flow.set_target_flow(flow) - optical_flow.set_enabled(True) - - elif params.animation_mode == "Video Source": - - flow_im, next_step_pil = animate_video_source( - i=i, - img=img, - video_frames=video_frames, - optical_flows=optical_flows, - base_name=base_name, - pre_animation_steps=params.pre_animation_steps, - frame_stride=params.frame_stride, - steps_per_frame=params.steps_per_frame, - file_namespace=params.file_namespace, - reencode_each_frame=params.reencode_each_frame, - lock_palette=params.lock_palette, - save_every=params.save_every, - infill_mode=params.infill_mode, - sampling_mode=params.sampling_mode, - ) - - if params.animation_mode != "off": - try: - for aug in stabilization_augs: - aug.set_comp(next_step_pil) - aug.set_enabled(True) - if last_frame_semantic is not None: - last_frame_semantic.set_image(embedder, next_step_pil) - last_frame_semantic.set_enabled(True) - for aug in init_augs: - aug.set_enabled(False) - if semantic_init_prompt is not None: - semantic_init_prompt.set_enabled(False) - except UnboundLocalError: - logger.critical( - "\n\n-----< PYTTI-TOOLS > ------" - "If you are seeing this error, it might mean " - "you are using an option that expects you have " - "provided an init_image or video_file.\n\nIf you " - "think you are seeing this message in error, please " - "file an issue here: " - "https://github.com/pytti-tools/pytti-core/issues/new" - "-----< PYTTI-TOOLS > ------\n\n" - ) - raise + pass diff --git a/src/pytti/LossAug/LossOrchestratorClass.py b/src/pytti/LossAug/LossOrchestratorClass.py index 6e64c4e..736bfe6 100644 --- a/src/pytti/LossAug/LossOrchestratorClass.py +++ b/src/pytti/LossAug/LossOrchestratorClass.py @@ -13,6 +13,133 @@ from pytti.LossAug.EdgeLossClass import EdgeLoss +################################# + + +LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss} + + +def build_loss(weight_name, weight, name, img, pil_target): + # from pytti.LossAug import LOSS_DICT + + weight_name, suffix = weight_name.split("_", 1) + if weight_name == "direct": + Loss = type(img).get_preferred_loss() + else: + Loss = LOSS_DICT[weight_name] + out = Loss.TargetImage( + f"{weight_name} {name}:{weight}", img.image_shape, pil_target + ) + out.set_enabled(pil_target is not None) + return out + + +################################# + + +def configure_init_image( + init_image_pil: Image.Image, + restore: bool, + img: PixelImage, + params, + loss_augs, +): + + if init_image_pil is not None: + if not restore: + # move these logging statements into .encode_image() + logger.info("Encoding image...") + img.encode_image(init_image_pil) + logger.info("Encoded Image:") + # pretty sure this assumes we're in a notebook + display.display(img.decode_image()) + # set up init image prompt + init_augs = ["direct_init_weight"] + init_augs = [ + build_loss( + x, + params[x], + f"init image ({params.init_image})", + img, + init_image_pil, + ) + for x in init_augs + if params[x] not in ["", "0"] + ] + loss_augs.extend(init_augs) + if params.semantic_init_weight not in ["", "0"]: + semantic_init_prompt = parse_prompt( + embedder, + f"init image [{params.init_image}]:{params.semantic_init_weight}", + init_image_pil, + ) + prompts[0].append(semantic_init_prompt) + else: + semantic_init_prompt = None + else: + init_augs, semantic_init_prompt = [], None + + return init_augs, semantic_init_prompt, loss_augs, img + + +def configure_stabilization_augs(img, init_image_pil, params, loss_augs): + ## NB: in loss orchestrator impl, this begins with an init_image override. + ## possibly the source of the issue? + stabilization_augs = [ + "direct_stabilization_weight", + "depth_stabilization_weight", + "edge_stabilization_weight", + ] + stabilization_augs = [ + build_loss(x, params[x], "stabilization", img, init_image_pil) + for x in stabilization_augs + if params[x] not in ["", "0"] + ] + loss_augs.extend(stabilization_augs) + + return loss_augs, img, init_image_pil, stabilization_augs + + +def configure_optical_flows(img, params, loss_augs): + + if params.animation_mode == "Video Source": + if params.flow_stabilization_weight == "": + params.flow_stabilization_weight = "0" + optical_flows = [ + OpticalFlowLoss.TargetImage( + f"optical flow stabilization (frame {-2**i}):{params.flow_stabilization_weight}", + img.image_shape, + ) + for i in range(params.flow_long_term_samples + 1) + ] + for optical_flow in optical_flows: + optical_flow.set_enabled(False) + loss_augs.extend(optical_flows) + elif params.animation_mode == "3D" and params.flow_stabilization_weight not in [ + "0", + "", + ]: + optical_flows = [ + TargetFlowLoss.TargetImage( + f"optical flow stabilization:{params.flow_stabilization_weight}", + img.image_shape, + ) + ] + for optical_flow in optical_flows: + optical_flow.set_enabled(False) + loss_augs.extend(optical_flows) + else: + optical_flows = [] + # other loss augs + if params.smoothing_weight != 0: + loss_augs.append(TVLoss(weight=params.smoothing_weight)) + + return img, loss_augs, optical_flows + + +####################################### + + class LossBuilder: LOSS_DICT = {"edge": EdgeLoss, "depth": DepthLoss} diff --git a/src/pytti/rotoscoper.py b/src/pytti/rotoscoper.py index 0b2e835..57b464c 100644 --- a/src/pytti/rotoscoper.py +++ b/src/pytti/rotoscoper.py @@ -21,7 +21,9 @@ def update_rotoscopers(self, frame_n: int): ROTOSCOPERS = RotoscopingOrchestrator() # fml... - +rotoscopers = ROTOSCOPERS.rotoscopers +update_rotoscopers = ROTOSCOPERS.update_rotoscopers +clear_rotoscopers = ROTOSCOPERS.clear_rotoscopers # surprised we're not using opencv here. # let's call this another unnecessary subprocess call to deprecate. diff --git a/src/pytti/update_func.py b/src/pytti/update_func.py new file mode 100644 index 0000000..7e8eb20 --- /dev/null +++ b/src/pytti/update_func.py @@ -0,0 +1,294 @@ +from pathlib import Path +import os +import subprocess + +from PIL import Image +import numpy as np +import torch +from IPython import display +from loguru import logger + +from pytti import ( + parametric_eval, + set_t, + vram_usage_mode, + print_vram_usage, + freeze_vram_usage, +) + +from pytti.Transforms import ( + animate_2d, + zoom_2d, + zoom_3d, + animate_video_source, +) + +from pytti.rotoscoper import ( + # clear_rotoscopers, + update_rotoscopers, +) + +# OUTPATH = f"{os.getcwd()}/images_out/" +OUTPATH = f"{os.getcwd()}/images_out" + + +# Update is called each step. +def update( + model, + img, + i, + stage_i, + params=None, + writer=None, + fig=None, + axs=None, + base_name=None, + optical_flows=None, + video_frames=None, + stabilization_augs=None, + last_frame_semantic=None, + embedder=None, + init_augs=None, + semantic_init_prompt=None, +): + def report_out( + img, + i, + stage_i, + model, + writer, + fig, # default to None... + axs, # default to None... + clear_every, + display_every, + approximate_vram_usage, + display_scale, + show_graphs, + show_palette, + ): + + # DM: I bet this could be abstracted out into a report_out() function or whatever + if clear_every > 0 and i > 0 and i % clear_every == 0: + display.clear_output() + + if display_every > 0 and i % display_every == 0: + logger.debug(f"Step {i} losses:") + if model.dataframe: + rec = model.dataframe[0].iloc[-1] + logger.debug(rec) + for k, v in rec.iteritems(): + writer.add_scalar(tag=f"losses/{k}", scalar_value=v, global_step=i) + + # does this VRAM stuff even do anything? + if approximate_vram_usage: + logger.debug("VRAM Usage:") + print_vram_usage() # update this function to use logger + # update this stuff to use/rely on tensorboard + display_width = int(img.image_shape[0] * display_scale) + display_height = int(img.image_shape[1] * display_scale) + if stage_i > 0 and show_graphs: + model.plot_losses(axs) + im = img.decode_image() + sidebyside = make_hbox( + im.resize((display_width, display_height), Image.LANCZOS), + fig, + ) + display.display(sidebyside) + else: + im = img.decode_image() + display.display( + im.resize((display_width, display_height), Image.LANCZOS) + ) + if show_palette and isinstance(img, PixelImage): + logger.debug("Palette:") + display.display(img.render_pallet()) + + def save_out( + i, + img, + writer, + OUTPATH, + base_name, + save_every, + file_namespace, + backups, + ): + # save + if i > 0 and save_every > 0 and i % save_every == 0: + try: + im + except NameError: + im = img.decode_image() + n = i // save_every + Path(f"{OUTPATH}/{file_namespace}").mkdir( + parents=True, + exist_ok=True, + ) + filename = f"{OUTPATH}/{file_namespace}/{base_name}_{n}.png" + im.save(filename) + + im_np = np.array(im) + writer.add_image( + tag="pytti output", + # img_tensor=filename, # thought this would work? + img_tensor=im_np, + global_step=i, + dataformats="HWC", # this was the key + ) + + if backups > 0: + filename = f"backup/{file_namespace}/{base_name}_{n}.bak" + torch.save(img.state_dict(), filename) + if n > backups: + + # YOOOOOOO let's not start shell processes unnecessarily + # and then execute commands using string interpolation. + # Replace this with a pythonic folder removal, then see + # if we can't deprecate the folder removal entirely. What + # is the purpose of "backups" here? Just use the frames that + # are being written to disk. + subprocess.run( + [ + "rm", + f"backup/{file_namespace}/{base_name}_{n-backups}.bak", + ] + ) + + report_out( + img=img, + i=i, + stage_i=stage_i, + model=model, + writer=writer, + fig=fig, # default to None... + axs=axs, # default to None... + clear_every=params.clear_every, + display_every=params.display_every, + approximate_vram_usage=params.approximate_vram_usage, + display_scale=params.display_scale, + show_graphs=params.show_graphs, + show_palette=params.show_palette, + ) + + save_out( + i=i, + img=img, + writer=writer, + OUTPATH=OUTPATH, + base_name=base_name, + save_every=params.save_every, + file_namespace=params.file_namespace, + backups=params.backups, + ) + + # animate + ################ + t = (i - params.pre_animation_steps) / ( + params.steps_per_frame * params.frames_per_second + ) + set_t(t) + if i >= params.pre_animation_steps: + if (i - params.pre_animation_steps) % params.steps_per_frame == 0: + logger.debug(f"Time: {t:.4f} seconds") + update_rotoscopers( + ((i - params.pre_animation_steps) // params.steps_per_frame + 1) + * params.frame_stride + ) + if params.reset_lr_each_frame: + model.set_optim(None) + if params.animation_mode == "2D": + tx, ty = parametric_eval(params.translate_x), parametric_eval( + params.translate_y + ) + theta = parametric_eval(params.rotate_2d) + zx, zy = parametric_eval(params.zoom_x_2d), parametric_eval( + params.zoom_y_2d + ) + next_step_pil = zoom_2d( + img, + (tx, ty), + (zx, zy), + theta, + border_mode=params.infill_mode, + sampling_mode=params.sampling_mode, + ) + ################ + for k, v in { + "tx": tx, + "ty": ty, + "theta": theta, + "zx": zx, + "zy": zy, + "t": t, + }.items(): + + writer.add_scalar( + tag=f"translation_2d/{k}", scalar_value=v, global_step=i + ) + + ########################### + elif params.animation_mode == "3D": + try: + im + except NameError: + im = img.decode_image() + with vram_usage_mode("Optical Flow Loss"): + flow, next_step_pil = zoom_3d( + img, + ( + params.translate_x, + params.translate_y, + params.translate_z_3d, + ), + params.rotate_3d, + params.field_of_view, + params.near_plane, + params.far_plane, + border_mode=params.infill_mode, + sampling_mode=params.sampling_mode, + stabilize=params.lock_camera, + ) + freeze_vram_usage() + + for optical_flow in optical_flows: + optical_flow.set_last_step(im) + optical_flow.set_target_flow(flow) + optical_flow.set_enabled(True) + elif params.animation_mode == "Video Source": + flow_im, next_step_pil = animate_video_source( + i=i, + img=img, + video_frames=video_frames, + optical_flows=optical_flows, + base_name=base_name, + pre_animation_steps=params.pre_animation_steps, + frame_stride=params.frame_stride, + steps_per_frame=params.steps_per_frame, + file_namespace=params.file_namespace, + reencode_each_frame=params.reencode_each_frame, + lock_palette=params.lock_palette, + save_every=params.save_every, + infill_mode=params.infill_mode, + sampling_mode=params.sampling_mode, + ) + + if params.animation_mode != "off": + for aug in stabilization_augs: + aug.set_comp(next_step_pil) + aug.set_enabled(True) + if last_frame_semantic is not None: + last_frame_semantic.set_image(embedder, next_step_pil) + last_frame_semantic.set_enabled(True) + for aug in init_augs: + aug.set_enabled(False) + if semantic_init_prompt is not None: + semantic_init_prompt.set_enabled(False) + + +############################################################### +### + +# Wait.... we literally instantiated the model just before +# defining update here. +# I bet all of this can go in the DirectImageGuide class and then +# we can just instantiate that class with the config object. diff --git a/src/pytti/workhorse.py b/src/pytti/workhorse.py index cdfc981..9fc6baa 100644 --- a/src/pytti/workhorse.py +++ b/src/pytti/workhorse.py @@ -70,6 +70,16 @@ # writer = SummaryWriter(TB_LOGDIR) OUTPATH = f"{os.getcwd()}/images_out/" +####################################################### + +from pytti.LossAug.LossOrchestratorClass import ( + configure_init_image, + configure_stabilization_augs, + configure_optical_flows, +) + +####################################################### + # To do: ove remaining gunk into this... # class Renderer: # """ @@ -330,40 +340,93 @@ def do_run(): ####################################### - # set up losses - loss_orch = LossConfigurator( - init_image_pil=init_image_pil, - restore=restore, - img=img, - embedder=embedder, - prompts=prompts, - # params=params, - ######## - # To do: group arguments into param groups - animation_mode=params.animation_mode, - init_image=params.init_image, - direct_image_prompts=params.direct_image_prompts, - semantic_init_weight=params.semantic_init_weight, - semantic_stabilization_weight=params.semantic_stabilization_weight, - flow_stabilization_weight=params.flow_stabilization_weight, - flow_long_term_samples=params.flow_long_term_samples, - smoothing_weight=params.smoothing_weight, - ########### - direct_init_weight=params.direct_init_weight, - direct_stabilization_weight=params.direct_stabilization_weight, - depth_stabilization_weight=params.depth_stabilization_weight, - edge_stabilization_weight=params.edge_stabilization_weight, + loss_augs = [] + + ##################### + # set up init image # + ##################### + + (init_augs, semantic_init_prompt, loss_augs, img) = configure_init_image( + init_image_pil, + restore, + img, + params, + loss_augs, ) + # other image prompts + + loss_augs.extend( + type(img) + .get_preferred_loss() + .TargetImage(p.strip(), img.image_shape, is_path=True) + for p in params.direct_image_prompts.split("|") + if p.strip() + ) + + # stabilization ( loss_augs, - init_augs, - stabilization_augs, - optical_flows, - semantic_init_prompt, - last_frame_semantic, img, - ) = loss_orch.configure_losses() + init_image_pil, + stabilization_augs, + ) = configure_stabilization_augs(img, init_image_pil, params, loss_augs) + + ############################ + ### I think this bit might've been lost in the shuffle? + + if params.semantic_stabilization_weight not in ["0", ""]: + last_frame_semantic = parse_prompt( + embedder, + f"stabilization:{params.semantic_stabilization_weight}", + init_image_pil if init_image_pil else img.decode_image(), + ) + last_frame_semantic.set_enabled(init_image_pil is not None) + for scene in prompts: + scene.append(last_frame_semantic) + else: + last_frame_semantic = None + + ### + ############################ + + # optical flow + img, loss_augs, optical_flows = configure_optical_flows(img, params, loss_augs) + + # # set up losses + # loss_orch = LossConfigurator( + # init_image_pil=init_image_pil, + # restore=restore, + # img=img, + # embedder=embedder, + # prompts=prompts, + # # params=params, + # ######## + # # To do: group arguments into param groups + # animation_mode=params.animation_mode, + # init_image=params.init_image, + # direct_image_prompts=params.direct_image_prompts, + # semantic_init_weight=params.semantic_init_weight, + # semantic_stabilization_weight=params.semantic_stabilization_weight, + # flow_stabilization_weight=params.flow_stabilization_weight, + # flow_long_term_samples=params.flow_long_term_samples, + # smoothing_weight=params.smoothing_weight, + # ########### + # direct_init_weight=params.direct_init_weight, + # direct_stabilization_weight=params.direct_stabilization_weight, + # depth_stabilization_weight=params.depth_stabilization_weight, + # edge_stabilization_weight=params.edge_stabilization_weight, + # ) + + # ( + # loss_augs, + # init_augs, + # stabilization_augs, + # optical_flows, + # semantic_init_prompt, + # last_frame_semantic, + # img, + # ) = loss_orch.configure_losses() # Phase 4 - setup outputs ########################## @@ -437,25 +500,27 @@ def do_run(): # make the main model object model = DirectImageGuide( - img, - embedder, + image_rep=img, + embedder=embedder, lr=params.learning_rate, params=params, writer=writer, - OUTPATH=OUTPATH, - base_name=base_name, fig=fig, axs=axs, - video_frames=video_frames, - # these can be passed in together as the loss orchestrator + base_name=base_name, optical_flows=optical_flows, + video_frames=video_frames, stabilization_augs=stabilization_augs, - last_frame_semantic=last_frame_semantic, # fml... - semantic_init_prompt=semantic_init_prompt, + last_frame_semantic=last_frame_semantic, + # embedder=embedder, init_augs=init_augs, - null_update=False, # uh... we can do better. + semantic_init_prompt=semantic_init_prompt, ) + from pytti.update_func import update + + model.update = update + # Pretty sure this isn't necessary, Hydra should take care of saving # the run settings now settings_path = f"{OUTPATH}/{params.file_namespace}/{base_name}_settings.txt"