diff --git a/llama-index-integrations/readers/llama-index-readers-google/llama_index/readers/google/drive/base.py b/llama-index-integrations/readers/llama-index-readers-google/llama_index/readers/google/drive/base.py index 4934074509e36..0e0bea6015fd6 100644 --- a/llama-index-integrations/readers/llama-index-readers-google/llama_index/readers/google/drive/base.py +++ b/llama-index-integrations/readers/llama-index-readers-google/llama_index/readers/google/drive/base.py @@ -91,6 +91,7 @@ def _get_credentials(self) -> Tuple[Credentials]: def _get_fileids_meta( self, + drive_id: Optional[str] = None, folder_id: Optional[str] = None, file_id: Optional[str] = None, mime_types: Optional[List[str]] = None, @@ -98,6 +99,7 @@ def _get_fileids_meta( ) -> List[List[str]]: """Get file ids present in folder/ file id Args: + drive_id: Drive id of the shared drive in google drive. folder_id: folder id of the folder in google drive. file_id: file id of the file in google drive mime_types: The mimeTypes you want to allow e.g.: "application/vnd.google-apps.document" @@ -134,16 +136,30 @@ def _get_fileids_meta( items = [] # get files taking into account that the results are paginated while True: - results = ( - service.files() - .list( - q=query, - includeItemsFromAllDrives=True, - supportsAllDrives=True, - fields="*", + if drive_id: + results = ( + service.files() + .list( + q=query, + driveId=drive_id, + corpora="drive", + includeItemsFromAllDrives=True, + supportsAllDrives=True, + fields="*", + ) + .execute() + ) + else: + results = ( + service.files() + .list( + q=query, + includeItemsFromAllDrives=True, + supportsAllDrives=True, + fields="*", + ) + .execute() ) - .execute() - ) items.extend(results.get("files", [])) page_token = results.get("nextPageToken", None) if page_token is None: @@ -151,13 +167,23 @@ def _get_fileids_meta( for item in items: if item["mimeType"] == folder_mime_type: - fileids_meta.extend( - self._get_fileids_meta( - folder_id=item["id"], - mime_types=mime_types, - query_string=query_string, + if drive_id: + fileids_meta.extend( + self._get_fileids_meta( + drive_id=drive_id, + folder_id=item["id"], + mime_types=mime_types, + query_string=query_string, + ) + ) + else: + fileids_meta.extend( + self._get_fileids_meta( + folder_id=item["id"], + mime_types=mime_types, + query_string=query_string, + ) ) - ) else: # Check if file doesn't belong to a Shared Drive. "owners" doesn't exist in a Shared Drive is_shared_drive = "driveId" in item @@ -338,6 +364,7 @@ def _load_from_file_ids( def _load_from_folder( self, + drive_id: str, folder_id: str, mime_types: Optional[List[str]], query_string: Optional[str], @@ -345,6 +372,7 @@ def _load_from_folder( """Load data from folder_id. Args: + drive_id: Drive id of the shared drive in google drive. folder_id: folder id of the folder in google drive. mime_types: The mimeTypes you want to allow e.g.: "application/vnd.google-apps.document" query_string: A more generic query string to filter the documents, e.g. "name contains 'test'". @@ -354,6 +382,7 @@ def _load_from_folder( """ try: fileids_meta = self._get_fileids_meta( + drive_id=drive_id, folder_id=folder_id, mime_types=mime_types, query_string=query_string, @@ -366,6 +395,7 @@ def _load_from_folder( def load_data( self, + drive_id: Optional[str] = None, folder_id: Optional[str] = None, file_ids: Optional[List[str]] = None, mime_types: Optional[List[str]] = None, # Deprecated @@ -374,6 +404,7 @@ def load_data( """Load data from the folder id or file ids. Args: + drive_id: Drive id of the shared drive in google drive. folder_id: Folder id of the folder in google drive. file_ids: File ids of the files in google drive. mime_types: The mimeTypes you want to allow e.g.: "application/vnd.google-apps.document" @@ -386,9 +417,11 @@ def load_data( self._creds = self._get_credentials() if folder_id: - return self._load_from_folder(folder_id, mime_types, query_string) + return self._load_from_folder(drive_id, folder_id, mime_types, query_string) elif file_ids: - return self._load_from_file_ids(file_ids, mime_types, query_string) + return self._load_from_file_ids( + drive_id, file_ids, mime_types, query_string + ) else: logger.warning("Either 'folder_id' or 'file_ids' must be provided.") return [] diff --git a/llama-index-integrations/readers/llama-index-readers-google/tests/test_drive_id.py b/llama-index-integrations/readers/llama-index-readers-google/tests/test_drive_id.py new file mode 100644 index 0000000000000..6ffc4e921fa7b --- /dev/null +++ b/llama-index-integrations/readers/llama-index-readers-google/tests/test_drive_id.py @@ -0,0 +1,29 @@ +import unittest +from unittest.mock import MagicMock +from llama_index.readers.google import GoogleDriveReader + + +class TestGoogleDriveReader(unittest.TestCase): + def test_load_data_with_drive_id(self): + # Mock the necessary objects and methods + mock_credentials = MagicMock() + mock_drive = MagicMock() + reader = GoogleDriveReader() + reader._get_credentials = MagicMock(return_value=(mock_credentials, mock_drive)) + reader._load_from_folder = MagicMock(return_value=["document1", "document2"]) + + # Test with a specific drive_id + drive_id = "example_drive_id" + folder_id = "example_folder_id" + result = reader.load_data(drive_id=drive_id, folder_id=folder_id) + + # Assert that the correct methods are called and the correct result is returned + reader._get_credentials.assert_called_once() + reader._load_from_folder.assert_called_once_with( + drive_id, folder_id, None, None + ) + self.assertEqual(result, ["document1", "document2"]) + + +if __name__ == "__main__": + unittest.main()