From 6bb2b3da66a6cb915618b72574416824b9f63ee6 Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Mon, 31 Oct 2022 12:15:35 -0400 Subject: [PATCH 1/6] :goal_net: Guardrail BBR only --- .../nipype_pipeline_engine/__init__.py | 35 ++- .../pipeline/nipype_pipeline_engine/engine.py | 138 +++++++++--- CPAC/pipeline/random_state/__init__.py | 8 +- CPAC/pipeline/random_state/seed.py | 81 +++++-- CPAC/pipeline/schema.py | 15 +- CPAC/qc/__init__.py | 24 +- CPAC/qc/globals.py | 42 ++++ CPAC/qc/qcmetrics.py | 163 +++++++++++--- CPAC/registration/exceptions.py | 41 ++++ CPAC/registration/guardrails.py | 208 ++++++++++++++++++ CPAC/registration/registration.py | 201 +++++++++-------- .../configs/pipeline_config_blank.yml | 6 + .../configs/pipeline_config_default.yml | 14 +- .../configs/pipeline_config_rbc-options.yml | 3 + CPAC/utils/docs.py | 37 ++++ 15 files changed, 818 insertions(+), 198 deletions(-) create mode 100644 CPAC/qc/globals.py create mode 100644 CPAC/registration/exceptions.py create mode 100644 CPAC/registration/guardrails.py diff --git a/CPAC/pipeline/nipype_pipeline_engine/__init__.py b/CPAC/pipeline/nipype_pipeline_engine/__init__.py index 48b445241b..66f4111cce 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/__init__.py +++ b/CPAC/pipeline/nipype_pipeline_engine/__init__.py @@ -1,25 +1,24 @@ -'''Module to import Nipype Pipeline engine and override some Classes. -See https://fcp-indi.github.io/docs/developer/nodes -for C-PAC-specific documentation. -See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html -for Nipype's documentation. - -Copyright (C) 2022 C-PAC Developers +# Copyright (C) 2022 C-PAC Developers -This file is part of C-PAC. +# This file is part of C-PAC. -C-PAC is free software: you can redistribute it and/or modify it under -the terms of the GNU Lesser General Public License as published by the -Free Software Foundation, either version 3 of the License, or (at your -option) any later version. +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. -C-PAC is distributed in the hope that it will be useful, but WITHOUT -ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or -FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public -License for more details. +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. -You should have received a copy of the GNU Lesser General Public -License along with C-PAC. If not, see .''' # noqa: E501 +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +'''Module to import Nipype Pipeline engine and override some Classes. +See https://fcp-indi.github.io/docs/developer/nodes +for C-PAC-specific documentation. +See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html +for Nipype's documentation.''' # noqa: E501 # pylint: disable=line-too-long from nipype.pipeline import engine as pe # import everything in nipype.pipeline.engine.__all__ from nipype.pipeline.engine import * # noqa: F401,F403 diff --git a/CPAC/pipeline/nipype_pipeline_engine/engine.py b/CPAC/pipeline/nipype_pipeline_engine/engine.py index 12e8808f1f..8695b1e536 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/engine.py +++ b/CPAC/pipeline/nipype_pipeline_engine/engine.py @@ -1,43 +1,57 @@ -'''Module to import Nipype Pipeline engine and override some Classes. -See https://fcp-indi.github.io/docs/developer/nodes -for C-PAC-specific documentation. -See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html -for Nipype's documentation. +# STATEMENT OF CHANGES: +# This file is derived from sources licensed under the Apache-2.0 terms, +# and this file has been changed. -STATEMENT OF CHANGES: - This file is derived from sources licensed under the Apache-2.0 terms, - and this file has been changed. +# CHANGES: +# * Supports just-in-time dynamic memory allocation +# * Skips doctests that require files that we haven't copied over +# * Applies a random seed +# * Supports overriding memory estimates via a log file and a buffer +# * Adds quotation marks around strings in dotfiles -CHANGES: - * Supports just-in-time dynamic memory allocation - * Skips doctests that require files that we haven't copied over - * Applies a random seed - * Supports overriding memory estimates via a log file and a buffer +# ORIGINAL WORK'S ATTRIBUTION NOTICE: +# Copyright (c) 2009-2016, Nipype developers -ORIGINAL WORK'S ATTRIBUTION NOTICE: - Copyright (c) 2009-2016, Nipype developers +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 - http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. +# Prior to release 0.12, Nipype was licensed under a BSD license. - Prior to release 0.12, Nipype was licensed under a BSD license. +# Modifications Copyright (C) 2022 C-PAC Developers -Modifications Copyright (C) 2022 C-PAC Developers +# This file is part of C-PAC. -This file is part of C-PAC.''' # noqa: E501 +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +'''Module to import Nipype Pipeline engine and override some Classes. +See https://fcp-indi.github.io/docs/developer/nodes +for C-PAC-specific documentation. +See https://nipype.readthedocs.io/en/latest/api/generated/nipype.pipeline.engine.html +for Nipype's documentation.''' # noqa: E501 # pylint: disable=line-too-long import os import re -from logging import getLogger from inspect import Parameter, Signature, signature +from logging import getLogger +from typing import Iterable, Tuple, Union from nibabel import load from nipype import logging from nipype.interfaces.utility import Function @@ -53,6 +67,7 @@ UNDEFINED_SIZE = (42, 42, 42, 1200) random_state_logger = getLogger('random') +logger = getLogger("nipype.workflow") def _check_mem_x_path(mem_x_path): @@ -399,10 +414,9 @@ def run(self, updatehash=False): if self.seed is not None: self._apply_random_seed() if self.seed_applied: - random_state_logger.info('%s', - '%s # (Atropos constant)' % - self.name if 'atropos' in - self.name else self.name) + random_state_logger.info('%s\t%s', '# (Atropos constant)' if + 'atropos' in self.name else + str(self.seed), self.name) return super().run(updatehash) @@ -483,6 +497,40 @@ def _configure_exec_nodes(self, graph): TypeError): self._handle_just_in_time_exception(node) + def connect_retries(self, nodes: Iterable['Node'], + connections: Iterable[Tuple['Node', Union[str, tuple], + str]]) -> None: + """Method to generalize making the same connections to try and + retry nodes. + + For each 3-tuple (``conn``) in ``connections``, will do + ``wf.connect(conn[0], conn[1], node, conn[2])`` for each ``node`` + in ``nodes`` + + Parameters + ---------- + nodes : iterable of Nodes + + connections : iterable of 3-tuples of (Node, str or tuple, str) + """ + wrong_conn_type_msg = (r'connect_retries `connections` argument ' + 'must be an iterable of (Node, str or ' + 'tuple, str) tuples.') + if not isinstance(connections, (list, tuple)): + raise TypeError(f'{wrong_conn_type_msg}: Given {connections}') + for node in nodes: + if not isinstance(node, Node): + raise TypeError('connect_retries requires an iterable ' + r'of nodes for the `nodes` parameter: ' + f'Given {node}') + for conn in connections: + if not all((isinstance(conn, (list, tuple)), len(conn) == 3, + isinstance(conn[0], Node), + isinstance(conn[1], (tuple, str)), + isinstance(conn[2], str))): + raise TypeError(f'{wrong_conn_type_msg}: Given {conn}') + self.connect(*conn[:2], node, conn[2]) + def _handle_just_in_time_exception(self, node): # pylint: disable=protected-access if hasattr(self, '_local_func_scans'): @@ -492,6 +540,32 @@ def _handle_just_in_time_exception(self, node): # TODO: handle S3 files node._apply_mem_x(UNDEFINED_SIZE) # noqa: W0212 + def nodes_and_guardrails(self, *nodes, registered, add_clones=True): + """Returns a two tuples of Nodes: (try, retry) and their + respective guardrails + + Parameters + ---------- + nodes : any number of Nodes + + Returns + ------- + nodes : tuple of Nodes + + guardrails : tuple of Nodes + """ + from CPAC.registration.guardrails import registration_guardrail_node, \ + retry_clone + nodes = list(nodes) + if add_clones is True: + nodes.extend([retry_clone(node) for node in nodes]) + guardrails = [None] * len(nodes) + for i, node in enumerate(nodes): + guardrails[i] = registration_guardrail_node( + f'guardrail_{node.name}', i) + self.connect(node, registered, guardrails[i], 'registered') + return tuple(nodes), tuple(guardrails) + def get_data_size(filepath, mode='xyzt'): """Function to return the size of a functional image (x * y * z * t) diff --git a/CPAC/pipeline/random_state/__init__.py b/CPAC/pipeline/random_state/__init__.py index 5956f33416..b590912417 100644 --- a/CPAC/pipeline/random_state/__init__.py +++ b/CPAC/pipeline/random_state/__init__.py @@ -1,6 +1,6 @@ '''Random state for C-PAC''' -from .seed import random_seed, random_seed_flags, set_up_random_state, \ - set_up_random_state_logger +from .seed import MAX_SEED, random_seed, random_seed_flags, \ + set_up_random_state, set_up_random_state_logger -__all__ = ['random_seed', 'random_seed_flags', 'set_up_random_state', - 'set_up_random_state_logger'] +__all__ = ['MAX_SEED', 'random_seed', 'random_seed_flags', + 'set_up_random_state', 'set_up_random_state_logger'] diff --git a/CPAC/pipeline/random_state/seed.py b/CPAC/pipeline/random_state/seed.py index b66e0d4611..21b67ccb90 100644 --- a/CPAC/pipeline/random_state/seed.py +++ b/CPAC/pipeline/random_state/seed.py @@ -1,5 +1,20 @@ -'''Functions to set, check, and log random seed''' -import os +# Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Functions to set, check, and log random seed""" import random from logging import getLogger @@ -13,9 +28,26 @@ from CPAC.utils.interfaces.ants import AI from CPAC.utils.monitoring.custom_logging import set_up_logger +MAX_SEED = np.iinfo(np.int32).max _seed = {'seed': None} +def increment_seed(node): + """Increment the random seed for a given node + + Parameters + ---------- + node : Node + + Returns + ------- + node : Node + """ + if isinstance(node.seed, int): + node.seed = seed_plus_1() + return node + + def random_random_seed(): '''Returns a random postive integer up to 2147483647 @@ -29,10 +61,10 @@ def random_random_seed(): Examples -------- - >>> 0 < random_random_seed() <= np.iinfo(np.int32).max + >>> 0 < random_random_seed() <= MAX_SEED True ''' - return random.randint(1, np.iinfo(np.int32).max) + return random.randint(1, MAX_SEED) def random_seed(): @@ -46,7 +78,7 @@ def random_seed(): ------- seed : int or None ''' - if _seed['seed'] == 'random': + if _seed['seed'] in ['random', None]: _seed['seed'] = random_random_seed() return _seed['seed'] @@ -137,6 +169,24 @@ def _reusable_flags(): } +def seed_plus_1(seed=None): + '''Increment seed, looping back to 1 at MAX_SEED + + Parameters + ---------- + seed : int, optional + Uses configured seed if not specified + + Returns + ------- + int + ''' + seed = random_seed() if seed is None else int(seed) + if seed < MAX_SEED: # increment random seed + return seed + 1 + return 1 # loop back to 1 + + def set_up_random_state(seed): '''Set global random seed @@ -160,8 +210,8 @@ def set_up_random_state(seed): >>> set_up_random_state(0) Traceback (most recent call last): ValueError: Valid random seeds are positive integers up to 2147483647, "random", or None, not 0 - >>> set_up_random_state(None) - + >>> 1 <= set_up_random_state(None) <= MAX_SEED + True ''' # noqa: E501 # pylint: disable=line-too-long if seed is not None: if seed == 'random': @@ -169,11 +219,13 @@ def set_up_random_state(seed): else: try: seed = int(seed) - assert 0 < seed <= np.iinfo(np.int32).max - except(ValueError, TypeError, AssertionError): - raise ValueError('Valid random seeds are positive integers up to ' - f'2147483647, "random", or None, not {seed}') - + assert 0 < seed <= MAX_SEED + except (ValueError, TypeError, AssertionError) as error: + raise ValueError( + 'Valid random seeds are positive integers up ' + f'to {MAX_SEED}, "random", or None, not {seed}' + ) from error + _seed['seed'] = seed return random_seed() @@ -185,5 +237,6 @@ def set_up_random_state_logger(log_dir): ---------- log_dir : str ''' - set_up_logger('random', level='info', log_dir=log_dir) - getLogger('random').info('seed: %s', random_seed()) + set_up_logger('random', filename='random.tsv', level='info', + log_dir=log_dir) + getLogger('random').info('seed\tnode') diff --git a/CPAC/pipeline/schema.py b/CPAC/pipeline/schema.py index 19764b6cfb..1be140c7c5 100644 --- a/CPAC/pipeline/schema.py +++ b/CPAC/pipeline/schema.py @@ -19,12 +19,12 @@ # pylint: disable=too-many-lines import re from itertools import chain, permutations -import numpy as np from pathvalidate import sanitize_filename -from voluptuous import All, ALLOW_EXTRA, Any, Capitalize, Coerce, \ +from voluptuous import All, ALLOW_EXTRA, Any, Capitalize, Coerce, Equal, \ ExactSequence, ExclusiveInvalid, In, Length, Lower, \ Match, Maybe, Optional, Range, Required, Schema from CPAC import docs_prefix +from CPAC.pipeline.random_state.seed import MAX_SEED from CPAC.utils.datatypes import ListFromItem from CPAC.utils.utils import YAML_BOOLS @@ -467,6 +467,10 @@ def sanitize(filename): }, }, 'registration_workflows': { + 'guardrails': { + 'thresholds': {metric: Maybe(float) for metric in + ('Dice', 'Jaccard', 'CrossCorr', 'Coverage')}, + }}, 'anatomical_registration': { 'run': bool1_1, 'resolution_for_anat': All(str, Match(resolution_regex)), @@ -526,9 +530,12 @@ def sanitize(filename): }, }, 'boundary_based_registration': { - 'run': forkable, + 'run': All(Coerce(ListFromItem), + [Any(bool1_1, All(Lower, Equal('fallback')))], + Length(max=3)), 'bbr_schedule': str, - 'bbr_wm_map': In({'probability_map', 'partial_volume_map'}), + 'bbr_wm_map': In({'probability_map', + 'partial_volume_map'}), 'bbr_wm_mask_args': str, 'reference': In({'whole-head', 'brain'}) }, diff --git a/CPAC/qc/__init__.py b/CPAC/qc/__init__.py index 75ee654fec..810d06aedc 100644 --- a/CPAC/qc/__init__.py +++ b/CPAC/qc/__init__.py @@ -1,2 +1,22 @@ -from .utils import * -from .qc import * +# Copyright (C) 2013-2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Quality control utilities for C-PAC""" +from CPAC.qc.globals import registration_guardrail_thresholds, \ + update_thresholds +from CPAC.qc.qcmetrics import qc_masks +__all__ = ['qc_masks', 'registration_guardrail_thresholds', + 'update_thresholds'] diff --git a/CPAC/qc/globals.py b/CPAC/qc/globals.py new file mode 100644 index 0000000000..e4a05d8d9d --- /dev/null +++ b/CPAC/qc/globals.py @@ -0,0 +1,42 @@ +# Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Global QC values""" +_REGISTRATION_GUARDRAIL_THRESHOLDS = {'thresholds': {}} + + +def registration_guardrail_thresholds() -> dict: + """Get registration guardrail thresholds + + Returns + ------- + dict + """ + return _REGISTRATION_GUARDRAIL_THRESHOLDS['thresholds'] + + +def update_thresholds(thresholds) -> None: + """Set a registration guardrail threshold + + Parameters + ---------- + thresholds : dict of {str: float or int} + + Returns + ------- + None + """ + _REGISTRATION_GUARDRAIL_THRESHOLDS['thresholds'].update(thresholds) diff --git a/CPAC/qc/qcmetrics.py b/CPAC/qc/qcmetrics.py index 6db977c495..b45430020c 100644 --- a/CPAC/qc/qcmetrics.py +++ b/CPAC/qc/qcmetrics.py @@ -1,24 +1,88 @@ +# Modifications: Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . + +# Original code: BSD 3-Clause License + +# Copyright (c) 2020, Lifespan Informatics and Neuroimaging Center + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. """QC metrics from XCP-D v0.0.9 Ref: https://github.com/PennLINC/xcp_d/tree/0.0.9 """ +# LGPL-3.0-or-later: Module docstring and lint exclusions # pylint: disable=invalid-name, redefined-outer-name +# BSD-3-Clause: imports and unspecified sections import nibabel as nb import numpy as np -def regisQ(bold2t1w_mask, t1w_mask, bold2template_mask, template_mask): - reg_qc = {'coregDice': [dc(bold2t1w_mask, t1w_mask)], - 'coregJaccard': [jc(bold2t1w_mask, t1w_mask)], - 'coregCrossCorr': [crosscorr(bold2t1w_mask, t1w_mask)], - 'coregCoverage': [coverage(bold2t1w_mask, t1w_mask)], - 'normDice': [dc(bold2template_mask, template_mask)], - 'normJaccard': [jc(bold2template_mask, template_mask)], - 'normCrossCorr': [crosscorr(bold2template_mask, template_mask)], - 'normCoverage': [coverage(bold2template_mask, template_mask)]} - return reg_qc +# BSD-3-Clause +def coverage(input1, input2): + """Estimate the coverage between two masks.""" + input1 = nb.load(input1).get_fdata() + input2 = nb.load(input2).get_fdata() + input1 = np.atleast_1d(input1.astype(np.bool)) + input2 = np.atleast_1d(input2.astype(np.bool)) + intsec = np.count_nonzero(input1 & input2) + if np.sum(input1) > np.sum(input2): + smallv = np.sum(input2) + else: + smallv = np.sum(input1) + cov = float(intsec)/float(smallv) + return cov + + +# BSD-3-Clause +def crosscorr(input1, input2): + r"""cross correlation: compute cross correction bewteen input masks""" + input1 = nb.load(input1).get_fdata() + input2 = nb.load(input2).get_fdata() + input1 = np.atleast_1d(input1.astype(np.bool)).flatten() + input2 = np.atleast_1d(input2.astype(np.bool)).flatten() + cc = np.corrcoef(input1, input2)[0][1] + return cc +# BSD-3-Clause def dc(input1, input2): r""" Dice coefficient @@ -71,6 +135,7 @@ def dc(input1, input2): return dc +# BSD-3-Clause def jc(input1, input2): r""" Jaccard coefficient @@ -106,26 +171,62 @@ def jc(input1, input2): return jc -def crosscorr(input1, input2): - r"""cross correlation: compute cross correction bewteen input masks""" - input1 = nb.load(input1).get_fdata() - input2 = nb.load(input2).get_fdata() - input1 = np.atleast_1d(input1.astype(np.bool)).flatten() - input2 = np.atleast_1d(input2.astype(np.bool)).flatten() - cc = np.corrcoef(input1, input2)[0][1] - return cc +# LGPL-3.0-or-later +def _prefix_regqc_keys(qc_dict: dict, prefix: str) -> str: + """Prepend string to each key in a qc dict + Parameters + ---------- + qc_dict : dict + output of ``qc_masks`` -def coverage(input1, input2): - """Estimate the coverage between two masks.""" - input1 = nb.load(input1).get_fdata() - input2 = nb.load(input2).get_fdata() - input1 = np.atleast_1d(input1.astype(np.bool)) - input2 = np.atleast_1d(input2.astype(np.bool)) - intsec = np.count_nonzero(input1 & input2) - if np.sum(input1) > np.sum(input2): - smallv = np.sum(input2) - else: - smallv = np.sum(input1) - cov = float(intsec)/float(smallv) - return cov + prefix : str + string to prepend + + Returns + ------- + dict + """ + return {f'{prefix}{_key}': _value for _key, _value in qc_dict.items()} + + +# BSD-3-Clause: logic +# LGPL-3.0-or-later: docstring and refactored function +def qc_masks(registered_mask: str, native_mask: str) -> dict: + """Return QC measures for coregistration + + Parameters + ---------- + registered_mask : str + path to registered mask + + native_mask : str + path to native-space mask + + Returns + ------- + dict + """ + return {'Dice': [dc(registered_mask, native_mask)], + 'Jaccard': [jc(registered_mask, native_mask)], + 'CrossCorr': [crosscorr(registered_mask, native_mask)], + 'Coverage': [coverage(registered_mask, native_mask)]} + + +# BSD-3-Clause: name and signature +# LGPL-3.0-or-later: docstring and refactored function +def regisQ(bold2t1w_mask: str, t1w_mask: str, bold2template_mask: str, + template_mask: str) -> dict: + """Collect coregistration QC measures + + Parameters + ---------- + bold2t1w_mask, t1w_mask, bold2template_mask, template_mask : str + + Returns + ------- + dict + """ + return {**_prefix_regqc_keys(qc_masks(bold2t1w_mask, t1w_mask), 'coreg'), + **_prefix_regqc_keys(qc_masks(bold2template_mask, template_mask), + 'norm')} diff --git a/CPAC/registration/exceptions.py b/CPAC/registration/exceptions.py new file mode 100644 index 0000000000..d962ddfa30 --- /dev/null +++ b/CPAC/registration/exceptions.py @@ -0,0 +1,41 @@ +# Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Custom registration exceptions""" + + +class BadRegistrationError(ValueError): + """Exception for when a QC measure for a registration falls below a + specified threshold""" + def __init__(self, *args, metric=None, value=None, threshold=None, + **kwargs): + """ + Parameters + ---------- + metric : str + QC metric + + value : float + calculated QC value + + threshold : float + specified threshold + """ + msg = "Registration failed quality control" + if all(arg is not None for arg in (metric, value, threshold)): + msg += f" ({metric}: {value} < {threshold})" + msg += "." + super().__init__(msg, *args, **kwargs) diff --git a/CPAC/registration/guardrails.py b/CPAC/registration/guardrails.py new file mode 100644 index 0000000000..1329cad97f --- /dev/null +++ b/CPAC/registration/guardrails.py @@ -0,0 +1,208 @@ +# Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Guardrails to protect against bad registrations""" +import logging +from typing import Tuple +from nipype.interfaces.utility import Function, Merge, Select +# pylint: disable=unused-import +from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow +from CPAC.pipeline.random_state.seed import increment_seed +from CPAC.qc import qc_masks, registration_guardrail_thresholds +from CPAC.registration.exceptions import BadRegistrationError +from CPAC.registration.utils import hardcoded_reg +from CPAC.utils.docs import retry_docstring + + +# noqa: F401 +def guardrail_selection(wf: 'Workflow', node1: 'Node', node2: 'Node', + output_key: str = 'registered', + guardrail_node: 'Node' = None) -> Node: + """Generate requisite Nodes for choosing a path through the graph + with retries. + + Takes two nodes to choose an output from. These nodes are assumed + to be guardrail nodes if `output_key` and `guardrail_node` are not + specified. + + A ``nipype.interfaces.utility.Merge`` is generated, connecting + ``output_key`` from ``node1`` and ``node2`` in that order. + + A ``nipype.interfaces.utility.Select`` node is generated taking the + output from the generated ``Merge`` and using the ``failed_qc`` + output of ``guardrail_node`` (``node1`` if ``guardrail_node`` is + unspecified). + + All relevant connections are made in the given Workflow. + + The ``Select`` node is returned; its output is keyed ``out`` and + contains the value of the given ``output_key`` (``registered`` if + unspecified). + + Parameters + ---------- + wf : Workflow + + node1, node2 : Node + first try, retry + + output_key : str + field to choose + + guardrail_node : Node + guardrail to collect 'failed_qc' from if not node1 + + Returns + ------- + select : Node + """ + # pylint: disable=redefined-outer-name,reimported,unused-import + from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow + if guardrail_node is None: + guardrail_node = node1 + name = node1.name + if output_key != 'registered': + name = f'{name}_{output_key}' + choices = Node(Merge(2), run_without_submitting=True, + name=f'{name}_choices') + select = Node(Select(), run_without_submitting=True, + name=f'choose_{name}') + wf.connect([(node1, choices, [(output_key, 'in1')]), + (node2, choices, [(output_key, 'in2')]), + (choices, select, [('out', 'inlist')]), + (guardrail_node, select, [('failed_qc', 'index')])]) + return select + + +def registration_guardrail(registered: str, reference: str, + retry: bool = False, retry_num: int = 0 + ) -> Tuple[str, int]: + """Check QC metrics post-registration and throw an exception if + metrics are below given thresholds. + + If inputs point to images that are not masks, images will be + binarized before being compared. + + .. seealso:: + + :py:mod:`CPAC.qc.qcmetrics` + Documentation of the :py:mod:`CPAC.qc.qcmetrics` module. + + Parameters + ---------- + registered, reference : str + path to mask + + retry : bool, optional + can retry? + + retry_num : int, optional + how many previous tries? + + Returns + ------- + registered_mask : str + path to mask + + failed_qc : int + metrics met specified thresholds?, used as index for selecting + outputs + .. seealso:: + + :py:mod:`guardrail_selection` + """ + logger = logging.getLogger('nipype.workflow') + qc_metrics = qc_masks(registered, reference) + failed_qc = 0 + for metric, threshold in registration_guardrail_thresholds().items(): + if threshold is not None: + value = qc_metrics.get(metric) + if isinstance(value, list): + value = value[0] + if value < threshold: + failed_qc = 1 + with open(f'{registered}.failed_qc', 'w', + encoding='utf-8') as _f: + _f.write(f'{metric}: {value} < {threshold}') + if retry: + registered = f'{registered}-failed' + else: + bad_registration = BadRegistrationError( + metric=metric, value=value, threshold=threshold) + logger.error(str(bad_registration)) + if retry_num: + # if we've already retried, raise the error + raise bad_registration + return registered, failed_qc + + +def registration_guardrail_node(name=None, retry_num=0): + """Convenience method to get a new registration_guardrail Node + + Parameters + ---------- + name : str, optional + + retry_num : int, optional + how many previous tries? + + Returns + ------- + Node + """ + if name is None: + name = 'registration_guardrail' + node = Node(Function(input_names=['registered', 'reference', 'retry_num'], + output_names=['registered', 'failed_qc'], + imports=['import logging', + 'from typing import Tuple', + 'from CPAC.qc import qc_masks, ' + 'registration_guardrail_thresholds', + 'from CPAC.registration.guardrails ' + 'import BadRegistrationError'], + function=registration_guardrail), name=name) + if retry_num: + node.inputs.retry_num = retry_num + return node + + +def retry_clone(node: 'Node') -> 'Node': + """Function to clone a node, name the clone, and increment its + random seed + + Parameters + ---------- + node : Node + + Returns + ------- + Node + """ + return increment_seed(node.clone(f'retry_{node.name}')) + + +# pylint: disable=missing-function-docstring,too-many-arguments +@retry_docstring(hardcoded_reg) +def retry_hardcoded_reg(moving_brain, reference_brain, moving_skull, + reference_skull, ants_para, moving_mask=None, + reference_mask=None, fixed_image_mask=None, + interp=None, reg_with_skull=0, previous_failure=False): + if not previous_failure: + return [], None + return hardcoded_reg(moving_brain, reference_brain, moving_skull, + reference_skull, ants_para, moving_mask, + reference_mask, fixed_image_mask, interp, + reg_with_skull) diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index b2260a9641..3603c70b2e 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -23,6 +23,8 @@ from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks +from CPAC.registration.guardrails import guardrail_selection, \ + registration_guardrail_node from CPAC.registration.utils import seperate_warps_list, \ check_transforms, \ generate_inverse_transform_flags, \ @@ -739,7 +741,7 @@ def create_register_func_to_anat(config, phase_diff_distcor=False, return register_func_to_anat -def create_register_func_to_anat_use_T2(config, name='register_func_to_anat_use_T2'): +def create_register_func_to_anat_use_T2(name='register_func_to_anat_use_T2'): # for monkey data # ref: https://github.com/DCAN-Labs/dcan-macaque-pipeline/blob/master/fMRIVolume/GenericfMRIVolumeProcessingPipeline.sh#L287-L295 # https://github.com/HechengJin0/dcan-macaque-pipeline/blob/master/fMRIVolume/GenericfMRIVolumeProcessingPipeline.sh#L524-L535 @@ -776,8 +778,6 @@ def create_register_func_to_anat_use_T2(config, name='register_func_to_anat_use_ outputspec.anat_func_nobbreg : string (nifti file) Functional scan registered to anatomical space """ - - register_func_to_anat_use_T2 = pe.Workflow(name=name) inputspec = pe.Node(util.IdentityInterface(fields=['func', @@ -877,13 +877,12 @@ def create_register_func_to_anat_use_T2(config, name='register_func_to_anat_use_ def create_bbregister_func_to_anat(phase_diff_distcor=False, - name='bbregister_func_to_anat'): - + name='bbregister_func_to_anat', + retry=False): """ Registers a functional scan in native space to structural. This is meant to be used after create_nonlinear_register() has been run and relies on some of its outputs. - Parameters ---------- fieldmap_distortion : bool, optional @@ -891,6 +890,8 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, take in the appropriate field map-related inputs. name : string, optional Name of the workflow. + retry : bool + Try twice? Returns ------- @@ -919,7 +920,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, outputspec.anat_func : string (nifti file) Functional data in anatomical space """ - register_bbregister_func_to_anat = pe.Workflow(name=name) inputspec = pe.Node(util.IdentityInterface(fields=['func', @@ -948,7 +948,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, register_bbregister_func_to_anat.connect( inputspec, 'bbr_wm_mask_args', wm_bb_mask, 'op_string') - register_bbregister_func_to_anat.connect(inputspec, 'anat_wm_segmentation', wm_bb_mask, 'in_file') @@ -959,49 +958,38 @@ def bbreg_args(bbreg_target): bbreg_func_to_anat = pe.Node(interface=fsl.FLIRT(), name='bbreg_func_to_anat') bbreg_func_to_anat.inputs.dof = 6 - - register_bbregister_func_to_anat.connect( - inputspec, 'bbr_schedule', - bbreg_func_to_anat, 'schedule') - - register_bbregister_func_to_anat.connect( - wm_bb_mask, ('out_file', bbreg_args), - bbreg_func_to_anat, 'args') - - register_bbregister_func_to_anat.connect( - inputspec, 'func', - bbreg_func_to_anat, 'in_file') - - register_bbregister_func_to_anat.connect( - inputspec, 'anat', - bbreg_func_to_anat, 'reference') - - register_bbregister_func_to_anat.connect( - inputspec, 'linear_reg_matrix', - bbreg_func_to_anat, 'in_matrix_file') - + nodes, guardrails = register_bbregister_func_to_anat.nodes_and_guardrails( + bbreg_func_to_anat, registered='out_file', add_clones=bool(retry)) + register_bbregister_func_to_anat.connect_retries(nodes, [ + (inputspec, 'bbr_schedule', 'schedule'), + (wm_bb_mask, ('out_file', bbreg_args), 'args'), + (inputspec, 'func', 'in_file'), + (inputspec, 'anat', 'reference'), + (inputspec, 'linear_reg_matrix', 'in_matrix_file')]) if phase_diff_distcor: + register_bbregister_func_to_anat.connect_retries(nodes, [ + (inputNode_pedir, ('pedir', convert_pedir), 'pedir'), + (inputspec, 'fieldmap', 'fieldmap'), + (inputspec, 'fieldmapmask', 'fieldmapmask'), + (inputNode_echospacing, 'echospacing', 'echospacing')]) + register_bbregister_func_to_anat.connect_retries(guardrails, [ + (inputspec, 'anat', 'reference')]) + if retry: + # pylint: disable=no-value-for-parameter + outfile = guardrail_selection(register_bbregister_func_to_anat, + *guardrails) + matrix = guardrail_selection(register_bbregister_func_to_anat, *nodes, + 'out_matrix_file', guardrails[0]) register_bbregister_func_to_anat.connect( - inputNode_pedir, ('pedir', convert_pedir), - bbreg_func_to_anat, 'pedir') - register_bbregister_func_to_anat.connect( - inputspec, 'fieldmap', - bbreg_func_to_anat, 'fieldmap') - register_bbregister_func_to_anat.connect( - inputspec, 'fieldmapmask', - bbreg_func_to_anat, 'fieldmapmask') + matrix, 'out', outputspec, 'func_to_anat_linear_xfm') + register_bbregister_func_to_anat.connect(outfile, 'out', + outputspec, 'anat_func') + else: register_bbregister_func_to_anat.connect( - inputNode_echospacing, 'echospacing', - bbreg_func_to_anat, 'echospacing') - - register_bbregister_func_to_anat.connect( - bbreg_func_to_anat, 'out_matrix_file', - outputspec, 'func_to_anat_linear_xfm') - - register_bbregister_func_to_anat.connect( - bbreg_func_to_anat, 'out_file', - outputspec, 'anat_func') - + bbreg_func_to_anat, 'out_matrix_file', + outputspec, 'func_to_anat_linear_xfm') + register_bbregister_func_to_anat.connect(guardrails[0], 'registered', + outputspec, 'anat_func') return register_bbregister_func_to_anat @@ -2754,8 +2742,8 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): "config": ["registration_workflows", "functional_registration", "coregistration"], "switch": ["run"], - "option_key": "None", - "option_val": "None", + "option_key": ["boundary_based_registration", "run"], + "option_val": [True, False, "fallback"], "inputs": [("sbref", "desc-motion_bold", "space-bold_label-WM_mask", @@ -2766,7 +2754,6 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): ("desc-preproc_T1w", "desc-restore-brain_T1w", "desc-preproc_T2w", - "desc-preproc_T2w", "T2w", ["label-WM_probseg", "label-WM_mask"], ["label-WM_pveseg", "label-WM_mask"], @@ -2775,22 +2762,19 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): "from-bold_to-T1w_mode-image_desc-linear_xfm", "from-bold_to-T1w_mode-image_desc-linear_warp"]} ''' - - diff_complete = False - if strat_pool.check_rpool("despiked-fieldmap") and \ - strat_pool.check_rpool("fieldmap-mask"): - diff_complete = True - + diff_complete = (strat_pool.check_rpool("despiked-fieldmap") and + strat_pool.check_rpool("fieldmap-mask")) + bbreg_status = "On" if opt is True else "Off" if isinstance( + opt, bool) else opt.title() + subwfname = f'func_to_anat_FLIRT_bbreg{bbreg_status}_{pipe_num}' if strat_pool.check_rpool('T2w') and cfg.anatomical_preproc['run_t2']: # monkey data - func_to_anat = create_register_func_to_anat_use_T2(cfg, - f'func_to_anat_FLIRT_' - f'{pipe_num}') + func_to_anat = create_register_func_to_anat_use_T2(subwfname) # https://github.com/DCAN-Labs/dcan-macaque-pipeline/blob/master/fMRIVolume/GenericfMRIVolumeProcessingPipeline.sh#L177 # fslmaths "$fMRIFolder"/"$NameOffMRI"_mc -Tmean "$fMRIFolder"/"$ScoutName"_gdc func_mc_mean = pe.Node(interface=afni_utils.TStat(), - name=f'func_motion_corrected_mean_{pipe_num}') + name=f'func_motion_corrected_mean_{pipe_num}') func_mc_mean.inputs.options = '-mean' func_mc_mean.inputs.outputtype = 'NIFTI_GZ' @@ -2813,24 +2797,23 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): # if field map-based distortion correction is on, but BBR is off, # send in the distortion correction files here func_to_anat = create_register_func_to_anat(cfg, diff_complete, - f'func_to_anat_FLIRT_' - f'{pipe_num}') + subwfname) func_to_anat.inputs.inputspec.dof = cfg.registration_workflows[ - 'functional_registration']['coregistration']['dof'] + 'functional_registration']['coregistration']['dof'] func_to_anat.inputs.inputspec.interp = cfg.registration_workflows[ - 'functional_registration']['coregistration']['interpolation'] + 'functional_registration']['coregistration']['interpolation'] node, out = strat_pool.get_data('sbref') wf.connect(node, out, func_to_anat, 'inputspec.func') if cfg.registration_workflows['functional_registration'][ - 'coregistration']['reference'] == 'brain': + 'coregistration']['reference'] == 'brain': # TODO: use JSON meta-data to confirm node, out = strat_pool.get_data('desc-preproc_T1w') elif cfg.registration_workflows['functional_registration'][ - 'coregistration']['reference'] == 'restore-brain': + 'coregistration']['reference'] == 'restore-brain': node, out = strat_pool.get_data('desc-restore-brain_T1w') wf.connect(node, out, func_to_anat, 'inputspec.anat') @@ -2864,22 +2847,22 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): (func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg') } - if True in cfg.registration_workflows['functional_registration'][ - 'coregistration']["boundary_based_registration"]["run"]: - - func_to_anat_bbreg = create_bbregister_func_to_anat(diff_complete, - f'func_to_anat_' - f'bbreg_' - f'{pipe_num}') + if opt in [True, 'fallback']: + fallback = opt == 'fallback' + func_to_anat_bbreg = create_bbregister_func_to_anat( + diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}', + opt is True) func_to_anat_bbreg.inputs.inputspec.bbr_schedule = \ cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'bbr_schedule'] - func_to_anat_bbreg.inputs.inputspec.bbr_wm_mask_args = \ cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'bbr_wm_mask_args'] + if fallback: + bbreg_guardrail = registration_guardrail_node( + f'bbreg{bbreg_status}_guardrail_{pipe_num}', 1) node, out = strat_pool.get_data('sbref') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.func') @@ -2889,31 +2872,35 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): 'reference'] == 'whole-head': node, out = strat_pool.get_data('desc-head_T1w') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat') + if fallback: + wf.connect(node, out, bbreg_guardrail, 'reference') elif cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'reference'] == 'brain': node, out = strat_pool.get_data('desc-preproc_T1w') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat') + if fallback: + wf.connect(node, out, bbreg_guardrail, 'reference') wf.connect(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg', func_to_anat_bbreg, 'inputspec.linear_reg_matrix') if strat_pool.check_rpool('space-bold_label-WM_mask'): node, out = strat_pool.get_data(["space-bold_label-WM_mask"]) - wf.connect(node, out, - func_to_anat_bbreg, 'inputspec.anat_wm_segmentation') else: - if cfg.registration_workflows['functional_registration'][ - 'coregistration']['boundary_based_registration']['bbr_wm_map'] == 'probability_map': + if cfg['registration_workflows', 'functional_registration', + 'coregistration', 'boundary_based_registration', + 'bbr_wm_map'] == 'probability_map': node, out = strat_pool.get_data(["label-WM_probseg", "label-WM_mask"]) - elif cfg.registration_workflows['functional_registration'][ - 'coregistration']['boundary_based_registration']['bbr_wm_map'] == 'partial_volume_map': + elif cfg['registration_workflows', 'functional_registration', + 'coregistration', 'boundary_based_registration', + 'bbr_wm_map'] == 'partial_volume_map': node, out = strat_pool.get_data(["label-WM_pveseg", "label-WM_mask"]) - wf.connect(node, out, - func_to_anat_bbreg, 'inputspec.anat_wm_segmentation') + wf.connect(node, out, + func_to_anat_bbreg, 'inputspec.anat_wm_segmentation') if diff_complete: node, out = strat_pool.get_data('effectiveEchoSpacing') @@ -2929,15 +2916,45 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): node, out = strat_pool.get_data("fieldmap-mask") wf.connect(node, out, func_to_anat_bbreg, 'inputspec.fieldmapmask') - - outputs = { - 'space-T1w_sbref': - (func_to_anat_bbreg, 'outputspec.anat_func'), - 'from-bold_to-T1w_mode-image_desc-linear_xfm': - (func_to_anat_bbreg, 'outputspec.func_to_anat_linear_xfm') - } - - return (wf, outputs) + if fallback: + # Fall back to no-BBReg + mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True, + name=f'bbreg_mean_bold_choices_{pipe_num}') + xfms = pe.Node(util.Merge(2), run_without_submitting=True, + name=f'bbreg_xfm_choices_{pipe_num}') + fallback_mean_bolds = pe.Node(util.Select(), + run_without_submitting=True, + name='bbreg_choose_mean_bold_' + f'{pipe_num}') + fallback_xfms = pe.Node(util.Select(), run_without_submitting=True, + name=f'bbreg_choose_xfm_{pipe_num}') + wf.connect([ + (func_to_anat_bbreg, bbreg_guardrail, [ + ('outputspec.anat_func', 'registered')]), + (bbreg_guardrail, mean_bolds, [('registered', 'in1')]), + (func_to_anat, mean_bolds, [('outputspec.anat_func_nobbreg', + 'in2')]), + (func_to_anat_bbreg, xfms, [ + ('outputspec.func_to_anat_linear_xfm', 'in1')]), + (func_to_anat, xfms, [ + ('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')]), + (mean_bolds, fallback_mean_bolds, [('out', 'inlist')]), + (xfms, fallback_xfms, [('out', 'inlist')]), + (bbreg_guardrail, fallback_mean_bolds, [ + ('failed_qc', 'index')]), + (bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])]) + outputs = { + 'space-T1w_sbref': (fallback_mean_bolds, 'out'), + 'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms, + 'out')} + else: + outputs = { + 'space-T1w_sbref': (func_to_anat_bbreg, + 'outputspec.anat_func'), + 'from-bold_to-T1w_mode-image_desc-linear_xfm': ( + func_to_anat_bbreg, + 'outputspec.func_to_anat_linear_xfm')} + return wf, outputs def create_func_to_T1template_xfm(wf, cfg, strat_pool, pipe_num, opt=None): diff --git a/CPAC/resources/configs/pipeline_config_blank.yml b/CPAC/resources/configs/pipeline_config_blank.yml index 09ff784f42..f4dd650d64 100644 --- a/CPAC/resources/configs/pipeline_config_blank.yml +++ b/CPAC/resources/configs/pipeline_config_blank.yml @@ -522,6 +522,12 @@ segmentation: WM_label: [2, 41] registration_workflows: + gaurdrails: + thresholds: + Dice: + Jaccard: + CrossCorr: + Coverage: anatomical_registration: run: Off registration: diff --git a/CPAC/resources/configs/pipeline_config_default.yml b/CPAC/resources/configs/pipeline_config_default.yml index be84b46008..9fd0fa0b35 100644 --- a/CPAC/resources/configs/pipeline_config_default.yml +++ b/CPAC/resources/configs/pipeline_config_default.yml @@ -586,6 +586,17 @@ segmentation: registration_workflows: + # Runtime quality checks + guardrails: + # Minimum QC values to allow a run to complete post-registration + # Set any metric empty (like "Dice:") or to None to disable that guardrail + # Default thresholds adopted from XCP-Engine + # (https://github.com/PennLINC/xcpEngine/blob/397ab6cf/designs/cbf_all.dsn#L66) + thresholds: + Dice: 0.8 + Jaccard: 0.9 + CrossCorr: 0.7 + Coverage: 0.8 anatomical_registration: @@ -784,7 +795,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] # Standard FSL 5.0 Scheduler used for Boundary Based Registration. diff --git a/CPAC/resources/configs/pipeline_config_rbc-options.yml b/CPAC/resources/configs/pipeline_config_rbc-options.yml index b79a016de3..5b5d89f83d 100644 --- a/CPAC/resources/configs/pipeline_config_rbc-options.yml +++ b/CPAC/resources/configs/pipeline_config_rbc-options.yml @@ -46,6 +46,9 @@ registration_workflows: T1w_brain_template_mask: $FSLDIR/data/standard/MNI152_T1_${resolution_for_anat}_brain_mask.nii.gz functional_registration: + coregistration: + boundary_based_registration: + run: [fallback] func_registration_to_template: output_resolution: diff --git a/CPAC/utils/docs.py b/CPAC/utils/docs.py index b1ee23df0b..181df9aa98 100644 --- a/CPAC/utils/docs.py +++ b/CPAC/utils/docs.py @@ -71,4 +71,41 @@ def grab_docstring_dct(fn): return dct +def retry_docstring(orig): + """Decorator to autodocument retries. + + Examples + -------- + >>> @retry_docstring(grab_docstring_dct) + ... def do_nothing(): + ... '''Does this do anything?''' + ... pass + >>> print(do_nothing.__doc__) + Does this do anything? + Retries the following after a failed QC check: + Function to grab a NodeBlock dictionary from a docstring. + + Parameters + ---------- + fn : function + The NodeBlock function with the docstring to be parsed. + + Returns + ------- + dct : dict + A NodeBlock configuration dictionary. + + """ + def retry(obj): + if obj.__doc__ is None: + obj.__doc__ = '' + origdoc = (f'{orig.__module__}.{orig.__name__}' if + orig.__doc__ is None else orig.__doc__) + obj.__doc__ = '\n'.join([ + obj.__doc__, 'Retries the following after a failed QC check:', + origdoc]) + return obj + return retry + + DOCS_URL_PREFIX = _docs_url_prefix() From 7629b3f53cf0223816273729acad90b087289fd5 Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Tue, 1 Nov 2022 15:26:05 -0400 Subject: [PATCH 2/6] :necktie: Add guardrail configurability --- CPAC/pipeline/cpac_pipeline.py | 7 +- .../pipeline/nipype_pipeline_engine/engine.py | 271 +++++++++++++++--- CPAC/pipeline/schema.py | 9 +- CPAC/registration/guardrails.py | 155 ++++------ CPAC/registration/registration.py | 49 ++-- .../configs/pipeline_config_blank.yml | 4 +- .../configs/pipeline_config_default.yml | 6 + .../configs/pipeline_config_rbc-options.yml | 7 + CPAC/utils/__init__.py | 2 - 9 files changed, 344 insertions(+), 166 deletions(-) diff --git a/CPAC/pipeline/cpac_pipeline.py b/CPAC/pipeline/cpac_pipeline.py index 14ce239602..d9435ffeef 100644 --- a/CPAC/pipeline/cpac_pipeline.py +++ b/CPAC/pipeline/cpac_pipeline.py @@ -767,8 +767,11 @@ def initialize_nipype_wf(cfg, sub_data_dct, name=""): if name: name = f'_{name}' - workflow_name = f'cpac{name}_{sub_data_dct["subject_id"]}_{sub_data_dct["unique_id"]}' - wf = pe.Workflow(name=workflow_name) + workflow_name = (f'cpac{name}_{sub_data_dct["subject_id"]}_' + f'{sub_data_dct["unique_id"]}') + wf = pe.Workflow(name=workflow_name, + guardrail_config=cfg['registration_workflows', + 'guardrails']) wf.base_dir = cfg.pipeline_setup['working_directory']['path'] wf.config['execution'] = { 'hash_method': 'timestamp', diff --git a/CPAC/pipeline/nipype_pipeline_engine/engine.py b/CPAC/pipeline/nipype_pipeline_engine/engine.py index 8695b1e536..eefdb5ea9a 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/engine.py +++ b/CPAC/pipeline/nipype_pipeline_engine/engine.py @@ -54,13 +54,18 @@ from typing import Iterable, Tuple, Union from nibabel import load from nipype import logging -from nipype.interfaces.utility import Function +from nipype.interfaces.base.support import Bunch, InterfaceResult +from nipype.interfaces.utility import Function, Merge, Select from nipype.pipeline import engine as pe from nipype.pipeline.engine.utils import load_resultfile as _load_resultfile from nipype.utils.functions import getsource from numpy import prod from traits.trait_base import Undefined from traits.trait_handlers import TraitListObject +from CPAC.pipeline.random_state.seed import increment_seed, random_seed, \ + random_seed_flags +from CPAC.registration.guardrails import BestOf, registration_guardrail, \ + skip_if_first_try_succeeds # set global default mem_gb DEFAULT_MEM_GB = 2.0 @@ -147,8 +152,6 @@ class Node(pe.Node): ) def __init__(self, *args, mem_gb=DEFAULT_MEM_GB, **kwargs): - # pylint: disable=import-outside-toplevel - from CPAC.pipeline.random_state import random_seed super().__init__(*args, mem_gb=mem_gb, **kwargs) self.logger = logging.getLogger("nipype.workflow") self.seed = random_seed() @@ -345,8 +348,6 @@ def parse_multiplicand(multiplicand): def _apply_random_seed(self): '''Apply flags for the first matched interface''' - # pylint: disable=import-outside-toplevel - from CPAC.pipeline.random_state import random_seed_flags if isinstance(self.interface, Function): for rsf, flags in random_seed_flags()['functions'].items(): if self.interface.inputs.function_str == getsource(rsf): @@ -411,6 +412,11 @@ def override_mem_gb(self, new_mem_gb): def run(self, updatehash=False): self.__doc__ = getattr(super(), '__doc__', '') + if hasattr(self.interface, 'inputs' + ) and hasattr(self.interface.inputs, 'previous_failure'): + if self.interface.inputs.previous_failure is False: + return InterfaceResult(self.interface, Bunch(), + self.inputs, self.outputs, None) if self.seed is not None: self._apply_random_seed() if self.seed_applied: @@ -420,6 +426,55 @@ def run(self, updatehash=False): return super().run(updatehash) +class GuardrailedNode: + '''A Node with QC guardrails.''' + def __init__(self, wf, node, reference, registered): + '''A Node with guardrails + + Parameters + ---------- + wf : Workflow + The parent workflow in which this node is guardrailed + + node : Node + Node to guardrail + + reference : str + key for reference image + + registered : str + key for registered image + ''' + self.guardrails = [registration_guardrail_node( + f'{node.name}_guardrail')] + self.node = node + self.reference = reference + self.registered = registered + self.retries = [] + self.wf = wf + self.wf.connect(self.node, registered, + self.guardrails[0], 'registered') + if self.wf.num_tries > 1: + if self.wf.retry_on_first_failure: + self.guardrails.append(registration_guardrail_node( + f'{node.name}_guardrail')) + self.retries.append(retry_clone(self.node)) + else: + for i in range(self.wf.num_tries - 1): + self.retries.append(retry_clone(self.node, i + 2)) + self.retries[0].interface = skip_if_first_try_succeeds( + self.retries[0].interface) + self.wf.connect(self.node, 'failed_qc', + self.retries[0], 'previous_failure') + for i, retry in enumerate(self.retries): + self.wf.connect(retry, registered, + self.guardrails[i + 1], 'registered') + + def guardrail_selection(self, node, output_key): + """Convenience method to :py:method:`Workflow.guardrail_selection`""" + return self.wf.guardrail_selection(node, output_key) + + class MapNode(Node, pe.MapNode): # pylint: disable=empty-docstring __doc__ = _doctest_skiplines( @@ -441,7 +496,8 @@ def __init__(self, *args, **kwargs): class Workflow(pe.Workflow): """Controls the setup and execution of a pipeline of processes.""" - def __init__(self, name, base_dir=None, debug=False): + def __init__(self, name, base_dir=None, debug=False, guardrail_config=None + ): """Create a workflow object. Parameters ---------- @@ -451,6 +507,7 @@ def __init__(self, name, base_dir=None, debug=False): path to workflow storage debug : boolean, optional enable verbose debug-level logging + guardrail_config : dict, optional """ import networkx as nx @@ -458,7 +515,8 @@ def __init__(self, name, base_dir=None, debug=False): self._debug = debug self.verbose_logger = getLogger('engine') if debug else None self._graph = nx.DiGraph() - + self._guardrail_config = { + } if guardrail_config is None else guardrail_config self._nodes_cache = set() self._nested_workflows_cache = set() @@ -497,9 +555,37 @@ def _configure_exec_nodes(self, graph): TypeError): self._handle_just_in_time_exception(node) - def connect_retries(self, nodes: Iterable['Node'], - connections: Iterable[Tuple['Node', Union[str, tuple], - str]]) -> None: + def connect(self, *args, **kwargs): + """Connects all the retry nodes and guardrails of guardrailed nodes, + then connects the other nodes as usual + + .. seealso:: pe.Node.connect + """ + if len(args) == 1: + connection_list = args.pop(0) + elif len(args) == 4: + connection_list = [(args.pop(0), args.pop(2), [ + (args.pop(1), args.pop(3))])] + new_connection_list = [] + for srcnode, destnode, connects in connection_list: + if isinstance(srcnode, GuardrailedNode): + for _from, _to in connects: + selected = srcnode.guardrail_selection(srcnode, _from) + self.connect(selected, 'out', destnode, _to) + if isinstance(destnode, GuardrailedNode): + for _from, _to in connects: + guardnodes = [destnode.node, *destnode.retries] + self.connect_many(guardnodes, [(srcnode, _from, _to)]) + if _from == destnode.reference: + self.connect_many(guardnodes, [ + (srcnode, _from, 'reference')]) + else: + new_connection_list.extend([srcnode, destnode, connects]) + super().connect(new_connection_list, *args, **kwargs) + + def connect_many(self, nodes: Iterable['Node'], + connections: Iterable[Tuple['Node', Union[str, tuple], + str]]) -> None: """Method to generalize making the same connections to try and retry nodes. @@ -513,14 +599,14 @@ def connect_retries(self, nodes: Iterable['Node'], connections : iterable of 3-tuples of (Node, str or tuple, str) """ - wrong_conn_type_msg = (r'connect_retries `connections` argument ' + wrong_conn_type_msg = (r'connect_many `connections` argument ' 'must be an iterable of (Node, str or ' 'tuple, str) tuples.') if not isinstance(connections, (list, tuple)): raise TypeError(f'{wrong_conn_type_msg}: Given {connections}') for node in nodes: if not isinstance(node, Node): - raise TypeError('connect_retries requires an iterable ' + raise TypeError('connect_many requires an iterable ' r'of nodes for the `nodes` parameter: ' f'Given {node}') for conn in connections: @@ -531,6 +617,78 @@ def connect_retries(self, nodes: Iterable['Node'], raise TypeError(f'{wrong_conn_type_msg}: Given {conn}') self.connect(*conn[:2], node, conn[2]) + @property + def guardrail(self): + """Are guardrails on? + + Returns + ------- + boolean + """ + return any(self._guardrail_config['thresholds'].values()) + + def guardrailed_node(self, node, reference, registered): + """Method to return a GuardrailedNode in the given Workflow. + + .. seealso:: GuardrailedNode + """ + return GuardrailedNode(self, node, reference, registered) + + def guardrail_selection(self, node: 'GuardrailedNode', output_key: str + ) -> Node: + """Generate requisite Nodes for choosing a path through the graph + with retries. + + Takes two nodes to choose an output from. These nodes are assumed + to be guardrail nodes if `output_key` and `guardrail_node` are not + specified. + + A ``nipype.interfaces.utility.Merge`` is generated, connecting + ``output_key`` from ``node1`` and ``node2`` in that order. + + A ``nipype.interfaces.utility.Select`` node is generated taking the + output from the generated ``Merge`` and using the ``failed_qc`` + output of ``guardrail_node`` (``node1`` if ``guardrail_node`` is + unspecified). + + All relevant connections are made in the given Workflow. + + The ``Select`` node is returned; its output is keyed ``out`` and + contains the value of the given ``output_key`` (``registered`` if + unspecified). + + Parameters + ---------- + node : GuardrailedNode + + output_key : key to select from Node + + Returns + ------- + select : Node + """ + name = node.node.name + if output_key != 'registered': + name = f'{name}_{output_key}' + choices = Node(Merge(self.num_tries), run_without_submitting=True, + name=f'{name}_choices') + select = Node(Select(), run_without_submitting=True, + name=f'choose_{name}') + self.connect([(node.node, choices, [(output_key, 'in1')]), + (choices, select, [('out', 'inlist')])]) + if self._guardrail_config['best_of'] > 1: + best_of = Node(BestOf(len(self.num_tries))) + self.connect([(node.guardrail, best_of, [('error', 'error1')]), + (best_of, 'index', [(select, 'index')])]) + for i, retry in enumerate(node.retries): + self.connect([(retry, choices, [(output_key, f'in{i+2}')]), + (retry.guardrail, best_of, [('error', + f'error{i+2}')])]) + elif self.retry_on_first_failure: + self.connect([(node.retries[0], choices, [(output_key, 'in2')]), + (node.guardrail, select, [('failed_qc', 'index')])]) + return select + def _handle_just_in_time_exception(self, node): # pylint: disable=protected-access if hasattr(self, '_local_func_scans'): @@ -540,31 +698,29 @@ def _handle_just_in_time_exception(self, node): # TODO: handle S3 files node._apply_mem_x(UNDEFINED_SIZE) # noqa: W0212 - def nodes_and_guardrails(self, *nodes, registered, add_clones=True): - """Returns a two tuples of Nodes: (try, retry) and their - respective guardrails - - Parameters - ---------- - nodes : any number of Nodes + @property + def num_tries(self): + """How many maximum tries? Returns ------- - nodes : tuple of Nodes + int + """ + if self.guardrail is False: + return 1 + return 2 if (self.retry_on_first_failure + ) else self._guardrail_config['best_of'] + + @property + def retry_on_first_failure(self): + """Retry iff first attempt fails? - guardrails : tuple of Nodes + Returns + ------- + bool """ - from CPAC.registration.guardrails import registration_guardrail_node, \ - retry_clone - nodes = list(nodes) - if add_clones is True: - nodes.extend([retry_clone(node) for node in nodes]) - guardrails = [None] * len(nodes) - for i, node in enumerate(nodes): - guardrails[i] = registration_guardrail_node( - f'guardrail_{node.name}', i) - self.connect(node, registered, guardrails[i], 'registered') - return tuple(nodes), tuple(guardrails) + return (self._guardrail_config['best_of'] == 1 and + self._guardrail_config['retry_on_first_failure'] is True) def get_data_size(filepath, mode='xyzt'): @@ -601,3 +757,54 @@ def get_data_size(filepath, mode='xyzt'): if mode == 'xyz': return prod(data_shape[0:3]).item() return prod(data_shape).item() + + +def registration_guardrail_node(name=None, retry_num=0): + """Convenience method to get a new registration_guardrail Node + + Parameters + ---------- + name : str, optional + + retry_num : int, optional + how many previous tries? + + Returns + ------- + Node + """ + if name is None: + name = 'registration_guardrail' + node = Node(Function(input_names=['registered', 'reference', 'retry_num'], + output_names=['registered', 'failed_qc', 'error'], + imports=['import logging', + 'from typing import Tuple', + 'from CPAC.qc import qc_masks, ' + 'registration_guardrail_thresholds', + 'from CPAC.registration.guardrails ' + 'import BadRegistrationError'], + function=registration_guardrail), name=name) + if retry_num: + node.inputs.retry_num = retry_num + return node + + +def retry_clone(node: 'Node', index: int = 1) -> 'Node': + """Function to clone a node, name the clone, and increment its + random seed + + Parameters + ---------- + node : Node + + index : int + if multiple tries regardless of initial success, nth try + (starting with 2) + + Returns + ------- + Node + """ + if index > 1: + return increment_seed(node.clone(f'{node.name}_try{index}')) + return increment_seed(node.clone(f'retry_{node.name}')) diff --git a/CPAC/pipeline/schema.py b/CPAC/pipeline/schema.py index 1be140c7c5..38e79ad896 100644 --- a/CPAC/pipeline/schema.py +++ b/CPAC/pipeline/schema.py @@ -282,7 +282,7 @@ def sanitize(filename): 'num_participants_at_once': int, 'random_seed': Maybe(Any( 'random', - All(int, Range(min=1, max=np.iinfo(np.int32).max)))), + All(int, Range(min=1, max=MAX_SEED)))), 'observed_usage': { 'callback_log': Maybe(str), 'buffer': Number, @@ -468,9 +468,10 @@ def sanitize(filename): }, 'registration_workflows': { 'guardrails': { - 'thresholds': {metric: Maybe(float) for metric in - ('Dice', 'Jaccard', 'CrossCorr', 'Coverage')}, - }}, + 'thresholds': {metric: Maybe(float) for metric in + ('Dice', 'Jaccard', 'CrossCorr', 'Coverage')}, + 'retry_on_first_failure': bool1_1, + 'best_of': All(int, Range(min=1))}, 'anatomical_registration': { 'run': bool1_1, 'resolution_for_anat': All(str, Match(resolution_regex)), diff --git a/CPAC/registration/guardrails.py b/CPAC/registration/guardrails.py index 1329cad97f..73f4875be3 100644 --- a/CPAC/registration/guardrails.py +++ b/CPAC/registration/guardrails.py @@ -17,74 +17,61 @@ """Guardrails to protect against bad registrations""" import logging from typing import Tuple -from nipype.interfaces.utility import Function, Merge, Select -# pylint: disable=unused-import -from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow -from CPAC.pipeline.random_state.seed import increment_seed +from nipype.interfaces.base import TraitedSpec, traits +from nipype.interfaces.io import add_traits, IOBase +from nipype.interfaces.utility.base import MergeInputSpec from CPAC.qc import qc_masks, registration_guardrail_thresholds from CPAC.registration.exceptions import BadRegistrationError from CPAC.registration.utils import hardcoded_reg from CPAC.utils.docs import retry_docstring -# noqa: F401 -def guardrail_selection(wf: 'Workflow', node1: 'Node', node2: 'Node', - output_key: str = 'registered', - guardrail_node: 'Node' = None) -> Node: - """Generate requisite Nodes for choosing a path through the graph - with retries. +class BestOfOutputSpec(TraitedSpec): + """Outputspec for :py:class:`BestOf`""" + index = traits.Int(desc="0-indexed index of minimum error") - Takes two nodes to choose an output from. These nodes are assumed - to be guardrail nodes if `output_key` and `guardrail_node` are not - specified. - A ``nipype.interfaces.utility.Merge`` is generated, connecting - ``output_key`` from ``node1`` and ``node2`` in that order. +class BestOf(IOBase): + """Returns the index of the smallest 'error'. Inputs are 1-indexed + and output is 0-indexed to mirror Merge and Select. - A ``nipype.interfaces.utility.Select`` node is generated taking the - output from the generated ``Merge`` and using the ``failed_qc`` - output of ``guardrail_node`` (``node1`` if ``guardrail_node`` is - unspecified). + .. seealso:: - All relevant connections are made in the given Workflow. + nipype.interfaces.utility.base.Merge - The ``Select`` node is returned; its output is keyed ``out`` and - contains the value of the given ``output_key`` (``registered`` if - unspecified). + nipype.interfaces.utility.base.Select - Parameters - ---------- - wf : Workflow + Example + ------- + >>> best_of = BestOf(3) + >>> best_of.inputs.error1 = 0.5 + >>> best_of.inputs.error2 = 0.1 + >>> best_of.inputs.error3 = 0.2 + >>> res = best_of.run() + >>> res.outputs.index + 1 + """ + input_spec = MergeInputSpec + output_spec = BestOfOutputSpec - node1, node2 : Node - first try, retry + def __init__(self, numinputs=0, **inputs): + super().__init__(**inputs) + self._numinputs = numinputs + if numinputs >= 1: + input_names = [f"error{(i + 1)}" for i in range(numinputs)] + else: + input_names = [] + add_traits(self.inputs, input_names) - output_key : str - field to choose + def _getval(self, idx): + return getattr(self.inputs, f"error{idx + 1}", 1) - guardrail_node : Node - guardrail to collect 'failed_qc' from if not node1 - - Returns - ------- - select : Node - """ - # pylint: disable=redefined-outer-name,reimported,unused-import - from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow - if guardrail_node is None: - guardrail_node = node1 - name = node1.name - if output_key != 'registered': - name = f'{name}_{output_key}' - choices = Node(Merge(2), run_without_submitting=True, - name=f'{name}_choices') - select = Node(Select(), run_without_submitting=True, - name=f'choose_{name}') - wf.connect([(node1, choices, [(output_key, 'in1')]), - (node2, choices, [(output_key, 'in2')]), - (choices, select, [('out', 'inlist')]), - (guardrail_node, select, [('failed_qc', 'index')])]) - return select + def _list_outputs(self): + outputs = self._outputs().get() + if self._numinputs >= 1: + values = [self._getval(idx) for idx in range(self._numinputs)] + outputs["index"] = values.index(min(values)) + return outputs def registration_guardrail(registered: str, reference: str, @@ -119,14 +106,15 @@ def registration_guardrail(registered: str, reference: str, failed_qc : int metrics met specified thresholds?, used as index for selecting - outputs - .. seealso:: + outputs with ``retry_on_first_failure`` - :py:mod:`guardrail_selection` + error : float: + sum of distance from thresholded QC metrics (min=0, max=count(metrics)) """ logger = logging.getLogger('nipype.workflow') qc_metrics = qc_masks(registered, reference) failed_qc = 0 + error = 0 for metric, threshold in registration_guardrail_thresholds().items(): if threshold is not None: value = qc_metrics.get(metric) @@ -146,52 +134,8 @@ def registration_guardrail(registered: str, reference: str, if retry_num: # if we've already retried, raise the error raise bad_registration - return registered, failed_qc - - -def registration_guardrail_node(name=None, retry_num=0): - """Convenience method to get a new registration_guardrail Node - - Parameters - ---------- - name : str, optional - - retry_num : int, optional - how many previous tries? - - Returns - ------- - Node - """ - if name is None: - name = 'registration_guardrail' - node = Node(Function(input_names=['registered', 'reference', 'retry_num'], - output_names=['registered', 'failed_qc'], - imports=['import logging', - 'from typing import Tuple', - 'from CPAC.qc import qc_masks, ' - 'registration_guardrail_thresholds', - 'from CPAC.registration.guardrails ' - 'import BadRegistrationError'], - function=registration_guardrail), name=name) - if retry_num: - node.inputs.retry_num = retry_num - return node - - -def retry_clone(node: 'Node') -> 'Node': - """Function to clone a node, name the clone, and increment its - random seed - - Parameters - ---------- - node : Node - - Returns - ------- - Node - """ - return increment_seed(node.clone(f'retry_{node.name}')) + error += (1 - value) + return registered, failed_qc, error # pylint: disable=missing-function-docstring,too-many-arguments @@ -206,3 +150,10 @@ def retry_hardcoded_reg(moving_brain, reference_brain, moving_skull, reference_skull, ants_para, moving_mask, reference_mask, fixed_image_mask, interp, reg_with_skull) + + +def skip_if_first_try_succeeds(interface): + """Set an interface up to skip if a previous attempt succeeded""" + if hasattr(interface, 'input_spec'): + interface.inputs.add_trait('previous_failure', traits.Bool()) + return interface diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index 3603c70b2e..08f4be93ea 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -23,8 +23,8 @@ from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks -from CPAC.registration.guardrails import guardrail_selection, \ - registration_guardrail_node +# from CPAC.registration.guardrails import guardrail_selection, \ +# registration_guardrail_node from CPAC.registration.utils import seperate_warps_list, \ check_transforms, \ generate_inverse_transform_flags, \ @@ -958,28 +958,31 @@ def bbreg_args(bbreg_target): bbreg_func_to_anat = pe.Node(interface=fsl.FLIRT(), name='bbreg_func_to_anat') bbreg_func_to_anat.inputs.dof = 6 - nodes, guardrails = register_bbregister_func_to_anat.nodes_and_guardrails( - bbreg_func_to_anat, registered='out_file', add_clones=bool(retry)) - register_bbregister_func_to_anat.connect_retries(nodes, [ - (inputspec, 'bbr_schedule', 'schedule'), - (wm_bb_mask, ('out_file', bbreg_args), 'args'), - (inputspec, 'func', 'in_file'), - (inputspec, 'anat', 'reference'), - (inputspec, 'linear_reg_matrix', 'in_matrix_file')]) + bbreg_func_to_anat = register_bbregister_func_to_anat.guardrailed_node( + bbreg_func_to_anat, 'reference', 'out_file') + + register_bbregister_func_to_anat.connect([ + (inputspec, bbreg_func_to_anat, [ + ('bbr_schedule', 'schedule'), + ('func', 'in_file'), + ('anat', 'reference'), + ('linear_reg_matrix', 'in_matrix_file')]), + (wm_bb_mask, bbreg_func_to_anat, ('out_file', bbreg_args), 'args')]) if phase_diff_distcor: - register_bbregister_func_to_anat.connect_retries(nodes, [ - (inputNode_pedir, ('pedir', convert_pedir), 'pedir'), - (inputspec, 'fieldmap', 'fieldmap'), - (inputspec, 'fieldmapmask', 'fieldmapmask'), - (inputNode_echospacing, 'echospacing', 'echospacing')]) - register_bbregister_func_to_anat.connect_retries(guardrails, [ - (inputspec, 'anat', 'reference')]) + register_bbregister_func_to_anat.connect([ + (inputNode_pedir, bbreg_func_to_anat, [ + ('pedir', convert_pedir), 'pedir']), + (inputspec, bbreg_func_to_anat, [ + ('fieldmap', 'fieldmap'), + ('fieldmapmask', 'fieldmapmask')]), + (inputNode_echospacing, bbreg_func_to_anat, [ + ('echospacing', 'echospacing')])]) if retry: # pylint: disable=no-value-for-parameter - outfile = guardrail_selection(register_bbregister_func_to_anat, - *guardrails) - matrix = guardrail_selection(register_bbregister_func_to_anat, *nodes, - 'out_matrix_file', guardrails[0]) + outfile = register_bbregister_func_to_anat.guardrail_selection( + bbreg_func_to_anat, 'out_file') + matrix = register_bbregister_func_to_anat.guardrail_selection( + bbreg_func_to_anat, 'out_matrix_file') register_bbregister_func_to_anat.connect( matrix, 'out', outputspec, 'func_to_anat_linear_xfm') register_bbregister_func_to_anat.connect(outfile, 'out', @@ -988,8 +991,8 @@ def bbreg_args(bbreg_target): register_bbregister_func_to_anat.connect( bbreg_func_to_anat, 'out_matrix_file', outputspec, 'func_to_anat_linear_xfm') - register_bbregister_func_to_anat.connect(guardrails[0], 'registered', - outputspec, 'anat_func') + register_bbregister_func_to_anat.connect( + bbreg_func_to_anat, 'out_file', outputspec, 'anat_func') return register_bbregister_func_to_anat diff --git a/CPAC/resources/configs/pipeline_config_blank.yml b/CPAC/resources/configs/pipeline_config_blank.yml index f4dd650d64..b51e1abaae 100644 --- a/CPAC/resources/configs/pipeline_config_blank.yml +++ b/CPAC/resources/configs/pipeline_config_blank.yml @@ -522,12 +522,14 @@ segmentation: WM_label: [2, 41] registration_workflows: - gaurdrails: + guardrails: thresholds: Dice: Jaccard: CrossCorr: Coverage: + retry_on_first_failure: Off + best_of: 1 anatomical_registration: run: Off registration: diff --git a/CPAC/resources/configs/pipeline_config_default.yml b/CPAC/resources/configs/pipeline_config_default.yml index 9fd0fa0b35..9aa13efcb4 100644 --- a/CPAC/resources/configs/pipeline_config_default.yml +++ b/CPAC/resources/configs/pipeline_config_default.yml @@ -597,6 +597,12 @@ registration_workflows: Jaccard: 0.9 CrossCorr: 0.7 Coverage: 0.8 + # If this option is turned on and best_of is set to 1, any registration step that falls below the specified thresholds will retry with an incremented seed + retry_on_first_failure: Off + # If this number is > 1, C-PAC will run that many iterations of each registration calculation and choose the registration with the smallest difference from 1 across all specified thresholds + # If this number ≠ 1, the retry_on_first_failure option will have no effect + # Must be at least 1 + best_of: 1 anatomical_registration: diff --git a/CPAC/resources/configs/pipeline_config_rbc-options.yml b/CPAC/resources/configs/pipeline_config_rbc-options.yml index 5b5d89f83d..6f0ea57dfe 100644 --- a/CPAC/resources/configs/pipeline_config_rbc-options.yml +++ b/CPAC/resources/configs/pipeline_config_rbc-options.yml @@ -31,6 +31,13 @@ pipeline_setup: random_seed: 77742777 registration_workflows: + guardrails: + thresholds: + Dice: 0.8 + Jaccard: 0.9 + CrossCorr: 0.7 + Coverage: 0.8 + retry_on_first_failure: On anatomical_registration: # Template to be used during registration. diff --git a/CPAC/utils/__init__.py b/CPAC/utils/__init__.py index c5c791ec03..b13af927fe 100644 --- a/CPAC/utils/__init__.py +++ b/CPAC/utils/__init__.py @@ -6,8 +6,6 @@ from .extract_data import run from .datatypes import ListFromItem from .configuration import check_pname, Configuration, set_subject -from .strategy import Strategy -from .outputs import Outputs from .utils import ( get_zscore, From b68dbfb8e4a15fdf72b4608c1e8f2f2f86a3f822 Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Tue, 1 Nov 2022 15:27:18 -0400 Subject: [PATCH 3/6] :loud_sound: Add BBR guardrail to CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 21f84bde5b..b7f6ce78cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added the ability to run AFNI 3dDespike on template-space BOLD data. - Added the ability to ingress TotalReadoutTime from epi field map meta-data from the JSON sidecars. - Added the ability to use TotalReadoutTime of epi field maps in the calculation of FSL topup distortion correction. +- Added ability to set minimum quality measure thresholds to boundary-based registration. - Difference method (``-``) for ``CPAC.utils.configuration.Configuration`` instances - Calculate reho and alff when timeseries in template space From 81c24714047133cccedf718feff8ba4d1fe9ca98 Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Wed, 2 Nov 2022 16:33:47 +0000 Subject: [PATCH 4/6] :necktie: Finish abstracting and installing BBR guardrails --- .pylintrc | 3 +- CPAC/pipeline/cpac_pipeline.py | 8 +- .../nipype_pipeline_engine/__init__.py | 7 +- .../pipeline/nipype_pipeline_engine/engine.py | 254 +++++++++++------- CPAC/pipeline/random_state/seed.py | 57 ++-- CPAC/pipeline/schema.py | 2 +- CPAC/qc/__init__.py | 5 +- CPAC/qc/globals/__init__.py | 17 ++ .../registration_guardrails.py} | 28 +- CPAC/registration/guardrails.py | 24 +- CPAC/registration/registration.py | 48 ++-- CPAC/utils/monitoring/monitoring.py | 5 +- dev/docker_data/run.py | 2 +- 13 files changed, 267 insertions(+), 193 deletions(-) create mode 100644 CPAC/qc/globals/__init__.py rename CPAC/qc/{globals.py => globals/registration_guardrails.py} (59%) diff --git a/.pylintrc b/.pylintrc index 41277323fd..9552292a7b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -426,7 +426,8 @@ function-naming-style=snake_case #function-rgx= # Good variable names which should always be accepted, separated by a comma. -good-names=c, +good-names=by, + c, e, f, i, diff --git a/CPAC/pipeline/cpac_pipeline.py b/CPAC/pipeline/cpac_pipeline.py index d9435ffeef..172f9f5075 100644 --- a/CPAC/pipeline/cpac_pipeline.py +++ b/CPAC/pipeline/cpac_pipeline.py @@ -32,6 +32,7 @@ from CPAC.pipeline import nipype_pipeline_engine as pe from CPAC.pipeline.nipype_pipeline_engine.plugins import \ LegacyMultiProcPlugin, MultiProcPlugin +from CPAC.qc.globals import registration_guardrails from nipype import config, logging from indi_aws import aws_utils, fetch_creds @@ -199,7 +200,7 @@ network_centrality ) -from CPAC.pipeline.random_state import set_up_random_state_logger +from CPAC.pipeline.random_state.seed import set_up_random_state_logger from CPAC.pipeline.schema import valid_options from CPAC.utils.trimmer import the_trimmer from CPAC.utils import Configuration, set_subject @@ -769,9 +770,8 @@ def initialize_nipype_wf(cfg, sub_data_dct, name=""): workflow_name = (f'cpac{name}_{sub_data_dct["subject_id"]}_' f'{sub_data_dct["unique_id"]}') - wf = pe.Workflow(name=workflow_name, - guardrail_config=cfg['registration_workflows', - 'guardrails']) + wf = pe.Workflow(name=workflow_name) + registration_guardrails.update(cfg['registration_workflows', 'guardrails']) wf.base_dir = cfg.pipeline_setup['working_directory']['path'] wf.config['execution'] = { 'hash_method': 'timestamp', diff --git a/CPAC/pipeline/nipype_pipeline_engine/__init__.py b/CPAC/pipeline/nipype_pipeline_engine/__init__.py index 66f4111cce..2d89c1c7fe 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/__init__.py +++ b/CPAC/pipeline/nipype_pipeline_engine/__init__.py @@ -24,11 +24,12 @@ from nipype.pipeline.engine import * # noqa: F401,F403 # import our DEFAULT_MEM_GB and override Node, MapNode from .engine import DEFAULT_MEM_GB, get_data_size, Node, MapNode, \ - UNDEFINED_SIZE, Workflow + registration_guardrail_node, UNDEFINED_SIZE, \ + Workflow __all__ = [ interface for interface in dir(pe) if not interface.startswith('_') -] + ['DEFAULT_MEM_GB', 'get_data_size', 'Node', 'MapNode', 'UNDEFINED_SIZE', - 'Workflow'] +] + ['DEFAULT_MEM_GB', 'get_data_size', 'Node', 'MapNode', + 'registration_guardrail_node', 'UNDEFINED_SIZE', 'Workflow'] del pe diff --git a/CPAC/pipeline/nipype_pipeline_engine/engine.py b/CPAC/pipeline/nipype_pipeline_engine/engine.py index eefdb5ea9a..d3ecaea2f1 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/engine.py +++ b/CPAC/pipeline/nipype_pipeline_engine/engine.py @@ -49,11 +49,13 @@ for Nipype's documentation.''' # noqa: E501 # pylint: disable=line-too-long import os import re +from copy import deepcopy from inspect import Parameter, Signature, signature from logging import getLogger from typing import Iterable, Tuple, Union from nibabel import load from nipype import logging +from nipype.interfaces.base import traits from nipype.interfaces.base.support import Bunch, InterfaceResult from nipype.interfaces.utility import Function, Merge, Select from nipype.pipeline import engine as pe @@ -63,9 +65,9 @@ from traits.trait_base import Undefined from traits.trait_handlers import TraitListObject from CPAC.pipeline.random_state.seed import increment_seed, random_seed, \ - random_seed_flags -from CPAC.registration.guardrails import BestOf, registration_guardrail, \ - skip_if_first_try_succeeds + random_seed_flags, Seed +from CPAC.qc.globals import registration_guardrails +from CPAC.registration.guardrails import BestOf, registration_guardrail # set global default mem_gb DEFAULT_MEM_GB = 2.0 @@ -425,54 +427,18 @@ def run(self, updatehash=False): str(self.seed), self.name) return super().run(updatehash) + @property + def seed(self): + """Random seed for this Node""" + return self._seed -class GuardrailedNode: - '''A Node with QC guardrails.''' - def __init__(self, wf, node, reference, registered): - '''A Node with guardrails - - Parameters - ---------- - wf : Workflow - The parent workflow in which this node is guardrailed - - node : Node - Node to guardrail - - reference : str - key for reference image - - registered : str - key for registered image - ''' - self.guardrails = [registration_guardrail_node( - f'{node.name}_guardrail')] - self.node = node - self.reference = reference - self.registered = registered - self.retries = [] - self.wf = wf - self.wf.connect(self.node, registered, - self.guardrails[0], 'registered') - if self.wf.num_tries > 1: - if self.wf.retry_on_first_failure: - self.guardrails.append(registration_guardrail_node( - f'{node.name}_guardrail')) - self.retries.append(retry_clone(self.node)) - else: - for i in range(self.wf.num_tries - 1): - self.retries.append(retry_clone(self.node, i + 2)) - self.retries[0].interface = skip_if_first_try_succeeds( - self.retries[0].interface) - self.wf.connect(self.node, 'failed_qc', - self.retries[0], 'previous_failure') - for i, retry in enumerate(self.retries): - self.wf.connect(retry, registered, - self.guardrails[i + 1], 'registered') - - def guardrail_selection(self, node, output_key): - """Convenience method to :py:method:`Workflow.guardrail_selection`""" - return self.wf.guardrail_selection(node, output_key) + @seed.setter + def seed(self, value): + """Cast seed to Seed type to ensure limits""" + try: + self._seed = Seed(value) + except TypeError: + self._seed = Undefined class MapNode(Node, pe.MapNode): @@ -496,8 +462,100 @@ def __init__(self, *args, **kwargs): class Workflow(pe.Workflow): """Controls the setup and execution of a pipeline of processes.""" - def __init__(self, name, base_dir=None, debug=False, guardrail_config=None - ): + class GuardrailedNode: + """A Node with QC guardrails. + + Set a node to a Workflow.GuardrailedNode like + ``node = wf.guardrailed_node(node, reference, registered, pipe_num)`` + to automatically build guardrails + """ + def __init__(self, wf, node, reference, registered, pipe_num, + retry=True): + '''A Node with guardrails + + Parameters + ---------- + wf : Workflow + The parent workflow in which this node is guardrailed + + node : Node + Node to guardrail + + reference : str + key for reference image + + registered : str + key for registered image + + pipe_num : int + int + + retry : bool + retry if run is so configured + ''' + self.guardrails = [registration_guardrail_node( + f'{node.name}_guardrail_{pipe_num}')] + self.node = node + self.reference = reference + self.registered = registered + self.retries = [] + self.wf = wf + self.wf.connect(self.node, registered, + self.guardrails[0], 'registered') + if retry and self.wf.num_tries > 1: + if registration_guardrails.retry_on_first_failure: + self.guardrails.append(registration_guardrail_node( + f'{node.name}_guardrail')) + self.retries.append(retry_clone(self.node)) + self.retries[0].interface.inputs.add_trait( + 'previous_failure', traits.Bool()) + self.guardrails.append(registration_guardrail_node( + f'{self.retries[0].name}_guardrail', + raise_on_failure=True)) + self.wf.connect(self.guardrails[0], 'failed_qc', + self.retries[0], 'previous_failure') + else: + num_retries = self.wf.num_tries - 1 + for i in range(num_retries): + self.retries.append(retry_clone(self.node, i + 2)) + self.guardrails.append(registration_guardrail_node( + f'{self.retries[-1].name}_guardrail', + raise_on_failure=(i + 1 == num_retries))) + for i, _retry in enumerate(self.retries): + self.wf.connect(_retry, registered, + self.guardrails[i + 1], 'registered') + for guardrail in self.guardrails: + guardrail.inputs.reference = self.reference + + def __getattr__(self, __name): + """Get attributes from the node that is guardrailed if that + attribute isn't an attribute of the guardrail""" + if __name in ('_reference', '_registered'): + return object.__getattribute__(self, __name[1:]) + if __name not in self.__dict__: + return getattr(self.node, __name) + return object.__getattribute__(self, __name) + + def __setattr__(self, __name, __value): + """Set an attribute in a node and all retries""" + if __name in ('reference', 'registered'): + super().__setattr__(f'_{__name}', __value) + for guardrail in self.guardrails: + setattr(guardrail.inputs, __name, __value) + if __name not in self.__dict__: + super().__setattr__(__name, __value) + else: + setattr(self.node, __name, __value) + for node in self.retries: + setattr(node, __name, __value) + + def guardrail_selection(self, node, output_key): + """Convenience method to + :py:method:`Workflow.guardrail_selection` + """ + return self.wf.guardrail_selection(node, output_key) + + def __init__(self, name, base_dir=None, debug=False): """Create a workflow object. Parameters ---------- @@ -507,7 +565,6 @@ def __init__(self, name, base_dir=None, debug=False, guardrail_config=None path to workflow storage debug : boolean, optional enable verbose debug-level logging - guardrail_config : dict, optional """ import networkx as nx @@ -515,8 +572,6 @@ def __init__(self, name, base_dir=None, debug=False, guardrail_config=None self._debug = debug self.verbose_logger = getLogger('engine') if debug else None self._graph = nx.DiGraph() - self._guardrail_config = { - } if guardrail_config is None else guardrail_config self._nodes_cache = set() self._nested_workflows_cache = set() @@ -562,17 +617,19 @@ def connect(self, *args, **kwargs): .. seealso:: pe.Node.connect """ if len(args) == 1: - connection_list = args.pop(0) + connection_list = args[0] elif len(args) == 4: - connection_list = [(args.pop(0), args.pop(2), [ - (args.pop(1), args.pop(3))])] + connection_list = [(args[0], args[2], [(args[1], args[3])])] + else: + raise TypeError("connect() takes either 4 arguments, or 1 list of" + f" connection tuples ({len(args)} args given)") new_connection_list = [] for srcnode, destnode, connects in connection_list: - if isinstance(srcnode, GuardrailedNode): + if isinstance(srcnode, Workflow.GuardrailedNode): for _from, _to in connects: selected = srcnode.guardrail_selection(srcnode, _from) self.connect(selected, 'out', destnode, _to) - if isinstance(destnode, GuardrailedNode): + elif isinstance(destnode, Workflow.GuardrailedNode): for _from, _to in connects: guardnodes = [destnode.node, *destnode.retries] self.connect_many(guardnodes, [(srcnode, _from, _to)]) @@ -580,8 +637,8 @@ def connect(self, *args, **kwargs): self.connect_many(guardnodes, [ (srcnode, _from, 'reference')]) else: - new_connection_list.extend([srcnode, destnode, connects]) - super().connect(new_connection_list, *args, **kwargs) + new_connection_list.extend([(srcnode, destnode, connects)]) + super().connect(new_connection_list, **kwargs) def connect_many(self, nodes: Iterable['Node'], connections: Iterable[Tuple['Node', Union[str, tuple], @@ -625,17 +682,32 @@ def guardrail(self): ------- boolean """ - return any(self._guardrail_config['thresholds'].values()) + return any(registration_guardrails.thresholds.values()) - def guardrailed_node(self, node, reference, registered): + def guardrailed_node(self, node, reference, registered, pipe_num): """Method to return a GuardrailedNode in the given Workflow. - .. seealso:: GuardrailedNode + .. seealso:: Workflow.GuardrailedNode + + Parameters + ---------- + node : Node + Node to guardrail + + reference : str + key for reference image + + registered : str + key for registered image + + pipe_num : int + int """ - return GuardrailedNode(self, node, reference, registered) + return self.GuardrailedNode(self, node, reference, registered, + pipe_num) - def guardrail_selection(self, node: 'GuardrailedNode', output_key: str - ) -> Node: + def guardrail_selection(self, node: 'Workflow.GuardrailedNode', + output_key: str) -> Node: """Generate requisite Nodes for choosing a path through the graph with retries. @@ -659,7 +731,7 @@ def guardrail_selection(self, node: 'GuardrailedNode', output_key: str Parameters ---------- - node : GuardrailedNode + node : Workflow.GuardrailedNode output_key : key to select from Node @@ -676,17 +748,18 @@ def guardrail_selection(self, node: 'GuardrailedNode', output_key: str name=f'choose_{name}') self.connect([(node.node, choices, [(output_key, 'in1')]), (choices, select, [('out', 'inlist')])]) - if self._guardrail_config['best_of'] > 1: + if registration_guardrails.best_of > 1: best_of = Node(BestOf(len(self.num_tries))) - self.connect([(node.guardrail, best_of, [('error', 'error1')]), + self.connect([(node.guardrails[0], best_of, [('error', 'error1')]), (best_of, 'index', [(select, 'index')])]) for i, retry in enumerate(node.retries): self.connect([(retry, choices, [(output_key, f'in{i+2}')]), - (retry.guardrail, best_of, [('error', + (retry.guardrails[i], best_of, [('error', f'error{i+2}')])]) - elif self.retry_on_first_failure: + elif registration_guardrails.retry_on_first_failure: self.connect([(node.retries[0], choices, [(output_key, 'in2')]), - (node.guardrail, select, [('failed_qc', 'index')])]) + (node.guardrails[0], select, [ + ('failed_qc', 'index')])]) return select def _handle_just_in_time_exception(self, node): @@ -708,19 +781,8 @@ def num_tries(self): """ if self.guardrail is False: return 1 - return 2 if (self.retry_on_first_failure - ) else self._guardrail_config['best_of'] - - @property - def retry_on_first_failure(self): - """Retry iff first attempt fails? - - Returns - ------- - bool - """ - return (self._guardrail_config['best_of'] == 1 and - self._guardrail_config['retry_on_first_failure'] is True) + return 2 if (registration_guardrails.retry_on_first_failure + ) else registration_guardrails.best_of def get_data_size(filepath, mode='xyzt'): @@ -759,33 +821,31 @@ def get_data_size(filepath, mode='xyzt'): return prod(data_shape).item() -def registration_guardrail_node(name=None, retry_num=0): +def registration_guardrail_node(name=None, raise_on_failure=False): """Convenience method to get a new registration_guardrail Node Parameters ---------- name : str, optional - retry_num : int, optional - how many previous tries? - Returns ------- Node """ if name is None: name = 'registration_guardrail' - node = Node(Function(input_names=['registered', 'reference', 'retry_num'], + node = Node(Function(input_names=['registered', 'reference', + 'raise_on_failure'], output_names=['registered', 'failed_qc', 'error'], imports=['import logging', 'from typing import Tuple', - 'from CPAC.qc import qc_masks, ' - 'registration_guardrail_thresholds', + 'from CPAC.qc import qc_masks', + 'from CPAC.qc.globals import ' + 'registration_guardrails', 'from CPAC.registration.guardrails ' 'import BadRegistrationError'], function=registration_guardrail), name=name) - if retry_num: - node.inputs.retry_num = retry_num + node.inputs.raise_on_failure = raise_on_failure return node @@ -806,5 +866,5 @@ def retry_clone(node: 'Node', index: int = 1) -> 'Node': Node """ if index > 1: - return increment_seed(node.clone(f'{node.name}_try{index}')) + return increment_seed(node.clone(f'{node.name}_try{index}'), index) return increment_seed(node.clone(f'retry_{node.name}')) diff --git a/CPAC/pipeline/random_state/seed.py b/CPAC/pipeline/random_state/seed.py index 21b67ccb90..18f9c36bbb 100644 --- a/CPAC/pipeline/random_state/seed.py +++ b/CPAC/pipeline/random_state/seed.py @@ -24,27 +24,51 @@ from nipype.interfaces.freesurfer.preprocess import ApplyVolTransform, ReconAll from nipype.interfaces.fsl.maths import MathsCommand from nipype.interfaces.fsl.utils import ImageMaths - +from CPAC.utils.docs import docstring_parameter from CPAC.utils.interfaces.ants import AI -from CPAC.utils.monitoring.custom_logging import set_up_logger MAX_SEED = np.iinfo(np.int32).max _seed = {'seed': None} -def increment_seed(node): +@docstring_parameter(1, MAX_SEED) +class Seed(int): + """An integer bounded at [{}, {}] for use in setting up random state.""" + def __new__(cls, *args, **kwargs): + _value = super().__new__(cls, *args, **kwargs + ) if args else random_seed() + while _value > MAX_SEED: + _value = _value - MAX_SEED + while _value < 1: + _value = _value + MAX_SEED + return _value + + @docstring_parameter(1, MAX_SEED) + def __add__(self, other): + """Loop back around to {} after {}""" + return Seed(super().__add__(other)) + + @docstring_parameter(MAX_SEED, 0) + def __sub__(self, other): + """Loop back around to {} at {}""" + return Seed(super().__sub__(other)) + + +def increment_seed(node, by=1): """Increment the random seed for a given node Parameters ---------- node : Node + by : int + how much to increment by + Returns ------- node : Node """ - if isinstance(node.seed, int): - node.seed = seed_plus_1() + node.seed += by return node @@ -64,7 +88,7 @@ def random_random_seed(): >>> 0 < random_random_seed() <= MAX_SEED True ''' - return random.randint(1, MAX_SEED) + return Seed(random.randint(1, MAX_SEED)) def random_seed(): @@ -169,24 +193,6 @@ def _reusable_flags(): } -def seed_plus_1(seed=None): - '''Increment seed, looping back to 1 at MAX_SEED - - Parameters - ---------- - seed : int, optional - Uses configured seed if not specified - - Returns - ------- - int - ''' - seed = random_seed() if seed is None else int(seed) - if seed < MAX_SEED: # increment random seed - return seed + 1 - return 1 # loop back to 1 - - def set_up_random_state(seed): '''Set global random seed @@ -218,7 +224,7 @@ def set_up_random_state(seed): seed = random_random_seed() else: try: - seed = int(seed) + seed = Seed(seed) assert 0 < seed <= MAX_SEED except (ValueError, TypeError, AssertionError) as error: raise ValueError( @@ -237,6 +243,7 @@ def set_up_random_state_logger(log_dir): ---------- log_dir : str ''' + from CPAC.utils.monitoring.custom_logging import set_up_logger set_up_logger('random', filename='random.tsv', level='info', log_dir=log_dir) getLogger('random').info('seed\tnode') diff --git a/CPAC/pipeline/schema.py b/CPAC/pipeline/schema.py index 38e79ad896..df03a02641 100644 --- a/CPAC/pipeline/schema.py +++ b/CPAC/pipeline/schema.py @@ -532,7 +532,7 @@ def sanitize(filename): }, 'boundary_based_registration': { 'run': All(Coerce(ListFromItem), - [Any(bool1_1, All(Lower, Equal('fallback')))], + [Any(All(Lower, Equal('fallback')), bool1_1)], Length(max=3)), 'bbr_schedule': str, 'bbr_wm_map': In({'probability_map', diff --git a/CPAC/qc/__init__.py b/CPAC/qc/__init__.py index 810d06aedc..45ea7b20cb 100644 --- a/CPAC/qc/__init__.py +++ b/CPAC/qc/__init__.py @@ -15,8 +15,5 @@ # You should have received a copy of the GNU Lesser General Public # License along with C-PAC. If not, see . """Quality control utilities for C-PAC""" -from CPAC.qc.globals import registration_guardrail_thresholds, \ - update_thresholds from CPAC.qc.qcmetrics import qc_masks -__all__ = ['qc_masks', 'registration_guardrail_thresholds', - 'update_thresholds'] +__all__ = ['qc_masks'] diff --git a/CPAC/qc/globals/__init__.py b/CPAC/qc/globals/__init__.py new file mode 100644 index 0000000000..af8700a6a9 --- /dev/null +++ b/CPAC/qc/globals/__init__.py @@ -0,0 +1,17 @@ +# Copyright (C) 2022 C-PAC Developers + +# This file is part of C-PAC. + +# C-PAC is free software: you can redistribute it and/or modify it under +# the terms of the GNU Lesser General Public License as published by the +# Free Software Foundation, either version 3 of the License, or (at your +# option) any later version. + +# C-PAC is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +# You should have received a copy of the GNU Lesser General Public +# License along with C-PAC. If not, see . +"""Global registration guardrail values""" diff --git a/CPAC/qc/globals.py b/CPAC/qc/globals/registration_guardrails.py similarity index 59% rename from CPAC/qc/globals.py rename to CPAC/qc/globals/registration_guardrails.py index e4a05d8d9d..c59cc6bc20 100644 --- a/CPAC/qc/globals.py +++ b/CPAC/qc/globals/registration_guardrails.py @@ -14,29 +14,29 @@ # You should have received a copy of the GNU Lesser General Public # License along with C-PAC. If not, see . -"""Global QC values""" -_REGISTRATION_GUARDRAIL_THRESHOLDS = {'thresholds': {}} +"""Global registration guardrail values""" +from traits.trait_base import Undefined +_REGISTRATION_GUARDRAILS = {} -def registration_guardrail_thresholds() -> dict: - """Get registration guardrail thresholds - - Returns - ------- - dict - """ - return _REGISTRATION_GUARDRAIL_THRESHOLDS['thresholds'] +def __getattr__(name): + """Get global values""" + if name == 'retry_on_first_failure': + return (_REGISTRATION_GUARDRAILS.get('best_of') == 1 and + _REGISTRATION_GUARDRAILS.get(name) is True) + return _REGISTRATION_GUARDRAILS.get(name, Undefined) -def update_thresholds(thresholds) -> None: - """Set a registration guardrail threshold +def update(_dict) -> None: + """Set registration guardrails Parameters ---------- - thresholds : dict of {str: float or int} + _dict : dict + keys and values to update Returns ------- None """ - _REGISTRATION_GUARDRAIL_THRESHOLDS['thresholds'].update(thresholds) + _REGISTRATION_GUARDRAILS.update(_dict) diff --git a/CPAC/registration/guardrails.py b/CPAC/registration/guardrails.py index 73f4875be3..b031ef351c 100644 --- a/CPAC/registration/guardrails.py +++ b/CPAC/registration/guardrails.py @@ -20,7 +20,8 @@ from nipype.interfaces.base import TraitedSpec, traits from nipype.interfaces.io import add_traits, IOBase from nipype.interfaces.utility.base import MergeInputSpec -from CPAC.qc import qc_masks, registration_guardrail_thresholds +from CPAC.qc import qc_masks +from CPAC.qc.globals import registration_guardrails from CPAC.registration.exceptions import BadRegistrationError from CPAC.registration.utils import hardcoded_reg from CPAC.utils.docs import retry_docstring @@ -75,7 +76,7 @@ def _list_outputs(self): def registration_guardrail(registered: str, reference: str, - retry: bool = False, retry_num: int = 0 + retry: bool = False, raise_on_failure: bool = False ) -> Tuple[str, int]: """Check QC metrics post-registration and throw an exception if metrics are below given thresholds. @@ -93,11 +94,8 @@ def registration_guardrail(registered: str, reference: str, registered, reference : str path to mask - retry : bool, optional - can retry? - - retry_num : int, optional - how many previous tries? + raise_on_failure : bool + raise exception if guardrail catches failed QC? Returns ------- @@ -115,7 +113,7 @@ def registration_guardrail(registered: str, reference: str, qc_metrics = qc_masks(registered, reference) failed_qc = 0 error = 0 - for metric, threshold in registration_guardrail_thresholds().items(): + for metric, threshold in registration_guardrails.thresholds.items(): if threshold is not None: value = qc_metrics.get(metric) if isinstance(value, list): @@ -131,8 +129,7 @@ def registration_guardrail(registered: str, reference: str, bad_registration = BadRegistrationError( metric=metric, value=value, threshold=threshold) logger.error(str(bad_registration)) - if retry_num: - # if we've already retried, raise the error + if raise_on_failure: raise bad_registration error += (1 - value) return registered, failed_qc, error @@ -150,10 +147,3 @@ def retry_hardcoded_reg(moving_brain, reference_brain, moving_skull, reference_skull, ants_para, moving_mask, reference_mask, fixed_image_mask, interp, reg_with_skull) - - -def skip_if_first_try_succeeds(interface): - """Set an interface up to skip if a previous attempt succeeded""" - if hasattr(interface, 'input_spec'): - interface.inputs.add_trait('previous_failure', traits.Bool()) - return interface diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index 08f4be93ea..d17213ca49 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -20,11 +20,9 @@ from CPAC.pipeline import nipype_pipeline_engine as pe from nipype.interfaces import afni, ants, c3, fsl, utility as util from nipype.interfaces.afni import utils as afni_utils - from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks -# from CPAC.registration.guardrails import guardrail_selection, \ -# registration_guardrail_node +from CPAC.qc.globals import registration_guardrails from CPAC.registration.utils import seperate_warps_list, \ check_transforms, \ generate_inverse_transform_flags, \ @@ -876,22 +874,25 @@ def create_register_func_to_anat_use_T2(name='register_func_to_anat_use_T2'): return register_func_to_anat_use_T2 -def create_bbregister_func_to_anat(phase_diff_distcor=False, - name='bbregister_func_to_anat', - retry=False): +def create_bbregister_func_to_anat(phase_diff_distcor, name, bbreg_status, + pipe_num=0): """ Registers a functional scan in native space to structural. This is meant to be used after create_nonlinear_register() has been run and relies on some of its outputs. Parameters ---------- - fieldmap_distortion : bool, optional + fieldmap_distortion : bool If field map-based distortion correction is being run, FLIRT should take in the appropriate field map-related inputs. - name : string, optional + + name : string Name of the workflow. - retry : bool - Try twice? + + bbreg_status : string + 'On' or 'fallback' + + pipe_num : int Returns ------- @@ -920,8 +921,9 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False, outputspec.anat_func : string (nifti file) Functional data in anatomical space """ - register_bbregister_func_to_anat = pe.Workflow(name=name) - + suffix = f'{bbreg_status.title()}_{pipe_num}' + retry = bbreg_status == 'On' + register_bbregister_func_to_anat = pe.Workflow(name=f'{name}_{suffix}') inputspec = pe.Node(util.IdentityInterface(fields=['func', 'anat', 'linear_reg_matrix', @@ -956,18 +958,18 @@ def bbreg_args(bbreg_target): return '-cost bbr -wmseg ' + bbreg_target bbreg_func_to_anat = pe.Node(interface=fsl.FLIRT(), - name='bbreg_func_to_anat') + name=f'bbreg_func_to_anat_{suffix}') bbreg_func_to_anat.inputs.dof = 6 bbreg_func_to_anat = register_bbregister_func_to_anat.guardrailed_node( - bbreg_func_to_anat, 'reference', 'out_file') - + bbreg_func_to_anat, 'reference', 'out_file', pipe_num) register_bbregister_func_to_anat.connect([ (inputspec, bbreg_func_to_anat, [ ('bbr_schedule', 'schedule'), ('func', 'in_file'), ('anat', 'reference'), - ('linear_reg_matrix', 'in_matrix_file')]), - (wm_bb_mask, bbreg_func_to_anat, ('out_file', bbreg_args), 'args')]) + ('linear_reg_matrix', 'in_matrix_file')])]) + register_bbregister_func_to_anat.connect( + wm_bb_mask, ('out_file', bbreg_args), bbreg_func_to_anat, 'args') if phase_diff_distcor: register_bbregister_func_to_anat.connect([ (inputNode_pedir, bbreg_func_to_anat, [ @@ -977,8 +979,7 @@ def bbreg_args(bbreg_target): ('fieldmapmask', 'fieldmapmask')]), (inputNode_echospacing, bbreg_func_to_anat, [ ('echospacing', 'echospacing')])]) - if retry: - # pylint: disable=no-value-for-parameter + if retry and registration_guardrails.retry_on_first_failure: outfile = register_bbregister_func_to_anat.guardrail_selection( bbreg_func_to_anat, 'out_file') matrix = register_bbregister_func_to_anat.guardrail_selection( @@ -2849,12 +2850,10 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): 'from-bold_to-T1w_mode-image_desc-linear_xfm': (func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg') } - if opt in [True, 'fallback']: fallback = opt == 'fallback' func_to_anat_bbreg = create_bbregister_func_to_anat( - diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}', - opt is True) + diff_complete, 'func_to_anat_bbreg', bbreg_status, pipe_num) func_to_anat_bbreg.inputs.inputspec.bbr_schedule = \ cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ @@ -2864,8 +2863,9 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): 'coregistration']['boundary_based_registration'][ 'bbr_wm_mask_args'] if fallback: - bbreg_guardrail = registration_guardrail_node( - f'bbreg{bbreg_status}_guardrail_{pipe_num}', 1) + bbreg_guardrail = pe.registration_guardrail_node( + f'bbreg{bbreg_status}_guardrail_{pipe_num}', + raise_on_failure=False) node, out = strat_pool.get_data('sbref') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.func') diff --git a/CPAC/utils/monitoring/monitoring.py b/CPAC/utils/monitoring/monitoring.py index 00d8b34dfd..37956b9ef2 100644 --- a/CPAC/utils/monitoring/monitoring.py +++ b/CPAC/utils/monitoring/monitoring.py @@ -8,14 +8,15 @@ from traits.trait_base import Undefined -from CPAC.pipeline import nipype_pipeline_engine as pe +from CPAC.pipeline.nipype_pipeline_engine import Workflow from .custom_logging import getLogger # Log initial information from all the nodes def recurse_nodes(workflow, prefix=''): + # pylint: disable=protected-access for node in nx.topological_sort(workflow._graph): - if isinstance(node, pe.Workflow): + if isinstance(node, Workflow): for subnode in recurse_nodes(node, prefix + workflow.name + '.'): yield subnode else: diff --git a/dev/docker_data/run.py b/dev/docker_data/run.py index f4814a8c5e..a8e8ecc568 100755 --- a/dev/docker_data/run.py +++ b/dev/docker_data/run.py @@ -27,7 +27,7 @@ import yaml from CPAC import license_notice, __version__ from CPAC.pipeline import AVAILABLE_PIPELINE_CONFIGS -from CPAC.pipeline.random_state import set_up_random_state +from CPAC.pipeline.random_state.seed import set_up_random_state from CPAC.utils.bids_utils import create_cpac_data_config, \ load_cpac_data_config, \ load_yaml_config, \ From 64d875e47de8bbecce41e1e00727fc76376d6c58 Mon Sep 17 00:00:00 2001 From: Jon Clucas Date: Wed, 2 Nov 2022 20:18:28 +0000 Subject: [PATCH 5/6] :recycle: Attach guardrails directly to Nodes --- .../pipeline/nipype_pipeline_engine/engine.py | 52 +++++++++---------- CPAC/registration/registration.py | 38 ++++---------- 2 files changed, 36 insertions(+), 54 deletions(-) diff --git a/CPAC/pipeline/nipype_pipeline_engine/engine.py b/CPAC/pipeline/nipype_pipeline_engine/engine.py index d3ecaea2f1..9d88d3a571 100644 --- a/CPAC/pipeline/nipype_pipeline_engine/engine.py +++ b/CPAC/pipeline/nipype_pipeline_engine/engine.py @@ -469,8 +469,7 @@ class GuardrailedNode: ``node = wf.guardrailed_node(node, reference, registered, pipe_num)`` to automatically build guardrails """ - def __init__(self, wf, node, reference, registered, pipe_num, - retry=True): + def __init__(self, wf, node, reference, registered, pipe_num, retry): '''A Node with guardrails Parameters @@ -493,39 +492,34 @@ def __init__(self, wf, node, reference, registered, pipe_num, retry : bool retry if run is so configured ''' - self.guardrails = [registration_guardrail_node( - f'{node.name}_guardrail_{pipe_num}')] self.node = node + self.node.guardrail = registration_guardrail_node( + f'{node.name}_guardrail_{pipe_num}') self.reference = reference self.registered = registered - self.retries = [] + self.tries = [node] self.wf = wf - self.wf.connect(self.node, registered, - self.guardrails[0], 'registered') if retry and self.wf.num_tries > 1: if registration_guardrails.retry_on_first_failure: - self.guardrails.append(registration_guardrail_node( - f'{node.name}_guardrail')) - self.retries.append(retry_clone(self.node)) - self.retries[0].interface.inputs.add_trait( + self.tries.append(retry_clone(self.node)) + self.tries[1].interface.inputs.add_trait( 'previous_failure', traits.Bool()) - self.guardrails.append(registration_guardrail_node( - f'{self.retries[0].name}_guardrail', - raise_on_failure=True)) - self.wf.connect(self.guardrails[0], 'failed_qc', - self.retries[0], 'previous_failure') + self.tries[1].guardrail = registration_guardrail_node( + f'{self.tries[1].name}_guardrail', + raise_on_failure=True) + self.wf.connect(self.tries[0].guardrail, 'failed_qc', + self.tries[1], 'previous_failure') else: num_retries = self.wf.num_tries - 1 for i in range(num_retries): - self.retries.append(retry_clone(self.node, i + 2)) - self.guardrails.append(registration_guardrail_node( - f'{self.retries[-1].name}_guardrail', - raise_on_failure=(i + 1 == num_retries))) - for i, _retry in enumerate(self.retries): - self.wf.connect(_retry, registered, - self.guardrails[i + 1], 'registered') - for guardrail in self.guardrails: - guardrail.inputs.reference = self.reference + self.tries.append(retry_clone(self.node, i + 2)) + self.tries[i + 1].guardrail = ( + registration_guardrail_node( + f'{self.tries[i + 1].name}_guardrail', + raise_on_failure=(i + 1 == num_retries))) + for i, _try in enumerate(self.tries): + self.wf.connect(_try, registered, _try.guardrail, 'registered') + _try.guardrail.inputs.reference = self.reference def __getattr__(self, __name): """Get attributes from the node that is guardrailed if that @@ -684,7 +678,8 @@ def guardrail(self): """ return any(registration_guardrails.thresholds.values()) - def guardrailed_node(self, node, reference, registered, pipe_num): + def guardrailed_node(self, node, reference, registered, pipe_num, + retry=True): """Method to return a GuardrailedNode in the given Workflow. .. seealso:: Workflow.GuardrailedNode @@ -702,9 +697,12 @@ def guardrailed_node(self, node, reference, registered, pipe_num): pipe_num : int int + + retry : bool + retry if run is so configured """ return self.GuardrailedNode(self, node, reference, registered, - pipe_num) + pipe_num, retry) def guardrail_selection(self, node: 'Workflow.GuardrailedNode', output_key: str) -> Node: diff --git a/CPAC/registration/registration.py b/CPAC/registration/registration.py index d17213ca49..a5140e1a7d 100644 --- a/CPAC/registration/registration.py +++ b/CPAC/registration/registration.py @@ -22,7 +22,6 @@ from nipype.interfaces.afni import utils as afni_utils from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks -from CPAC.qc.globals import registration_guardrails from CPAC.registration.utils import seperate_warps_list, \ check_transforms, \ generate_inverse_transform_flags, \ @@ -922,7 +921,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor, name, bbreg_status, Functional data in anatomical space """ suffix = f'{bbreg_status.title()}_{pipe_num}' - retry = bbreg_status == 'On' register_bbregister_func_to_anat = pe.Workflow(name=f'{name}_{suffix}') inputspec = pe.Node(util.IdentityInterface(fields=['func', 'anat', @@ -961,7 +959,8 @@ def bbreg_args(bbreg_target): name=f'bbreg_func_to_anat_{suffix}') bbreg_func_to_anat.inputs.dof = 6 bbreg_func_to_anat = register_bbregister_func_to_anat.guardrailed_node( - bbreg_func_to_anat, 'reference', 'out_file', pipe_num) + bbreg_func_to_anat, 'reference', 'out_file', pipe_num, + retry=bbreg_status == 'On') register_bbregister_func_to_anat.connect([ (inputspec, bbreg_func_to_anat, [ ('bbr_schedule', 'schedule'), @@ -979,21 +978,14 @@ def bbreg_args(bbreg_target): ('fieldmapmask', 'fieldmapmask')]), (inputNode_echospacing, bbreg_func_to_anat, [ ('echospacing', 'echospacing')])]) - if retry and registration_guardrails.retry_on_first_failure: - outfile = register_bbregister_func_to_anat.guardrail_selection( - bbreg_func_to_anat, 'out_file') - matrix = register_bbregister_func_to_anat.guardrail_selection( - bbreg_func_to_anat, 'out_matrix_file') - register_bbregister_func_to_anat.connect( - matrix, 'out', outputspec, 'func_to_anat_linear_xfm') - register_bbregister_func_to_anat.connect(outfile, 'out', - outputspec, 'anat_func') - else: - register_bbregister_func_to_anat.connect( - bbreg_func_to_anat, 'out_matrix_file', - outputspec, 'func_to_anat_linear_xfm') - register_bbregister_func_to_anat.connect( - bbreg_func_to_anat, 'out_file', outputspec, 'anat_func') + outfile = register_bbregister_func_to_anat.guardrail_selection( + bbreg_func_to_anat, 'out_file') + matrix = register_bbregister_func_to_anat.guardrail_selection( + bbreg_func_to_anat, 'out_matrix_file') + register_bbregister_func_to_anat.connect( + matrix, 'out', outputspec, 'func_to_anat_linear_xfm') + register_bbregister_func_to_anat.connect(outfile, 'out', + outputspec, 'anat_func') return register_bbregister_func_to_anat @@ -1827,7 +1819,7 @@ def bold_to_T1template_xfm_connector(wf_name, cfg, reg_tool, symmetric=False, name='change_transform_type') wf.connect(fsl_reg_2_itk, 'itk_transform', - change_transform, 'input_affine_file') + change_transform, 'input_affine_file') # combine ALL xfm's into one - makes it easier downstream write_composite_xfm = pe.Node( @@ -2862,10 +2854,6 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'bbr_wm_mask_args'] - if fallback: - bbreg_guardrail = pe.registration_guardrail_node( - f'bbreg{bbreg_status}_guardrail_{pipe_num}', - raise_on_failure=False) node, out = strat_pool.get_data('sbref') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.func') @@ -2875,16 +2863,12 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None): 'reference'] == 'whole-head': node, out = strat_pool.get_data('desc-head_T1w') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat') - if fallback: - wf.connect(node, out, bbreg_guardrail, 'reference') elif cfg.registration_workflows['functional_registration'][ 'coregistration']['boundary_based_registration'][ 'reference'] == 'brain': node, out = strat_pool.get_data('desc-preproc_T1w') wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat') - if fallback: - wf.connect(node, out, bbreg_guardrail, 'reference') wf.connect(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg', func_to_anat_bbreg, 'inputspec.linear_reg_matrix') From e5ba0f2eeddc193ecdcc40a7f8f4156c083a98e5 Mon Sep 17 00:00:00 2001 From: "Theodore (Machine User)" Date: Wed, 2 Nov 2022 21:37:54 +0000 Subject: [PATCH 6/6] :bulb: Update comments based on default preconfig --- .../configs/pipeline_config_anat-only.yml | 3 ++- .../pipeline_config_benchmark-ANTS.yml | 3 ++- .../pipeline_config_benchmark-FNIRT.yml | 3 ++- .../configs/pipeline_config_blank.yml | 20 ++++++++++++++++--- .../configs/pipeline_config_ccs-options.yml | 3 ++- .../pipeline_config_fmriprep-options.yml | 3 ++- .../configs/pipeline_config_ndmg.yml | 3 ++- .../configs/pipeline_config_rbc-options.yml | 17 +++++++++++++++- .../configs/pipeline_config_regtest-1.yml | 3 ++- .../configs/pipeline_config_regtest-2.yml | 3 ++- .../configs/pipeline_config_regtest-3.yml | 3 ++- .../configs/pipeline_config_regtest-4.yml | 3 ++- .../configs/pipeline_config_rodent.yml | 3 ++- 13 files changed, 55 insertions(+), 15 deletions(-) diff --git a/CPAC/resources/configs/pipeline_config_anat-only.yml b/CPAC/resources/configs/pipeline_config_anat-only.yml index 50f04bc001..b5bf6890e8 100644 --- a/CPAC/resources/configs/pipeline_config_anat-only.yml +++ b/CPAC/resources/configs/pipeline_config_anat-only.yml @@ -98,7 +98,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [Off] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_benchmark-ANTS.yml b/CPAC/resources/configs/pipeline_config_benchmark-ANTS.yml index 8c9c93f465..e3d378f973 100644 --- a/CPAC/resources/configs/pipeline_config_benchmark-ANTS.yml +++ b/CPAC/resources/configs/pipeline_config_benchmark-ANTS.yml @@ -109,7 +109,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_benchmark-FNIRT.yml b/CPAC/resources/configs/pipeline_config_benchmark-FNIRT.yml index 749a5d5ace..fbbbb34d4b 100644 --- a/CPAC/resources/configs/pipeline_config_benchmark-FNIRT.yml +++ b/CPAC/resources/configs/pipeline_config_benchmark-FNIRT.yml @@ -126,7 +126,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_blank.yml b/CPAC/resources/configs/pipeline_config_blank.yml index b51e1abaae..b9c19595ee 100644 --- a/CPAC/resources/configs/pipeline_config_blank.yml +++ b/CPAC/resources/configs/pipeline_config_blank.yml @@ -522,14 +522,28 @@ segmentation: WM_label: [2, 41] registration_workflows: + + # Runtime quality checks guardrails: + + # Minimum QC values to allow a run to complete post-registration + # Set any metric empty (like "Dice:") or to None to disable that guardrail + # Default thresholds adopted from XCP-Engine + # (https://github.com/PennLINC/xcpEngine/blob/397ab6cf/designs/cbf_all.dsn#L66) thresholds: Dice: Jaccard: CrossCorr: Coverage: + + # If this option is turned on and best_of is set to 1, any registration step that falls below the specified thresholds will retry with an incremented seed retry_on_first_failure: Off + + # If this number is > 1, C-PAC will run that many iterations of each registration calculation and choose the registration with the smallest difference from 1 across all specified thresholds + # If this number ≠ 1, the retry_on_first_failure option will have no effect + # Must be at least 1 best_of: 1 + anatomical_registration: run: Off registration: @@ -695,7 +709,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [Off] # reference for boundary based registration @@ -1199,8 +1214,7 @@ functional_preproc: # this is a fork point # run: [On, Off] - this will run both and fork the pipeline run: [Off] - - space: 'native' + space: native nuisance_corrections: 2-nuisance_regression: diff --git a/CPAC/resources/configs/pipeline_config_ccs-options.yml b/CPAC/resources/configs/pipeline_config_ccs-options.yml index 6b3c43068b..4bcc492898 100644 --- a/CPAC/resources/configs/pipeline_config_ccs-options.yml +++ b/CPAC/resources/configs/pipeline_config_ccs-options.yml @@ -126,7 +126,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_fmriprep-options.yml b/CPAC/resources/configs/pipeline_config_fmriprep-options.yml index 2f6059cea2..cd4002c0b4 100644 --- a/CPAC/resources/configs/pipeline_config_fmriprep-options.yml +++ b/CPAC/resources/configs/pipeline_config_fmriprep-options.yml @@ -264,7 +264,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] # reference for boundary based registration diff --git a/CPAC/resources/configs/pipeline_config_ndmg.yml b/CPAC/resources/configs/pipeline_config_ndmg.yml index aa1e31d486..50c48e3d38 100644 --- a/CPAC/resources/configs/pipeline_config_ndmg.yml +++ b/CPAC/resources/configs/pipeline_config_ndmg.yml @@ -74,7 +74,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_rbc-options.yml b/CPAC/resources/configs/pipeline_config_rbc-options.yml index 6f0ea57dfe..0f6096a859 100644 --- a/CPAC/resources/configs/pipeline_config_rbc-options.yml +++ b/CPAC/resources/configs/pipeline_config_rbc-options.yml @@ -31,13 +31,23 @@ pipeline_setup: random_seed: 77742777 registration_workflows: + + # Runtime quality checks guardrails: + + # Minimum QC values to allow a run to complete post-registration + # Set any metric empty (like "Dice:") or to None to disable that guardrail + # Default thresholds adopted from XCP-Engine + # (https://github.com/PennLINC/xcpEngine/blob/397ab6cf/designs/cbf_all.dsn#L66) thresholds: Dice: 0.8 Jaccard: 0.9 CrossCorr: 0.7 Coverage: 0.8 + + # If this option is turned on and best_of is set to 1, any registration step that falls below the specified thresholds will retry with an incremented seed retry_on_first_failure: On + anatomical_registration: # Template to be used during registration. @@ -55,7 +65,12 @@ registration_workflows: functional_registration: coregistration: boundary_based_registration: + + # this is a fork point + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [fallback] + func_registration_to_template: output_resolution: @@ -120,7 +135,7 @@ functional_preproc: # this is a fork point # run: [On, Off] - this will run both and fork the pipeline run: [On] - space: 'template' + space: template nuisance_corrections: 2-nuisance_regression: diff --git a/CPAC/resources/configs/pipeline_config_regtest-1.yml b/CPAC/resources/configs/pipeline_config_regtest-1.yml index a86ef2e573..8ddf1b784b 100644 --- a/CPAC/resources/configs/pipeline_config_regtest-1.yml +++ b/CPAC/resources/configs/pipeline_config_regtest-1.yml @@ -109,7 +109,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_regtest-2.yml b/CPAC/resources/configs/pipeline_config_regtest-2.yml index e421366f22..ff7ccdb444 100644 --- a/CPAC/resources/configs/pipeline_config_regtest-2.yml +++ b/CPAC/resources/configs/pipeline_config_regtest-2.yml @@ -130,7 +130,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_regtest-3.yml b/CPAC/resources/configs/pipeline_config_regtest-3.yml index 5f735557d2..064b916898 100644 --- a/CPAC/resources/configs/pipeline_config_regtest-3.yml +++ b/CPAC/resources/configs/pipeline_config_regtest-3.yml @@ -130,7 +130,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_regtest-4.yml b/CPAC/resources/configs/pipeline_config_regtest-4.yml index 2384e3b3bd..3b8299b195 100644 --- a/CPAC/resources/configs/pipeline_config_regtest-4.yml +++ b/CPAC/resources/configs/pipeline_config_regtest-4.yml @@ -165,7 +165,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] func_registration_to_template: diff --git a/CPAC/resources/configs/pipeline_config_rodent.yml b/CPAC/resources/configs/pipeline_config_rodent.yml index 62387568ec..8e53970704 100644 --- a/CPAC/resources/configs/pipeline_config_rodent.yml +++ b/CPAC/resources/configs/pipeline_config_rodent.yml @@ -96,7 +96,8 @@ registration_workflows: boundary_based_registration: # this is a fork point - # run: [On, Off] - this will run both and fork the pipeline + # run: [On, Off, fallback] - this will run both and fork the pipeline + # if 'fallback' is one of the selected options, BBR will run and, if its output fails quality_thresholds, the pipeline will fallback to BBR's input image run: [On] # Standard FSL 5.0 Scheduler used for Boundary Based Registration.