From 2c47af6b290b5ddb2d3827d8e993680eb3cc9557 Mon Sep 17 00:00:00 2001 From: seefs001 Date: Fri, 27 Sep 2024 20:45:41 +0800 Subject: [PATCH] add xai --- examples/xai_example/main.go | 76 +++++ x/x.go | 27 ++ xai/xai.go | 581 +++++++++++++++++++++++++++++++++++ xhttpc/xhttpc.go | 447 +++++++++++++++++++++------ xhttpc/xhttpc_test.go | 21 +- 5 files changed, 1045 insertions(+), 107 deletions(-) create mode 100644 examples/xai_example/main.go create mode 100644 xai/xai.go diff --git a/examples/xai_example/main.go b/examples/xai_example/main.go new file mode 100644 index 0000000..9eff39e --- /dev/null +++ b/examples/xai_example/main.go @@ -0,0 +1,76 @@ +package main + +import ( + "context" + + "github.com/seefs001/xox/xai" + "github.com/seefs001/xox/xenv" + "github.com/seefs001/xox/xlog" +) + +func main() { + xenv.Load() + client := xai.NewOpenAIClient() + + // Text generation (non-streaming) + response, err := client.QuickGenerateText(context.Background(), []string{"Hello, world!"}, xai.WithTextModel(xai.ModelClaude35Sonnet)) + if err != nil { + xlog.Error("Error generating text", "error", err) + return + } + xlog.Info("Text generation response:") + xlog.Info(response) + + // Text generation (streaming) + xlog.Info("Streaming text generation:") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + textChan, errChan := client.QuickGenerateTextStream(ctx, []string{"Hello, world!"}, xai.WithTextModel(xai.ModelClaude35Sonnet)) + + streamFinished := make(chan struct{}) + go func() { + defer close(streamFinished) + for { + select { + case text, ok := <-textChan: + if !ok { + xlog.Info("Stream finished") + return + } + xlog.Info(text) + case err, ok := <-errChan: + if !ok { + return + } + if err != nil { + xlog.Error("Error generating text stream", "error", err) + return + } + case <-ctx.Done(): + return + } + } + }() + + <-streamFinished + + // Image generation + xlog.Info("Image generation:") + imageURLs, err := client.GenerateImage(context.Background(), "A beautiful sunset over the ocean", xai.WithImageModel(xai.DefaultImageModel)) + if err != nil { + xlog.Error("Error generating image", "error", err) + return + } + for i, url := range imageURLs { + xlog.Infof("Image %d URL: %s", i+1, url) + } + + // Embedding generation + xlog.Info("Embedding generation:") + embeddings, err := client.CreateEmbeddings(context.Background(), []string{"Hello, world!"}, xai.DefaultEmbeddingModel) + if err != nil { + xlog.Error("Error creating embeddings", "error", err) + return + } + xlog.Infof("Embedding for 'Hello, world!': %v", embeddings[0][:5]) // Print first 5 values of the embedding +} diff --git a/x/x.go b/x/x.go index c336011..13b5a4f 100644 --- a/x/x.go +++ b/x/x.go @@ -3,12 +3,14 @@ package x import ( "context" "crypto/rand" + "encoding/base64" "encoding/json" "errors" "fmt" "math/big" "reflect" "runtime/debug" + "strings" "sync" "time" ) @@ -1010,3 +1012,28 @@ func IsZero[T any](value T) bool { } return false } + +// IsImageURL checks if a string is a valid image URL +func IsImageURL(s string) bool { + // Check if the URL starts with http:// or https:// + if !strings.HasPrefix(s, "http://") && !strings.HasPrefix(s, "https://") { + return false + } + + // Simple check for common image file extensions + // You might want to implement a more robust check based on your requirements + extensions := []string{".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp"} + lowered := strings.ToLower(s) + for _, ext := range extensions { + if strings.HasSuffix(lowered, ext) { + return true + } + } + return false +} + +// IsBase64 checks if a string is a valid base64 encoded string +func IsBase64(s string) bool { + _, err := base64.StdEncoding.DecodeString(s) + return err == nil +} diff --git a/xai/xai.go b/xai/xai.go new file mode 100644 index 0000000..c3dbd3d --- /dev/null +++ b/xai/xai.go @@ -0,0 +1,581 @@ +package xai + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "time" + + "github.com/seefs001/xox/x" + "github.com/seefs001/xox/xhttpc" +) + +// OpenAIClient represents a client for interacting with the OpenAI API +type OpenAIClient struct { + baseURL string + apiKey string + httpClient *xhttpc.Client + model string + debug bool +} + +// OpenAIClientOption is a function type for configuring the OpenAIClient +type OpenAIClientOption func(*OpenAIClient) + +// TextGenerationOptions contains options for generating text +type TextGenerationOptions struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + SystemPrompt string `json:"system"` + Messages []Message `json:"messages"` + IsStreaming bool `json:"stream"` + ObjectSchema string `json:"object_schema"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + TopP float64 `json:"top_p,omitempty"` + N int `json:"n,omitempty"` + ChunkSize int `json:"chunk_size,omitempty"` +} + +// Message represents a single message in a conversation +type Message struct { + Role string `json:"role"` + Content string `json:"content"` + Image string `json:"image,omitempty"` +} + +// ChatCompletionResponse represents the response from the chat completion API +type ChatCompletionResponse struct { + ID string `json:"id"` + ObjectType string `json:"object"` + CreatedAt int64 `json:"created"` + ModelName string `json:"model"` + UsageInfo Usage `json:"usage"` + Choices []Choice `json:"choices"` +} + +// Usage represents the token usage information +type Usage struct { + PromptTokenCount int `json:"prompt_tokens"` + CompletionTokenCount int `json:"completion_tokens"` + TotalTokenCount int `json:"total_tokens"` +} + +// Choice represents a single choice in the API response +type Choice struct { + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` + Index int `json:"index"` +} + +// StreamResponse represents a single chunk of the streaming response +type StreamResponse struct { + ID string `json:"id"` + ObjectType string `json:"object"` + CreatedAt int64 `json:"created"` + ModelName string `json:"model"` + Choices []StreamChoice `json:"choices"` +} + +// StreamChoice represents a single choice in the streaming response +type StreamChoice struct { + Delta StreamDelta `json:"delta"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` +} + +// StreamDelta represents the delta content in a streaming response +type StreamDelta struct { + Content string `json:"content"` +} + +// EmbeddingRequest represents the request for creating embeddings +type EmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` +} + +// EmbeddingResponse represents the response from the embeddings API +type EmbeddingResponse struct { + Object string `json:"object"` + Data []struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` +} + +// ImageGenerationRequest represents the request for generating images +type ImageGenerationRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} + +// ImageGenerationResponse represents the response from the image generation API +type ImageGenerationResponse struct { + Created int64 `json:"created"` + Data []struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` + } `json:"data"` +} + +// Constants for message roles and models +const ( + MessageRoleSystem = "system" + MessageRoleUser = "user" + MessageRoleAssistant = "assistant" + + DefaultModel = "gpt-3.5-turbo" + ModelGPT4o = "gpt-4o" + ModelClaude35Sonnet = "claude-3-5-sonnet-20240620" +) + +// Environment variable keys +const ( + EnvOpenAIAPIKey = "OPENAI_API_KEY" + EnvOpenAIBaseURL = "OPENAI_API_BASE" +) + +// API endpoints +const ( + DefaultBaseURL = "https://api.openai.com/v1" + ChatCompletionsURL = "/chat/completions" + EmbeddingsURL = "/embeddings" + ImageGenerationURL = "/images/generations" + DefaultEmbeddingModel = "text-embedding-ada-002" + DefaultImageModel = "dall-e-3" + DefaultChunkSize = 100 // Default chunk size for streaming +) + +// WithBaseURL sets the base URL for the OpenAI API +func WithBaseURL(url string) OpenAIClientOption { + return func(c *OpenAIClient) { + c.baseURL = url + } +} + +// WithAPIKey sets the API key for authentication +func WithAPIKey(key string) OpenAIClientOption { + return func(c *OpenAIClient) { + c.apiKey = key + } +} + +// WithHTTPClient sets a custom HTTP client +func WithHTTPClient(client *xhttpc.Client) OpenAIClientOption { + return func(c *OpenAIClient) { + c.httpClient = client + } +} + +// WithModel sets the default model for the OpenAI client +func WithModel(model string) OpenAIClientOption { + return func(c *OpenAIClient) { + c.model = model + } +} + +// WithDebug enables or disables debug mode +func WithDebug(debug bool) OpenAIClientOption { + return func(c *OpenAIClient) { + c.debug = debug + } +} + +// NewOpenAIClient creates a new OpenAIClient with the given options +func NewOpenAIClient(options ...OpenAIClientOption) *OpenAIClient { + client := &OpenAIClient{ + baseURL: DefaultBaseURL, + httpClient: xhttpc.NewClient( + xhttpc.WithTimeout(30 * time.Second), + ), + model: DefaultModel, + debug: false, + } + + client.loadEnvironmentVariables() + + for _, option := range options { + option(client) + } + + if client.debug { + client.httpClient.SetDebug(true) + } + + return client +} + +func (c *OpenAIClient) loadEnvironmentVariables() { + if apiKey := os.Getenv(EnvOpenAIAPIKey); apiKey != "" { + c.apiKey = apiKey + } + if baseURL := os.Getenv(EnvOpenAIBaseURL); baseURL != "" { + c.baseURL = baseURL + } +} + +// validateTextGenerationOptions checks if the provided options are valid +func validateTextGenerationOptions(options *TextGenerationOptions) error { + hasMessages := len(options.Messages) > 0 + hasPromptOrSystem := options.Prompt != "" || options.SystemPrompt != "" + + if hasMessages && hasPromptOrSystem { + return fmt.Errorf("either 'Messages' or 'Prompt'/'SystemPrompt' should be provided, not both") + } + + if !hasMessages && !hasPromptOrSystem { + return fmt.Errorf("either 'Messages' or 'Prompt'/'SystemPrompt' must be provided") + } + + return nil +} + +// GenerateText generates text based on the provided options +func (c *OpenAIClient) GenerateText(ctx context.Context, options TextGenerationOptions) (string, error) { + if err := validateTextGenerationOptions(&options); err != nil { + return "", err + } + + if x.IsEmpty(options.Model) { + options.Model = c.model + } + + requestBody := map[string]interface{}{ + "model": options.Model, + "messages": options.Messages, + } + + // Only add non-default parameters + if options.Temperature != 0 { + requestBody["temperature"] = options.Temperature + } + if options.MaxTokens != 0 { + requestBody["max_tokens"] = options.MaxTokens + } + if options.TopP != 0 { + requestBody["top_p"] = options.TopP + } + if options.N != 0 { + requestBody["n"] = options.N + } + if options.IsStreaming { + requestBody["stream"] = options.IsStreaming + } + + resp, err := c.sendRequest(ctx, ChatCompletionsURL, requestBody) + if err != nil { + return "", err + } + defer resp.Body.Close() + + var result ChatCompletionResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("error decoding response: %w", err) + } + + if len(result.Choices) == 0 { + return "", fmt.Errorf("no choices returned from API") + } + + return result.Choices[0].Message.Content, nil +} + +// GenerateTextStream generates text in a streaming fashion +func (c *OpenAIClient) GenerateTextStream(ctx context.Context, options TextGenerationOptions) (<-chan string, <-chan error) { + textChan := make(chan string) + errChan := make(chan error, 1) + + go func() { + defer close(textChan) + defer close(errChan) + + if err := validateTextGenerationOptions(&options); err != nil { + errChan <- err + return + } + + if x.IsEmpty(options.Model) { + options.Model = c.model + } + + requestBody := map[string]interface{}{ + "model": options.Model, + "messages": options.Messages, + "stream": true, + } + + // Only add non-default parameters + if options.Temperature != 0 { + requestBody["temperature"] = options.Temperature + } + if options.MaxTokens != 0 { + requestBody["max_tokens"] = options.MaxTokens + } + if options.TopP != 0 { + requestBody["top_p"] = options.TopP + } + if options.N != 0 { + requestBody["n"] = options.N + } + + resp, err := c.sendRequest(ctx, ChatCompletionsURL, requestBody) + if err != nil { + errChan <- err + return + } + defer resp.Body.Close() + + c.handleStreamResponse(ctx, resp, textChan, errChan, options.ChunkSize) + }() + + return textChan, errChan +} + +func (c *OpenAIClient) sendRequest(ctx context.Context, endpoint string, body interface{}) (*http.Response, error) { + resp, err := c.httpClient. + SetBaseURL(c.baseURL). + SetBearerToken(c.apiKey). + PostJSON(ctx, endpoint, body) + + if err != nil { + return nil, fmt.Errorf("error sending request: %w", err) + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + return resp, nil +} + +func (c *OpenAIClient) handleStreamResponse(ctx context.Context, resp *http.Response, textChan chan<- string, errChan chan<- error, chunkSize int) { + reader := bufio.NewReader(resp.Body) + buffer := strings.Builder{} + + if chunkSize <= 0 { + chunkSize = DefaultChunkSize + } + + for { + select { + case <-ctx.Done(): + return + default: + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + if buffer.Len() > 0 { + textChan <- buffer.String() + } + return + } + errChan <- fmt.Errorf("error reading stream: %w", err) + return + } + + line = strings.TrimSpace(line) + if line == "" || line == "data: [DONE]" { + continue + } + + if !strings.HasPrefix(line, "data: ") { + errChan <- fmt.Errorf("unexpected line format: %s", line) + return + } + + data := strings.TrimPrefix(line, "data: ") + var streamResponse StreamResponse + if err := json.Unmarshal([]byte(data), &streamResponse); err != nil { + errChan <- fmt.Errorf("error unmarshaling stream data: %w", err) + return + } + + if len(streamResponse.Choices) > 0 && streamResponse.Choices[0].Delta.Content != "" { + buffer.WriteString(streamResponse.Choices[0].Delta.Content) + if buffer.Len() > 0 { + select { + case textChan <- buffer.String(): + buffer.Reset() + case <-ctx.Done(): + return + } + } + } + } + } +} + +// prepareTextGenerationOptions prepares the options for text generation +func (c *OpenAIClient) prepareTextGenerationOptions(prompt []string, options ...func(*TextGenerationOptions)) TextGenerationOptions { + opts := TextGenerationOptions{ + Model: c.model, + Messages: []Message{}, + ChunkSize: DefaultChunkSize, + } + + for _, option := range options { + option(&opts) + } + + if opts.SystemPrompt != "" { + opts.Messages = append(opts.Messages, Message{Role: MessageRoleSystem, Content: opts.SystemPrompt}) + } + + for i, content := range prompt { + role := MessageRoleUser + if i%2 != 0 { + role = MessageRoleAssistant + } + + message := Message{Role: role} + if x.IsImageURL(content) || x.IsBase64(content) { + message.Image = content + } else { + message.Content = content + } + opts.Messages = append(opts.Messages, message) + } + return opts +} + +// QuickGenerateText is a convenience method for generating text +func (c *OpenAIClient) QuickGenerateText(ctx context.Context, prompt []string, options ...func(*TextGenerationOptions)) (string, error) { + opts := c.prepareTextGenerationOptions(prompt, options...) + return c.GenerateText(ctx, opts) +} + +// QuickGenerateTextStream is a convenience method for generating text in a streaming fashion +func (c *OpenAIClient) QuickGenerateTextStream(ctx context.Context, prompt []string, options ...func(*TextGenerationOptions)) (<-chan string, <-chan error) { + opts := c.prepareTextGenerationOptions(prompt, options...) + return c.GenerateTextStream(ctx, opts) +} + +// WithTextModel sets the model for text generation +func WithTextModel(model string) func(*TextGenerationOptions) { + return func(opts *TextGenerationOptions) { + opts.Model = model + } +} + +// WithChunkSize sets the chunk size for streaming text generation +func WithChunkSize(size int) func(*TextGenerationOptions) { + return func(opts *TextGenerationOptions) { + opts.ChunkSize = size + } +} + +// CreateEmbeddings generates embeddings for the given input +func (c *OpenAIClient) CreateEmbeddings(ctx context.Context, input []string, model string) ([][]float64, error) { + if model == "" { + model = DefaultEmbeddingModel + } + + requestBody := EmbeddingRequest{ + Model: model, + Input: input, + } + + resp, err := c.sendRequest(ctx, EmbeddingsURL, requestBody) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result EmbeddingResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + embeddings := make([][]float64, len(result.Data)) + for i, data := range result.Data { + embeddings[i] = data.Embedding + } + + return embeddings, nil +} + +// GenerateImage generates an image based on the provided prompt +func (c *OpenAIClient) GenerateImage(ctx context.Context, prompt string, options ...func(*ImageGenerationRequest)) ([]string, error) { + requestBody := ImageGenerationRequest{ + Model: DefaultImageModel, + Prompt: prompt, + } + + for _, option := range options { + option(&requestBody) + } + + resp, err := c.sendRequest(ctx, ImageGenerationURL, requestBody) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + var result ImageGenerationResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("error decoding response: %w", err) + } + + urls := make([]string, len(result.Data)) + for i, data := range result.Data { + if data.URL != "" { + urls[i] = data.URL + } else { + urls[i] = data.B64JSON + } + } + + return urls, nil +} + +// WithImageModel sets the model for image generation +func WithImageModel(model string) func(*ImageGenerationRequest) { + return func(opts *ImageGenerationRequest) { + opts.Model = model + } +} + +// WithImageSize sets the size for image generation +func WithImageSize(size string) func(*ImageGenerationRequest) { + return func(opts *ImageGenerationRequest) { + opts.Size = size + } +} + +// WithImageQuality sets the quality for image generation +func WithImageQuality(quality string) func(*ImageGenerationRequest) { + return func(opts *ImageGenerationRequest) { + opts.Quality = quality + } +} + +// WithImageResponseFormat sets the response format for image generation +func WithImageResponseFormat(format string) func(*ImageGenerationRequest) { + return func(opts *ImageGenerationRequest) { + opts.ResponseFormat = format + } +} + +// WithImageCount sets the number of images to generate +func WithImageCount(n int) func(*ImageGenerationRequest) { + return func(opts *ImageGenerationRequest) { + opts.N = n + } +} diff --git a/xhttpc/xhttpc.go b/xhttpc/xhttpc.go index c2cba5d..774bd60 100644 --- a/xhttpc/xhttpc.go +++ b/xhttpc/xhttpc.go @@ -3,6 +3,7 @@ package xhttpc import ( "bytes" "context" + "encoding/base64" "encoding/json" "fmt" "io" @@ -17,14 +18,40 @@ import ( "github.com/seefs001/xox/xlog" ) +const ( + defaultTimeout = 10 * time.Second + defaultDialTimeout = 5 * time.Second + defaultKeepAlive = 30 * time.Second + defaultMaxIdleConns = 100 + defaultIdleConnTimeout = 90 * time.Second + defaultTLSHandshakeTimeout = 5 * time.Second + defaultExpectContinueTimeout = 1 * time.Second + defaultRetryCount = 3 + defaultMaxBackoff = 30 * time.Second + defaultMaxBodyLogSize = 1024 // 1KB + defaultUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" +) + // Client is a high-performance HTTP client with sensible defaults and advanced features type Client struct { - client *http.Client - retryCount int - maxBackoff time.Duration - userAgent string - debug bool - logOptions LogOptions + client *http.Client + retryConfig RetryConfig + userAgent string + debug bool + logOptions LogOptions + baseURL string + headers http.Header + cookies []*http.Cookie + queryParams url.Values + formData url.Values + authToken string +} + +// RetryConfig contains retry-related configuration +type RetryConfig struct { + Enabled bool + Count int + MaxBackoff time.Duration } // LogOptions contains configuration for debug logging @@ -42,30 +69,41 @@ type ClientOption func(*Client) func NewClient(options ...ClientOption) *Client { c := &Client{ client: &http.Client{ - Timeout: 10 * time.Second, + Timeout: defaultTimeout, Transport: &http.Transport{ DialContext: (&net.Dialer{ - Timeout: 5 * time.Second, - KeepAlive: 30 * time.Second, + Timeout: defaultDialTimeout, + KeepAlive: defaultKeepAlive, }).DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 5 * time.Second, - ExpectContinueTimeout: 1 * time.Second, + MaxIdleConns: defaultMaxIdleConns, + IdleConnTimeout: defaultIdleConnTimeout, + TLSHandshakeTimeout: defaultTLSHandshakeTimeout, + ExpectContinueTimeout: defaultExpectContinueTimeout, }, }, - retryCount: 3, - maxBackoff: 30 * time.Second, - userAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36", - debug: false, + retryConfig: RetryConfig{ + Enabled: false, + Count: defaultRetryCount, + MaxBackoff: defaultMaxBackoff, + }, + userAgent: defaultUserAgent, + debug: false, logOptions: LogOptions{ - LogHeaders: false, + LogHeaders: true, LogBody: true, LogResponse: true, - MaxBodyLogSize: 1024, // Default to 1KB + MaxBodyLogSize: defaultMaxBodyLogSize, }, + headers: make(http.Header), + queryParams: make(url.Values), + formData: make(url.Values), } + // Set default headers + c.headers.Set("User-Agent", c.userAgent) + c.headers.Set("Accept", "application/json") + c.headers.Set("Accept-Language", "en-US,en;q=0.9") + for _, option := range options { option(c) } @@ -80,17 +118,10 @@ func WithTimeout(timeout time.Duration) ClientOption { } } -// WithRetryCount sets the number of retries for failed requests -func WithRetryCount(count int) ClientOption { - return func(c *Client) { - c.retryCount = count - } -} - -// WithMaxBackoff sets the maximum backoff duration -func WithMaxBackoff(maxBackoff time.Duration) ClientOption { +// WithRetryConfig sets the retry configuration +func WithRetryConfig(config RetryConfig) ClientOption { return func(c *Client) { - c.maxBackoff = maxBackoff + c.retryConfig = config } } @@ -98,15 +129,16 @@ func WithMaxBackoff(maxBackoff time.Duration) ClientOption { func WithUserAgent(userAgent string) ClientOption { return func(c *Client) { c.userAgent = userAgent + c.headers.Set("User-Agent", userAgent) } } -// WithHTTPProxy sets an HTTP proxy for the client -func WithHTTPProxy(proxyURL string) ClientOption { +// WithProxy sets a proxy for the client +func WithProxy(proxyURL string) ClientOption { return func(c *Client) { proxy, err := url.Parse(proxyURL) if err != nil { - // Handle error (e.g., log it) + xlog.Error("Failed to parse proxy URL", "error", err) return } transport, ok := c.client.Transport.(*http.Transport) @@ -118,24 +150,6 @@ func WithHTTPProxy(proxyURL string) ClientOption { } } -// WithSOCKS5Proxy sets a SOCKS5 proxy for the client -func WithSOCKS5Proxy(proxyAddr string) ClientOption { - return func(c *Client) { - transport, ok := c.client.Transport.(*http.Transport) - if !ok { - transport = &http.Transport{} - } - transport.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer := &net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - } - return dialer.DialContext(ctx, "tcp", proxyAddr) - } - c.client.Transport = transport - } -} - // WithDebug enables or disables debug mode func WithDebug(debug bool) ClientOption { return func(c *Client) { @@ -143,6 +157,11 @@ func WithDebug(debug bool) ClientOption { } } +// SetDebug enables or disables debug mode +func (c *Client) SetDebug(debug bool) { + c.debug = debug +} + // WithLogOptions sets the logging options for debug mode func WithLogOptions(options LogOptions) ClientOption { return func(c *Client) { @@ -150,27 +169,202 @@ func WithLogOptions(options LogOptions) ClientOption { } } +// SetBaseURL sets the base URL for all requests +func (c *Client) SetBaseURL(url string) *Client { + c.baseURL = url + return c +} + +// SetHeader sets a header for all requests +func (c *Client) SetHeader(key, value string) *Client { + c.headers.Set(key, value) + return c +} + +// SetHeaders sets multiple headers for all requests +func (c *Client) SetHeaders(headers map[string]string) *Client { + for k, v := range headers { + c.headers.Set(k, v) + } + return c +} + +// AddCookie adds a cookie for all requests +func (c *Client) AddCookie(cookie *http.Cookie) *Client { + c.cookies = append(c.cookies, cookie) + return c +} + +// SetQueryParam sets a query parameter for all requests +func (c *Client) SetQueryParam(key, value string) *Client { + c.queryParams.Set(key, value) + return c +} + +// SetQueryParams sets multiple query parameters for all requests +func (c *Client) SetQueryParams(params map[string]string) *Client { + for k, v := range params { + c.queryParams.Set(k, v) + } + return c +} + +// SetFormData sets form data for all requests +func (c *Client) SetFormData(data map[string]string) *Client { + for k, v := range data { + c.formData.Set(k, v) + } + return c +} + +// SetBasicAuth sets basic auth for all requests +func (c *Client) SetBasicAuth(username, password string) *Client { + c.SetHeader("Authorization", "Basic "+basicAuth(username, password)) + return c +} + +// SetBearerToken sets bearer auth token for all requests +func (c *Client) SetBearerToken(token string) *Client { + c.authToken = token + return c +} + +// AddQueryParam adds a query parameter for all requests +func (c *Client) AddQueryParam(key string, value interface{}) *Client { + c.queryParams.Add(key, fmt.Sprintf("%v", value)) + return c +} + +// AddFormDataField adds a form data field for all requests +func (c *Client) AddFormDataField(key string, value interface{}) *Client { + c.formData.Add(key, fmt.Sprintf("%v", value)) + return c +} + +// Request performs an HTTP request +func (c *Client) Request(ctx context.Context, method, url string, body interface{}) (*http.Response, error) { + return c.doRequest(ctx, method, url, body) +} + // Get performs a GET request func (c *Client) Get(ctx context.Context, url string) (*http.Response, error) { - return c.doRequest(ctx, http.MethodGet, url, nil) + return c.Request(ctx, http.MethodGet, url, nil) } // Post performs a POST request func (c *Client) Post(ctx context.Context, url string, body interface{}) (*http.Response, error) { - return c.doRequest(ctx, http.MethodPost, url, body) + return c.Request(ctx, http.MethodPost, url, body) } // Put performs a PUT request func (c *Client) Put(ctx context.Context, url string, body interface{}) (*http.Response, error) { - return c.doRequest(ctx, http.MethodPut, url, body) + return c.Request(ctx, http.MethodPut, url, body) +} + +// Patch performs a PATCH request +func (c *Client) Patch(ctx context.Context, url string, body interface{}) (*http.Response, error) { + return c.Request(ctx, http.MethodPatch, url, body) } // Delete performs a DELETE request func (c *Client) Delete(ctx context.Context, url string) (*http.Response, error) { - return c.doRequest(ctx, http.MethodDelete, url, nil) + return c.Request(ctx, http.MethodDelete, url, nil) +} + +// Do sends an HTTP request and returns an HTTP response, following +// policy (such as redirects, cookies, auth) as configured on the client. +func (c *Client) Do(req *http.Request) (*http.Response, error) { + return c.doRequest(req.Context(), req.Method, req.URL.String(), req.Body) } +// Head issues a HEAD request to the specified URL. +func (c *Client) Head(url string) (resp *http.Response, err error) { + return c.Request(context.Background(), http.MethodHead, url, nil) +} + +// PostForm issues a POST request to the specified URL, with data's keys and +// values URL-encoded as the request body. +func (c *Client) PostForm(url string, data url.Values) (resp *http.Response, err error) { + return c.Post(context.Background(), url, data) +} + +// PostJSON performs a POST request with a JSON body +func (c *Client) PostJSON(ctx context.Context, url string, body interface{}) (*http.Response, error) { + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal JSON body: %w", err) + } + return c.Post(ctx, url, bytes.NewReader(jsonBody)) +} + +// PutJSON performs a PUT request with a JSON body +func (c *Client) PutJSON(ctx context.Context, url string, body interface{}) (*http.Response, error) { + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal JSON body: %w", err) + } + return c.Put(ctx, url, bytes.NewReader(jsonBody)) +} + +// PatchJSON performs a PATCH request with a JSON body +func (c *Client) PatchJSON(ctx context.Context, url string, body interface{}) (*http.Response, error) { + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal JSON body: %w", err) + } + return c.Patch(ctx, url, bytes.NewReader(jsonBody)) +} + +// startTimeKey is the key used to store the start time in the request context +var startTimeKey = struct{}{} + func (c *Client) doRequest(ctx context.Context, method, reqURL string, body interface{}) (*http.Response, error) { + fullURL := c.baseURL + reqURL + req, err := c.createRequest(ctx, method, fullURL, body) + if err != nil { + return nil, err + } + + if c.debug { + c.logRequest(req) + // Set the start time before sending the request + ctx = context.WithValue(ctx, startTimeKey, time.Now()) + req = req.WithContext(ctx) + } + + var resp *http.Response + if c.retryConfig.Enabled { + operation := func() error { + var err error + resp, err = c.client.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + if resp.StatusCode >= 500 { + return fmt.Errorf("server error: %d", resp.StatusCode) + } + return nil + } + + err = c.retryWithBackoff(ctx, operation) + if err != nil { + return nil, err + } + } else { + resp, err = c.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + } + + if c.debug && c.logOptions.LogResponse { + c.logResponse(resp) + } + + return resp, nil +} + +func (c *Client) createRequest(ctx context.Context, method, reqURL string, body interface{}) (*http.Request, error) { var bodyReader io.Reader var contentType string @@ -179,6 +373,16 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, body inte case url.Values: bodyReader = strings.NewReader(v.Encode()) contentType = "application/x-www-form-urlencoded" + case io.Reader: + bodyReader = v + // Check if the reader is a *bytes.Reader containing JSON data + if jsonReader, ok := v.(*bytes.Reader); ok { + jsonData, _ := io.ReadAll(jsonReader) + if json.Valid(jsonData) { + contentType = "application/json" + bodyReader = bytes.NewReader(jsonData) + } + } default: jsonBody, err := json.Marshal(body) if err != nil { @@ -194,49 +398,62 @@ func (c *Client) doRequest(ctx context.Context, method, reqURL string, body inte return nil, fmt.Errorf("failed to create request: %w", err) } - if contentType != "" { - req.Header.Set("Content-Type", contentType) + // Set method-specific headers + switch method { + case http.MethodGet: + req.Header.Set("Cache-Control", "no-cache") + case http.MethodPost, http.MethodPut, http.MethodPatch: + if contentType != "" { + req.Header.Set("Content-Type", contentType) + } } + + // Set default headers req.Header.Set("User-Agent", c.userAgent) + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") - if c.debug { - c.logRequest(req) + // Set custom headers + for k, v := range c.headers { + req.Header[k] = v } - var resp *http.Response - operation := func() error { - var err error - resp, err = c.client.Do(req) - if err != nil { - return fmt.Errorf("failed to send request: %w", err) - } - if resp.StatusCode >= 500 { - return fmt.Errorf("server error: %d", resp.StatusCode) + // Set cookies + for _, cookie := range c.cookies { + req.AddCookie(cookie) + } + + // Set query parameters + q := req.URL.Query() + for k, v := range c.queryParams { + for _, vv := range v { + q.Add(k, vv) } - return nil } + req.URL.RawQuery = q.Encode() - err = c.retryWithBackoff(ctx, operation) - if err != nil { - return nil, err + // Set form data + if (method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch) && contentType == "application/x-www-form-urlencoded" { + req.PostForm = c.formData } - if c.debug && c.logOptions.LogResponse { - c.logResponse(resp) + // Set auth token + if c.authToken != "" { + req.Header.Set("Authorization", "Bearer "+c.authToken) } - return resp, nil + return req, nil } func (c *Client) retryWithBackoff(ctx context.Context, operation func() error) error { var err error - for i := 0; i < c.retryCount; i++ { + for i := 0; i < c.retryConfig.Count; i++ { err = operation() if err == nil { return nil } - if i == c.retryCount-1 { + if i == c.retryConfig.Count-1 { break } @@ -257,33 +474,12 @@ func (c *Client) retryWithBackoff(ctx context.Context, operation func() error) e func (c *Client) calculateBackoff(attempt int) time.Duration { backoff := float64(time.Second) - max := float64(c.maxBackoff) + max := float64(c.retryConfig.MaxBackoff) temp := math.Min(max, math.Pow(2, float64(attempt))*backoff) - rand.Seed(time.Now().UnixNano()) backoff = temp/2 + rand.Float64()*(temp/2) return time.Duration(backoff) } -// SetRequestHeaders sets custom headers for all requests -func (c *Client) SetRequestHeaders(headers map[string]string) { - c.client.Transport = &headerTransport{ - base: c.client.Transport, - headers: headers, - } -} - -type headerTransport struct { - base http.RoundTripper - headers map[string]string -} - -func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - for k, v := range t.headers { - req.Header.Set(k, v) - } - return t.base.RoundTrip(req) -} - func (c *Client) logRequest(req *http.Request) { xlog.Info("HTTP Request", "method", req.Method, "url", req.URL.String()) @@ -304,10 +500,34 @@ func (c *Client) logRequest(req *http.Request) { xlog.Info("Request Body", "body", string(body)) } } + + xlog.Info("Request Details", + "method", req.Method, + "url", req.URL.String(), + "protocol", req.Proto, + "host", req.Host, + "content_length", req.ContentLength, + "transfer_encoding", req.TransferEncoding, + "close", req.Close, + "trailer", req.Trailer, + "remote_addr", req.RemoteAddr, + "request_uri", req.RequestURI, + ) + + if c.debug { + xlog.Info("Custom Client Settings", + "base_url", c.baseURL, + "user_agent", c.userAgent, + "retry_enabled", c.retryConfig.Enabled, + "retry_count", c.retryConfig.Count, + "max_backoff", c.retryConfig.MaxBackoff, + "debug_mode", c.debug, + ) + } } func (c *Client) logResponse(resp *http.Response) { - xlog.Info("HTTP Response", "status", resp.Status) + xlog.Info("HTTP Response", "status", resp.Status, "status_code", resp.StatusCode) if c.logOptions.LogHeaders { for key, values := range resp.Header { @@ -326,4 +546,33 @@ func (c *Client) logResponse(resp *http.Response) { xlog.Info("Response Body", "body", string(body)) } } + + xlog.Info("Response Details", + "status", resp.Status, + "status_code", resp.StatusCode, + "protocol", resp.Proto, + "content_length", resp.ContentLength, + "transfer_encoding", resp.TransferEncoding, + "uncompressed", resp.Uncompressed, + "trailer", resp.Trailer, + ) + + if c.debug { + // Safely get and use startTimeKey + if startTimeValue := resp.Request.Context().Value(startTimeKey); startTimeValue != nil { + if startTime, ok := startTimeValue.(time.Time); ok { + duration := time.Since(startTime) + xlog.Info("Response Timing", + "duration", duration.String(), + "start_time", startTime.Format(time.RFC3339), + "end_time", time.Now().Format(time.RFC3339), + ) + } + } + } +} + +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) } diff --git a/xhttpc/xhttpc_test.go b/xhttpc/xhttpc_test.go index 3aae526..2d97322 100644 --- a/xhttpc/xhttpc_test.go +++ b/xhttpc/xhttpc_test.go @@ -17,22 +17,24 @@ func TestNewClient(t *testing.T) { client := NewClient() assert.NotNil(t, client, "NewClient should not return nil") - assert.Equal(t, 3, client.retryCount, "Default retryCount should be 3") - assert.Equal(t, 30*time.Second, client.maxBackoff, "Default maxBackoff should be 30s") + assert.Equal(t, 3, client.retryConfig.Count, "Default retry count should be 3") + assert.Equal(t, 30*time.Second, client.retryConfig.MaxBackoff, "Default max backoff should be 30s") assert.NotEmpty(t, client.userAgent, "Default userAgent should be set") } func TestClientOptions(t *testing.T) { client := NewClient( WithTimeout(5*time.Second), - WithRetryCount(5), - WithMaxBackoff(10*time.Second), + WithRetryConfig(RetryConfig{ + Count: 5, + MaxBackoff: 10 * time.Second, + }), WithUserAgent("TestAgent"), ) assert.Equal(t, 5*time.Second, client.client.Timeout, "Timeout should be 5s") - assert.Equal(t, 5, client.retryCount, "RetryCount should be 5") - assert.Equal(t, 10*time.Second, client.maxBackoff, "MaxBackoff should be 10s") + assert.Equal(t, 5, client.retryConfig.Count, "RetryCount should be 5") + assert.Equal(t, 10*time.Second, client.retryConfig.MaxBackoff, "MaxBackoff should be 10s") assert.Equal(t, "TestAgent", client.userAgent, "UserAgent should be TestAgent") } @@ -111,7 +113,10 @@ func TestRetryWithBackoff(t *testing.T) { })) defer server.Close() - client := NewClient(WithRetryCount(3), WithMaxBackoff(100*time.Millisecond)) + client := NewClient(WithRetryConfig(RetryConfig{ + Count: 3, + MaxBackoff: 100 * time.Millisecond, + })) resp, err := client.Get(context.Background(), server.URL) require.NoError(t, err, "Request should not fail") @@ -129,7 +134,7 @@ func TestSetRequestHeaders(t *testing.T) { defer server.Close() client := NewClient() - client.SetRequestHeaders(map[string]string{ + client.SetHeaders(map[string]string{ "X-Custom-Header": "test-value", })