Skip to content

Commit

Permalink
Merge pull request #107 from gaia-dpci/add-query-marker
Browse files Browse the repository at this point in the history
Add query marker
  • Loading branch information
druzm authored Jun 11, 2024
2 parents 3426ebe + f049924 commit 6fa4820
Show file tree
Hide file tree
Showing 21 changed files with 239 additions and 81 deletions.
6 changes: 0 additions & 6 deletions src/gaiaxpy/core/generic_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,6 @@ def format_additional_columns(additional_columns: Optional[Union[str, list, dict
return convert_values_to_lists(additional_columns)


def validate_error_correction(phot_system, error_correction):
if any(system.name.startswith(ADDITIONAL_SYSTEM_PREFIX) for system in phot_system) and error_correction:
raise ValueError('Photometry is requested for a non-built-in system, but error_correction is set to True. '
'Error correction is only implemented for built-in systems.')


def validate_photometric_system(photometric_system):
"""
Ensure photometric system input isn't empty.
Expand Down
2 changes: 1 addition & 1 deletion src/gaiaxpy/core/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.1.1'
__version__ = '2.1.2'
2 changes: 1 addition & 1 deletion src/gaiaxpy/core/xml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_xp_merge(x_root):

def get_array_text(x_root, tag):
result = x_root.find(tag)
result = [element.text for element in result] if result else None
result = [element.text for element in result] if result is not None else None
length = len(result) if result else None
return result, length

Expand Down
10 changes: 8 additions & 2 deletions src/gaiaxpy/file_parser/parse_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ def parse_file(self, file_path, disable_info=False):
str: File extension ('.csv', '.fits', or '.xml').
"""
if not disable_info:
print(self.info_msg, end='\r')
self.print_info_msg()
extension = _get_file_extension(file_path)
parser = self.get_parser(extension)
parsed_data = _cast(parser(file_path))
if not disable_info:
print(self.info_msg + ' Done!', end='\r')
self.print_info_msg(done=True)
return parsed_data, extension

def _parse_avro(self, avro_file):
Expand Down Expand Up @@ -156,6 +156,12 @@ def _parse_xml(self, xml_file, _array_columns=None, _matrix_columns=None, _useco
row[size_column]), axis=1)
return df

def print_info_msg(self, done=False):
msg = self.info_msg
if done:
msg = msg + ' Done!'
print(msg, end='\r')


def _get_file_extension(file_path):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/gaiaxpy/generator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __get_file_names_recursively(dir_path: str, show_warning: bool = False) -> l
if len(compliant_files) == 0:
raise ValueError('No filter files found in the given directory. Please check your files.')
elif len(compliant_files) < len(all_files) and show_warning:
message = 'Some files in the directory do not correspond to filter files. The program will ignore them.'
message = 'Based on their names, some files in the directory do not correspond to filter files. They will be ignored.'
print(f'UserWarning: {message}', file=sys.stderr)
return compliant_files

Expand Down
4 changes: 1 addition & 3 deletions src/gaiaxpy/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pandas as pd

from gaiaxpy.colour_equation.xp_filter_system_colour_equation import _apply_colour_equation
from gaiaxpy.core.generic_functions import cast_output, format_additional_columns, validate_photometric_system, \
validate_error_correction
from gaiaxpy.core.generic_functions import cast_output, format_additional_columns, validate_photometric_system
from gaiaxpy.error_correction.error_correction import _apply_error_correction
from gaiaxpy.input_reader.input_reader import InputReader
from gaiaxpy.output.photometry_data import PhotometryData
Expand Down Expand Up @@ -89,7 +88,6 @@ def __is_gaia_initially_in_systems(_internal_photometric_system: list,
# Prepare systems, keep track of original systems (especially required for error_correction)
internal_phot_system = photometric_system.copy() if isinstance(photometric_system, list) else (
[photometric_system].copy())
validate_error_correction(internal_phot_system, error_correction)
gaia_system = PhotometricSystem.Gaia_DR3_Vega
is_gaia_in_input = __is_gaia_initially_in_systems(internal_phot_system)
if error_correction and not is_gaia_in_input:
Expand Down
34 changes: 21 additions & 13 deletions src/gaiaxpy/generator/internal_photometric_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,27 +115,35 @@ def _set_file(self, bp_model, rp_model):
str: Path of a file.
"""

def _validate_additional_system_file(_actual_path):
file_names = [split(p)[1] for p in _actual_path]
def _extract_matches(_actual_path, _system_name):
__is_built_in_name = lambda fname: fname.startswith('XpFilter')
__is_exact_match = lambda fname, sysname: fname.startswith(sysname) and f'{sysname}.' in fname
__matches_additional = lambda fname: pattern.match(fname)
__split_paths = [split(p) for p in _actual_path]
__paths, __filenames = zip(*__split_paths)
pattern = re.compile(_ADDITIONAL_SYSTEM_FILES_REGEX, re.IGNORECASE)
if all(f.startswith('XpFilter') for f in file_names):
if all(__is_built_in_name(f) for f in __filenames): # If system is built-in
return _actual_path
return [join(file_path, s) for s in file_names if pattern.match(s) and not s.startswith('XpFilter')]
return [join(p, fn) for p, fn in __split_paths if not __is_built_in_name(fn) and __matches_additional(fn)
and __is_exact_match(fn, _system_name)] # If is additional

def _validate_path(_actual_path):
if len(actual_path) == 0:
raise ValueError('Filter file not found in given path.')
elif len(actual_path) > 1:
# Remove configuration file if it exists to avoid issues when reloading
if exists(_CFG_FILE_PATH):
remove(_CFG_FILE_PATH)
raise ValueError(f'More than one system named {self.label.replace(f"{ADDITIONAL_SYSTEM_PREFIX}_", "")}'
f' were found. System names in the given directory should be unique. Operation aborted.')

file_name = replace_file_name(self.config_file, 'filter', 'filter', bp_model, rp_model, self.label)
system_name = file_name.split('.')[0]
file_path = get_file_path(self.config_file)
# Search file in file path to obtain the actual path
actual_path = glob(file_path + f"/**/{system_name}*.xml", recursive=True)
actual_path = _validate_additional_system_file(actual_path)
if len(actual_path) == 0:
raise ValueError('Filter file not found in given path.')
elif len(actual_path) > 1:
# Remove configuration file if it exists to avoid issues on reloading
if exists(_CFG_FILE_PATH):
remove(_CFG_FILE_PATH)
raise ValueError(f'More than one system named {self.label.replace(f"{ADDITIONAL_SYSTEM_PREFIX}_", "")}'
f' were found. System names in the given directory should be unique. Operation aborted.')
actual_path = _extract_matches(actual_path, system_name)
_validate_path(actual_path)
self.filter_file = actual_path[0]

def _load_offset_from_xml(self):
Expand Down
6 changes: 6 additions & 0 deletions src/gaiaxpy/input_reader/archive_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ def _login(self, gaia):
gaia.login(user=user, password=password)
else:
pass

def show_info_msg(self, done=False):
msg = self.info_msg
if done:
msg = msg + ' Done!'
print(msg, end='\r')
10 changes: 8 additions & 2 deletions src/gaiaxpy/input_reader/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __get_parseable_columns(self):

def read(self):
if not self.disable_info:
print(self.info_msg, end='\r')
self.show_info_msg()
content = self.content
str_array_columns, np_array_columns = self.__get_parseable_columns()
if str_array_columns:
Expand All @@ -88,10 +88,16 @@ def read(self):
data[f'{band}_covariance_matrix'] = data.apply(get_covariance_matrix, axis=1, args=(band,))
self.requested_columns = self.requested_columns + covariance_columns
if not self.disable_info:
print(self.info_msg + ' Done!', end='\r')
self.show_info_msg(done=True)
data = _cast(data)
if self.additional_columns:
data = rename_with_required(data, self.additional_columns)
data = data[self.requested_columns] if self.requested_columns else data
# No extension returned for DataFrames
return data, None

def show_info_msg(self, done=False):
msg = self.info_msg
if done:
msg = msg + ' Done!'
print(msg, end='\r')
4 changes: 2 additions & 2 deletions src/gaiaxpy/input_reader/list_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def read(self, _data_release=data_release):
self._login(gaia)
# ADQL query
if not self.disable_info:
print(self.info_msg, end='\r')
self.show_info_msg()
result = gaia.load_data(ids=sources, format='csv', data_release=_data_release, data_structure='raw',
retrieval_type='XP_CONTINUOUS', avoid_datatype_check=True)
try:
Expand All @@ -52,6 +52,6 @@ def read(self, _data_release=data_release):
except (KeyError, IndexError):
raise ValueError('No continuous raw data found for the given sources.')
if not self.disable_info:
print(self.info_msg + ' Done!', end='\r')
self.show_info_msg(done=True)
return DataFrameReader(data, function_name, self.truncation, additional_columns=self.additional_columns,
disable_info=True).read()
26 changes: 22 additions & 4 deletions src/gaiaxpy/input_reader/query_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from astroquery.gaia import GaiaClass

from gaiaxpy.core.server import data_release, gaia_server
from gaiaxpy.core.version import __version__
from .archive_reader import ArchiveReader
from .dataframe_reader import DataFrameReader
from ..core.custom_errors import SelectorNotImplementedError
Expand Down Expand Up @@ -43,11 +44,28 @@ def get_srcids(_table):
except StopIteration:
raise ValueError('Source ID column not found in query result.')
if not isinstance(sid_col, str):
raise ValueError(f'Index of source ID column should be a string, but is {type(sid_col).__name__}: {sid_col}.')
raise ValueError(
f'Index of source ID column should be a string, but is {type(sid_col).__name__}: {sid_col}.')
return _table[sid_col]

def read(self, _data_release=data_release):
@staticmethod
def _add_marker(query, comment):
def __remove_comments(_query: str) -> str:
single_line_comment = re.compile(r'--.*$', re.MULTILINE)
multi_line_comment = re.compile(r'/\*.*?\*/', re.DOTALL)
query_no_single_line_comments = single_line_comment.sub('', _query) # Remove single-line comments
return multi_line_comment.sub('', query_no_single_line_comments) # Remove multi-line comments

query = __remove_comments(query)

if comment:
insensitive_select = re.compile(re.escape('select'), re.IGNORECASE)
query = insensitive_select.sub(f'select --{comment} \n', query)
return query

def read(self, _data_release=data_release, _comment=f'This query was launched from within GaiaXPy {__version__}'):
query = self.content
query = self._add_marker(query, _comment)
function_name = self.function.__name__
if function_name in not_supported_functions:
raise ValueError(f'Function {function_name} does not accept ADQL queries.')
Expand All @@ -56,7 +74,7 @@ def read(self, _data_release=data_release):
self._login(gaia)
# ADQL query
if not self.disable_info:
print(self.info_msg, end='\r')
self.show_info_msg()
job = gaia.launch_job_async(query, dump_to_file=False)
query_result = job.get_results()
result = gaia.load_data(ids=self.get_srcids(query_result), format='csv', data_release=_data_release,
Expand All @@ -67,6 +85,6 @@ def read(self, _data_release=data_release):
except KeyError:
raise ValueError('No continuous raw data found for the requested query.')
if not self.disable_info:
print(self.info_msg + ' Done!', end='\r')
self.show_info_msg(done=True)
return DataFrameReader(data, function_name, self.truncation, additional_columns=self.additional_columns,
disable_info=True).read()
22 changes: 18 additions & 4 deletions src/gaiaxpy/spectrum/multi_synthetic_photometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
====================================
Module to represent a synthetic photometry in multiple photometric systems.
"""
import re
from typing import List

import pandas as pd

Expand Down Expand Up @@ -39,13 +41,25 @@ def __init__(self, photometric_system, photometries):
self.mags, self.fluxes, self.errors = _generate_variables(photometries)

def _generate_output_df(self):
def _get_column_sublist(_label: str, _df: pd.DataFrame) -> List:
"""
Get the columns in a DataFrame that correspond to a particular system.
Args:
_label: A system label.
_df: A DataFrame containing photometry data.
Returns:
A list of column names.
"""
_pattern = fr'{_label}_(flux_error|flux|mag)_[A-Za-z0-9]+'
return list(filter(lambda col: re.match(_pattern, col), _df.columns))

photometries_df = self._photometries_to_df()
# Reorder DataFrame columns
phot_system_labels = [phot_system.get_system_label() for phot_system in self.photometric_system]
reordered_columns = ['source_id']
for label in phot_system_labels:
column_sublist = [column for column in photometries_df.columns if column.startswith(f'{label}_')]
reordered_columns.extend(column_sublist)
reordered_columns = ['source_id'] + [col for label in phot_system_labels for col in _get_column_sublist(
label, photometries_df)]
photometries_df = photometries_df[reordered_columns]
return photometries_df

Expand Down
Loading

0 comments on commit 6fa4820

Please sign in to comment.