diff --git a/bddl/knowledge_base/models.py b/bddl/knowledge_base/models.py index ab067ace..edc87008 100644 --- a/bddl/knowledge_base/models.py +++ b/bddl/knowledge_base/models.py @@ -357,23 +357,29 @@ def transition_subgraph(self): return nx.relabel_nodes(G, lambda x: (x.name if isinstance(x, Synset) else f'recipe: {x.name}'), copy=True) def is_produceable_from(self, synsets): - # If it's already available, then we're good. - if self in synsets: - return True, set() - - # Otherwise, are there any recipes that I can use to obtain it? - recipe_alternatives = set() - for recipe in self.produced_by_transition_rules: - producabilities_and_recipe_sets = [ingredient.is_produceable_from(synsets) for ingredient in recipe.input_synsets] - producabilities, recipe_sets = zip(*producabilities_and_recipe_sets) - if all(producabilities): - recipe_alternatives.add(recipe) - recipe_alternatives.update(ingredient_recipe for recipe_set in recipe_sets for ingredient_recipe in recipe_set) + def _is_produceable_from(_self, _synsets, _seen): + if _self in _seen: + return False, set() + + # If it's already available, then we're good. + if _self in _synsets: + return True, set() + + # Otherwise, are there any recipes that I can use to obtain it? + recipe_alternatives = set() + for recipe in _self.produced_by_transition_rules: + producabilities_and_recipe_sets = [_is_produceable_from(ingredient, _synsets, _seen | {self}) for ingredient in recipe.input_synsets] + producabilities, recipe_sets = zip(*producabilities_and_recipe_sets) + if all(producabilities): + recipe_alternatives.add(recipe) + recipe_alternatives.update(ingredient_recipe for recipe_set in recipe_sets for ingredient_recipe in recipe_set) + + if not recipe_alternatives: + return False, set() - if not recipe_alternatives: - return False, set() + return True, recipe_alternatives - return True, recipe_alternatives + return _is_produceable_from(self, synsets, set()) @dataclass(eq=False, order=False) diff --git a/bddl/knowledge_base/processing.py b/bddl/knowledge_base/processing.py index 9c699ffe..256c89b5 100644 --- a/bddl/knowledge_base/processing.py +++ b/bddl/knowledge_base/processing.py @@ -309,16 +309,24 @@ def create_transitions(self): json_paths = glob.glob(str(GENERATED_DATA_DIR / "transition_map/tm_jsons/*.json")) transitions = [] for jp in json_paths: + if "washer_" in jp: + continue with open(jp) as f: transitions.extend(json.load(f)) # Create the transition objects for transition_data in self.tqdm(transitions): - transition = TransitionRule.create(name=transition_data["rule_name"]) - for synset_name in transition_data["input_objects"].keys(): + rule_name = transition_data["rule_name"] + transition = TransitionRule.create(name=rule_name) + inputs = set(transition_data["input_objects"].keys()) + assert inputs, f"Transition {transition.name} has no inputs!" + outputs = set(transition_data["output_objects"].keys()) + assert outputs, f"Transition {transition.name} has no outputs!" + assert inputs & outputs == set(), f"Inputs and outputs of {transition.name} overlap!" + for synset_name in inputs: synset = Synset.get(name=synset_name) transition.input_synsets.add(synset) - for synset_name in transition_data["output_objects"].keys(): + for synset_name in outputs: synset = Synset.get(name=synset_name) transition.output_synsets.add(synset)