From bcd67c7bcd73d9946473bf9a333d14a729c5e2a9 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 16:03:21 +0200 Subject: [PATCH 01/13] fix tests --- .../components/audio/test_whisper_local.py | 2 +- .../components/audio/test_whisper_remote.py | 2 +- test/preview/test_files/audio/answer.wav | Bin 0 -> 29228 bytes 3 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 test/preview/test_files/audio/answer.wav diff --git a/test/preview/components/audio/test_whisper_local.py b/test/preview/components/audio/test_whisper_local.py index 1acb47878a..745ece37c5 100644 --- a/test/preview/components/audio/test_whisper_local.py +++ b/test/preview/components/audio/test_whisper_local.py @@ -157,7 +157,7 @@ def test_transcribe_stream(self): assert results == [expected] @pytest.mark.integration - def test_whisper_local_transcriber(preview_samples_path): + def test_whisper_local_transcriber(self, preview_samples_path): comp = LocalWhisperTranscriber(model_name_or_path="medium") comp.warm_up() output = comp.run( diff --git a/test/preview/components/audio/test_whisper_remote.py b/test/preview/components/audio/test_whisper_remote.py index 26851845a2..ad71c9a5c7 100644 --- a/test/preview/components/audio/test_whisper_remote.py +++ b/test/preview/components/audio/test_whisper_remote.py @@ -217,7 +217,7 @@ def test_custom_api_base(self, mock_request, preview_samples_path): reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration - def test_whisper_remote_transcriber(preview_samples_path): + def test_whisper_remote_transcriber(self, preview_samples_path): comp = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY")) output = comp.run( diff --git a/test/preview/test_files/audio/answer.wav b/test/preview/test_files/audio/answer.wav new file mode 100644 index 0000000000000000000000000000000000000000..874eab01fede4e6b7fe24270d70ff85a87f5caa3 GIT binary patch literal 29228 zcmX6_1$-38+n$--J&PwKKyY{0qQ%|a3&q{twP}F zg;bOX_20w9-%MrUi{lvr|5Ch{@xN>S|MLvuBoQPWzl$OrAP9c?~zNl}wm#_H#L4Cr%-|-0Gy`*yf?+Tv(T_F}%N<-3;SlpW% zR}3c}tc>9fL%1XIzu!6jcWo#BnjVh~BsHF6;9dWFej5B7Mt-VG>IB{;AL&WDldR;F z`b%w7j5H>l$?v%GGxaZaB?hmKCjoT_&pn{7<2R{EZru5OJa>UQs)8yl$%X5lQ|r|l zby7W7@6~Jd3Oi(BJxA0?HAOv8-N*}apR^$BRXNoSYq>^>&}jM>d7(zDq3R!{;~lG$ z-$)@6NldIl$NkhNgUB#agQUY9)F9nRQ>^&Bny7lK)i|>vcz$}^#cb6=wO1Qd1Zha> zV9ndq1ho=-Rvv%PhWiSt47gGk(wt-_Pt{&L^AzqZBYyg#rl~S2o2sjJW2fKXFD+CJ z)l>~sE7TdR_oAAiN~$odU_YtADr!Zv;w*`j$0{n)!mKkpO2?4WsJE{dFTx;!rL zOTWyg4ywVV6g^I#(W~?%Jx|}zaOPrrsUQW(DD_j$lFemdnN_BjsbmS+M_!S6)l~ID zC6Z2LHF-!Bq4XVDL;fUXai%x1nzp#h*K(yCB74et^1Upq{!sm}4?Wa)oPdUXt4i{c zAL^M(AnnO^Tw?|-?;mwuEhh76Jyx6br@hE4m6iM=ljtD&0~RnwWmWI+-Y4WG`AQy? z73FjBOKg=xVb8@#4pNCMC((2=y-dH+)T|JTX8UL!`he^wTgh~s$x@Y8U6beKKDkhi zlD*|(`AL>ggVj3h@n_YXoFEqYOs3%so~Tt=UoX`TUt`rG?DHnz+B@TyB=Kn%RHw z@#2+Ss&1*bYCmi#n)D+xNgmRIwe%cI6_q$TVS7ZhbB!KpWOU^Y9}qFqJ&JAQe}txa z6O->Im++nPSMmMsec)Rg+-fzH8)+>2s5-0bY_s#QXHZx|SS@#Z;~E(tJUk z=@3pkowdOd_8@UhCaK?OGp(>=wQGwzw|k7UkrA%-p*iSq+Mm86eN=r>(C%Zt4E<#e zwFcM~_%psrRDxZuSM$|9xk+^8-1gaVd>CKE+wi^iR7)|oisx3$cSALX(P&pU2?o73H+it%gkwIvzOVetURIff%?Ju z<^V3_A~KVnWMlMBjz`XAt}(78M=?EtU8e)tzwCGRg1nM5`9o`z^~@@6AF+?{zeIO& zi|@b+_wZ!xxc&fygOFJd*W>|`&jd+da;NqPfS*_`6Nm;5fp^#AT(;9ck)?Hdue zXGWfv3(LQaR)TUM|`=5BMb<+tD4H>~aEQnRq#Ph3_bX?IpY zU+Ea+I_v7{++_6EHj+E=nofwB32KPEB1Z7Gb~C#H-zHAWwW=2)a9cG)9v9tsD~pA4 z1Pc0#`w#f11||iYg({d?EZdsN=gA(VIvcC?(b{XtOk;~^OL9QvS9AC`;dVAp)iPtl z^j%|Jp6qnuw?5jl>^NK5ujOreN-LsA>I00{&Ma=j^W2rqxla>v z1uwzth?z1YqJ3s@!v4*EY!4AxkqiIvmDM!u$%Lz9!ceLeW4>d)&bQ&XxZUHeftX{-0Al~>E^@;cm_LtEk88s0T3 zSCl{epsNIn=Q4OI@FZBm+RjUe-}znZH>)9kp=xM3j4ygG-DAviEOf@ZqMTi{@~SqU zz*8ePEKzM_E)mYt*uCvT{ENIw9?_-DMRUkc_AG0HHQRIsd;2b?^hycy`n<(_gMGvN zKLdZ5_w5bhfqD;XsEQN&oh8t(Dh5?SGU|govadCspLBLeGcJ4mY&+7njw<3Q5>`06 zc$x$03MFI4E$u-}cd+GRRgTkF~t{x~8r z;+JQIVS2uD5^+2kD{B+FJ^FG1kYQbw3_`|NZm%Fy9-EGGly%g$8GkybyJooax!M>z$U8fy{Tp8) zGf^j7NMqDM{)e5OpB9bO5_%ieR86A9TOP@KS^Yya@Wk6YSoU81?kXPHI56hp;1ev5bJ3+-!G0dhAceRfu0W}fZo?#F(PtP!~& zesW@;#5JkeuPTQ}o?UcePOHDQ*25_MmnX@w=P`&8^>@|QN>Z+u(xHT(Am$*%+TMN2J%LA5CMReR(zA}UAz z6*)GdockavV`m6_@lEol4Q>lD(+c@RRjj!}BC6-0d6>s|?R?~>VK+USTv_y1>M&o( zD~r!UO0Rq%7mBZT89tQn5!uOIs#qB|i8NC&G7s9Re*!hVj^zGHpMNe*^8cFXedufN zALB0-%w@e525q42(yD5HmY3~7)%sDLlbPf!p4rM7%4p?u)yR+~-{t(#ndisuh|t64 zMK?}sW=O~=#QGlRB}~vCwrc{M>$5& z!D5Q2A&)5^s`JNkHZN|c#g5*jo=(k#5YksqS7W1L%p|7-sES>hsaC>ft z_lasBdnaaG_&}|${Uwkpureq^2h0}c@X(cz2u-v0+qZZ{HJ{ehlMRn^qdVxy=c(lg zr}N}yzJXss`TGV6~9#bHuYQoRS)ZIpzD2>Kkh{z~{%qbU zDHVKs{d%Cd-|`;wKJ?}coHOt9#%ekVkq6`xy8gm+2|YtgliT8i=qtL4wS0ipm>!OH zW;&g_XAW=r_|(UvIz+#XKb2TH56k)BB5Ua03*hA6^w@@o9E-Nf#Ko-5p1XpKOxvI_P3M^ai# zF}Apucr@1sEj#Y13GK($YR$CKY#}Koqj?3Zr1{Z2jeah>_+;m@^V;p~b9PC)u~{-W z$Cu6*?wjVl=e^)HeVGGAgP(&3Lw&5DyfRtGMzLJ16P2VljYcJxMJ<#igl26sCusfS z&*pfQw_7%mdQ)^{cqpQ0{F?NevMfq}AhwEgVkrNQ)1MoEEt0hU=ZL&Q{m#ae34u};|^_c*9B%zj}%erni1szo~Ua%sI-d+T5CD@3k-6f7?CnMLdrvAcA~~ zUDC=JV!^(_#86&ydgy*&pg-9g^zQI)4-U4xe6IW{_n^Y7r931r{YjIEt$Y%ZO)NLB znYWD!aYeJ{%6TiptTY*8cSoH_RW8lyj5#xnP17UxU#Fih3H+U0@z=ZL9N!Cld;aC& z*Oy6OQ+5SLgjk?Y;5XBdJG4g5VxFwwb3CP-YqivDJ{!rp(i^Inq~a?dF6WWMWSE?9 zMTL68o))Ux#sYV&C#@qyuA$|eYp zpJlBH-3Vk4v<@5yt_r;gT?y_Eqz+6DED!cFPuNthm;K~taa@#>PvmaWm*%DGNPhKH zW)NGD7XB?QmPU@V|g&lv$VCZ(hAIO?g_verA z`M$UQnUp*%C9OY???g&BU#DOm{*F}BcQ|8R9UP1GM6He9-l%I#(<`!_qym`*!W~GqyUO=#6L@bxZb8Ey-G1jNL#+C@CNEhqjw<;fHv7e%yLtRMg6`n(;K`sXbS8Ac%wzv!Kes#bPGW}arP>4av(#f$ zHuqGVoMd-mB_dh6Hd)T3YZ6~0CV%AN=#*6768C00l;Lga{826=x4detvX)zJ|H~xt zJ?-~*KaMB${x!?HDka;mslPsXONVyx+2p$>^;YO6v$6f!7~`U&w4<_KU5nNtfhT2l zlrf63Qj+jSc6C8%4?UZsj{X<9FCOsQ!cwuUiMEM#C-HI`KaJY85C5C*uq#-F%r0gf zE3Li4&TcO-2Zr*6N}J=Yvesd$uJ=Gt;p?06*57z5*_%4V1jo@ z@|Yx^G&NZzKl#-!Ir(Svr2D^~`m32~c`6`z1=KH@tWxMRZK~1K7^vsh|I+(7%D6VT zK0DhRHQ6rYx@z=}w%n-T7^kPA6Oi{#@V4@-vdLWXNj4Pk?6=k)tD)tvmYHs|ty$Fa z*hB2{c3SI8C@$17wA&0>JJ9*|v>VwG_CHpb)y~{uwzg^l6>Eso8=|EAS8kS7TVirW#jm#4DHYQKP_4IKW7bZ4KwJ)rKW2?59IoW4X(%c)E z}-dvaA}T4#ER30_nV{ zE~_xoTji3gX?fSD$d57J$R}a%!Y+nyirgAKEbe^jL22!TPcd)9V%#I$TiwN7^>q&! zDyrKDLze=_{3(9!YnYNZ`JW{E>qp8cZ#Qp*uYo^4ct6zM8o-Z9sX76*zCaeSwfdiq zLr&_-@9OCCxQ06Yj%>~(=Wb^&hti|LDNSkTuTS7;=QV2Gazbe%-&p zKP?akUppAM?oZ|K?@tpvZ5n)(NF$5LA3!V8AXCQ6qUsM4M>Em(G$;KIjK&Mkd!W() z2VAZyqEAk5lyT4Ww0B*0%yblYUGVe^9~YSteL1Fa%&^Fi=ZD+xKIX~p+34JF+}EO6 zQ$)kB;+semlk7!iV(3z^Xs}{nhTrYq=5G?%97rGB9n2OQ7wR9njd;1AXGD#$1Xz6< z;P*r5dX`yVW8`!kGN$OiX;)cl?YeeWZ(#gsT+(N1YZznw=tpu(ZImm;ao(LzMV&g0 zZ?G#{>CL90R-ue$d214~>QFw9&*pFVbWvWukz9V3m*o;UTK)+X{jt0WBt5hIF7}Jp zqK@={vxo-o@BqkkT9OFNqn2LLSYb>wG8^@c1HeiiIybqJT+3a%ong-Hj?Rutj_%I8 zPR+I2DIIkX3ma?y($wS~kj3TVrTANX;VInz909RRX71_vsP>u zJI9h(te)3M=a}PI>4+a(!|IT(4b6UAbH*ofDlkfiEY2p-rZy&N&o7v5=W-sh`Y2K87unXJkt=X2#zG=7To3PS{Tyh^z z;qAnC(MVQ?UAz;8=DyOy}RyW(6y=OgEOXAS2%hsV*=SgbGARsd6YO~=#eB#+W$eleaO zvR_#lt=bIIYTD7ZfZ?CDlju1wWEW-IN>U zUYQv^#McT-o}@JJ*XmYh-|lJvMvA^*l?st2q!pOMhv0q+q66s(Zs9c9N~3^Xm!l=f zLRDG~0+%(P6r(L@W?Gp}rIBnY3$okVIlZs(!uVu7H|81Hfb4zHiP6NkZnSZna%^`D zcT{mC8yk&|Mh~N}QNTE_XVLosqhCThk{ESX+~GQJY~Qv<0K?p4?Z(Q^Sjkpq+hwP< zN5Gdh@C4BkIo1#}xr;ADcI#rVv^746-{X-Y1V5d}Tk(p#A@}nS;-lQAGJ|oujWssH z~^W1gp)={&lk{e~)V`ojDWccVbLw6#Njn5?E2FlP zYMX0m~Is8_5`&5q;7ux|uvu zBM?Q_B~^6z%10pf8=}A;i{IV79vOBdAH9iwrVOP8W}`8<}~3#G>)y zuE;7&qw?A%Bh(V0rAnm&v)q`RBI9Xq_Jy59HIk0D1wZ~w zh>NbFU#Lls(pcDOPB8z145NS2U!)=#tKP^OGDen?8)Q79>q|uY&T^?tM(wjx6~$?n zr>W=#ylN&`bE1l?@?e4A$%Zl{W&nxkF18@|Y?c+U^4xg#2VjE@;SYb%Y_tf?Pj8Wp zq!&3vVqpER!PY;=IbK(J$rz{uey4xYEA$S1f)2Ge&4?bS2Y&MgzgdNA+yoz9kpy95 zy;Wo6Azc|@pErR&?}oL0LDad26-FUuRv?eaPH^6L7>7Sz1}EB^%puF5K3ELB#8~iZ zbHzF?kT=u8uNChsw!tFOBgPFw1^q=$ATP;tvKKbJkE|u*NNM6j99{+ndjS*=yEXH`Jw6f4kd(*GcN-Tg%9@h4*OCZ zR@?;~Yd2hD7rBW$6QnR+JBd!Dv*}oB;Y{YKhjIqauRWscI=NNu#7Rt-gJlmHklTUp zm!~&rQ??46<`g=MT!3<8Hn8M^@WUiEKpj9_iUBY1o2-g!41y)h#uX03QX3;SXNIlM zA)iPRiJ&J?cdQ~ONoJf+1oaayNeBOYgi|uX+HX>ipa#efk2{9gH4}S311$0joLDy4 zWievP1K8P9M1tbDc21mWIK58#LgnB`j;N&e z$QJUBC?;+K1DVP%^UUB3Pr#E)$q}g4zRC{jA#@{WU_}Ag(suYjBX~kJ?Bsb^p9>Yl zN;y(?$Fo<+RdNz6p@|$TU(1fL^F^c#yzxFFP!0MUc#UavCS8b(GJ}qzWpFnkvIX7k z1!zweL2J?iNXsd(_)}2x9#xZ}VyKO5JOf&j8@PvmV9lSQ$w`kjyRkFZV8QcHt>(u) z5&9N;Rt?v_gZg;^>}4%H^fh@5PrV8TEQ8txADoMt=N~arSmLor_djpxA|J{7sFK}k zIdmf($Oe+07SkHhN@Oba;MyT_L#|T;>2WZ+$Dy<-L`SoMX6(n3pxYOR=7jLuVG^tH4^M_PzHlxt~va$GK=2o`! zkg|xQBZymuL*tV{0!m`auv^plhMHN@~hUeFQSYN3!8|_eixVGr`Sz2aoGMO zQ$!!0S+$~nqZjIj(;O{c%ih4?$H}y$s_}#@l+{=Wv8{^whZJOI7r@f;y_#^dFDGILmk(+An)Z8HJ)A<(NHUuRm*8ntEN_1Pwy|H<)itnydsNp zh??Xd!}4i^ptrEd@9dggkId1^vl-?lZK0N$Poovoc8&0Fe6;pO&t>J6+nu?zj{fy( zb+`+t2ytkxLUs{e!EqftQWjNNEhCe385*bjz&EO?h2lLsqII_0X<7A))?BF>|7d%l z(x|N0*YELyGOIC}HsmR!9{ns=(z&uS+oQ$#KhQ0%lfK1LcdZKKBTw}!qMp8vXCwoS zA~c&dgv>QE$!FRfHJ5h}ixHPoCWM_Pl>!lt<#rp|j5*CYEQVE(?}Egb&Ta-;vYAxc z!$>_bOItxQScUYaq@}oQ)ncjju7SP!ZezCZk#i5}7n-UksJlGQ@XPjAoO2QxWKA@_ zl0Sl)bF_ZTtcoZxhW*9X+MQ@FsG2&dYFNW@{>}BAj5LWh-8sa|9b4T0CUtXnceh9` z>%Q--<)^ZTYY@3;UNCmaFfq>kx6ehVL_Pof8(k5z&U?j5(I+5_CWPM8(_v-`G2;@3 z^7r3!Ms08}_ogRn^^AHPDI8kmEaYet@N*V%$$t5(ihG(-#otP%GqwSz2t(I97=6io z?v-QpI-#r9IcG`!mcIf+T2a&yf9g}sG^Bx}i?u@4*ZSG_wMovEDX~Ty_q|XtWabg{ zg-k`8smAm>c_}98No0jd!nU#&{(Z(=V`{KCo$c&zx2L1!UHIvAM7xFbB#9DEP0+?c z2YKqwXLc8xc)%4P^L=?EK8tPs;}J=6W1yMiE14)$sRKp}5pQqSi#xWOjqUXAKV@HW zo6aE3RY>b$U89R!9RkZ_Y{WgQL&}KQ^{V|Z+mqT!eH#16dDDI=W7I-J27l-)nJ2JS z=WLLfLuvX$tEU=Bmdj&SPe)qOA#f;Sj+G@a$dzQo`10_P(X;&lYoI4+o)?De4gYGd zcMK70&~fdPxwSp2DSJnIT33yQ>V=gvtcmy(=pOb#E8#!xE+Bf^+l_nbExhhD|7g{T z*b@|idy(hNdEUA)d#!uHa7TLiKx<-G6O)|DJQcX z&@N+}F(x^Kk=8xPIvuo)_r|XPW$oxJu>;x4E6&mRsD4VZi()vcr2e!3V(fM8H?qMz z;i#uo^fhB1t+GmM=I3|a&18i@GiOz0hE^KgDGj>mQTIOXx856tv~R$%-iJda=U)`< z;T?GkeX{t*mN>$KujoT(kHA>oDQcrx-(M-Vm3=hibo+Q}JJQ*YZzG$z%|c3ne@^_`Ha=EOWnMF&JvL7GcFpa*eY~aU3Ea~TZL{Bd*7HbnynCy4R9bX~ebJrGBB28jmt})Mb$4ggP9zy6 z<%dvOM>Xv!ua4TQvZqq06W`++E-sqw^l;5Gm$09xKXU1d`CWUTd!FqH&33<}BYgL@ z0qzt2c=DNjrito`HAYKs%(1SBB-dMW4a>^z*iH33=u;*bwPp5T4(BxeX6UV5)-{VP z^WOIKR1eL;Bwl+<*M+9g)p}5t5jjW|Q8VmlU@yNMKGj|wZ0juMDwpCkY8wZve&lz@ zYx^nU@nCDPc0~Ol#uBeCtd>A}?u+wki9D~G(gU&wKjc~;S|$g(2U=0;F)D@hsNJh* zjI^-MxjTVB8l=z1dF2j#)`-^E8izh_r^u|X={5W}-T6fS;5)}t?WzBacFs6vGv0#c z_0&qq%WP-b&^XcCmD7GNW1L?D=UHE4gnhvdIM4HZ{IoMdZV@IQ<7`b{2dJm3T`%N~ z&KsB+IuiB)ilTXEPJ~S+_Ha0PI#{XkCgMi&9u*PxDImo9u=QsCK&GfMWZ;jE(N(m)$rBwP^{Mtjm5PST-n6(eBp63; z8#jZg9eGGk%ar4c59+M-3{~@Nl15yyM!DXvu2(Ux^*1p}oVD;!Bk zh6=k~sCz;hmvOA1Y3PP4Yw(lpa8)!i1;c}5<3`y7lZQkHgK?yQE2FyonvP`EGQ8-sqn-#DWLA5N<3M|fedU|+$bOx8c!hqZa?seRs=1!&Dd7HwoTU-3=O zh1!8&bG{#)v79$Qr{k+R=?R+LP1wyU^MkfC-5$QJ63Hz`rPHQ{QnV;;m z?)oA*)WdO(E;W}h2(TLgnk+oQB=3(Y=|BsmCW*>W= zqq`koewkgoGs@DrDgDD=(9ynTuBQB*Z0$H>E|B?LBSY2oS*%>}i9T5$A2ii<=gYuw z>_rcIy77YMFx{}W-|UB~6n#faiju_P&{Yj z9^zz_90G>iFMQ}-o2zQtKj0txhkdhZ`6|21xL6=rq<3u5H-?s386%Zl!~f28l2qZ3 z#@m2u+T-uZH+{I&K+tc#`O+N*YVF z^!`Syv3t3lG4#`ula2GY*3LTih`VMBU@NDon>*xiJwJ(*{lz)QcdLY5*OSg%APVc- z#0}M1TV?OJ&O7_-i9u<-cJ*>LPufhH#CJ(~1<~=uuJRFtT`(6OyaB*j@&B3Zp9ZGJspn%339e$yUcER+~G1`sGj3*t*F_E)zG@x zy~P!bu5|GR*d+G_vzB$x#q2b)yS_;*uuss>+66n0n61^;Png%_DE6n`o%&&2Ul2{YzZ3LTn`Vau-jFS~ONWD5BI4AR<}0 zlbzRA_%j=AG?zLJytxc`_;jM8t=M|)fjS(T;-tFEOr`wJjbg7k+Ste9?Ba5k;p~gunt&k?9rX(;2Oo*vaEc=_xR72qR$5!M`}B)7hyAVX zv_dNA9L3vO_t+D?GIjcz7WHG>J2|i zlj&zMSKiW|%T*#3G?8buOjbWp-Kfga$d?w7H#R~2Af?1cS>G7P|Fv@%??_#6>reCz z`mNA;d$VJ$J~|X9_ZTP0G~R^1RJBPRT2Iu{mh(b#y8ek&v)hqOtPi-LzO<0qE-T0m zEZR{oSWT@oZ0J?$Qeib zyF;duuJ$SRjE=WMY!zK7s`CE&E7p^5x1X^4>@IrwpDF`AAhK)i=~iA;RM0$HxV&ON z1lD>*J1eW%7fC#;P0q+D)=Pw|T{N3=vM|+$$B|YH3QM)uG{qFx3$l{`ZV%B~I158N z6C|hfqI5mD>zCR#x>LQhuc>d)t$q{#{15c7P2xCfikEwMjcK zM?sa?O6x~fiIPCHCj* zddYWMEX^fW%IjngJtB($4Va86gN-s3`IkN>6U8r;j?ADj=!LFgGUc8q3g-BM46rQ3 zDSwH7p@!TqhN-{E4b_}pV1LTNVkpV3SBF~Y3ed*IYymAoy!Jn0wf3Cqz?1R<%WH>M z-64y4d2x+wCV6NB@)c?|ziOj(VKe1bu~z!jH|-0igM6Zj9Hq~sE95|+c(u^cxX|%S zWdenL3N&Chrfu4)mEuoYjW~fX9m33(9|(ITwseCj<`;oU8IU5FScCu_bVRp&6h*B5>ePP)o1Dh+J{ys)pNYONZjE+aOi zB5t)^l&1@*L0z<~cp-C>O|%d_2poQ$swIn%&$Kn&Po|-Fs0kKfDak^Q0q44fsk!N3 za|g)fl+sSB0g$w+m+l{>n$O)p8nr7_jCM6M_0Xa#av zw!r)Egg&7wUNs7nfE_R+H&Q04dthj~t4Uy>eyT!h2G+X{h`3ZQ5R4Z##K@Q6aS*V0wL}Uz3LULuNdyO0QU3} zc&mz7|4Q6ler3rA@XEh&tqSB2kotCrc^gy`c~7p$w1~(JfKfY1K5z-=pwox~-?#67 zGmRs`Tje6DRY_RRo;)=1P4|cI3-n9{M{W7#L?2Ujo>xH%K0v?@}_@FZAM@Ev% z@cL`u9gLhGuH>EQ!6Hl5vty5%G*st<+EWvuSiYw5eRx_TBU zx3+47DPt?y8oB{7V**xo8yrw`+|NEyT8xHra3$FJj=Yo@F5<*$@t8+}9T+11kauKS z8Lm=POU(Oj2TyyJlti4KhZr{qdo_eUBwfiOau+j>LE!eksSh%z*eni;9?)1k!OX;H z-rufg@3TAFE3C108hgC$;pO-)D25n1-dAEcxW9|456MEZBQkZMrL~i~W~?GLyis(_DQ zW;HN#S+A@#c3pdo9WORv>hvVEEX#0q39zmbbSG(xInCWPL2Crn%oHHAldx;;Sv{?g z)|i#1{b)3611_@}%|Sx2#b(fBX9sWpj34CH`8~TNSmO5lyuFOS;%;6AoXkERjS6ZK zpTcutMqz+R2Kzln7L?cV$d40Oau>80;}E+d!CReUOZ8_)eJDmsJAN4>441J=FKt8{ zQP9#w85z))tPqsS0=bY<9A zXSEq@$7@8oIby!J#tVV(c8J&fU+6;zig-~M^J+JR8`_i)VklVQ8zL7J<2LMfi>wOG z`W(3&d80EXY*xZM3(H@yvBRnu{hQU+#_J`G1;#<6gR#@7X5=!2G1fS(|1kC%)%4|h zDRd#nXjibgZOBw`IQi9Qs9FCM2e7++xyHBIN3quJc3*oYKgD0z&G}SO5>}fE+-o_o zDmB2>K0uZZ()IKuW`J5@HZ_rDW*J!}=z|n(0`})AeFVnzG}zmrV2A%A4n)ANP?~-a z0#m*3#V(Os-V)<5*}G9Z!sB0XPQU$6kunrI-TTmGG(jXk3od^h_(-YlBW9jdE^-1~ zUq;kXUyz>*k@n#HI-?eB1O;IUWTTbHjaQ)w7f|Iq$Ncy#u%xfa2-=J$(iEB#`q(El z1Ixk2f;YO%#^co5v2xI!Xsjj}%WPm6N7MXd9A4iVDjAPz3!iMGZ1|>2-G`oL4p{7U zqJ-Rz$1$)<9n~I0o0WL~@8IL6A+HPpgSk*GM8|Lr3}s&MTg{RG`XMf7BuB9iHxWbQ z5xqHht~kUi4><8&G9ZpZwK4+Ql`6soHO@dN{RW5w(0)jM3tIXYqN}`%sF00BvmRPG zeU#C~QOr3BI;c!8=`7&P;~eAM>u^JHa>D3he9{kVH`!cR;#zV~9hIlyt26ludzck% zwYG@0F|@#}ZQX%pyR5y`&dl?|%f7;ET_OYaupOdwfF7U|q39WhDb5%0)6rU(HUXNK z-Yf}~)kK<>rUP4-97 zoZSP|1o&hG#Ie7KNe0vAv@C5$>wybx0-bv;k_yaq92rYm(=*WJtVPyuNk)^}$nJwM zx&0Go`Y(R-MvVch`vVoqAh5OpH6Ky-3p6?l5iNBx8y!?U>Wv=o=In@~-7z(u1)ool z4tA3V@7)lrcUv&3f8hEd)sftR%{+sOAQdz~_pt-XU?UEo!tI0_=Kv<^s)18H1Pp3f|qWn ziX!Ge!5-M)siJ{ht)_D^Mb_iLDINA52>4ExTYJtLp@-WED?bB;!UP&e{{}jE>neEJ;Bx z)}4N&1y~LC1PY`!V2OKby;)lJlC@_o@R=HiOAU6R1X8z^dBQm2@fk$RYFs)IL%8 zc`}q9f8jF-#-ZZ4juiz}BQ+K^j)wbwfz|&ZTOvz85?aSBbxYHH#7G<~$O#n@z=F)kW&jF5g#*NlaFg5F;j zT0Ur_&aemcZ~7EEf%Z`7aFGKTcQO9RE(?A0Gs}efi`v_4FLa_D(-fyLB}7o;EWrH0 zX!4Czf|6?=+s-1jquNfbjdoVs1KrYF#t&jcoueIT5|kLjp*R}=70y*W6FncQt|zFrS3u=ZSzVPMkuR1) zyVFD##qQUTKSe`%4qj0c8Mr1q=ano6)kZti)}xRO0`TB?#N6`mezn|!%naEDN}E=)F?_T+tgjXJV*@hL51AEtxfG(+5vZXS!0*087c?1N(0s(HX=)yJ zrz{i$O<_v{TAK&*n%sd`{*BjKJO-vb=~r&dO{}HwT!V%|hmO zvnoCrA+w#x5AzC8nlYSg0#q||Xd+8yowW^^>o~2S)xYZwW3xU|uc9~oZ>9lA=uYTZ zOr>M?yEVLTpy(+|;6rUnLPht6M?h6rMKloWQTgTkpVDz4J||@r6h7IIoqExp@cZ5H z`*N_jX4HXd_PALJK||c=)*QZTeT3HapX{h~34#Z&T}m&|KnE?FrUtOwvvc zbq!4q-3iq)&zU`}Dwuxi2bJxAs^?8=B^Zl*Y%>(1e`&9^c>Q<1x4uokq;J<(gAK@~ zFVm=2n+>HC(WSdpSH$n*(4hAc)xqf3Lgw)ZDx&aj1kQ0j;!$QOlH>5XF(=WbULvtL z!=w0o7e!xzJNSthUK^U58c?UYfa~ok*@r1NClgDEqbp9Y{f4F%$V?5Av8vo*^z~raxg>-;vcXp~^i2RACd|Yc$TkCTi9l zKtQso=Qy!txVPnqxVI3q5^>sHRYzb7L!nk3g#L8}o_Rz4LTyqDeewvr!wTGUDfHG} z^wcZxxj}8P(yBP+vhefLxUWjUBf6p*nFV{ggM51rb@F$peJ`V4S&h#%>Vy0;6#aZ8 z>Zf6}J^M||qQ@JXjh6;>oHf!K{qzIc7o7YB*fOK%VX+k1Hc6J3GogA*!x!5%?0j|x zyE^p7ragc^;%zXI&=5M-QP9|}#JUgB1XLJBSXg! zG455#1^9HJB{-Wd_`gF?3rypjoxUX&%G#<6IZb1= zar!Ug2R`8<&K2WY=JYs+Ip!NjFb}m1I@_7702@YkkzdfdE*8Cb4Lgff%e)Zk6si#F z655Xm!|UcBmb7l$OL;X!QYUcWaijxg-2Y((wN082+Nd7-ZvBb=Q4i{$p~A1CAJbe~ z9`+k;NV=&J@*S!fDlYJ&d=(!KWpxWowC3g6cm`g9kAx~YJMOoLydc}*v%Z?cGn~+a z7ysY>DcJfK_|Ps?uQL(tYCuPR5Pnh$nQ1PvQGKX7L4l%E^`lEHi3;hBYzl4SI`s50 z=z*7O5&C$1K2NNXXf!j18EuUJYw14VZJ_%Wh7;#2!(nItt)yY zl>R+W>#0b)o;FHoNRf&LWtXhXP-fXHLWq0s8UOe1^t{gNu5-`%jqmt;Ki}_foZIlD z(OHe`ng1X&mf0ZIFM4Z4?MV%Ny6oZpc=hG_~SHpMqFv<_b0N46QceNxZP|PceN4zo>(9sxg$~Ej{39l zKjH`D=i-fE(zHZz)yFTB1<3*Gg>zLKi{KywSxRF&MVCaE#AasB$a*@vRL&haLvkWH zw`NbynwFYHIH zhb?}Ox&4gM&GmhsjCofgQx1BS^Wfj&2=}T5R#PY1>7-DbNcTu{d~gGFtATCamWre% z+uvCNzJH3>-s1d9=S1m5u|$K!WhC%5{GD&N_}v*#My5s=$9~A1n)P^gBKv}z3OPqb z{S{d)vfj=-5xXY#4c_yg$WEuCx;oh zQxB-YyagY|ME{bz&WxRKe&FKFCYi-mUEale{)|4Sb}(MO^BIh>FN+#2JNeRylpg9V z1@UceQhHPUWq5p;@8g|O-5>9icqXwiaZU0I`O;(XpqBdnOy|M(vWRMK=&FcAU5#mH zMW2r5#=eU6%-rJS)IYPX$~u8#Op7gy7T^Yt$r)c%<2>gK*7#&|8PTnYA60v{7d~4! zqHtW{p2D8-<8qCGEaNTJ;H%W8e}yQuBZKU#UEl_g%VaRaqc6Inl@{)VR zkP$Mh+vTmNQ|o1&<8j=j@O>?$-{rUWMg5|wrDEUBMR3bohe8FQUycr@14T{ES# zYGq|-O;wjX8T(JHRP5{M1JR1n>9Xt1nB~+|Ii41i>x^Ovof0>y#Js_-PO*b-svNtM zb;ZFD?_Oq4wdCksBkfhM&SJ)&K)*qlb@z;p?qX{y7hNreb$7yYCWJ4UT$^~?{@uEG z+Ck@3R?^?G_*o~C8slqs+kJO`Vx-=lm51bF-$PPQW~_?zj;+XSnmsyac#&xCCAoWw z^egg&>f@7H<1#15ekR$kBOSz(%iRRDO%I(_PbU@T7mO_ETF|zjSHVm+_e0@x@g8h$ zX>usco{>>KGFEM>mK#{Pu(&~DT;TRioJ?vV_BX~y`qKSKyWm&Td3O=%=ZtFj=p*)!kMp=Y(u?)iRBzaV zmAtP4-Cv|CCYyKT&WJq;h*BW^*s@@n5S_Um4*ZTsme~ zel6=3-BetLs{`!KD33uuFUMS|9`}f7@P5+h%^OW_ySB(5lDBe*O`H(8h z%XnCJ>Xh7Nf*gF7Si2@!GS$^NmRXoyhF4X!x49o8QEeZC`qCXd0#iS z(l6s$-@DM#RrED9<6-P^mJ_=3aM2gcx|VEutT`s2<-f6$KA7@V$^OZ~9(RdZLp+zq zbI0PR<&7XOqkE)dw7WaAYG$p=>f%&sLP_K7_n{Q_ScD>qt8PR&zm)kJnA zIwmH@OUEb6de1McUU+@s^umVmpW=_<5`QLNSNU0&(Tk>QxtVf;`o)syJU<_a_KKE^ z?s69631?F-f&NR`)+Jc|eQ<4pRrlSADTyJ@zg?SXljxGT4yIf$cW9ZYmuRLAGz4Dl zP4q}Ev<|XB_3b@A@VpxDtGhTMj~U-V2gp-%OE z^p)tP>NpSLm7nsD+>B*xp*4HjK!)9{g6Ag=s%B?8lQj_Q-=mfmvhZLT-^22K>>^84 z?v+0O^O3xCz8cANtnv}KI>P>?m()0`M}}iFC2{nJqxX_kxTmwDeaoH1tk1%!~r$17jf+FqR3EL{RT3b=Jd!wtsWMyTs0|cnytSYDyJ_%fLQ7`oZ&Mb^qXuh zn*}#d)k{^gK2(&Q6w&Hs*3l-xvI8o7wQ%5U>yn@I@;-8_N>H(^e6@{vjbgEHo7q;8 zzN*!|f3v==^0!{rJN~lY`YkqZ*LLgy=jN)r>+*W%JX<3f^KEa>4>v2uRw_N>5HyxkMc0#Ea>5qYNqxsDT za)makh8?0?q9bCbRDy5L4LON; z6%qAo74dJ}{qdE0?SC`YV?muT)d`X5B=bJTdb6BlEsbYo%q4}o(CDB$3RXen>G~~Gb@XzD6ZL=ampm2m8Bn{62$#U4s?qHA>~tRNFACZFvBz;D z$R3$b&z8_XK9ST40f-Y~ssUyi!(B?5#u?Sw)REfJK zEzhNHU3C=1*{-sYbENw@uG+t9~PP zn6HoAWG7ykSnt?etPB49jP2H#_aDFl=BgskD;Tunh z7-y2L;LCW{@~cW=DJa?&iuSQ$pGW)6RGcbi{3!=r3;$mi5%REu%Xvn5?0#! z_B><%#3v%qX0^#EU$L`Bq@JkNYk8QNug02m8fvinck4x8 za+EKg2dU%j%fn#Ywg>>)fgPvmZDhxkr@vQ{l3%67-QgSs!wS;xkAllx1M zrIdW^Gjp13W~=C=7-mr0YE?cdyq<350WqT$dsHKgwvKI&waUCsPV$X2z1P{*v`Zx2 zO)@XQkhfKm!Yv1NQ(Kddh)S8s0HJb zCHW!RzZimFhzmBLr$KrduZA*{TvlT*N9m;)F4YvvD2qu&AnbaNZRz#9^Ua}>dz|yt zoo2%Dd(@`}W27(pYq@v->iOB`_nuiMQ=>)Zb&(0tUa=kI+%>a8=D1h_r++I_FR}#J zdsm(3AXLwhnGCj?)G{^F&9l8#0XE2*8zr7j6efmVwhz4KXCX$Vn?Wa^N4XPHzz7F(P?^4ZxQ71x9I;)Jv2 zm&4C~G=6RR4)R@U@)Jn^30^mrcYa7ROVf`9P-7vVU1;?x;rwoWcGTXZ1#d1-^5y;A z6UyH!df%eA)1uYwu(z1F)(Qh%C?oj?s|zcoJ#pV}%(oTnXhz3B+4q*8@t#|dK8qG} zTj3S460xOD^`DX{w1}LP$9&G8*NfNdt<8-yqotVaMA_*R^0&io$w zP1`#qYBUwMVyVu_5sBGuWn1cAuKk7A#t*wQbY!v)j@1%V_!_VJQie5FT|3Wu--m3h z8uXna<2a2c?umSES7=cy1`{YWc zR17L<_q@nItYbZ;o?O{#%>y)E$a{Cn6l%b+v9jca;_}P-4qW*)Hda%mzYP4k0S2uj z!E0rD4`3T3_-V+`4zai2_)BeeGsb(@t7m+V!;fSWM_9}#2wGPwEuh|DUi>jD{Zrmk zh)tG)Gw0(3tz}CCM6-viaZZ5LQ~AQ%a;c~2>EAMgVJZq8-MIe;9~jRX$3Xb^c|`)s z-$AltyrP{J=gX?DFxL7|;(GY^u@V2quQ#hiyyJb3S|xi!JlSVASrb+30W6^dOD}7M z{T%w(~iTk6JC8DUv5omcN^s!R=%cSUsKg_ z$EQdB3@HucZoZuXm1R1TWoZe#la!pwkOABl^LdC;#=`7CW^>-U&3pMdoz z^j)1C!_Khjs?DoJpC#!`yB`_kT!cv!F&}61&OF z9#BzukQLl0>NGIJ0)ayFhnr%+uw)Hn5{UXnwlp_@3TRGNs51H$tbiTXMcuV6# zeW1<=V;hxb7afeZBJ?Xm=B3U4LSA@@=+Ky?tFXOu+MC5=M#It)Tz#nq-Ab#Jp;Mjva8WiPFKXQqpM&7qTJo;^r*f6@Ix zb6+Xvm~J*xRUfvi%{767fv-+x?@#f%P7pi6A2)0BM|GQ@{I-rXX7QvqS<2IF@=Y;e zK6!4T;e<9yvZ#t?Si)D`v@cziaT#{f(eLe`V)M6_?Tll-CXN#*Eh0l37N}n;3D6jWA`Y0jp z{%-8c$Y23$T?t{%rENb)51IV*EXyiRsu$5lU;*`EKzLu&=uVPV0>`ZfIlJ0HbG4t9 zwYZf$r{RI`lHUT+Gzrh@`(Lo0QdXQc>+M^Q4aQW)NC(05ci`viQ1lifjmw4B}Q4A z%z`Wz!iZhQUXrX^h~i~%m}NZsVe#iqxHO4RgE1}k0xSddkyREZEkg~F8!*{8GLL$8_7>qC#(D6&dYezX!E_2 zM)T=r6DvEwGcriKCCk2A8|_(L75X}2-gC6@v3BM_kgeui%E;S8plfly_WZgE1UaUK z1LhXmNs?_Xa&IjM>PUMRvYTV1xRS*zz+1M^UkU!zIn5bckX7*ORqS^>%YDw*EA;m* zO)T@CO|-Tm?K$7mK-llFlKt&ZkG7$)=JCEEzHan;;1>1Bwzk&W(t8`Y)0r0f8f$wo zy|eF^7~6U5D2HT!OII@&>1!cfPlrNNM6l;@v6qZu3Jds(jJJ4%PgvQ*JCCK2Jz_>R z_0)`Awj;+Y{0xY3F{w39&$ud`pW+$Y%waw)gcb4={IHxlMc`W{;L=XGvP=*2`2Qg| z+~9vz#JkPsX87L4<4Pm$LLc4e?lQBhpZ1^1`Y#EW4tU=dx?QWCE&4i0_9v}yor2_B zw7wiFEYSW+n)!_dMfF=sR(@7=Eijjp|MRknIA=q%37oqXdG%z?z38qj9X8>CW#COz zG)j2=Dc=)p__*G;)6Ht5Ut+a!t~mxf-J$;k^et_cW%z$dHOoJYY(tt2&i9&~{tliU zKGP;d)m`*ElUKdMiYNN*3p!p$@9Wsxw_5(#*K+U5@alT(vj;u)<|V!8HOS~1{*cE9 z?_<@W#^Ho23HBqic4Qi3dppy1^DSA=;Z3{se^zwK6pwyq(aZG^sw=b2W1o3CNlF5@ zYV!tt1W%|6UFwkhrDQTd+r#wMpKcn_W{OAbrH51Ec8=akv8-$uwb#7XYx@WD4z;nL zc)&Iz*+gP1^fb>rKlSyE|Nl&C2S_x0>dGMXAE8GrUG<2z)OOSzUu1( z_BxMM9pVeW)9w!bzfx~Yd|&Um0y?Nle^-)F$jv&CMl&AKlnn)}zW`R1XGay%daCZV z&MoNUPc3fN&i6bb+<>;jxMJ+16-jg?`4B7W&}?BEufwMXl}NWz)ANnA-h`Y&Bndbj zstv2me4{q97(TU!J`Q@ECi(MdI9vP)ks?Ha@QG5n{?0Yh%IVDH zLNcjE>LHee`vbGAtNq2YLruECI76JP#TT2=RpazH1B=>6@*7Ec4=qKEqNuMze=R5d zNxXdw-uD{T_qkbbNTbU>qx#MJ{*WV;({}^D*WBFe>$?gar;I&L0;g!}q`nW(M98Xs zG~e|+Vw=|g#=4@ILT0nWyEK7ftjIE*B2kdPOb%T_GjDlXv`F{=1Ht~vj zljl>al?73v`4KW`FS|y2M}Z;4m1v)jMrX>7CZEQ(g;hi z)M{qmf;Y717oGiVmwvQ2uQv44fTqiP6f=)}+6@t6yHT&C)!@m2SFeO%%i-#GelFAZ zA-JAWQz%2?&9xc$NN@Pvn|Ay7y|1rMo^Pjzz_gm^r3UHclFex-5qekw>sImHkelz+ zVqoRPv{=b_D${2@n(F}1`a++(v4i{614fd{O=i&yW6dUq{od=73E0S#^xR*?D_#-J z>~`~*Lc-HzEHl~CBK`eH&b!IEfcKo^nYsEaDx!p%K*Z0po{d6-98xafzxkdIc}=j) z!$!T6mX7&r7yAi0#vZnG+m zpXIwDMunIi=2FBfFW?b1JVNXaCq;r+2ET4$ticD{!Q0j(+SvcjhhT-?d(iB*l2*8% z^{2G2Y%_5kA!}$oGV?95=teKfy0yo@b!QF<%GhF<8VwzXd*W22SMCa#=mbz2byt z3XCG?Cd9g6X93XzbH2#PDjRD#&m~E47rCz>}`sgX=xZy!ko{_*(#7m8FI|_%mw->#siOQA>TXyGNRDNuaI^Ee|ZcWe=J=K z2sh9zpxYfD1u&!(xi;p8L+lKMOF!AAJ87mTV!fd}B|f>wk4fJn-F=_QAZ!5;uzM&!lw|f(dYjIApF;a literal 0 HcmV?d00001 From 9253431b5b125d7b3efa99b38bdf56b991725620 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 16:07:27 +0200 Subject: [PATCH 02/13] add component --- .../preview/components/joiners/__init__.py | 0 haystack/preview/components/joiners/join.py | 24 +++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100644 haystack/preview/components/joiners/__init__.py create mode 100644 haystack/preview/components/joiners/join.py diff --git a/haystack/preview/components/joiners/__init__.py b/haystack/preview/components/joiners/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py new file mode 100644 index 0000000000..17da22e275 --- /dev/null +++ b/haystack/preview/components/joiners/join.py @@ -0,0 +1,24 @@ +from typing import Type +from haystack.preview import component + + +@component +class Join: + """ + Simple component that joins together a group of inputs of the same type. Works with every type that supports + the + operator for joining, like lists, strings, etc. + """ + + def __init__(inputs_count: int, inputs_type: Type): + """ + :param inputs_count: the number of inputs to expect. + :param inputs_type: the type of the inputs. Every type that supports the + operator works. + """ + component.set_input_types({f"input_{i}": inputs_type for i in range(inputs_count)}) + component.set_output_types(output=inputs_type) + + def run(self, **kwargs): + output = [] + for values in kwargs.values(): + output += values + return {"output": output} From 4392537a32e38c00fca8188a8350d4a66ab6880b Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 16:25:00 +0200 Subject: [PATCH 03/13] add test --- test/preview/components/joiner/test_join.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 test/preview/components/joiner/test_join.py diff --git a/test/preview/components/joiner/test_join.py b/test/preview/components/joiner/test_join.py new file mode 100644 index 0000000000..133e21700b --- /dev/null +++ b/test/preview/components/joiner/test_join.py @@ -0,0 +1,13 @@ +from typing import List + +import pytest + +from haystack.preview.components.joiners.join import Join + + +class TestJoin: + @pytest.mark.unit + def test_join(self): + comp = Join(inputs_count=2, inputs_type=List[int]) + output = comp.run(input_1=[1, 2], input_2=[3, 4]) + assert output == {"output": [1, 2, 3, 4]} From bfd771791aed80ffb3e126432d74ebbd7c670b41 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 16:36:47 +0200 Subject: [PATCH 04/13] reno --- releasenotes/notes/join-lists-1307f7872a37e238.yaml | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 releasenotes/notes/join-lists-1307f7872a37e238.yaml diff --git a/releasenotes/notes/join-lists-1307f7872a37e238.yaml b/releasenotes/notes/join-lists-1307f7872a37e238.yaml new file mode 100644 index 0000000000..06f96b0add --- /dev/null +++ b/releasenotes/notes/join-lists-1307f7872a37e238.yaml @@ -0,0 +1,3 @@ +--- +preview: + - Add `Join`, a small component that can be used to join lists and other types supporting the + operator. From 1a25d2f98898e630125dd9e45fbf98b6a645b19d Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 16:48:15 +0200 Subject: [PATCH 05/13] typo --- haystack/preview/components/joiners/join.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py index 17da22e275..21118266e4 100644 --- a/haystack/preview/components/joiners/join.py +++ b/haystack/preview/components/joiners/join.py @@ -9,7 +9,7 @@ class Join: the + operator for joining, like lists, strings, etc. """ - def __init__(inputs_count: int, inputs_type: Type): + def __init__(self, inputs_count: int, inputs_type: Type): """ :param inputs_count: the number of inputs to expect. :param inputs_type: the type of the inputs. Every type that supports the + operator works. From 74eed5989ab9d7503f1885553dca4f05e2ef2fec Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 16:54:33 +0200 Subject: [PATCH 06/13] add serialization --- haystack/preview/components/joiners/join.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py index 21118266e4..e0cca3275d 100644 --- a/haystack/preview/components/joiners/join.py +++ b/haystack/preview/components/joiners/join.py @@ -1,5 +1,5 @@ from typing import Type -from haystack.preview import component +from haystack.preview import component, default_from_dict, default_to_dict @component @@ -14,10 +14,23 @@ def __init__(self, inputs_count: int, inputs_type: Type): :param inputs_count: the number of inputs to expect. :param inputs_type: the type of the inputs. Every type that supports the + operator works. """ + self.inputs_count = inputs_count + self.inputs_type = inputs_type component.set_input_types({f"input_{i}": inputs_type for i in range(inputs_count)}) component.set_output_types(output=inputs_type) + def to_dict(self): + return default_to_dict(self, inputs_count=self.inputs_count, inputs_type=self.inputs_type) + + @classmethod + def from_dict(self, data): + return default_from_dict(self, data) + def run(self, **kwargs): + """ + Joins together a group of inputs of the same type. Works with every type that supports the + operator, + like lists, strings, etc. + """ output = [] for values in kwargs.values(): output += values From 8192e34ec408053c7a2cb3e6b73cd801a98ea05b Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 17:00:14 +0200 Subject: [PATCH 07/13] stray changes --- test/preview/components/audio/test_whisper_local.py | 2 +- test/preview/components/audio/test_whisper_remote.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/preview/components/audio/test_whisper_local.py b/test/preview/components/audio/test_whisper_local.py index 745ece37c5..1acb47878a 100644 --- a/test/preview/components/audio/test_whisper_local.py +++ b/test/preview/components/audio/test_whisper_local.py @@ -157,7 +157,7 @@ def test_transcribe_stream(self): assert results == [expected] @pytest.mark.integration - def test_whisper_local_transcriber(self, preview_samples_path): + def test_whisper_local_transcriber(preview_samples_path): comp = LocalWhisperTranscriber(model_name_or_path="medium") comp.warm_up() output = comp.run( diff --git a/test/preview/components/audio/test_whisper_remote.py b/test/preview/components/audio/test_whisper_remote.py index ad71c9a5c7..26851845a2 100644 --- a/test/preview/components/audio/test_whisper_remote.py +++ b/test/preview/components/audio/test_whisper_remote.py @@ -217,7 +217,7 @@ def test_custom_api_base(self, mock_request, preview_samples_path): reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", ) @pytest.mark.integration - def test_whisper_remote_transcriber(self, preview_samples_path): + def test_whisper_remote_transcriber(preview_samples_path): comp = RemoteWhisperTranscriber(api_key=os.environ.get("OPENAI_API_KEY")) output = comp.run( From 7b1767e447d07029c061434acb71885775265639 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 17:10:21 +0200 Subject: [PATCH 08/13] add tests --- haystack/preview/components/joiners/join.py | 18 +++++++++----- test/preview/components/joiner/test_join.py | 26 +++++++++++++++++++-- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py index e0cca3275d..dc42772eeb 100644 --- a/haystack/preview/components/joiners/join.py +++ b/haystack/preview/components/joiners/join.py @@ -14,24 +14,30 @@ def __init__(self, inputs_count: int, inputs_type: Type): :param inputs_count: the number of inputs to expect. :param inputs_type: the type of the inputs. Every type that supports the + operator works. """ + if inputs_count < 1: + raise ValueError("inputs_count must be at least 1") self.inputs_count = inputs_count self.inputs_type = inputs_type - component.set_input_types({f"input_{i}": inputs_type for i in range(inputs_count)}) - component.set_output_types(output=inputs_type) + component.set_input_types(self, **{f"input_{i}": inputs_type for i in range(inputs_count)}) + component.set_output_types(self, output=inputs_type) def to_dict(self): return default_to_dict(self, inputs_count=self.inputs_count, inputs_type=self.inputs_type) @classmethod - def from_dict(self, data): - return default_from_dict(self, data) + def from_dict(cls, data): + return default_from_dict(cls, data) def run(self, **kwargs): """ Joins together a group of inputs of the same type. Works with every type that supports the + operator, like lists, strings, etc. """ - output = [] - for values in kwargs.values(): + if len(kwargs) != self.inputs_count: + raise ValueError(f"Join expected {self.inputs_count} inputs, but got {len(kwargs)}") + + values = list(kwargs.values()) + output = values.pop() + for values in values: output += values return {"output": output} diff --git a/test/preview/components/joiner/test_join.py b/test/preview/components/joiner/test_join.py index 133e21700b..4609da6196 100644 --- a/test/preview/components/joiner/test_join.py +++ b/test/preview/components/joiner/test_join.py @@ -7,7 +7,29 @@ class TestJoin: @pytest.mark.unit - def test_join(self): + def test_join_to_dict(self): + comp = Join(inputs_count=2, inputs_type=str) + assert comp.to_dict() == {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": str}} + + @pytest.mark.unit + def test_join_list(self): comp = Join(inputs_count=2, inputs_type=List[int]) - output = comp.run(input_1=[1, 2], input_2=[3, 4]) + output = comp.run(input_0=[1, 2], input_1=[3, 4]) assert output == {"output": [1, 2, 3, 4]} + + @pytest.mark.unit + def test_join_str(self): + comp = Join(inputs_count=2, inputs_type=str) + output = comp.run(input_0="hello", input_1="test") + assert output == {"output": "hellotest"} + + @pytest.mark.unit + def test_join_one_input(self): + comp = Join(inputs_count=1, inputs_type=str) + output = comp.run(input_0="hello") + assert output == {"output": "hello"} + + @pytest.mark.unit + def test_join_zero_input(self): + with pytest.raises(ValueError): + Join(inputs_count=0, inputs_type=str) From 74998714239026e7689ea3b98a063737d709b6f1 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 17:18:03 +0200 Subject: [PATCH 09/13] fix order --- haystack/preview/components/joiners/join.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py index dc42772eeb..f27af791da 100644 --- a/haystack/preview/components/joiners/join.py +++ b/haystack/preview/components/joiners/join.py @@ -37,7 +37,7 @@ def run(self, **kwargs): raise ValueError(f"Join expected {self.inputs_count} inputs, but got {len(kwargs)}") values = list(kwargs.values()) - output = values.pop() - for values in values: + output = values[0] + for values in values[1:]: output += values return {"output": output} From 0657dc2d9d8712e6ff3dddd64a1b1149c092deba Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 17:49:34 +0200 Subject: [PATCH 10/13] marshalling types --- haystack/preview/components/joiners/join.py | 8 +++- haystack/preview/utils/__init__.py | 1 + haystack/preview/utils/marshalling.py | 46 +++++++++++++++++++++ test/preview/components/joiner/test_join.py | 9 +++- 4 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 haystack/preview/utils/marshalling.py diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py index f27af791da..00e37dda9b 100644 --- a/haystack/preview/components/joiners/join.py +++ b/haystack/preview/components/joiners/join.py @@ -1,5 +1,6 @@ from typing import Type -from haystack.preview import component, default_from_dict, default_to_dict +from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError +from haystack.preview.utils import marshal_type, unmarshal_type @component @@ -22,10 +23,13 @@ def __init__(self, inputs_count: int, inputs_type: Type): component.set_output_types(self, output=inputs_type) def to_dict(self): - return default_to_dict(self, inputs_count=self.inputs_count, inputs_type=self.inputs_type) + return default_to_dict(self, inputs_count=self.inputs_count, inputs_type=marshal_type(self.inputs_type)) @classmethod def from_dict(cls, data): + if not "inputs_type" in data["init_parameters"]: + raise DeserializationError("The inputs_type parameter for Join is missing.") + data["init_parameters"]["inputs_type"] = unmarshal_type(data["init_parameters"]["inputs_type"]) return default_from_dict(cls, data) def run(self, **kwargs): diff --git a/haystack/preview/utils/__init__.py b/haystack/preview/utils/__init__.py index a84ea468e2..9b13aa8604 100644 --- a/haystack/preview/utils/__init__.py +++ b/haystack/preview/utils/__init__.py @@ -1,3 +1,4 @@ from haystack.preview.utils.expit import expit from haystack.preview.utils.requests_utils import request_with_retry from haystack.preview.utils.filters import document_matches_filter +from haystack.preview.utils.marshalling import marshal_type, unmarshal_type diff --git a/haystack/preview/utils/marshalling.py b/haystack/preview/utils/marshalling.py new file mode 100644 index 0000000000..1248b3cac0 --- /dev/null +++ b/haystack/preview/utils/marshalling.py @@ -0,0 +1,46 @@ +from typing import Type +import builtins +import sys + +from haystack.preview import DeserializationError + + +def marshal_type(type_: Type) -> str: + """ + Given a type, return a string representation that can be unmarshalled. + + :param type_: The type. + :return: Its string representation. + """ + module = type_.__module__ + if module == "builtins": + return type_.__name__ + return f"{module}.{type_.__name__}" + + +def unmarshal_type(type_name: str) -> Type: + """ + Given the string representation of a type, return the type itself. + + :param type_name: The string representation of the type. + :return: The type itself. + """ + if "." not in type_name: + type_ = getattr(builtins, type_name, None) + if not type_: + raise DeserializationError(f"Could not locate builtin called '{type_name}'") + return type_ + + parts = type_name.split(".") + module_name = ".".join(parts[:-1]) + type_name = parts[-1] + + module = sys.modules.get(module_name, None) + if not module: + raise DeserializationError(f"Could not locate the module '{module_name}'") + + type_ = getattr(module, type_name, None) + if not type_: + raise DeserializationError(f"Could not locate the type '{type_name}'") + + return type_ diff --git a/test/preview/components/joiner/test_join.py b/test/preview/components/joiner/test_join.py index 4609da6196..481c04e392 100644 --- a/test/preview/components/joiner/test_join.py +++ b/test/preview/components/joiner/test_join.py @@ -9,7 +9,14 @@ class TestJoin: @pytest.mark.unit def test_join_to_dict(self): comp = Join(inputs_count=2, inputs_type=str) - assert comp.to_dict() == {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": str}} + assert comp.to_dict() == {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": "str"}} + + @pytest.mark.unit + def test_join_from_dict(self): + data = {"type": "Join", "init_parameters": {"inputs_count": 2, "inputs_type": "str"}} + comp = Join.from_dict(data) + assert comp.inputs_count == 2 + assert comp.inputs_type == str @pytest.mark.unit def test_join_list(self): From 617fffd1a68a885d8906869d047584c5dc84103f Mon Sep 17 00:00:00 2001 From: ZanSara Date: Wed, 20 Sep 2023 17:52:23 +0200 Subject: [PATCH 11/13] marshalling tests --- test/preview/utils/test_marshalling.py | 37 ++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 test/preview/utils/test_marshalling.py diff --git a/test/preview/utils/test_marshalling.py b/test/preview/utils/test_marshalling.py new file mode 100644 index 0000000000..c44d26c373 --- /dev/null +++ b/test/preview/utils/test_marshalling.py @@ -0,0 +1,37 @@ +import pytest + +from haystack.preview import Document, DeserializationError +from haystack.preview.utils.marshalling import marshal_type, unmarshal_type + + +TYPE_STRING_PAIRS = [(int, "int"), (Document, "haystack.preview.dataclasses.document.Document")] + + +@pytest.mark.unit +@pytest.mark.parametrize("type_,string", TYPE_STRING_PAIRS) +def test_marshal_type(type_, string): + assert marshal_type(type_) == string + + +@pytest.mark.unit +@pytest.mark.parametrize("type_,string", TYPE_STRING_PAIRS) +def test_unmarshal_type(type_, string): + assert unmarshal_type(string) == type_ + + +@pytest.mark.unit +def test_unmarshal_type_missing_builtin(): + with pytest.raises(DeserializationError, match="Could not locate builtin called 'something'"): + unmarshal_type("something") + + +@pytest.mark.unit +def test_unmarshal_type_missing_module(): + with pytest.raises(DeserializationError, match="Could not locate the module 'something'"): + unmarshal_type("something.int") + + +@pytest.mark.unit +def test_unmarshal_type_missing_type(): + with pytest.raises(DeserializationError, match="Could not locate the type 'Documentttt'"): + unmarshal_type("haystack.preview.dataclasses.document.Documentttt") From 28246c53a28f7449def977d33941f1ac5f802bdc Mon Sep 17 00:00:00 2001 From: Daria Fokina Date: Mon, 25 Sep 2023 12:19:00 +0200 Subject: [PATCH 12/13] docstrings update --- haystack/preview/components/joiners/join.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py index 00e37dda9b..c2ba97ac82 100644 --- a/haystack/preview/components/joiners/join.py +++ b/haystack/preview/components/joiners/join.py @@ -6,14 +6,14 @@ @component class Join: """ - Simple component that joins together a group of inputs of the same type. Works with every type that supports - the + operator for joining, like lists, strings, etc. + A simple component that joins together a group of inputs of the same type. Works with every type that supports + the + operator for joining, such as lists, strings, etc. """ def __init__(self, inputs_count: int, inputs_type: Type): """ - :param inputs_count: the number of inputs to expect. - :param inputs_type: the type of the inputs. Every type that supports the + operator works. + :param inputs_count: The number of inputs to expect. + :param inputs_type: The type of the inputs. Every type that supports the + operator works. """ if inputs_count < 1: raise ValueError("inputs_count must be at least 1") @@ -35,7 +35,7 @@ def from_dict(cls, data): def run(self, **kwargs): """ Joins together a group of inputs of the same type. Works with every type that supports the + operator, - like lists, strings, etc. + such as lists, strings, etc. """ if len(kwargs) != self.inputs_count: raise ValueError(f"Join expected {self.inputs_count} inputs, but got {len(kwargs)}") From 684c0bbaeeafb164cfac9223d6923cd183a0f319 Mon Sep 17 00:00:00 2001 From: ZanSara Date: Mon, 25 Sep 2023 17:41:49 +0200 Subject: [PATCH 13/13] explain the type error --- haystack/preview/components/joiners/join.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/haystack/preview/components/joiners/join.py b/haystack/preview/components/joiners/join.py index c2ba97ac82..e819e41b99 100644 --- a/haystack/preview/components/joiners/join.py +++ b/haystack/preview/components/joiners/join.py @@ -1,5 +1,5 @@ from typing import Type -from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError +from haystack.preview import component, default_from_dict, default_to_dict, DeserializationError, ComponentError from haystack.preview.utils import marshal_type, unmarshal_type @@ -42,6 +42,11 @@ def run(self, **kwargs): values = list(kwargs.values()) output = values[0] - for values in values[1:]: - output += values + try: + for values in values[1:]: + output += values + except TypeError: + raise ComponentError( + f"Join expected inputs of a type that supports the + operator, but got: {[type(v) for v in values]}" + ) return {"output": output}