Skip to content

Commit

Permalink
sty(black): standardize formatting a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Apr 8, 2020
1 parent 9446d47 commit 0ae5351
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 70 deletions.
95 changes: 54 additions & 41 deletions niworkflows/interfaces/nibabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,28 @@
from nipype import logging
from nipype.utils.filemanip import fname_presuffix
from nipype.interfaces.base import (
traits, TraitedSpec, BaseInterfaceInputSpec, File,
SimpleInterface, OutputMultiObject, InputMultiObject
traits,
TraitedSpec,
BaseInterfaceInputSpec,
File,
SimpleInterface,
OutputMultiObject,
InputMultiObject,
)

IFLOGGER = logging.getLogger('nipype.interface')
IFLOGGER = logging.getLogger("nipype.interface")


class _ApplyMaskInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='an image')
in_mask = File(exists=True, mandatory=True, desc='a mask')
threshold = traits.Float(0.5, usedefault=True,
desc='a threshold to the mask, if it is nonbinary')
in_file = File(exists=True, mandatory=True, desc="an image")
in_mask = File(exists=True, mandatory=True, desc="a mask")
threshold = traits.Float(
0.5, usedefault=True, desc="a threshold to the mask, if it is nonbinary"
)


class _ApplyMaskOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='masked file')
out_file = File(exists=True, desc="masked file")


class ApplyMask(SimpleInterface):
Expand All @@ -36,8 +42,9 @@ def _run_interface(self, runtime):
msknii = nb.load(self.inputs.in_mask)
msk = msknii.get_fdata() > self.inputs.threshold

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
)

if img.dataobj.shape[:3] != msk.shape:
raise ValueError("Image and mask sizes do not match.")
Expand All @@ -49,19 +56,18 @@ def _run_interface(self, runtime):
msk = msk[..., np.newaxis]

masked = img.__class__(img.dataobj * msk, None, img.header)
masked.to_filename(self._results['out_file'])
masked.to_filename(self._results["out_file"])
return runtime


class _BinarizeInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='input image')
thresh_low = traits.Float(mandatory=True,
desc='non-inclusive lower threshold')
in_file = File(exists=True, mandatory=True, desc="input image")
thresh_low = traits.Float(mandatory=True, desc="non-inclusive lower threshold")


class _BinarizeOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='masked file')
out_mask = File(exists=True, desc='output mask')
out_file = File(exists=True, desc="masked file")
out_mask = File(exists=True, desc="output mask")


class Binarize(SimpleInterface):
Expand All @@ -73,33 +79,35 @@ class Binarize(SimpleInterface):
def _run_interface(self, runtime):
img = nb.load(self.inputs.in_file)

self._results['out_file'] = fname_presuffix(
self.inputs.in_file, suffix='_masked', newpath=runtime.cwd)
self._results['out_mask'] = fname_presuffix(
self.inputs.in_file, suffix='_mask', newpath=runtime.cwd)
self._results["out_file"] = fname_presuffix(
self.inputs.in_file, suffix="_masked", newpath=runtime.cwd
)
self._results["out_mask"] = fname_presuffix(
self.inputs.in_file, suffix="_mask", newpath=runtime.cwd
)

data = img.get_fdata()
mask = data > self.inputs.thresh_low
data[~mask] = 0.0
masked = img.__class__(data, img.affine, img.header)
masked.to_filename(self._results['out_file'])
masked.to_filename(self._results["out_file"])

img.header.set_data_dtype('uint8')
maskimg = img.__class__(mask.astype('uint8'), img.affine,
img.header)
maskimg.to_filename(self._results['out_mask'])
img.header.set_data_dtype("uint8")
maskimg = img.__class__(mask.astype("uint8"), img.affine, img.header)
maskimg.to_filename(self._results["out_mask"])

return runtime


