diff --git a/firestore/tests/unit/test_document.py b/firestore/tests/unit/test_document.py index 401ae0b8b7ca..9c079be7c37b 100644 --- a/firestore/tests/unit/test_document.py +++ b/firestore/tests/unit/test_document.py @@ -512,19 +512,28 @@ def test_get_not_found(self): [document], field_paths=field_paths, transaction=None) def _collections_helper(self, page_size=None): - from google.api_core import grpc_helpers + from google.api_core.page_iterator import Iterator + from google.api_core.page_iterator import Page from google.cloud.firestore_v1beta1.collection import ( CollectionReference) from google.cloud.firestore_v1beta1.gapic.firestore_client import ( FirestoreClient) - from google.cloud.firestore_v1beta1.proto import firestore_pb2 + + class _Iterator(Iterator): + + def __init__(self, pages): + super(_Iterator, self).__init__(client=None) + self._pages = pages + + def _next_page(self): + if self._pages: + page, self._pages = self._pages[0], self._pages[1:] + return Page(self, page, self.item_to_value) collection_ids = ['coll-1', 'coll-2'] - list_coll_response = firestore_pb2.ListCollectionIdsResponse( - collection_ids=collection_ids) - channel = grpc_helpers.ChannelStub() - api_client = FirestoreClient(channel=channel) - channel.ListCollectionIds.response = list_coll_response + iterator = _Iterator(pages=[collection_ids]) + api_client = mock.create_autospec(FirestoreClient) + api_client.list_collection_ids.return_value = iterator client = _make_client() client._firestore_api_internal = api_client @@ -543,12 +552,11 @@ def _collections_helper(self, page_size=None): self.assertEqual(collection.parent, document) self.assertEqual(collection.id, collection_id) - request, = channel.ListCollectionIds.requests - self.assertEqual(request.parent, document._document_path) - if page_size is None: - self.assertEqual(request.page_size, 0) - else: - self.assertEqual(request.page_size, page_size) + api_client.list_collection_ids.assert_called_once_with( + document._document_path, + page_size=page_size, + metadata=client._rpc_metadata, + ) def test_collections_wo_page_size(self): self._collections_helper()