Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming node add #434

Closed
wants to merge 10 commits into from
26 changes: 26 additions & 0 deletions starcluster/awsutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,32 @@ def wait_for_propagation(self, instances=None, spot_requests=None,
instance_ids, self.get_all_instances, 'instance-id',
'instances', max_retries=max_retries, interval=interval)

def _check_for_propagation(self, obj_ids, fetch_func, id_filter, obj_name):
filters = {id_filter: obj_ids}
reqs_ids = []
reqs = fetch_func(filters=filters)
reqs_ids = [req.id for req in reqs]
found = [oid for oid in obj_ids if oid in reqs_ids]
return found

def check_for_propagation(self, instance_ids=None, spot_ids=None):
"""
Check propagated instances. Returns a tuple where the first item is
a list of the found instances and the second a list of the found
spot requests.
"""
found_instance_ids = []
found_spot_ids = []
if spot_ids:
found_instance_ids = self._check_for_propagation(
spot_ids, self.get_all_spot_requests,
'spot-instance-request-id', 'spot requests')
if instance_ids:
found_spot_ids = self._check_for_propagation(
instance_ids, self.get_all_instances, 'instance-id',
'instances')
return found_instance_ids, found_spot_ids

def run_instances(self, image_id, instance_type='m1.small', min_count=1,
max_count=1, key_name=None, security_groups=None,
placement=None, user_data=None, placement_group=None,
Expand Down
140 changes: 113 additions & 27 deletions starcluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,13 +775,16 @@ def nodes(self):
log.debug('returning self._nodes = %s' % self._nodes)
return self._nodes

def get_nodes_or_raise(self):
nodes = self.nodes
if not nodes:
def get_nodes_or_raise(self, nodes=None):
_nodes = self.nodes
if not _nodes:
filters = {'instance.group-name': self._security_group}
terminated_nodes = self.ec2.get_all_instances(filters=filters)
raise exception.NoClusterNodesFound(terminated_nodes)
return nodes
if nodes:
nodes_ids = [n.id for n in nodes]
_nodes = filter(lambda n: n.id in nodes_ids, _nodes)
return _nodes

def get_node(self, identifier, nodes=None):
"""
Expand Down Expand Up @@ -880,11 +883,14 @@ def spot_requests(self):
filters['launch.group-id'] = group_id
return self.ec2.get_all_spot_requests(filters=filters)

def get_spot_requests_or_raise(self):
spots = self.spot_requests
if not spots:
def get_spot_requests_or_raise(self, spots):
_spots = self.spot_requests
if not _spots:
raise exception.NoClusterSpotRequests
return spots
if spots:
spots_ids = [s.id for s in spots]
_spots = filter(lambda s: s.id in spots_ids, _spots)
return _spots

def create_node(self, alias, image_id=None, instance_type=None, zone=None,
placement_group=None, spot_bid=None, force_flat=False):
Expand Down Expand Up @@ -1015,7 +1021,13 @@ def add_nodes(self, num_nodes, aliases=None, image_id=None,
if self._make_alias(master=True) in aliases:
raise exception.ClusterValidationError(
"worker nodes cannot have master as an alias")
if not no_create:
if no_create:
self.wait_for_cluster(msg="Waiting for node(s) to come up...")
log.debug("Adding node(s): %s" % aliases)
for alias in aliases:
node = self.get_node(alias)
self.run_plugins(method_name="on_add_node", node=node)
else:
if self.subnet:
ip_count = self.subnet.available_ip_address_count
if ip_count < len(aliases):
Expand All @@ -1032,14 +1044,86 @@ def add_nodes(self, num_nodes, aliases=None, image_id=None,
placement_group=placement_group,
spot_bid=spot_bid)
if spot_bid or self.spot_bid:
self.ec2.wait_for_propagation(spot_requests=resp)
self.streaming_add(spots=resp)
else:
self.ec2.wait_for_propagation(instances=resp[0].instances)
self.wait_for_cluster(msg="Waiting for node(s) to come up...")
log.debug("Adding node(s): %s" % aliases)
for alias in aliases:
node = self.get_node(alias)
self.run_plugins(method_name="on_add_node", node=node)
self.streaming_add(instances=resp[0].instances)

def streaming_add(self, spots=[], instances=[]):
"""
As soon as a new node is ready, run the add plugins commands over it.
"""
assert bool(spots) != bool(instances), \
"You must define either spots or instances"

interval = self.refresh_interval
log.info("Waiting for one of the new nodes to be up "
"(updating every {}s)".format(interval))

unpropagated_spots = spots
spots = []
unpropagated_instances = instances
instances = []
while True:
ready_instances = []
if unpropagated_spots:
propagated_spot_ids, _ = self.ec2.check_for_propagation(
spot_ids=[s.id for s in unpropagated_spots])
unpropagated_spots = utils.filter_move(
lambda s: s.id not in propagated_spot_ids,
unpropagated_spots, spots)
if unpropagated_spots:
log.info("Still waiting for unpropagated spots:"
+ str(unpropagated_spots))

if spots:
instance_ids = []
spots = self.get_spot_requests_or_raise(spots)
spots = utils.filter_move(
lambda s: s.state != 'active' or s.instance_id is None,
spots, instance_ids, lambda s: s.instance_id)
if instance_ids:
log.info("Instance ids:" + str(instance_ids))
# Those one are already propagated
instances += \
self.ec2.get_all_instances(instance_ids=instance_ids)
if spots:
log.info("Still waiting for spots: " + str(spots))

if unpropagated_instances:
_, propagated_instance_ids = self.ec2.check_for_propagation(
instance_ids=[s.id for s in unpropagated_instances])
unpropagated_instances = utils.filter_move(
lambda i: i.id not in propagated_instance_ids,
unpropagated_instances, instances)
if unpropagated_instances:
log.info("Still waiting for unpropagated instances: "
+ str(unpropagated_instances))

if instances:
instances = self.get_nodes_or_raise(nodes=instances)
ssh_up = self.pool.map(lambda i: i.is_up(), instances)
zip_instances = utils.filter_move(
lambda i: i[0].state != 'running' or not i[1],
zip(instances, ssh_up), ready_instances,
lambda i: i[0])
instances = [i[0] for i in zip_instances]
if instances:
log.info("Still waiting for instances: " + str(instances))
for ready_instance in ready_instances:
log.info("Adding node: %s" % ready_instance.alias)
up_nodes = filter(lambda n: n.is_up(), self.nodes)
self.run_plugins(method_name="on_add_node",
node=ready_instance, nodes=up_nodes)
if any([unpropagated_spots, spots,
unpropagated_instances, instances]):
if instances or ready_instances:
# instances means we wait on ssh is_up, no need to sleep
# ready_instances means nodes were added, that took
# time so we should loop again now
continue
time.sleep(interval)
else:
break

def remove_node(self, node=None, terminate=True, force=False):
"""
Expand Down Expand Up @@ -1386,7 +1470,7 @@ def wait_for_active_spots(self, spots=None):
pbar.update(len(active_spots))
if not pbar.finished:
time.sleep(self.refresh_interval)
spots = self.get_spot_requests_or_raise()
spots = self.get_spot_requests_or_raise(spots)
pbar.reset()
self.ec2.wait_for_propagation(
instances=[s.instance_id for s in spots])
Expand All @@ -1397,27 +1481,27 @@ def wait_for_running_instances(self, nodes=None,
Wait until all cluster nodes are in a 'running' state
"""
log.info("Waiting for all nodes to be in a 'running' state...")
nodes = nodes or self.get_nodes_or_raise()
_nodes = nodes or self.get_nodes_or_raise()
pbar = self.progress_bar.reset()
pbar.maxval = len(nodes)
pbar.maxval = len(_nodes)
pbar.update(0)
now = datetime.datetime.utcnow()
timeout = now + datetime.timedelta(minutes=kill_pending_after_mins)
while not pbar.finished:
running_nodes = [n for n in nodes if n.state == "running"]
pbar.maxval = len(nodes)
running_nodes = [n for n in _nodes if n.state == "running"]
pbar.maxval = len(_nodes)
pbar.update(len(running_nodes))
if not pbar.finished:
if datetime.datetime.utcnow() > timeout:
pending = [n for n in nodes if n not in running_nodes]
pending = [n for n in _nodes if n not in running_nodes]
log.warn("%d nodes have been pending for >= %d mins "
"- terminating" % (len(pending),
kill_pending_after_mins))
for node in pending:
node.terminate()
else:
time.sleep(self.refresh_interval)
nodes = self.get_nodes_or_raise()
_nodes = self.get_nodes_or_raise(nodes)
pbar.reset()

def wait_for_ssh(self, nodes=None):
Expand Down Expand Up @@ -1672,7 +1756,7 @@ def _setup_cluster(self):
self.run_plugins()

def run_plugins(self, plugins=None, method_name="run", node=None,
reverse=False):
reverse=False, nodes=None):
"""
Run all plugins specified in this Cluster object's self.plugins list
Uses plugins list instead of self.plugins if specified.
Expand All @@ -1687,9 +1771,11 @@ def run_plugins(self, plugins=None, method_name="run", node=None,
if reverse:
plugs.reverse()
for plug in plugs:
self.run_plugin(plug, method_name=method_name, node=node)
self.run_plugin(plug, method_name=method_name, node=node,
nodes=nodes)

def run_plugin(self, plugin, name='', method_name='run', node=None):
def run_plugin(self, plugin, name='', method_name='run', node=None,
nodes=None):
"""
Run a StarCluster plugin.

Expand All @@ -1707,7 +1793,7 @@ def run_plugin(self, plugin, name='', method_name='run', node=None):
log.warn("Plugin %s has no %s method...skipping" %
(plugin_name, method_name))
return
args = [self.nodes, self.master_node, self.cluster_user,
args = [nodes or self.nodes, self.master_node, self.cluster_user,
self.cluster_shell, self.volumes]
if node:
args.insert(0, node)
Expand Down
12 changes: 12 additions & 0 deletions starcluster/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,3 +649,15 @@ def get_spinner(msg):
log.info(msg, extra=dict(__nonewline__=True))
s.start()
return s


def filter_move(keep_fct, in_, out, extract_fct=None):
def _filter(item):
if keep_fct(item):
return True
if extract_fct:
out.append(extract_fct(item))
else:
out.append(item)
return False
return filter(_filter, in_)