Skip to content

Commit

Permalink
Merge branch 'main' into add-waiting-env
Browse files Browse the repository at this point in the history
Signed-off-by: Will Sackfield <sackfield@replicate.com>
  • Loading branch information
8W9aG authored Sep 25, 2024
2 parents f79c597 + 45d7d2d commit f0c5b79
Showing 1 changed file with 39 additions and 40 deletions.
79 changes: 39 additions & 40 deletions python/cog/server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,56 +288,55 @@ def run(self) -> None:
callback=self._stream_write_hook,
)

self._setup(redirector)
self._loop(redirector)
with redirector:
self._setup(redirector)
self._loop(redirector)

def send_cancel(self) -> None:
if self.is_alive() and self.pid:
os.kill(self.pid, signal.SIGUSR1)

def _setup(self, redirector: StreamRedirector) -> None:
with redirector:
done = Done()
wait_for_env()
done = Done()
wait_for_env()
try:
self._predictor = load_predictor_from_ref(self._predictor_ref)
# Could be a function or a class
if hasattr(self._predictor, "setup"):
run_setup(self._predictor)
except Exception as e: # pylint: disable=broad-exception-caught
traceback.print_exc()
done.error = True
done.error_detail = str(e)
except BaseException as e:
# For SystemExit and friends we attempt to add some useful context
# to the logs, but reraise to ensure the process dies.
traceback.print_exc()
done.error = True
done.error_detail = str(e)
raise
finally:
try:
self._predictor = load_predictor_from_ref(self._predictor_ref)
# Could be a function or a class
if hasattr(self._predictor, "setup"):
run_setup(self._predictor)
except Exception as e: # pylint: disable=broad-exception-caught
traceback.print_exc()
done.error = True
done.error_detail = str(e)
except BaseException as e:
# For SystemExit and friends we attempt to add some useful context
# to the logs, but reraise to ensure the process dies.
traceback.print_exc()
done.error = True
done.error_detail = str(e)
raise
finally:
try:
redirector.drain(timeout=10)
except TimeoutError:
self._events.send(
Log(
"WARNING: logs may be truncated due to excessive volume.",
source="stderr",
)
redirector.drain(timeout=10)
except TimeoutError:
self._events.send(
Log(
"WARNING: logs may be truncated due to excessive volume.",
source="stderr",
)
raise
self._events.send(done)
)
raise
self._events.send(done)

def _loop(self, redirector: StreamRedirector) -> None:
with redirector:
while True:
ev = self._events.recv()
if isinstance(ev, Shutdown):
break
if isinstance(ev, PredictionInput):
self._predict(ev.payload, redirector)
else:
print(f"Got unexpected event: {ev}", file=sys.stderr)
while True:
ev = self._events.recv()
if isinstance(ev, Shutdown):
break
if isinstance(ev, PredictionInput):
self._predict(ev.payload, redirector)
else:
print(f"Got unexpected event: {ev}", file=sys.stderr)

def _predict(
self,
Expand Down

0 comments on commit f0c5b79

Please sign in to comment.