diff --git a/src/main/kotlin/org/phoenixframework/Channel.kt b/src/main/kotlin/org/phoenixframework/Channel.kt index 0794cfb..d6103ab 100644 --- a/src/main/kotlin/org/phoenixframework/Channel.kt +++ b/src/main/kotlin/org/phoenixframework/Channel.kt @@ -38,7 +38,7 @@ data class Binding( */ class Channel( val topic: String, - params: Payload, + paramsClosure: PayloadClosure, internal val socket: Socket ) { @@ -94,10 +94,10 @@ class Channel( internal var timeout: Long /** Params passed in through constructions and provided to the JoinPush */ - var params: Payload = params + var params: Payload + get() = joinPush.payload set(value) { joinPush.payload = value - field = value } /** Set to true once the channel has attempted to join */ @@ -121,6 +121,12 @@ class Channel( */ internal var onMessage: (Message) -> Message = { it } + constructor( + topic: String, + params: Payload, + socket: Socket + ) : this(topic, { params }, socket) + init { this.state = State.CLOSED this.bindings = ConcurrentLinkedQueue() @@ -148,7 +154,7 @@ class Channel( this.joinPush = Push( channel = this, event = Event.JOIN.value, - payload = params, + payloadClosure = paramsClosure, timeout = timeout) // Perform once the Channel has joined diff --git a/src/main/kotlin/org/phoenixframework/Defaults.kt b/src/main/kotlin/org/phoenixframework/Defaults.kt index 5c43d3e..06ecc9f 100644 --- a/src/main/kotlin/org/phoenixframework/Defaults.kt +++ b/src/main/kotlin/org/phoenixframework/Defaults.kt @@ -26,13 +26,9 @@ import com.google.gson.FieldNamingPolicy import com.google.gson.Gson import com.google.gson.GsonBuilder import com.google.gson.JsonObject -import com.google.gson.JsonParser import com.google.gson.reflect.TypeToken -import okhttp3.HttpUrl import okhttp3.HttpUrl.Companion.toHttpUrlOrNull -import org.phoenixframework.Defaults.gson import java.net.URL -import javax.swing.text.html.HTML.Tag.P object Defaults { @@ -156,10 +152,8 @@ object Defaults { httpBuilder.addQueryParameter("vsn", vsn) // Append any additional query params - paramsClosure.invoke()?.let { - it.forEach { (key, value) -> - httpBuilder.addQueryParameter(key, value.toString()) - } + paramsClosure.invoke().forEach { (key, value) -> + httpBuilder.addQueryParameter(key, value.toString()) } // Return the [URL] that will be used to establish a connection diff --git a/src/main/kotlin/org/phoenixframework/Push.kt b/src/main/kotlin/org/phoenixframework/Push.kt index 5234205..432e4bf 100644 --- a/src/main/kotlin/org/phoenixframework/Push.kt +++ b/src/main/kotlin/org/phoenixframework/Push.kt @@ -32,8 +32,8 @@ class Push( val channel: Channel, /** The event the Push is targeting */ val event: String, - /** The message to be sent */ - var payload: Payload = mapOf(), + /** Closure that allows changing parameters sent during push */ + var payloadClosure: PayloadClosure, /** Duration before the message is considered timed out and failed to send */ var timeout: Long = Defaults.TIMEOUT ) { @@ -56,6 +56,23 @@ class Push( /** The event that is associated with the reference ID of the Push */ var refEvent: String? = null + var payload: Payload + get() = payloadClosure.invoke() + set(value) { + payloadClosure = { value } + } + + constructor( + /** The channel the Push is being sent through */ + channel: Channel, + /** The event the Push is targeting */ + event: String, + /** The message to be sent */ + payload: Payload = mapOf(), + /** Duration before the message is considered timed out and failed to send */ + timeout: Long = Defaults.TIMEOUT + ) : this(channel, event, { payload }, timeout) + //------------------------------------------------------------------------------ // Public //------------------------------------------------------------------------------ diff --git a/src/main/kotlin/org/phoenixframework/Socket.kt b/src/main/kotlin/org/phoenixframework/Socket.kt index 253df13..4d3ddef 100644 --- a/src/main/kotlin/org/phoenixframework/Socket.kt +++ b/src/main/kotlin/org/phoenixframework/Socket.kt @@ -100,7 +100,7 @@ const val WS_CLOSE_ABNORMAL = 1006 /** * A closure that will return an optional Payload */ -typealias PayloadClosure = () -> Payload? +typealias PayloadClosure = () -> Payload /** A closure that will encode a Map into a JSON String */ typealias EncodeClosure = (Any) -> String @@ -242,7 +242,7 @@ class Socket( */ constructor( url: String, - params: Payload? = null, + params: Payload = mapOf(), vsn: String = Defaults.VSN, encode: EncodeClosure = Defaults.encode, decode: DecodeClosure = Defaults.decode, @@ -358,9 +358,14 @@ class Socket( fun channel( topic: String, params: Payload = mapOf() + ): Channel = this.channel(topic) { params } + + fun channel( + topic: String, + paramsClosure: PayloadClosure ): Channel { - val channel = Channel(topic, params, this) - this.channels = this.channels + channel + val channel = Channel(topic, paramsClosure, this) + this.channels += channel return channel } diff --git a/src/test/kotlin/org/phoenixframework/ChannelTest.kt b/src/test/kotlin/org/phoenixframework/ChannelTest.kt index bd83c49..78e3820 100644 --- a/src/test/kotlin/org/phoenixframework/ChannelTest.kt +++ b/src/test/kotlin/org/phoenixframework/ChannelTest.kt @@ -149,6 +149,32 @@ class ChannelTest { /* End JoinParams */ } + + @Nested + @DisplayName("join paramsClosure") + inner class JoinParamsClosure { + @Test + internal fun `updating join params closure`() { + val paramsClosure = { mapOf("value" to 1) } + val change = mapOf("value" to 2) + + channel = Channel("topic", paramsClosure, socket) + val joinPush = channel.joinPush + + assertThat(joinPush.channel).isEqualTo(channel) + assertThat(joinPush.payload["value"]).isEqualTo(1) + assertThat(joinPush.event).isEqualTo("phx_join") + assertThat(joinPush.timeout).isEqualTo(10_000L) + + channel.params = change + assertThat(joinPush.channel).isEqualTo(channel) + assertThat(joinPush.payload["value"]).isEqualTo(2) + assertThat(channel.params["value"]).isEqualTo(2) + assertThat(joinPush.event).isEqualTo("phx_join") + assertThat(joinPush.timeout).isEqualTo(10_000L) + } + } + @Nested @DisplayName("join") inner class Join { diff --git a/src/test/kotlin/org/phoenixframework/PresenceTest.kt b/src/test/kotlin/org/phoenixframework/PresenceTest.kt index 4eaee21..2874e1f 100644 --- a/src/test/kotlin/org/phoenixframework/PresenceTest.kt +++ b/src/test/kotlin/org/phoenixframework/PresenceTest.kt @@ -196,7 +196,7 @@ class PresenceTest { @Test internal fun `onJoins new presences and onLeaves left presences`() { val newState = fixState - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u4" to mutableMapOf("metas" to listOf(mapOf("id" to 4, "phx_ref" to "4")))) val joined: PresenceDiff = mutableMapOf() @@ -245,9 +245,9 @@ class PresenceTest { @Test internal fun `onJoins only newly added metas`() { - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3")))) - val newState = mutableMapOf( + val newState: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf( mapOf("id" to 3, "phx_ref" to "3"), mapOf("id" to 3, "phx_ref" to "3.new") @@ -285,9 +285,9 @@ class PresenceTest { @Test internal fun `onLeaves only newly removed metas`() { - val newState = mutableMapOf( + val newState: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf(mapOf("id" to 3, "phx_ref" to "3")))) - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf( mapOf("id" to 3, "phx_ref" to "3"), mapOf("id" to 3, "phx_ref" to "3.left") @@ -326,13 +326,13 @@ class PresenceTest { @Test internal fun `syncs both joined and left metas`() { - val newState = mutableMapOf( + val newState: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf( mapOf("id" to 3, "phx_ref" to "3"), mapOf("id" to 3, "phx_ref" to "3.new") ))) - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u3" to mutableMapOf("metas" to listOf( mapOf("id" to 3, "phx_ref" to "3"), mapOf("id" to 3, "phx_ref" to "3.left") @@ -421,13 +421,13 @@ class PresenceTest { @Test internal fun `removes meta while leaving key if other metas exist`() { - var state = mutableMapOf( + var state: MutableMap>>> = mutableMapOf( "u1" to mutableMapOf("metas" to listOf( mapOf("id" to 1, "phx_ref" to "1"), mapOf("id" to 1, "phx_ref" to "1.2") ))) - val leaves = mutableMapOf( + val leaves: MutableMap>>> = mutableMapOf( "u1" to mutableMapOf("metas" to listOf( mapOf("id" to 1, "phx_ref" to "1") ))) diff --git a/src/test/kotlin/org/phoenixframework/SocketTest.kt b/src/test/kotlin/org/phoenixframework/SocketTest.kt index 5fc512f..6bf44a4 100644 --- a/src/test/kotlin/org/phoenixframework/SocketTest.kt +++ b/src/test/kotlin/org/phoenixframework/SocketTest.kt @@ -48,7 +48,7 @@ class SocketTest { internal fun `sets defaults`() { val socket = Socket("wss://localhost:4000/socket") - assertThat(socket.paramsClosure.invoke()).isNull() + assertThat(socket.paramsClosure.invoke()).isEmpty() assertThat(socket.channels).isEmpty() assertThat(socket.sendBuffer).isEmpty() assertThat(socket.ref).isEqualTo(0)