Skip to content

Commit

Permalink
Support valid data for multitask training (#257)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Introduced support for multitask validation data, allowing users to
specify multiple validation datasets through new arguments.
- Enhanced flexibility in handling validation data, accommodating both
single and multitask configurations.
- Added dynamic model freezing capability based on configuration
parameters.
- Improved configurability by allowing external configuration data to be
integrated into function execution.

- **Bug Fixes**
- Improved error handling for validation data inputs, ensuring robust
processing of various data structures.

- **Documentation**
- Updated documentation to clarify the usage of new arguments related to
multitask validation data.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: zjgemi <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zjgemi and pre-commit-ci[bot] authored Sep 3, 2024
1 parent 643e889 commit 8fb287e
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 36 deletions.
19 changes: 19 additions & 0 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,11 @@ def input_args():
doc_valid_data_prefix = "The prefix of validation data systems"
doc_valid_sys = "The validation data systems"
doc_valid_data_uri = "The URI of validation data"
doc_multi_valid_data = (
"The validation data for multitask, it should be a dict, whose keys are task names and each value is a dict"
"containing fields `prefix` and `sys` for initial data of each task"
)
doc_multi_valid_data_uri = "The URI of validation data for multitask"

return [
Argument("type_map", List[str], optional=False, doc=doc_type_map),
Expand Down Expand Up @@ -607,6 +612,20 @@ def input_args():
default=None,
doc=doc_valid_data_uri,
),
Argument(
"multi_valid_data",
dict,
optional=True,
default=None,
doc=doc_multi_valid_data,
),
Argument(
"multi_valid_data_uri",
str,
optional=True,
default=None,
doc=doc_multi_valid_data_uri,
),
]


Expand Down
31 changes: 23 additions & 8 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,28 @@ def workflow_concurrent_learning(
]
upload_python_packages = _upload_python_packages

valid_data = config["inputs"]["valid_data_sys"]
if config["inputs"]["valid_data_uri"] is not None:
valid_data = get_artifact_from_uri(config["inputs"]["valid_data_uri"])
elif valid_data is not None:
valid_data_prefix = config["inputs"]["valid_data_prefix"]
valid_data = get_systems_from_data(valid_data, valid_data_prefix)
valid_data = upload_artifact_and_print_uri(valid_data, "valid_data")
multitask = config["inputs"]["multitask"]
valid_data = None
if multitask:
if config["inputs"]["multi_valid_data_uri"] is not None:
valid_data = get_artifact_from_uri(config["inputs"]["multi_valid_data_uri"])
elif config["inputs"]["multi_valid_data"] is not None:
multi_valid_data = config["inputs"]["multi_valid_data"]
valid_data = {}
for k, v in multi_valid_data.items():
sys = v["sys"]
sys = get_systems_from_data(sys, v.get("prefix", None))
valid_data[k] = sys
valid_data = upload_artifact_and_print_uri(valid_data, "multi_valid_data")
else:
if config["inputs"]["valid_data_uri"] is not None:
valid_data = get_artifact_from_uri(config["inputs"]["valid_data_uri"])
elif config["inputs"]["valid_data_prefix"] is not None:
valid_data_prefix = config["inputs"]["valid_data_prefix"]
valid_data = config["inputs"]["valid_data_sys"]
valid_data = get_systems_from_data(valid_data, valid_data_prefix)
valid_data = upload_artifact_and_print_uri(valid_data, "valid_data")

