From bd28835d620e8f17d61d91393359927c4dd58ff8 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Fri, 13 Aug 2021 06:02:13 -0700 Subject: [PATCH] Advantage Actor Critic (A2C) Model (#598) * Apply suggestions from code review Co-authored-by: Akihiro Nitta Co-authored-by: Jirka Borovec Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> --- CHANGELOG.md | 3 +- .../rl_benchmark/cartpole_a2c_results.jpg | Bin 0 -> 43042 bytes docs/source/reinforce_learn.rst | 112 +++++- pl_bolts/models/rl/__init__.py | 16 +- .../models/rl/advantage_actor_critic_model.py | 327 ++++++++++++++++++ pl_bolts/models/rl/common/agents.py | 30 ++ pl_bolts/models/rl/common/networks.py | 34 ++ .../integration/test_actor_critic_models.py | 27 ++ tests/models/rl/test_scripts.py | 15 + tests/models/rl/unit/test_a2c.py | 55 +++ tests/models/rl/unit/test_agents.py | 14 +- 11 files changed, 617 insertions(+), 16 deletions(-) create mode 100644 docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg create mode 100644 pl_bolts/models/rl/advantage_actor_critic_model.py create mode 100644 tests/models/rl/integration/test_actor_critic_models.py create mode 100644 tests/models/rl/unit/test_a2c.py diff --git a/CHANGELOG.md b/CHANGELOG.md index ab2004794b..40fca95982 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). - ## [unReleased] - 2021-MM-DD ### Added +- Added Advantage Actor-Critic (A2C) Model [#598](https://github.com/PyTorchLightning/lightning-bolts/pull/598)) + ### Changed diff --git a/docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg b/docs/source/_images/rl_benchmark/cartpole_a2c_results.jpg new file mode 100644 index 0000000000000000000000000000000000000000..15be42c746d81afb4d9a5b07c44f103afcf4779c GIT binary patch literal 43042 zcmdpdWmH^C)9yfU2_D=nKyY_=clY4#u7N;sf&r`)0IT=xS7%Uh7001v8CZqrWfEEJ)Ak$C~KnW9MK{fyYXK5xV zC?_r`NFZl#V{B$=1OSMAiBpACQyh5rUGr;H)HrB{7<3EtkD^#a4KR=b|L`aPp=@?|N3yVJWE&4EW`hRsz5rmsML`T5XP>F2<5xzOg4Y z+`I)Ow`hGKL7xUCko`s*9tou%AuVnAl>-1{3IK6#_aBN8SF^L504%IMom@YYT?ROh zh-X}lJ-?W9c>0Aw0vPZ%;9A2HAnh0dw-Dcjk`Vy%9hx#D3GZ!@x*HV|3A&5W*d)KW zP}qJ(W`(~c>8&Dc1@*;>8;UaqNw>p49|qNVH#Kn!R}q>>i~J^?9cP2@d(=>nTJ*zd zbMpX<7Dxr|Wda}E1qcu3s+sYVG_Rpp+RzAA=b#&WqDH`2{iuNs1fPq*Q^^K*NdznbXGDG33IUPpuv_5#a60QKI5oz zAgUrAWN?vx$@$?!jB2ZNpD2<+M-OfQ)2#xYxrGtTOyC^9O(Im_De?|xp8gY$f?G3@ z4^{9c_(m{d4v$I{1$?CHUBb=)#BQh+QV1J8r6+-IB1S2w404@yS{~=K8S&iOBf`%Spwju_)WXqWetaWuq3VVHm~#jXsYL*;&IkCW zG`toZ>xW0*Oy(kAOIm?dRrsS=J#Uwq5`843n8!=@&eu|?Y zKXw0b|K4E)qmb1jf4A@U@T71o=q4YZyFz%5|L}t|C}$+I{*~_++&-*w$vJs&JPmDG zNR?T=>RUfg`}ByP_-FNni6Ih$TIO0D+B{Sb=irm)XxNe5>I~YHgUCCK>ZeBpdiLF1 zyX#V{=&Y&$)Dx`j%>!z~HC%r>8qdv;RS?AXlamvh`(EL$7H?CKrRT9$yV|}rLjinW z2P3YAYQBN>5Lgf%Ql@noMzGv=8%5j?y~Xw8^AbNEVQN42 z7D`Blk`g*IztuQWKEj|t*0_)(XmX$oKkUTaPlamG5`mj^r=QT5MDA~_M0GQ{G1{l- zA&eu~^?5L`LwkOdOG$-{^h#9N7sC|wV6DG+!dvt;mPfK7C-9#9kN}75+h>o0Pp99-NEm`^ zg>3546~D}PV|J5vhj)8*8yjGsL01F`5n+DC5eLauLr#d5r6epNJ|e^=*Cus}a-W`> z9-dm7>YV;D)otEmzHQz#9Z{gI>`^)+Izyxx^VDDQk;je04bzRYA=)!(iT^muRMCb>SWO(Ytc#)wz>Kx@snICvR+?ju3jbh=zxAx+MON4WT7h#n#4J6Mbe;E0a z9GJYF49QGu;cOONV_7|?p;a=k5vL(qZEbdCI$5J#V_0)honU%0!&CMo5R`40g_7Gf zn^aaMy_-FX!XZ+ZUA4z`&voy7Z{{9EBR{K>kd;!RncJjVr(CBgD>^1ft8SC5ld~e< z;tbJSu5XLc5M~tiCM+dP_MYtigO}g4!IQ>Q(LFxgXG}0`C^!lDYs@9O+mHQBaE!zB zTKZ}>x9w1@_HfzXMb-AXTk#6FlD8@jc zd>P<~i6PU@g`3oyw49vI?UUjh`DyHDOjI>(+&A*9w@;uijd5RRl?XfH;O{@Buj^! z6Ff>eO0%>=wXwBV>)GqimME7?n)WnvFA-fW&+CuGZ*Yzp#;bZRLKEr| zrix~YCKm@empwk+P~&T-v89pn?0Q1C61F;crg-w+yWMR+;GRAVtq$<^`P99HM%NO9 z6W2w0BI_Zw!_Q%ja|Y*#H+&SCQdh(+CnF5#iF{<{tX>Q?}`E z$MP###^D@j!r>CZ8u(Xi3xpb2%WPP<+~${s6E}v%hTZg66$J4pO_bOC->jBm>vj7O z`?|vma4tTqFnb$6&V#vkW(O=opP&dLCy=g4R!V*)uOYD@ieZ2AmO#s4$&UW2VNp zlX|(_n^=FS3t5cm7?_Ki7%{wJR9U=(rg69ga()tpu%9F>-?W6tLY% zgQQulbIHNU;Nfr`^lk22D*ICdMZ?4S)cKE#h+wB~_6@}tX>!?=B{I9%Sh0n?TWNL1 z98E0JWkSJ^C?DbP8c(yHzb?tN@;&Dhl&4pwyO^I(eVVM!`;w=fd`#zccDWn49>_b^ zZgOf|J~>E-NA>-nLp}rE?_oR5(+(By<_QO!KiOR@XgaIH$Pz@B(?|D=crrv(lO3 z%a7rllA-2!D-9Vf8Xa#3;qee{sv6aIrJj|*^=X_Qp3W&+ElH~0@oW&i+o_4IIj=cF zn?QpV-?eF`6p8uY>rWBssls!!~*IW66hV$PakNwfH97HT?lRDz^=P-s~= zScg=HX6aGO?e;LA=H*?wqF*mKU{`v(8QKk*_#$*5+XHY&# zaNoCMRrBY?x8kz6Qq$T#V?0kUuq`ArXU zib)?Pm!%7Y1IY;mY_n;Waclb1?dG7m`K9g0E5{LbLh;c`+nKWmKB z&~}7et~77fKKv8DXg|o?vu9fN+DKg^Mu`_7{BSi=6E~KY22cX8p#b0@$bdJ%D-hr> z00avF{?9c4AOV8?pKAq>_rKbJ0sui~0I*+eG=QJ4KT*Ipu=lS|&~(t>EkKLY-~7G? znSO0*n&S-({D8IsK)?5y3xA94)ws)THGI1a0h%2v}(7Xy}M| zUa~b0`Rj{OS(;i<`*Q(b1NJme$3^g~o-E#>U=+mY$uRotBP)mVtp9*n--@ z&Dv4lmD<{Y_}`oSzK@WRgQ2~dt)rQZHNoqB^$l#C9Jz^zULW+IpMUiOy8#M%E5M)p(d0 z8Myv=%>Vf2kD`C|RQuDDm5%Q3o_~GwSI<{XIArb3jDVB$Iu$(hT(ti~_TTLlj2vt% zonD(OTbnuZ02TbF%YTV~_2Qy^o!!4?``@MfM+#g^9vCj#|JVf&OxqJ_FaW>@5EtTC zas@d|g>=Q4p8g!^ep2~S^0+Isn<7 zxnoZ6lgCLVsbu`eE1km^Z{o4zVav%vtB28xV&}osB#%{f@zw9_H{Z&i1_MMuN%;P} zyd?;hWFLC0qXE!S|9Sa}`1;4;gCP@u{O8q1#@DewTNji{{6~hIuD5OC3nuoTS9=+> zT>1Ugk^b%A^z_5m<#|f(p|-Kj`-^V zK(*5TdOiWj|EgB{BJ|_!W+a2-Pr0M3wwI@}#j!DzufIna@D2oKX1s9~hx*T`!Ek^w z-zK9&x<9i2_jZ2P=*863)GZ?AH8nM}r8&C!B_-0=M~iu799}2V5Hg&4^K;YBt{aDD zZ8@A;q<;oc2IZ(;J~K7dpB2eNiLUJ!E%lvCy*CV#I0)A{@$tH@ElxUxh23$da)>5I zQI=cz`F^&I<8YGoaH&Dj;*)cl*L6~>x<^@I;n${J(430M&scqs?vIaceFYm z5RKU&%7B)U7;)hYK07gXwVj|lI2cC-?v=)5_7%r*oL@mx(}H9Bc)4l2_2D>15P@4s zBodEWF;^-YmZqLh$NSDmk+b;ma#FU|!Dgj-UZ0nR>vCLBBZd6CYvV&_mf8n8uY1REgkF(9Cgyicf;b-+Q=(ZXQ`)eD<=4)H)s3x#=#Id1 zoBqvhC#iP%4Fp^iIPAOr2wcvso3@u0Az|V9>pj>s@B4ycofZuuA|j^sJe$>8s}w$| zb+zT|!+B!Y^>2u!Y;Pet=FvXsO^#jN?WV!re?%=1p{45e>BTe=b8~Ar`ih2zmK^*! zO2j-b2;W=nYfMbd6JqtjY#F*MS=+-gCnmi?z!BSrwS}^_=SC$J6$x?iFIqlNSIF<) z8JmA{$n>G7r)RL=5?kcKfwj9^au^$UetNK4JlNj{yMZmMTlLmlaoffH5*k`{)F+^? zpW!@udFj}&mz%Qn34y2l$i~#PoO*Y6*WyP1V5X$j>L;bjok^iH-RE=kkuMp?I^W#{2yuQZ6p7*86=bx5J7)%Vh~Esn`O=JoAH9 zpO+S^rq*)DJq77v16%{o=|bgunnh0g)XueeCKy>LLstnxl`-xcEw{5CbY`UcNS{Z+ zi}H;`cx)D^#0G{ia092)wT|S3TUnp`80mlw9xb=+xXjRLSlVVelv!QRvq0xjyUpID z<_(dF!+rN9nLm#07|L;>>wZEMAG+E4)E}r^(T_SC-9;yYsmiJg&}!9wW9VWjNt&lD%%%gIJ;+yWm-&bQ<~J zfA>5eNap*rTiFC}`1r8;QVMN)(6FxmdAClt&HMcM;Y3nEz<-8>nOPa7Zen7>zwaJ# z_7sqXdzhaU@q?=z@jGHybo8|*$Qfjp(>i_6eE>8{wbiYk)qR!=F|Xyzn9p-o4d=__ z#kLbC2WZ2fhzC=dOs+4R9qh~^4mdniX4^uYVvpLTc ztKaW2odv{IS2K-FnlIYv5mHcyA2%M*9PC)D2SK8(=oGD=-6QZmrdxehRYx!Cbv_j# zI?QNjn7<6h^+$u?w+ZK;20zG2&?&5$n5{v(dUIrlBgKmlpOT!+gcR~%e^fJT5c6qo z9rF|p9Tk=Bun!{m@a|%lGg$f1V?7Z4G}o<_w(Zg0YX0`k$|lgAEZ2V-`g2x|lY*wc zD^7G}a@di|6L`zrX=v`MuHQQ+TYTV;W|uot4)FSlHd+F!gY$-#430>}2dB)6?vbvy zOT&-{exzAvs!yLA(I5M>1+HB;(#-M#M{jyx{_#T1&XQsgdnj@}qI@582MY8T14n`;yPkLKp&05O| zmNeImFB+@gYWOh%te4&JFCETPe=h_iZcyN zK=W5dfaQZD$MYh7i?Zu0H(63$t)yG+W@Ux#SudxelB%n z2{4D3&H|`pzI!`^2gN75oPUmDM|3|KSe0z$#*0x-X1`ZQV@@BFC}+fEx!UU0WEW$F z8Iw)myk4VOZ5n+oraktS&g0rr@$q9CkC$bidKauEr)PP3UV*?%)Til~4nuRM;Zso! z%I@!-Ra~C8wpNwM8+H{45T8%ih%NwJ?m z>*z?Q&GUFax+WM*D>i4-NS)PVccQJo>E8=+(2!3105(=yBsyPOt2s8UeQ+AwNwDNp zadgyhRjtIPGX~)+t|%hd3y!g&c2x{^Z`bSWjsd7$hxKJVpd{Wh4qWFM{}ZmIvqCv) zlyAthz|kj+ZYZU>?O-krJF5~wXTqMvB_dw-6GxivGRoBcV8OZdBe|s;PMjqt!=1}% z;b+=NZft1GZEfRRkr(5W-3$DdcmI(Q=uG*Aw<%0=WxokH17c%(;Jmz2R`5XIMV_n& z3Y!Ar?W~G0tH^;~n0#hgBSJHFVtQl}9i}Dq43}L+Uu+LNvTWKGKg{qb{6X*N&8+zE z2-W{}X$}a3a5AIYg^0Fzg_Go@paedUZLrG+>9BGQx)>vx1Xl!P=mDf|8qKLbh)-Iv7ECUS3JXe_SB*dC|H_I<#Pl+<^af z?@lAU?|O^jJenv6ZNn8w(EW!vxf#0@A;(p=!YkKnbN@4}X^6r0Fi0@m|fY1UOeS`UY4`XzSKHgc};`+lTNK z!AbkrtkzR_vIlnA_};e})F8ZMb{Wa2LRcG;E%aOLPRe@&lWqivNO*v*><>>!EqR-b z9=z66tt13EfEzG%O!v;q2VViuTV%n%V_vfHy(oa8=$f){!OVA-t>^A;A}E7w)yf?& zVNh?3AKZkCzO-;&*zES%v@}+YWX&~*(cIe6gY#Z=P61SZ9hkzP%NzPf2KXF~Yz#=;zawrLGh=Otjx2GiEF0+hV38uH#}VP00kjdtFxRZ z+D-FmO~;`$(2;LMJuB;%@_oz#9@Uj(DugW2*C9EAUN7qg2!HtlPAIT2y&gTJ$S)s2 z+$TV(0Nur-NsOW?a37cZoc{PJ0>5?=MBrHvjD+QTo)xUx0~`4eJqk@<-7lj@t%{Xc z7uZ7f2Pz?tLxX+Oo8}yK`TGvz1mk;L#TkA<4%ABUIKz(r_dnIzh$1% zzW&p*F21wme<@@u2^3P}9smELq@P%wN5dqs9}t^K7#ZdCIbU5}eN9aKS%HiFhe+U~ zLlFZW@EL+Z?#x2&YAv194ZXlvDLO|oLBh()YF7qvOY{}oZ3>E%+Blk-6<7iRrqi3_ z=*LG-hVxVQb`z*?rlrun^6KhI2M~m3A?+yTTD?A1=H;o(B0cRm`HEPRSk2re2X2>HqRsqDlRUrps3ij z$|%Z|ORcZ3&tSJ9G|5ux7!Ker*52&P`r8ED69M2B2yb@w{uVX@X_Xp_M^~mlQB4#d z0??|C^xL_({}#?d$(1{Y<468&{6F=9v4ibwOPx=o&BN^`kBl5U?{4;PqL_||d>zLGS zX$9R?$^z0FAeJUuccqw3xW*&FOqe0RW<4ZU_)BB|pls72pboh^@eXVRPQT)>8 zVkYFD(66)0AGiB!fBF#+%9RO#3|gdcG9}wm&b+lVAU=VnBWd?A=JNkFTLS*=JZP}x zO_@63$~|BhmU71a8n7sbeXxXYTqDCn=&Cn}pdImn6L)Ml!1(Jh;gdspUtk{l4y+Wb z*hV3{8$VHPy7-z}txk02u<)x>pt6G63PD#DijmX{Y7!TlbT0guQcM*bq(d|qhd6Lq zW7b?Tepw0_qhDjLN17Y{?p{Ah@BZ zrco}j`@FuC1=~QdX3};H6`L;=zfBH|-H#;>TCZ*pMvv(upg-8=#=)z#P^|R!A@Z%k z^c>!hWbJVr><;bMksjt5Z7p(;GN5_72Zo0KGH!xf=tKDoK#5haoG;ygc<3bF3HogH zAoq(7^+Rm_QLr*=bw734hXe>_w@?w#RMIsULXfyP<@5a_n*ku_ux=++Qc`k0%8B1! zsBDp=H)|~}BNHp1B{~RXW=Qw;%w!^IKP*x6KAy&WxH~f`C@hR0h#~KZYutf+k}vv0s^skO#(K!;X?x((M` zF4vD)a_zEW&~OvW%3R{~UR+q%x?Vg6vVffGG;NQwKk^Wa2|tS)8-G48{Z5TA+9C@C z>mUXB`PJ|)cjwCI-J4^Bin6>PcE7vi!eK;Fefa`yzcifnvD6$tx49^i`@Xt_8~xW7 zBA9|gEr3?EdN|Bvkjg3jAqdiLa7S|iMrCt5l|SwqcZaru(^*TRN4)0H(dd4~!T!=b zwB72lJLtAA|BUOha(FiJ*mUCk-s{15$V{S~qv~w`>{&l(BYFE5IlTXT0l1a80OXXFBzF@mn^xmgN*^LuXt z>rx7B>=MPMS(S~z%fpzS_%x`7cHWj<%j0QcK|6l-9pf%uadtnLIa(O@?k*(?HlWi&yZ-}#Z$5A=xR!iS(L4{oF|*r zaj1%q*q`6UOHkGElIh7`AvU9wY%im2mO#WV(sXmn3DK6GP+e{{BvO%mEz zy@6$etb`>YCDm|Y*M8-+b92e^9xlVFsdosRy*O8YIh^(cu&Ca`;@(Hz57dMjnMA1% zRUt(h*^Lw5h10XV#*`|`+foN`igCAlfY}33nQ|6Z=1`u!!S^cRg4mqSu{a)w(dT78 zEi|gSUP;|OJraglcXmM)Ep`nR72*~)OFPyL#l=#AWOI`y>=v;6nrdqK&8@bMkdl)2j_@~W>8z4XOFM&LW%#Q&C}w`FV|Gm2>L_Iff?9O5{S?zbAsJ;jOlf!s$QynA`lEt;^*$PeTum_+vN|_c#QQAN>>lG= zYnttb-T6=ABR6IFtHR7?0f@JUC)Fi(#2W_MIjsz}lL;U=FMuV?=V+o_79hK$yjkfH zu#Af$@vufe3rt-|bL_*>kJq;I1Eha z$IfmpmmpIeNT@eizV;rz3rR^ydj(%8ahoJi^nuwj+7DG|KVdjzF=?Hyf;iNwNE>bSJiL&3(6} zZXK!F>{@`s<@EKhjpZkk-dZ4Sj&)>!w%H*>WnzFvUn=X3W;S2Lv@Bl)*CT z9NOq1o=SGy_avOSWkZ6vAJwRu9?^AB^7!#Q;kQ#DfVsbN~v!WAv^xM6rK~J$Pkc%}OnrE94JZp0297C{Wa40CwDK2g?%V5lq; zDM8?LmpD&-zjcgZDnM0h%wTJ0W)%Hd5%g zH|z0}2&Bk`S*NE|Bp>tnJDgeZ_uG$I1@jY|mD_6>d;bpnzxa>zP_SbEnyhs|8O;)5 z6BU-FYNP@&p&P$|{$mRV|=da#}pf!_<-Ss&c^K6|baa7#M+72-cI>=k+ zguX=s8A~&5+o$3DZF(enBx{MIgASC1K(KOtEDPQ{*+1O}pYSoaf`NWLiu7bp4FOPBv5>PDQe2^Ey{Y5#_wU=ZYT)V@h)4zyDBwdI(CbY{{N+ zmZ~3|cE1=fv^>xv0P!IOmSdE;e10POMKQ60tx;Q}%@e`K6UFBd{`YK_(FxMTypYm9ABM;l1&qbpV1Er$xX& zt2hRx#%L0-SfPjaw*9!LD57A$~<}Fz9bP2MStLnOPpv$x= zB!YB;A;KWbKPqv*s?;WYNbpPdL};~(GO`|#WNaL@hZ}^u2(w}teJrZH^7@JYu3TxC zN^fvVOYe+ZmRlZ_FDDA0uF_4KR7-WX^}&pbvE+@FBtawq!{J>A`C;d?XYB3Ws-&vg zbR?s#or;Q%J_L*oN$Ke)tQO^#=EukNu&;q6Rd)3hj)l=8W+yB#5>!xErvPS#38gxo z#$68Q2uDUnK5jvvV5~D%+oZL`=B4ZqHo%4z78X{ovig9_P0x{SJ2`h?VWW{fP~wtn zwNz&d1iqM2 zP*CWEC;3lPX+?xwZzF_!f^R|y_Ru0<`BlRykY351S7om^K*g8V^VT%XD5NsywRG0*HT=7XHs8Tk|x zzD`X|o!$(01-;RQ17`hyW}vxv2m&1swua1th={ug2HvQgKKal3FaP-Ic1ciXc-L+l z^r8xxLQU!=Im%F$B*7x@F(wtAuf`5O#C)XNHZB2AguiwnI+!_@hP21m4V+C9UG@6q zDdYWeICf*4NwuVOThcWkZ33}GZ5Z_LIV%v|&T|#k);iFlM_phYo3l75?YgF&6U~y!P7&Je~nz6RW?SW7nvS^sWTC?K5yIcL(Z zj*E#zV>f}68|_u?S!d>TA1pN>7DjKi2?GsA3G{({Ee?mQ|7O#NSS*)y5BfFBDna zLop!QHqEW9JX;f3d!3+1kA5Ve7YD=iEBtI+;XdVW8H^Kw4i+_7Xtm0-h!iXxuIfK! zicifdTg$PNZpkkXf&~o*3sg$4AqEE$I*M-`aZYpc!?x21KO^4TYzweL>Xm6Bw70;5 zDI-J46weytNPXDKH@7Bc>WbvRIn#MPIzZp0yv7)OvM4~LAH6ffLe0+=02K}fg!L92 zMK-f>VyPMFd}&*1NgPb6YhncXuS-6Cw511J6wuW6m1MF-#6gQaqnZw=iNM|+7Bj1= zpf>@5f#R7#w!SFj*D=x#;grEk<)hSf02z)JG%}FN zVR|bep2Lqiq0F(C&|z;qm8170M^0PC87g%Z&`bLT#oKvb>-ToO^vi(i3WNcEykiel z57aYCE<${u!@;ly24B?Q=Xu=G-llUEi&Pmy>K!W+8}MRLeL-6hbkxoW+PA9O_j?61s7=Xq2`a-8hRJ=wbwkuah*nyruR! zNnZiRuB=Qc_)l1nodc>OSQ%B8;kwcxp$zcPhLRLBpOlS#4`bC7v4P=YIU*geCRDt< z9nR~qiO2_Bh3TLDP?K4D!SfKznoAm-Jiedj3QB8hj}lbk=vFVS=i^h6$;IRQwo)4t z({QIUxeodf_3G%DWt#Ma6Z`w9hI>F|l7&tRfp+QtN^X}g%SQwO?J^yTFF5mS^o0G7 zjqK`1@$H%FF+|Z17Fv$qD%8hmK0)KT6q%Xe^+Cy_hxwyn5`}%HrHV$nw)I83Lj&@; zdKSo8g&MdCv$M6p>MbZ1P&jC)TF$|q59_0qZTi3 zE>%=uP5OePU@zHZk#YvSNo*++4Ie=EbT%bHEC$94pRPvjwSWWnmIbJT?0-NhWQZ}S z_BLc;6L{meSk@e#5w%e$PG3+%H-X9KVPjrFKJ8M=i&#g<>C0{?l!yUN#WkK*LdgOj z{;$ypyc&(GB0I^e(GbYx;yqwf^Kd9^!0Rv+B`~8ZXsK_ybwGaaioG=jZgHeoTxP89 zavO+qa)zjWgET}XZxNp=F@UHJ$AvYE`< z%Do^bg)1+yx9=$m%&Va*S(rhfE5hgS!XVDnt;Azjd5Jp+MbP|9=T(p_+uDb)q)>ho zpA}Jjwgs48u)f<(v3}VLx@3jcn3MsUKosIBXNQKmI%!$!eY}>(5gH!S?(WG~V11n~ z@Wdr#X0G|c$|;$P!zC}n`-J)>f~C)AkDTg%XOopZlHQu6qP!AK?(x3UF&IT;BWf|7 z#AU|j0@bmz3)~fIZ&frkKQo_K)FRsb z=Jf>??Z}R;&mSCiJ3i&)Yy8f;hU0rGqm?d%{_%`Y+cmSR@g65hS#kW1Y{A&CZa+N* zn(czbn5?d^RD(7y7p&+jb?pcu(-q}s`Z+kM;>gY`SDgbjLnTXEu5jsn3^lr*3JpvG zoKCe5YvIDiC;Rl9_Y(1>Qk{dL2PBAaV2>i@)^p|J*3IiInB4B!Lc$F$kH^cJG-F&( zvUj@|vI7)VvQgeomCPrm*47b^$DfO~3yWytG;JD&9DU+xbxCEpo{PDzF1zfH-Ao33 zK^lMwLV~^#E{Po=S=|}ya#bOk#)8|0_KIA937AptFQfPuqN1NUU3IA7@MgdNOrsrb zP{oenW(1IZ!z9PTe5jU{=JVK?3_m}AskKNML5U5ceq%FFh|`YvEUz>{uv$t>55K=8 z2$}SH%5?x4P=BYZeJA^8w}c_Do6)$7QC`X8`U~U8$Q@m;ZEE(1Lq$^ZKDW^Z6xrm0 zfuS~a1+<25@!D_L>P;_W#M6-qk zF=M^(nW4F^^aBTngfW%P(Qca8Q6cm{j5c2abakk5(@)!7@hRHReVwjw)^%$Vb~Sai z^Qd3Za+(JRAjUFzj>~3`qhZsuNERg*4cQ1+bhPUQ=`g7=?59-JL%>N22UHMC=%jvl zxS07x2zcbVOb2|T;+Fx=ATwuEWOkm%4ZPj%DCpiI>*7%f1tn!XFpM*`YJIf#!*QF^ zw_Y?PUjSAZu;MLNfV>{~j`68ne-}UVkdZk<1C}&nFD4#KC z})E%Zq$@-JCXr=Yr^s`;YY>j1XEU|?_Y0TVfF9)w3-KAY>p zSw`%_X(Pbfk45S_0gOW$s9p^tx6e(BDk)LtPb<2#oYq$Ukm)TdseJ$KlQ{}Ex6NuJ z7hc5KkYSL2>Z5>Wrs3tfO;MH7z|-v2w292{W=$Lkj=tO#T=3Ai0jtkj5O2Ca*p_;? zPzGRdoxGbc2n$==fcmqHg4RMXCc4PmajjV}xJ$7Zn@j_fO0}icG^kb5gV)g9hyNWo z7IXj%$W2wsV=O4%YExC)^OYU@dcdZo^xzPc@}YT%)$%eZnE|1&ih@FBlkw6&7)A%B z+?dBP@R@0Kx+d65n)m)uKa+a(cAl zy4__igZ`nIW&}XQ?IL{q*DaS+6pFyi8y<^w5DRmr!{469TP{-}Ru;VBysjHr2lFn^ z(CR70DpC?Iw4^TBmuEfXMk{0AeX5}(FIYYZJ`!8M91o1y8WFxbKZAMbUQlgaBn;kT zawnbOL4<=V&LHZw1z1oPW%U_Lu1?kSVPHLofi8=*pOqQar8c_A2Kc|ZS!Gq-*f7}s z?qS2IjkAb_jr+2ef`zMQW@-6jblozo$_hx812NW~Wr7uUskWNh`y$mcm4fP@4K;r3 z<$#H@|0{q0=xJtM+jo119mzr}1WT#BRFGjO{eCQ)`Rn^x$ngW72S!NRH<9VHc2%vWTG`}|{^rYkL{ zk%#Q+QM@0B1w1I>q(?8`srDDQikNe91#>UW;%H5qWpV#vh+)>4PRA?G_gKA!G&KG1u$ z+=NJ2RA%tR(K5-=uabwjqFwV?ImJrxIm%GEXRk@?r62&u0DKO>|? z0a=NO7;Sn=@~f0*(ARmgNc;Gn*6v}&w1o+*stfo&2ZpJs31n^GXZqrW@Lvuhu%?O9 z%Bncwcvh9jtH_PTjWZw0>7Txhu5{$ry}V^ssuMFv)Gp7Z9P^M6cxc;=WQn7Yw>jD4 z>oHvj=DCJ%X;xBJ&iO(Qr0L=D(g>kc80<&#o12%eWY9f8HOg6-nFojwcyqWCEVe89 z@M3|9J~muqU{)0s_1!?g7dM%k1MGDtQws|esNo2W)z(pm4y?$aSvp_}1p)zYX?+}Q z^Q@y4$VW>oB04xYkWx_%)s6-_gT80RK}XLcLw-w&j1G4C5ZVU)JP36skfnI&UlZwG zZ8g%yagvflhEU8ib^8+vllQ&!TZPD%m$T`2UR5f4W(DVBI?hUcJm`qnx93aCVG@N{ z;05^$MfUTJyY?WF?V6z$q&y?f_4r&GQ=2cMx|-Q!`15F35;H3vD&KKVK8ieU+j)I7 zdq&p8v2-}_cJuMddc0H3($ZO+d`tHVnW^{TyVO!bElA)i2p}!RR5JCJpn`7ir;qN9 zkDfBKMDib4+p9}hRj&)a=rc1jFF#}MkI-LU);P+qF1ML*FQ2S%@nX1H;67*^AE|p} zaC2p->eyv<_w)@&zTBrsR)mxevpE)e-YgSEwrKpyeDpd0TXnpG=*w4?YbX)I>_ab7 zyuUGGlsk#XNTd1ebOd*I{XOY?cw2FMA&GdavkM|+q4~_Kbu0vm-~)&Nj&dO#(lAD( z`F^$$c!Oy;;|Beg_xxYZQ|eEQxi6m%%gL;NZ74rdNFO8h(17OUk|^x_`Sux(Ra-B_Fb z=K{zZgrCKTxgD>C9VB`$MZ4R6t(MQ-n?5m{#(1Q_0#@^&M!)*n--()e@3Re3%4)q9 zzNH`k41p6`&BRx;rASgKX(FvQ$*{3BiMi7Zok2p)7HfAT~FsU5mp?D$lOy^LZ;oh_=8&e)BZX^PF!E zqkB5^jnDT~(;p5f-Od74agCA7v>HTqxS4c!c-d#O>Sl^UE9rrFvmeE`4*^dli?AFEqp)KtiKE)6%Qqv32;JQB(@v6G^zD?x$!OOk!O6ej>fDQ92fzX}DQO!zdYqxxJ z>n-EHzD_iZj9RpQ_%uVKwSAMqxNDtMc3!PxKfwGoHAkY;ID2S;`JmK< z%0>Q9F`X|fn0XxQ4~)ex5;E^G(ic17bNtUU8!1GS>(5Iw;RT>86%Af#deE_A>zDgg z=Oh*`LEWQk7pRksK?rKE*rpg&O8%4#yfZ`O_-TTr6u7un38kfjC4`xz=1NzR=byrg z9I>)!#RR8od)&G`*QImn`Y1nwOS~M;WP%cPEz&$2bo8M!#CX=UpVoE-<}=9*_8?C7 zr~BA%>JO}diIfhRc3SWQbluk7uMfbwoUJ}m-G-$TYe-{_RmVvujTmkZB^^m+0=;fr zRtnj`O?(xX-F2RSqWNJzAC&nZU$hyF&jss!x1fhL?(D+sR-D8zVWD zMzY~>WB-S_ul}pz`@RN52_+;%x?4m*L^`BX;?mvS-Cfch64KmDr!+`+UOJ__Ti_Xt z_via3JpSOzADo$U&z!x_K5MVFM%vcFIKQ@*E>g<}!y9GLP^Mfo=}au*Wt#f9Twq+@pAZiTT1LJ`tHJazQan` z+xf9LAoeM{iJUj@PTA~U;rV1Cp$uA}W`qDy+p|TGCwNxt)^mUM7WUzm;UM=I=#)hR z#H7zf7Z_HWf$##+FhmzSu`y@I0DS`BaCKwl7q}=HMNFh}cqmB=NqM6jrqYn#&IrBe zU<7kJ)Uwj*)M)LdFds|H*v(eaYAWp%J20xR8y-i z#%kvzH9p~0Hq*faBGhlvt@o)!I}{&1f3E{W0BGG1*+O?diu%#4tnnr_+$OZIA@IF4 zA(IAyw5&~0un`ocFt5mA^V}oNMWnHi)yv!$W8M#hM|;QQrsLvAvVpOvpA9XvSE?7m zz7uG0A-0I}a&H2;Xbr+S3ocOF|VM@?arz*M#>nU6W zhx!)uAw-#xrh>x0x#x_#BK>vpDJ z5~jEtas&gpDCZF_I_t7k2H>J*O7C8lE|G0?&{0ksk7DG(%Jn!(`4=TZPIayGs}eWm z(WBsUm6Jnarx{&?m%KM^!%t5jxPDye-5O8U5Q`NznenKU`HRRjFN7##G8DGIlQcW{ z;b1dLy8lReMrYHIUopCJ%r=`nzsFG7-De_yyBp+mL&z1zRAVGVoU`sA#H&MpLa(3e zDrFr_SJVVL{3U1zV`f0_ylCa-;TJ;LIBg`2Kf;;z6C_Fk1E{7_;rWbagdW!jA{N}_ z9*pWoZ#g$~-(k<{&i7$6U{~x84H2nfEboh>8lps_3!>%<>madT_GgV(LX^gVXX`qx8b*y@{}={hl~^y5K7%)*q2rrE4M7O8d64H<>$b-H@Re{1D% z*`-`JIIrgC!W& zz3saV!IG^fT|7Z%Y=qp{xGKX;8m*sKcE;%vZhCFC4<%A}t~x-<1h@=5RnpYYCdjVk zI+q-~mK~ZDq>5bccecltzD$79OANyGEaaFfafq2Z>;VDJsFxO=0PN=BK*WEz^{jRFyqP#wvag7b^w)jP=pzyagH(m7{eqqs=@GRZ8H6v2jYn@l zwwm2t+Dx@I&%)Zo?zZQ0gicFk2Vu`elMqKHg!f)EkrTU!H6@!wkicY@2?X!t(AlgEVnc@Al+xwx-BZtn_dx}_kmYUa+ArT7p#nD9H-6TRL#V@t_Xzg6vgxkp} z9a;pxwbLnr8qM2W^TRRi0c`WjaT2ljbL>~wxt-M6H+3fMRUw8l9yPra@t9O=-x&yW zOb3A1O$xVBnoIiL8a4j0Gvw+ELzJs#)K1PzLxiB06*IaV-XWMblmyBleRYz&B#FHy z$cm>c)gCWZfp6v#w>(7jpN0+K&2QyC3rpaawbH)@gjORg817B1a-6n z_huSX)uyoZpBF&j=O^m&Y?7wFLHI5f(RF-mh%;sM7M_%_>|Fy3L5Jp|&{+9v;LVVS}8aHdnCn_&$plr4Jg^hYSoO01a@LNLArm{pnzQg|0|Ip@+U;(maP- z7}pU5DM1%D&oJg2oi@Md&a-k;^{HB?sRmFhc?9-U!_hmtH@j-QYi8>*4ZHP(>}Y+H zs&*$oPmiE@c5jYcf>0fAo@SCyY2#xcIszfo<{88elmb4Qq$64Iaz5hQ{&X_+q%UA1 z+eqQGl$dgb+x1+z86h@qB8k~?Y^des^*V|?%;;ge8Dyu7+|zMWSdrZiX9sxKHu)-& zyWKm7=<_t+>a{IG=5H@dgG@orarb*xsIoR-zw=M(&7a-~RMn^@X^$z&s22UydEfMY z4=A9OG>vYnqc3(nB%0+{ZMe(lHECM!>-d^MWE_8X7kd1-pnzVXoyK?my+aH{?#P8+?J_t+U75?$(voA~yO8e?lFf@^mVYB#N<|-*aTwZ~+|` z5a}Ry?Z{iA?W{>7pzZpYs?ub7UukAHRaiF^>X+DPIWewQ%Bn1xb22c$X{>h6w;+7( zA;0>5_@l0e2{6DRq@&t)b{xw`zX2aqOuhKK z-_-&|$GJm{r*Z^Uf57H%RIT8iOg*RS9V~q_&Q&@ciKH;>sYO-Fa4RQ#E)nx`hqXeh zqUMFwk%d*!=Q?>pd0w3#6U5R)`_)_i2(Jl#D)73^7A3=lX`zoiB;MCbrG^D~E_-@p zg|dCkI$64?BF$4_EGg69)pQw+DTG;E%4g)9R?w}@zwJ1U;X7sk$})8F6Z0}@;O7@g-;fYRuGZt! zUDQDEN16_3oin&zihI zW%ZVpagV1?rZeZ~vAu|ewqYaJ$u6z$bRXY4|5Gx-A2hT0Aay1|w!I$!>DVg)E9rG( z-AAn>JTUw!BQ!BMmTKcWAzibxz5UPE@GL(Bqsn14FGLh-p6Vq%r7Gpl;aYY9@8AHO zxK#(ByGiS0O$R!R zcu2MvZ}mZx9|U>-@nwyxTB8hcG%cFOM==c0{WKr*(K}>C|EF8N4pDR-ZWGgQ&Y$W= zwnxL-Qj@4VtjyHJwH<-e-OlI?A)drIK0(*rA?49J+h~!3?UrAIE{1Zhb&5J0E z&|;*7SSwVTCUF#dpPhfN&&iX;oK1S+xEa; zXw9}&_7GZ7C`H{Ik7JAWx@qx}Tr3 zsLI(Z@nr}W=#MLhN(L<`&&f|0`4SL{Jfz+evYa3(S36;tXiw8THeC#zBk+#~ z(<!u#xMpu5zh4P z?g&?23xA^V9|cSNaZs zh%_|A?pZ8=`IbXBxKd*t0CblF0f!a*kMlynKe=IpS#ys0i-`T@<3gb)>zwxzR*xlyFkq(4-t1+aB|(43)L zf7MvyV?2qD6l@)ad7LHOzIn1`Gc)@HPJf?b3{b=Pmyu4E_?lvLo{S z=3BFuN$k{z0`yQ20L$;S?s+-2^Ls6WcyBVz*>&v`w(Wi%d^`a9L)H`kqp8PV4X6+} z_+g0heuTQp&l265psURgX_+h;YWZplsgeaNIloxh{&HD~#QPebjvtpnHeP)nLd3Fr z5g?B2HQb^!qbpucs8P5TRmtg@E&BwHUh5qpT5XTNgxg!A0o`N)fD9vJ3E#jE=OdY^uujGRX|pp~58U&w z9{!c zzC)F((G=(>@FYXXP`wk5FPpLme%+NvMXJ;^+qW7mhf1{Bcu&PHJS+_8hQXkLR{R(I zS>3L>CxZ`7H9y|FAZV_#Z%RdxE1cY?zwa>msapS8xC=%p8Ln?P-g(YSx)Xk7K?~_n zo9Re2IOLrN6(l~fIx=+7!QdzbkGkX=p}IJo>tsE>J}0&0T= z;XrAQ<+}vPBl|=VT+fmG31EN4GJe*rhRWFyK?5t~bScP>`NBU7J?kAmkAevvpK6J6H80k$=KdN{GadX#Qi58|jP(nAG2lTE13Hno^5NTgii3e;B@j^}J5V zmAsuR=^9Nq&%4pSnPpO4rNg2)r^Hu1^}clMkF3UQL!0b9T3t8?o2|X+^w~h3wtlns zXHOz8j|+$c#W>iS&nivGUNu@dW!OXk3@p=_<3Ob;0ou|mM$NW98*Jx5rg8Rt@**D13xajHtR_=ILx5rjvY91`(xLZ^@< zDAeSmp;8#v1^&47L5;84GZeMSs(9~nrME?u+05N-_ZwlG&@jIQpKI$67nrz>h3g56|yl8YJGB3@+hu%2X8no9xk?!3r?3zBYn_AJ6^X2+9MQ!Y+|vYyJzmT>KaTiHML2}??2EmK6tLXu-X%vhtjL%Laj5} zRY8)D8uq12pFY+8a((HQeYb8E-Qo*J{EK)uqDAPQqpRfNxfi5~6&Vmvs?2V#treV6 zA5epq=M({L31E3uIPY1Qz+&yt9YL5v@+1WCxN=#;4eeOwe#5{JFBVI6ZGV+SOAPs` z7@m+V03)YBM=BFIGkeOrKw@^|qsz0k|mJ16kF)?5)3R?D4g1Fp$M z^|k9>!6M}Tt&C_Am({Kd)g^ZaQg_P_I8#YfzrNMu2uRw0qT&0xavEkGF{=;{e#0qo zn_67E+K>j5jv)~2(0!PESJYMu-^O9{vtt=Q2~9+>ZC2>Y#wL9=o{VhZ*j=NM0u7+I zzoaRrkdTpc5#=OPOtdJCa$C}*Y`942EU_W!Q3|ElmEG}JD50CtOWyOZQHz7AQ(I@h zlUwI((vmXW>8TE%+S+2)wGR(ePq{*XFTWe{n&Lz}Ff)ueN>&!)B?p-^2v+Z!ztB;Q zfOh4G78l!?x78O}CThzm1hnmH<-&ty%Ty0psLW`)-#?bUXpc0u6)BxW(NRDiR3;XJ}MN(QizJQ?p z6Vya%D9_9LdXsLL%&_usj4iN-vv~yIS{*R(s3ma)4d*dc!bRm!3QNcam#w@LBN=(5 zzReQqd5N_mz#lndn?#2_V_JTT%V{{W1e{-IR1}T4B-OgIXZj-nGG;+c{E@E;ito=N zlP93gO#9;I+wT772HQ2Z^4P^c!NS*rcD#MnmBM|UBKAw!S0}2cADQMn4-j_W!&>Z7 z`;*8i5td?f-};dX#aCnCa0M2O!0Ql(K8r?;s&A!yw9Sa9;AWKc->Wj^<=F zNvt&k7~#$Mg0_}us|->cq#NAj4sqr>0Eba?Dam-&h-p47GAUyJ*aj#D)y^`(NKVg2 zL*e(F2Rs0%Y=WXd_<9aUO+S`Yu@WS8?+Ua~PAV}faiqhrKZVLVjRZokeEaozgSYr7@tb}cRcgt+LnL=_p@8#*QSTjpiIlIydW0(X$2-mqX+ z&{Ji3m0aADXSY@DjURzSgFJvINS9la1OJrP&T#dwS6WYcj|YZnYd2=ag8hC&EMZ{a z8Q_oz&ckB7{2HHcT9nVZXIBlYIS;-Ol@h!F{X8EUM;qfxd(cqKY&5vY2C8Y0P3cQp zXc*RZU)B5UsbcZLdasky@CwDwaXb{w{_8~P=tkIkRv$UWQE+0IgQHj2Dq_5Wdmz9A ztdo;{@k6-?t|%ssI_U|Ds1mDz!9Tt4_Z!AK8y;E2libF_)iICrmnPE9Wr?YJ#>u7&4wO&!L z8~-(5SG@b;@@+S97vdqlc|QOGtGg&(s#0n3zvuP1R=~|N?Y&KHYY#P*9y$nbIj?5B z_3L#YnUQj9?e_u3xxH5#eHv;@# zyArUh*P+U4(+N05I#ZG4Bgv02N;#2Dir&~n!Z)Wp%8c|c7S(gv*V_#OW}nBF58ZZ> zusf=2QQ8yUnl+)ysqka19zRGdSt^-8_O5+KUy)(!0-6Y7luLhI=ye*uvA;%H${a}} zCCFv)5x%&8G}8D(_UEDW7Afo!$rnvqE0Vs!p&Oid?6;mf6ks?HYjHQJ$;lWLTnTg# zkaa1r!BcSAj_SKyH|(dh&Krfj{Cta^;H{K!!jaDHk@=O~!aHUEP(5jRgqu$IB16k6 z-fA*=6=cJmGk3H5G`8O0%}3&5CnO3;8O*#E>we$L4@6w+i>CwMNZA1B@+uBEG|3#F z{(O9$=#i?jMvRG}>AG3*4qex?zI1iB#USzN*=}Ez=}0Z++&OrnAIQ)2JNhqTIOI!Y zgQMsp3E#>G=JY=~EP5aB1a*_o`iY%!;~tw3WBL^tldqsPXzB)w6h7btx)sR6?9S#8o23ZJ2u z#$NDNO5a4$^gZm6+niS=c+J=u&LP@x{q3gDxT#be@>d$si?8XV5)RHgIYVDc;-}M| zN)JY7oIni3Z2aFWsvBaAY#|UGle@RR7VuUOcIKZ@MM;J^;C5>q7}zyi=sbs^fRO=% z%DTwoL2^Hpt$J+3PMv6yQB83t?Y`s)*KttTP)ot~{A+kHup3MNHBZ<9AwkaYUPT^65Vz z)9c~UH!bjinz2@%wO8C(*~Tx5G`oqw5Ggv3hkqkKD{<|PDi#A&!33vcR~2lHAAkg) zg-$ab3P0ThDerr3Ft8osb>q}nCnvFEwRe;&f*P+S-3x4y7wZb^v=X&z10DV;6&ww% zL`6jB4}bvjWM;p+M=C#Tl&g|(AF;8zRMU&5(^SOgsVHDpr>COG5Fa0#g$>J0;1~-p z5CQmqkUI6VYhzf;z?4H+!T!0UQEz*VWjyP}n#(E1d0o|O)bn+@-P%8YhPivs^RpOa zAa+S_MXzf@T5v4zx;;7QMbvoyMPgQ;aJ+oA+=qjbQU41?>{?SvA~b%#5EZts?^H2# zuamyB9q3tUWpkHWRXVCDVIH|Q$TSiGaDO+Dq8-C9)uJCs_Ff%U7bQP!z?KhcT4apw z4N_XrMIb`^GrU;abOaI#W!Xwe3??sPJjB1x5??F5ucBo#l|Hj--0Sfa^PSqGWuG;4 zj;W2G1o(gs%);S1vqvvNLF8p5YQl}SIeqq{dCddbx+K6ZOJa>U{qX#$bjFN404@@d zd54HW&o`1MdK$oL@w1ot4@IL&)lO^o?m+K=j@|9hm2%}}Cq)sVxHaB${3MP|9ZGov z?$+lf5bep&PtbzwoM15cLg^K(^ynZkjV^n(qK_p1?YZH$Snp&Q@z?{lHE@_>Zf;mP z#BNay(x*P#UqwdBOGUl%vWfA9&{JE`2j4V<3o+%zb8^W56o0X*K0a+ih|{LdE+46P z84MmUQOkg$3#HMmizpUQ2zn=0$bsCLRRh22uNy%^;TZur)7iUT{3z4;iQRtZfjKr- z#XWvxovid5(-ownA6Ya?-)~-iOL#es9Q_4}ZMH%60|n34dGo1nlMD;UvwVEI zW>v_O^CtxVxkYh5I*U+mpUhLB!@x(if*U7!qLDQMm4$v8$qM-m+#b4)(@>#E$392H zH=phLP1k*7?1}*!2Dez^CY28iiT8U&QKMY7Bn+_VVNMqO+G4Fzs(2!q1N6=)a*B#+ zd;j{MH|PIEm&=>*7pjdA(Ob+y0k}o=uZ)1t24bkKT)r1ERa{M#x|-& zTZa)Rz!*hR3>V_Ern52nG>1EYpx}|1k+&!ffRsj3_{RTdd?>*9n)s8aNMLy>z3oB5 z5)ohGgekSX9!fK+ojkPU)Z^yu6p4&hcKLuVFYx*(Mj-HTAJBGrShHPLeh{x!4mZdA z%;Fvc?}X&%=l8YAQ};4@xx@hW1RPIQl)tNaUae58oU9JZfj|tg)G|>f<9VWsS2lES zBY!)Ryg-IaQipPn0;5JlT&XJ_fy=8r^r2e+vcHC=sP5OIA5jBIln8wHxLe&W|%WkLRNz@?lV+I#~} z<|`SEeJ>uJNC02PYn@{`m;Liv%8fL%wdPwZB_fK4UV4)qE+x@SOnB2cngz%SL{|>m zAN;znL{%mzR92?ooi}l9(yF2tJEf9z@>es9p?@Y_TdZ)f+j76UdG9I7gWD@KX@45Kk=$v)bsuBeog1{^Vn#;FrPTVZkSX4*^8d|jt+Qs z%MDp(rvnd+zs4imuL16UCI%3gS$)rFX-X5HtqTPhLDQ5wKelLSky+i9jlb8=DXU*S z^ZTt0eSA4UyG4qG1j&_4z^cKxzQy9@yJs z`1tOr!36~|043Y9&cMJxJOKHnH@|C;BnJp$;fXhd!vd3aWicTFC@Za3xcidPK!al%1R4I z>j#XR(orijVmDOHf9+4aWsu-=w7HLSl2sc`Y>tT#oRgx?^KT#M0QPy31sn))9*Ie0 z8~Rr4@Z_gzG`qJjoR1((okHz?VSK%C!`8%Y1|D%Sys2-io@ts)aENkkbG^D_0(5(@ z8qmQKP>O{0GCwN7{x_+HNUhm-`ROUt5AOE{=hH26fvEUz%9)rd+?>kyXldiWc;3+` zzhfHis#dQSlaMGcZcu9#EJ{gHO*APl>i)pSCZU#k+Fk0Z=In0r=t&SIQk{cWr!c4q zau!?Rmr}fzPn!Ys&B?cI94cDI_r8v^ZFlQaL1;FPh_(O3uEU7HZw`QQ7s>sm+F;le zM!>F~m8C%2TOb+|5+0qhgXgju{vf*KITMq%a6kzxf;d1ah{`W5CH1n(DZU%(8l%zR zCI6ZMg!(YS>_#+lFI!cb#)HX{kYO0V+IDSy2DCz0p0dz%CuT+}dUHEi|hlg?8 zeRArk%3MwjEd)&U^s;w_Nt+SE`uh45lvTunW374Ie=)M$4>~T=Rn`7v@yAM@JSrN= z(N-}o zlX%d0Lj0PSx8(5U>tm+@R^mNUKv#Nn&A&zHHLPy;!wB(Vs;7sQ+?dR2@Rq!>Nr0_Z ztIjqYt5D!1RyI-X^7eUh>$WFlU=ZIwp)`a_ z^iJEmaxl+=({b-T2B`6nv*z$#sf5SQRcuhsZtCB? z41n^*NS5Kv(n6j5Vb1l9<&~x~S%3SHY-JCxuf-bGahsRZjk_LAh*{2AQzKpZwl3Dc zqLpWY;3mX@dh%0JiKth0<{C$VgwyB1_=`d<)GREudVL`0b=|>LG_zW~dNtAJNYZcJ z(WAE_8X~;9Dg*$$dTvlsxw7UmXTylN14?%23IN4?Wht`i5XCI8K;9FTl4A)w1~)3X>#LT5=f3^$7O^bh8{)rl&d zk1<%kc>U(Qwobt>i!y+=!2kwdKJ)bhO1Uh%$#G2l&yOod-3)H)6-Hj4pTZ9ugI5vq z`Wn1Gs+Xj(%KglK;-nj2dicB);M4(|(K0?9V8AmzCbu4lUz))C>)vsu(FO9P+BAL= z`Nw;K&hx?h^^lsoN6QlN=L(dJ%}3x*JNS}oDVHPe$q#Gl3_=O}ByXd<(GcI$_|!M$ zDE65~x%Pe4VwE%a`Z@=!nA^d*p>$`h@=d+;2MjvUB-BIOp&!<;2|kl}z9$CiVzT+e z?f{IX;niQzTcTCKE>DY6*j9d>}~WFpt#zQP8fW6Gma~#z39g z#;}+4+qyNp-l3grIn7+NI`wGo<+i#@u*$bT<|R#HfNjYNh@n*It!(9IAzqBE9|oe1 zS%Tm&7{%M$DYf-v#t(%^DjVFJ54vUZ_I<`Y_iTDAe>PD*@0Y-yhTN4za6VipV+txn ze(ogk=~MRF+L`>eD08(v#?ZkN+w28<`GzPVQT|*y^mjiaQn?HXtR3H4k^^@CFr}tC zLQPdwDdV6Oyd&_1R7gT#86rY?z0A7)-Km?|IZd302)g_ z{o&~n3~jh?66euuo@yIXC?k9t#|;l`CWIk@sTV)HG4M?A3Y6ceeZU`T!q_ynVnS;dhNG`--teVXq2*JSMsDRLTBgx);lnQwjZSW|v7dmrC8g?)pB(DJ7UdB(?u~ z*r30EWDag9+LHJPz!c^qE6meTNzE`wl5mlK%_$N{eYLLQadRuKQa-mAO&Il(ahP`4 zREt^&es6OF}y*XUoUN&F^aXKOsBG99=OLZPtMlmX%(du{>?Ws{FV+%K*IpjO&p)Z91ksoY4||I##`as z){qAA-KC#HoJ<*9 z^M&Fi2Q2=mByOz>l?Y<4&S_Q%KOp~V!AL~F)KX;GWoCrr^V9?#D+)tin_+Ip$HChA z<}%g$qf!Gpd+m(kNA}l1gLO<;pQFa-*2$#je?sH`a7&x1VW>@P46+N!ZAPP|KfyWz zIby2iD3)(|=i*_10m&U}eWCT(>wxb7W`#!({biH{ovxZWb~RD;Bp6yF&OPE;gV!_Bl#x8K*mtJCte?)0@w=QgOeZVN;9swlc; zVP8Gy`I$46l<$w76@FTG)8gn#=!S9dgLKc=)V?j^Wa%Q^K zp(G9mc4p>xiq+T`i&l~G2ZpcdG5(AXu?m=l8r36IAM#f;Y{hvfHPoux;ff2h#g}!L z&)4@II%|IIN)uD*Ju{s9H$E8HmvD+haDeydvbReXNG#stHKgc3EbGaGTmC-ft-+|o zZd0ZH{4zn^4So0Qe^>oCJYVnoj+V-zA;)Y`Fr(Jgv5?$udMClQar(0gL+d+zZD-%7h>T%qSaV0Hn|V3Dt?G~ zE**7Wx8Hy=jm+c&^Z5I`zdS`$)^I;6blpUZQ)H{c$0lK56|)`RXHzYnSEXukaZ9zx zHcFdT8cH2m`tO|~PCv!ikCk*#$jr$RG&M83d;4ezy1u?%@z^B=gHO~Z!&&h0@zuah zPKTPZrN&~cD1&ke(E5k^rCy}tF$QWO=T34Ec(^Z9nCjNfcU z1@rE6sFe=RkOHyHSnNmDW*WRywBAMV6Wfu>(+$?P7n!qcEc|`_dom@iJMH4Mg(jx7 zg{wuHthAr-Oz8$%63jEK&IsgFJ#H_q->hEX?qtK#3pWqLvyORVlRpR_A0Njg^-N{Q z#%O*yyA8AJ5!D0n}Z zzlnFM-9N{w)GYmV^0{uWXgR~M+HO;Nbn@5nYWLh^yC8&5Ts-vs`}f)t;ypS%=;-VIyT%*50vos;tQ(YdfLrePR(;gvmLY^?hRPjMVU&xl*G5ZNgw9jWEaVS zFn?TY#ITn2Cw}d3;J?BVDsNg0Ynu1p@@baq3>F2jSxSQgGU@<^b+*q>Y{@QOqmB{|~h z0!3yqWS#ZmSkcWeM>Gc4uG%7to=@#%-pk*k|H$lfu~Z{c!;>IPLiMzY-a7(ySq#bW zYuDGR3pSNe>{jv)`adI0qvmU!bxA0q6@F?z06n?eS5xxaM$TM#_8OZ+Smp&d;Bg|+ zKfuCGmdRFrGsYfih6FXLDq%{o!Q@XuFF(Wq0leJD$VA6iMdWmW|ClRDY~c}=$v2hD zs9(~I=o6tXPt?y|@X#}`Y|dPuA%0IsfEQerHAwtD&skF5-mJ3`uW#`cVF{g}4Gtiy z3e#I}7z=l(Hq@<$o(LU%vc040_y5&Z2mWst+B@iT4L&s``inVYj7;(;RTdg|r=bEa zZW`qlpID@cyPm+AlSw}2d0^c2z|3wrzQJ9E%LU~*K!0tDiEY?`RV;{F>Z`vC6S=xH zpbo7(F7nrKK>Ky7>J@eSG!=+qWv zTA2EdWs7SB{(DvsFk74*LON`S#{DJ+bCb{kEH&uTQlm^lQ^P{SWwp)S+WhgI`G13J zjed?^MpQ>?7R=-m19p|qo}7eYjn}8l6jwWd{AOl`e@$JiL7LeY)}+Vl|J#R#a7eWG z-Q_1=cfwVN56Xe2@(xW?C!rbKwZfVB2cY?c^s7N@|812pK0}zmv$!NCvDV?% zS;NPc9lBc-^PLn;C5q2T;m~Suf0z!mn%5nK=`a{L`~4177`QY>U>#oP6)_<{MSV|P zaW_y38evjuNYw#znXG)-&6$NX@LjIK*Imx(31yfH{o7tk9v{V_Hc21^C$u39@w*9w z4#!?d44#^5+zofhU2pDir_W68q8JfT1DkX0`rl3H|2(e`P==z4f-JHvO3z7QLo#Z= znzH#^d)-ZZ@X~*HQ)k|%twx0TZ`Wf5)`iAupK7)sA3o?(aqFF8v6xj=pOlPcukuaT ztC_me8Us}|EhCMOq4_G+!AH-2mo+{mKswl0)tV$)fit+p&spKy%?3F*-Re7;&FnG= zv7R*4)>cdIrb@`*Qc%C>srcV0HK<2iMZ~5o$j!}ooR2cO?)&rCVk8*jh`4PDYa7g_ zELF;Xl3y?1#RC$0N=&M-_*j{A8a?TWS1doBM5g@tpwv=YbztxMF!6zwX{BXH;xVD}D@CJK5KEl_k>AYm#bI#nv)I28D zGM)4}&R{I?k$|3YFK5M#dF~}Z6)QkmC!AR@=9wM8xPR!84Lv13hyu;^TM*S|c=_hI zgcE3|_Wmur5bZgsW@H zSEl{PBiX~?{AXJsmrvis!s5-#LcGZS*i_!!OSdFyW|h^+K{UW%#*dV z(Wd_s!c)&X1Y1>=lo-)?PbBL0FNklj)UNoN=h?h!*4Fen6O!3aKR_US{B5Y+*ZN%% z-lh1(wqvSRRjOxW9EC*|9)ggFI@j(gnP5TyLH)f*`LK|HTh<%PHmjfC9f||WbMmHV z-&~C@kGpRV+>*Y2r2xLZQPX(Shaa&^!r}%zvsO7&r|&^mEu<01Y7rxRuP;abi5*G`;Q^l3izUJccd@Mwm;RF7k|Gre0%> z1U-pM5h-%b->b?=z%Mq%Pc=xd8|~u^-%6xS%;4zY{BrqsouUi-6rd_!?$gIE-CsnW zWD{#0CKhSCeFHww(r1O6R(PlCum&TyhitqwVaSr=Pu2GQS7RLmtS+3LB!=4duV#d5{c1~X24K_VJxXXL&R@aldKBO<&KPC7Y3?j3S$o&oQF8y*_MU?4Vct#{Z3 z9Q6x~8}R1NfX`AHOqFO<7!KmGF!NjclvzAw;*nLGyTK|1K0BeTs3E(4Qg2vTWM!ct z^{zO-S|y}{MLMT%ZRuRR(4x|V1=GNE$yEh_gcu%-C4nXdJlAkUm#5I!hvxv%uzd7B zctLcwMRYn5_`c(#!@|VWT*kpGq}x}tx0%AZxw&%c>IAe6ShBa@mJAlB_JU{Y5NB|? zgMJQNJ`V>qqxAq3MGAU|UO2xjZkgpwg_Nepxs*!1eYq7%Df>cw?lTXcl2R*l;75(c zGLJk#s0zzh&#Y*(WK--flu>%UJ^jF%=Vj`yZpN6jcYqX=gwDuub5P0)F7s84g!Crt zO836>kTj#9h?G#878JPS&?CMId#bp`g*?<__x9=A{bmdz{1Tz&$foYD)18JKx6Lh= ziBT*ESTlL~*k9%c!|Mm^4}j%-+rB4Y4HTCc-=E3<1nx#bLE(nrrq5&rh;|Nxxf933 zMTWbZa~8dh2;}T&$`@5c&hLiOIEBD~G{1tYc=MAMWoonO?~0aPB=wH@`1qr>MNIQr(eMf|VKTE?KneRW_T~ zlp!kp8_d@BPqW@?r?G6Jw&*^70b$-U-PvKUOi#5>DxFL0)qKnJeAKAE~7MW=3n6j3s-`u487h+HA>R)p0iBsBvPdbx>|-p8&E`8y2KW%aU7;QRTcslU%aH8U!~2V8 zw@VCUcqOiFGRAL@TSYCpG*Q%wl(^SBrm?!V-aX$)Y7n>$sCP~H!sqjo?p8T=AqJC@ zoAvR8F6_7mcwm4U5C)o@!}RkhMYI{FDgS1-;5u)h?mUUNZ{1U~Y64_M3i{7<+1RX6 zGF;3O!)~|Dh{=nw&4J@>&ap<#70iP9)at= zbg-s6og*tz{ewX;LnP9_iIIezy3h$9Xm8Ub3~4205eJ#1M#IKnuC93{6qlpuK*_l~ z@v!+s5J*V>oVfVbo{#*%yY3F=ej%6X?v?$XmDrq?(1>h|dY~QFF`=6Hhq#~OKb%7_ z2q@3(38*iay<7Rm8&y?{WvWf3W{u7>GWJXc-yK%rdprK`&S!||Wx~Vdx-xf&anScE z8Fy-GO6OgWZp{T_wVxmD{Nfjl|C^gF3}9|bRy!9t5Vyz`FQHX6B|D#<5V@N)Z9%BF zM?j~H_|zP6*hqgTCQHKG>TrugbFbLX`xyj5Y6L9SP`2@4myMNhP#fQPKEG;UtXcnm zrUex!^tV1r@8zQTJQ)lUYr+PNv_fy#$A$mo6gnqKlDN{68NQO1oRgZ z{Crkb9675h$!BPGe$bZ5gX(xZwvQqQ&BkH>=l*XW@BcWSL~m3bkWOI5SPnVd8#k+l z^TpGa`?`g-HhQWNE9jFSf8#@z4k*-f5h4~|&8W#T1`T)x);J$1fj|5qa$z!fZ zZ&Qm&t1>J6D-hO!LRQujW}*Jpjcz9=uerCX>_hJp#**0juF39ls*dhn-~Mj9*myM= z)Sd9(M=3kt3MMB}y%qsV!!Xc3@`&|v<{MX&v^}%aR__P)2l!-L_2KED%)md7_%28? zN_n1>!bFxCsl`xSl1O-Kx?Vo@L5x4d;|km_P5379>yKa;9O74 z13BCsfE__ZlhmIMTHL>Oq~8Bm+y*lMGdj(Y&ZLrBA82?z*izejxHv<1E_I;@Dn6Rg zyX_;y35%)1lHN?>f-S&$9Fh0#)2l$IJ6iM^7*{{&15Z<*sfI;EzRE)E`jFLE5~K9# z+Wj|w&)a+?fse!4((cj)#bB%oT2+~*Hru;1*7b9>fsKU4=Ts#BLK1voNbib3VT2>f z5a}8dp{7kBLD8-p70z*2^J>euiqfaUgx^Uaq z0Q`^nVQy8BxbENK4)ziVjrvVx5eKPK4vT`VmaC7#u)#Y|dWZfz6x|nCeTc@3U*ZxI zfhLcilVZbVc=-OUb`8B?pcX6~sD*1l@Hbvbk_Uiukw-@f@$q6pLIF$F#;x@Lz@}1Z z7@m`#f6K7jm0Vq2%>Z32)A2Ka#-aOkf}>OB)(v?@Ik?GFha3#3w?AXGwIMgi}X!RF^Cc)@FC>0N0{%teP4{*^(<;{&wQAoG~@YJ03#2N zrt9TFcnF%$&Wyf+MWsDho9m3M^{GIl$f@LMpi~{>}*P^ zDKni|t+|z)5(HM;I+Ei2O*#NjBJc>5tnroM{|!qZhXNRG(a_s9vOnn*h&;hy;N|%f zAUZzz@48kMZ*S;KfcoDg%?IOy_!7u>QBT?||KJ{v7k>m$0*;2LFOR|d`)25*z^z}} zl1!ugcO9$|ES#bdj_u!=0+s-oAh1OUpu#u&JDwntEWZKZ^*w>N|L)_>=l__AdwjIn z{y=4qh|^bK-2lr}==Umm{3{q82slJ`^u52){TNV{FtCuy=XL-4u`kJhv_*>GyYL5O zf4t5b0Jt3fKbs^U!Ogwt!G&tG+9KuSX!oOg@xr2_7$nTDzFumXI40_dhtBU-F#oT) z>kMjY?bd*BP*Ln6H7Wv1S3pF%bg4@3AWH8TLJug)Q3O;v5rU|I1Oif}Bw(S0jvTF|+P!S{sJ3cHP=jxW?J2-!UIo zEe|k|?+sgtce9nCQf&O*y^boeAA0IPuH<)e+84imtyv3da+km-^)=<;OOk?dI!wVQQMl6(~sIe2iOa4foriZj(~ckrKK|f zBc!W+z~wFBZD;3A^>;s%5dcj9SfI^Un@SujA*XqG5?+9W8BzD{aqwh@Ta7=BtpSSk zDywn>YL_&^oT{p-w7WR@EBwi5)zL?OK+1vVjG}a==k@%W70LhdWG2BLJBvf0CV$nI zX=rL{jz|`c>~|Ub>#J)p&{2@P{_eS>O)T)L7Aro`^H&_zi~jxjn!R$-6AkBG=J8Uv zgJ+`8e6OTZ`HdtzOQib9+94rTxRUYmL*tEftgFzj-bS5UPO7Go%heSW ze4FaGHsLzLd*#AtpAX&q9rLTaOA8Ay<1olB39$@>NX)Nq4OC~Y4lwWcTScfQM6AsX z#fU0p#D>E9)z4mDioW3qfyKwOha`0^I(AS(2KCD|y}Z1nGXs%mX4=T$;Mb|{?Do#p z`#$4#6a4RMYw_zwpb%hN0CIk(O~P*NNsFHXuh~BoWqdX?I+Uj)!;E{${87f!l4^4O z(lqiw9G=&388FHRY}%9f#lpkGKUknl-go^@^R|+cK@JA0wGOF|4_U2IHUNVa;Bp>964rg^7$gEs?qyXhA6UNrPFn2mqzXdX zG}V9sX-QhD=&_^h@dbbOmoBR{(4s`l^kt2YT54OHXmfP{YBTv27;nwUa24Hs$*~$? z9KgfP%1Zw`%c=NYw6H`0b>t{ewIAB=My6+DH753dJnv$;01L=-C7P{Li!!n^@VX<% zY)}zpl2K=!mogGA7Hu)cJ%+8 z(HBx0*gJPYpf&DYp@AsCKwfppgvU1_^F zXC(kv8ki?p_Ke2gzJQnLP*a-|+||!$1mKbga5kV)`5(!^Q?>97wI`soc7NsoV8yd^ za~77}^_{H*+>h|qMg?T|=0NZUu#C`yEW1+nYu5sP^#9Y#xuCnYpu;5L1_y!BS8vFI ze%{_o%JA&nCtzVibBu9kPefh~VLRYERH-8_F1~$wq!CiSGJ|ao;C0^%5&Vw=U7a(| zX+73=a&qd&y?dnKDL(aJW@hF%56_fJ>Yp%y6{d0H#sh$DJ5e!g(GXSup^;eYDu9NV zi0-dXPVCD7yTt+K6DW|Xx_j>)ro#2h5B;MRy});;w{PD*zmJo9uAft7^r>5MOOAq1 z>4MT5y_5_9E<00VJ7od1^K}BIG@Bb38>a)MW#HbMZoX0gky`ni4W!b6F%ox5Lvzh@ zllT?w;@W=w{30SJR~2H#r{yyMDpO}}xFOE@e2zbG*`lf!@sufE)@%Zi55ijNwO&lw z-C$>rS^=A-BDm`EuG(4Rk^!bX)Z*FFpO@weWW|Ri6#4fy%hBm#K=J=ef>PY)i_cJY zeKafkXK@WKADjB2J>MIvp$BX{WpUl`KZ~pY7Wp_Gmb7=`_TRSv*z-4c()4zB^wIqY zV3Gd^f;6tN11*o=qS0ysu%|PA0xDrt5-YQC9JzRL8WOp1Oez6&Y0j5gI`Dm{#7F;LRTlhG%!UBV?+R4PpwYr zUMJ5)nmM}PpNVvAyx00PW#ZwwH86}g-6%`e%0w-^;F(fT+ShjF9n{QU1SoSHJd;i^ z7c~(e^cu+xM$QhJQ@Wh<#;&R7$nC?ani18QVKq&glges|tp=~QC-D4boA2J;3+_sR zig)EP)BBbN)+W=6?K9MQ0IdOLpzjdl9|-trWrY!fq4Gr{;c^0f;+@G0lC#VdXh3YU()e5_X>LZ5 zlv^G1ZlMbO`*{Se=AA|w+ID^DuWZ?%(r1E-qSJh(cB{_e6F=K2n4PXVHp{*T9!AQ> zD`es%SLmWvN_PHypkY>T3YaiFY!^Ndb!nRk>FWRj%;|GMtv(+3Xa+qg7;8+p>26gh zF{hFqwe*DCnKdBn2j|T#YW6phhxbX8Z^S*jvh4Q35?w_V>Qv|n5h&PPMXi%EkEn0h z#H@Z5ewp>1PwlYKl@7X*agO{ay3W|pR19^Qr@|AtGeCk&kx!itNcEdtRHmmZR(j7&oBFy zcd`q~wcspCZkRqeAk#^fu@`r1yWfp)CSby#}32w7}6#0hsenCkdcPN0@qq* zN%*T|!hQKcA_1MC43Aa#F(yrRa&t?K+W6T%Ia%@~M&$VMT>Du}#Uo9Yaw?e9qpj{0to|&NBTysMAV_o-} z!zHud4zSxfAA`ZLQp?r1PWuKhFv_EB$7}%n!y;lGbAlm3|Dpy#ME|OLs%s=`)S|QRy2@L`=IECQY-up0wu>R5P2<-6r5P?ae@<_vs=-)bb zN;*8?Wtm(uD$YzgA+rq1S0O4UV&LbP`$|5j$zkUqq7e-H&Kl3M8YW2Er8t%$G%^09=JwXa}^>n#K1=?nF)5X5# zUj-^_i8^>$d30xMFtaVmexkK`2fID+-So6f9{C&UhW()(7hU+q0>r&TE0MMwnHe&K z4z6FtX9uqldY>E0Fx)KyIlObpwPh{Ao6SW+tM(v80&(97;IvnjuKP$v39a%u*)^CR zloT! z$dC-r7rFLpP_G>tHHL}WWUlbv+gv(g*XF9~n>K%GavkS3an-u{qMy6a_51oupm2v; zSO&En$3=#J;(WJBdhJZfoeWamXQgK^{*R80``dO)p_#+XXaI2}_2p0Rl=)AXxy zMj8E?FLSW*8Oi1TmR}X6X}R}lwu|LfA{8W8tBpQ&t6d?_P4}{w>|v>Zn(&tnr37M#|NT`7Rv&Awsu);o z@~t`_f{=Y)`GIK3_Zcbpx9U67v9wQb(US(&k!vSybt2d1&M7~X{g&Y!eLICTvmi=v zQK@-G`V_MXk*@XkZJeE@*=jvfT||)oY7Fa6$HE>x;B7N(FiPnk?!PZdq7>#WBpob8_C@Nb-V z&%Qb@flmuXWhn;Ks_bl&%Up2}9Q7pABR{Lm7}#^|6A`Q|hBl+6T31rui=?QzBzeEi zj_7r~KR~!Q*&SJ0gs3_B?13V5Yec=2|18Z}Ru5UfcC$UEfrJYTA)*9TXj;HCEe$Zk z9T=FLe57FeA;CMuzbTg!A2mc0D4sy+q-qXbnz&}`5nV$x74^GReafo+4i2B^X?*>D zCKKDJ&k!j?aCk0rLG5D1$cVZ}8l~Z>c}VX5Rfu zdfRjE!wr$Q-I94FsXfJ~;=YxnSa$^pKdw264KY_wB+bnFpgo9pj~h$j={95!=1Noj zYe{FITJfi&Ae=^YfQY4;m3uWHJ$F>r=125ueYv;U5yt$AiJ?z^nG85VXbl2A>+Gd^ z=Hz!HYScNoiTWw#WL3H)bTUtQJ*Ix+Wq>xSpMbtY6_lH52d*rkOis`H7T}FmJZ~RFAK?K*7dKAJORKeVoRl8;LU!Sx% zDHyRBp3bpbAG+nBF`6r4Mzpta!?_UV*9I(WaxOQ_xUo(ULUjE^w(sVnG{E1)bF zD}Et`jn^!t+ApM;symXA*N4jt6ZEiR^4_w1Mo%*uIL$Ky(_C#3ffuWiLfOXKl3v*R z*MDMnO35nHNbkxzVfuW2CZni31o18Id}X#7j4y9XI6Z0_|+@)Q?3HFlXh~ zW?NVprJU!dwT=NF^xW9QxAIgYa%#u^8$e{#W%e&78)tlDYg9><8!LsJZ@K>bb?rJ4 zl3BkvX_qbQT56=?<-nPdyinjQKdwD%3Ok~jLl3xBh|QP~Y}U_f?SUl2O^Y-J-BVS8SGrm*Lf# z(5C}^h`{OG41QMFHy%C(_Y)#3>kDmcemNe1oCiSatK-PR9I-2>J4vY4U`l!>M#-=3 z+L~oRPfB6K7DDgiWNUUKllzg6n=dD;7%M6y{Nd(j0T&^K62&8}!2pDDS^d&{?Smc? zDw4y|(h>PtMO!n$^IM{WEXEl5Da@=*I{(dI)1W6k?#1PxsSU2(Haa)ze0I<$ zXQxJ&`J`_ RL models currently only support CPU and single GPU training with `distributed_backend=dp`. Full GPU support will be added in later updates. +------------ DQN Models ---------- @@ -86,7 +87,7 @@ Example:: trainer = Trainer() trainer.fit(dqn) -.. autoclass:: pl_bolts.models.rl.dqn_model.DQN +.. autoclass:: pl_bolts.models.rl.DQN :noindex: --------------- @@ -150,7 +151,7 @@ Example:: trainer = Trainer() trainer.fit(ddqn) -.. autoclass:: pl_bolts.models.rl.double_dqn_model.DoubleDQN +.. autoclass:: pl_bolts.models.rl.DoubleDQN :noindex: --------------- @@ -240,7 +241,7 @@ Example:: trainer = Trainer() trainer.fit(dueling_dqn) -.. autoclass:: pl_bolts.models.rl.dueling_dqn_model.DuelingDQN +.. autoclass:: pl_bolts.models.rl.DuelingDQN :noindex: -------------- @@ -326,7 +327,7 @@ Example:: trainer = Trainer() trainer.fit(noisy_dqn) -.. autoclass:: pl_bolts.models.rl.noisy_dqn_model.NoisyDQN +.. autoclass:: pl_bolts.models.rl.NoisyDQN :noindex: -------------- @@ -519,7 +520,7 @@ Example:: trainer = Trainer() trainer.fit(per_dqn) -.. autoclass:: pl_bolts.models.rl.per_dqn_model.PERDQN +.. autoclass:: pl_bolts.models.rl.PERDQN :noindex: @@ -611,7 +612,7 @@ Example:: trainer = Trainer() trainer.fit(reinforce) -.. autoclass:: pl_bolts.models.rl.reinforce_model.Reinforce +.. autoclass:: pl_bolts.models.rl.Reinforce :noindex: -------------- @@ -664,5 +665,102 @@ Example:: trainer = Trainer() trainer.fit(vpg) -.. autoclass:: pl_bolts.models.rl.vanilla_policy_gradient_model.VanillaPolicyGradient +.. autoclass:: pl_bolts.models.rl.VanillaPolicyGradient + :noindex: + +-------------- + +Actor-Critic Models +------------------- +The following models are based on Actor Critic. Actor Critic conbines the approaches of value-based learning (the DQN family) +and the policy-based learning (the PG family) by learning the value function as well as the policy distribution. This approach +updates the policy network according to the policy gradient, and updates the value network to fit the discounted rewards. + +Actor Critic Key Points: + - Actor outputs a distribution of actions for controlling the agent + - Critic outputs a value of current state for policy update suggestion + - The addition of critic allows the model to do n-step training instead of generating an entire trajectory + +Advantage Actor Critic (A2C) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +(Asynchronous) Advantage Actor Critic model introduced in `Asynchronous Methods for Deep Reinforcement Learning `_ +Paper authors: Volodymyr Mnih, Adrià Puigdomènech Badia, Mehdi Mirza, Alex Graves, Timothy P. Lillicrap, Tim Harley, David Silver, Koray Kavukcuoglu + +Original implementation by: `Jason Wang `_ + +Advantage Actor Critic (A2C) is the classical actor critic approach in reinforcement learning. The underlying neural +network has an actor head and a critic head to output action distribution as well as value of current state. Usually the +first few layers are shared by the two heads to prevent learning similar stuff twice. It builds upon the idea of using a +baseline of average reward to reduce variance (in VPG) by using the critic as a baseline which could theoretically have +better performance. + +The algorithm can use an n-step training approach instead of generating an entire trajectory. The algorithm is as follows: + +1. Initialize our network. +2. Rollout n steps and save the transitions (states, actions, rewards, values, dones). +3. Calculate the n-step (discounted) return by bootstrapping the last value. + +.. math:: + + G_{n+1} = V_{n+1}, G_t = r_t + \gamma G_{t+1} \ \forall t \in [0,n] + +4. Calculate actor loss using values as baseline. + +.. math:: + + L_{actor} = - \frac1n \sum_t (G_t - V_t) \log \pi (a_t | s_t) + +5. Calculate critic loss using returns as target. + +.. math:: + L_{critic} = \frac1n \sum_t (V_t - G_t)^2 + +6. Calculate entropy bonus to encourage exploration. + +.. math:: + + H_\pi = - \frac1n \sum_t \pi (a_t | s_t) \log \pi (a_t | s_t) + +7. Calculate total loss as a weighted sum of the three components above. + +.. math:: + + L = L_{actor} + \beta_{critic} L_{critic} - \beta_{entropy} H_\pi + +8. Perform gradient descent to update our network. + +.. note:: + The current implementation only support discrete action space, and has only been tested on the CartPole environment. + +A2C Benefits +~~~~~~~~~~~~~~~ + +- Combines the benefit from value-based learning and policy-based learning + +- Further reduces variance using the critic as a value estimator + +A2C Results +~~~~~~~~~~~~~~~~ + +Hyperparameters: + +- Batch Size: 32 +- Learning Rate: 0.001 +- Entropy Beta: 0.01 +- Critic Beta: 0.5 +- Gamma: 0.99 + +.. image:: _images/rl_benchmark/cartpole_a2c_results.jpg + :width: 300 + :alt: A2C Results + +Example:: + + from pl_bolts.models.rl import AdvantageActorCritic + a2c = AdvantageActorCritic("CartPole-v0") + trainer = Trainer() + trainer.fit(a2c) + +.. autoclass:: pl_bolts.models.rl.AdvantageActorCritic :noindex: diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py index 070ec666be..a84b51dec6 100644 --- a/pl_bolts/models/rl/__init__.py +++ b/pl_bolts/models/rl/__init__.py @@ -1,12 +1,14 @@ -from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401 -from pl_bolts.models.rl.dqn_model import DQN # noqa: F401 -from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401 -from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401 -from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401 -from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401 -from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401 +from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic +from pl_bolts.models.rl.double_dqn_model import DoubleDQN +from pl_bolts.models.rl.dqn_model import DQN +from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN +from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN +from pl_bolts.models.rl.per_dqn_model import PERDQN +from pl_bolts.models.rl.reinforce_model import Reinforce +from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient __all__ = [ + "AdvantageActorCritic", "DoubleDQN", "DQN", "DuelingDQN", diff --git a/pl_bolts/models/rl/advantage_actor_critic_model.py b/pl_bolts/models/rl/advantage_actor_critic_model.py new file mode 100644 index 0000000000..9f2a835a0d --- /dev/null +++ b/pl_bolts/models/rl/advantage_actor_critic_model.py @@ -0,0 +1,327 @@ +""" +Advantage Actor Critic (A2C) +""" +from argparse import ArgumentParser +from collections import OrderedDict +from typing import Any, Iterator, List, Tuple + +import numpy as np +import torch +from pytorch_lightning import LightningModule, seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from torch import optim as optim +from torch import Tensor +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from pl_bolts.datamodules import ExperienceSourceDataset +from pl_bolts.models.rl.common.agents import ActorCriticAgent +from pl_bolts.models.rl.common.networks import ActorCriticMLP +from pl_bolts.utils import _GYM_AVAILABLE +from pl_bolts.utils.warnings import warn_missing_pkg + +if _GYM_AVAILABLE: + import gym +else: # pragma: no cover + warn_missing_pkg("gym") + + +class AdvantageActorCritic(LightningModule): + """ + PyTorch Lightning implementation of `Advantage Actor Critic + `_ + + Paper Authors: Volodymyr Mnih, Adrià Puigdomènech Badia, et al. + + Model implemented by: + + - `Jason Wang `_ + + Example: + >>> from pl_bolts.models.rl import AdvantageActorCritic + ... + >>> model = AdvantageActorCritic("CartPole-v0") + """ + + def __init__( + self, + env: str, + gamma: float = 0.99, + lr: float = 0.001, + batch_size: int = 32, + avg_reward_len: int = 100, + entropy_beta: float = 0.01, + critic_beta: float = 0.5, + epoch_len: int = 1000, + **kwargs: Any, + ) -> None: + """ + Args: + env: gym environment tag + gamma: discount factor + lr: learning rate + batch_size: size of minibatch pulled from the DataLoader + batch_episodes: how many episodes to rollout for each batch of training + avg_reward_len: how many episodes to take into account when calculating the avg reward + entropy_beta: dictates the level of entropy per batch + critic_beta: dictates the level of critic loss per batch + epoch_len: how many batches before pseudo epoch + """ + super().__init__() + + if not _GYM_AVAILABLE: # pragma: no cover + raise ModuleNotFoundError("This Module requires gym environment which is not installed yet.") + + # Hyperparameters + self.save_hyperparameters() + self.batches_per_epoch = batch_size * epoch_len + + # Model components + self.env = gym.make(env) + self.net = ActorCriticMLP(self.env.observation_space.shape, self.env.action_space.n) + self.agent = ActorCriticAgent(self.net) + + # Tracking metrics + self.total_rewards = [0] + self.episode_reward = 0 + self.done_episodes = 0 + self.avg_rewards = 0.0 + self.avg_reward_len = avg_reward_len + self.eps = np.finfo(np.float32).eps.item() + self.batch_states: List = [] + self.batch_actions: List = [] + self.batch_rewards: List = [] + self.batch_masks: List = [] + + self.state = self.env.reset() + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Passes in a state x through the network and gets the log prob of each action + and the value for the state as an output + + Args: + x: environment state + + Returns: + action log probabilities, values + """ + if not isinstance(x, list): + x = [x] + + if not isinstance(x, Tensor): + x = torch.tensor(x, device=self.device) + + logprobs, values = self.net(x) + return logprobs, values + + def train_batch(self) -> Iterator[Tuple[np.ndarray, int, Tensor]]: + """ + Contains the logic for generating a new batch of data to be passed to the DataLoader + + Returns: + yields a tuple of Lists containing tensors for + states, actions, and returns of the batch. + + Note: + This is what's taken by the dataloader: + states: a list of numpy array + actions: a list of list of int + returns: a torch tensor + """ + while True: + for _ in range(self.hparams.batch_size): + action = self.agent(self.state, self.device)[0] + + next_state, reward, done, _ = self.env.step(action) + + self.batch_rewards.append(reward) + self.batch_actions.append(action) + self.batch_states.append(self.state) + self.batch_masks.append(done) + self.state = next_state + self.episode_reward += reward + + if done: + self.done_episodes += 1 + self.state = self.env.reset() + self.total_rewards.append(self.episode_reward) + self.episode_reward = 0 + self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:])) + + _, last_value = self.forward(self.state) + + returns = self.compute_returns(self.batch_rewards, self.batch_masks, last_value) + for idx in range(self.hparams.batch_size): + yield self.batch_states[idx], self.batch_actions[idx], returns[idx] + + self.batch_states = [] + self.batch_actions = [] + self.batch_rewards = [] + self.batch_masks = [] + + def compute_returns( + self, + rewards: List[float], + dones: List[bool], + last_value: Tensor, + ) -> Tensor: + """ + Calculate the discounted rewards of the batched rewards + + Args: + rewards: list of rewards + dones: list of done masks + last_value: the predicted value for the last state (for bootstrap) + + Returns: + tensor of discounted rewards + """ + g = last_value + returns = [] + + for r, d in zip(rewards[::-1], dones[::-1]): + g = r + self.hparams.gamma * g * (1 - d) + returns.append(g) + + # reverse list and stop the gradients + returns = torch.tensor(returns[::-1]) + + return returns + + def loss( + self, + states: Tensor, + actions: Tensor, + returns: Tensor, + ) -> Tensor: + """ + Calculates the loss for A2C which is a weighted sum of + actor loss (MSE), critic loss (PG), and entropy (for exploration) + + Args: + states: tensor of shape (batch_size, state dimension) + actions: tensor of shape (batch_size, ) + returns: tensor of shape (batch_size, ) + """ + + logprobs, values = self.net(states) + + # calculates (normalized) advantage + with torch.no_grad(): + # critic is trained with normalized returns, so we need to scale the values here + advs = returns - values * returns.std() + returns.mean() + # normalize advantages to train actor + advs = (advs - advs.mean()) / (advs.std() + self.eps) + # normalize returns to train critic + targets = (returns - returns.mean()) / (returns.std() + self.eps) + + # entropy loss + entropy = -logprobs.exp() * logprobs + entropy = self.hparams.entropy_beta * entropy.sum(1).mean() + + # actor loss + logprobs = logprobs[range(self.hparams.batch_size), actions] + actor_loss = -(logprobs * advs).mean() + + # critic loss + critic_loss = self.hparams.critic_beta * torch.square(targets - values).mean() + + # total loss (weighted sum) + total_loss = actor_loss + critic_loss - entropy + return total_loss + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> OrderedDict: + """ + Perform one actor-critic update using a batch of data + + Args: + batch: a batch of (states, actions, returns) + """ + states, actions, returns = batch + loss = self.loss(states, actions, returns) + + log = { + "episodes": self.done_episodes, + "reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + } + return OrderedDict({ + "loss": loss, + "avg_reward": self.avg_rewards, + "log": log, + "progress_bar": log, + }) + + def configure_optimizers(self) -> List[Optimizer]: + """Initialize Adam optimizer""" + optimizer = optim.Adam(self.net.parameters(), lr=self.hparams.lr) + return [optimizer] + + def _dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + dataset = ExperienceSourceDataset(self.train_batch) + dataloader = DataLoader(dataset=dataset, batch_size=self.hparams.batch_size) + return dataloader + + def train_dataloader(self) -> DataLoader: + """Get train loader""" + return self._dataloader() + + def get_device(self, batch) -> str: + """Retrieve device currently being used by minibatch""" + return batch[0][0][0].device.index if self.on_gpu else "cpu" + + @staticmethod + def add_model_specific_args(arg_parser: ArgumentParser) -> ArgumentParser: + """ + Adds arguments for A2C model + + Args: + arg_parser: the current argument parser to add to + + Returns: + arg_parser with model specific cargs added + """ + + arg_parser.add_argument("--entropy_beta", type=float, default=0.01, help="entropy coefficient") + arg_parser.add_argument("--critic_beta", type=float, default=0.5, help="critic loss coefficient") + arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch") + arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") + arg_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") + arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") + arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") + + arg_parser.add_argument( + "--avg_reward_len", + type=int, + default=100, + help="how many episodes to include in avg reward", + ) + + return arg_parser + + +def cli_main() -> None: + parser = ArgumentParser(add_help=False) + + # trainer args + parser = Trainer.add_argparse_args(parser) + + # model args + parser = AdvantageActorCritic.add_model_specific_args(parser) + args = parser.parse_args() + + model = AdvantageActorCritic(**args.__dict__) + + # save checkpoints based on avg_reward + checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True) + + seed_everything(123) + trainer = Trainer.from_argparse_args(args, deterministic=True, checkpoint_callback=checkpoint_callback) + trainer.fit(model) + + +if __name__ == "__main__": + cli_main() diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index e692c8becb..057108b702 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -137,3 +137,33 @@ def __call__(self, states: Tensor, device: str) -> List[int]: actions = [np.random.choice(len(prob), p=prob) for prob in prob_np] return actions + + +class ActorCriticAgent(Agent): + """Actor-Critic based agent that returns an action based on the networks policy""" + + def __call__(self, states: Tensor, device: str) -> List[int]: + """ + Takes in the current state and returns the action based on the agents policy + + Args: + states: current state of the environment + device: the device used for the current batch + + Returns: + action defined by policy + """ + if not isinstance(states, list): + states = [states] + + if not isinstance(states, Tensor): + states = torch.tensor(states, device=device) + + logprobs, _ = self.net(states) + probabilities = logprobs.exp().squeeze(dim=-1) + prob_np = probabilities.data.cpu().numpy() + + # take the numpy values and randomly select action based on prob distribution + actions = [np.random.choice(len(prob), p=prob) for prob in prob_np] + + return actions diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 48a9380f9c..3a88931398 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -93,6 +93,40 @@ def forward(self, input_x): return self.net(input_x.float()) +class ActorCriticMLP(nn.Module): + """ + MLP network with heads for actor and critic + """ + + def __init__(self, input_shape: Tuple[int], n_actions: int, hidden_size: int = 128): + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ + super().__init__() + + self.fc1 = nn.Linear(input_shape[0], hidden_size) + self.actor_head = nn.Linear(hidden_size, n_actions) + self.critic_head = nn.Linear(hidden_size, 1) + + def forward(self, x) -> Tuple[Tensor, Tensor]: + """ + Forward pass through network. Calculates the action logits and the value + + Args: + x: input to network + + Returns: + action log probs (logits), value + """ + x = F.relu(self.fc1(x.float())) + a = F.log_softmax(self.actor_head(x), dim=-1) + c = self.critic_head(x) + return a, c + + class DuelingMLP(nn.Module): """ MLP network with duel heads for val and advantage diff --git a/tests/models/rl/integration/test_actor_critic_models.py b/tests/models/rl/integration/test_actor_critic_models.py new file mode 100644 index 0000000000..dd21b4505a --- /dev/null +++ b/tests/models/rl/integration/test_actor_critic_models.py @@ -0,0 +1,27 @@ +import argparse + +from pytorch_lightning import Trainer + +from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic + + +def test_a2c(): + """Smoke test that the A2C model runs""" + + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) + args_list = [ + "--env", + "CartPole-v0", + ] + hparams = parent_parser.parse_args(args_list) + + trainer = Trainer( + gpus=0, + max_steps=100, + max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early + val_check_interval=1, # This just needs 'some' value, does not effect training right now + fast_dev_run=True + ) + model = AdvantageActorCritic(hparams.env) + trainer.fit(model) diff --git a/tests/models/rl/test_scripts.py b/tests/models/rl/test_scripts.py index ee30206718..829049fc19 100644 --- a/tests/models/rl/test_scripts.py +++ b/tests/models/rl/test_scripts.py @@ -126,3 +126,18 @@ def test_cli_run_rl_vanilla_policy_gradient(cli_args): cli_args = cli_args.strip().split(' ') if cli_args else [] with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() + + +@pytest.mark.parametrize('cli_args', [ + ' --env CartPole-v0' + ' --max_steps 10' + ' --fast_dev_run 1' + ' --batch_size 10', +]) +def test_cli_run_rl_advantage_actor_critic(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.advantage_actor_critic_model import cli_main + + cli_args = cli_args.strip().split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() diff --git a/tests/models/rl/unit/test_a2c.py b/tests/models/rl/unit/test_a2c.py new file mode 100644 index 0000000000..79805f1b62 --- /dev/null +++ b/tests/models/rl/unit/test_a2c.py @@ -0,0 +1,55 @@ +import argparse + +import torch +from torch import Tensor + +from pl_bolts.models.rl.advantage_actor_critic_model import AdvantageActorCritic + + +def test_a2c_loss(): + """Test the reinforce loss function""" + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) + args_list = [ + "--env", + "CartPole-v0", + "--batch_size", + "32", + ] + hparams = parent_parser.parse_args(args_list) + model = AdvantageActorCritic(**vars(hparams)) + + batch_states = torch.rand(32, 4) + batch_actions = torch.rand(32).long() + batch_qvals = torch.rand(32) + + loss = model.loss(batch_states, batch_actions, batch_qvals) + + assert isinstance(loss, Tensor) + + +def test_a2c_train_batch(): + """Tests that a single batch generates correctly""" + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = AdvantageActorCritic.add_model_specific_args(parent_parser) + args_list = [ + "--env", + "CartPole-v0", + "--batch_size", + "32", + ] + hparams = parent_parser.parse_args(args_list) + model = AdvantageActorCritic(**vars(hparams)) + + model.n_steps = 4 + model.hparams.batch_size = 1 + xp_dataloader = model.train_dataloader() + + batch = next(iter(xp_dataloader)) + + assert len(batch) == 3 + assert len(batch[0]) == model.hparams.batch_size + assert isinstance(batch, list) + assert isinstance(batch[0], Tensor) + assert isinstance(batch[1], Tensor) + assert isinstance(batch[2], Tensor) diff --git a/tests/models/rl/unit/test_agents.py b/tests/models/rl/unit/test_agents.py index 2f25e54de4..ca37d864dd 100644 --- a/tests/models/rl/unit/test_agents.py +++ b/tests/models/rl/unit/test_agents.py @@ -7,7 +7,7 @@ import torch from torch import Tensor -from pl_bolts.models.rl.common.agents import Agent, PolicyAgent, ValueAgent +from pl_bolts.models.rl.common.agents import ActorCriticAgent, Agent, PolicyAgent, ValueAgent class TestAgents(TestCase): @@ -61,3 +61,15 @@ def test_policy_agent(self): action = policy_agent(self.states, self.device) self.assertIsInstance(action, list) self.assertEqual(action[0], 1) + + +def test_a2c_agent(): + env = gym.make("CartPole-v0") + logprobs = torch.nn.functional.log_softmax(Tensor([[0.0, 100.0]])) + net = Mock(return_value=(logprobs, Tensor([[1]]))) + states = [env.reset()] + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + a2c_agent = ActorCriticAgent(net) + action = a2c_agent(states, device) + assert isinstance(action, list) + assert action[0] == 1