From 6a50abb23eab38c70c9cdb049f2b6b6abbad66a4 Mon Sep 17 00:00:00 2001 From: Jeffrey Newman Date: Thu, 7 Nov 2024 17:22:46 -0600 Subject: [PATCH] Estimation Pydantic (#2) * pydantic for estimation settings * allow df as type in config * fix table_info * repair for Pydantic * df is attribute --- activitysim/core/estimation.py | 148 +++++++++++++++++++++++---------- 1 file changed, 103 insertions(+), 45 deletions(-) diff --git a/activitysim/core/estimation.py b/activitysim/core/estimation.py index 43b677986..bb1f069e4 100644 --- a/activitysim/core/estimation.py +++ b/activitysim/core/estimation.py @@ -6,11 +6,13 @@ import os import shutil from pathlib import Path +from typing import Literal import pandas as pd -import yaml +from pydantic import model_validator from activitysim.core import simulate, workflow +from activitysim.core.configuration import PydanticReadable from activitysim.core.configuration.base import PydanticBase from activitysim.core.util import reindex from activitysim.core.yaml_tools import safe_dump @@ -48,14 +50,77 @@ def estimation_enabled(state): return settings is not None +class SurveyTableConfig(PydanticBase): + file_name: str + index_col: str + + # The dataframe is stored in the loaded config dynamically but not given + # directly in the config file, as it's not a simple serializable object that + # can be written in a YAML file. + class Config: + arbitrary_types_allowed = True + + df: pd.DataFrame | None = None + + +class EstimationTableRecipeConfig(PydanticBase): + omnibus_tables: dict[str, list[str]] + omnibus_tables_append_columns: list[str] + + +class EstimationConfig(PydanticReadable): + SKIP_BUNDLE_WRITE_FOR: list[str] = [] + EDB_FILETYPE: Literal["csv", "parquet", "pkl"] = "csv" + EDB_ALTS_FILE_FORMAT: Literal["verbose", "compact"] = "verbose" + + enable: bool = False + """Flag to enable estimation.""" + + bundles: list[str] = [] + """List of component names to create EDBs for.""" + + model_estimation_table_types: dict[str, str] = {} + """Mapping of component names to estimation table types. + + The keys of this mapping are the model component names, and the values are the + names of the estimation table recipes that should be used to generate the + estimation tables for the model component. The recipes are generally related + to the generic model types, such as 'simple_simulate', 'interaction_simulate', + 'interaction_sample_simulate', etc. + """ + + estimation_table_recipes: dict[str, EstimationTableRecipeConfig] = {} + """Mapping of estimation table recipe names to their configurations. + + The keys of this mapping are the names of the estimation table recipes. + The recipes are generally related to the generic model types, such as + 'simple_simulate', 'interaction_simulate', 'interaction_sample_simulate', + etc. The values are the configurations for the estimation table recipes. + """ + + survey_tables: dict[str, SurveyTableConfig] = {} + + # pydantic class validator to ensure that the model_estimation_table_types + # dictionary is a valid dictionary with string keys and string values, and + # that all the values are in the estimation_table_recipes dictionary + @model_validator(mode="after") + def validate_model_estimation_table_types(self): + for key, value in self.model_estimation_table_types.items(): + if value not in self.estimation_table_recipes: + raise ValueError( + f"model_estimation_table_types value '{value}' not in estimation_table_recipes" + ) + return self + + class Estimator: def __init__( self, state: workflow.State, - bundle_name, - model_name, - estimation_table_recipes, - settings, + bundle_name: str, + model_name: str, + estimation_table_recipe: EstimationTableRecipeConfig, + settings: EstimationConfig, ): logger.info("Initialize Estimator for'%s'" % (model_name,)) @@ -63,7 +128,7 @@ def __init__( self.bundle_name = bundle_name self.model_name = model_name self.settings_name = model_name - self.estimation_table_recipes = estimation_table_recipes + self.estimation_table_recipe = estimation_table_recipe self.estimating = True self.settings = settings @@ -84,10 +149,10 @@ def __init__( # assert 'override_choices' in self.model_settings, \ # "override_choices not found for %s in %s." % (model_name, ESTIMATION_SETTINGS_FILE_NAME) - self.omnibus_tables = self.estimation_table_recipes["omnibus_tables"] - self.omnibus_tables_append_columns = self.estimation_table_recipes[ - "omnibus_tables_append_columns" - ] + self.omnibus_tables = self.estimation_table_recipe.omnibus_tables + self.omnibus_tables_append_columns = ( + self.estimation_table_recipe.omnibus_tables_append_columns + ) self.tables = {} self.tables_to_cache = [ table_name @@ -345,7 +410,7 @@ def write_omnibus_table(self): if len(self.omnibus_tables) == 0: return - edbs_to_skip = self.settings.get("SKIP_BUNDLE_WRITE_FOR", []) + edbs_to_skip = self.settings.SKIP_BUNDLE_WRITE_FOR if self.bundle_name in edbs_to_skip: self.debug(f"Skipping write to disk for {self.bundle_name}") return @@ -376,7 +441,7 @@ def write_omnibus_table(self): self.debug(f"sorting tables: {table_names}") df.sort_index(ascending=True, inplace=True, kind="mergesort") - filetype = self.settings.get("EDB_FILETYPE", "csv") + filetype = self.settings.EDB_FILETYPE if filetype == "csv": file_path = self.output_file_path(omnibus_table, "csv") @@ -448,7 +513,7 @@ def write_coefficients_template(self, model_settings): assert self.estimating if isinstance(model_settings, PydanticBase): - model_settings = model_settings.dict() + model_settings = model_settings.model_dump() coefficients_df = simulate.read_model_coefficient_template( self.state.filesystem, model_settings ) @@ -460,7 +525,7 @@ def write_choosers(self, choosers_df): choosers_df, "choosers", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_choices(self, choices): @@ -471,7 +536,7 @@ def write_choices(self, choices): choices, "choices", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_override_choices(self, choices): @@ -482,7 +547,7 @@ def write_override_choices(self, choices): choices, "override_choices", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_constants(self, constants): @@ -521,7 +586,7 @@ def write_model_settings( ) assert not os.path.isfile(file_path) with open(file_path, "w") as f: - safe_dump(model_settings.dict(), f) + safe_dump(model_settings.model_dump(), f) else: if "include_settings" in model_settings: file_path = self.output_file_path( @@ -582,7 +647,7 @@ def melt_alternatives(self, df): # 31153,2,util_dist_0_1,1.0 # 31153,3,util_dist_0_1,1.0 - output_format = self.settings.get("EDB_ALTS_FILE_FORMAT", "verbose") + output_format = self.settings.EDB_ALTS_FILE_FORMAT assert output_format in ["verbose", "compact"] if output_format == "compact": @@ -613,7 +678,7 @@ def write_interaction_expression_values(self, df): df, "interaction_expression_values", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_expression_values(self, df): @@ -621,7 +686,7 @@ def write_expression_values(self, df): df, "expression_values", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_alternatives(self, alternatives_df, bundle_directory=False): @@ -638,7 +703,7 @@ def write_interaction_sample_alternatives(self, alternatives_df): alternatives_df, "interaction_sample_alternatives", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def write_interaction_simulate_alternatives(self, interaction_df): @@ -647,7 +712,7 @@ def write_interaction_simulate_alternatives(self, interaction_df): interaction_df, "interaction_simulate_alternatives", append=True, - filetype=self.settings.get("EDB_FILETYPE", "csv"), + filetype=self.settings.EDB_FILETYPE, ) def get_survey_values(self, model_values, table_name, column_names): @@ -679,8 +744,8 @@ class EstimationManager(object): def __init__(self): self.settings_initialized = False self.bundles = [] - self.estimation_table_recipes = {} - self.model_estimation_table_types = {} + self.estimation_table_recipes: dict[str, EstimationTableRecipeConfig] = {} + self.model_estimation_table_types: dict[str, str] = {} self.estimating = {} self.settings = None self.enabled = False @@ -691,40 +756,33 @@ def initialize_settings(self, state): return assert not self.settings_initialized - self.settings = state.filesystem.read_model_settings( - ESTIMATION_SETTINGS_FILE_NAME, mandatory=False + self.settings = EstimationConfig.read_settings_file( + state.filesystem, ESTIMATION_SETTINGS_FILE_NAME, mandatory=False ) if not self.settings: # if the model self.settings file is not found, we are not in estimation mode. self.enabled = False else: - self.enabled = self.settings.get("enable", "True") - self.bundles = self.settings.get("bundles", []) + self.enabled = self.settings.enable + self.bundles = self.settings.bundles - self.model_estimation_table_types = self.settings.get( - "model_estimation_table_types", {} - ) - self.estimation_table_recipes = self.settings.get( - "estimation_table_recipes", {} - ) + self.model_estimation_table_types = self.settings.model_estimation_table_types + self.estimation_table_recipes = self.settings.estimation_table_recipes if self.enabled: - self.survey_tables = self.settings.get("survey_tables", {}) + self.survey_tables = self.settings.survey_tables for table_name, table_info in self.survey_tables.items(): assert ( - "file_name" in table_info - ), "No file name specified for survey_table '%s' in %s" % ( - table_name, - ESTIMATION_SETTINGS_FILE_NAME, - ) + table_info.file_name + ), f"No file name specified for survey_table '{table_name}' in {ESTIMATION_SETTINGS_FILE_NAME}" file_path = state.filesystem.get_data_file_path( - table_info["file_name"], mandatory=True + table_info.file_name, mandatory=True ) assert os.path.exists( file_path ), "File for survey table '%s' not found: %s" % (table_name, file_path) df = pd.read_csv(file_path) - index_col = table_info.get("index_col") + index_col = table_info.index_col if index_col is not None: assert ( index_col in df.columns @@ -744,7 +802,7 @@ def initialize_settings(self, state): df = df[df.household_id.isin(pipeline_hh_ids)] # add the table df to survey_tables - table_info["df"] = df + table_info.df = df self.settings_initialized = True @@ -806,7 +864,7 @@ def begin_estimation( state, bundle_name, model_name, - estimation_table_recipes=self.estimation_table_recipes[ + estimation_table_recipe=self.estimation_table_recipes[ model_estimation_table_type ], settings=self.settings, @@ -824,7 +882,7 @@ def get_survey_table(self, table_name): "EstimationManager. get_survey_table: survey table '%s' not in survey_tables" % table_name ) - df = self.survey_tables[table_name].get("df") + df = self.survey_tables[table_name].df return df def get_survey_values(self, model_values, table_name, column_names):