Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TimeTruncate functional option #1552

Merged
merged 7 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
buf = append(buf, "'0000-00-00'"...)
} else {
buf = append(buf, '\'')
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate)
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return "", err
}
Expand Down
49 changes: 41 additions & 8 deletions dsn.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ var (
// If a new Config is created instead of being parsed from a DSN string,
// the NewConfig function should be used, which sets default values.
type Config struct {
// non boolean fields

User string // Username
Passwd string // Password (requires User)
Net string // Network (e.g. "tcp", "tcp6", "unix". default: "tcp")
Expand All @@ -45,15 +47,15 @@ type Config struct {
Loc *time.Location // Location for time.Time values
MaxAllowedPacket int // Max packet size allowed
ServerPubKey string // Server public key name
pubKey *rsa.PublicKey // Server public key
TLSConfig string // TLS configuration name
TLS *tls.Config // TLS configuration, its priority is higher than TLSConfig
TimeTruncate time.Duration // Truncate time.Time values to the specified duration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
Logger Logger // Logger

// boolean fields

AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowFallbackToPlaintext bool // Allows fallback to unencrypted connection if server does not support TLS
Expand All @@ -66,17 +68,48 @@ type Config struct {
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
RejectReadOnly bool // Reject read-only connections

// private fields. new options should be come here
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

private -> unexpected - let's use Go naming


pubKey *rsa.PublicKey // Server public key
timeTruncate time.Duration // Truncate time.Time values to the specified duration
}

// Functional Options Pattern
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis
type option func(*Config) error
methane marked this conversation as resolved.
Show resolved Hide resolved

// NewConfig creates a new Config and sets default values.
func NewConfig() *Config {
return &Config{
func NewConfig(opts ...option) *Config {
cfg := &Config{
methane marked this conversation as resolved.
Show resolved Hide resolved
Loc: time.UTC,
MaxAllowedPacket: defaultMaxAllowedPacket,
Logger: defaultLogger,
AllowNativePasswords: true,
CheckConnLiveness: true,
}

cfg.SetOptions(opts...)
return cfg
}

func (c *Config) SetOptions(opts ...option) error {
for _, opt := range opts {
err := opt(c)
if err != nil {
return err
}
}
return nil
}

// TimeTruncate sets the time duration to truncate time.Time values in
// query parameters.
func TimeTruncate(d time.Duration) option {
return func(cfg *Config) error {
cfg.timeTruncate = d
return nil
}
}

func (cfg *Config) Clone() *Config {
Expand Down Expand Up @@ -263,8 +296,8 @@ func (cfg *Config) FormatDSN() string {
writeDSNParam(&buf, &hasParam, "parseTime", "true")
}

if cfg.TimeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.TimeTruncate.String())
if cfg.timeTruncate > 0 {
writeDSNParam(&buf, &hasParam, "timeTruncate", cfg.timeTruncate.String())
}

if cfg.ReadTimeout > 0 {
Expand Down Expand Up @@ -509,9 +542,9 @@ func parseDSNParams(cfg *Config, params string) (err error) {

// time.Time truncation
case "timeTruncate":
cfg.TimeTruncate, err = time.ParseDuration(value)
cfg.timeTruncate, err = time.ParseDuration(value)
methane marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return
return fmt.Errorf("invalid timeTruncate value: %v, error: %w", value, err)
}

// I/O read Timeout
Expand Down
2 changes: 1 addition & 1 deletion dsn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ var testDSNs = []struct {
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true},
}, {
"user:password@/dbname?loc=UTC&timeout=30s&parseTime=true&timeTruncate=1h",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, TimeTruncate: time.Hour},
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Loc: time.UTC, Timeout: 30 * time.Second, ParseTime: true, MaxAllowedPacket: defaultMaxAllowedPacket, Logger: defaultLogger, AllowNativePasswords: true, CheckConnLiveness: true, timeTruncate: time.Hour},
},
}

methane marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 1 addition & 1 deletion packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() {
b = append(b, "0000-00-00"...)
} else {
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.TimeTruncate)
b, err = appendDateTime(b, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions result.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ import "database/sql/driver"
// This is accessible by executing statements using sql.Conn.Raw() and
// downcasting the returned result:
//
// res, err := rawConn.Exec(...)
// res.(mysql.Result).AllRowsAffected()
//
// res, err := rawConn.Exec(...)
methane marked this conversation as resolved.
Show resolved Hide resolved
// res.(mysql.Result).AllRowsAffected()
type Result interface {
driver.Result
// AllRowsAffected returns a slice containing the affected rows for each
Expand Down
Loading