class _FourToThreeInputSpec(BaseInterfaceInputSpec):
in_file = File(exists=True, mandatory=True, desc='input 4d image')
allow_3D = traits.Bool(False, usedefault=True, desc='do not fail if a 3D volume is passed in')
in_file = File(exists=True, mandatory=True, desc="input 4d image")
allow_3D = traits.Bool(
False, usedefault=True, desc="do not fail if a 3D volume is passed in"
)


class _FourToThreeOutputSpec(TraitedSpec):
out_files = OutputMultiObject(File(exists=True),
desc='output list of 3d images')
out_files = OutputMultiObject(File(exists=True), desc="output list of 3d images")


class SplitSeries(SimpleInterface):
Expand All @@ -117,9 +125,11 @@ def _run_interface(self, runtime):
if ndim != 4:
if self.inputs.allow_3D and ndim == 3:
out_file = str(
Path(fname_presuffix(self.inputs.in_file, suffix=f"_idx-000")).absolute()
Path(
fname_presuffix(self.inputs.in_file, suffix=f"_idx-000")
).absolute()
)
self._results['out_files'] = out_file
self._results["out_files"] = out_file
filenii.to_filename(out_file)
return runtime
raise RuntimeError(
Expand All @@ -128,27 +138,29 @@ def _run_interface(self, runtime):
)

files_3d = nb.four_to_three(filenii)
self._results['out_files'] = []
self._results["out_files"] = []
in_file = self.inputs.in_file
for i, file_3d in enumerate(files_3d):
out_file = str(
Path(fname_presuffix(in_file, suffix=f"_idx-{i:03}")).absolute()
)
file_3d.to_filename(out_file)
self._results['out_files'].append(out_file)
self._results["out_files"].append(out_file)

return runtime


class _MergeSeriesInputSpec(BaseInterfaceInputSpec):
in_files = InputMultiObject(File(exists=True, mandatory=True,
desc='input list of 3d images'))
allow_4D = traits.Bool(True, usedefault=True,
desc='whether 4D images are allowed to be concatenated')
in_files = InputMultiObject(
File(exists=True, mandatory=True, desc="input list of 3d images")
)
allow_4D = traits.Bool(
True, usedefault=True, desc="whether 4D images are allowed to be concatenated"
)


class _MergeSeriesOutputSpec(TraitedSpec):
out_file = File(exists=True, desc='output 4d image')
out_file = File(exists=True, desc="output 4d image")


class MergeSeries(SimpleInterface):
Expand All @@ -169,12 +181,13 @@ def _run_interface(self, runtime):
nii_list += nb.four_to_three(filenii)
continue
else:
raise ValueError("Input image has an incorrect number of dimensions"
f" ({ndim}).")
raise ValueError(
"Input image has an incorrect number of dimensions" f" ({ndim})."
)

img_4d = nb.concat_images(nii_list)
out_file = fname_presuffix(self.inputs.in_files[0], suffix="_merged")
img_4d.to_filename(out_file)

self._results['out_file'] = out_file
self._results["out_file"] = out_file
return runtime
66 changes: 37 additions & 29 deletions niworkflows/interfaces/tests/test_nibabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def test_Binarize(tmp_path):
mask = np.zeros((20, 20, 20), dtype=bool)
mask[5:15, 5:15, 5:15] = bool

data = np.zeros_like(mask, dtype='float32')
data = np.zeros_like(mask, dtype="float32")
data[mask] = np.random.gamma(2, size=mask.sum())

in_file = tmp_path / 'input.nii.gz'
in_file = tmp_path / "input.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

binif = Binarize(thresh_low=0.0, in_file=str(in_file)).run()
Expand All @@ -36,28 +36,32 @@ def test_ApplyMask(tmp_path):
mask[8:11, 8:11, 8:11] = 1.0

# Test the 3D
in_file = tmp_path / 'input3D.nii.gz'
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

in_mask = tmp_path / 'mask.nii.gz'
in_mask = tmp_path / "mask.nii.gz"
nb.Nifti1Image(mask, np.eye(4), None).to_filename(str(in_mask))

masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.4).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3

masked1 = ApplyMask(in_file=str(in_file), in_mask=str(in_mask), threshold=0.6).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3

