diff --git a/pkg/sequence/segment.go b/pkg/sequence/segment.go index 4e48d7d..633a649 100644 --- a/pkg/sequence/segment.go +++ b/pkg/sequence/segment.go @@ -33,7 +33,7 @@ const ( " PRIMARY KEY ( `business_id` )" + ");" segmentExits = "SELECT EXISTS (SELECT 1 FROM `segment` WHERE `biz_id` = ?)" - insertSegment = "INSERT INTO `segment`(`biz_id`, `step`, `max_id`) VALUES (?, ?, 1)" + insertSegment = "INSERT INTO `segment`(`biz_id`, `step`, `max_id`) VALUES (?, ?, ?)" selectSegment = "SELECT max_id FROM segment WHERE biz_id = ? FOR UPDATE" updateSegment = "UPDATE segment SET max_id = ? WHERE biz_id = ? AND max_id = ?" ) @@ -47,7 +47,7 @@ type SegmentWorker struct { step int64 } -func NewSegmentWorker(dsn string, len int64, biz string) (*SegmentWorker, error) { +func NewSegmentWorker(dsn string, from, len int64, biz string) (*SegmentWorker, error) { db, err := sql.Open("mysql", dsn) if err != nil { return nil, err @@ -66,7 +66,7 @@ func NewSegmentWorker(dsn string, len int64, biz string) (*SegmentWorker, error) } } if !exists { - if _, err := db.ExecContext(context.Background(), insertSegment, biz, len); err != nil { + if _, err := db.ExecContext(context.Background(), insertSegment, biz, len, from); err != nil { log.Error(err) return nil, err } @@ -103,8 +103,8 @@ func (w *SegmentWorker) ProduceID() { } } - w.buffer <- w.min w.min++ + w.buffer <- w.min } } diff --git a/pkg/sequence/segment_test.go b/pkg/sequence/segment_test.go index 9ace551..2796036 100644 --- a/pkg/sequence/segment_test.go +++ b/pkg/sequence/segment_test.go @@ -35,7 +35,7 @@ func (suite *_SegmentTestSuite) SetupSuite() { suite.T().Fatal(err) } dsn := fmt.Sprintf("root:123456@tcp(localhost:%d)/segment", port.Int()) - segment, err := NewSegmentWorker(dsn, 1000, "student") + segment, err := NewSegmentWorker(dsn, 0, 1000, "student") if err != nil { suite.T().Fatal(err) } diff --git a/pkg/sequence/sequence.go b/pkg/sequence/sequence.go index af3397d..d7e45d0 100644 --- a/pkg/sequence/sequence.go +++ b/pkg/sequence/sequence.go @@ -28,7 +28,7 @@ func NewSequence(generator *config.SequenceGenerator, tableName string) (proto.S if segmentConfig.Step == 0 { segmentConfig.Step = 1000 } - return NewSegmentWorker(segmentConfig.DSN, segmentConfig.Step, tableName) + return NewSegmentWorker(segmentConfig.DSN, segmentConfig.From, segmentConfig.Step, tableName) case config.Snowflake: var ( err error @@ -49,6 +49,7 @@ func NewSequence(generator *config.SequenceGenerator, tableName string) (proto.S type SegmentConfig struct { DSN string `yaml:"dsn" json:"dsn"` + From int64 `yaml:"from" json:"from"` Step int64 `default:"1000" yaml:"step" json:"step"` } diff --git a/pkg/sequence/snowflake_test.go b/pkg/sequence/snowflake_test.go index 03f071b..dc697c9 100644 --- a/pkg/sequence/snowflake_test.go +++ b/pkg/sequence/snowflake_test.go @@ -5,6 +5,14 @@ import ( "testing" ) +func TestNextID(t *testing.T) { + worker, err := NewWorker(10) + if err != nil { + t.Error(err) + } + t.Log(worker.NextID()) +} + func BenchmarkSnowFlakeNextID(b *testing.B) { worker, err := NewWorker(10) if err != nil {