diff --git a/starcluster/awsutils.py b/starcluster/awsutils.py index 904be21af..e39cbc2a1 100644 --- a/starcluster/awsutils.py +++ b/starcluster/awsutils.py @@ -597,6 +597,32 @@ def wait_for_propagation(self, instances=None, spot_requests=None, instance_ids, self.get_all_instances, 'instance-id', 'instances', max_retries=max_retries, interval=interval) + def _check_for_propagation(self, obj_ids, fetch_func, id_filter, obj_name): + filters = {id_filter: obj_ids} + reqs_ids = [] + reqs = fetch_func(filters=filters) + reqs_ids = [req.id for req in reqs] + found = [oid for oid in obj_ids if oid in reqs_ids] + return found + + def check_for_propagation(self, instance_ids=None, spot_ids=None): + """ + Check propagated instances. Returns a tuple where the first item is + a list of the found instances and the second a list of the found + spot requests. + """ + found_instance_ids = [] + found_spot_ids = [] + if spot_ids: + found_instance_ids = self._check_for_propagation( + spot_ids, self.get_all_spot_requests, + 'spot-instance-request-id', 'spot requests') + if instance_ids: + found_spot_ids = self._check_for_propagation( + instance_ids, self.get_all_instances, 'instance-id', + 'instances') + return found_instance_ids, found_spot_ids + def run_instances(self, image_id, instance_type='m1.small', min_count=1, max_count=1, key_name=None, security_groups=None, placement=None, user_data=None, placement_group=None, diff --git a/starcluster/cluster.py b/starcluster/cluster.py index bbbaa127c..2b8b4124f 100644 --- a/starcluster/cluster.py +++ b/starcluster/cluster.py @@ -775,13 +775,16 @@ def nodes(self): log.debug('returning self._nodes = %s' % self._nodes) return self._nodes - def get_nodes_or_raise(self): - nodes = self.nodes - if not nodes: + def get_nodes_or_raise(self, nodes=None): + _nodes = self.nodes + if not _nodes: filters = {'instance.group-name': self._security_group} terminated_nodes = self.ec2.get_all_instances(filters=filters) raise exception.NoClusterNodesFound(terminated_nodes) - return nodes + if nodes: + nodes_ids = [n.id for n in nodes] + _nodes = filter(lambda n: n.id in nodes_ids, _nodes) + return _nodes def get_node(self, identifier, nodes=None): """ @@ -880,11 +883,14 @@ def spot_requests(self): filters['launch.group-id'] = group_id return self.ec2.get_all_spot_requests(filters=filters) - def get_spot_requests_or_raise(self): - spots = self.spot_requests - if not spots: + def get_spot_requests_or_raise(self, spots): + _spots = self.spot_requests + if not _spots: raise exception.NoClusterSpotRequests - return spots + if spots: + spots_ids = [s.id for s in spots] + _spots = filter(lambda s: s.id in spots_ids, _spots) + return _spots def create_node(self, alias, image_id=None, instance_type=None, zone=None, placement_group=None, spot_bid=None, force_flat=False): @@ -1015,7 +1021,13 @@ def add_nodes(self, num_nodes, aliases=None, image_id=None, if self._make_alias(master=True) in aliases: raise exception.ClusterValidationError( "worker nodes cannot have master as an alias") - if not no_create: + if no_create: + self.wait_for_cluster(msg="Waiting for node(s) to come up...") + log.debug("Adding node(s): %s" % aliases) + for alias in aliases: + node = self.get_node(alias) + self.run_plugins(method_name="on_add_node", node=node) + else: if self.subnet: ip_count = self.subnet.available_ip_address_count if ip_count < len(aliases): @@ -1032,14 +1044,86 @@ def add_nodes(self, num_nodes, aliases=None, image_id=None, placement_group=placement_group, spot_bid=spot_bid) if spot_bid or self.spot_bid: - self.ec2.wait_for_propagation(spot_requests=resp) + self.streaming_add(spots=resp) else: - self.ec2.wait_for_propagation(instances=resp[0].instances) - self.wait_for_cluster(msg="Waiting for node(s) to come up...") - log.debug("Adding node(s): %s" % aliases) - for alias in aliases: - node = self.get_node(alias) - self.run_plugins(method_name="on_add_node", node=node) + self.streaming_add(instances=resp[0].instances) + + def streaming_add(self, spots=[], instances=[]): + """ + As soon as a new node is ready, run the add plugins commands over it. + """ + assert bool(spots) != bool(instances), \ + "You must define either spots or instances" + + interval = self.refresh_interval + log.info("Waiting for one of the new nodes to be up " + "(updating every {}s)".format(interval)) + + unpropagated_spots = spots + spots = [] + unpropagated_instances = instances + instances = [] + while True: + ready_instances = [] + if unpropagated_spots: + propagated_spot_ids, _ = self.ec2.check_for_propagation( + spot_ids=[s.id for s in unpropagated_spots]) + unpropagated_spots = utils.filter_move( + lambda s: s.id not in propagated_spot_ids, + unpropagated_spots, spots) + if unpropagated_spots: + log.info("Still waiting for unpropagated spots:" + + str(unpropagated_spots)) + + if spots: + instance_ids = [] + spots = self.get_spot_requests_or_raise(spots) + spots = utils.filter_move( + lambda s: s.state != 'active' or s.instance_id is None, + spots, instance_ids, lambda s: s.instance_id) + if instance_ids: + log.info("Instance ids:" + str(instance_ids)) + # Those one are already propagated + instances += \ + self.ec2.get_all_instances(instance_ids=instance_ids) + if spots: + log.info("Still waiting for spots: " + str(spots)) + + if unpropagated_instances: + _, propagated_instance_ids = self.ec2.check_for_propagation( + instance_ids=[s.id for s in unpropagated_instances]) + unpropagated_instances = utils.filter_move( + lambda i: i.id not in propagated_instance_ids, + unpropagated_instances, instances) + if unpropagated_instances: + log.info("Still waiting for unpropagated instances: " + + str(unpropagated_instances)) + + if instances: + instances = self.get_nodes_or_raise(nodes=instances) + ssh_up = self.pool.map(lambda i: i.is_up(), instances) + zip_instances = utils.filter_move( + lambda i: i[0].state != 'running' or not i[1], + zip(instances, ssh_up), ready_instances, + lambda i: i[0]) + instances = [i[0] for i in zip_instances] + if instances: + log.info("Still waiting for instances: " + str(instances)) + for ready_instance in ready_instances: + log.info("Adding node: %s" % ready_instance.alias) + up_nodes = filter(lambda n: n.is_up(), self.nodes) + self.run_plugins(method_name="on_add_node", + node=ready_instance, nodes=up_nodes) + if any([unpropagated_spots, spots, + unpropagated_instances, instances]): + if instances or ready_instances: + # instances means we wait on ssh is_up, no need to sleep + # ready_instances means nodes were added, that took + # time so we should loop again now + continue + time.sleep(interval) + else: + break def remove_node(self, node=None, terminate=True, force=False): """ @@ -1386,7 +1470,7 @@ def wait_for_active_spots(self, spots=None): pbar.update(len(active_spots)) if not pbar.finished: time.sleep(self.refresh_interval) - spots = self.get_spot_requests_or_raise() + spots = self.get_spot_requests_or_raise(spots) pbar.reset() self.ec2.wait_for_propagation( instances=[s.instance_id for s in spots]) @@ -1397,19 +1481,19 @@ def wait_for_running_instances(self, nodes=None, Wait until all cluster nodes are in a 'running' state """ log.info("Waiting for all nodes to be in a 'running' state...") - nodes = nodes or self.get_nodes_or_raise() + _nodes = nodes or self.get_nodes_or_raise() pbar = self.progress_bar.reset() - pbar.maxval = len(nodes) + pbar.maxval = len(_nodes) pbar.update(0) now = datetime.datetime.utcnow() timeout = now + datetime.timedelta(minutes=kill_pending_after_mins) while not pbar.finished: - running_nodes = [n for n in nodes if n.state == "running"] - pbar.maxval = len(nodes) + running_nodes = [n for n in _nodes if n.state == "running"] + pbar.maxval = len(_nodes) pbar.update(len(running_nodes)) if not pbar.finished: if datetime.datetime.utcnow() > timeout: - pending = [n for n in nodes if n not in running_nodes] + pending = [n for n in _nodes if n not in running_nodes] log.warn("%d nodes have been pending for >= %d mins " "- terminating" % (len(pending), kill_pending_after_mins)) @@ -1417,7 +1501,7 @@ def wait_for_running_instances(self, nodes=None, node.terminate() else: time.sleep(self.refresh_interval) - nodes = self.get_nodes_or_raise() + _nodes = self.get_nodes_or_raise(nodes) pbar.reset() def wait_for_ssh(self, nodes=None): @@ -1672,7 +1756,7 @@ def _setup_cluster(self): self.run_plugins() def run_plugins(self, plugins=None, method_name="run", node=None, - reverse=False): + reverse=False, nodes=None): """ Run all plugins specified in this Cluster object's self.plugins list Uses plugins list instead of self.plugins if specified. @@ -1687,9 +1771,11 @@ def run_plugins(self, plugins=None, method_name="run", node=None, if reverse: plugs.reverse() for plug in plugs: - self.run_plugin(plug, method_name=method_name, node=node) + self.run_plugin(plug, method_name=method_name, node=node, + nodes=nodes) - def run_plugin(self, plugin, name='', method_name='run', node=None): + def run_plugin(self, plugin, name='', method_name='run', node=None, + nodes=None): """ Run a StarCluster plugin. @@ -1707,7 +1793,7 @@ def run_plugin(self, plugin, name='', method_name='run', node=None): log.warn("Plugin %s has no %s method...skipping" % (plugin_name, method_name)) return - args = [self.nodes, self.master_node, self.cluster_user, + args = [nodes or self.nodes, self.master_node, self.cluster_user, self.cluster_shell, self.volumes] if node: args.insert(0, node) diff --git a/starcluster/utils.py b/starcluster/utils.py index ec6f0bfff..1dda4f2c0 100644 --- a/starcluster/utils.py +++ b/starcluster/utils.py @@ -649,3 +649,15 @@ def get_spinner(msg): log.info(msg, extra=dict(__nonewline__=True)) s.start() return s + + +def filter_move(keep_fct, in_, out, extract_fct=None): + def _filter(item): + if keep_fct(item): + return True + if extract_fct: + out.append(extract_fct(item)) + else: + out.append(item) + return False + return filter(_filter, in_)