Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_ancestors and get_descendants have the same arguments. #572

Merged
merged 10 commits into from
May 12, 2023
12 changes: 10 additions & 2 deletions ontopy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,17 @@ def add_branch( # pylint: disable=too-many-arguments,too-many-locals
nodeattrs=nodeattrs,
**attrs,
)

common_ancestors = False
ancestor_generations = None
if include_parents in ("common", "closest"):
common_ancestors = True
elif isinstance(include_parents, int):
ancestor_generations = include_parents
francescalb marked this conversation as resolved.
Show resolved Hide resolved
parents = self.ontology.get_ancestors(
classes, include=include_parents, strict=True
classes,
common=common_ancestors,
generations=ancestor_generations,
strict=True,
)
if parents:
for parent in parents:
Expand Down
58 changes: 33 additions & 25 deletions ontopy/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
If desirable some of these additions may be moved back into owlready2.
"""
# pylint: disable=too-many-lines,fixme,arguments-differ,protected-access
from typing import TYPE_CHECKING, Optional, Union, Sequence
from typing import TYPE_CHECKING, Optional, Union
import os
import itertools
import inspect
Expand Down Expand Up @@ -1468,18 +1468,25 @@ def closest_common_ancestor(*classes):
"A closest common ancestor should always exist !"
)

def get_ancestors(self, classes, include="all", strict=True):
def get_ancestors(
self,
classes: "Union[List, ThingClass]",
common: bool = False,
generations: int = None,
strict: bool = True,
) -> set:
"""Return ancestors of all classes in `classes`.
classes to be provided as list

The values of `include` may be:
- None: ignore this argument
- "all": Include all ancestors.
- "closest": Include all ancestors up to the closest common
ancestor of all classes.
- int: Include this number of ancestor levels. Here `include`
may be an integer or a string that can be converted to int.
Args:
francescalb marked this conversation as resolved.
Show resolved Hide resolved
classes: class(es) for which ancestors are desired.
common: whether to only return the closest common ancestor.
francescalb marked this conversation as resolved.
Show resolved Hide resolved
generations: Include this number of generations, default is all.
strict: only return real ancestors if True.
jesper-friis marked this conversation as resolved.
Show resolved Hide resolved
Returns:
A set of ancestors for given number of generations.
"""
if not isinstance(classes, Iterable):
classes = [classes]

ancestors = set()
if not classes:
return ancestors
Expand All @@ -1490,22 +1497,23 @@ def addancestors(entity, counter, subject):
subject.add(parent)
addancestors(parent, counter - 1, subject)

if isinstance(include, str) and include.isdigit():
include = int(include)

if include == "all":
ancestors.update(*(_.ancestors() for _ in classes))
elif include == "closest":
closest = self.closest_common_ancestor(*classes)
if sum(map(bool, [common, generations])) > 1:
raise ValueError(
"Only one of `generations` or `common` may be specified."
)
if common:
closest_ancestor = self.closest_common_ancestor(*classes)
for cls in classes:
ancestors.update(
_ for _ in cls.ancestors() if closest in _.ancestors()
_
francescalb marked this conversation as resolved.
Show resolved Hide resolved
for _ in cls.ancestors()
if closest_ancestor in _.ancestors()
)
elif isinstance(include, int):
elif isinstance(generations, int):
for entity in classes:
addancestors(entity, int(include), ancestors)
elif include not in (None, "None", "none", ""):
raise ValueError('include must be "all", "closest" or None')
addancestors(entity, generations, ancestors)
else:
ancestors.update(*(_.ancestors() for _ in classes))
francescalb marked this conversation as resolved.
Show resolved Hide resolved

if strict:
return ancestors.difference(classes)
Expand All @@ -1519,7 +1527,7 @@ def get_descendants(
) -> set:
"""Return descendants/subclasses of all classes in `classes`.
Args:
classes: to be provided as list.
classes: class(es) for which descendants are desired.
common: whether to only return descendants common to all classes.
generations: Include this number of generations, default is all.
Returns:
Expand All @@ -1529,7 +1537,7 @@ def get_descendants(
'generations' defaults to all.
"""

if not isinstance(classes, Sequence):
if not isinstance(classes, Iterable):
francescalb marked this conversation as resolved.
Show resolved Hide resolved
classes = [classes]

descendants = {name: [] for name in classes}
Expand Down
67 changes: 67 additions & 0 deletions tests/test_generation_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,70 @@ def test_descendants(emmo: "Ontology", repo_dir: "Path") -> None:
assert onto.get_descendants([onto.Tree, onto.NaturalDye], common=True) == {
onto.Avocado
}


def test_ancestors(emmo: "Ontology", repo_dir: "Path") -> None:
francescalb marked this conversation as resolved.
Show resolved Hide resolved
from ontopy import get_ontology
from ontopy.utils import LabelDefinitionError

ontopath = repo_dir / "tests" / "testonto" / "testontology.ttl"

onto = get_ontology(ontopath).load()

# Test that default gives all ancestors.
assert onto.get_ancestors(onto.NorwaySpruce) == {
onto.Spruce,
onto.Tree,
onto.EvergreenTree,
onto.Thing,
}

# Test that asking for 0 generations returns empty set
assert onto.get_ancestors(onto.NorwaySpruce, generations=0) == set()

# Check that number of generations are returned correctly
assert onto.get_ancestors(onto.NorwaySpruce, generations=2) == {
onto.Spruce,
onto.EvergreenTree,
}

assert onto.get_ancestors(onto.NorwaySpruce, generations=1) == {
onto.Spruce,
}
# Check that no error is generated if one of the classes do
# not have enough parents for all given generations
assert onto.get_ancestors(onto.NorwaySpruce, generations=10) == (
onto.get_ancestors(onto.NorwaySpruce)
)

# Check that ancestors of a list is returned correctly
assert onto.get_ancestors([onto.NorwaySpruce, onto.Avocado]) == {
onto.Tree,
onto.EvergreenTree,
onto.Spruce,
onto.NaturalDye,
onto.Thing,
}
# Check that classes up to closest common ancestor are returned

assert onto.get_ancestors(
[onto.NorwaySpruce, onto.Avocado], common=True
) == {
onto.EvergreenTree,
onto.Spruce,
}

with pytest.raises(ValueError):
onto.get_ancestors(onto.NorwaySpruce, common=True, generations=4)

# Test strict == False
assert onto.get_ancestors(
[onto.NorwaySpruce, onto.Avocado],
common=True,
strict=False,
) == {
onto.EvergreenTree,
onto.Spruce,
onto.NorwaySpruce,
onto.Avocado,
}