Skip to content

Commit

Permalink
Phase Sequence: modifier functions that copy must use attr.evolve (#961)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 337203373
  • Loading branch information
arsharma1 authored Oct 15, 2020
1 parent 94bebbc commit 70550f9
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 5 deletions.
15 changes: 10 additions & 5 deletions openhtf/core/phase_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,28 +155,33 @@ def _asdict(self) -> Dict[Text, Any]:

def with_args(self: SequenceClassT, **kwargs: Any) -> SequenceClassT:
"""Send these keyword-arguments when phases are called."""
return type(self)(
return attr.evolve(
self,
nodes=tuple(n.with_args(**kwargs) for n in self.nodes),
name=util.format_string(self.name, kwargs))

def with_plugs(self: SequenceClassT,
**subplugs: Type[base_plugs.BasePlug]) -> SequenceClassT:
"""Substitute plugs for placeholders for this phase, error on unknowns."""
return type(self)(
return attr.evolve(
self,
nodes=tuple(n.with_plugs(**subplugs) for n in self.nodes),
name=util.format_string(self.name, subplugs))

def load_code_info(self: SequenceClassT) -> SequenceClassT:
"""Load coded info for all contained phases."""
return type(self)(
nodes=tuple(n.load_code_info() for n in self.nodes), name=self.name)
return attr.evolve(
self,
nodes=tuple(n.load_code_info() for n in self.nodes),
name=self.name)

def apply_to_all_phases(
self: SequenceClassT, func: Callable[[phase_descriptor.PhaseDescriptor],
phase_descriptor.PhaseDescriptor]
) -> SequenceClassT:
"""Apply func to all contained phases."""
return type(self)(
return attr.evolve(
self,
nodes=tuple(n.apply_to_all_phases(func) for n in self.nodes),
name=self.name)

Expand Down
52 changes: 52 additions & 0 deletions test/core/phase_branches_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,58 @@ def test_as_dict(self):
}
self.assertEqual(expected, branch._asdict())

def test_with_args(self):
branch = phase_branches.BranchSequence(
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
nodes=(run_phase,),
name='name_{arg}')
expected = phase_branches.BranchSequence(
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
nodes=(run_phase.with_args(arg=1),),
name='name_1')

self.assertEqual(expected, branch.with_args(arg=1))

def test_with_plugs(self):

class MyPlug(htf.BasePlug):
pass

branch = phase_branches.BranchSequence(
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
nodes=(run_phase,),
name='name_{my_plug.__name__}')
expected = phase_branches.BranchSequence(
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
nodes=(run_phase.with_plugs(my_plug=MyPlug),),
name='name_MyPlug')

self.assertEqual(expected, branch.with_plugs(my_plug=MyPlug))

def test_load_code_info(self):
branch = phase_branches.BranchSequence(
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
nodes=(run_phase,))
expected = phase_branches.BranchSequence(
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
nodes=(run_phase.load_code_info(),))

self.assertEqual(expected, branch.load_code_info())

def test_apply_to_all_phases(self):

def do_rename(phase):
return _rename(phase, 'blah_blah')

branch = phase_branches.BranchSequence(
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
nodes=(run_phase,))
expected = phase_branches.BranchSequence(
phase_branches.DiagnosisCondition.on_all(BranchDiagResult.SET),
nodes=(do_rename(run_phase),))

self.assertEqual(expected, branch.apply_to_all_phases(do_rename))


class BranchSequenceIntegrationTest(htf_test.TestCase):

Expand Down

0 comments on commit 70550f9

Please sign in to comment.