Skip to content

Commit

Permalink
SNOW-1029646: Add WithFileGetStream that supports downloading a file …
Browse files Browse the repository at this point in the history
…into stream (#1192)

Added the context WithFileGetStream to support downloading a file into in-memory stream.
  • Loading branch information
sfc-gh-ext-simba-jl authored Aug 27, 2024
1 parent aa4b7cd commit f434413
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 115 deletions.
34 changes: 24 additions & 10 deletions azure_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type azureAPI interface {
UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
UploadFile(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
DownloadFile(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
DownloadStream(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error)
GetProperties(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
}

Expand Down Expand Up @@ -276,16 +277,29 @@ func (util *snowflakeAzureClient) nativeDownloadFile(
if meta.mockAzureClient != nil {
blobClient = meta.mockAzureClient
}
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
if err != nil {
return err
}
defer f.Close()
_, err = blobClient.DownloadFile(
context.Background(), f, &azblob.DownloadFileOptions{
Concurrency: uint16(maxConcurrency)})
if err != nil {
return err
if meta.options.getFileToStream {
blobDownloadResponse, err := blobClient.DownloadStream(context.Background(), &azblob.DownloadStreamOptions{})
if err != nil {
return err
}
retryReader := blobDownloadResponse.NewRetryReader(context.Background(), &azblob.RetryReaderOptions{})
defer retryReader.Close()
_, err = meta.dstStream.ReadFrom(retryReader)
if err != nil {
return err
}
} else {
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
if err != nil {
return err
}
defer f.Close()
_, err = blobClient.DownloadFile(
context.Background(), f, &azblob.DownloadFileOptions{
Concurrency: uint16(maxConcurrency)})
if err != nil {
return err
}
}
meta.resStatus = downloaded
return nil
Expand Down
13 changes: 9 additions & 4 deletions azure_storage_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ func TestUnitDetectAzureTokenExpireError(t *testing.T) {
}

type azureObjectAPIMock struct {
UploadStreamFunc func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
UploadFileFunc func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
DownloadFileFunc func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
UploadStreamFunc func(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error)
UploadFileFunc func(ctx context.Context, file *os.File, o *azblob.UploadFileOptions) (azblob.UploadFileResponse, error)
DownloadFileFunc func(ctx context.Context, file *os.File, o *blob.DownloadFileOptions) (int64, error)
DownloadStreamFunc func(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error)
GetPropertiesFunc func(ctx context.Context, o *blob.GetPropertiesOptions) (blob.GetPropertiesResponse, error)
}

func (c *azureObjectAPIMock) UploadStream(ctx context.Context, body io.Reader, o *azblob.UploadStreamOptions) (azblob.UploadStreamResponse, error) {
Expand All @@ -131,6 +132,10 @@ func (c *azureObjectAPIMock) DownloadFile(ctx context.Context, file *os.File, o
return c.DownloadFileFunc(ctx, file, o)
}

func (c *azureObjectAPIMock) DownloadStream(ctx context.Context, o *blob.DownloadStreamOptions) (azblob.DownloadStreamResponse, error) {
return c.DownloadStreamFunc(ctx, o)
}

func TestUploadFileWithAzureUploadFailedError(t *testing.T) {
info := execResponseStageInfo{
Location: "azblob/storage/users/456/",
Expand Down
28 changes: 24 additions & 4 deletions connection_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package gosnowflake
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -88,10 +89,11 @@ func (sc *snowflakeConn) processFileTransfer(
isInternal bool) (
*execResponse, error) {
sfa := snowflakeFileTransferAgent{
sc: sc,
data: &data.Data,
command: query,
options: new(SnowflakeFileTransferOptions),
sc: sc,
data: &data.Data,
command: query,
options: new(SnowflakeFileTransferOptions),
streamBuffer: new(bytes.Buffer),
}
if fs := getFileStream(ctx); fs != nil {
sfa.sourceStream = fs
Expand All @@ -112,6 +114,11 @@ func (sc *snowflakeConn) processFileTransfer(
if err != nil {
return nil, err
}
if sfa.options.getFileToStream {
if err := writeFileStream(ctx, sfa.streamBuffer); err != nil {
return nil, err
}
}
return data, nil
}

Expand All @@ -138,6 +145,19 @@ func getFileTransferOptions(ctx context.Context) *SnowflakeFileTransferOptions {
return o
}

func writeFileStream(ctx context.Context, streamBuf *bytes.Buffer) error {
s := ctx.Value(fileGetStream)
w, ok := s.(io.Writer)
if !ok {
return errors.New("expected an io.Writer")
}
_, err := streamBuf.WriteTo(w)
if err != nil {
return err
}
return nil
}

func (sc *snowflakeConn) populateSessionParameters(parameters []nameValueParameter) {
// other session parameters (not all)
logger.WithContext(sc.ctx).Infof("params: %#v", parameters)
Expand Down
13 changes: 13 additions & 0 deletions doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,19 @@ an absolute path rather than a relative path. For example:
db.Query("GET @~ file:///tmp/my_data_file auto_compress=false overwrite=false")
To download a file into an in-memory stream (rather than a file) use code similar to the code below.
var streamBuf bytes.Buffer
ctx := WithFileTransferOptions(context.Background(), &SnowflakeFileTransferOptions{getFileToStream: true})
ctx = WithFileGetStream(ctx, &streamBuf)
sql := "get @~/data1.txt.gz file:///tmp/testData"
dbt.mustExecContext(ctx, sql)
// streamBuf is now filled with the stream. Use bytes.NewReader(streamBuf.Bytes()) to read uncompressed stream or
// use gzip.NewReader(&streamBuf) for to read compressed stream.
Note: GET statements are not supported for multi-statement queries.
Specifying temporary directory for encryption and compression:
Putting and getting requires compression and/or encryption, which is done in the OS temporary directory.
Expand Down
72 changes: 51 additions & 21 deletions encrypt_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,46 +190,51 @@ func encryptFile(
return meta, tmpOutputFile.Name(), nil
}

func decryptFile(
func decryptFileKey(
metadata *encryptMetadata,
sfe *snowflakeFileEncryption,
filename string,
chunkSize int,
tmpDir string) (
string, error) {
if chunkSize == 0 {
chunkSize = aes.BlockSize * 4 * 1024
}
sfe *snowflakeFileEncryption) ([]byte, []byte, error) {
decodedKey, err := base64.StdEncoding.DecodeString(sfe.QueryStageMasterKey)
if err != nil {
return "", err
return nil, nil, err
}
keyBytes, err := base64.StdEncoding.DecodeString(metadata.key) // encrypted file key
if err != nil {
return "", err
return nil, nil, err
}
ivBytes, err := base64.StdEncoding.DecodeString(metadata.iv)
if err != nil {
return "", err
return nil, nil, err
}

// decrypt file key
decryptedKey := make([]byte, len(keyBytes))
if err = decryptECB(decryptedKey, keyBytes, decodedKey); err != nil {
return "", err
return nil, nil, err
}
decryptedKey, err = paddingTrim(decryptedKey)
if err != nil {
return "", err
return nil, nil, err
}

// decrypt file
return decryptedKey, ivBytes, err
}

func initCBC(decryptedKey []byte, ivBytes []byte) (cipher.BlockMode, error) {
block, err := aes.NewCipher(decryptedKey)
if err != nil {
return "", err
return nil, err
}
mode := cipher.NewCBCDecrypter(block, ivBytes)

return mode, err
}

func decryptFile(
metadata *encryptMetadata,
sfe *snowflakeFileEncryption,
filename string,
chunkSize int,
tmpDir string) (string, error) {
tmpOutputFile, err := os.CreateTemp(tmpDir, baseName(filename)+"#")
if err != nil {
return "", err
Expand All @@ -240,11 +245,37 @@ func decryptFile(
return "", err
}
defer infile.Close()
totalFileSize, err := decryptStream(metadata, sfe, chunkSize, infile, tmpOutputFile)
if err != nil {
return "", err
}
tmpOutputFile.Truncate(int64(totalFileSize))
return tmpOutputFile.Name(), nil
}

func decryptStream(
metadata *encryptMetadata,
sfe *snowflakeFileEncryption,
chunkSize int,
src io.Reader,
out io.Writer) (int, error) {
if chunkSize == 0 {
chunkSize = aes.BlockSize * 4 * 1024
}
decryptedKey, ivBytes, err := decryptFileKey(metadata, sfe)
if err != nil {
return 0, err
}
mode, err := initCBC(decryptedKey, ivBytes)
if err != nil {
return 0, err
}

var totalFileSize int
var prevChunk []byte
for {
chunk := make([]byte, chunkSize)
n, err := infile.Read(chunk)
n, err := src.Read(chunk)
if n == 0 || err != nil {
break
} else if n%aes.BlockSize != 0 {
Expand All @@ -255,17 +286,16 @@ func decryptFile(
totalFileSize += n
chunk = chunk[:n]
mode.CryptBlocks(chunk, chunk)
tmpOutputFile.Write(chunk)
out.Write(chunk)
prevChunk = chunk
}
if err != nil {
return "", err
return 0, err
}
if prevChunk != nil {
totalFileSize -= paddingOffset(prevChunk)
}
tmpOutputFile.Truncate(int64(totalFileSize))
return tmpOutputFile.Name(), nil
return totalFileSize, err
}

type materialDescriptor struct {
Expand Down
5 changes: 5 additions & 0 deletions file_transfer_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ type SnowflakeFileTransferOptions struct {
/* streaming PUT */
compressSourceFromStream bool

/* streaming GET */
getFileToStream bool

/* PUT */
putCallback *snowflakeProgressPercentage
putAzureCallback *snowflakeProgressPercentage
Expand Down Expand Up @@ -124,6 +127,7 @@ type snowflakeFileTransferAgent struct {
useAccelerateEndpoint bool
presignedURLs []string
options *SnowflakeFileTransferOptions
streamBuffer *bytes.Buffer
}

func (sfa *snowflakeFileTransferAgent) execute() error {
Expand Down Expand Up @@ -411,6 +415,7 @@ func (sfa *snowflakeFileTransferAgent) initFileMetadata() error {
name: baseName(fileName),
srcFileName: fileName,
dstFileName: dstFileName,
dstStream: new(bytes.Buffer),
stageLocationType: sfa.stageLocationType,
stageInfo: sfa.stageInfo,
localLocation: sfa.localLocation,
Expand Down
3 changes: 3 additions & 0 deletions file_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ type fileMetadata struct {
srcStream *bytes.Buffer
realSrcStream *bytes.Buffer

/* streaming GET */
dstStream *bytes.Buffer

/* GCS */
presignedURL *url.URL
gcsFileHeaderDigest string
Expand Down
31 changes: 18 additions & 13 deletions gcs_storage_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,24 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
return meta.lastError
}

f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
if err != nil {
return err
}
defer f.Close()
if _, err = io.Copy(f, resp.Body); err != nil {
return err
if meta.options.getFileToStream {
if _, err := io.Copy(meta.dstStream, resp.Body); err != nil {
return err
}
} else {
f, err := os.OpenFile(fullDstFileName, os.O_CREATE|os.O_WRONLY, readWriteFileMode)
if err != nil {
return err
}
defer f.Close()
if _, err = io.Copy(f, resp.Body); err != nil {
return err
}
fi, err := os.Stat(fullDstFileName)
if err != nil {
return err
}
meta.srcFileSize = fi.Size()
}

var encryptMeta encryptMetadata
Expand All @@ -348,12 +359,6 @@ func (util *snowflakeGcsClient) nativeDownloadFile(
}
}
}

fi, err := os.Stat(fullDstFileName)
if err != nil {
return err
}
meta.srcFileSize = fi.Size()
meta.resStatus = downloaded
meta.gcsFileHeaderDigest = resp.Header.Get(gcsMetadataSfcDigest)
meta.gcsFileHeaderContentLength = resp.ContentLength
Expand Down
Loading

0 comments on commit f434413

Please sign in to comment.