Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNT Integration: Pre-load LabkitProjectFrame with pre-defined labels #112

Open
tferr opened this issue Jan 2, 2024 · 7 comments
Open

SNT Integration: Pre-load LabkitProjectFrame with pre-defined labels #112

tferr opened this issue Jan 2, 2024 · 7 comments

Comments

@tferr
Copy link

tferr commented Jan 2, 2024

@maarzt ,

We've been hacking ways to integrate both Labkit and TWS into SNT. Right now we can import and train models from paths, but a more useful command would to to "Send neurites" to Labkit, so that the traced paths of selected neurites could be pre-loaded into a Labkit instance as labels.

This would allow for pixel-perfect labels along thin neurites, which - in our experience - significantly improves the training relatively to freehand annotations.

In TWS, we convert Paths into ROIs and feed those to its GUI at startup. It is a bit clunky because we need to use reflection, but it works quite well (will probably open a PR there to formalize this feature). Labkit can hold much larger datasets and would be a pity to be restricted to TWS.

Do you have pointers on how to attempt this in Labkit? I've looked at LabelBrushController briefly, but did not find it useful. Did I miss something? I am convinced that a method that would allow to 'paint' programmatically would work. A method that converts poly/freehand-lines into labels would be best.

NB: We can convert selected paths into a digitized skeleton mask and train a model with it, but that is not interactive, and defeats a bit what we are trying to achieve.

@maarzt
Copy link
Collaborator

maarzt commented Jan 12, 2024

Ok, interesting. In Labkit I call a collection of ROIs a "labeling". So if I understand you correctly you would like to start Labkit with a pre-initialized "labeling". That is definitely possible. Here is an example for you:

import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.util.Arrays;

import javax.swing.JFrame;
import javax.swing.WindowConstants;

import ij.ImagePlus;
import net.imagej.ImgPlus;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccess;
import net.imglib2.img.VirtualStackAdapter;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.roi.labeling.LabelingType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import org.scijava.Context;
import sc.fiji.labkit.ui.SegmentationComponent;
import sc.fiji.labkit.ui.inputimage.DatasetInputImage;
import sc.fiji.labkit.ui.labeling.Label;
import sc.fiji.labkit.ui.labeling.Labeling;
import sc.fiji.labkit.ui.models.DefaultSegmentationModel;
import sc.fiji.labkit.ui.segmentation.SegmentationTool;
import sc.fiji.labkit.ui.segmentation.Segmenter;

/**
 * An example how one can start a "custom" Labkit with pre-initialized labeling.
 */
public class LabkitExample {

	public static void main(String... args) {
		// open an image and create a segmentation model
		Context context = new Context();
		ImagePlus imagePlus = new ImagePlus("https://imagej.nih.gov/ij/images/t1-head.zip");
		ImgPlus<?> image = VirtualStackAdapter.wrap(imagePlus);
		image.setChannelMinimum(0, 0 );     // set brightness and contrast for displaying the given image
		image.setChannelMaximum(0, 1000 );
		DefaultSegmentationModel segmentationModel =
			new DefaultSegmentationModel(context, new DatasetInputImage(image));

		// Initialize an empty labeling of correct size
		FinalInterval imageSize = new FinalInterval(image);
		Labeling labeling = Labeling.createEmpty(Arrays.asList("S1", "S2", "background"), imageSize);

		// add one straight line
		RandomAccess<LabelingType<Label>> randomAccess = labeling.randomAccess();
		LabelingType<Label> color = randomAccess.get().createVariable();
		color.add(labeling.getLabel("S1"));
		for (int i = 0; i < 200; i++) {
			randomAccess.setPosition(i, 0); // set X
			randomAccess.setPosition(i, 1); // set Y
			randomAccess.setPosition(i, 2); // set Z
			randomAccess.get().set( color );
		}

		// Create the gui components
		segmentationModel.imageLabelingModel().labeling().set(labeling);
		JFrame frame = new JFrame();
		SegmentationComponent segmentationComponent =
			new SegmentationComponent(frame, segmentationModel, false);
		frame.add(segmentationComponent);
		frame.setSize(1000, 600);
		frame.setDefaultCloseOperation( WindowConstants.DISPOSE_ON_CLOSE );
		frame.addWindowListener(new WindowAdapter() {

			@Override
			public void windowClosed(WindowEvent e) {
				showResults( segmentationModel, image );
				segmentationComponent.close(); // don't forget to close the segmentation component if you are done.
			}
		});
		frame.setVisible(true);
	}

