Skip to content

Commit

Permalink
Added support for different drives (#12146)
Browse files Browse the repository at this point in the history
  • Loading branch information
DiogoMFonseca authored Mar 26, 2024
1 parent 2c92e88 commit c5bd9ed
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,15 @@ def _get_credentials(self) -> Tuple[Credentials]:

def _get_fileids_meta(
self,
drive_id: Optional[str] = None,
folder_id: Optional[str] = None,
file_id: Optional[str] = None,
mime_types: Optional[List[str]] = None,
query_string: Optional[str] = None,
) -> List[List[str]]:
"""Get file ids present in folder/ file id
Args:
drive_id: Drive id of the shared drive in google drive.
folder_id: folder id of the folder in google drive.
file_id: file id of the file in google drive
mime_types: The mimeTypes you want to allow e.g.: "application/vnd.google-apps.document"
Expand Down Expand Up @@ -134,30 +136,54 @@ def _get_fileids_meta(
items = []
# get files taking into account that the results are paginated
while True:
results = (
service.files()
.list(
q=query,
includeItemsFromAllDrives=True,
supportsAllDrives=True,
fields="*",
if drive_id:
results = (
service.files()
.list(
q=query,
driveId=drive_id,
corpora="drive",
includeItemsFromAllDrives=True,
supportsAllDrives=True,
fields="*",
)
.execute()
)
else:
results = (
service.files()
.list(
q=query,
includeItemsFromAllDrives=True,
supportsAllDrives=True,
fields="*",
)
.execute()
)
.execute()
)
items.extend(results.get("files", []))
page_token = results.get("nextPageToken", None)
if page_token is None:
break

for item in items:
if item["mimeType"] == folder_mime_type:
fileids_meta.extend(
self._get_fileids_meta(
folder_id=item["id"],
mime_types=mime_types,
query_string=query_string,
if drive_id:
fileids_meta.extend(
self._get_fileids_meta(
drive_id=drive_id,
folder_id=item["id"],
mime_types=mime_types,
query_string=query_string,
)
)
else:
fileids_meta.extend(
self._get_fileids_meta(
folder_id=item["id"],
mime_types=mime_types,
query_string=query_string,
)
)
)
else:
# Check if file doesn't belong to a Shared Drive. "owners" doesn't exist in a Shared Drive
is_shared_drive = "driveId" in item
Expand Down Expand Up @@ -338,13 +364,15 @@ def _load_from_file_ids(

def _load_from_folder(
self,
drive_id: str,
folder_id: str,
mime_types: Optional[List[str]],
query_string: Optional[str],
) -> List[Document]:
"""Load data from folder_id.
Args:
drive_id: Drive id of the shared drive in google drive.
folder_id: folder id of the folder in google drive.
mime_types: The mimeTypes you want to allow e.g.: "application/vnd.google-apps.document"
query_string: A more generic query string to filter the documents, e.g. "name contains 'test'".
Expand All @@ -354,6 +382,7 @@ def _load_from_folder(
"""
try:
fileids_meta = self._get_fileids_meta(
drive_id=drive_id,
folder_id=folder_id,
mime_types=mime_types,
query_string=query_string,
Expand All @@ -366,6 +395,7 @@ def _load_from_folder(

def load_data(
self,
drive_id: Optional[str] = None,
folder_id: Optional[str] = None,
file_ids: Optional[List[str]] = None,
mime_types: Optional[List[str]] = None, # Deprecated
Expand All @@ -374,6 +404,7 @@ def load_data(
"""Load data from the folder id or file ids.
Args:
drive_id: Drive id of the shared drive in google drive.
folder_id: Folder id of the folder in google drive.
file_ids: File ids of the files in google drive.
mime_types: The mimeTypes you want to allow e.g.: "application/vnd.google-apps.document"
Expand All @@ -386,9 +417,11 @@ def load_data(
self._creds = self._get_credentials()

if folder_id:
return self._load_from_folder(folder_id, mime_types, query_string)
return self._load_from_folder(drive_id, folder_id, mime_types, query_string)
elif file_ids:
return self._load_from_file_ids(file_ids, mime_types, query_string)
return self._load_from_file_ids(
drive_id, file_ids, mime_types, query_string
)
else:
logger.warning("Either 'folder_id' or 'file_ids' must be provided.")
return []
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import unittest
from unittest.mock import MagicMock
from llama_index.readers.google import GoogleDriveReader


class TestGoogleDriveReader(unittest.TestCase):
def test_load_data_with_drive_id(self):
# Mock the necessary objects and methods
mock_credentials = MagicMock()
mock_drive = MagicMock()
reader = GoogleDriveReader()
reader._get_credentials = MagicMock(return_value=(mock_credentials, mock_drive))
reader._load_from_folder = MagicMock(return_value=["document1", "document2"])

# Test with a specific drive_id
drive_id = "example_drive_id"
folder_id = "example_folder_id"
result = reader.load_data(drive_id=drive_id, folder_id=folder_id)

# Assert that the correct methods are called and the correct result is returned
reader._get_credentials.assert_called_once()
reader._load_from_folder.assert_called_once_with(
drive_id, folder_id, None, None
)
self.assertEqual(result, ["document1", "document2"])


if __name__ == "__main__":
unittest.main()

0 comments on commit c5bd9ed

Please sign in to comment.