diff --git a/activitysim/abm/models/disaggregate_accessibility.py b/activitysim/abm/models/disaggregate_accessibility.py index ab4f9acef..8d1102743 100644 --- a/activitysim/abm/models/disaggregate_accessibility.py +++ b/activitysim/abm/models/disaggregate_accessibility.py @@ -154,6 +154,15 @@ class DisaggregateAccessibilitySettings(PydanticReadable, extra="forbid"): procedure work. """ + KEEP_COLS: list[str] | None = None + """ + Disaggreate accessibility table is grouped by the "by" cols above and the KEEP_COLS are averaged + across the group. Initializing the below as NA if not in the auto ownership level, they are skipped + in the groupby mean and the values are correct. + (It's a way to avoid having to update code to reshape the table and introduce new functionality there.) + If none, will keep all of the columns with "accessibility" in the name. + """ + FROM_TEMPLATES: bool = False annotate_proto_tables: list[DisaggregateAccessibilityAnnotateSettings] = [] """ @@ -164,6 +173,11 @@ class DisaggregateAccessibilitySettings(PydanticReadable, extra="forbid"): """ NEAREST_METHOD: str = "skims" + postprocess_proto_tables: list[DisaggregateAccessibilityAnnotateSettings] = [] + """ + List of preprocessor settings to apply to the proto-population tables after generation. + """ + def read_disaggregate_accessibility_yaml( state: workflow.State, file_name @@ -846,6 +860,10 @@ def compute_disaggregate_accessibility( state.tracing.register_traceable_table(tablename, df) del df + disagg_model_settings = read_disaggregate_accessibility_yaml( + state, "disaggregate_accessibility.yaml" + ) + # Run location choice logsums = get_disaggregate_logsums( state, @@ -906,4 +924,23 @@ def compute_disaggregate_accessibility( for k, df in logsums.items(): state.add_table(k, df) + # available post-processing + for annotations in disagg_model_settings.postprocess_proto_tables: + tablename = annotations.tablename + df = state.get_dataframe(tablename) + assert df is not None + assert annotations is not None + assign_columns( + state, + df=df, + model_settings={ + **annotations.annotate.dict(), + **disagg_model_settings.suffixes.dict(), + }, + trace_label=tracing.extend_trace_label( + "disaggregate_accessibility.postprocess", tablename + ), + ) + state.add_table(tablename, df) + return diff --git a/activitysim/abm/tables/disaggregate_accessibility.py b/activitysim/abm/tables/disaggregate_accessibility.py index 8ab0e0820..7828e1c4c 100644 --- a/activitysim/abm/tables/disaggregate_accessibility.py +++ b/activitysim/abm/tables/disaggregate_accessibility.py @@ -107,7 +107,7 @@ def maz_centroids(state: workflow.State): @workflow.table -def proto_disaggregate_accessibility(state: workflow.State): +def proto_disaggregate_accessibility(state: workflow.State) -> pd.DataFrame: # Read existing accessibilities, but is not required to enable model compatibility df = input.read_input_table( state, "proto_disaggregate_accessibility", required=False @@ -130,7 +130,7 @@ def proto_disaggregate_accessibility(state: workflow.State): @workflow.table -def disaggregate_accessibility(state: workflow.State): +def disaggregate_accessibility(state: workflow.State) -> pd.DataFrame: """ This step initializes pre-computed disaggregate accessibility and merges it onto the full synthetic population. Function adds merged all disaggregate accessibility tables to the pipeline but returns nothing. @@ -169,17 +169,17 @@ def disaggregate_accessibility(state: workflow.State): ) merging_params = model_settings.MERGE_ON nearest_method = model_settings.NEAREST_METHOD - accessibility_cols = [ - x for x in proto_accessibility_df.columns if "accessibility" in x - ] + + if model_settings.KEEP_COLS is None: + keep_cols = [x for x in proto_accessibility_df.columns if "accessibility" in x] + else: + keep_cols = model_settings.KEEP_COLS # Parse the merging parameters assert merging_params is not None # Check if already assigned! - if set(accessibility_cols).intersection(persons_merged_df.columns) == set( - accessibility_cols - ): + if set(keep_cols).intersection(persons_merged_df.columns) == set(keep_cols): return # Find the nearest zone (spatially) with accessibilities calculated @@ -211,7 +211,7 @@ def disaggregate_accessibility(state: workflow.State): # because it will get slightly different logsums for households in the same zone. # This is because different destination zones were selected. To resolve, get mean by cols. right_df = ( - proto_accessibility_df.groupby(merge_cols)[accessibility_cols] + proto_accessibility_df.groupby(merge_cols)[keep_cols] .mean() .sort_values(nearest_cols) .reset_index() @@ -244,9 +244,9 @@ def disaggregate_accessibility(state: workflow.State): ) # Predict the nearest person ID and pull the logsums - matched_logsums_df = right_df.loc[clf.predict(x_pop)][ - accessibility_cols - ].reset_index(drop=True) + matched_logsums_df = right_df.loc[clf.predict(x_pop)][keep_cols].reset_index( + drop=True + ) merge_df = pd.concat( [left_df.reset_index(drop=False), matched_logsums_df], axis=1 ).set_index("person_id") @@ -278,9 +278,9 @@ def disaggregate_accessibility(state: workflow.State): # Check that it was correctly left-joined assert all(persons_merged_df[merge_cols] == merge_df[merge_cols]) - assert any(merge_df[accessibility_cols].isnull()) + assert any(merge_df[keep_cols].isnull()) # Inject merged accessibilities so that it can be included in persons_merged function - state.add_table("disaggregate_accessibility", merge_df[accessibility_cols]) + state.add_table("disaggregate_accessibility", merge_df[keep_cols]) - return merge_df[accessibility_cols] + return merge_df[keep_cols]