Skip to content

Commit

Permalink
add options to fetch triples/graph from an external KG (#14)
Browse files Browse the repository at this point in the history
* Modify mapping strategy to fetch graph from a KG
  • Loading branch information
Treesarj authored Jul 11, 2024
1 parent fbe1094 commit d5f5606
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 13 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ ci:
autoupdate_schedule: 'weekly'
skip: ["pylint", "pylint-tests"]
submodules: false

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
242 changes: 230 additions & 12 deletions oteapi_dlite/strategies/mapping.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
"""Mapping filter strategy."""

# pylint: disable=unused-argument,invalid-name
from __future__ import annotations

import logging

# pylint: disable=unused-argument,invalid-name,disable=line-too-long,E1133,W0511
from enum import Enum
from typing import TYPE_CHECKING, Annotated, Optional

import rdflib
from jinja2 import Template, TemplateError
from oteapi.models import AttrDict, MappingConfig
from pydantic import AnyUrl
from pydantic.dataclasses import Field, dataclass
from rdflib.exceptions import Error as RDFLibException
from SPARQLWrapper import JSON, SPARQLWrapper
from SPARQLWrapper.SPARQLExceptions import SPARQLWrapperException
from tripper import Triplestore

from oteapi_dlite.models import DLiteSessionUpdate
from oteapi_dlite.utils import get_collection, update_collection

if TYPE_CHECKING: # pragma: no cover
from typing import Any
logger = logging.getLogger(__name__)


class BackendEnum(str, Enum):
Expand Down Expand Up @@ -86,6 +96,18 @@ class DLiteMappingStrategyConfig(AttrDict):
)
),
] = None
graph_uri: Annotated[
Optional[str],
Field(
description=("The URI of the graph in which to perform the query")
),
] = None
sparql_endpoint: Annotated[
Optional[str],
Field(
description="Endpoint Url to create an instance of SPARQLWrapper configured for the target SPARQL service"
),
] = None


class DLiteMappingConfig(MappingConfig):
Expand Down Expand Up @@ -120,7 +142,7 @@ def initialize(self) -> DLiteSessionUpdate:
ts = Triplestore(
backend=self.mapping_config.configuration.backend,
base_iri=self.mapping_config.configuration.base_iri,
triplestore_url=self.mapping_config.configuration.triplestore_url, # pylint: disable=line-too-long
triplestore_url=self.mapping_config.configuration.triplestore_url,
database=self.mapping_config.configuration.database,
uname=self.mapping_config.configuration.username,
pwd=self.mapping_config.configuration.password,
Expand All @@ -131,18 +153,38 @@ def initialize(self) -> DLiteSessionUpdate:
if self.mapping_config.prefixes:
for prefix, iri in self.mapping_config.prefixes.items():
ts.bind(prefix, iri)

if self.mapping_config.triples:
ts.add_triples(
[
[
ts.expand_iri(t) if isinstance(t, str) else t
for t in triple
]
for triple in self.mapping_config.triples # pylint: disable=not-an-iterable
]
if (
self.mapping_config.configuration.sparql_endpoint
and self.mapping_config.configuration.graph_uri
):
config = self.mapping_config.configuration
sparql_instance = SPARQLWrapper(config.sparql_endpoint)
sparql_instance.setHTTPAuth("BASIC")
sparql_instance.setCredentials(
config.username,
config.password,
)
# extract class names i.e. objects from triples
class_names = [triple[2] for triple in self.mapping_config.triples]
# Find parent node of the class_names
parent_node: str | None = find_parent_node(
sparql_instance,
class_names,
config.graph_uri, # type:ignore
)
# If parent node exists, find the KG
if parent_node:
graph: rdflib.Graph = fetch_and_populate_graph(
sparql_instance,
config.graph_uri, # type:ignore
parent_node,
)
graph_triples = [(str(s), str(p), str(o)) for s, p, o in graph]
# Add triples to the collection
populate_triplestore(ts, graph_triples)

# Add triples to the collection
populate_triplestore(ts, self.mapping_config.triples)
update_collection(coll)
return DLiteSessionUpdate(collection_id=coll.uuid)

Expand All @@ -155,3 +197,179 @@ def get(self) -> DLiteSessionUpdate:
else get_collection().uuid
)
)


def populate_triplestore(ts: Triplestore, triples: list):
"""Populate the triplestore instance"""
ts.add_triples(
[
[ts.expand_iri(t) if isinstance(t, str) else t for t in triple]
for triple in (triples) # pylint: disable=not-an-iterable
]
)


