Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for non-Python ipynb notebooks to DABs #1827

Merged
merged 18 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/bundle/generate/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (n *downloader) markNotebookForDownload(ctx context.Context, notebookPath *

ext := notebook.GetExtensionByLanguage(info)

filename := path.Base(*notebookPath) + ext
filename := path.Base(*notebookPath) + string(ext)
targetPath := filepath.Join(n.sourceDir, filename)

n.files[targetPath] = *notebookPath
Expand Down
2 changes: 1 addition & 1 deletion cmd/workspace/workspace/export_dir.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (opts exportDirOptions) callback(ctx context.Context, workspaceFiler filer.
return err
}
objectInfo := info.Sys().(workspace.ObjectInfo)
targetPath += notebook.GetExtensionByLanguage(&objectInfo)
targetPath += string(notebook.GetExtensionByLanguage(&objectInfo))

// Skip file if a file already exists in path.
// os.Stat returns a fs.ErrNotExist if a file does not exist at path.
Expand Down
572 changes: 362 additions & 210 deletions internal/filer_test.go

Large diffs are not rendered by default.

22 changes: 13 additions & 9 deletions internal/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,13 @@ func RequireErrorRun(t *testing.T, args ...string) (bytes.Buffer, bytes.Buffer,
return stdout, stderr, err
}

func readFile(t *testing.T, name string) string {
b, err := os.ReadFile(name)
require.NoError(t, err)

return string(b)
}

func writeFile(t *testing.T, name string, body string) string {
f, err := os.Create(filepath.Join(t.TempDir(), name))
require.NoError(t, err)
Expand Down Expand Up @@ -562,12 +569,10 @@ func setupLocalFiler(t *testing.T) (filer.Filer, string) {
}

func setupWsfsFiler(t *testing.T) (filer.Filer, string) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
ctx, wt := acc.WorkspaceTest(t)

ctx := context.Background()
w := databricks.Must(databricks.NewWorkspaceClient())
tmpdir := TemporaryWorkspaceDir(t, w)
f, err := filer.NewWorkspaceFilesClient(w, tmpdir)
tmpdir := TemporaryWorkspaceDir(t, wt.W)
f, err := filer.NewWorkspaceFilesClient(wt.W, tmpdir)
require.NoError(t, err)

// Check if we can use this API here, skip test if we cannot.
Expand All @@ -581,11 +586,10 @@ func setupWsfsFiler(t *testing.T) (filer.Filer, string) {
}

func setupWsfsExtensionsFiler(t *testing.T) (filer.Filer, string) {
t.Log(GetEnvOrSkipTest(t, "CLOUD_ENV"))
_, wt := acc.WorkspaceTest(t)

w := databricks.Must(databricks.NewWorkspaceClient())
tmpdir := TemporaryWorkspaceDir(t, w)
f, err := filer.NewWorkspaceFilesExtensionsClient(w, tmpdir)
tmpdir := TemporaryWorkspaceDir(t, wt.W)
f, err := filer.NewWorkspaceFilesExtensionsClient(wt.W, tmpdir)
require.NoError(t, err)

return f, tmpdir
Expand Down
27 changes: 27 additions & 0 deletions internal/testdata/notebooks/py1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
27 changes: 27 additions & 0 deletions internal/testdata/notebooks/py2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.8.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
25 changes: 25 additions & 0 deletions internal/testdata/notebooks/r1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "R",
"language": "R",
"name": "ir"
},
"language_info": {
"name": "R"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
29 changes: 29 additions & 0 deletions internal/testdata/notebooks/r2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "r"
}
},
"outputs": [],
"source": [
"print(2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "R",
"language": "R",
"name": "ir"
},
"language_info": {
"name": "R"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
38 changes: 38 additions & 0 deletions internal/testdata/notebooks/scala1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
}
],
"source": [
"println(1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Scala",
"language": "scala",
"name": "scala"
},
"language_info": {
"codemirror_mode": "text/x-scala",
"file_extension": ".sc",
"mimetype": "text/x-scala",
"name": "scala",
"nbconvert_exporter": "script",
"version": "2.13.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
38 changes: 38 additions & 0 deletions internal/testdata/notebooks/scala2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n"
]
}
],
"source": [
"println(2)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Scala",
"language": "scala",
"name": "scala"
},
"language_info": {
"codemirror_mode": "text/x-scala",
"file_extension": ".sc",
"mimetype": "text/x-scala",
"name": "scala",
"nbconvert_exporter": "script",
"version": "2.13.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
20 changes: 20 additions & 0 deletions internal/testdata/notebooks/sql1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"select 1"
]
}
],
"metadata": {
"language_info": {
"name": "sql"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
20 changes: 20 additions & 0 deletions internal/testdata/notebooks/sql2.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"select 2"
]
}
],
"metadata": {
"language_info": {
"name": "sql"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
46 changes: 25 additions & 21 deletions libs/filer/workspace_files_extensions_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"io/fs"
"path"
"slices"
"strings"

"github.com/databricks/cli/libs/log"
Expand All @@ -23,14 +24,6 @@ type workspaceFilesExtensionsClient struct {
readonly bool
}

var extensionsToLanguages = map[string]workspace.Language{
".py": workspace.LanguagePython,
".r": workspace.LanguageR,
".scala": workspace.LanguageScala,
".sql": workspace.LanguageSql,
".ipynb": workspace.LanguagePython,
}

type workspaceFileStatus struct {
wsfsFileInfo

Expand All @@ -50,11 +43,16 @@ func (w *workspaceFilesExtensionsClient) stat(ctx context.Context, name string)
// This function returns the stat for the provided notebook. The stat object itself contains the path
// with the extension since it is meant to be used in the context of a fs.FileInfo.
func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx context.Context, name string) (*workspaceFileStatus, error) {
ext := path.Ext(name)
nameWithoutExt := strings.TrimSuffix(name, ext)
ext := notebook.Extension(path.Ext(name))
nameWithoutExt := strings.TrimSuffix(name, string(ext))

// File name does not have an extension associated with Databricks notebooks, return early.
if _, ok := extensionsToLanguages[ext]; !ok {
if !slices.Contains([]notebook.Extension{
notebook.ExtensionPython,
notebook.ExtensionR,
notebook.ExtensionScala,
notebook.ExtensionSql,
notebook.ExtensionJupyter}, ext) {
return nil, nil
}

Expand All @@ -75,29 +73,35 @@ func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithExt(ctx contex
return nil, nil
}

// Not the correct language. Return early.
if stat.Language != extensionsToLanguages[ext] {
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not of the correct language. Expected %s but found %s.", name, path.Join(w.root, nameWithoutExt), extensionsToLanguages[ext], stat.Language)
// Not the correct language. Return early. Note: All languages are supported
// for Jupyter notebooks.
if ext != notebook.ExtensionJupyter && stat.Language != notebook.ExtensionToLanguage[ext] {
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not of the correct language. Expected %s but found %s.", name, path.Join(w.root, nameWithoutExt), notebook.ExtensionToLanguage[ext], stat.Language)
return nil, nil
}

// When the extension is .py we expect the export format to be source.
// When the extension is one of .py, .r, .scala or .sql we expect the export format to be source.
// If it's not, return early.
if ext == ".py" && stat.ReposExportFormat != workspace.ExportFormatSource {
if slices.Contains([]notebook.Extension{
notebook.ExtensionPython,
notebook.ExtensionR,
notebook.ExtensionScala,
notebook.ExtensionSql}, ext) &&
stat.ReposExportFormat != workspace.ExportFormatSource {
shreyas-goenka marked this conversation as resolved.
Show resolved Hide resolved
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not exported as a source notebook. Its export format is %s.", name, path.Join(w.root, nameWithoutExt), stat.ReposExportFormat)
return nil, nil
}

// When the extension is .ipynb we expect the export format to be Jupyter.
// If it's not, return early.
if ext == ".ipynb" && stat.ReposExportFormat != workspace.ExportFormatJupyter {
if ext == notebook.ExtensionJupyter && stat.ReposExportFormat != workspace.ExportFormatJupyter {
log.Debugf(ctx, "attempting to determine if %s could be a notebook. Found a notebook at %s but it is not exported as a Jupyter notebook. Its export format is %s.", name, path.Join(w.root, nameWithoutExt), stat.ReposExportFormat)
return nil, nil
}

// Modify the stat object path to include the extension. This stat object will be used
// to return the fs.FileInfo object in the stat method.
stat.Path = stat.Path + ext
stat.Path = stat.Path + string(ext)
return &workspaceFileStatus{
wsfsFileInfo: stat,
nameForWorkspaceAPI: nameWithoutExt,
Expand All @@ -120,13 +124,13 @@ func (w *workspaceFilesExtensionsClient) getNotebookStatByNameWithoutExt(ctx con
ext := notebook.GetExtensionByLanguage(&stat.ObjectInfo)

// If the notebook was exported as a Jupyter notebook, the extension should be .ipynb.
if stat.Language == workspace.LanguagePython && stat.ReposExportFormat == workspace.ExportFormatJupyter {
ext = ".ipynb"
if stat.ReposExportFormat == workspace.ExportFormatJupyter {
ext = notebook.ExtensionJupyter
}

// Modify the stat object path to include the extension. This stat object will be used
// to return the fs.DirEntry object in the ReadDir method.
stat.Path = stat.Path + ext
stat.Path = stat.Path + string(ext)
return &workspaceFileStatus{
wsfsFileInfo: stat,
nameForWorkspaceAPI: name,
Expand Down
Loading
Loading