Skip to content
This repository has been archived by the owner on May 1, 2024. It is now read-only.

Commit

Permalink
Tasks for loading GA data into Snowflake (PART 1)
Browse files Browse the repository at this point in the history
This is part 1 of the GA loading pipeline which DOES NOT depend on a
Luigi upgrade.

DE-1374 (PART 1)
  • Loading branch information
pwnage101 committed Apr 19, 2019
1 parent ad0d76e commit b51a455
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 74 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ test-acceptance-local-all:
REMOTE_TASK=$(shell which remote-task) LUIGI_CONFIG_PATH='config/test.cfg' ACCEPTANCE_TEST_CONFIG="/var/tmp/acceptance.json" python -m coverage run --rcfile=./.coveragerc -m nose --nocapture --with-xunit -A acceptance -v

quality-docker-local:
bash -c 'source ${ANALYTICS_PIPELINE_VENV}/analytics_pipeline/bin/activate && pip install -r requirements/test.txt'
bash -c 'source ${ANALYTICS_PIPELINE_VENV}/analytics_pipeline/bin/activate && isort --check-only --recursive edx/'
pycodestyle edx

Expand Down
6 changes: 3 additions & 3 deletions edx/analytics/tasks/common/bigquery_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def __init__(self, credentials_target, dataset_id, table, update_id):
self.update_id = update_id
with credentials_target.open('r') as credentials_file:
json_creds = json.load(credentials_file)
self.project_id = json_creds['project_id']
credentials = service_account.Credentials.from_service_account_info(json_creds)
self.client = bigquery.Client(credentials=credentials, project=self.project_id)
self.project_id = json_creds['project_id']
credentials = service_account.Credentials.from_service_account_info(json_creds)
self.client = bigquery.Client(credentials=credentials, project=self.project_id)

def touch(self):
self.create_marker_table()
Expand Down
173 changes: 127 additions & 46 deletions edx/analytics/tasks/common/snowflake_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import logging

import luigi
import snowflake.connector
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from snowflake.connector import ProgrammingError

import snowflake.connector
from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin
from edx.analytics.tasks.util.s3_util import canonicalize_s3_url
from edx.analytics.tasks.util.url import ExternalURL
from snowflake.connector import ProgrammingError

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -90,9 +91,22 @@ def touch(self, connection):
self.create_marker_table()

connection.cursor().execute(
"""INSERT INTO {database}.{schema}.{marker_table} (update_id, target_table)
VALUES (%s, %s)""".format(database=self.database, schema=self.schema, marker_table=self.marker_table),
(self.update_id, "{database}.{schema}.{table}".format(database=self.database, schema=self.schema, table=self.table))
"""
INSERT INTO {database}.{schema}.{marker_table} (update_id, target_table)
VALUES (%s, %s)
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
),
(
self.update_id,
"{database}.{schema}.{table}".format(
database=self.database,
schema=self.schema,
table=self.table,
),
)
)

# make sure update is properly marked
Expand All @@ -112,8 +126,17 @@ def exists(self, connection=None):
return False

cursor = connection.cursor()
query = "SELECT 1 FROM {database}.{schema}.{marker_table} WHERE update_id='{update_id}' AND target_table='{database}.{schema}.{table}'".format(
database=self.database, schema=self.schema, marker_table=self.marker_table, update_id=self.update_id, table=self.table)
query = """
SELECT 1
FROM {database}.{schema}.{marker_table}
WHERE update_id='{update_id}' AND target_table='{database}.{schema}.{table}'
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
update_id=self.update_id,
table=self.table,
)
log.debug(query)
cursor.execute(query)
row = cursor.fetchone()
Expand All @@ -135,8 +158,8 @@ def marker_table_exists(self, connection):
schema=self.schema,
))
row = cursor.fetchone()
except ProgrammingError as e:
if "does not exist" in e.msg:
except ProgrammingError as err:
if "does not exist" in err.msg:
# If so then the query failed because the database or schema doesn't exist.
row = None
else:
Expand Down Expand Up @@ -169,8 +192,14 @@ def clear_marker_table(self, connection):
Delete all markers related to this table update.
"""
if self.marker_table_exists(connection):
query = "DELETE FROM {database}.{schema}.{marker_table} where target_table='{database}.{schema}.{table}'".format(
database=self.database, schema=self.schema, marker_table=self.marker_table, table=self.table,
query = """
DELETE FROM {database}.{schema}.{marker_table}
WHERE target_table='{database}.{schema}.{table}'
""".format(
database=self.database,
schema=self.schema,
marker_table=self.marker_table,
table=self.table,
)
connection.cursor().execute(query)

Expand Down Expand Up @@ -233,20 +262,6 @@ def columns(self):
def file_format_name(self):
raise NotImplementedError

@property
def field_delimiter(self):
"""
The delimiter in the data to be copied. Default is tab (\t).
"""
return "\t"

@property
def null_marker(self):
"""
The null sequence in the data to be copied. Default is Hive NULL (\\N).
"""
return r'\\N'

@property
def pattern(self):
"""
Expand All @@ -260,7 +275,11 @@ def create_database(self, connection):

