diff --git a/firestore/google/cloud/firestore_v1beta1/document.py b/firestore/google/cloud/firestore_v1beta1/document.py index b3069bdf4753..fc2f0f6c271e 100644 --- a/firestore/google/cloud/firestore_v1beta1/document.py +++ b/firestore/google/cloud/firestore_v1beta1/document.py @@ -422,6 +422,25 @@ def get(self, field_paths=None, transaction=None): [self], field_paths=field_paths, transaction=transaction) return _consume_single_get(snapshot_generator) + def collections(self, page_size=None): + """List subcollections of the current document. + + Args: + page_size (Optional[int]]): Iterator page size. + + Returns: + Sequence[~.firestore_v1beta1.collection.CollectionReference[: + iterator of subcollections of the current document. If the + document does not exist at the time of `snapshot`, the + iterator will be empty + """ + iterator = self._client._firestore_api.list_collection_ids( + self._document_path, page_size=page_size, + metadata=self._client._rpc_metadata) + iterator.document = self + iterator.item_to_value = _item_to_collection_ref + return iterator + class DocumentSnapshot(object): """A snapshot of document data in a Firestore database. @@ -658,3 +677,14 @@ def _first_write_result(write_results): raise ValueError('Expected at least one write result') return write_results[0] + + +def _item_to_collection_ref(iterator, item): + """Convert collection ID to collection ref. + + Args: + iterator (google.api_core.page_iterator.GRPCIterator): + iterator response + item (str): ID of the collection + """ + return iterator.document.collection(item) diff --git a/firestore/tests/system.py b/firestore/tests/system.py index 65348673b3a4..e4346feb9c8b 100644 --- a/firestore/tests/system.py +++ b/firestore/tests/system.py @@ -104,6 +104,28 @@ def test_create_document(client, cleanup): assert stored_data == expected_data +def test_create_document_w_subcollection(client, cleanup): + document_id = 'shun' + unique_resource_id('-') + document = client.document('collek', document_id) + # Add to clean-up before API request (in case ``create()`` fails). + cleanup(document) + + data = { + 'now': firestore.SERVER_TIMESTAMP, + } + document.create(data) + + child_ids = ['child1', 'child2'] + + for child_id in child_ids: + subcollection = document.collection(child_id) + _, subdoc = subcollection.add({'foo': 'bar'}) + cleanup(subdoc) + + children = document.collections() + assert sorted(child.id for child in children) == sorted(child_ids) + + def test_cannot_use_foreign_key(client, cleanup): document_id = 'cannot' + unique_resource_id('-') document = client.document('foreign-key', document_id) diff --git a/firestore/tests/unit/test_document.py b/firestore/tests/unit/test_document.py index e60e1140abe4..401ae0b8b7ca 100644 --- a/firestore/tests/unit/test_document.py +++ b/firestore/tests/unit/test_document.py @@ -511,6 +511,51 @@ def test_get_not_found(self): client.get_all.assert_called_once_with( [document], field_paths=field_paths, transaction=None) + def _collections_helper(self, page_size=None): + from google.api_core import grpc_helpers + from google.cloud.firestore_v1beta1.collection import ( + CollectionReference) + from google.cloud.firestore_v1beta1.gapic.firestore_client import ( + FirestoreClient) + from google.cloud.firestore_v1beta1.proto import firestore_pb2 + + collection_ids = ['coll-1', 'coll-2'] + list_coll_response = firestore_pb2.ListCollectionIdsResponse( + collection_ids=collection_ids) + channel = grpc_helpers.ChannelStub() + api_client = FirestoreClient(channel=channel) + channel.ListCollectionIds.response = list_coll_response + + client = _make_client() + client._firestore_api_internal = api_client + + # Actually make a document and call delete(). + document = self._make_one('where', 'we-are', client=client) + if page_size is not None: + collections = list(document.collections(page_size=page_size)) + else: + collections = list(document.collections()) + + # Verify the response and the mocks. + self.assertEqual(len(collections), len(collection_ids)) + for collection, collection_id in zip(collections, collection_ids): + self.assertIsInstance(collection, CollectionReference) + self.assertEqual(collection.parent, document) + self.assertEqual(collection.id, collection_id) + + request, = channel.ListCollectionIds.requests + self.assertEqual(request.parent, document._document_path) + if page_size is None: + self.assertEqual(request.page_size, 0) + else: + self.assertEqual(request.page_size, page_size) + + def test_collections_wo_page_size(self): + self._collections_helper() + + def test_collections_w_page_size(self): + self._collections_helper(page_size=10) + class TestDocumentSnapshot(unittest.TestCase):