Skip to content
This repository has been archived by the owner on May 1, 2024. It is now read-only.

Commit

Permalink
Tasks for loading GA data into Snowflake (PART 1)
Browse files Browse the repository at this point in the history
This is part 1 of the GA loading pipeline which DOES NOT depend on a
Luigi upgrade.

DE-1374 (PART 1)
  • Loading branch information
pwnage101 committed Apr 23, 2019
1 parent 6db66ef commit d3c5ac1
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 74 deletions.
6 changes: 3 additions & 3 deletions edx/analytics/tasks/common/bigquery_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(self, credentials_target, dataset_id, table, update_id):
self.update_id = update_id
with credentials_target.open('r') as credentials_file:
json_creds = json.load(credentials_file)
self.project_id = json_creds['project_id']
credentials = service_account.Credentials.from_service_account_info(json_creds)
self.client = bigquery.Client(credentials=credentials, project=self.project_id)
self.project_id = json_creds['project_id']
credentials = service_account.Credentials.from_service_account_info(json_creds)
self.client = bigquery.Client(credentials=credentials, project=self.project_id)

def touch(self):
self.create_marker_table()
Expand Down
169 changes: 125 additions & 44 deletions edx/analytics/tasks/common/snowflake_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from snowflake.connector import ProgrammingError

from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin
from edx.analytics.tasks.util.s3_util import canonicalize_s3_url
from edx.analytics.tasks.util.url import ExternalURL

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -90,9 +91,22 @@ def touch(self, connection):
self.create_marker_table()

connection.cursor().execute(
"""INSERT INTO {database}.{schema}.{marker_table} (update_id, target_table)
VALUES (%s, %s)""".format(database=self.database, schema=self.schema, marker_table=self.marker_table),
(self.update_id, "{database}.{schema}.{table}".format(database=self.database, schema=self.schema, table=self.table))
"""
INSERT INTO {database}.{schema}.{marker_table} (update_id, target_table)
VALUES (%s, %s)
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
),
(
self.update_id,
"{database}.{schema}.{table}".format(
database=self.database,
schema=self.schema,
table=self.table,
),
)
)

# make sure update is properly marked
Expand All @@ -112,8 +126,17 @@ def exists(self, connection=None):
return False

cursor = connection.cursor()
query = "SELECT 1 FROM {database}.{schema}.{marker_table} WHERE update_id='{update_id}' AND target_table='{database}.{schema}.{table}'".format(
database=self.database, schema=self.schema, marker_table=self.marker_table, update_id=self.update_id, table=self.table)
query = """
SELECT 1
FROM {database}.{schema}.{marker_table}
WHERE update_id='{update_id}' AND target_table='{database}.{schema}.{table}'
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
update_id=self.update_id,
table=self.table,
)
log.debug(query)
cursor.execute(query)
row = cursor.fetchone()
Expand All @@ -135,8 +158,8 @@ def marker_table_exists(self, connection):
schema=self.schema,
))
row = cursor.fetchone()
except ProgrammingError as e:
if "does not exist" in e.msg:
except ProgrammingError as err:
if "does not exist" in err.msg:
# If so then the query failed because the database or schema doesn't exist.
row = None
else:
Expand Down Expand Up @@ -169,8 +192,14 @@ def clear_marker_table(self, connection):
Delete all markers related to this table update.
"""
if self.marker_table_exists(connection):
query = "DELETE FROM {database}.{schema}.{marker_table} where target_table='{database}.{schema}.{table}'".format(
database=self.database, schema=self.schema, marker_table=self.marker_table, table=self.table,
query = """
DELETE FROM {database}.{schema}.{marker_table}
WHERE target_table='{database}.{schema}.{table}'
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
table=self.table,
)
connection.cursor().execute(query)

Expand Down Expand Up @@ -233,20 +262,6 @@ def columns(self):
def file_format_name(self):
raise NotImplementedError

@property
def field_delimiter(self):
"""
The delimiter in the data to be copied. Default is tab (\t).
"""
return "\t"

@property
def null_marker(self):
"""
The null sequence in the data to be copied. Default is Hive NULL (\\N).
"""
return r'\\N'

