Skip to content

Commit

Permalink
additional readability improvements to pipeline.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andyx13 committed Jan 31, 2024
1 parent 6781dec commit 58b57ae
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions sigpro/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ def __init__(self, primitives, primitive_combinations, features_as_strings=False
length = max([len(combination) for combination in primitive_combinations] + [0])
if length == 0:
raise ValueError('At least one non-empty output feature must be specified')

self.num_layers = length
self.primitives = copy(primitives[:]) # Will need to adjust API to remove primitives arg
self.primitive_combinations = None
Expand All @@ -366,7 +367,6 @@ def __init__(self, primitives, primitive_combinations, features_as_strings=False
in primitive_combinations]

for combination in self.primitive_combinations:

combo_length = len(combination)
for ind in range(combo_length - 1):
if combination[ind].get_type_subtype()[0] != 'transformation':
Expand Down Expand Up @@ -432,13 +432,16 @@ def _build_pipeline(self): # pylint: disable=too-many-locals, too-many-branches
for input_dict in final_primitive.get_inputs():
final_primitive_inputs[numbered_primitive_name][input_dict['name']] = \
f'{final_primitive_str}.' + str(input_dict['name'])

in_name = input_dict['name']
is_required = True
if 'optional' in input_dict:
is_required = not input_dict['optional']

# Context arguments should be named properly in the input data.
if in_name not in final_primitive.get_context_arguments() and \
in_name != 'amplitude_values' and is_required:

# We need to hook up the primitive input to the proper output in chain
if layer == 1:
npn = numbered_primitive_name[:] # lint
Expand All @@ -463,23 +466,28 @@ def _build_pipeline(self): # pylint: disable=too-many-locals, too-many-branches

final_primitive_inputs[numbered_primitive_name][in_name] = \
f'{final_primitive_str}.' + str(in_name)
LOGGER.warning(f'expecting {in_name} to be given by the user.')
warning_str = f'expecting {in_name} to be given by the user.'
LOGGER.warning(warning_str)
# fps = final_primitive_str # lint
# raise ValueError(f'Arg {in_name} of primitive {fps} \
# not produced by any predecessor primitive.')

else:
clpi = combination[:prev_ind]
icn = '.'.join([pr.get_tag() for pr in clpi]) + f'.{prev_ind}'
previous_comb = combination[:prev_ind]
previous_comb_str = '.'.join([pr.get_tag() for pr in
previous_comb])
previous_comb_str += f'.{prev_ind}'
final_primitive_inputs[numbered_primitive_name][in_name] = \
f'{icn}.{in_name}'
f'{previous_comb_str}.{in_name}'

if layer == 1:
final_primitive_inputs[numbered_primitive_name]['amplitude_values'] = \
'amplitude_values'

else:
clm1 = combination[:layer - 1] # lint character limit
input_column_name = '.'.join([pr.get_tag() for pr in clm1]) \
+ f'.{layer-1}'
input_column_name = '.'.join([pr.get_tag() for pr in
combination[:layer - 1]])
input_column_name += f'.{layer-1}'
final_primitive_inputs[numbered_primitive_name]['amplitude_values'] = \
input_column_name + '.amplitude_values'

Expand All @@ -491,12 +499,14 @@ def _build_pipeline(self): # pylint: disable=too-many-locals, too-many-branches
out_name = output_dict['name']
final_primitive_outputs[numbered_primitive_name][out_name] = \
f'{output_column_name}.' + str(out_name)

else:
npn = numbered_primitive_name[:] # lint
for output_dict in final_primitive.get_outputs():
out_name = output_dict['name']
final_outputs.append({'name': output_column_name + '.' + str(out_name),
'variable': f'{npn}.{out_name}'})

return MLPipeline(
primitives=final_primitives_list,
init_params=final_init_params,
Expand Down Expand Up @@ -543,16 +553,21 @@ def build_tree_pipeline(transformation_layers, aggregation_layer):
for primitive_ in layer:
if not isinstance(primitive_, Primitive):
raise ValueError('Non-primitive specified in transformation_layers')

all_layers.append(layer.copy())
primitives_all.update(layer)

else:
raise ValueError('Each layer in transformation_layers must be a list')

if isinstance(aggregation_layer, list):
for primitive_ in aggregation_layer:
if not isinstance(primitive_, Primitive):
raise ValueError('Non-primitive specified in aggregation_layer')

all_layers.append(aggregation_layer.copy())
primitives_all.update(aggregation_layer)

else:
raise ValueError('aggregation_layer must be a list')

Expand Down Expand Up @@ -645,7 +660,6 @@ def merge_pipelines(pipelines):
primitive_combinations = set()

for pipeline in (pipelines)[::-1]:

primitives_all.update(pipeline.get_primitives())
primitive_combinations.update(pipeline.get_output_combinations())

Expand Down

0 comments on commit 58b57ae

Please sign in to comment.