Skip to content

Commit

Permalink
PGC-69: Add docker pull progress indicator
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 7e32636c7ecb61a26807b90ce72bb2c488160a3b
  • Loading branch information
drew committed Oct 9, 2021
1 parent 67c29a4 commit 39cff28
Show file tree
Hide file tree
Showing 4 changed files with 982 additions and 4 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ pyyaml>=5.3,<=5.4
requests==2.25.0
smart_open>=2.1.0,<3.0
tabulate==0.8.9
tqdm==4.62.3
urllib3>=1.25.3,<1.26
98 changes: 95 additions & 3 deletions src/gretel_client/projects/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import tarfile
import uuid

from dataclasses import dataclass
from pathlib import Path
from time import sleep
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Dict, Iterator, List, Optional, Tuple, TYPE_CHECKING, Union
from urllib.parse import urlparse

import docker
Expand All @@ -21,6 +22,7 @@
from docker.models.containers import Container
from docker.models.volumes import Volume
from docker.types.containers import DeviceRequest
from tqdm import tqdm

from gretel_client.config import get_logger, get_session_config
from gretel_client.projects.jobs import ACTIVE_STATES, Job
Expand Down Expand Up @@ -234,15 +236,21 @@ def delete(self):
pass

def _pull(self):
self.logger.debug("Authenticating image pull")
auth, _ = _get_container_auth()
self.logger.info(f"Pulling container image {self.image}")
try:
self._docker_client.images.pull(self.image, auth_config=auth)
pull = self._docker_client.api.pull(
self.image, auth_config=auth, stream=True, decode=True
)
progress_printer = _PullProgressPrinter(pull)
progress_printer.start()

except Exception as ex:
raise ContainerRunError(f"Could not pull image {self.image}") from ex
return self.image

def _run(self, remove: bool = True):
self.logger.debug("Pulling container image")
image = self._pull()
self.logger.debug("Preparing input data volume")
volume_config = self.input_volume.prepare_volume()
Expand Down Expand Up @@ -377,3 +385,87 @@ def check_docker_env():
raise DockerEnvironmentError(
"Can't connect to docker. Please check that docker is installed and running."
) from ex


@dataclass
class _PullUpdate:
"""The Docker daemon emits pull progress as a JSON
schema. This dataclass is responsible for deserializing
each JSON progress update from Docker.
"""

id: str
"""Update id"""

status: str
"""Update status"""

current: Optional[int]
"""Units in mb"""

total: Optional[int]
"""Units in mb"""

def __post_init__(self):
self.current = round(self.current / 2 ** 20) if self.current else None
self.total = round(self.total / 2 ** 20) if self.total else None

@classmethod
def from_dict(cls, source: dict) -> _PullUpdate:
return cls(
id=source.get("id", source.get("status")),
status=source["status"],
current=source.get("progressDetail", {}).get("current"),
total=source.get("progressDetail", {}).get("total"),
)

@property
def units(self) -> str:
return "mb"

def build_indicator(self) -> tqdm:
t = tqdm(total=self.total, unit=self.units)
t.set_description(self.status)
return t


class _PullProgressPrinter:
"""Print docker pull progress"""

def __init__(self, pull: Iterator):
self._pull = pull
self._bars: Dict[str, tqdm] = {}

def start(self):
"""Begin iterating and printing pull updates
from the docker daemon.
"""
for update in self._iter_updates():
if update.current:
self._update_progress(update)
self._close_bars()

def _close_bars(self):
for bar in self._bars.values():
bar.close()

def _update_progress(self, update: _PullUpdate):
bar = self._get_or_create_bar(update)
self._update_bar_total(bar, update)

def _get_or_create_bar(self, update: _PullUpdate) -> tqdm:
if update.id in self._bars:
return self._bars[update.id]
else:
self._bars[update.id] = update.build_indicator()
return self._bars[update.id]

def _update_bar_total(self, bar: tqdm, update: _PullUpdate):
if bar.desc != update.status:
bar.set_description(update.status)
if update.current:
bar.update(update.current - bar.n)

def _iter_updates(self) -> Iterator[_PullUpdate]:
for raw_update in self._pull:
yield _PullUpdate.from_dict(raw_update)
Loading

0 comments on commit 39cff28

Please sign in to comment.