diff --git a/csv/client.go b/csv/client.go new file mode 100644 index 00000000..661c2b77 --- /dev/null +++ b/csv/client.go @@ -0,0 +1,32 @@ +package csv + +type Options func(*Client) + +// Client is a csv client. +type Client struct { + IncludeHeaders bool + Delimiter rune +} + +func NewClient(options ...Options) (*Client, error) { + c := &Client{ + Delimiter: ',', + } + for _, option := range options { + option(c) + } + + return c, nil +} + +func WithHeader() Options { + return func(c *Client) { + c.IncludeHeaders = true + } +} + +func WithDelimiter(delimiter rune) Options { + return func(c *Client) { + c.Delimiter = delimiter + } +} diff --git a/csv/read.go b/csv/read.go index 3d194ce0..a431b5df 100644 --- a/csv/read.go +++ b/csv/read.go @@ -9,12 +9,19 @@ import ( "github.com/cloudquery/plugin-sdk/schema" ) -func Read(f io.Reader, table *schema.Table, sourceName string, res chan<- []any) error { - reader := csv.NewReader(f) +func (cl *Client) Read(r io.Reader, table *schema.Table, sourceName string, res chan<- []any) error { + reader := csv.NewReader(r) + reader.Comma = cl.Delimiter sourceNameIndex := table.Columns.Index(schema.CqSourceNameColumn.Name) if sourceNameIndex == -1 { return fmt.Errorf("could not find column %s in table %s", schema.CqSourceNameColumn.Name, table.Name) } + if cl.IncludeHeaders { + _, err := reader.Read() + if err != nil { + return err + } + } for { record, err := reader.Read() if err != nil { diff --git a/csv/write.go b/csv/write.go index 6db88d53..c1a367c8 100644 --- a/csv/write.go +++ b/csv/write.go @@ -7,8 +7,14 @@ import ( "github.com/cloudquery/plugin-sdk/schema" ) -func WriteTableBatch(w io.Writer, _ *schema.Table, resources [][]any) error { +func (cl *Client) WriteTableBatch(w io.Writer, table *schema.Table, resources [][]any) error { writer := csv.NewWriter(w) + writer.Comma = cl.Delimiter + if cl.IncludeHeaders { + if err := cl.WriteTableHeaders(w, table); err != nil { + return err + } + } for _, resource := range resources { record := make([]string, len(resource)) for i, v := range resource { @@ -21,3 +27,18 @@ func WriteTableBatch(w io.Writer, _ *schema.Table, resources [][]any) error { writer.Flush() return nil } + +func (cl *Client) WriteTableHeaders(w io.Writer, table *schema.Table) error { + writer := csv.NewWriter(w) + writer.Comma = cl.Delimiter + + tableHeaders := make([]string, len(table.Columns)) + for index, header := range table.Columns { + tableHeaders[index] = header.Name + } + if err := writer.Write(tableHeaders); err != nil { + return err + } + writer.Flush() + return nil +} diff --git a/csv/write_read_test.go b/csv/write_read_test.go index 0493ba41..6d6a56b9 100644 --- a/csv/write_read_test.go +++ b/csv/write_read_test.go @@ -10,43 +10,67 @@ import ( ) func TestWriteRead(t *testing.T) { - var b bytes.Buffer - table := testdata.TestTable("test") - cqtypes := testdata.GenTestData(table) - if err := cqtypes[0].Set("test-source"); err != nil { - t.Fatal(err) + cases := []struct { + name string + options []Options + outputCount int + }{ + {name: "default", outputCount: 1}, + {name: "with_headers", options: []Options{WithHeader()}, outputCount: 1}, + {name: "with_delimiter", options: []Options{WithDelimiter('\t')}, outputCount: 1}, + {name: "with_delimter_headers", options: []Options{WithDelimiter('\t'), WithHeader()}, outputCount: 1}, } - writer := bufio.NewWriter(&b) - transformer := &Transformer{} - transformedValues := schema.TransformWithTransformer(transformer, cqtypes) - // schema.TransformWithTransformer(tra) - if err := WriteTableBatch(writer, table, [][]any{transformedValues}); err != nil { - t.Fatal(err) - } - writer.Flush() - reader := bufio.NewReader(&b) - ch := make(chan []any) - var readErr error - go func() { - readErr = Read(reader, table, "test-source", ch) - close(ch) - }() - totalCount := 0 - reverseTransformer := &ReverseTransformer{} - for row := range ch { - gotCqtypes, err := reverseTransformer.ReverseTransformValues(table, row) - if err != nil { - t.Fatal(err) - } - if diff := cqtypes.Diff(gotCqtypes); diff != "" { - t.Fatalf("got diff: %s", diff) - } - totalCount++ - } - if readErr != nil { - t.Fatal(readErr) - } - if totalCount != 1 { - t.Fatalf("expected 1 row, got %d", totalCount) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var b bytes.Buffer + table := testdata.TestTable("test") + cqtypes := testdata.GenTestData(table) + if err := cqtypes[0].Set("test-source"); err != nil { + t.Fatal(err) + } + writer := bufio.NewWriter(&b) + reader := bufio.NewReader(&b) + transformer := &Transformer{} + transformedValues := schema.TransformWithTransformer(transformer, cqtypes) + client, err := NewClient(tc.options...) + if err != nil { + t.Fatal(err) + } + + if err := client.WriteTableBatch(writer, table, [][]any{transformedValues}); err != nil { + t.Fatal(err) + } + writer.Flush() + + ch := make(chan []any) + var readErr error + go func() { + readErr = client.Read(reader, table, "test-source", ch) + close(ch) + }() + totalCount := 0 + reverseTransformer := &ReverseTransformer{} + for row := range ch { + if client.IncludeHeaders && totalCount == 0 { + totalCount++ + continue + } + gotCqtypes, err := reverseTransformer.ReverseTransformValues(table, row) + if err != nil { + t.Fatal(err) + } + if diff := cqtypes.Diff(gotCqtypes); diff != "" { + t.Fatalf("got diff: %s", diff) + } + totalCount++ + } + if readErr != nil { + t.Fatal(readErr) + } + if totalCount != tc.outputCount { + t.Fatalf("expected %d row, got %d", tc.outputCount, totalCount) + } + }) } } diff --git a/json/client.go b/json/client.go new file mode 100644 index 00000000..0fd6045a --- /dev/null +++ b/json/client.go @@ -0,0 +1,14 @@ +package json + +type Option func(*Client) + +type Client struct{} + +func NewClient(options ...Option) (*Client, error) { + c := &Client{} + for _, option := range options { + option(c) + } + + return c, nil +} diff --git a/json/read.go b/json/read.go index 666e49ff..9fcbe811 100644 --- a/json/read.go +++ b/json/read.go @@ -11,7 +11,7 @@ import ( const maxJSONSize = 1024 * 1024 * 20 -func Read(f io.Reader, table *schema.Table, sourceName string, res chan<- []any) error { +func (*Client) Read(f io.Reader, table *schema.Table, sourceName string, res chan<- []any) error { sourceNameIndex := table.Columns.Index(schema.CqSourceNameColumn.Name) if sourceNameIndex == -1 { return fmt.Errorf("could not find column %s in table %s", schema.CqSourceNameColumn.Name, table.Name) diff --git a/json/write.go b/json/write.go index a2678825..8f7d76c1 100644 --- a/json/write.go +++ b/json/write.go @@ -7,7 +7,7 @@ import ( "github.com/cloudquery/plugin-sdk/schema" ) -func WriteTableBatch(w io.Writer, table *schema.Table, resources [][]any) error { +func (*Client) WriteTableBatch(w io.Writer, table *schema.Table, resources [][]any) error { for _, resource := range resources { jsonObj := make(map[string]any, len(table.Columns)) for i := range resource { diff --git a/json/write_read_test.go b/json/write_read_test.go index 498995f9..31aca079 100644 --- a/json/write_read_test.go +++ b/json/write_read_test.go @@ -16,19 +16,25 @@ func TestWriteRead(t *testing.T) { if err := cqtypes[0].Set("test-source"); err != nil { t.Fatal(err) } - - writer := bufio.NewWriter(&b) transformer := &Transformer{} transformedValues := schema.TransformWithTransformer(transformer, cqtypes) - if err := WriteTableBatch(writer, table, [][]any{transformedValues}); err != nil { + + writer := bufio.NewWriter(&b) + reader := bufio.NewReader(&b) + + cl, err := NewClient() + if err != nil { + t.Fatal(err) + } + if err := cl.WriteTableBatch(writer, table, [][]any{transformedValues}); err != nil { t.Fatal(err) } writer.Flush() - reader := bufio.NewReader(&b) + ch := make(chan []any) var readErr error go func() { - readErr = Read(reader, table, "test-source", ch) + readErr = cl.Read(reader, table, "test-source", ch) close(ch) }() totalCount := 0