Skip to content

Commit

Permalink
Merge pull request #764 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: support wf.query() in debug mode
  • Loading branch information
zjgemi authored Feb 22, 2024
2 parents 63bb0fa + a66ba11 commit 94d91a3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/dflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,10 @@ def __init__(
if self.template.slices.input_parameter:
name = self.template.slices.input_parameter[0]
value = self.inputs.parameters[name].value
self.with_param = argo_range(argo_len(value))
if hasattr(value, "__len__"):
self.with_param = argo_range(len(value))
else:
self.with_param = argo_range(argo_len(value))
else:
assert len(self.template.slices.input_artifact) > 0, "sliced "\
"input parameter or artifact must not be empty to infer "\
Expand Down
30 changes: 29 additions & 1 deletion src/dflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,31 @@ def query(
Returns:
an ArgoWorkflow object
"""
if config["mode"] == "debug":
nodes = {}
for step in self.query_step():
step.inputs.parameters = list(step.inputs.parameters.values())
step.inputs.artifacts = list(step.inputs.artifacts.values())
step.outputs.parameters = list(
step.outputs.parameters.values())
step.outputs.artifacts = list(step.outputs.artifacts.values())
nodes[step.id] = step.recover()
outputs = self.query_global_outputs()
if outputs is not None:
outputs.parameters = list(outputs.parameters.values())
outputs.artifacts = list(outputs.artifacts.values())
outputs = outputs.recover()
response = {
"metadata": {
"name": self.id,
},
"status": {
"phase": self.query_status(),
"nodes": nodes,
"outputs": outputs,
}
}
return ArgoWorkflow(response)
query_params = None
if fields is not None:
query_params = [('fields', ",".join(fields))]
Expand Down Expand Up @@ -955,6 +980,7 @@ def query_step(
"workflow": self.id,
"displayName": _name,
"key": s,
"id": s,
"startedAt": os.path.getmtime(stepdir),
"phase": _phase,
"type": _type,
Expand Down Expand Up @@ -998,6 +1024,7 @@ def query_step(
})
step = ArgoStep(step, self.id)
step_list.append(step)
step_list.sort(key=lambda x: x["startedAt"])
return step_list

return self.query().get_step(name=name, key=key, phase=phase, id=id,
Expand All @@ -1015,7 +1042,8 @@ def query_keys_of_steps(
a list of keys
"""
if config["mode"] == "debug":
return [step.key for step in self.query_step()]
return [step.key for step in self.query_step()
if step.key is not None]
try:
try:
response = self.api_instance.api_client.call_api(
Expand Down

0 comments on commit 94d91a3

Please sign in to comment.