	private static void showResults(DefaultSegmentationModel segmentationModel, ImgPlus<?> image) {
		// Compute and show the complete segmentation and probability map.
		Segmenter segmenter = segmentationModel.segmenterList().segmenters().get().get(0);
		SegmentationTool segmentationTool = new SegmentationTool(segmenter);
		ImgPlus<UnsignedByteType> segmentation = segmentationTool.segment(image);
		ImageJFunctions.show(segmentation).setDisplayRange(0, 3);
		ImgPlus<FloatType> probabilityMap = segmentationTool.probabilityMap(image);
		ImageJFunctions.show(probabilityMap).setDisplayRange(0, 1);
	}

}

The example start's Labkit on a given image and also shows a labeling that is pre-initialized with a straight line.

Additionally if you train a pixel classifier and close the window. The results will be shown in two new windows.

@tferr I hope this example helps you!

@tferr
Copy link
Author

tferr commented Jan 12, 2024

Thanks @maarzt , this is really helpful. Thanks for the detailed example. It should be straightforward to adapt this.

Should I assume that the ZCT positions are indexed to the ImgPlus<T> Axes?, i.e.:

final long zLen = image.dimension(image.dimensionIndex(Axes.Z));
final long cLen = image.dimension(image.dimensionIndex(Axes.CHANNEL));
final long tLen = image.dimension(image.dimensionIndex(Axes.TIME));
for (int i = 0; i < 200; i++)  {
	randomAccess.setPosition(i, image.dimensionIndex(Axes.X)); // set X
	randomAccess.setPosition(i, image.dimensionIndex(Axes.Y)); // set Y
	if (zLen > 1)
		randomAccess.setPosition(i, image.dimensionIndex(Axes.Z)); // set Z
	if (cLen > 1)
		randomAccess.setPosition(i, image.dimensionIndex(Axes.CHANNEL)); // set C
	if (tLen > 1)
		randomAccess.setPosition(i, image.dimensionIndex(Axes.TIME)); // set T
	randomAccess.get().set(color);
}

Or is there some other convention in place in which the axis index is always known, eg. X=0;Y=1;Z=2;C=3, etc.?

@maarzt
Copy link
Collaborator

maarzt commented Jan 12, 2024

Labkit can deal with varying Axes order. But yes you need to set the axes types correctly in the ImgPlus<T>.

For the Labeling it's a different story. I think there the axes order is fixed to XYZT. The labeling has no channel axis. Every channel has the same "label".

@tferr
Copy link
Author

tferr commented Jan 12, 2024

The labeling has no channel axis. Every channel has the same "label".

Why is that? Are multichannel images not formally supported? If you have an image of two fluorophores (A, and B), and want to train the model with 3 classes ("foreground a", "foreground b", and "background"), how would one go about it? I see indeed that I trigger this error, when I use a 2D multichannel image: The error says that the image has 2 dimensions(!) (and my labeling 3), so I guess the channel dimension as been simplified internally by LabKit !? Something akin to an RGB conversion?

Apart from that, I think I got everything to work well with 'normal' grayscale images 2D, 3D, with or without a time axis (will link the code here as soon as I have time to check that everything is indeed working)

@tferr
Copy link
Author

tferr commented Jan 16, 2024

Multichannel-support aside: the code is here and I added a dedicated page to the documentation. Hopefully I got the main differences right between Labkit/TWS. Do let me know if otherwise.

@maarzt
Copy link
Collaborator

maarzt commented Jan 19, 2024

Your code looks good. Does it work as intended?

Regarding your earlier question about multichannel images. Yes the channel axis is treated specially. But the details are hard to explain. The "LabelingType" class that is used for the labeling does a lot of tricks. It is weakly similar to an RGB type but rather than representing color, it can represent a set of labels. So technically a pixel can be annotated as "foreground", "background", {"foreground", "background"} or an empty set.

I modified the example to show how to use Labkit on a two channel image:

import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.util.Arrays;

