Skip to content

Commit

Permalink
Fix warnings and move Testkit backend to context APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
fbiville committed Apr 8, 2022
1 parent c81ce99 commit b602bd6
Showing 1 changed file with 76 additions and 59 deletions.
135 changes: 76 additions & 59 deletions testkit-backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package main

import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
Expand All @@ -39,22 +40,23 @@ import (
// Handles a testkit backend session.
// Tracks all objects (and errors) that is created by testkit frontend.
type backend struct {
rd *bufio.Reader // Socket to read requests from
wr io.Writer // Socket to write responses (and logs) on, don't buffer (WriteString on bufio was weird...)
drivers map[string]neo4j.Driver
sessionStates map[string]*sessionState
results map[string]neo4j.Result
transactions map[string]neo4j.Transaction
recordedErrors map[string]error
resolvedAddresses map[string][]interface{}
id int // Id to use for next object created by frontend
wrLock sync.Mutex
rd *bufio.Reader // Socket to read requests from
wr io.Writer // Socket to write responses (and logs) on, don't buffer (WriteString on bufio was weird...)
drivers map[string]neo4j.DriverWithContext
sessionStates map[string]*sessionState
results map[string]neo4j.ResultWithContext
managedTransactions map[string]neo4j.ManagedTransaction
explicitTransactions map[string]neo4j.ExplicitTransaction
recordedErrors map[string]error
resolvedAddresses map[string][]interface{}
id int // ID to use for next object created by frontend
wrLock sync.Mutex
}

// To implement transactional functions a bit of extra state is needed on the
// driver session.
type sessionState struct {
session neo4j.Session
session neo4j.SessionWithContext
retryableState int
retryableErrorId string
}
Expand All @@ -65,17 +67,20 @@ const (
retryableNegative = -1
)

var ctx = context.Background()

func newBackend(rd *bufio.Reader, wr io.Writer) *backend {
return &backend{
rd: rd,
wr: wr,
drivers: make(map[string]neo4j.Driver),
sessionStates: make(map[string]*sessionState),
results: make(map[string]neo4j.Result),
transactions: make(map[string]neo4j.Transaction),
recordedErrors: make(map[string]error),
resolvedAddresses: make(map[string][]interface{}),
id: 0,
rd: rd,
wr: wr,
drivers: make(map[string]neo4j.DriverWithContext),
sessionStates: make(map[string]*sessionState),
results: make(map[string]neo4j.ResultWithContext),
managedTransactions: make(map[string]neo4j.ManagedTransaction),
explicitTransactions: make(map[string]neo4j.ExplicitTransaction),
recordedErrors: make(map[string]error),
resolvedAddresses: make(map[string][]interface{}),
id: 0,
}
}

Expand Down Expand Up @@ -166,7 +171,7 @@ func (b *backend) nextId() string {

func (b *backend) process() bool {
request := ""
in_request := false
inRequest := false

for {
line, err := b.rd.ReadString('\n')
Expand All @@ -176,20 +181,20 @@ func (b *backend) process() bool {

switch line {
case "#request begin\n":
if in_request {
if inRequest {
panic("Already in request")
}
in_request = true
inRequest = true
case "#request end\n":
if !in_request {
if !inRequest {
panic("End while not in request")
}
b.handleRequest(b.toRequest(request))
request = ""
in_request = false
inRequest = false
return true
default:
if !in_request {
if !inRequest {
panic("Line while not in request")
}

Expand Down Expand Up @@ -263,12 +268,12 @@ func (b *backend) toCypherAndParams(data map[string]interface{}) (string, map[st
func (b *backend) handleTransactionFunc(isRead bool, data map[string]interface{}) {
sid := data["sessionId"].(string)
sessionState := b.sessionStates[sid]
blockingRetry := func(tx neo4j.Transaction) (interface{}, error) {
blockingRetry := func(tx neo4j.ManagedTransaction) (interface{}, error) {
sessionState.retryableState = retryableNothing
// Instruct client to start doing it's work
txid := b.nextId()
b.transactions[txid] = tx
b.writeResponse("RetryableTry", map[string]interface{}{"id": txid})
// Instruct client to start doing its work
txId := b.nextId()
b.managedTransactions[txId] = tx
b.writeResponse("RetryableTry", map[string]interface{}{"id": txId})
// Process all things that the client might do within the transaction
for {
b.process()
Expand All @@ -290,9 +295,9 @@ func (b *backend) handleTransactionFunc(isRead bool, data map[string]interface{}
}
var err error
if isRead {
_, err = sessionState.session.ReadTransaction(blockingRetry, b.toTransactionConfigApply(data))
_, err = sessionState.session.ExecuteRead(ctx, blockingRetry, b.toTransactionConfigApply(data))
} else {
_, err = sessionState.session.WriteTransaction(blockingRetry, b.toTransactionConfigApply(data))
_, err = sessionState.session.ExecuteWrite(ctx, blockingRetry, b.toTransactionConfigApply(data))
}

if err != nil {
Expand Down Expand Up @@ -387,7 +392,7 @@ func (b *backend) handleRequest(req map[string]interface{}) {
}
// Parse URI (or rather type cast)
uri := data["uri"].(string)
driver, err := neo4j.NewDriver(uri, authToken, func(c *neo4j.Config) {
driver, err := neo4j.NewDriverWithContext(uri, authToken, func(c *neo4j.Config) {
// Setup custom logger that redirects log entries back to frontend
c.Log = &streamLog{writeLine: b.writeLineLocked}
// Optional custom user agent from frontend
Expand Down Expand Up @@ -425,7 +430,7 @@ func (b *backend) handleRequest(req map[string]interface{}) {
case "DriverClose":
driverId := data["driverId"].(string)
driver := b.drivers[driverId]
err := driver.Close()
err := driver.Close(ctx)
if err != nil {
b.writeError(err)
return
Expand All @@ -442,13 +447,13 @@ func (b *backend) handleRequest(req map[string]interface{}) {
case "w":
sessionConfig.AccessMode = neo4j.AccessModeWrite
default:
b.writeError(errors.New("Unknown accessmode: " + data["accessMode"].(string)))
b.writeError(errors.New("Unknown access mode: " + data["accessMode"].(string)))
return
}
if data["bookmarks"] != nil {
bookmarksx := data["bookmarks"].([]interface{})
bookmarks := make([]string, len(bookmarksx))
for i, x := range bookmarksx {
rawBookmarks := data["bookmarks"].([]interface{})
bookmarks := make([]string, len(rawBookmarks))
for i, x := range rawBookmarks {
bookmarks[i] = x.(string)
}
sessionConfig.Bookmarks = neo4j.BookmarksFromRawValues(bookmarks...)
Expand All @@ -470,7 +475,7 @@ func (b *backend) handleRequest(req map[string]interface{}) {
case "SessionClose":
sessionId := data["sessionId"].(string)
sessionState := b.sessionStates[sessionId]
err := sessionState.session.Close()
err := sessionState.session.Close(ctx)
if err != nil {
b.writeError(err)
return
Expand All @@ -480,7 +485,7 @@ func (b *backend) handleRequest(req map[string]interface{}) {
case "SessionRun":
sessionState := b.sessionStates[data["sessionId"].(string)]
cypher, params := b.toCypherAndParams(data)
result, err := sessionState.session.Run(cypher, params, b.toTransactionConfigApply(data))
result, err := sessionState.session.Run(ctx, cypher, params, b.toTransactionConfigApply(data))
if err != nil {
b.writeError(err)
return
Expand All @@ -496,13 +501,13 @@ func (b *backend) handleRequest(req map[string]interface{}) {

case "SessionBeginTransaction":
sessionState := b.sessionStates[data["sessionId"].(string)]
tx, err := sessionState.session.BeginTransaction(b.toTransactionConfigApply(data))
tx, err := sessionState.session.BeginTransaction(ctx, b.toTransactionConfigApply(data))
if err != nil {
b.writeError(err)
return
}
idKey := b.nextId()
b.transactions[idKey] = tx
b.explicitTransactions[idKey] = tx
b.writeResponse("Transaction", map[string]interface{}{"id": idKey})

case "SessionLastBookmarks":
Expand All @@ -514,9 +519,16 @@ func (b *backend) handleRequest(req map[string]interface{}) {
b.writeResponse("Bookmarks", map[string]interface{}{"bookmarks": bookmarks})

case "TransactionRun":
tx := b.transactions[data["txId"].(string)]
// ManagedTransaction is compatible with ExplicitTransaction
// and is all that is needed for TransactionRun
var tx neo4j.ManagedTransaction
var found bool
transactionId := data["txId"].(string)
if tx, found = b.explicitTransactions[transactionId]; !found {
tx = b.managedTransactions[transactionId]
}
cypher, params := b.toCypherAndParams(data)
result, err := tx.Run(cypher, params)
result, err := tx.Run(ctx, cypher, params)
if err != nil {
b.writeError(err)
return
Expand All @@ -532,8 +544,8 @@ func (b *backend) handleRequest(req map[string]interface{}) {

case "TransactionCommit":
txId := data["txId"].(string)
tx := b.transactions[txId]
err := tx.Commit()
tx := b.explicitTransactions[txId]
err := tx.Commit(ctx)
if err != nil {
b.writeError(err)
return
Expand All @@ -542,8 +554,8 @@ func (b *backend) handleRequest(req map[string]interface{}) {

case "TransactionRollback":
txId := data["txId"].(string)
tx := b.transactions[txId]
err := tx.Rollback()
tx := b.explicitTransactions[txId]
err := tx.Rollback(ctx)
if err != nil {
b.writeError(err)
return
Expand All @@ -552,8 +564,8 @@ func (b *backend) handleRequest(req map[string]interface{}) {

case "TransactionClose":
txId := data["txId"].(string)
tx := b.transactions[txId]
err := tx.Close()
tx := b.explicitTransactions[txId]
err := tx.Close(ctx)
if err != nil {
b.writeError(err)
return
Expand All @@ -577,16 +589,16 @@ func (b *backend) handleRequest(req map[string]interface{}) {

case "ResultNext":
result := b.results[data["resultId"].(string)]
more := result.Next()
more := result.Next(ctx)
b.writeRecord(result, result.Record(), more)
case "ResultPeek":
result := b.results[data["resultId"].(string)]
var record *db.Record = nil
more := result.PeekRecord(&record)
more := result.PeekRecord(ctx, &record)
b.writeRecord(result, record, more)
case "ResultList":
result := b.results[data["resultId"].(string)]
records, err := result.Collect()
records, err := result.Collect(ctx)
if err != nil {
b.writeError(err)
return
Expand All @@ -600,7 +612,7 @@ func (b *backend) handleRequest(req map[string]interface{}) {
})
case "ResultConsume":
result := b.results[data["resultId"].(string)]
summary, err := result.Consume()
summary, err := result.Consume(ctx)
if err != nil {
b.writeError(err)
return
Expand Down Expand Up @@ -655,13 +667,18 @@ func (b *backend) handleRequest(req map[string]interface{}) {
case "CheckMultiDBSupport":
driver := b.drivers[data["driverId"].(string)]
session := driver.NewSession(neo4j.SessionConfig{})
result, err := session.Run("RETURN 42", nil)
defer session.Close()
result, err := session.Run(ctx, "RETURN 42", nil)
defer func() {
err = session.Close(ctx)
if err != nil {
b.writeError(fmt.Errorf("could not check multi DB support: %w", err))
}
}()
if err != nil {
b.writeError(fmt.Errorf("could not check multi DB support: %w", err))
return
}
summary, err := result.Consume()
summary, err := result.Consume(ctx)
if err != nil {
b.writeError(fmt.Errorf("could not check multi DB support: %w", err))
return
Expand Down Expand Up @@ -724,7 +741,7 @@ func (b *backend) handleRequest(req map[string]interface{}) {
}
}

func (b *backend) writeRecord(result neo4j.Result, record *neo4j.Record, expectRecord bool) {
func (b *backend) writeRecord(result neo4j.ResultWithContext, record *neo4j.Record, expectRecord bool) {
if expectRecord && record == nil {
b.writeResponse("BackendError", map[string]interface{}{
"msg": "Found no record where one was expected.",
Expand Down

0 comments on commit b602bd6

Please sign in to comment.