Skip to content

Commit

Permalink
Merge pull request #178 from NeuralEnsemble/feat/path-util
Browse files Browse the repository at this point in the history
Towards semi-automated path generation
  • Loading branch information
pgleeson authored Sep 11, 2023
2 parents b7f3bdf + 63dc4cf commit 81d4d5c
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 6 deletions.
58 changes: 58 additions & 0 deletions neuroml/nml/generatedssupersuper.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,61 @@ def _check_arg_list(self, **kwargs):
print(err)
self.info()
raise ValueError(err)

@classmethod
def get_class_hierarchy(cls):
"""Get the class hierarchy for a component classs.
Reference: https://stackoverflow.com/a/75161393/375067
See the methods in neuroml.utils to use this generated hierarchy.
:returns: nested single key dictionaries where the key of each
dictionary is the root node of that subtree, and keys are its
immediate descendents
"""
# classes that don't have any members, like ZeroOrNone, which is an Enum
schema = sys.modules[cls.__module__]
try:
allmembers = cls._get_members()
except AttributeError:
return {cls.__name__: []}

retlist = []
for member in allmembers:
if member is not None:
# is it a complex type, which will have a corresponding class?
member_class = getattr(schema, member.get_data_type(), None)
# if it isn't a class, so a simple type, just added it with an
# empty list
if member_class is None:
retlist.append({member.get_name(): []})
else:
# if it is a class, see if it has a hierarchy
try:
retlist.append(member_class.get_class_hierarchy())
except AttributeError:
retlist.append({member_class.__name__: []})

return {cls.__name__: retlist}

@classmethod
def get_nml2_class_hierarchy(cls):
"""Return the NeuroML class hierarchy.
The root here is NeuroMLDocument.
This is useful in calculating paths to different components to aid in
construction of relative paths.
This caches the value as a class variable so that it is not
re-calculated when used multiple times.
"""
# if hierarchy exists, return
try:
return cls.__nml_hier
# first run
except AttributeError:
schema = sys.modules[cls.__module__]
cls.__nml_hier = schema.NeuroMLDocument.get_class_hierarchy()
return cls.__nml_hier
17 changes: 16 additions & 1 deletion neuroml/test/test_nml.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
except ImportError:
import unittest

from neuroml.utils import component_factory
from neuroml.utils import (component_factory, print_hierarchy)
import neuroml


Expand Down Expand Up @@ -762,3 +762,18 @@ def test_morphinfo(self):

cell.morphinfo(True)
cell.biophysinfo()

def test_class_hierarchy(self):
"""Test the class hierarchy getter and printer
"""
hier = neuroml.Cell.get_class_hierarchy()
self.assertIsNotNone(hier)
print()
print_hierarchy(hier)

hier = neuroml.Morphology.get_class_hierarchy()
self.assertIsNotNone(hier)
print()
print(hier)
print()
print_hierarchy(hier)
17 changes: 14 additions & 3 deletions neuroml/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
Copyright 2023 NeuroML contributors
"""

import neuroml
from neuroml.utils import component_factory
import unittest
import tempfile

import neuroml
from neuroml.utils import (component_factory, get_relative_component_path,
print_hierarchy)


class UtilsTestCase(unittest.TestCase):
Expand Down Expand Up @@ -42,3 +43,13 @@ def test_component_factory_should_fail(self):
"IafCell",
id="test_cell",
)

def test_networkx_hier_graph(self):
"""Test constructing a networkx graph of a hierarchy"""
hier = neuroml.NeuroMLDocument.get_class_hierarchy()
self.assertIsNotNone(hier)
print_hierarchy(hier)

path, graph = get_relative_component_path("Input", "Instance")
self.assertIsNotNone(graph)
self.assertEqual(path, "../Population/Instance")
81 changes: 79 additions & 2 deletions neuroml/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
Utilities for checking generated code
"""
import sys
import inspect
import os
import sys
import warnings
from typing import Union, Any
from typing import Any, Dict, Union, Optional, Type

import networkx

import neuroml.nml.nml as schema

from . import loaders


Expand Down Expand Up @@ -228,6 +232,79 @@ def component_factory(
return new_obj


def print_hierarchy(tree, indent=4, current_ind=0):
"""Print the hierarchy tree generated by get_class_hierarchy
Reference: https://stackoverflow.com/a/75161393/375067
"""
for k, v in tree.items():
if current_ind:
before_dashes = current_ind - indent
print(' ' * before_dashes + '└' + '-'*(indent-1) + k)
else:
print(k)
for sub_tree in v:
print_hierarchy(sub_tree, indent=indent, current_ind=current_ind + indent)


def get_hier_graph_networkx(graph: networkx.DiGraph, hier: Dict[str, Any]):
"""Get a networkx graph of the NeuroML hierarchy
:param graph: graph object to populate
:param hier: component type hierarchy obtained from `get_class_hierarchy`
and `get_nml2_class_hierarchy` methods
:returns: None
"""
for k, vs in hier.items():
for v in vs:
if type(v) is dict:
graph.add_edge(k, list(v.keys())[0])
get_hier_graph_networkx(graph, v)
else:
graph.add_edge(k, v)


def get_relative_component_path(src: str, dest: str, root: Type =
schema.NeuroMLDocument, graph:
Optional[networkx.DiGraph] = None):
"""Construct a path from src component to dest in a neuroml document.
Useful when referring to components in other components
Note that
:param src: source component
:param dest: destination component
:param root: root component of the hierarchy
:param graph: a networkx digraph of the NeuroML hierarchy if available
if not, one is constructed
:returns: calculated path and networkx digraph for future use
"""
if graph is None:
graph = networkx.DiGraph()
get_hier_graph_networkx(graph, root.get_nml2_class_hierarchy())

p1 = (list(networkx.all_shortest_paths(graph, root.__name__, "Instance")))
p2 = (list(networkx.all_shortest_paths(graph, root.__name__, "Input")))

if len(p1) > 1 or len(p2) > 1:
print("Multiple paths found, cannot calculate recommended path")
print("Paths are:")
for p in p1 + p2:
print("/".join(p1[0]))
else:
p1s = "/".join(p1[0])
p2s = "/".join(p2[0])
print(f"Path1: {p1s}")
print(f"Path2: {p2s}")
# remove one "../" because we do not need to get to the common level
# here, unlike actual file system path traversal
path = os.path.relpath(p1s, p2s).replace("../", "", 1)
print("Relative path: " + path)

return (path, graph)


def main():
if len(sys.argv) != 2:
print("Please specify the name of the NeuroML2 file...")
Expand Down

0 comments on commit 81d4d5c

Please sign in to comment.