@property
def pattern(self):
"""
Expand All @@ -260,7 +275,11 @@ def create_database(self, connection):

def create_schema(self, connection):
cursor = connection.cursor()
cursor.execute("CREATE SCHEMA IF NOT EXISTS {database}.{schema}".format(database=self.database, schema=self.schema))
cursor.execute(
"CREATE SCHEMA IF NOT EXISTS {database}.{schema}".format(
database=self.database, schema=self.schema,
)
)

def create_table(self, connection):
coldefs = ','.join(
Expand All @@ -273,29 +292,15 @@ def create_table(self, connection):

def create_format(self, connection):
"""
Creates a named file format used for bulk loading data into Snowflake tables.
Invoke Snowflake's CREATE FILE FORMAT statement to create the named file format which
configures the loading.
The resulting file format name should be: {self.database}.{self.schema}.{self.file_format_name}
"""
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'CSV' COMPRESSION = 'AUTO' FIELD_DELIMITER = '{field_delimiter}'
FIELD_OPTIONALLY_ENCLOSED_BY = 'NONE' ERROR_ON_COLUMN_COUNT_MISMATCH = TRUE
EMPTY_FIELD_AS_NULL = FALSE ESCAPE_UNENCLOSED_FIELD = 'NONE'
NULL_IF = ('{null_marker}')
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
field_delimiter=self.field_delimiter,
null_marker=self.null_marker,
)
log.debug(query)
connection.cursor().execute(query)
raise NotImplementedError

def create_stage(self, connection):
"""
Creates a named external stage to use for loading data into Snowflake.
"""
stage_url = self.input()['insert_source_task'].path
stage_url = canonicalize_s3_url(self.input()['insert_source_task'].path)
query = """
CREATE OR REPLACE STAGE {database}.{schema}.{table}_stage
URL = '{stage_url}'
Expand Down Expand Up @@ -377,3 +382,79 @@ def output(self):

def update_id(self):
return '{task_name}(date={key})'.format(task_name=self.task_family, key=self.date.isoformat())


class SnowflakeLoadFromHiveTSVTask(SnowflakeLoadTask): # pylint: disable=abstract-method
"""
Abstract Task for loading CSV data from s3 into a table in Snowflake.
Implementations should define the following properties:
- self.insert_source_task
- self.table
- self.columns
- self.file_format_name
"""

@property
def field_delimiter(self):
"""
The delimiter in the data to be copied. Default is tab (\t).
"""
return "\t"

@property
def null_marker(self):
"""
The null sequence in the data to be copied. Default is Hive NULL (\\N).
"""
return r'\\N'

def create_format(self, connection):
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'CSV' COMPRESSION = 'AUTO' FIELD_DELIMITER = '{field_delimiter}'
FIELD_OPTIONALLY_ENCLOSED_BY = 'NONE' ERROR_ON_COLUMN_COUNT_MISMATCH = TRUE
EMPTY_FIELD_AS_NULL = FALSE ESCAPE_UNENCLOSED_FIELD = 'NONE'
NULL_IF = ('{null_marker}')
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
field_delimiter=self.field_delimiter,
null_marker=self.null_marker,
)
log.debug(query)
connection.cursor().execute(query)


class SnowflakeLoadJSONTask(SnowflakeLoadTask): # pylint: disable=abstract-method
"""
Abstract Task for loading JSON data from s3 into a table in Snowflake. The resulting table will
contain a single VARIANT column called raw_json.
Implementations should define the following properties:
- self.insert_source_task
- self.table
- self.file_format_name
"""

@property
def columns(self):
return [
('raw_json', 'VARIANT'),
]

def create_format(self, connection):
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'JSON'
COMPRESSION = 'AUTO'
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
)
log.debug(query)
connection.cursor().execute(query)
43 changes: 33 additions & 10 deletions edx/analytics/tasks/util/s3_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import time
from fnmatch import fnmatch
from urlparse import urlparse
from urlparse import urlparse, urlunparse

from luigi.contrib.hdfs.format import Plain
from luigi.contrib.hdfs.target import HdfsTarget
Expand Down Expand Up @@ -122,22 +122,19 @@ def func(name):
return (n for n in names if func(n))


# TODO: Once we upgrade to boto3 (luigi>=2.7.6), delete this class! In boto3/luigi>=2.7.6, we must
# NOT pass `host` to the s3 client or else it will throw a KeyError. boto3 will already default to
# s3.amazonaws.com.
class ScalableS3Client(S3Client):
"""
S3 client that adds support for defaulting host name.
S3 client that adds support for defaulting host name to s3.amazonaws.com.
"""
# TODO: Make this behavior configurable and submit this change upstream.