import javax.swing.JFrame;
import javax.swing.WindowConstants;

import ij.ImagePlus;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.img.VirtualStackAdapter;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.roi.labeling.LabelingType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.Views;
import org.scijava.Context;
import sc.fiji.labkit.ui.SegmentationComponent;
import sc.fiji.labkit.ui.inputimage.DatasetInputImage;
import sc.fiji.labkit.ui.labeling.Label;
import sc.fiji.labkit.ui.labeling.Labeling;
import sc.fiji.labkit.ui.models.DefaultSegmentationModel;
import sc.fiji.labkit.ui.segmentation.SegmentationTool;
import sc.fiji.labkit.ui.segmentation.Segmenter;
import sc.fiji.labkit.ui.utils.DimensionUtils;

/**
 * An example how one can start a "custom" Labkit with pre-initialized labeling.
 */
public class LabkitExample {

	public static void main(String... args) {
		// open an image and create a segmentation model
		Context context = new Context();
		ImagePlus imagePlus = new ImagePlus("https://imagej.net/ij/images/confocal-series.zip");
		ImgPlus<?> image = VirtualStackAdapter.wrap(imagePlus);
		DefaultSegmentationModel segmentationModel =
			new DefaultSegmentationModel(context, new DatasetInputImage(image));

		// Initialize an empty labeling of correct size
		int channelDimension = image.dimensionIndex(Axes.CHANNEL);
		Interval imageSize = DimensionUtils.intervalRemoveDimension(image, channelDimension);
		Labeling labeling = Labeling.createEmpty(Arrays.asList("S1", "S2", "background"), imageSize);

		// add some labels
		RandomAccess<LabelingType<Label>> randomAccess = labeling.randomAccess();
		LabelingType<Label> color = randomAccess.get().createVariable();
		color.clear();
		color.add(labeling.getLabel("S1"));
		Views.interval(labeling, Intervals.createMinSize(120,205,11, 30, 3, 5)).forEach(pixel -> pixel.set(color));
		color.clear();
		color.add(labeling.getLabel("S2"));
		Views.interval(labeling, Intervals.createMinSize(250,139,11, 3, 30, 5)).forEach(pixel -> pixel.set(color));
		color.clear();
		color.add(labeling.getLabel("background"));
		Views.interval(labeling, Intervals.createMinSize(34,28,11, 20, 4, 5)).forEach(pixel -> pixel.set(color));

		// Create the gui components
		segmentationModel.imageLabelingModel().labeling().set(labeling);
		JFrame frame = new JFrame();
		SegmentationComponent segmentationComponent =
			new SegmentationComponent(frame, segmentationModel, false);
		frame.add(segmentationComponent);
		frame.setSize(1000, 600);
		frame.setDefaultCloseOperation( WindowConstants.DISPOSE_ON_CLOSE );
		frame.addWindowListener(new WindowAdapter() {

			@Override
			public void windowClosed(WindowEvent e) {
				showResults( segmentationModel, image );
				segmentationComponent.close(); // don't forget to close the segmentation component if you are done.
			}
		});
		frame.setVisible(true);
	}

	private static void showResults(DefaultSegmentationModel segmentationModel, ImgPlus<?> image) {
		Segmenter segmenter = segmentationModel.segmenterList().segmenters().get().get(0);
		SegmentationTool segmentationTool = new SegmentationTool(segmenter);
		ImgPlus<UnsignedByteType> segmentation = segmentationTool.segment(image);
		ImageJFunctions.show(segmentation).setDisplayRange(0, 3);
		ImgPlus<FloatType> probabilityMap = segmentationTool.probabilityMap(image);
		ImageJFunctions.show(probabilityMap).setDisplayRange(0, 1);
	}

}

The biggest difference is in these two lines:

		int channelDimension = image.dimensionIndex(Axes.CHANNEL);
		Interval imageSize = DimensionUtils.intervalRemoveDimension(image, channelDimension);

@tferr
Copy link
Author

tferr commented Jan 24, 2024

@maarzt , thanks! I've incorporated this, and tweaked the documentation. Everything seems to be working well, on a couple of tests with 2D and 3D multichannel images. I consider the SNT <> Labkit bridge finalized, so feel free to close this. Thanks a lot for the thoughtful examples

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants