-
Notifications
You must be signed in to change notification settings - Fork 3
/
script.py
439 lines (364 loc) · 12.9 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
"""
LUA scripts-related functions.
"""
try: # pragma: no cover
from ujson import dumps as jdumps
from ujson import loads as jloads
except ImportError:
from functools import partial
from json import dumps
from json import loads as jloads
jdumps = partial(dumps, separators=(',', ':'))
import six
from collections import namedtuple
from functools import partial
from redis.client import Script as RedisScript
from redis.client import BasePipeline
from .exceptions import error_handler
from .regions import (
ArgumentRegion,
KeyRegion,
ReturnRegion,
PragmaRegion,
ScriptRegion,
)
from .render import RenderContext
@six.python_2_unicode_compatible
class Script(object):
SENTINEL = object()
__slots__ = [
'name',
'keys',
'args',
'return_type',
'multiple_inclusion',
'line_infos',
'regions',
'_render',
'_redis_script',
]
@classmethod
def get_keys_from_regions(cls, regions):
result = []
for region in regions:
if isinstance(region, KeyRegion):
if region.index != len(result) + 1:
raise ValueError(
"Encountered key %s with index %d when index %d was "
"expected" % (
region.name,
region.index,
len(result) + 1,
)
)
result.append(region.name)
elif isinstance(region, ScriptRegion):
result.extend(cls.get_keys_from_regions(region.script.regions))
duplicates = {x for x in result if result.count(x) > 1}
if duplicates:
raise ValueError("Duplicate key(s) %r" % list(duplicates))
return result
@classmethod
def get_args_from_regions(cls, regions):
result = []
for region in regions:
if isinstance(region, ArgumentRegion):
if region.index != len(result) + 1:
raise ValueError(
"Encountered argument %s with index %d when index %d "
"was expected" % (
region.name,
region.index,
len(result) + 1,
)
)
result.append((region.name, region.type_))
elif isinstance(region, ScriptRegion):
result.extend(cls.get_args_from_regions(region.script.regions))
duplicates = {x for x in result if result.count(x) > 1}
if duplicates:
raise ValueError("Duplicate arguments(s) %r" % list(duplicates))
return result
@classmethod
def get_return_from_regions(cls, regions):
result = None
for region in regions:
if isinstance(region, ReturnRegion):
if result is not None:
raise ValueError("There can be only one return statement.")
result = region.type_
return result
@classmethod
def get_multiple_inclusion_from_regions(cls, regions):
result = True
for region in regions:
if isinstance(region, PragmaRegion):
# There is only one type of pragmas as of now, so we don't need
# to test for the exact value.
result = False
return result
_LineInfo = namedtuple(
'_LineInfo',
[
'first_real_line',
'real_line',
'real_line_count',
'first_line',
'line',
'line_count',
'region',
],
)
@classmethod
def get_line_info_for_regions(cls, regions, included_scripts):
"""
Get a list of tuples (first_real_line, real_line, real_line_count,
first_line, line, line_count, region) for the specified list of
regions.
:params regions: A list of regions to get the line information from.
:params included_scripts: A set of scripts that were visited already.
:returns: A list of tuples.
"""
result = []
real_line = 1
line = 1
def add_region(real_line, line, region):
result.append(
cls._LineInfo(
real_line,
real_line,
region.real_line_count,
line,
line,
region.line_count,
region,
),
)
for region in regions:
if isinstance(region, ScriptRegion):
if region.script in included_scripts:
real_line += region.real_line_count
continue
included_scripts.add(region.script)
sub_result = cls.get_line_info_for_regions(
regions=region.script.regions,
included_scripts=included_scripts,
)
add_region(real_line, line, region)
real_line += region.real_line_count
line += sub_result[-1].line + sub_result[-1].line_count - 1
else:
add_region(real_line, line, region)
real_line += region.real_line_count
line += region.line_count
return result
def __init__(self, name, regions):
"""
Create a new script object.
:param name: The name of the script, without its `.lua` extension.
:param regions: A non-empty list of regions that compose the script.
"""
if not regions:
raise ValueError('regions cannot be empty')
self.name = name
self.keys = self.get_keys_from_regions(regions)
self.args = self.get_args_from_regions(regions)
self.return_type = self.get_return_from_regions(regions)
self.multiple_inclusion = self.get_multiple_inclusion_from_regions(
regions,
)
self.line_infos = self.get_line_info_for_regions(regions, {self})
duplicates = set(self.keys) & {arg for arg, _ in self.args}
if duplicates:
raise ValueError(
'Some key(s) and argument(s) have the same names: %r' % list(
duplicates,
),
)
self.regions = regions
self._render = None
self._redis_script = None
def __repr__(self):
return '{_class}(name={self.name!r})'.format(
_class=self.__class__.__name__,
self=self,
)
def __hash__(self):
return hash(self.name)
@property
def line_count(self):
info = self.line_infos[-1]
return info.line + info.line_count - 1
@property
def real_line_count(self):
info = self.line_infos[-1]
return info.real_line + info.real_line_count - 1
def get_real_line_content(self, line):
"""
Get the real line content for the script at the specified line.
:param line: The line.
:returns: A line content.
"""
info = self.get_line_info(line)
if isinstance(info.region, ScriptRegion):
return info.region.content
else:
return info.region.content.split('\n')[line - info.first_line]
def get_scripts_for_line(self, line):
"""
Get the list of (script, line) by order of traversal for a given line.
:param line: The line.
:returns: A list of (script, line) that got traversed by that line.
"""
info = self.get_line_info(line)
result = [(self, info.real_line)]
if isinstance(info.region, ScriptRegion):
result.extend(
info.region.script.get_scripts_for_line(
line - info.first_line + 1,
),
)
return result
def get_line_info(self, line):
"""
Get the line information for the specified line.
:param line: The line.
:returns: The (real_line, real_line_count, line, line_count, region)
tuple or `ValueError` if no such line exists.
"""
for info in self.line_infos:
if line >= info.line and line < info.line + info.line_count:
return self._LineInfo(
first_real_line=info.first_real_line,
real_line=info.real_line + min(
line - info.line,
info.real_line_count - 1,
),
real_line_count=info.real_line_count,
first_line=info.first_line,
line=line,
line_count=info.line_count,
region=info.region,
)
raise ValueError("No such line %d in script %s" % (line, self))
def __str__(self):
return self.name + ".lua"
def render(self, context=None):
if context is None:
context = RenderContext()
if not self._render:
self._render = context.render_script(self)
return self._render
else:
return context.render_script(self)
def __eq__(self, other):
if not isinstance(other, Script):
return NotImplemented
return all([
other.name == self.name,
other.regions == self.regions,
])
@classmethod
def convert_argument_for_call(cls, type_, value):
if type_ is int:
return int(value)
elif type_ is bool:
return 1 if value else 0
elif type_ is list:
return jdumps(list(value))
elif type_ is dict:
return jdumps(dict(value))
else:
return str(value)
@classmethod
def convert_return_value_from_call(cls, type_, value):
if type_ is str:
return str(value)
elif type_ is int:
return int(value)
elif type_ is bool:
return bool(value)
elif type_ in [list, dict]:
if isinstance(value, six.binary_type):
value = value.decode('utf-8')
return jloads(value)
else:
return value
def runner(self, client, **kwargs):
"""
Call the script with its named arguments.
:returns: The script result.
"""
sentinel = self.SENTINEL
keys = {
key: index
for index, key
in enumerate(self.keys)
}
args = {
arg: (index, type_)
for index, (arg, type_)
in enumerate(self.args)
}
keys_params = [sentinel] * len(self.keys)
args_params = [sentinel] * len(self.args)
for name, value in kwargs.items():
try:
index = keys[name]
keys_params[index] = value
except KeyError:
try:
index, type_ = args[name]
args_params[index] = self.convert_argument_for_call(
type_,
value,
)
except KeyError:
raise TypeError("Unknown key/argument %r" % name)
missing_keys = {
key
for key, index in keys.items()
if keys_params[index] is sentinel
}
if missing_keys:
raise TypeError("Missing key(s) %r" % list(missing_keys))
missing_args = {
arg
for arg, (index, type_) in args.items()
if args_params[index] is sentinel
}
if missing_args:
raise TypeError(
"Missing argument(s) %r" % list(missing_args),
)
with error_handler(self):
if not self._redis_script:
self._redis_script = RedisScript(
registered_client=client,
script=self.render(),
)
result = self._redis_script(
keys=keys_params,
args=args_params,
client=client,
)
if isinstance(client, BasePipeline):
return partial(
self.convert_return_value_from_call,
self.return_type,
)
else:
return self.convert_return_value_from_call(
self.return_type,
result,
)
def get_runner(self, client):
"""
Get a runner for the script on the specified `client`.
:param client: The Redis instance to call the script on.
:returns: The runner, a callable that takes the script named arguments
and returns its result. If `client` is a pipeline, then the runner
returns another callable, through which the resulting value must be
passed to be parsed.
"""
return partial(self.runner, client)