def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, **kwargs):

if not aws_access_key_id:
aws_access_key_id = self._get_s3_config('aws_access_key_id')
if not aws_secret_access_key:
aws_secret_access_key = self._get_s3_config('aws_secret_access_key')
def __init__(self, *args, **kwargs):
if 'host' not in kwargs:
kwargs['host'] = self._get_s3_config('host') or 's3.amazonaws.com'

super(ScalableS3Client, self).__init__(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, **kwargs)
super(ScalableS3Client, self).__init__(*args, **kwargs)


class S3HdfsTarget(HdfsTarget):
Expand All @@ -162,3 +159,29 @@ def open(self, mode='r'):
if not hasattr(self, 's3_client'):
self.s3_client = ScalableS3Client()
return AtomicS3File(safe_path, self.s3_client, policy=DEFAULT_KEY_ACCESS_POLICY)


def canonicalize_s3_url(url):
"""
Convert the given s3 URL into a form which is safe to use with external tools.
Specifically, URL Schemes such as "s3+https" are unrecognized by gsutil and Snowflake, and must
be converted to "s3".
Args:
url (str): An s3 URL.
Raises:
ValueError: if the scheme of the input url is unrecognized as S3 at all.
"""
parsed_url = urlparse(url)
if parsed_url.scheme == 's3':
canonical_url = url # Simple passthrough, no change needed.
elif parsed_url.scheme == 's3+https':
new_url_parts = parsed_url._replace(scheme='s3')
canonical_url = urlunparse(new_url_parts)
else:
raise ValueError(
'The URL scheme "{}" does not appear to be an S3 URL scheme.'.format(parsed_url.scheme)
)
return canonical_url
42 changes: 36 additions & 6 deletions edx/analytics/tasks/util/tests/test_s3_util.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,41 @@
"""Tests for S3-related utility functionality."""
"""
Tests for S3-related utility functionality.
"""
from __future__ import print_function

from unittest import TestCase

from ddt import data, ddt, unpack
from mock import MagicMock

from edx.analytics.tasks.util import s3_util


class GenerateS3SourcesTestCase(TestCase):
"""Tests for generate_s3_sources()."""
"""
Tests for generate_s3_sources().
"""

def _make_key(self, keyname, size):
"""Makes a dummy key object, providing the necessary accessors."""
"""
Makes a dummy key object, providing the necessary accessors.
"""
s3_key = MagicMock()
s3_key.key = keyname
s3_key.size = size
return s3_key

def _make_s3_generator(self, bucket_name, root, path_info, patterns):
"""Generates a list of matching S3 sources using a mock S3 connection."""
"""
Generates a list of matching S3 sources using a mock S3 connection.
"""
s3_conn = MagicMock()
s3_bucket = MagicMock()
s3_conn.get_bucket = MagicMock(return_value=s3_bucket)
target_list = [self._make_key("{root}/{path}".format(root=root, path=path), size)
for path, size in path_info.iteritems()]
s3_bucket.list = MagicMock(return_value=target_list)
print [(k.key, k.size) for k in target_list]
print([(k.key, k.size) for k in target_list])

s3_bucket.name = bucket_name
source = "s3://{bucket}/{root}".format(bucket=bucket_name, root=root)
Expand All @@ -34,7 +44,9 @@ def _make_s3_generator(self, bucket_name, root, path_info, patterns):
return output

def _run_without_filtering(self, bucket_name, root, path_info):
"""Runs generator and checks output."""
"""
Runs generator and checks output.
"""
patterns = ['*']
output = self._make_s3_generator(bucket_name, root, path_info, patterns)
self.assertEquals(len(output), len(path_info))
Expand Down Expand Up @@ -97,3 +109,21 @@ def test_generate_with_trailing_slash(self):
(bucket_name, root.rstrip('/'), "subdir1/path1"),
(bucket_name, root.rstrip('/'), "path2")
]))


@ddt
class CanonicalizeS3URLTestCase(TestCase):
"""
Tests for canonicalize_s3_url().
"""

@data(
('s3://hello/world', 's3://hello/world'),
('s3+https://hello/world', 's3://hello/world'),
)
@unpack
def test_canonicalize_s3_url(self, original_url, canonicalized_url):
self.assertEquals(
s3_util.canonicalize_s3_url(original_url),
canonicalized_url,
)
Loading

0 comments on commit d3c5ac1

Please sign in to comment.