Skip to content

Commit

Permalink
fix: Fix memory management (#131)
Browse files Browse the repository at this point in the history
A number of fixes, but most notably removing support for Retain/Release as we've decided to only support the default Go memory allocator going forward. With the Go allocator we can rely on Go's garbage collector.
  • Loading branch information
hermanschaaf authored Apr 20, 2023
1 parent 2157205 commit 9b97bae
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 76 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.PHONY: test
test:
go test -race ./...
go test -race -tags=assert ./...

.PHONY: lint
lint:
Expand Down
12 changes: 1 addition & 11 deletions csv/write_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/bradleyjkemp/cupaloy/v2"
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v2/testdata"
Expand All @@ -33,21 +32,14 @@ func TestWriteRead(t *testing.T) {
arrowSchema := table.ToArrowSchema()
sourceName := "test-source"
syncTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)
opts := testdata.GenTestDataOptions{
SourceName: sourceName,
SyncTime: syncTime,
MaxRows: 2,
StableUUID: uuid.MustParse("00000000-0000-0000-0000-000000000001"),
StableTime: time.Date(2021, 1, 2, 0, 0, 0, 0, time.UTC),
}
records := testdata.GenTestData(mem, arrowSchema, opts)
defer func() {
for _, r := range records {
r.Release()
}
}()
records := testdata.GenTestData(arrowSchema, opts)
cl, err := NewClient(tc.options...)
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -83,10 +75,8 @@ func TestWriteRead(t *testing.T) {
totalCount := 0
for got := range ch {
if diff := destination.RecordDiff(records[totalCount], got); diff != "" {
got.Release()
t.Errorf("got diff: %s", diff)
}
got.Release()
totalCount++
}
if readErr != nil {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module github.com/cloudquery/filetypes/v2
go 1.19

require (
github.com/cloudquery/plugin-sdk/v2 v2.3.6
github.com/cloudquery/plugin-sdk/v2 v2.3.7
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.8.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ github.com/bradleyjkemp/cupaloy/v2 v2.8.0 h1:any4BmKE+jGIaMpnU8YgH/I2LPiLBufr6oM
github.com/bradleyjkemp/cupaloy/v2 v2.8.0/go.mod h1:bm7JXdkRd4BHJk9HpwqAI8BoAY1lps46Enkdqw6aRX0=
github.com/cloudquery/arrow/go/v12 v12.0.0-20230419074556-00ceafa3b033 h1:wMIRbdyx9Oe9Cfzf9DN1lEyTuQnwSYLls2gsN7EfhZM=
github.com/cloudquery/arrow/go/v12 v12.0.0-20230419074556-00ceafa3b033/go.mod h1:d+tV/eHZZ7Dz7RPrFKtPK02tpr+c9/PEd/zm8mDS9Vg=
github.com/cloudquery/plugin-sdk/v2 v2.3.6 h1:fzsmALscu9w6pZNB+6Aj4cobhqCCKorCDjV3EtkF8Ao=
github.com/cloudquery/plugin-sdk/v2 v2.3.6/go.mod h1:/wAbhyQbdIUAMEL+Yo9zkgoBls83xt3ev6jLpJblIoU=
github.com/cloudquery/plugin-sdk/v2 v2.3.7 h1:tDRi61+NzIfOORxrRjP48E2bM322maWoVZi2E6VR0rI=
github.com/cloudquery/plugin-sdk/v2 v2.3.7/go.mod h1:/wAbhyQbdIUAMEL+Yo9zkgoBls83xt3ev6jLpJblIoU=
github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down
1 change: 0 additions & 1 deletion json/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ func (*Client) Read(r io.Reader, arrowSchema *arrow.Schema, _ string, res chan<-
scanner := bufio.NewScanner(r)
scanner.Buffer(make([]byte, maxJSONSize), maxJSONSize)
rb := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema)
defer rb.Release()
for scanner.Scan() {
b := scanner.Bytes()
err := rb.UnmarshalJSON(b)
Expand Down
1 change: 0 additions & 1 deletion json/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ func (c *Client) WriteTableBatch(w io.Writer, _ *arrow.Schema, records []arrow.R

func (*Client) writeRecord(w io.Writer, record arrow.Record) error {
arr := array.RecordToStructArray(record)
defer arr.Release()
enc := json.NewEncoder(w)
enc.SetEscapeHTML(false)
for i := 0; i < arr.Len(); i++ {
Expand Down
19 changes: 2 additions & 17 deletions json/write_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/bradleyjkemp/cupaloy/v2"
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v2/testdata"
Expand All @@ -21,19 +20,12 @@ func TestWrite(t *testing.T) {
arrowSchema := table.ToArrowSchema()
sourceName := "test-source"
syncTime := time.Now().UTC().Round(1 * time.Second)
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)
opts := testdata.GenTestDataOptions{
SourceName: sourceName,
SyncTime: syncTime,
MaxRows: 1,
}
records := testdata.GenTestData(mem, arrowSchema, opts)
defer func() {
for _, r := range records {
r.Release()
}
}()
records := testdata.GenTestData(arrowSchema, opts)
cl, err := NewClient()
if err != nil {
t.Fatal(err)
Expand All @@ -49,21 +41,14 @@ func TestWriteRead(t *testing.T) {
arrowSchema := table.ToArrowSchema()
sourceName := "test-source"
syncTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)
opts := testdata.GenTestDataOptions{
SourceName: sourceName,
SyncTime: syncTime,
MaxRows: 2,
StableUUID: uuid.MustParse("00000000-0000-0000-0000-000000000001"),
StableTime: time.Date(2021, 1, 2, 0, 0, 0, 0, time.UTC),
}
records := testdata.GenTestData(mem, arrowSchema, opts)
defer func() {
for _, r := range records {
r.Release()
}
}()
records := testdata.GenTestData(arrowSchema, opts)
cl, err := NewClient()
if err != nil {
t.Fatal(err)
Expand Down
33 changes: 19 additions & 14 deletions parquet/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ type ReaderAtSeeker interface {
}

func (*Client) Read(f ReaderAtSeeker, arrowSchema *arrow.Schema, _ string, res chan<- arrow.Record) error {
mem := memory.DefaultAllocator
ctx := context.Background()
rdr, err := file.NewParquetReader(f)
if err != nil {
Expand All @@ -31,38 +30,44 @@ func (*Client) Read(f ReaderAtSeeker, arrowSchema *arrow.Schema, _ string, res c
Parallel: false,
BatchSize: 1024,
}
fr, err := pqarrow.NewFileReader(rdr, arrProps, mem)
fr, err := pqarrow.NewFileReader(rdr, arrProps, memory.DefaultAllocator)
if err != nil {
return fmt.Errorf("failed to create new parquet file reader: %w", err)
}
rr, err := fr.GetRecordReader(ctx, nil, nil)
if err != nil {
return fmt.Errorf("failed to get parquet record reader: %w", err)
}

for rr.Next() {
rec := rr.Record()
castRec, err := castStringsToExtensions(mem, rec, arrowSchema)
castRec, err := castStringsToExtensions(rec, arrowSchema)
if err != nil {
return fmt.Errorf("failed to cast extension types: %w", err)
}
castRec.Retain()
res <- castRec
_, err = rr.Read()
if err == io.EOF {
break
} else if err != nil {
return fmt.Errorf("error while reading record: %w", err)
castRecs := convertToSingleRowRecords(castRec)
for _, r := range castRecs {
res <- r
}
}
rr.Release()
if rr.Err() != nil && rr.Err() != io.EOF {
return fmt.Errorf("failed to read parquet record: %w", rr.Err())
}

return nil
}

func castStringsToExtensions(mem memory.Allocator, rec arrow.Record, arrowSchema *arrow.Schema) (arrow.Record, error) {
rb := array.NewRecordBuilder(mem, arrowSchema)
func convertToSingleRowRecords(rec arrow.Record) []arrow.Record {
records := make([]arrow.Record, rec.NumRows())
for i := int64(0); i < rec.NumRows(); i++ {
records[i] = rec.NewSlice(i, i+1)
}
return records
}

defer rb.Release()
// castExtensionColsToString casts extension columns to string.
func castStringsToExtensions(rec arrow.Record, arrowSchema *arrow.Schema) (arrow.Record, error) {
rb := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema)
for c := 0; c < int(rec.NumCols()); c++ {
col := rec.Column(c)
switch {
Expand Down
23 changes: 13 additions & 10 deletions parquet/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,31 @@ import (
"github.com/cloudquery/plugin-sdk/v2/types"
)

func (*Client) WriteTableBatch(w io.Writer, arrowSchema *arrow.Schema, records []arrow.Record) error {
func (c *Client) WriteTableBatch(w io.Writer, arrowSchema *arrow.Schema, records []arrow.Record) error {
props := parquet.NewWriterProperties()
arrprops := pqarrow.DefaultWriterProps()
newSchema := convertSchema(arrowSchema)
fw, err := pqarrow.NewFileWriter(newSchema, w, props, arrprops)
if err != nil {
return err
}
mem := memory.DefaultAllocator
for _, rec := range records {
castRec, err := castExtensionColsToString(mem, rec)
err := c.writeRecord(rec, fw)
if err != nil {
return fmt.Errorf("failed to cast to string: %w", err)
}
if err := fw.Write(castRec); err != nil {
return err
}
}
return fw.Close()
}

func (*Client) writeRecord(rec arrow.Record, fw *pqarrow.FileWriter) error {
castRec, err := castExtensionColsToString(rec)
if err != nil {
return fmt.Errorf("failed to cast to string: %w", err)
}
return fw.Write(castRec)
}

func convertSchema(sch *arrow.Schema) *arrow.Schema {
oldFields := sch.Fields()
fields := make([]arrow.Field, len(oldFields))
Expand All @@ -55,11 +59,10 @@ func convertSchema(sch *arrow.Schema) *arrow.Schema {
return newSchema
}

func castExtensionColsToString(mem memory.Allocator, rec arrow.Record) (arrow.Record, error) {
// castExtensionColsToString casts extension columns to string. It does not release the original record.
func castExtensionColsToString(rec arrow.Record) (arrow.Record, error) {
newSchema := convertSchema(rec.Schema())
rb := array.NewRecordBuilder(mem, newSchema)

defer rb.Release()
rb := array.NewRecordBuilder(memory.DefaultAllocator, newSchema)
for c := 0; c < int(rec.NumCols()); c++ {
col := rec.Column(c)
switch {
Expand Down
16 changes: 4 additions & 12 deletions parquet/write_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/memory"
"github.com/cloudquery/plugin-sdk/v2/plugins/destination"
"github.com/cloudquery/plugin-sdk/v2/testdata"
)
Expand All @@ -19,19 +18,12 @@ func TestWriteRead(t *testing.T) {
arrowSchema := table.ToArrowSchema()
sourceName := "test-source"
syncTime := time.Now().UTC().Round(1 * time.Second)
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)
opts := testdata.GenTestDataOptions{
SourceName: sourceName,
SyncTime: syncTime,
MaxRows: 1,
MaxRows: 2,
}
records := testdata.GenTestData(mem, arrowSchema, opts)
defer func() {
for _, r := range records {
r.Release()
}
}()
records := testdata.GenTestData(arrowSchema, opts)
writer := bufio.NewWriter(&b)
reader := bufio.NewReader(&b)

Expand Down Expand Up @@ -68,7 +60,7 @@ func TestWriteRead(t *testing.T) {
if readErr != nil {
t.Fatal(readErr)
}
if totalCount != 1 {
t.Fatalf("expected 1 row, got %d", totalCount)
if totalCount != 2 {
t.Fatalf("expected 2 rows, got %d", totalCount)
}
}
6 changes: 0 additions & 6 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,6 @@ import (
)

func (cl *Client) WriteTableBatchFile(w io.Writer, arrowSchema *arrow.Schema, records []arrow.Record) error {
defer func() {
for _, r := range records {
r.Release()
}
}()

switch cl.spec.Format {
case FormatTypeCSV:
if err := cl.csv.WriteTableBatch(w, arrowSchema, records); err != nil {
Expand Down

0 comments on commit 9b97bae

Please sign in to comment.