def create_schema(self, connection):
cursor = connection.cursor()
cursor.execute("CREATE SCHEMA IF NOT EXISTS {database}.{schema}".format(database=self.database, schema=self.schema))
cursor.execute(
"CREATE SCHEMA IF NOT EXISTS {database}.{schema}".format(
database=self.database, schema=self.schema,
)
)

def create_table(self, connection):
coldefs = ','.join(
Expand All @@ -273,29 +292,15 @@ def create_table(self, connection):

def create_format(self, connection):
"""
Creates a named file format used for bulk loading data into Snowflake tables.
Invoke Snowflake's CREATE FILE FORMAT statement to create the named file format which
configures the loading.
The resulting file format name should be: {self.database}.{self.schema}.{self.file_format_name}
"""
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'CSV' COMPRESSION = 'AUTO' FIELD_DELIMITER = '{field_delimiter}'
FIELD_OPTIONALLY_ENCLOSED_BY = 'NONE' ERROR_ON_COLUMN_COUNT_MISMATCH = TRUE
EMPTY_FIELD_AS_NULL = FALSE ESCAPE_UNENCLOSED_FIELD = 'NONE'
NULL_IF = ('{null_marker}')
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
field_delimiter=self.field_delimiter,
null_marker=self.null_marker,
)
log.debug(query)
connection.cursor().execute(query)
raise NotImplementedError

def create_stage(self, connection):
"""
Creates a named external stage to use for loading data into Snowflake.
"""
stage_url = self.input()['insert_source_task'].path
stage_url = canonicalize_s3_url(self.input()['insert_source_task'].path)
query = """
CREATE OR REPLACE STAGE {database}.{schema}.{table}_stage
URL = '{stage_url}'
Expand Down Expand Up @@ -377,3 +382,79 @@ def output(self):

def update_id(self):
return '{task_name}(date={key})'.format(task_name=self.task_family, key=self.date.isoformat())


class SnowflakeLoadCSVTask(SnowflakeLoadTask): # pylint: disable=abstract-method
"""
Abstract Task for loading CSV data from s3 into a table in Snowflake.
Implementations should define the following properties:
- self.insert_source_task
- self.table
- self.columns
- self.file_format_name
"""

@property
def field_delimiter(self):
"""
The delimiter in the data to be copied. Default is tab (\t).
"""
return "\t"

@property
def null_marker(self):
"""
The null sequence in the data to be copied. Default is Hive NULL (\\N).
"""
return r'\\N'

def create_format(self, connection):
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'CSV' COMPRESSION = 'AUTO' FIELD_DELIMITER = '{field_delimiter}'
FIELD_OPTIONALLY_ENCLOSED_BY = 'NONE' ERROR_ON_COLUMN_COUNT_MISMATCH = TRUE
EMPTY_FIELD_AS_NULL = FALSE ESCAPE_UNENCLOSED_FIELD = 'NONE'
NULL_IF = ('{null_marker}')
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
field_delimiter=self.field_delimiter,
null_marker=self.null_marker,
)
log.debug(query)
connection.cursor().execute(query)


class SnowflakeLoadJSONTask(SnowflakeLoadTask): # pylint: disable=abstract-method
"""
Abstract Task for loading JSON data from s3 into a table in Snowflake. The resulting table will
contain a single VARIANT column called raw_json.
Implementations should define the following properties:
- self.insert_source_task
- self.table
- self.file_format_name
"""

@property
def columns(self):
return [
('raw_json', 'VARIANT'),
]

def create_format(self, connection):
query = """
CREATE OR REPLACE FILE FORMAT {database}.{schema}.{file_format_name}
TYPE = 'JSON'
COMPRESSION = 'AUTO'
""".format(
database=self.database,
schema=self.schema,
file_format_name=self.file_format_name,
)
log.debug(query)
connection.cursor().execute(query)
44 changes: 33 additions & 11 deletions edx/analytics/tasks/util/s3_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import time
from fnmatch import fnmatch
from urlparse import urlparse
from urlparse import urlparse, urlunparse

from luigi.contrib.hdfs.format import Plain
from luigi.contrib.hdfs.target import HdfsTarget
Expand Down Expand Up @@ -125,19 +125,17 @@ def func(name):
class ScalableS3Client(S3Client):
"""
S3 client that adds support for defaulting host name.
"""
# TODO: Make this behavior configurable and submit this change upstream.
def __init__(self, aws_access_key_id=None, aws_secret_access_key=None, **kwargs):
DEPRECATED: Just specify `host` in the `[s3]` configuration section, e.g.:
if not aws_access_key_id:
aws_access_key_id = self._get_s3_config('aws_access_key_id')
if not aws_secret_access_key:
aws_secret_access_key = self._get_s3_config('aws_secret_access_key')
if 'host' not in kwargs:
kwargs['host'] = self._get_s3_config('host') or 's3.amazonaws.com'
[s3]
host = s3.amazonaws.com
super(ScalableS3Client, self).__init__(aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, **kwargs)
NOTE: In future versions of Luigi, we must NOT pass `host` to the s3 client
or else it will throw a KeyError. boto3 will already default to
s3.amazonaws.com.
"""
pass