concurrent_learning_op = make_concurrent_learning_op(
train_style,
explore_style,
Expand Down Expand Up @@ -591,7 +606,7 @@ def workflow_concurrent_learning(
init_data = upload_artifact_and_print_uri(init_data, "multi_init_data")
train_config["multitask"] = True
train_config["head"] = head
explore_config["head"] = head
explore_config["model_frozen_head"] = head
else:
if config["inputs"]["init_data_uri"] is not None:
init_data = get_artifact_from_uri(config["inputs"]["init_data_uri"])
Expand Down
22 changes: 17 additions & 5 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_input_sign(cls):
"init_model": Artifact(Path, optional=True),
"init_data": Artifact(NestedDict[Path]),
"iter_data": Artifact(List[Path]),
"valid_data": Artifact(List[Path], optional=True),
"valid_data": Artifact(NestedDict[Path], optional=True),
"optional_files": Artifact(List[Path], optional=True),
}
)
Expand Down Expand Up @@ -182,11 +182,10 @@ def execute(
finetune_mode = ip["optional_parameter"]["finetune_mode"]
config = ip["config"] if ip["config"] is not None else {}
impl = ip["config"].get("impl", "tensorflow")
dp_command = ip["config"].get("command", "dp").split()
assert impl in ["tensorflow", "pytorch"]
if impl == "pytorch":
dp_command = ["dp", "--pt"]
else:
dp_command = ["dp"]
dp_command.append("--pt")
finetune_args = config.get("finetune_args", "")
train_args = config.get("train_args", "")
config = RunDPTrain.normalize_config(config)
Expand Down Expand Up @@ -356,7 +355,7 @@ def write_data_to_input_script(
iter_data: List[Path],
auto_prob_str: str = "prob_sys_size",
major_version: str = "1",
valid_data: Optional[List[Path]] = None,
valid_data: Optional[Union[List[Path], Dict[str, List[Path]]]] = None,
):
odict = idict.copy()
if config["multitask"]:
Expand All @@ -368,6 +367,11 @@ def write_data_to_input_script(
if k == head:
v["training_data"]["systems"] += [str(ii) for ii in iter_data]
v["training_data"]["auto_prob"] = auto_prob_str
if valid_data is None:
v.pop("validation_data", None)
else:
v["validation_data"] = v.get("validation_data", {"batch_size": 1})
v["validation_data"]["systems"] = [str(ii) for ii in valid_data[k]]
return odict
data_list = [str(ii) for ii in init_data] + [str(ii) for ii in iter_data]
if major_version == "1":
Expand Down Expand Up @@ -490,6 +494,7 @@ def decide_init_model(

@staticmethod
def training_args():
doc_command = "The command for DP, 'dp' for default"
doc_impl = "The implementation/backend of DP. It can be 'tensorflow' or 'pytorch'. 'tensorflow' for default."
doc_init_model_policy = "The policy of init-model training. It can be\n\n\
- 'no': No init-model training. Traing from scratch.\n\n\
Expand All @@ -513,6 +518,13 @@ def training_args():
doc_init_model_with_finetune = "Use finetune for init model"
doc_train_args = "Extra arguments for dp train"
return [
Argument(
"command",
str,
optional=True,
default="dp",
doc=doc_command,
),
Argument(
"impl",
str,
Expand Down
52 changes: 29 additions & 23 deletions dpgen2/op/run_lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,28 +150,7 @@ def execute(
elif ext == ".pt":
# freeze model
mname = pytorch_model_name_pattern % (idx)
freeze_args = "-o %s" % mname
if config.get("head") is not None:
freeze_args += " --head %s" % config["head"]
freeze_cmd = "dp --pt freeze -c %s %s" % (mm, freeze_args)
ret, out, err = run_command(freeze_cmd, shell=True)
if ret != 0:
logging.error(
"".join(
(
"freeze failed\n",
"command was",
freeze_cmd,
"out msg",
out,
"\n",
"err msg",
err,
"\n",
)
)
)
raise TransientError("freeze failed")
freeze_model(mm, mname, config.get("model_frozen_head"))
else:
raise RuntimeError(
"Model file with extension '%s' is not supported" % ext
Expand Down Expand Up @@ -240,7 +219,9 @@ def lmp_args():
default=False,
doc=doc_shuffle_models,
),
Argument("head", str, optional=True, default=None, doc=doc_head),
Argument(
"model_frozen_head", str, optional=True, default=None, doc=doc_head
),
]

@staticmethod
Expand Down Expand Up @@ -310,3 +291,28 @@ def find_only_one_key(lmp_lines, key, raise_not_found=True):
else:
return None
return found[0]


def freeze_model(input_model, frozen_model, head=None):
freeze_args = "-o %s" % frozen_model
if head is not None:
freeze_args += " --head %s" % head
freeze_cmd = "dp --pt freeze -c %s %s" % (input_model, freeze_args)
ret, out, err = run_command(freeze_cmd, shell=True)
if ret != 0:
logging.error(
"".join(
(
"freeze failed\n",
"command was",
freeze_cmd,
"out msg",
out,
"\n",
"err msg",
err,
"\n",
)
)
)
raise TransientError("freeze failed")
37 changes: 37 additions & 0 deletions dpgen2/op/run_relax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from pathlib import (
Path,
Expand All @@ -6,6 +7,9 @@
List,
)

from dargs import (
Argument,
)
from dflow.python import (
OP,
OPIO,
Expand All @@ -14,13 +18,19 @@
OPIOSign,
)

from dpgen2.constants import (
pytorch_model_name_pattern,
)
from dpgen2.exploration.task import (
DiffCSPTaskGroup,
)

from .run_caly_model_devi import (
atoms2lmpdump,
)
from .run_lmp import (
freeze_model,
)


class RunRelax(OP):
Expand All @@ -29,6 +39,7 @@ def get_input_sign(cls):
return OPIOSign(
{
"diffcsp_task_grp": BigParameter(DiffCSPTaskGroup),
"expl_config": dict,
"task_path": Artifact(Path),
"models": Artifact(List[Path]),
}
Expand Down Expand Up @@ -73,6 +84,15 @@ def execute(
task_group = ip["diffcsp_task_grp"]
task = next(iter(task_group)) # Only support single task
models = ip["models"]
config = ip["expl_config"]
config = RunRelax.normalize_config(config)
if config["model_frozen_head"] is not None:
frozen_models = []
for idx in range(len(models)):
mname = pytorch_model_name_pattern % (idx)
freeze_model(models[idx], mname, config["model_frozen_head"])
frozen_models.append(Path(mname))
models = frozen_models
relaxer = Relaxer(models[0])
type_map = relaxer.calculator.dp.get_type_map()
fmax = task.fmax
Expand Down Expand Up @@ -178,3 +198,20 @@ def execute(
"model_devis": model_devis,
}
)

@staticmethod
def relax_args():
doc_head = "Select a head from multitask"
return [
Argument(
"model_frozen_head", str, optional=True, default=None, doc=doc_head
),
]

@staticmethod
def normalize_config(data={}):
ta = RunRelax.relax_args()
base = Argument("base", dict, ta)
data = base.normalize_value(data, trim_pattern="_*")
base.check_value(data, strict=False)
return data
1 change: 1 addition & 0 deletions dpgen2/superop/prep_run_diffcsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def _prep_run_diffcsp(
),
parameters={
"diffcsp_task_grp": expl_task_grp,
"expl_config": expl_config,
},
artifacts={
"models": models,
Expand Down
1 change: 1 addition & 0 deletions tests/op/test_run_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def testRunRelax(self, mocked_run):
op_in = OPIO(
{
"diffcsp_task_grp": task_group,
"expl_config": {},
"task_path": Path("task.000000"),
"models": [Path("model_0.pt"), Path("model_1.pt")],
}
Expand Down

0 comments on commit 8fb287e

Please sign in to comment.