Skip to content
This repository has been archived by the owner on Mar 31, 2019. It is now read-only.

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Panin committed Jun 6, 2016
1 parent 6290839 commit ae58262
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
9 changes: 5 additions & 4 deletions agentnet/environment/session_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def __init__(self, observations=1,
self.observations = [
create_shared(
"sessions.observations_history." + str(i),
np.zeros((10, 5) + tuple(obs.output_shape[1:]), dtype=get_layer_dtype(obs)))
np.zeros((10, 5) + tuple(obs.output_shape[1:])),
dtype=get_layer_dtype(obs)
)
for i, obs in enumerate(observations)
]

Expand Down Expand Up @@ -117,9 +119,8 @@ def __init__(self, observations=1,
self.preceding_agent_memories = [
create_shared(
"session.prev_memory." + str(i),
np.zeros((10,) + tuple(mem.output_shape[1:]),
dtype=get_layer_dtype(mem)
)
np.zeros((10,) + tuple(mem.output_shape[1:])),
dtype=get_layer_dtype(mem)
)
for i, mem in enumerate(agent_memories)
]
Expand Down
42 changes: 42 additions & 0 deletions agentnet/memory/logical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
A few layers that implement simple switching and counting ops
"""
from lasagne.layers import Layer,MergeLayer
import theano.tensor as T
from ..utils.layers import get_layer_dtype


class CounterLayer(Layer):
"""A simple counter Layer that increments it's state by 1 each turn and loops each k iterations"""
def __init__(self,prev_counter,k=None,name=None):
super(CounterLayer,self).__init__(prev_counter,name=name)
self.k=k
def get_output_for(self,prev_state,**kwargs):
return prev_state+1 if self.k is None else (prev_state+1)%self.k
@property
def output_dtype(self):
return 'int32'

#TODO(jheuristic) try ifelse?

class SwitchLayer(MergeLayer):
"""a simple layer that implements an 'if-than-else' logic"""
def __init__(self,condition,than_branch,else_branch,name=None):
super(SwitchLayer,self).__init__(incomings=[condition,than_branch,else_branch], name=name)

assert tuple(than_branch.output_shape) == tuple(else_branch.output_shape)
assert get_layer_dtype(than_branch) == get_layer_dtype(else_branch)

self.output_dtype = get_layer_dtype(than_branch)

def get_output_for(self,inputs,**kwargs):
"""
:param inputs: a tuple of [condition,than,else]
"""

cond, than_branch, else_branch = inputs
return T.switch(cond,than_branch,else_branch)

def get_output_shape_for(self,input_shapes):
return input_shapes[-1]

0 comments on commit ae58262

Please sign in to comment.