diff --git a/.github/scripts/check_lineage_updates.sh b/.github/scripts/check_lineage_updates.sh old mode 100755 new mode 100644 index d2299e315..06aa3ab39 --- a/.github/scripts/check_lineage_updates.sh +++ b/.github/scripts/check_lineage_updates.sh @@ -1,176 +1,32 @@ #!/bin/bash +set -e -if [ "$#" -ne 4 ]; then - echo "Usage: $0 " +if [ "$#" -ne 2 ]; then + echo "Usage: $0 " exit 1 fi old_branch="$1" new_branch="$2" -output_path="$3" -run_id="$4" -cd straxen -git checkout "$old_branch" - -get_old_plugins=$(cat<<'EOF' -import strax, straxen -import numpy as np -import sys -import json - -output_path = sys.argv[3] - -st = straxen.contexts.xenonnt_online(output_folder=output_path) -st.storage.append(strax.DataDirectory(output_path, readonly = True)) - -version_hash_dict_old = straxen.test_utils.version_hash_dict(st) -with open("version_hash_dict_old.json", 'w') as jsonfile: - json.dump(version_hash_dict_old, jsonfile) -EOF -) - -python3 -c "$get_old_plugins" "$@" - - -git checkout "$new_branch" - -get_new_plugins=$(cat<<'EOF' -import strax, straxen -import numpy as np -import sys -import json - -output_path = sys.argv[3] - -run_id = sys.argv[4] - -st = straxen.contexts.xenonnt_online(output_folder=output_path) -st.storage.append(strax.DataDirectory(output_path, readonly = True)) - -#Make the new version-hash dictionary -version_hash_dict_new = straxen.test_utils.version_hash_dict(st) -with open("version_hash_dict_new.json", 'w') as jsonfile: - json.dump(version_hash_dict_new, jsonfile) - -#... and open the old one -with open("version_hash_dict_old.json", 'r') as jsonfile: - version_hash_dict_old = json.load(jsonfile) - -#... to compare the changes -updated_plugins_dict = straxen.test_utils.updated_plugins(version_hash_dict_old, version_hash_dict_new) -with open("plugin_update_comparison.json", 'w') as jsonfile: - json.dump(updated_plugins_dict, jsonfile) - -#Print the deleted plugins: -print("\nDeleted plugins:") -for p in updated_plugins_dict['deleted']: - print(f" - {p}") -print('\n') - -#Now print the info for the added plugins -bad_field_info_added = straxen.test_utils.bad_field_info(st, run_id, updated_plugins_dict['added']) - -for p in bad_field_info_added: - print(f"\nNew plugin '{p}' has the following bad field fractions:") - for c in bad_field_info_added[p]: - if bad_field_info_added[p][c]>0: - #Don't print the mean values (unless the column name literally starts with mean) - if (not c.startswith('mean')) or (c.startswith('mean_mean')): - print(f" - {c}: {bad_field_info_added[p][c]}") -print('\n') - -##################################### Comparing differences to changed plugins ##################################### - -lowest_level_changed_plugins = straxen.test_utils.lowest_level_plugins(st, updated_plugins_dict['changed']) - -#See the nan field fractions + mean of each field -new_changed_plugin_bad_info = straxen.test_utils.bad_field_info(st, run_id, lowest_level_changed_plugins) - -with open("new_changed_plugin_bad_info.json", 'w') as jsonfile: - json.dump(new_changed_plugin_bad_info, jsonfile) -print("Finish writing to file") -#affected means the plugins which directly depend on the lowest level changed plugins -affected_changed_plugins = straxen.test_utils.directly_depends_on(st, - lowest_level_changed_plugins, - updated_plugins_dict['changed']) - -new_affected_plugin_bad_info = straxen.test_utils.bad_field_info(st, run_id, affected_changed_plugins) -with open("new_affected_plugin_bad_info.json", 'w') as jsonfile: - json.dump(new_affected_plugin_bad_info, jsonfile) -EOF -) - -python3 -c "$get_new_plugins" "$@" - -git checkout "$old_branch" - -compare_plugins=$(cat<<'EOF' -import strax, straxen -import numpy as np -import sys -import json - -output_path = sys.argv[3] - -run_id = sys.argv[4] -#'025423' - -st = straxen.contexts.xenonnt_online(output_folder=output_path) -st.storage.append(strax.DataDirectory(output_path, readonly = True)) - -#Load in the updated plugins dict -with open("plugin_update_comparison.json", 'r') as jsonfile: - updated_plugins_dict = json.load(jsonfile) - -#Load the bad fields info of the newly changed plugins (remember, we're back to the old branch) -with open("new_changed_plugin_bad_info.json", 'r') as jsonfile: - new_changed_plugin_bad_info = json.load(jsonfile) - -#Load the affected plugins bad info -with open("new_affected_plugin_bad_info.json", 'r') as jsonfile: - new_affected_plugin_bad_info = json.load(jsonfile) - -##################### Now compute the same for the old version of the plugins ##################### -old_changed_plugin_bad_info = straxen.test_utils.bad_field_info(st, run_id, list(new_changed_plugin_bad_info.keys())) -old_affected_plugin_bad_info = straxen.test_utils.bad_field_info(st, run_id, list(new_affected_plugin_bad_info.keys())) - -###Now report the differences -#Lowest level plugins -all_plugin_change_info = {"Lowest Levels":{'old':old_changed_plugin_bad_info, - 'new':new_changed_plugin_bad_info}, - "Affected":{'old':old_affected_plugin_bad_info, - 'new':new_affected_plugin_bad_info}} - -for level in ['Lowest Levels', 'Affected']: - print(f"#################### {level} Plugins ####################") - - for p in all_plugin_change_info[level]['old']: - print(f"Change report for '{p}':") - data_types_old = np.array(list(all_plugin_change_info[level]['old'][p].keys()))[::2] - data_types_new = np.array(list(all_plugin_change_info[level]['new'][p].keys()))[::2] - - all_data_types = np.unique(np.concatenate([data_types_old, data_types_new])) - for d in all_data_types: - if d not in data_types_old: - print(f" - New column {d} added") - elif d not in data_types_new: - print(f" - Column {d} deleted") - else: - if (all_plugin_change_info[level]['old'][p][d] != all_plugin_change_info[level]['new'][p][d]): - print(f" - {d} bad fraction changed from: {all_plugin_change_info[level]['old'][p][d]} -> {all_plugin_change_info[level]['new'][p][d]}") - if (all_plugin_change_info[level]['old'][p][f'mean_{d}'] != all_plugin_change_info[level]['new'][p][f'mean_{d}']): - print(f" - {d} mean value changed from: {all_plugin_change_info[level]['old'][p][f'mean_{d}']} -> {all_plugin_change_info[level]['new'][p][f'mean_{d}']}") - print("All other columns remained the same\n") - print('\n') -EOF -) - -python3 -c "$compare_plugins" "$@" - -#Remove the temporary dictionaries that were created to deal with switching branches -rm new_affected_plugin_bad_info.json -rm new_changed_plugin_bad_info.json -rm plugin_update_comparison.json -rm version_hash_dict_new.json -rm version_hash_dict_old.json +top_level=$(git rev-parse --show-toplevel) +current_branch=$(git rev-parse --abbrev-ref HEAD) + +# run cleanup on exit +function cleanup { + git checkout $current_branch +} +trap cleanup EXIT + +cd $top_level/bin + +git checkout $old_branch +./report_pr_changes --branch old --computation lineage_hash +git checkout $new_branch +./report_pr_changes --branch new --computation lineage_hash +./report_pr_changes --branch new --computation hash_comparison +./report_pr_changes --branch new --computation print_added_plugin +./report_pr_changes --branch new --computation changed_affected_plugin +git checkout $old_branch +./report_pr_changes --branch old --computation changed_affected_plugin +./report_pr_changes --branch old --computation report_changes diff --git a/bin/report_pr_changes b/bin/report_pr_changes new file mode 100755 index 000000000..76209f0f9 --- /dev/null +++ b/bin/report_pr_changes @@ -0,0 +1,203 @@ +#!/usr/bin/env python +"""Check updates of straxen results. + +Example: + report_pr_changes branch_a branch_b + +""" +import argparse +import json + +import numpy as np +import straxen +from straxen.test_utils import nt_test_run_id + + +def read_json(json_name): + with open(json_name, "r") as jsonfile: + return json.load(jsonfile) + + +def save_json(json_name, json_dict): + with open(json_name, "w") as jsonfile: + json.dump(json_dict, jsonfile, indent=4) + + +def save_lineage_hash_dict(st, json_name): + version_hash_dict = straxen.test_utils.version_hash_dict(st) + save_json(json_name, version_hash_dict) + + +def compare_lineage_hash_dict(old_json_name, new_json_name, save_json_name): + version_hash_dict_old = read_json(old_json_name) + version_hash_dict_new = read_json(new_json_name) + + updated_plugins_dict = straxen.test_utils.updated_plugins( + version_hash_dict_old, version_hash_dict_new + ) + save_json(save_json_name, updated_plugins_dict) + + +version_hash_dict_old_name = "version_hash_dict_old.json" +version_hash_dict_new_name = "version_hash_dict_new.json" +plugin_update_comparison_name = "plugin_update_comparison.json" +new_changed_plugin_bad_info_name = "new_changed_plugin_bad_info.json" +new_affected_plugin_bad_info_name = "new_affected_plugin_bad_info.json" +old_changed_plugin_bad_info_name = "old_changed_plugin_bad_info.json" +old_affected_plugin_bad_info_name = "old_affected_plugin_bad_info.json" + + +def get_lineage_hash_dict(): + st = straxen.test_utils.nt_test_context() + if args.branch == "old": + save_lineage_hash_dict(st, version_hash_dict_old_name) + elif args.branch == "new": + save_lineage_hash_dict(st, version_hash_dict_new_name) + + +def get_hash_comparison(): + compare_lineage_hash_dict( + version_hash_dict_old_name, + version_hash_dict_new_name, + plugin_update_comparison_name, + ) + updated_plugins_dict = read_json(plugin_update_comparison_name) + # Print the deleted plugins: + print("\nDeleted plugins:") + for p in updated_plugins_dict["deleted"]: + print(f" - {p}") + print("\n") + + +def print_added_plugin_info(): + # Now print the info for the added plugins + st = straxen.test_utils.nt_test_context() + updated_plugins_dict = read_json(plugin_update_comparison_name) + bad_field_info_added = straxen.test_utils.bad_field_info( + st, nt_test_run_id, updated_plugins_dict["added"] + ) + for p in bad_field_info_added: + print(f"\nNew plugin '{p}' has the following bad field fractions:") + for c in bad_field_info_added[p]: + if bad_field_info_added[p][c] > 0: + # Don't print the mean values (unless the column name literally starts with mean) + if (not c.startswith("mean")) or (c.startswith("mean_mean")): + print(f" - {c}: {bad_field_info_added[p][c]}") + print("\n") + + +def get_changed_affected_plugin_info(): + if args.branch == "old": + # Compute the same for the old version of the plugins + st = straxen.test_utils.nt_test_context() + new_changed_plugin_bad_info = read_json(new_changed_plugin_bad_info_name) + new_affected_plugin_bad_info = read_json(new_affected_plugin_bad_info_name) + old_changed_plugin_bad_info = straxen.test_utils.bad_field_info( + st, nt_test_run_id, list(new_changed_plugin_bad_info.keys()) + ) + old_affected_plugin_bad_info = straxen.test_utils.bad_field_info( + st, nt_test_run_id, list(new_affected_plugin_bad_info.keys()) + ) + save_json(old_changed_plugin_bad_info_name, old_changed_plugin_bad_info) + save_json(old_affected_plugin_bad_info_name, old_affected_plugin_bad_info) + elif args.branch == "new": + # Comparing differences to changed plugins + st = straxen.test_utils.nt_test_context() + updated_plugins_dict = read_json(plugin_update_comparison_name) + lowest_level_changed_plugins = straxen.test_utils.lowest_level_plugins( + st, updated_plugins_dict["changed"] + ) + # Affected means the plugins which directly depend on the lowest level changed plugins + affected_changed_plugins = straxen.test_utils.directly_depends_on( + st, lowest_level_changed_plugins, updated_plugins_dict["changed"] + ) + # See the nan field fractions + mean of each field + new_changed_plugin_bad_info = straxen.test_utils.bad_field_info( + st, nt_test_run_id, lowest_level_changed_plugins + ) + new_affected_plugin_bad_info = straxen.test_utils.bad_field_info( + st, nt_test_run_id, affected_changed_plugins + ) + save_json(new_changed_plugin_bad_info_name, new_changed_plugin_bad_info) + save_json(new_affected_plugin_bad_info_name, new_affected_plugin_bad_info) + + +def report_changes(): + old_changed_plugin_bad_info = read_json(old_changed_plugin_bad_info_name) + old_affected_plugin_bad_info = read_json(old_affected_plugin_bad_info_name) + new_changed_plugin_bad_info = read_json(new_changed_plugin_bad_info_name) + new_affected_plugin_bad_info = read_json(new_affected_plugin_bad_info_name) + # Lowest level plugins + all_plugin_change_info = { + "Lowest Levels": {"old": old_changed_plugin_bad_info, "new": new_changed_plugin_bad_info}, + "Affected": {"old": old_affected_plugin_bad_info, "new": new_affected_plugin_bad_info}, + } + for level in ["Lowest Levels", "Affected"]: + print(f"#################### {level} Plugins ####################") + for p in all_plugin_change_info[level]["old"]: + print(f"Change report for '{p}':") + data_types_old = np.array(list(all_plugin_change_info[level]["old"][p].keys()))[::2] + data_types_new = np.array(list(all_plugin_change_info[level]["new"][p].keys()))[::2] + all_data_types = np.unique(np.concatenate([data_types_old, data_types_new])) + for d in all_data_types: + if d not in data_types_old: + print(f" - New column {d} added") + elif d not in data_types_new: + print(f" - Column {d} deleted") + else: + if ( + all_plugin_change_info[level]["old"][p][d] + != all_plugin_change_info[level]["new"][p][d] + ): + print( + f" - {d} bad fraction changed from:" + f" {all_plugin_change_info[level]['old'][p][d]} ->" + f" {all_plugin_change_info[level]['new'][p][d]}" + ) + if ( + all_plugin_change_info[level]["old"][p][f"mean_{d}"] + != all_plugin_change_info[level]["new"][p][f"mean_{d}"] + ): + print( + f" - {d} mean value changed from:" + f" {all_plugin_change_info[level]['old'][p][f'mean_{d}']} ->" + f" {all_plugin_change_info[level]['new'][p][f'mean_{d}']}" + ) + print("All other columns remained the same\n") + print("\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Check results changes induced by PR") + parser.add_argument( + "--branch", + type=str, + required=True, + choices=["old", "new"], + help="Whether on the old or new branch", + ) + parser.add_argument( + "--computation", + type=str, + required=True, + choices=[ + "lineage_hash", + "hash_comparison", + "print_added_plugin", + "changed_affected_plugin", + "report_changes", + ], + help="Type of computation", + ) + args = parser.parse_args() + + if args.computation == "lineage_hash": + get_lineage_hash_dict() + elif args.computation == "hash_comparison": + get_hash_comparison() + elif args.computation == "print_added_plugin": + print_added_plugin_info() + elif args.computation == "changed_affected_plugin": + get_changed_affected_plugin_info() + elif args.computation == "report_changes": + report_changes() diff --git a/straxen/test_utils.py b/straxen/test_utils.py index 5866bfe17..69d9a8198 100644 --- a/straxen/test_utils.py +++ b/straxen/test_utils.py @@ -312,7 +312,7 @@ def lowest_level_plugins(context, plugins): """ - if not type(plugins) == np.ndarray: + if not isinstance(plugins, np.ndarray): plugins = np.array(plugins) low_lev_plugs = [] @@ -339,9 +339,9 @@ def directly_depends_on(context, base_plugins, all_plugins): next_layer_plugins = [] for p in all_plugins: dep_on = context._plugin_class_registry[p].depends_on - if type(dep_on) == str: + if isinstance(dep_on, str): dep_on = [dep_on] - elif type(dep_on) == tuple: + elif isinstance(dep_on, tuple): dep_on = list(dep_on) if np.any(np.isin(base_plugins, dep_on)):