# TODO: import the below function from SOFT7 once its available
def find_parent_node(
sparql: SPARQLWrapper,
class_names: list[str],
graph_uri: str,
) -> str | None:
"""
Queries a SPARQL endpoint to find a common parent node (LCA) for a given list of
class URIs within a specified RDF graph.
Args:
sparql (SPARQLWrapper): An instance of SPARQLWrapper configured for the target
SPARQL service.
class_names (list[str]): The class URIs to find a common parent for.
graph_uri (str): The URI of the graph in which to perform the query.
Returns:
str | None: The URI of the common parent node if one exists, otherwise None.
Raises:
RuntimeError: If there is an error in executing or processing the SPARQL query
or if there is an error in rendering the SPARQL query using Jinja2 templates.
Note:
This function assumes that the provided `sparql` instance is already configured
with necessary authentication and format settings.
"""

try:
template_str = """
{% macro sparql_query(class_names, graph_uri) %}
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
SELECT ?parentClass
WHERE {
GRAPH <{{ graph_uri }}> {
?class rdfs:subClassOf* ?parentClass .
FILTER(
{% for class_name in class_names -%}
?class = <{{ class_name }}>{{ " ||" if not loop.last }}
{% endfor %})
}
}
{% endmacro %}
"""

template = Template(template_str)
query = template.module.sparql_query(class_names, graph_uri)
sparql.setReturnFormat(JSON)
sparql.setQuery(query)

target_count = len(class_names)
counts: dict[str, int] = {}

results = sparql.query().convert()
for result in results["results"]["bindings"]:
parent_class = result["parentClass"]["value"]
counts[parent_class] = counts.get(parent_class, 0) + 1
if counts[parent_class] == target_count:
return parent_class

except SPARQLWrapperException as wrapper_error:
raise RuntimeError(
f"Failed to fetch or parse results: {wrapper_error}"
) from wrapper_error

except TemplateError as template_error:
raise RuntimeError(
f"Jinja2 template error: {template_error}"
) from template_error

logger.info("Could not find a common parent node.")
return None


# TODO: import the below function from SOFT7 once its available
def fetch_and_populate_graph(
sparql: SPARQLWrapper,
graph_uri: str,
parent_node: str,
graph: Optional[rdflib.Graph] = None,
) -> rdflib.Graph | None:
"""
Fetches RDF triples related to a specified parent node from a SPARQL endpoint and
populates them into an RDF graph.
Args:
sparql (SPARQLWrapper): An instance of SPARQLWrapper configured for the target
SPARQL service.
graph_uri (str): The URI of the graph from which triples will be fetched.
parent_node (str): The URI of the parent node to base the triple fetching on.
graph (rdflib.Graph, optional): An instance of an RDFlib graph to populate with
fetched triples.
If `None`, a new empty graph is created. Defaults to `None`.
Returns:
rdflib.Graph: The graph populated with the fetched triples.
Raises:
RuntimeError: If processing the SPARQL query or building the RDF graph fails.
Note:
This function assumes that the provided `sparql` instance is already configured
with necessary authentication and format settings.
"""
# Create a new graph if one is not provided
graph = graph or rdflib.Graph()

try:
sparql.setReturnFormat(JSON)

query = f"""
PREFIX owl: <http://www.w3.org/2002/07/owl#>
PREFIX skos: <http://www.w3.org/2004/02/skos/core#>
PREFIX rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
PREFIX fno: <https://w3id.org/function/ontology#>
SELECT ?subject ?predicate ?object
WHERE {{
GRAPH <{graph_uri}> {{
?subject ?predicate ?object .
?subject rdfs:subClassOf* <{parent_node}> .
FILTER (?predicate IN (
rdfs:subClassOf,
skos:prefLabel,
rdfs:subPropertyOf,
rdfs:domain,
rdfs:range,
rdf:type,
owl:propertyDisjointWith,
fno:expects,
fno:predicate,
fno:type,
fno:returns,
fno:executes))
}}
}}
"""
sparql.setQuery(query)

results = sparql.query().convert()
for result in results["results"]["bindings"]:
graph.add(
(
rdflib.URIRef(result["subject"]["value"]),
rdflib.URIRef(result["predicate"]["value"]),
rdflib.URIRef(result["object"]["value"]),
)
)

logger.info("Graph populated with fetched triples.")

except SPARQLWrapperException as wrapper_error:
raise RuntimeError(
f"Failed to fetch or parse results: {wrapper_error}"
) from wrapper_error

except RDFLibException as rdflib_error:
raise RuntimeError(
f"Failed to build graph elements: {rdflib_error}"
) from rdflib_error

return graph

0 comments on commit d5f5606

Please sign in to comment.