class S3HdfsTarget(HdfsTarget):
Expand All @@ -162,3 +160,27 @@ def open(self, mode='r'):
if not hasattr(self, 's3_client'):
self.s3_client = ScalableS3Client()
return AtomicS3File(safe_path, self.s3_client, policy=DEFAULT_KEY_ACCESS_POLICY)


def canonicalize_s3_url(url):
"""
Convert the given s3 URL into a form which is safe to use with external tools.
Specifically, URL Schemes such as "s3+https" are urecognized by gsutil and Snowflake, and must
be converted to "s3".
Args:
url (str): An s3 URL.
Raises:
ValueError: if the scheme of the input url is unrecognized.
"""
parsed_url = urlparse(url)
if parsed_url.scheme == 's3':
canonical_url = url # Simple passthrough, no change needed.
if parsed_url.scheme == 's3+https':
new_url_parts = parsed_url._replace(scheme='s3')
canonical_url = urlunparse(new_url_parts)
else:
raise ValueError('The S3 URL scheme "{}" is unrecognized.'.format(parsed_url.scheme))
return canonical_url
7 changes: 5 additions & 2 deletions edx/analytics/tasks/util/tests/test_s3_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Tests for S3-related utility functionality."""
"""
Tests for S3-related utility functionality.
"""
from __future__ import print_function

from unittest import TestCase

Expand All @@ -25,7 +28,7 @@ def _make_s3_generator(self, bucket_name, root, path_info, patterns):
target_list = [self._make_key("{root}/{path}".format(root=root, path=path), size)
for path, size in path_info.iteritems()]
s3_bucket.list = MagicMock(return_value=target_list)
print [(k.key, k.size) for k in target_list]
print([(k.key, k.size) for k in target_list])

s3_bucket.name = bucket_name
source = "s3://{bucket}/{root}".format(bucket=bucket_name, root=root)
Expand Down
2 changes: 1 addition & 1 deletion edx/analytics/tasks/warehouse/load_ga_permissions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import logging

import luigi
from apiclient.discovery import build
from google.oauth2 import service_account

from apiclient.discovery import build
from edx.analytics.tasks.common.vertica_load import VerticaCopyTask, VerticaCopyTaskMixin
from edx.analytics.tasks.util.hive import WarehouseMixin
from edx.analytics.tasks.util.overwrite import OverwriteOutputMixin
Expand Down
Loading

0 comments on commit b51a455

Please sign in to comment.