data4d = np.stack((data, 2 * data, 3 * data), axis=-1)
# Test the 4D case
in_file4d = tmp_path / 'input4D.nii.gz'
in_file4d = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(data4d, np.eye(4), None).to_filename(str(in_file4d))

masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5**3 * 6
masked1 = ApplyMask(
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.4
).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 5 ** 3 * 6

masked1 = ApplyMask(in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3**3 * 6
masked1 = ApplyMask(
in_file=str(in_file4d), in_mask=str(in_mask), threshold=0.6
).run()
assert nb.load(masked1.outputs.out_file).get_fdata().sum() == 3 ** 3 * 6

# Test errors
nb.Nifti1Image(mask, 2 * np.eye(4), None).to_filename(str(in_mask))
Expand All @@ -77,16 +81,17 @@ def test_SplitSeries(tmp_path):

# Test the 4D
data = np.ones((20, 20, 20, 15), dtype=float)
in_file = tmp_path / 'input4D.nii.gz'
in_file = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

split = SplitSeries(in_file=str(in_file)).run()
assert len(split.outputs.out_files) == 15

# Test the 3D
data = np.ones((20, 20, 20), dtype=float)
in_file = tmp_path / 'input3D.nii.gz'
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename(
str(in_file)
)

with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file)).run()
Expand All @@ -95,9 +100,10 @@ def test_SplitSeries(tmp_path):
assert isinstance(split.outputs.out_files, str)

# Test the 3D
data = np.ones((20, 20, 20, 1), dtype=float)
in_file = tmp_path / 'input3D.nii.gz'
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20, 1), dtype=float), np.eye(4), None).to_filename(
str(in_file)
)

with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file)).run()
Expand All @@ -106,9 +112,10 @@ def test_SplitSeries(tmp_path):
assert isinstance(split.outputs.out_files, str)

# Test the 5D
data = np.ones((20, 20, 20, 2, 2), dtype=float)
in_file = tmp_path / 'input5D.nii.gz'
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
in_file = tmp_path / "input5D.nii.gz"
nb.Nifti1Image(
np.ones((20, 20, 20, 2, 2), dtype=float), np.eye(4), None
).to_filename(str(in_file))

with pytest.raises(RuntimeError):
SplitSeries(in_file=str(in_file)).run()
Expand All @@ -118,7 +125,7 @@ def test_SplitSeries(tmp_path):

# Test splitting ANTs warpfields
data = np.ones((20, 20, 20, 1, 3), dtype=float)
in_file = tmp_path / 'warpfield.nii.gz'
in_file = tmp_path / "warpfield.nii.gz"
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))

split = SplitSeries(in_file=str(in_file)).run()
Expand All @@ -129,17 +136,18 @@ def test_MergeSeries(tmp_path):
"""Test 3-to-4 NIfTI concatenation interface."""
os.chdir(str(tmp_path))

data = np.ones((20, 20, 20), dtype=float)
in_file = tmp_path / 'input3D.nii.gz'
nb.Nifti1Image(data, np.eye(4), None).to_filename(str(in_file))
in_file = tmp_path / "input3D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20), dtype=float), np.eye(4), None).to_filename(
str(in_file)
)

merge = MergeSeries(in_files=[str(in_file)] * 5).run()
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)

in_4D = tmp_path / 'input4D.nii.gz'
nb.Nifti1Image(
np.ones((20, 20, 20, 4), dtype=float), np.eye(4), None
).to_filename(str(in_4D))
in_4D = tmp_path / "input4D.nii.gz"
nb.Nifti1Image(np.ones((20, 20, 20, 4), dtype=float), np.eye(4), None).to_filename(
str(in_4D)
)

merge = MergeSeries(in_files=[str(in_file)] + [str(in_4D)]).run()
assert nb.load(merge.outputs.out_file).dataobj.shape == (20, 20, 20, 5)
Expand Down

0 comments on commit 0ae5351

Please sign in to comment.