diff --git a/benchmarks/csharp/Program.cs b/benchmarks/csharp/Program.cs index 11df0e36be..fd7db07424 100644 --- a/benchmarks/csharp/Program.cs +++ b/benchmarks/csharp/Program.cs @@ -14,6 +14,8 @@ using StackExchange.Redis; +using static Glide.ConnectionConfiguration; + public static class MainClass { private enum ChosenAction { GET_NON_EXISTING, GET_EXISTING, SET }; @@ -292,7 +294,11 @@ private static async Task run_with_parameters(int total_commands, { var clients = await createClients(clientCount, () => { - var glide_client = new AsyncClient(host, PORT, useTLS); + var config = new StandaloneClientConfigurationBuilder() + .WithAddress(host, PORT) + .WithTlsMode(useTLS ? TlsMode.SecureTls : TlsMode.NoTls) + .Build(); + var glide_client = new AsyncClient(config); return Task.FromResult<(Func>, Func, Action)>( (async (key) => await glide_client.GetAsync(key), async (key, value) => await glide_client.SetAsync(key, value), diff --git a/csharp/lib/AsyncClient.cs b/csharp/lib/AsyncClient.cs index 83e3d4c39b..d92c6ba39e 100644 --- a/csharp/lib/AsyncClient.cs +++ b/csharp/lib/AsyncClient.cs @@ -4,18 +4,23 @@ using System.Runtime.InteropServices; +using static Glide.ConnectionConfiguration; + namespace Glide; public class AsyncClient : IDisposable { #region public methods - public AsyncClient(string host, UInt32 port, bool useTLS) + public AsyncClient(StandaloneClientConfiguration config) { successCallbackDelegate = SuccessCallback; var successCallbackPointer = Marshal.GetFunctionPointerForDelegate(successCallbackDelegate); failureCallbackDelegate = FailureCallback; var failureCallbackPointer = Marshal.GetFunctionPointerForDelegate(failureCallbackDelegate); - clientPointer = CreateClientFfi(host, port, useTLS, successCallbackPointer, failureCallbackPointer); + var configPtr = Marshal.AllocHGlobal(Marshal.SizeOf(typeof(ConnectionRequest))); + Marshal.StructureToPtr(config.ToRequest(), configPtr, false); + clientPointer = CreateClientFfi(configPtr, successCallbackPointer, failureCallbackPointer); + Marshal.FreeHGlobal(configPtr); if (clientPointer == IntPtr.Zero) { throw new Exception("Failed creating a client"); @@ -104,7 +109,7 @@ private void FailureCallback(ulong index) private delegate void IntAction(IntPtr arg); [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "create_client")] - private static extern IntPtr CreateClientFfi(String host, UInt32 port, bool useTLS, IntPtr successCallback, IntPtr failureCallback); + private static extern IntPtr CreateClientFfi(IntPtr config, IntPtr successCallback, IntPtr failureCallback); [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "close_client")] private static extern void CloseClientFfi(IntPtr client); diff --git a/csharp/lib/ConnectionConfiguration.cs b/csharp/lib/ConnectionConfiguration.cs new file mode 100644 index 0000000000..101ca8119e --- /dev/null +++ b/csharp/lib/ConnectionConfiguration.cs @@ -0,0 +1,563 @@ +/** + * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 + */ + +using System.Runtime.InteropServices; + +namespace Glide; + +public abstract class ConnectionConfiguration +{ + #region Structs and Enums definitions + /// + /// A mirror of ConnectionRequest from connection_request.proto. + /// + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + internal struct ConnectionRequest + { + public nuint AddressCount; + public IntPtr Addresses; // ** NodeAddress - array pointer + public TlsMode TlsMode; + public bool ClusterMode; + public uint RequestTimeout; + public ReadFrom ReadFrom; + public RetryStrategy ConnectionRetryStrategy; + public AuthenticationInfo AuthenticationInfo; + public uint DatabaseId; + public Protocol Protocol; + [MarshalAs(UnmanagedType.LPStr)] + public string? ClientName; + } + + /// + /// Represents the address and port of a node in the cluster. + /// + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + internal struct NodeAddress + { + [MarshalAs(UnmanagedType.LPStr)] + public string Host; + public ushort Port; + } + + /// + /// Represents the strategy used to determine how and when to reconnect, in case of connection + /// failures. The time between attempts grows exponentially, to the formula rand(0 ... factor * + /// (exponentBase ^ N)), where N is the number of failed attempts. + /// + /// Once the maximum value is reached, that will remain the time between retry attempts until a + /// reconnect attempt is successful. The client will attempt to reconnect indefinitely. + /// + /// + [StructLayout(LayoutKind.Sequential)] + public struct RetryStrategy + { + /// + /// Number of retry attempts that the client should perform when disconnected from the server, + /// where the time between retries increases. Once the retries have reached the maximum value, the + /// time between retries will remain constant until a reconnect attempt is successful. + /// + public uint NumberOfRetries; + /// + /// The multiplier that will be applied to the waiting time between each retry. + /// + public uint Factor; + /// + /// The exponent base configured for the strategy. + /// + public uint ExponentBase; + + public RetryStrategy(uint number_of_retries, uint factor, uint exponent_base) + { + NumberOfRetries = number_of_retries; + Factor = factor; + ExponentBase = exponent_base; + } + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + internal struct AuthenticationInfo + { + [MarshalAs(UnmanagedType.LPStr)] + public string? Username; + [MarshalAs(UnmanagedType.LPStr)] + public string Password; + + public AuthenticationInfo(string? username, string password) + { + Username = username; + Password = password; + } + } + + // TODO doc + public enum TlsMode : uint + { + NoTls = 0, + SecureTls = 1, + //InsecureTls = 2, + } + + /// + /// Represents the client's read from strategy. + /// + public enum ReadFrom : uint + { + /// + /// Always get from primary, in order to get the freshest data. + /// + Primary = 0, + /// + /// Spread the requests between all replicas in a round-robin manner. If no replica is available, route the requests to the primary. + /// + PreferReplica = 1, + // TODO: doc or comment out/remove + //LowestLatency = 2, + //AZAffinity = 3, + } + + /// + /// Represents the communication protocol with the server. + /// + public enum Protocol : uint + { + /// + /// Use RESP3 to communicate with the server nodes. + /// + RESP3 = 0, + /// + /// Use RESP2 to communicate with the server nodes. + /// + RESP2 = 1, + } + #endregion + + private static readonly string DEFAULT_HOST = "localhost"; + private static readonly ushort DEFAULT_PORT = 6379; + + /// + /// Basic class which holds common configuration for all types of clients.
+ /// Refer to derived classes for more details: and . + ///
+ public abstract class BaseClientConfiguration + { + internal ConnectionRequest Request; + + internal ConnectionRequest ToRequest() => Request; + } + + /// + /// Configuration for a standalone client. Use to create an instance. + /// + public sealed class StandaloneClientConfiguration : BaseClientConfiguration + { + internal StandaloneClientConfiguration() { } + } + + /// + /// Configuration for a cluster client. Use to create an instance. + /// + public sealed class ClusterClientConfiguration : BaseClientConfiguration + { + internal ClusterClientConfiguration() { } + } + + /// + /// Builder for configuration of common parameters for standalone and cluster client. + /// + /// Derived builder class + public abstract class ClientConfigurationBuilder : IDisposable + where T : ClientConfigurationBuilder, new() + { + internal ConnectionRequest Config; + + protected ClientConfigurationBuilder(bool cluster_mode) + { + Config = new ConnectionRequest { ClusterMode = cluster_mode }; + } + + #region address + private readonly List addresses = new(); + + /// + /// Add a new address to the list.
+ /// See also . + // + // + + protected (string? host, ushort? port) Address + { + set + { + addresses.Add(new NodeAddress + { + Host = value.host ?? DEFAULT_HOST, + Port = value.port ?? DEFAULT_PORT + }); + } + } + + /// + public T WithAddress((string? host, ushort? port) address) + { + Address = (address.host, address.port); + return (T)this; + } + + /// + public T WithAddress((string host, ushort port) address) + { + Address = (address.host, address.port); + return (T)this; + } + + /// + public T WithAddress(string? host, ushort? port) + { + Address = (host, port); + return (T)this; + } + + /// + public T WithAddress(string host, ushort port) + { + Address = (host, port); + return (T)this; + } + + /// + /// Add a new address to the list with default port. + /// + public T WithAddress(string host) + { + Address = (host, DEFAULT_PORT); + return (T)this; + } + + /// + /// Add a new address to the list with default host. + /// + public T WithAddress(ushort port) + { + Address = (DEFAULT_HOST, port); + return (T)this; + } + + /// + /// Syntax sugar helper class for adding addresses. + /// + public sealed class AddressBuilder + { + private readonly ClientConfigurationBuilder owner; + + internal AddressBuilder(ClientConfigurationBuilder owner) + { + this.owner = owner; + } + + /// + public static AddressBuilder operator +(AddressBuilder builder, (string? host, ushort? port) address) + { + builder.owner.WithAddress(address); + return builder; + } + + /// + public static AddressBuilder operator +(AddressBuilder builder, (string host, ushort port) address) + { + builder.owner.WithAddress(address); + return builder; + } + + /// + public static AddressBuilder operator +(AddressBuilder builder, string host) + { + builder.owner.WithAddress(host); + return builder; + } + + /// + public static AddressBuilder operator +(AddressBuilder builder, ushort port) + { + builder.owner.WithAddress(port); + return builder; + } + } + + /// + /// DNS Addresses and ports of known nodes in the cluster. If the server is in cluster mode the + /// list can be partial, as the client will attempt to map out the cluster and find all nodes. If + /// the server is in standalone mode, only nodes whose addresses were provided will be used by the + /// client. + /// + /// For example: + /// [ + /// ("sample-address-0001.use1.cache.amazonaws.com", 6378), + /// ("sample-address-0002.use2.cache.amazonaws.com"), + /// ("sample-address-0002.use3.cache.amazonaws.com", 6380) + /// ] + /// + public AddressBuilder Addresses + { + get + { + return new AddressBuilder(this); + } + set { } // needed for += + } + // TODO possible options : list and array + #endregion + #region TLS + /// + /// Configure whether communication with the server should use Transport Level Security.
+ /// Should match the TLS configuration of the server/cluster, otherwise the connection attempt will fail. + ///
+ public TlsMode TlsMode + { + set + { + Config.TlsMode = value; + } + } + /// + public T WithTlsMode(TlsMode tls_mode) + { + TlsMode = tls_mode; + return (T)this; + } + /// + public T With(TlsMode tls_mode) + { + return WithTlsMode(tls_mode); + } + #endregion + #region Request Timeout + /// + /// The duration in milliseconds that the client should wait for a request to complete. This + /// duration encompasses sending the request, awaiting for a response from the server, and any + /// required reconnections or retries. If the specified timeout is exceeded for a pending request, + /// it will result in a timeout error. If not set, a default value will be used. + /// + public uint RequestTimeout + { + set + { + Config.RequestTimeout = value; + } + } + /// + public T WithRequestTimeout(uint request_timeout) + { + RequestTimeout = request_timeout; + return (T)this; + } + #endregion + #region Read From + /// + /// Configure the client's read from strategy. If not set, will be used. + /// + public ReadFrom ReadFrom + { + set + { + Config.ReadFrom = value; + } + } + /// + public T WithReadFrom(ReadFrom read_from) + { + ReadFrom = read_from; + return (T)this; + } + /// + public T With(ReadFrom read_from) + { + return WithReadFrom(read_from); + } + #endregion + #region Authentication + /// + /// Configure credentials for authentication process. If none are set, the client will not authenticate itself with the server. + /// + /// + /// username The username that will be used for authenticating connections to the Redis servers. If not supplied, "default" will be used.
+ /// password The password that will be used for authenticating connections to the Redis servers. + ///
+ public (string? username, string password) Authentication + { + set + { + Config.AuthenticationInfo = new AuthenticationInfo + ( + value.username, + value.password + ); + } + } + /// + /// Configure credentials for authentication process. If none are set, the client will not authenticate itself with the server. + /// + /// The username that will be used for authenticating connections to the Redis servers. If not supplied, "default" will be used.> + /// The password that will be used for authenticating connections to the Redis servers. + public T WithAuthentication(string? username, string password) + { + Authentication = (username, password); + return (T)this; + } + /// + public T WithAuthentication((string? username, string password) credentials) + { + return WithAuthentication(credentials.username, credentials.password); + } + #endregion + #region Protocol + /// + /// Configure the protocol version to use. If not set, will be used.
+ /// See also . + ///
+ public Protocol ProtocolVersion + { + set + { + Config.Protocol = value; + } + } + + /// + public T WithProtocolVersion(Protocol protocol) + { + ProtocolVersion = protocol; + return (T)this; + } + + /// + public T With(Protocol protocol) + { + ProtocolVersion = protocol; + return (T)this; + } + #endregion + #region Client Name + /// + /// Client name to be used for the client. Will be used with CLIENT SETNAME command during connection establishment. + /// + public string? ClientName + { + set + { + Config.ClientName = value; + } + } + + /// + public T WithClientName(string? clientName) + { + ClientName = clientName; + return (T)this; + } + #endregion + + public void Dispose() => Clean(); + + private void Clean() + { + if (Config.Addresses != IntPtr.Zero) + { + Marshal.FreeHGlobal(Config.Addresses); + Config.Addresses = IntPtr.Zero; + } + } + + internal ConnectionRequest Build() + { + Clean(); // memory leak protection on rebuilding a config from the builder + Config.AddressCount = (uint)addresses.Count; + var address_size = Marshal.SizeOf(typeof(NodeAddress)); + Config.Addresses = Marshal.AllocHGlobal(address_size * addresses.Count); + for (int i = 0; i < addresses.Count; i++) + { + Marshal.StructureToPtr(addresses[i], Config.Addresses + i * address_size, false); + } + return Config; + } + } + + /// + /// Represents the configuration settings for a Standalone Redis client. + /// + public class StandaloneClientConfigurationBuilder : ClientConfigurationBuilder + { + public StandaloneClientConfigurationBuilder() : base(false) { } + + /// + /// Complete the configuration with given settings. + /// + public new StandaloneClientConfiguration Build() => new() { Request = base.Build() }; + + #region DataBase ID + // TODO: not used + /// + /// Index of the logical database to connect to. + /// + public uint DataBaseId + { + set + { + Config.DatabaseId = value; + } + } + /// + public StandaloneClientConfigurationBuilder WithDataBaseId(uint dataBaseId) + { + DataBaseId = dataBaseId; + return this; + } + #endregion + #region Connection Retry Strategy + /// + /// Strategy used to determine how and when to reconnect, in case of connection failures.
+ /// See also + ///
+ public RetryStrategy ConnectionRetryStrategy + { + set + { + Config.ConnectionRetryStrategy = value; + } + } + /// + public StandaloneClientConfigurationBuilder WithConnectionRetryStrategy(RetryStrategy connection_retry_strategy) + { + ConnectionRetryStrategy = connection_retry_strategy; + return this; + } + /// + public StandaloneClientConfigurationBuilder With(RetryStrategy connection_retry_strategy) + { + return WithConnectionRetryStrategy(connection_retry_strategy); + } + /// + /// + /// + /// + public StandaloneClientConfigurationBuilder WithConnectionRetryStrategy(uint number_of_retries, uint factor, uint exponent_base) + { + return WithConnectionRetryStrategy(new RetryStrategy(number_of_retries, factor, exponent_base)); + } + #endregion + } + + /// + /// Represents the configuration settings for a Cluster Redis client.
+ /// Notes: Currently, the reconnection strategy in cluster mode is not configurable, and exponential backoff with fixed values is used. + ///
+ public class ClusterClientConfigurationBuilder : ClientConfigurationBuilder + { + public ClusterClientConfigurationBuilder() : base(true) { } + + /// + /// Complete the configuration with given settings. + /// + public new ClusterClientConfiguration Build() => new() { Request = base.Build() }; + } +} diff --git a/csharp/lib/src/configuration.rs b/csharp/lib/src/configuration.rs new file mode 100644 index 0000000000..06e27d0f1e --- /dev/null +++ b/csharp/lib/src/configuration.rs @@ -0,0 +1,88 @@ +/** + * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 + */ +use std::ffi::c_char; + +/// A mirror of `ConnectionRequest` from [`connection_request.proto`](https://github.com/aws/glide-for-redis/blob/main/glide-core/src/protobuf/connection_request.proto). +#[repr(C)] +pub struct ConnectionConfig { + pub address_count: usize, + /// Pointer to an array. + pub addresses: *const *const NodeAddress, + pub tls_mode: TlsMode, + pub cluster_mode: bool, + pub request_timeout: u32, + pub read_from: ReadFrom, + pub connection_retry_strategy: ConnectionRetryStrategy, + pub authentication_info: AuthenticationInfo, + pub database_id: u32, + pub protocol: ProtocolVersion, + pub client_name: *const c_char, +} + +/// A mirror of `NodeAddress` from [`connection_request.proto`](https://github.com/aws/glide-for-redis/blob/main/glide-core/src/protobuf/connection_request.proto). +/// Represents the address and port of a node in the cluster. +#[repr(C)] +pub struct NodeAddress { + pub host: *const c_char, + pub port: u16, +} + +/// A mirror of `TlsMode` from [`connection_request.proto`](https://github.com/aws/glide-for-redis/blob/main/glide-core/src/protobuf/connection_request.proto). +#[repr(C)] +pub enum TlsMode { + NoTls = 0, + Secure = 1, + Insecure = 2, +} + +/// A mirror of `ReadFrom` from [`connection_request.proto`](https://github.com/aws/glide-for-redis/blob/main/glide-core/src/protobuf/connection_request.proto). +/// Represents the client's read from strategy. +#[repr(C)] +pub enum ReadFrom { + /// Always get from primary, in order to get the freshest data. + Primary = 0, + /// Spread the requests between all replicas in a round-robin manner. If no replica is available, route the requests to the primary. + PreferReplica = 1, + LowestLatency = 2, + AZAffinity = 3, +} + +/// A mirror of `ConnectionRetryStrategy` from [`connection_request.proto`](https://github.com/aws/glide-for-redis/blob/main/glide-core/src/protobuf/connection_request.proto). +/// Represents the strategy used to determine how and when to reconnect, in case of connection failures. +/// The time between attempts grows exponentially, to the formula +/// ``` +/// rand(0 ... factor * (exponentBase ^ N)) +/// ``` +/// where `N` is the number of failed attempts. +/// +/// Once the maximum value is reached, that will remain the time between retry attempts until a +/// reconnect attempt is successful. The client will attempt to reconnect indefinitely. +#[repr(C)] +pub struct ConnectionRetryStrategy { + /// Number of retry attempts that the client should perform when disconnected from the server, + /// where the time between retries increases. Once the retries have reached the maximum value, the + /// time between retries will remain constant until a reconnect attempt is successful. + pub number_of_retries: u32, + /// The multiplier that will be applied to the waiting time between each retry. + pub factor: u32, + /// The exponent base configured for the strategy. + pub exponent_base: u32, +} + +/// A mirror of `AuthenticationInfo` from [`connection_request.proto`](https://github.com/aws/glide-for-redis/blob/main/glide-core/src/protobuf/connection_request.proto). +#[repr(C)] +pub struct AuthenticationInfo { + pub username: *const c_char, + pub password: *const c_char, +} + +/// A mirror of `ProtocolVersion` from [`connection_request.proto`](https://github.com/aws/glide-for-redis/blob/main/glide-core/src/protobuf/connection_request.proto). +/// Represents the communication protocol with the server. +#[repr(C)] +pub enum ProtocolVersion { + /// Use RESP3 to communicate with the server nodes. + RESP3 = 0, + /// Use RESP2 to communicate with the server nodes. + RESP2 = 1, +} diff --git a/csharp/lib/src/lib.rs b/csharp/lib/src/lib.rs index 495d959598..c951da2fd4 100644 --- a/csharp/lib/src/lib.rs +++ b/csharp/lib/src/lib.rs @@ -1,15 +1,17 @@ /** * Copyright GLIDE-for-Redis Project Contributors - SPDX Identifier: Apache-2.0 */ +pub mod configuration; +use configuration::{ConnectionConfig, NodeAddress, ProtocolVersion, ReadFrom, TlsMode}; + +use glide_core::client::Client as GlideClient; use glide_core::connection_request; -use glide_core::{client::Client as GlideClient, connection_request::NodeAddress}; use redis::{Cmd, FromRedisValue, RedisResult}; use std::{ ffi::{c_void, CStr, CString}, os::raw::c_char, }; -use tokio::runtime::Builder; -use tokio::runtime::Runtime; +use tokio::runtime::{Builder, Runtime}; pub enum Level { Error = 0, @@ -26,37 +28,111 @@ pub struct Client { runtime: Runtime, } -fn create_connection_request( - host: String, - port: u32, - use_tls: bool, +/// Convert raw C string to a rust string. +/// +/// # Safety +/// +/// * `ptr` must be able to be safely casted to a valid `CStr` via `CStr::from_ptr`. See the safety documentation of [`std::ffi::CStr::from_ptr`](https://doc.rust-lang.org/std/ffi/struct.CStr.html#method.from_ptr). +unsafe fn ptr_to_str(ptr: *const c_char) -> String { + if ptr as i64 != 0 { + unsafe { CStr::from_ptr(ptr) }.to_str().unwrap().into() + } else { + "".into() + } +} + +/// Convert raw array pointer to a vector of [`NodeAddress`](NodeAddress)es. +/// +/// # Safety +/// +/// * `len` must not be greater than `isize::MAX`. See the safety documentation of [`std::slice::from_raw_parts`](https://doc.rust-lang.org/std/slice/fn.from_raw_parts.html). +/// * `data` must not be null. +/// * `data` must point to `len` consecutive properly initialized [`NodeAddress`](NodeAddress) structs. +/// * Each [`NodeAddress`](NodeAddress) dereferenced by `data` must contain a valid string pointer. See the safety documentation of [`ptr_to_str`](ptr_to_str). +#[allow(rustdoc::redundant_explicit_links)] +unsafe fn node_addresses_to_proto( + data: *const *const NodeAddress, + len: usize, +) -> Vec { + unsafe { std::slice::from_raw_parts(data as *mut NodeAddress, len) } + .iter() + .map(|addr| { + let mut address_info = connection_request::NodeAddress::new(); + address_info.host = unsafe { ptr_to_str(addr.host) }.into(); + address_info.port = addr.port as u32; + address_info + }) + .collect() +} + +/// Convert connection configuration to a corresponding protobuf object. +/// +/// # Safety +/// +/// * `config` must not be null. +/// * `config` must be a valid pointer to a [`ConnectionConfig`](ConnectionConfig) struct. +/// * Dereferenced [`ConnectionConfig`](ConnectionConfig) struct and all nested structs must contain valid pointers. See the safety documentation of [`node_addresses_to_proto`](node_addresses_to_proto) and [`ptr_to_str`](ptr_to_str). +#[allow(rustdoc::redundant_explicit_links)] +unsafe fn create_connection_request( + config: *const ConnectionConfig, ) -> connection_request::ConnectionRequest { - let mut address_info = NodeAddress::new(); - address_info.host = host.to_string().into(); - address_info.port = port; - let addresses_info = vec![address_info]; let mut connection_request = connection_request::ConnectionRequest::new(); - connection_request.addresses = addresses_info; - connection_request.tls_mode = if use_tls { - connection_request::TlsMode::SecureTls - } else { - connection_request::TlsMode::NoTls + + let config_ref = unsafe { &*config }; + + connection_request.addresses = + unsafe { node_addresses_to_proto(config_ref.addresses, config_ref.address_count) }; + + connection_request.tls_mode = match config_ref.tls_mode { + TlsMode::Secure => connection_request::TlsMode::SecureTls, + TlsMode::Insecure => connection_request::TlsMode::InsecureTls, + TlsMode::NoTls => connection_request::TlsMode::NoTls, + } + .into(); + connection_request.cluster_mode_enabled = config_ref.cluster_mode; + connection_request.request_timeout = config_ref.request_timeout; + + connection_request.read_from = match config_ref.read_from { + ReadFrom::AZAffinity => connection_request::ReadFrom::AZAffinity, + ReadFrom::PreferReplica => connection_request::ReadFrom::PreferReplica, + ReadFrom::Primary => connection_request::ReadFrom::Primary, + ReadFrom::LowestLatency => connection_request::ReadFrom::LowestLatency, + } + .into(); + + let mut retry_strategy = connection_request::ConnectionRetryStrategy::new(); + retry_strategy.number_of_retries = config_ref.connection_retry_strategy.number_of_retries; + retry_strategy.factor = config_ref.connection_retry_strategy.factor; + retry_strategy.exponent_base = config_ref.connection_retry_strategy.exponent_base; + connection_request.connection_retry_strategy = Some(retry_strategy).into(); + + let mut auth_info = connection_request::AuthenticationInfo::new(); + auth_info.username = unsafe { ptr_to_str(config_ref.authentication_info.username) }.into(); + auth_info.password = unsafe { ptr_to_str(config_ref.authentication_info.password) }.into(); + connection_request.authentication_info = Some(auth_info).into(); + + connection_request.database_id = config_ref.database_id; + connection_request.protocol = match config_ref.protocol { + ProtocolVersion::RESP2 => connection_request::ProtocolVersion::RESP2, + ProtocolVersion::RESP3 => connection_request::ProtocolVersion::RESP3, } .into(); + connection_request.client_name = unsafe { ptr_to_str(config_ref.client_name) }.into(); + connection_request } -fn create_client_internal( - host: *const c_char, - port: u32, - use_tls: bool, +/// # Safety +/// +/// * `config` must be a valid [`ConnectionConfig`](ConnectionConfig) pointer. See the safety documentation of [`create_connection_request`](create_connection_request). +#[allow(rustdoc::redundant_explicit_links)] +unsafe fn create_client_internal( + config: *const ConnectionConfig, success_callback: unsafe extern "C" fn(usize, *const c_char) -> (), failure_callback: unsafe extern "C" fn(usize) -> (), ) -> RedisResult { - let host_cstring = unsafe { CStr::from_ptr(host as *mut c_char) }; - let host_string = host_cstring.to_str()?.to_string(); - let request = create_connection_request(host_string, port, use_tls); + let request = unsafe { create_connection_request(config) }; let runtime = Builder::new_multi_thread() .enable_all() .thread_name("GLIDE for Redis C# thread") @@ -71,31 +147,49 @@ fn create_client_internal( }) } -/// Creates a new client to the given address. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool. +/// Creates a new client with the given configuration. The success callback needs to copy the given string synchronously, since it will be dropped by Rust once the callback returns. All callbacks should be offloaded to separate threads in order not to exhaust the client's thread pool. +/// +/// # Safety +/// +/// * `config` must be a valid [`ConnectionConfig`](ConnectionConfig) pointer. See the safety documentation of [`create_client_internal`](create_client_internal). +#[allow(rustdoc::redundant_explicit_links)] +#[allow(rustdoc::private_intra_doc_links)] #[no_mangle] -pub extern "C" fn create_client( - host: *const c_char, - port: u32, - use_tls: bool, +pub unsafe extern "C" fn create_client( + config: *const ConnectionConfig, success_callback: unsafe extern "C" fn(usize, *const c_char) -> (), failure_callback: unsafe extern "C" fn(usize) -> (), ) -> *const c_void { - match create_client_internal(host, port, use_tls, success_callback, failure_callback) { + match unsafe { create_client_internal(config, success_callback, failure_callback) } { Err(_) => std::ptr::null(), // TODO - log errors Ok(client) => Box::into_raw(Box::new(client)) as *const c_void, } } +/// Closes the given client, deallocating it from the heap. +/// +/// # Safety +/// +/// * `client_ptr` must not be null. +/// * `client_ptr` must be able to be safely casted to a valid `Box` via `Box::from_raw`. See the safety documentation of [`std::boxed::Box::from_raw`](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.from_raw). #[no_mangle] -pub extern "C" fn close_client(client_ptr: *const c_void) { +pub unsafe extern "C" fn close_client(client_ptr: *const c_void) { let client_ptr = unsafe { Box::from_raw(client_ptr as *mut Client) }; let _runtime_handle = client_ptr.runtime.enter(); drop(client_ptr); } -/// Expects that key and value will be kept valid until the callback is called. +/// Execute `SET` command. See [`redis.io`](https://redis.io/commands/set/) for details. +/// +/// # Safety +/// +/// * `client_ptr` must not be null. +/// * `client_ptr` must be able to be safely casted to a valid `Box` via `Box::from_raw`. See the safety documentation of [`std::boxed::Box::from_raw`](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.from_raw). +/// * `key` and `value` must not be null. +/// * `key` and `value` must be able to be safely casted to a valid `CStr` via `CStr::from_ptr`. See the safety documentation of [`std::ffi::CStr::from_ptr`](https://doc.rust-lang.org/std/ffi/struct.CStr.html#method.from_ptr). +/// * `key` and `value` must be kept valid until the callback is called. #[no_mangle] -pub extern "C" fn set( +pub unsafe extern "C" fn set( client_ptr: *const c_void, callback_index: usize, key: *const c_char, @@ -124,10 +218,18 @@ pub extern "C" fn set( }); } -/// Expects that key will be kept valid until the callback is called. If the callback is called with a string pointer, the pointer must -/// be used synchronously, because the string will be dropped after the callback. +/// Execute `GET` command. See [`redis.io`](https://redis.io/commands/get/) for details. +/// +/// # Safety +/// +/// * `client_ptr` must not be null. +/// * `client_ptr` must be able to be safely casted to a valid `Box` via `Box::from_raw`. See the safety documentation of [`std::boxed::Box::from_raw`](https://doc.rust-lang.org/std/boxed/struct.Box.html#method.from_raw). +/// * `key` must not be null. +/// * `key` must be able to be safely casted to a valid `CStr` via `CStr::from_ptr`. See the safety documentation of [`std::ffi::CStr::from_ptr`](https://doc.rust-lang.org/std/ffi/struct.CStr.html#method.from_ptr). +/// * `key` must be kept valid until the callback is called. +/// * If the callback is called with a string pointer, the pointer must be used synchronously, because the string will be dropped after the callback. #[no_mangle] -pub extern "C" fn get(client_ptr: *const c_void, callback_index: usize, key: *const c_char) { +pub unsafe extern "C" fn get(client_ptr: *const c_void, callback_index: usize, key: *const c_char) { let client = unsafe { Box::leak(Box::from_raw(client_ptr as *mut Client)) }; // The safety of this needs to be ensured by the calling code. Cannot dispose of the pointer before all operations have completed. let ptr_address = client_ptr as usize; @@ -183,10 +285,12 @@ impl From for logger_core::Level { } } +/// # Safety +/// +/// * `message` must not be null. +/// * `message` must be able to be safely casted to a valid `CStr` via `CStr::from_ptr`. See the safety documentation of [`std::ffi::CStr::from_ptr`](https://doc.rust-lang.org/std/ffi/struct.CStr.html#method.from_ptr). #[no_mangle] #[allow(improper_ctypes_definitions)] -/// # Safety -/// Unsafe function because creating string from pointer pub unsafe extern "C" fn log( log_level: Level, log_identifier: *const c_char, @@ -205,10 +309,12 @@ pub unsafe extern "C" fn log( } } +/// # Safety +/// +/// * `file_name` must not be null. +/// * `file_name` must be able to be safely casted to a valid `CStr` via `CStr::from_ptr`. See the safety documentation of [`std::ffi::CStr::from_ptr`](https://doc.rust-lang.org/std/ffi/struct.CStr.html#method.from_ptr). #[no_mangle] #[allow(improper_ctypes_definitions)] -/// # Safety -/// Unsafe function because creating string from pointer pub unsafe extern "C" fn init(level: Option, file_name: *const c_char) -> Level { let file_name_as_str; unsafe { diff --git a/csharp/tests/AsyncClientTests.cs b/csharp/tests/AsyncClientTests.cs index e9adfdf97b..80645b49d3 100644 --- a/csharp/tests/AsyncClientTests.cs +++ b/csharp/tests/AsyncClientTests.cs @@ -6,6 +6,8 @@ namespace tests; using Glide; +using static Glide.ConnectionConfiguration; + // TODO - need to start a new redis server for each test? public class AsyncClientTests { @@ -24,10 +26,16 @@ private async Task GetAndSetRandomValues(AsyncClient client) Assert.That(result, Is.EqualTo(value)); } + private StandaloneClientConfiguration GetConfig() + { + return new StandaloneClientConfigurationBuilder() + .WithAddress("localhost", 6379).Build(); + } + [Test] public async Task GetReturnsLastSet() { - using (var client = new AsyncClient("localhost", 6379, false)) + using (var client = new AsyncClient(GetConfig())) { await GetAndSetRandomValues(client); } @@ -36,7 +44,7 @@ public async Task GetReturnsLastSet() [Test] public async Task GetAndSetCanHandleNonASCIIUnicode() { - using (var client = new AsyncClient("localhost", 6379, false)) + using (var client = new AsyncClient(GetConfig())) { var key = Guid.NewGuid().ToString(); var value = "שלום hello 汉字"; @@ -49,7 +57,7 @@ public async Task GetAndSetCanHandleNonASCIIUnicode() [Test] public async Task GetReturnsNull() { - using (var client = new AsyncClient("localhost", 6379, false)) + using (var client = new AsyncClient(GetConfig())) { var result = await client.GetAsync(Guid.NewGuid().ToString()); Assert.That(result, Is.EqualTo(null)); @@ -59,7 +67,7 @@ public async Task GetReturnsNull() [Test] public async Task GetReturnsEmptyString() { - using (var client = new AsyncClient("localhost", 6379, false)) + using (var client = new AsyncClient(GetConfig())) { var key = Guid.NewGuid().ToString(); var value = ""; @@ -72,7 +80,7 @@ public async Task GetReturnsEmptyString() [Test] public async Task HandleVeryLargeInput() { - using (var client = new AsyncClient("localhost", 6379, false)) + using (var client = new AsyncClient(GetConfig())) { var key = Guid.NewGuid().ToString(); var value = Guid.NewGuid().ToString(); @@ -92,7 +100,7 @@ public async Task HandleVeryLargeInput() [Test] public void ConcurrentOperationsWork() { - using (var client = new AsyncClient("localhost", 6379, false)) + using (var client = new AsyncClient(GetConfig())) { var operations = new List();