-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
92 lines (82 loc) · 3.32 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Copyright 2022 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, List
import sys
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from merchant.merchant_center_uploader import MerchantCenterUpdaterDoFn
from merchant.offer import BatchOffers, RubikOffer
from utils.logger import logger
from config.read import (
read_from_yaml,
rubik_offer_from_csv_line,
rubik_offer_from_big_query_row,
)
from vision.vision import Vision
"""Rubik's main module
#TODO Improvements:
- Create a separate class for Rubik
- Create getters and setters
- For each read_option, assign to a different method
"""
class Rubik:
def __init__(self, config_file):
config = read_from_yaml(config_file)
pipeline_options = PipelineOptions()
rubik_options = config
with beam.Pipeline(options=pipeline_options) as pipeline:
if rubik_options["csv_file"] is not None:
process = (
pipeline
| "Load rows" >> beam.io.ReadFromText(rubik_options["csv_file"])
| "Map lines to objects" >> beam.Map(rubik_offer_from_csv_line)
)
elif rubik_options["big_query"] is not None:
process = (
pipeline
| "Load rows"
>> beam.io.gcp.bigquery.ReadFromBigQuery(
table=rubik_options["big_query"],
gcs_location=rubik_options["big_query_gcs_location"],
)
| "Map rows to objects" >> beam.Map(rubik_offer_from_big_query_row)
)
if rubik_options["vision_ai"] is True:
process | "Vision AI to select best image" >> beam.Map(
Vision(config).find_best_image
)
(
process
| "Build Tuples"
>> beam.Map(lambda product: (product.merchant_id, product))
| "Group by Merchant Id" >> beam.GroupByKey()
| "Batch elements"
>> beam.ParDo(
BatchOffers(rubik_options["batch_size"])
).with_output_types(Tuple[str, List[RubikOffer]])
| "Upload to Merchant Center"
>> beam.ParDo(
MerchantCenterUpdaterDoFn(
rubik_options["client_id"],
rubik_options["client_secret"],
rubik_options["access_token"],
rubik_options["refresh_token"],
rubik_options["rubik_custom_label"],
)
)
)
if __name__ == "__main__":
logger().info("Starting Rubik execution")
config_file = str(sys.argv[1:][0])
Rubik(config_file)