-
Notifications
You must be signed in to change notification settings - Fork 1
/
ray_pytorch_experiments.py
155 lines (125 loc) · 4.56 KB
/
ray_pytorch_experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python 3
#
# python ray_pytorch_experiments.py s3src dstfile field-aware(yes/n)
#
# Runs all pytorch experiments, pulling datasets from s3 on demand
# s3src is the key prefix for .tar.zst
#
# this runs "incrementally" / so successful runs are idempotent
#
# if "field-aware" is yes, then this runs a different set of experiments,
# focussed on field-aware methods.
import traceback
import ray
ray.init()
import os
os.makedirs('svms-data', exist_ok=True) # data dir
DATA_DIR = os.path.abspath("svms-data")
import sys
assert len(sys.argv) == 4, sys.argv
s3src = sys.argv[1]
dstfile = sys.argv[2]
field_aware = sys.argv[3] == "yes"
records = []
import shutil
if os.path.exists(dstfile):
print('NOTE: dstfile {} exists, so will re-use previous runs'.format(dstfile))
import json
with open(dstfile, 'r') as f:
for line in f:
records.append(json.loads(line))
tups = set(tuple(record[k] for k in ('budget', 'dataset', 'compress', 'learner'))
for record in records)
print('found {} existing records'.format(len(tups)))
for tup in tups:
print(' {}'.format(tup))
def any_match(tup):
for record in records:
rtup = tuple(record[k] for k in ('budget', 'dataset', 'compress', 'learner'))
if rtup == tup:
print('found record for {}\n'.format(tup), end='')
return record
print('running {}\n'.format(tup), end='')
return None
from run_torchfm import main
def logged_main(cmd_args):
budget, dataset, compress = cmd_args[0:3]
learner = cmd_args[-1]
tup = (budget, dataset, compress, learner)
d = any_match(tup)
if d is not None:
return (d, tup)
try:
return (main(*cmd_args), tup)
except Exception as e:
return (e, tup)
ngpus = 1 if field_aware else 0.5
@ray.remote(num_cpus=1, num_gpus=ngpus, max_calls=1)
def runner(cmd_args):
return logged_main(cmd_args)
@ray.remote
def cleanup(bash_glob, dependencies):
ray.wait(dependencies, num_returns=len(dependencies), timeout=None)
check_call(f"""bash -c 'rm -f {bash_glob}'""", shell=True)
if field_aware:
# ['ft'] OOM or don't finish
suffixes = ['faft']
models = ['wd']
budgets = [256 * 1024, 1024 * 1024]
else:
suffixes = ['sm', 'ft', 'te', 'ht']
budgets = [1024]
models = ['lr', 'wd', 'fm', 'nfm', 'dfm']
def launch_grid(budget, dataset):
futures = []
for suffix in suffixes:
suffix_futures = []
for model in models:
cputhreads = 1
quiet = "yes"
cuda = "cuda"
args = [budget, dataset, suffix, quiet, cputhreads, cuda, model]
suffix_futures.append(runner.remote(args))
rmglob = f'{DATA_DIR}/{dataset}.{{train,test}}.{suffix}{budget}.{{data,indices,indptr,y}}.bin'
cleanup_fut = cleanup.remote(rmglob, suffix_futures)
futures.extend(suffix_futures)
futures.append(cleanup_fut)
return futures
from multiprocessing import cpu_count
assert cpu_count() > 1, cpu_count()
from subprocess import check_call, DEVNULL
import json
compress_bases = [f'fieldaware{budget}.tar.zst' for budget in budgets] if field_aware else ['binary.tar.zst']
for compress_base in compress_bases:
compress_file = f'{DATA_DIR}/{compress_base}'
if not os.path.exists(compress_file):
print('downloading binary zst')
check_call(f"""aws s3 cp --no-progress "{s3src}{compress_base}" {compress_file}""", shell=True)
else:
print('found local zst', compress_file)
print('extracting binary zst')
cpus_for_unzst = max(cpu_count() - 1, 1)
check_call(f"""tar -I "pzstd -p {cpus_for_unzst}" -xf {compress_file}""", shell=True)
# check_call(f"""rm -f {compress_file}""", shell=True)
datasets = ['url', 'kdd12', 'kdda', 'kddb']
futures = []
for budget in budgets:
for dataset in datasets:
futures.extend(launch_grid(budget, dataset))
while futures:
ready_ids, remaining_ids = ray.wait(futures)
# Get the available object and do something with it.
with open(dstfile, "a") as f:
for ready_id in ready_ids:
result = ray.get(ready_id)
if result is None:
continue # cleanup task
result, tup = result
if isinstance(result, Exception):
ex = result
ex = traceback.format_exception(etype=type(ex), value=ex, tb=ex.__traceback__)
print('error ignored for task {}:\n{}\n'.format(tup, ex), end='')
continue
print(json.dumps(result), file=f)
print('completed {}\n'.format(tup), end='')
futures = remaining_ids