Skip to content

Commit

Permalink
Add script to cache provider artifacts for faster startup. (#28335)
Browse files Browse the repository at this point in the history
This should be run during template docker image creation.
  • Loading branch information
robertwb authored Sep 13, 2023
1 parent 9f3bea9 commit 141e3e6
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 11 deletions.
46 changes: 46 additions & 0 deletions sdks/python/apache_beam/yaml/cache_provider_artifacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import logging
import time

from apache_beam.version import __version__ as beam_version
from apache_beam.yaml import yaml_provider


def cache_provider_artifacts():
providers_by_id = {}
for providers in yaml_provider.standard_providers().values():
for provider in providers:
# Dedup for better logging.
providers_by_id[id(provider)] = provider
for provider in providers_by_id.values():
t = time.time()
artifacts = provider.cache_artifacts()
if artifacts:
logging.info(
'Cached %s in %0.03f seconds.', ', '.join(artifacts), time.time() - t)
if '.dev' not in beam_version:
# Also cache a base python venv for fast cloning.
t = time.time()
artifacts = yaml_provider.PypiExpansionService._create_venv_to_clone()
logging.info('Cached %s in %0.03f seconds.', artifacts, time.time() - t)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
cache_provider_artifacts()
76 changes: 65 additions & 11 deletions sdks/python/apache_beam/yaml/yaml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from typing import Dict
from typing import Iterable
from typing import Mapping
from typing import Optional

import yaml
from yaml.loader import SafeLoader
Expand All @@ -57,6 +58,9 @@ def available(self) -> bool:
"""Returns whether this provider is available to use in this environment."""
raise NotImplementedError(type(self))

def cache_artifacts(self) -> Optional[Iterable[str]]:
raise NotImplementedError(type(self))

def provided_transforms(self) -> Iterable[str]:
"""Returns a list of transform type names this provider can handle."""
raise NotImplementedError(type(self))
Expand Down Expand Up @@ -256,17 +260,24 @@ def available(self):
self._is_available = False
return self._is_available

def cache_artifacts(self):
pass


class ExternalJavaProvider(ExternalProvider):
def __init__(self, urns, jar_provider):
super().__init__(
urns, lambda: external.JavaJarExpansionService(jar_provider()))
self._jar_provider = jar_provider

def available(self):
# pylint: disable=subprocess-run-check
return subprocess.run(['which', 'java'],
capture_output=True).returncode == 0

def cache_artifacts(self):
return [self._jar_provider()]


@ExternalProvider.register_provider_type('python')
def python(urns, packages=()):
Expand All @@ -289,6 +300,9 @@ def __init__(self, urns, packages):
def available(self):
return True # If we're running this script, we have Python installed.

def cache_artifacts(self):
return [self._service._venv()]

def create_external_transform(self, urn, args):
# Python transforms are "registered" by fully qualified name.
return external.ExternalTransform(
Expand Down Expand Up @@ -351,6 +365,9 @@ def __init__(self, transform_factories):
def available(self):
return True

def cache_artifacts(self):
pass

def provided_transforms(self):
return self._transform_factories.keys()

Expand Down Expand Up @@ -527,23 +544,60 @@ def __init__(self, packages, base_python=sys.executable):
self._packages = packages
self._base_python = base_python

def _key(self):
return json.dumps({'binary': self._base_python, 'packages': self._packages})
@classmethod
def _key(cls, base_python, packages):
return json.dumps({
'binary': base_python, 'packages': sorted(packages)
},
sort_keys=True)

def _venv(self):
venv = os.path.join(
self.VENV_CACHE,
hashlib.sha256(self._key().encode('utf-8')).hexdigest())
@classmethod
def _path(cls, base_python, packages):
return os.path.join(
cls.VENV_CACHE,
hashlib.sha256(cls._key(base_python,
packages).encode('utf-8')).hexdigest())

@classmethod
def _create_venv_from_scratch(cls, base_python, packages):
venv = cls._path(base_python, packages)
if not os.path.exists(venv):
python_binary = os.path.join(venv, 'bin', 'python')
subprocess.run([self._base_python, '-m', 'venv', venv], check=True)
subprocess.run([python_binary, '-m', 'ensurepip'], check=True)
subprocess.run([python_binary, '-m', 'pip', 'install'] + self._packages,
subprocess.run([base_python, '-m', 'venv', venv], check=True)
venv_python = os.path.join(venv, 'bin', 'python')
subprocess.run([venv_python, '-m', 'ensurepip'], check=True)
subprocess.run([venv_python, '-m', 'pip', 'install'] + packages,
check=True)
with open(venv + '-requirements.txt', 'w') as fout:
fout.write('\n'.join(self._packages))
fout.write('\n'.join(packages))
return venv

@classmethod
def _create_venv_from_clone(cls, base_python, packages):
venv = cls._path(base_python, packages)
if not os.path.exists(venv):
clonable_venv = cls._create_venv_to_clone(base_python)
clonable_python = os.path.join(clonable_venv, 'bin', 'python')
subprocess.run(
[clonable_python, '-m', 'clonevirtualenv', clonable_venv, venv],
check=True)
venv_binary = os.path.join(venv, 'bin', 'python')
subprocess.run([venv_binary, '-m', 'pip', 'install'] + packages,
check=True)
with open(venv + '-requirements.txt', 'w') as fout:
fout.write('\n'.join(packages))
return venv

@classmethod
def _create_venv_to_clone(cls, base_python):
return cls._create_venv_from_scratch(
base_python, [
'apache_beam[dataframe,gcp,test]==' + beam_version,
'virtualenv-clone'
])

def _venv(self):
return self._create_venv_from_clone(self._base_python, self._packages)

def __enter__(self):
venv = self._venv()
self._service_provider = subprocess_server.SubprocessServer(
Expand Down

0 comments on commit 141e3e6

Please sign in to comment.