diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py index d923ff4ab8bff..72cbc68d2ab90 100644 --- a/scripts/benchmark_migration.py +++ b/scripts/benchmark_migration.py @@ -94,11 +94,22 @@ def find_models(module: ModuleType) -> List[Type[Model]]: engine = create_engine(sqlalchemy_uri) Base = automap_base() Base.prepare(engine, reflect=True) - for table in tables: + seen = set() + while tables: + table = tables.pop() + seen.add(table) model = getattr(Base.classes, table) model.__tablename__ = table models.append(model) + # add other models referenced in foreign keys + inspector = inspect(model) + for column in inspector.columns.values(): + for foreign_key in column.foreign_keys: + table = foreign_key.column.table.name + if table not in seen: + tables.add(table) + # sort topologically so we can create entities in order and # maintain relationships (eg, create a database before creating # a slice) @@ -108,7 +119,8 @@ def find_models(module: ModuleType) -> List[Type[Model]]: dependent_tables: List[str] = [] for column in inspector.columns.values(): for foreign_key in column.foreign_keys: - dependent_tables.append(foreign_key.target_fullname.split(".")[0]) + if foreign_key.column.table.name != model.__tablename__: + dependent_tables.append(foreign_key.column.table.name) sorter.add(model.__tablename__, *dependent_tables) order = list(sorter.static_order()) models.sort(key=lambda model: order.index(model.__tablename__))