forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
net_builder.py
742 lines (643 loc) · 27.1 KB
/
net_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
## @package net_builder
# Module caffe2.python.net_builder
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core, context
from caffe2.python.task import Task, TaskGroup
from caffe2.python.control_ops_util import add_if_op, add_while_op
@context.define_context()
class NetBuilder(object):
"""
Scope-driven mechanism for building nets, loops and conditional blocks.
Arguments:
name: NetBuilder's name
initial_scope: list of blobs that are available for reading/writing
Example:
from caffe2.python.net_builder import NetBuilder, ops
with NetBuilder() as nb:
c = ops.Const(5)
d = ops.Const(0)
with ops.loop():
ops.stop_if(ops.LE([c, ops.Const(0)]))
ops.Add([c, ops.Const(-1)], [c])
with ops.If(ops.GE([c, ops.Const(3)])):
ops.Add([d, ops.Const(10)], [d])
ops.Print(c, [])
ops.Print(d, [])
step = core.to_execution_step(nb)
"""
def __init__(self, name=None, initial_scope=None, _stop_blob_required=False,
_stop_blob=None, _fullname=None, _use_control_ops=False):
parent = NetBuilder.current(required=False)
assert not _fullname or not name, 'Cannot set both _fullname and name'
assert not _use_control_ops or \
(not _stop_blob_required and not _stop_blob), \
'Stop blobs are not used with control operators'
self.name = _fullname or '/'.join(
n for n in (parent.name if parent else None, name) if n
)
self._frozen = False
self._current_net = None
self._children = []
if parent:
# make sure parent has an up to date lexical scope computed
parent._update_lexical_scope()
self._init_lexical_scope = set(parent._lexical_scope) if parent else set()
if initial_scope:
self._init_lexical_scope |= set([str(b) for b in initial_scope])
self._lexical_scope = set(self._init_lexical_scope)
self._stop_blob = _stop_blob
self._stop_blob_required = _stop_blob_required
self._use_control_ops = _use_control_ops
def stop_blob(self):
"""
Returns the BlobReference to the stop_blob of this NetBuilder.
If one is not yet available, creates one.
This function assumes that the stop_blob() will be used immediatelly
in the current net, so it doesn't initialize it if the current net is
the first of the builder.
"""
assert not self._use_control_ops, \
'Stop blobs are not used with control operators'
if self._stop_blob is None:
net = self.current_net()
self._stop_blob = core.BlobReference(
net.NextName('stop_blob'), net=net)
net.Const(False, blob_out=self._stop_blob)
if self._current_net != self._children[0]:
self._children.insert(0, core.Net('stop_blob_init'))
self._children[0].Const(False, blob_out=self._stop_blob)
return self._stop_blob
def stop_if(self, blob):
assert not self._use_control_ops, \
'Stop blobs are not used with control operators'
stop_blob = self.stop_blob()
ops.Or([stop_blob, blob], [stop_blob])
self._current_net = None
def _assert_mutable(self):
assert not self._frozen, (
'This NetBuilder (%s) has been built already.' % self.name)
def _update_lexical_scope(self):
"""
Updates lexical scope based on the current list of children.
Lexical scope contains names of blobs that are currently available
and were introduced in the net builder
"""
self._lexical_scope = set(self._init_lexical_scope)
for child in self._children:
if isinstance(child, core.Net):
self._lexical_scope |= child.UsedBlobNames()
elif isinstance(child, NetBuilder) and child._use_control_ops:
self._lexical_scope |= child._lexical_scope
def _reset_children(self):
self._current_net = None
self._children = []
self._lexical_scope = set(self._init_lexical_scope)
def add(self, child):
self._assert_mutable()
if self._use_control_ops:
assert isinstance(child, core.Net) or (
isinstance(child, NetBuilder) and child._use_control_ops), \
"Expected Net or NetBuilder with control ops"
self._current_net = None
self._children.append(child)
# to-do : check it's not a dag net
if isinstance(child, core.Net):
self._current_net = child
self._update_lexical_scope()
return child
def current_net(self, name=None):
self._assert_mutable()
if self._current_net is None or name is not None:
self.add(core.Net(name))
return self._current_net
def freeze(self):
for child in self._children:
if hasattr(child, 'freeze'):
child.freeze()
self._current_net = None
self._frozen = True
def get(self):
self.freeze()
return self._children
def __exit__(self, etype, *args):
if self._use_control_ops and len(self._children) > 0:
_children = self._children
self._reset_children()
merged_net = NetBuilder.merge_nets(
_children, self._lexical_scope)
assert merged_net, "Expected a non-empty merge of children"
self._children = [merged_net]
self.freeze()
if etype is not None:
return
assert (not self._stop_blob_required) or self._stop_blob is not None, (
'This NetBuilder (%s) requires a stop condition ' % self.name +
'to be set with `stop` or `stop_if`')
@staticmethod
def merge_nets(nets_or_builders, outer_blob_names):
# Only nets or builders with control ops are allowed.
# Need to pay attention to external outputs, e.g.
# ...
# IfNet1 (cond_blob):
# (Net1)
# X = 1
# IfNet2 (...):
# X = X + 1
# ...
# In this example there're two children in then branch of IfNet1:
# a subnet Net1 that creates blob X and sets its value to one, and
# a net builder IfNet2 that (conditionally) increments X.
# From IfNet2's point of view X is an external input
# and output blob, it will be put into IfNet2 net's external_output.
# At the same time, from the point of view of IfNet1 X is purely local.
# Net.AppendNet just merges external outputs of the networks, so
# without checking this the result of Net1.AppendNet(IfNet2's net)
# would have blob X in external_output
net = None
for n in nets_or_builders:
cur = None
if isinstance(n, NetBuilder):
assert n._use_control_ops, \
"Merging of NetBuilder supported only for control ops"
nets = n.get()
assert len(nets) == 1 and isinstance(nets[0], core.Net), \
"Invalid control op net builder"
cur = nets[0]
else:
assert isinstance(n, core.Net)
cur = n
if net:
net.AppendNet(cur)
else:
net = cur
if net:
# correct external output
external_outputs = [o for o in net.Proto().external_output
if o in outer_blob_names]
net.Proto().external_output[:] = external_outputs
return net
def __str__(self):
return self.name or 'Un-named NetBuilder'
class Operations(object):
"""
Operations to be used in the context of a NetBuilder.
"""
def net(self, net=None, name=None):
"""
Retrieves the current net, or add a new net to the builder.
Args:
net: If provided, add the given net to the active builder.
Else, returns the current Net or creates a new one as needed.
name: if provided, creates a new Net with given name and makes
it the new current net of the active builder. Cannot
be provided if net is provided.
"""
assert name is None or net is None, (
'Cannot provide both `net` and `name`.')
if net is not None:
NetBuilder.current().add(net)
return net
return NetBuilder.current().current_net(name=name)
def __getattr__(self, op_type):
"""
Adds an operator call to the currently active Net.
"""
if op_type.startswith('__'):
raise AttributeError()
# We want hasattr to work properly even if no context is active.
if NetBuilder.current(required=False) is None:
raise AttributeError('No active NetBuilder.')
return getattr(self.net(), op_type)
def task_group(self):
"""
Creates a local task group which will execute as the next step of
the current NetBuilder.
"""
from caffe2.python import task
group = NetBuilder.current()
with task.Cluster():
with task.Node('local'):
tg = task.TaskGroup()
group.add(tg)
return tg
def stop(self):
"""
Stop execution of the current execution step.
Example:
ops.Print(a, 0)
ops.stop()
ops.Print(b, 0)
In the example, 'b' will never be printed.
"""
return self.stop_if(ops.Const(True))
def stop_if(self, blob):
"""
Stop execution of the current execution step if the
condition `blob` is met.
Example:
ops.Print(a, 0)
ops.stop_if(ops.LE([x, ops.Const(0)]))
ops.Print(b, 0)
In the example, 'b' will only be printed if the value of scalar
tensor 'x' is greater than 0.
"""
return NetBuilder.current().stop_if(blob)
def loop(self, iters=None, name=None):
"""
Creates a NetBuilder that will execute in a loop as the next step of
the current NetBuilder. If `iters` is provided, the loop will execute
for `iters` iterations and then stop. `iters` can be a constant or a
BlobReference. If `iters` is not provided, the loop will execute
until `ops.stop` or `ops.stop_if` is called.
Examples:
a = ops.Const(5)
with ops.loop():
ops.stop_if(ops.LE([a, ops.Const(0)]))
ops.Print(a, 0)
ops.Add([a, ops.Const(-1)], [a])
Above, 'a' will be printed 5 times, with values 5 to 1.
with ops.loop(10) as loop:
ops.LogInfo(loop.iter())
This will print the numbers from 0 to 9.
x = ops.Add([ops.Const(10), ops.Const(10)])
with ops.loop(x) as loop:
ops.LogInfo(loop.iter())
This will print the numbers from 0 to 19.
"""
return NetBuilder.current().add(_Loop(iters, name=name))
def stop_guard(self, has_stopped_blob=None, name=None):
"""
Creates a NetBuilder that will execute once as the next step of the
current NetBuilder. After execution, a bool tensor will indicate
whether the inner execution was halted with `stop` or `stop_if`.
Example:
a = ops.Const(True)
with ops.stop_guard() as sg1:
ops.stop_if(a)
ops.Print(ops.Const('did not stop'))
b = ops.Const(False)
with ops.stop_guard() as sg2:
ops.stop_if(b)
ops.Print(ops.Const('did not stop'))
ops.Print(sg1.has_stopped(), [])
ops.Print(sg2.has_stopped(), [])
In the example, 'did not stop' will be printed once,
followed by True and False.
"""
return NetBuilder.current().add(
_StopGuard(has_stopped_blob=has_stopped_blob, name=name))
def If(self, cond, name=None):
"""
Creates a NetBuilder that will execute once as the next step of the
current NetBuilder if the blob `cond` is True.
Example:
with ops.If(ops.Const(True)):
ops.Print(ops.Const('Will print'))
with ops.If(ops.Const(False)):
ops.Print(ops.Const('Wont print'))
The example will print 'Will print' once.
"""
return NetBuilder.current().add(_RunIf(cond, name=name))
def IfNet(self, cond, name=None):
"""
Same as If, but uses 'If' operator instead of execution step logic
"""
return NetBuilder.current().add(_RunIfNet(cond, name=name))
def Else(self, name=None):
"""
Else branch of IfNet, has to be specified immediately after IfNet.
Example:
with ops.IfNet(ops.LT([x, y])):
...
with ops.Else():
...
"""
return _RunElseNet(name=name)
def WhileNet(self, name=None):
"""
NetBuilder for 'While' control operator
"""
return NetBuilder.current().add(_RunWhileNet(name=name))
def Condition(self, name=None):
"""
Loop's condition, executed within WhileNet context
"""
assert isinstance(NetBuilder.current(), _RunWhileNet), \
"Use of Condition outside of WhileNet"
return _RunWhileCondition(name=name)
def task_init(self):
"""
Defines operations that will be executed once at task startup.
Useful when implementing processors, that don't have access to the Task
top-level structure.
This setup will be run only once, even if multiple instances of the task
will run in parallel. For instance-local initialization, use
`task_instance_init` instead.
Example:
def my_processor(rec):
with ops.task_init():
one = ops.Const(1)
two = ops.Const(1)
return Tuple(
ops.Add(rec[0](), zero), ops.Add(rec[1](), two))
"""
setup = _SetupBuilder(_SetupBuilder.INIT)
self.net().add_attribute(Task.TASK_SETUP, setup)
return setup
def task_exit(self):
"""
Define operations to be executed once at task shutdown.
Useful when implementing processors, that don't have access to the Task
top-level structure.
This shutdown will be run only once, after all concurrent instances of
the task have already finished. For instance-local shutdown,
use `task_instance_exit` instead.
Example:
def read_queue(queue):
with ops.task_exit():
queue.close(ops.net())
return queue.read(ops.net())
"""
setup = _SetupBuilder(_SetupBuilder.EXIT)
self.net().add_attribute(Task.TASK_SETUP, setup)
return setup
def task_instance_init(self):
"""
Defines operations that will be executed once at startup of each
instance of a task. This can be seen as "thread_local" initialization.
It is guaranteed to run only after all `task_init` logic finishes.
This setup will be run concurrently for each instance of a task.
For global task initialization, use `task_init` instead.
"""
setup = _SetupBuilder(_SetupBuilder.INIT)
self.net().add_attribute(Task.TASK_INSTANCE_SETUP, setup)
return setup
def task_instance_exit(self):
"""
Defines operations that will be executed once at shutdown of each
instance of a task. This can be seen as "thread_local" finalization.
This shutdown will be run concurrently for each instance of a task.
For global task shutdown, use `task_exit` instead.
"""
setup = _SetupBuilder(_SetupBuilder.EXIT)
self.net().add_attribute(Task.TASK_INSTANCE_SETUP, setup)
return setup
def local_init(self):
"""
Similar to `task_init`, but executes at TaskGroup's startup instead,
before any task of the group starts executing. This will run only
once on each node, before initialization of any task, so it can be
used e.g. to initialize blobs shared across tasks.
"""
setup = _SetupBuilder(_SetupBuilder.INIT)
self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
return setup
def local_exit(self, name=None):
"""
Similar to `task_exit`, but executes at TaskGroup's exit instead,
after all tasks of the group finished execution.
This will run only once on each node.
"""
setup = _SetupBuilder(_SetupBuilder.EXIT, name)
self.net().add_attribute(TaskGroup.LOCAL_SETUP, setup)
return setup
def task_reporter(self, interval_ms=1000, name=None):
"""
Define operations to be executed at every time interval from
task start-up to finish. These operations are guaranteed to
execute at least once after all other operations of the task are
finished.
Example:
with ops.task_reporter(interval_ms=10000):
ops.LogInfo('10s elapsed')
"""
return _ReporterBuilder(interval_ms, net=self.net(), name=name)
def local_reporter(self, interval_ms=1000, name=None):
"""
Similar to task_report, but operations defined within this block
will run repeatedly for as long as any of the tasks in the current
TaskGroup have not finished.
"""
return _ReporterBuilder(interval_ms, name=name)
ops = Operations()
class _ReporterBuilder(NetBuilder):
def __init__(self, interval_ms, net=None, name=None):
NetBuilder.__init__(self, name)
self._net = net
self.interval_ms = interval_ms
def __exit__(self, etype, *args):
if etype is None:
step = core.to_execution_step(self)
step.RunEveryMillis(self.interval_ms)
if self._net:
self._net.add_attribute(Task.REPORT_STEP, step)
else:
TaskGroup.current().report_step(
step, interval_ms=self.interval_ms)
NetBuilder.__exit__(self, etype, *args)
class _SetupBuilder(NetBuilder):
INIT = 'init'
EXIT = 'exit'
def __init__(self, type, name=None):
NetBuilder.__init__(self, name)
self.type = type
def setup(self, net):
if self.type == _SetupBuilder.INIT:
return core.to_execution_step(self)
def exit(self, net):
if self.type == _SetupBuilder.EXIT:
return core.to_execution_step(self)
class _RunOnce(NetBuilder):
def __init__(self, name=None):
NetBuilder.__init__(self, name)
def __exit__(self, etype, *args):
if etype is None and self._stop_blob is not None:
ops.stop()
NetBuilder.__exit__(self, etype, *args)
class _StopGuard(_RunOnce):
def __init__(self, has_stopped_blob=None, name=None):
_RunOnce.__init__(self, name)
self._stopped = has_stopped_blob
self._ran = False
def __enter__(self):
r = _RunOnce.__enter__(self)
self._stopped = ops.Const(True, blob_out=self._stopped)
return r
def __exit__(self, etype, *args):
if etype is None:
self._ran = True
ops.Const(False, blob_out=self._stopped)
_RunOnce.__exit__(self, etype, *args)
def has_stopped(self):
"""
Return a blob that will be set to scalar bool `True` after
this net builder ran, iff it was halted early.
"""
assert self._ran, 'Context not used yet.'
return self._stopped
class _Loop(NetBuilder):
def __init__(self, iters=None, name=None):
NetBuilder.__init__(self, name, _stop_blob_required=True)
if iters is not None:
self._inc = ops.Const(1)
self._iter = ops.Const(0)
self._num_iters = (
iters if isinstance(iters, core.BlobReference)
else ops.Const(iters))
else:
self._num_iters = None
def iter(self):
assert self._num_iters is not None, (
'This loop does not have a number of iterations.')
assert self._iter is not None, (
'iter() must be called from inside the loop context')
return self._iter
def __enter__(self):
builder = NetBuilder.__enter__(self)
if self._num_iters is not None:
ops.stop_if(ops.GE([self._iter, self._num_iters]))
return builder
def __exit__(self, type, *args):
if type is None and self._num_iters is not None:
self.current_net().Add([self._iter, self._inc], [self._iter])
NetBuilder.__exit__(self, type, *args)
class _RunIf(_RunOnce):
def __init__(self, cond_blob=None, name=None, _already_ran=None):
_RunOnce.__init__(self, name)
assert cond_blob or _already_ran
self._is_else = cond_blob is None
if _already_ran is None:
self._else_blob = ops.Not(cond_blob)
self._already_ran = ops.Const(False)
else:
self._already_ran = _already_ran
self._else_blob = _already_ran if cond_blob is None else (
ops.Or([_already_ran, ops.Not(cond_blob)]))
def __enter__(self):
r = _RunOnce.__enter__(self)
ops.stop_if(self._else_blob)
ops.Const(True, blob_out=self._already_ran)
return r
def Elif(self, cond, name=None):
assert not self._is_else, 'Else not allowed for an Else.'
return NetBuilder.current().add(_RunIf(
cond, name=name or self.name, _already_ran=self._already_ran))
def Else(self, name=None):
assert not self._is_else, 'Elif not allowed for an Else.'
return NetBuilder.current().add(
_RunIf(name=name or self.name, _already_ran=self._already_ran))
class _RunIfNet(NetBuilder):
"""
Generates a single net that uses If operator
"""
def __init__(self, cond_blob, name=None):
NetBuilder.__init__(self, name=name, _use_control_ops=True)
assert cond_blob, 'Conditional blob is not specified for an If net'
self._cond_blob = cond_blob
self._then_net = None
self._else_net = None
def add(self, child):
return NetBuilder.add(self, child)
def __exit__(self, type, *args):
if type is None:
_then_nets = self._children
self._reset_children()
self._then_net = NetBuilder.merge_nets(
_then_nets, self._lexical_scope)
if not self._then_net:
self._then_net = core.Net('empty_then_net')
if_net = core.Net(self.name + '/if_net')
add_if_op(if_net, self._cond_blob, self._lexical_scope,
self._then_net, self._else_net)
self._current_net = if_net
self._children = [if_net]
NetBuilder.__exit__(self, type, *args)
class _RunElseNet(NetBuilder):
"""
Else branch for _RunIfNet builder
"""
def __init__(self, name=None):
NetBuilder.__init__(self, name=name, _use_control_ops=True)
parent = NetBuilder.current(required=False)
assert parent and len(parent._children) > 0 and \
isinstance(parent._children[-1], _RunIfNet), \
'Invalid use of Else builder'
self._if_builder = parent._children[-1]
def __exit__(self, type, *args):
if type is None:
_else_nets = self._children
self._reset_children()
self._if_builder._else_net = NetBuilder.merge_nets(
_else_nets, self._lexical_scope)
if self._if_builder._else_net:
if_else_net = core.Net(self.name + '/if_else_net')
add_if_op(
if_else_net,
self._if_builder._cond_blob,
self._lexical_scope,
self._if_builder._then_net,
self._if_builder._else_net)
self._if_builder._current_net = if_else_net
self._if_builder._children = [if_else_net]
NetBuilder.__exit__(self, type, *args)
class _RunWhileNet(NetBuilder):
"""
Generates a single net that uses While operator
"""
def __init__(self, name=None):
NetBuilder.__init__(self, name=name, _use_control_ops=True)
self._cond_builder = None
def __exit__(self, type, *args):
if type is None:
assert self._cond_builder, \
'Condition builder must be specified in While op'
_cond_blob = self._cond_builder._cond_blob
_cond_net = self._cond_builder._cond_net
loop_body = self._children
self._reset_children()
loop_body_net = NetBuilder.merge_nets(
loop_body, self._lexical_scope)
if not loop_body_net:
loop_body_net = core.Net('empty_loop_body_net')
while_net = core.Net(self.name + '/while_net')
add_while_op(while_net, _cond_blob, self._lexical_scope,
loop_body_net, _cond_net)
self._current_net = while_net
self._children = [while_net]
NetBuilder.__exit__(self, type, *args)
class _RunWhileCondition(NetBuilder):
"""
Computes loop's condition, used in the context of WhileNet.
Last operator must have a single scalar boolean output that will be used
as a condition value, no other blobs created in the condition net are
visible outside of it
"""
def __init__(self, name=None):
NetBuilder.__init__(self, name=name, _use_control_ops=True)
parent = NetBuilder.current(required=False)
assert parent and isinstance(parent, _RunWhileNet), \
'Invalid use of loop condition builder'
assert not parent._cond_builder, \
'Multiple loop condition builders specified'
assert len(parent._children) == 0, \
'Condition definition must be specified before the loop\'s body'
parent._cond_builder = self
self._cond_blob = None
self._cond_net = None
def __exit__(self, type, *args):
if type is None:
condition_body = self._children
self._reset_children()
self._cond_net = NetBuilder.merge_nets(
condition_body, self._lexical_scope)
assert self._cond_net, 'Invalid loop condition specified'
assert len(self._cond_net.Proto().op) > 0, 'Invalid condition net'
last_op = self._cond_net.Proto().op[-1]
assert len(last_op.output) == 1, 'Invalid condition net'
self._cond_blob = core.BlobReference(name=last_op.output[0], net=None)
self._current_net = self._cond_net
self._children = [self._cond_net]
NetBuilder.__exit__(self, type, *args)