Creating a streamline aggregator

There are many ways to group streamlines. Streamlines can be labeled based on which brain regions they connect, such as a :pyclass:`~dsi2.aggregation.region_labeled_clusters.RegionLabelAggregator`. Or they can be grouped based on the similarity of their morphology, such as with a QuickBundlesAggregator.

One of the advantages of using DSI2 is that you can test out a custom clustering algorithm on any part of the brain while interactively changing the parameters of your algorithm. Here we’ll build a new Aggregator from scratch.

  • Subclasses ClusterEditor
  • Overrides the aggregate method
  • operates on a TrackDataSource

Defining the Aggregator class

Defining algorithm parameters as traits

The UI and event listening functionality for Aggregators is inherited from the :pyclass:`~dsi2.aggregation.cluster_ui.ClusterEditor`. Whichever parameters will be needed by your aggregation algorithm should be defined as Traits in the class.

from dsi2.aggregation.cluster_ui import ClusterEditor
from sklearn.cluster import MiniBatchKMeans
from dipy.tracking.metrics import downsample
from traits.api import Float, Int, Enum
from traitsui.api import Item, View, Group
import numpy as np

class KMeansAggregator(ClusterEditor):
    k = Int(3,
            label="Epsilon",
            desc="How many groups should the streamlines be grouped into? ",
            parameter=True)

The parameter of the k-means algorithm is attached to the class as Traits with some special metadata. It is crucial that parameter=True is passed to each of these Trait definitions so that ClusterEditor knows to update its streamline labeling if the value gets changed in the UI.

Specifying the GUI editors for your aggregator

The ClusterEditor superclass looks for a variable in its subclasses called algorithm_widgets. Editor widgets from TraitsUI are defined in algorithm_widgets.

class KMeansAggregator(ClusterEditor):
    ...
    algorithm_widgets = Group(
                          Item("k")
                          )

More advanced editors can be specified, but for now we’ll just use the default editors provided by TraitsUI.

Overriding the .aggregate() method

To fit the streamline data into groups, we must override the .aggregate() method. This method should expect a single TrackDataset as its argument. We will access its .tracks property and apply some transformation that turns an arbitrarily shaped streamline into a feature vector that k-means can use. Here we will copy DSI Studio’s clustering and extract the following features: the first, middle and last coordinate and the length of each streamline form a 10-feature vector.

def aggregate(self, track_dataset):
    """
    An example implementation of the k-means algorithm implemented in
    DSI Studio.  This function is automatically applied to all
    TrackDatasets returned from a query.

    Parameters:
    -----------
    track_dataset:dsi2.streamlines.track_dataset.TrackDataset
    """
    # extract the streamline data
    tracks = track_dataset.tracks

    # Make a matrix of downsampled streamlines
    points = np.array([ downsample(trk, n_pol=3).flatten() \
                                for trk in tracks])

    # Calculate the length of each streamline
    lengths = np.array([len(trk) for trk in tracks]).reshape(-1,1)

    # Concatenate the points and the track lengths
    features = np.hstack((points, lengths))

    # Initialize the k-means algorithm
    kmeans = MiniBatchKMeans(n_clusters=self.k, compute_labels=True)
    kmeans.fit(features)

    # Return the labels
    return kmeans.labels_

That’s all there is to it!

Exploring your aggregator in realtime

With the aggregator class defined, we can begin applying it to streamlines. Let’s create a data source, an aggregator, and set up a sphere browser to use them

from dsi2.database.mongo_track_datasource import MongoTrackDataSource
from dsi2.ui.sphere_browser import SphereBrowser

# Only select a single scan from the test data
test_subject = [ "0377A" ]

data_source = MongoTrackDataSource(
  scan_ids = test_subject,
  mongo_host = "127.0.0.1",
  mongo_port = 27017,
  db_name="dsi2_test"
)

kmeans_agg = KMeansAggregator()

browser = SphereBrowser(track_source=data_source, aggregator=kmeans_agg)
browser.edit_traits()