diff --git a/firestore/google/cloud/firestore_v1beta1/document.py b/firestore/google/cloud/firestore_v1beta1/document.py index 7b8fd6dedb18..b4d6c2fa1312 100644 --- a/firestore/google/cloud/firestore_v1beta1/document.py +++ b/firestore/google/cloud/firestore_v1beta1/document.py @@ -18,7 +18,9 @@ import six +from google.api_core import exceptions from google.cloud.firestore_v1beta1 import _helpers +from google.cloud.firestore_v1beta1.proto import common_pb2 from google.cloud.firestore_v1beta1.watch import Watch @@ -423,9 +425,37 @@ def get(self, field_paths=None, transaction=None): if isinstance(field_paths, six.string_types): raise ValueError( "'field_paths' must be a sequence of paths, not a string.") - snapshot_generator = self._client.get_all( - [self], field_paths=field_paths, transaction=transaction) - return _consume_single_get(snapshot_generator) + + if field_paths is not None: + mask = common_pb2.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None + + firestore_api = self._client._firestore_api + try: + document_pb = firestore_api.get_document( + self._document_path, + mask=mask, + transaction=_helpers.get_transaction_id(transaction), + metadata=self._client._rpc_metadata) + except exceptions.NotFound: + data = None + exists = False + create_time = None + update_time = None + else: + data = _helpers.decode_dict(document_pb.fields, self._client) + exists = True + create_time = document_pb.create_time + update_time = document_pb.update_time + + return DocumentSnapshot( + reference=self, + data=data, + exists=exists, + read_time=None, # No server read_time available + create_time=create_time, + update_time=update_time) def collections(self, page_size=None): """List subcollections of the current document. diff --git a/firestore/tests/system.py b/firestore/tests/system.py index 62ea42c7ed0e..be391eeeb213 100644 --- a/firestore/tests/system.py +++ b/firestore/tests/system.py @@ -413,7 +413,6 @@ def test_document_get(client, cleanup): write_result = document.create(data) snapshot = document.get() check_snapshot(snapshot, document, data, write_result) - assert_timestamp_less(snapshot.create_time, snapshot.read_time) def test_document_delete(client, cleanup): diff --git a/firestore/tests/unit/test_cross_language.py b/firestore/tests/unit/test_cross_language.py index 3438a838ffa0..9362d874861b 100644 --- a/firestore/tests/unit/test_cross_language.py +++ b/firestore/tests/unit/test_cross_language.py @@ -21,6 +21,7 @@ import pytest from google.protobuf import text_format +from google.cloud.firestore_v1beta1.proto import document_pb2 from google.cloud.firestore_v1beta1.proto import firestore_pb2 from google.cloud.firestore_v1beta1.proto import test_pb2 from google.cloud.firestore_v1beta1.proto import write_pb2 @@ -170,19 +171,18 @@ def test_create_testprotos(test_proto): @pytest.mark.parametrize('test_proto', _GET_TESTPROTOS) def test_get_testprotos(test_proto): testcase = test_proto.get - # XXX this stub currently does nothing because no get testcases have - # is_error; taking this bit out causes the existing tests to fail - # due to a lack of batch getting - try: - testcase.is_error - except AttributeError: - return - else: # pragma: NO COVER - testcase = test_proto.get - firestore_api = _mock_firestore_api() - client, document = _make_client_document(firestore_api, testcase) - call = functools.partial(document.get, None, None) - _run_testcase(testcase, call, firestore_api, client) + firestore_api = mock.Mock(spec=['get_document']) + response = document_pb2.Document() + firestore_api.get_document.return_value = response + client, document = _make_client_document(firestore_api, testcase) + + document.get() # No '.textprotos' for errors, field_paths. + + firestore_api.get_document.assert_called_once_with( + document._document_path, + mask=None, + transaction=None, + metadata=client._rpc_metadata) @pytest.mark.parametrize('test_proto', _SET_TESTPROTOS) diff --git a/firestore/tests/unit/test_document.py b/firestore/tests/unit/test_document.py index c3348fe77af0..75531d92edbe 100644 --- a/firestore/tests/unit/test_document.py +++ b/firestore/tests/unit/test_document.py @@ -463,74 +463,90 @@ def test_delete_with_option(self): ) self._delete_helper(last_update_time=timestamp_pb) - def test_get_w_single_field_path(self): - client = mock.Mock(spec=[]) + def _get_helper( + self, field_paths=None, use_transaction=False, not_found=False): + from google.api_core.exceptions import NotFound + from google.cloud.firestore_v1beta1.proto import common_pb2 + from google.cloud.firestore_v1beta1.proto import document_pb2 + from google.cloud.firestore_v1beta1.transaction import Transaction - document = self._make_one('yellow', 'mellow', client=client) - with self.assertRaises(ValueError): - document.get('foo') + # Create a minimal fake GAPIC with a dummy response. + create_time = 123 + update_time = 234 + firestore_api = mock.Mock(spec=['get_document']) + response = mock.create_autospec(document_pb2.Document) + response.fields = {} + response.create_time = create_time + response.update_time = update_time + + if not_found: + firestore_api.get_document.side_effect = NotFound('testing') + else: + firestore_api.get_document.return_value = response - def test_get_success(self): - # Create a minimal fake client with a dummy response. - response_iterator = iter([mock.sentinel.snapshot]) - client = mock.Mock(spec=['get_all']) - client.get_all.return_value = response_iterator + client = _make_client('donut-base') + client._firestore_api_internal = firestore_api - # Actually make a document and call get(). - document = self._make_one('yellow', 'mellow', client=client) - snapshot = document.get() + document = self._make_one('where', 'we-are', client=client) - # Verify the response and the mocks. - self.assertIs(snapshot, mock.sentinel.snapshot) - client.get_all.assert_called_once_with( - [document], field_paths=None, transaction=None) + if use_transaction: + transaction = Transaction(client) + transaction_id = transaction._id = b'asking-me-2' + else: + transaction = None + + snapshot = document.get( + field_paths=field_paths, transaction=transaction) + + self.assertIs(snapshot.reference, document) + if not_found: + self.assertIsNone(snapshot._data) + self.assertFalse(snapshot.exists) + self.assertIsNone(snapshot.read_time) + self.assertIsNone(snapshot.create_time) + self.assertIsNone(snapshot.update_time) + else: + self.assertEqual(snapshot.to_dict(), {}) + self.assertTrue(snapshot.exists) + self.assertIsNone(snapshot.read_time) + self.assertIs(snapshot.create_time, create_time) + self.assertIs(snapshot.update_time, update_time) + + # Verify the request made to the API + if field_paths is not None: + mask = common_pb2.DocumentMask(field_paths=sorted(field_paths)) + else: + mask = None - def test_get_with_transaction(self): - from google.cloud.firestore_v1beta1.client import Client - from google.cloud.firestore_v1beta1.transaction import Transaction + if use_transaction: + expected_transaction_id = transaction_id + else: + expected_transaction_id = None - # Create a minimal fake client with a dummy response. - response_iterator = iter([mock.sentinel.snapshot]) - client = mock.create_autospec(Client, instance=True) - client.get_all.return_value = response_iterator + firestore_api.get_document.assert_called_once_with( + document._document_path, + mask=mask, + transaction=expected_transaction_id, + metadata=client._rpc_metadata) - # Actually make a document and call get(). - document = self._make_one('yellow', 'mellow', client=client) - transaction = Transaction(client) - transaction._id = b'asking-me-2' - snapshot = document.get(transaction=transaction) + def test_get_not_found(self): + self._get_helper(not_found=True) - # Verify the response and the mocks. - self.assertIs(snapshot, mock.sentinel.snapshot) - client.get_all.assert_called_once_with( - [document], field_paths=None, transaction=transaction) + def test_get_default(self): + self._get_helper() - def test_get_not_found(self): - from google.cloud.firestore_v1beta1.document import DocumentSnapshot + def test_get_w_string_field_path(self): + with self.assertRaises(ValueError): + self._get_helper(field_paths='foo') - # Create a minimal fake client with a dummy response. - read_time = 123 - expected = DocumentSnapshot(None, None, False, read_time, None, None) - response_iterator = iter([expected]) - client = mock.Mock( - _database_string='sprinklez', - spec=['_database_string', 'get_all']) - client.get_all.return_value = response_iterator - - # Actually make a document and call get(). - document = self._make_one('house', 'cowse', client=client) - field_paths = ['x.y', 'x.z', 't'] - snapshot = document.get(field_paths=field_paths) - self.assertIsNone(snapshot.reference) - self.assertIsNone(snapshot._data) - self.assertFalse(snapshot.exists) - self.assertEqual(snapshot.read_time, expected.read_time) - self.assertIsNone(snapshot.create_time) - self.assertIsNone(snapshot.update_time) + def test_get_with_field_path(self): + self._get_helper(field_paths=['foo']) - # Verify the response and the mocks. - client.get_all.assert_called_once_with( - [document], field_paths=field_paths, transaction=None) + def test_get_with_multiple_field_paths(self): + self._get_helper(field_paths=['foo', 'bar.baz']) + + def test_get_with_transaction(self): + self._get_helper(use_transaction=True) def _collections_helper(self, page_size=None): from google.api_core.page_iterator import Iterator