From f858d750e2bdae54ad15c00a828460d5b9484b5f Mon Sep 17 00:00:00 2001 From: Chengshu Li Date: Fri, 10 May 2024 18:43:09 -0700 Subject: [PATCH] allow refreshing for taxonomy files for sanity check --- bddl/data_generation/sanity_check.py | 5 +++-- bddl/object_taxonomy.py | 8 ++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/bddl/data_generation/sanity_check.py b/bddl/data_generation/sanity_check.py index 091ce388..ad511567 100644 --- a/bddl/data_generation/sanity_check.py +++ b/bddl/data_generation/sanity_check.py @@ -2,6 +2,7 @@ import json import os from collections.abc import Iterable +from bddl.object_taxonomy import ObjectTaxonomy TRANSITION_RULE_FOLDER = pathlib.Path(__file__).parents[1] / "generated_data" / "transition_map" / "tm_jsons" SYNSET_KEYS = ["machine", "container", "washed_item", "heat_source", "input_synsets", "output_synsets"] @@ -63,9 +64,9 @@ def sanity_check_transition_rules_washer(object_taxonomy): assert cleanser_synset in leaf_synsets, f"In washer transition rule, {cleanser_synset} is not a leaf synset." def sanity_check(): - # Lazy import so that it can use the latest version of output_hierarchy_properties.json - from bddl.object_taxonomy import ObjectTaxonomy object_taxonomy = ObjectTaxonomy() + # Needs to refresh to use the latest version of the taxonomy + object_taxonomy.refresh_hierarchy_file() sanity_check_object_hierarchy(object_taxonomy) sanity_check_transition_rules(object_taxonomy) sanity_check_transition_rules_washer(object_taxonomy) diff --git a/bddl/object_taxonomy.py b/bddl/object_taxonomy.py index 403dc241..f5872f1f 100644 --- a/bddl/object_taxonomy.py +++ b/bddl/object_taxonomy.py @@ -11,8 +11,12 @@ class ObjectTaxonomy(object): def __init__(self, hierarchy_type="default"): - hierarchy_file = DEFAULT_HIERARCHY_FILE - self.taxonomy = self._parse_taxonomy(hierarchy_file) + self.taxonomy = self._parse_taxonomy(DEFAULT_HIERARCHY_FILE) + + def refresh_hierarchy_file(self): + DEFAULT_HIERARCHY_FILE = pkgutil.get_data( + bddl.__package__, "generated_data/output_hierarchy_properties.json") + self.taxonomy = self._parse_taxonomy(DEFAULT_HIERARCHY_FILE) @staticmethod def _parse_taxonomy(json_str):