Skip to content

Commit

Permalink
feat: Include CSV headers (#30)
Browse files Browse the repository at this point in the history

Adds ability to conditionally include headers for CSV files
  • Loading branch information
bbernays authored Jan 18, 2023
1 parent d6f1734 commit 9ab6df8
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 47 deletions.
32 changes: 32 additions & 0 deletions csv/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package csv

type Options func(*Client)

// Client is a csv client.
type Client struct {
IncludeHeaders bool
Delimiter rune
}

func NewClient(options ...Options) (*Client, error) {
c := &Client{
Delimiter: ',',
}
for _, option := range options {
option(c)
}

return c, nil
}

func WithHeader() Options {
return func(c *Client) {
c.IncludeHeaders = true
}
}

func WithDelimiter(delimiter rune) Options {
return func(c *Client) {
c.Delimiter = delimiter
}
}
11 changes: 9 additions & 2 deletions csv/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,19 @@ import (
"github.com/cloudquery/plugin-sdk/schema"
)

func Read(f io.Reader, table *schema.Table, sourceName string, res chan<- []any) error {
reader := csv.NewReader(f)
func (cl *Client) Read(r io.Reader, table *schema.Table, sourceName string, res chan<- []any) error {
reader := csv.NewReader(r)
reader.Comma = cl.Delimiter
sourceNameIndex := table.Columns.Index(schema.CqSourceNameColumn.Name)
if sourceNameIndex == -1 {
return fmt.Errorf("could not find column %s in table %s", schema.CqSourceNameColumn.Name, table.Name)
}
if cl.IncludeHeaders {
_, err := reader.Read()
if err != nil {
return err
}
}
for {
record, err := reader.Read()
if err != nil {
Expand Down
23 changes: 22 additions & 1 deletion csv/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@ import (
"github.com/cloudquery/plugin-sdk/schema"
)

func WriteTableBatch(w io.Writer, _ *schema.Table, resources [][]any) error {
func (cl *Client) WriteTableBatch(w io.Writer, table *schema.Table, resources [][]any) error {
writer := csv.NewWriter(w)
writer.Comma = cl.Delimiter
if cl.IncludeHeaders {
if err := cl.WriteTableHeaders(w, table); err != nil {
return err
}
}
for _, resource := range resources {
record := make([]string, len(resource))
for i, v := range resource {
Expand All @@ -21,3 +27,18 @@ func WriteTableBatch(w io.Writer, _ *schema.Table, resources [][]any) error {
writer.Flush()
return nil
}

func (cl *Client) WriteTableHeaders(w io.Writer, table *schema.Table) error {
writer := csv.NewWriter(w)
writer.Comma = cl.Delimiter

tableHeaders := make([]string, len(table.Columns))
for index, header := range table.Columns {
tableHeaders[index] = header.Name
}
if err := writer.Write(tableHeaders); err != nil {
return err
}
writer.Flush()
return nil
}
98 changes: 61 additions & 37 deletions csv/write_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,43 +10,67 @@ import (
)

func TestWriteRead(t *testing.T) {
var b bytes.Buffer
table := testdata.TestTable("test")
cqtypes := testdata.GenTestData(table)
if err := cqtypes[0].Set("test-source"); err != nil {
t.Fatal(err)
cases := []struct {
name string
options []Options
outputCount int
}{
{name: "default", outputCount: 1},
{name: "with_headers", options: []Options{WithHeader()}, outputCount: 1},
{name: "with_delimiter", options: []Options{WithDelimiter('\t')}, outputCount: 1},
{name: "with_delimter_headers", options: []Options{WithDelimiter('\t'), WithHeader()}, outputCount: 1},
}
writer := bufio.NewWriter(&b)
transformer := &Transformer{}
transformedValues := schema.TransformWithTransformer(transformer, cqtypes)
// schema.TransformWithTransformer(tra)
if err := WriteTableBatch(writer, table, [][]any{transformedValues}); err != nil {
t.Fatal(err)
}
writer.Flush()
reader := bufio.NewReader(&b)
ch := make(chan []any)
var readErr error
go func() {
readErr = Read(reader, table, "test-source", ch)
close(ch)
}()
totalCount := 0
reverseTransformer := &ReverseTransformer{}
for row := range ch {
gotCqtypes, err := reverseTransformer.ReverseTransformValues(table, row)
if err != nil {
t.Fatal(err)
}
if diff := cqtypes.Diff(gotCqtypes); diff != "" {
t.Fatalf("got diff: %s", diff)
}
totalCount++
}
if readErr != nil {
t.Fatal(readErr)
}
if totalCount != 1 {
t.Fatalf("expected 1 row, got %d", totalCount)

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
var b bytes.Buffer
table := testdata.TestTable("test")
cqtypes := testdata.GenTestData(table)
if err := cqtypes[0].Set("test-source"); err != nil {
t.Fatal(err)
}
writer := bufio.NewWriter(&b)
reader := bufio.NewReader(&b)
transformer := &Transformer{}
transformedValues := schema.TransformWithTransformer(transformer, cqtypes)
client, err := NewClient(tc.options...)
if err != nil {
t.Fatal(err)
}

if err := client.WriteTableBatch(writer, table, [][]any{transformedValues}); err != nil {
t.Fatal(err)
}
writer.Flush()

ch := make(chan []any)
var readErr error
go func() {
readErr = client.Read(reader, table, "test-source", ch)
close(ch)
}()
totalCount := 0
reverseTransformer := &ReverseTransformer{}
for row := range ch {
if client.IncludeHeaders && totalCount == 0 {
totalCount++
continue
}
gotCqtypes, err := reverseTransformer.ReverseTransformValues(table, row)
if err != nil {
t.Fatal(err)
}
if diff := cqtypes.Diff(gotCqtypes); diff != "" {
t.Fatalf("got diff: %s", diff)
}
totalCount++
}
if readErr != nil {
t.Fatal(readErr)
}
if totalCount != tc.outputCount {
t.Fatalf("expected %d row, got %d", tc.outputCount, totalCount)
}
})
}
}
14 changes: 14 additions & 0 deletions json/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package json

type Option func(*Client)

type Client struct{}

func NewClient(options ...Option) (*Client, error) {
c := &Client{}
for _, option := range options {
option(c)
}

return c, nil
}
2 changes: 1 addition & 1 deletion json/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

const maxJSONSize = 1024 * 1024 * 20

func Read(f io.Reader, table *schema.Table, sourceName string, res chan<- []any) error {
func (*Client) Read(f io.Reader, table *schema.Table, sourceName string, res chan<- []any) error {
sourceNameIndex := table.Columns.Index(schema.CqSourceNameColumn.Name)
if sourceNameIndex == -1 {
return fmt.Errorf("could not find column %s in table %s", schema.CqSourceNameColumn.Name, table.Name)
Expand Down
2 changes: 1 addition & 1 deletion json/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"github.com/cloudquery/plugin-sdk/schema"
)

func WriteTableBatch(w io.Writer, table *schema.Table, resources [][]any) error {
func (*Client) WriteTableBatch(w io.Writer, table *schema.Table, resources [][]any) error {
for _, resource := range resources {
jsonObj := make(map[string]any, len(table.Columns))
for i := range resource {
Expand Down
16 changes: 11 additions & 5 deletions json/write_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,25 @@ func TestWriteRead(t *testing.T) {
if err := cqtypes[0].Set("test-source"); err != nil {
t.Fatal(err)
}

writer := bufio.NewWriter(&b)
transformer := &Transformer{}
transformedValues := schema.TransformWithTransformer(transformer, cqtypes)
if err := WriteTableBatch(writer, table, [][]any{transformedValues}); err != nil {

writer := bufio.NewWriter(&b)
reader := bufio.NewReader(&b)

cl, err := NewClient()
if err != nil {
t.Fatal(err)
}
if err := cl.WriteTableBatch(writer, table, [][]any{transformedValues}); err != nil {
t.Fatal(err)
}
writer.Flush()
reader := bufio.NewReader(&b)

ch := make(chan []any)
var readErr error
go func() {
readErr = Read(reader, table, "test-source", ch)
readErr = cl.Read(reader, table, "test-source", ch)
close(ch)
}()
totalCount := 0
Expand Down

0 comments on commit 9ab6df8

Please sign in to comment.