diff --git a/tableauserverclient/server/endpoint/endpoint.py b/tableauserverclient/server/endpoint/endpoint.py index e04acc595..b1a42b20c 100644 --- a/tableauserverclient/server/endpoint/endpoint.py +++ b/tableauserverclient/server/endpoint/endpoint.py @@ -3,7 +3,7 @@ from packaging.version import Version from functools import wraps from xml.etree.ElementTree import ParseError -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Mapping from .exceptions import ( ServerResponseError, @@ -35,15 +35,35 @@ def __init__(self, parent_srv: "Server"): self.parent_srv = parent_srv @staticmethod - def _make_common_headers(auth_token, content_type): - _client_version: Optional[str] = get_versions()["version"] - headers = {} + def set_parameters(http_options, auth_token, content, content_type, parameters) -> Dict[str, Any]: + parameters = parameters or {} + parameters.update(http_options) + if "headers" not in parameters: + parameters["headers"] = {} + if auth_token is not None: - headers[TABLEAU_AUTH_HEADER] = auth_token + parameters["headers"][TABLEAU_AUTH_HEADER] = auth_token if content_type is not None: - headers[CONTENT_TYPE_HEADER] = content_type - headers[USER_AGENT_HEADER] = "Tableau Server Client/{}".format(_client_version) - return headers + parameters["headers"][CONTENT_TYPE_HEADER] = content_type + + Endpoint.set_user_agent(parameters) + if content is not None: + parameters["data"] = content + return parameters or {} + + @staticmethod + def set_user_agent(parameters): + if USER_AGENT_HEADER not in parameters["headers"]: + if USER_AGENT_HEADER in parameters: + parameters["headers"][USER_AGENT_HEADER] = parameters[USER_AGENT_HEADER] + else: + # only set the TSC user agent if not already populated + _client_version: Optional[str] = get_versions()["version"] + parameters["headers"][USER_AGENT_HEADER] = "Tableau Server Client/{}".format(_client_version) + + # result: parameters["headers"]["User-Agent"] is set + # return explicitly for testing only + return parameters def _make_request( self, @@ -54,18 +74,14 @@ def _make_request( content_type: Optional[str] = None, parameters: Optional[Dict[str, Any]] = None, ) -> "Response": - parameters = parameters or {} - if "headers" not in parameters: - parameters["headers"] = {} - parameters.update(self.parent_srv.http_options) - parameters["headers"].update(Endpoint._make_common_headers(auth_token, content_type)) - - if content is not None: - parameters["data"] = content + parameters = Endpoint.set_parameters( + self.parent_srv.http_options, auth_token, content, content_type, parameters + ) - logger.debug("request {}, url: {}".format(method.__name__, url)) + logger.debug("request {}, url: {}".format(method, url)) if content: - logger.debug("request content: {}".format(helpers.strings.redact_xml(content[:1000]))) + redacted = helpers.strings.redact_xml(content[:1000]) + logger.debug("request content: {}".format(redacted)) server_response = method(url, **parameters) self._check_status(server_response, url) diff --git a/tableauserverclient/server/server.py b/tableauserverclient/server/server.py index 5e2dacf33..d2a8b933b 100644 --- a/tableauserverclient/server/server.py +++ b/tableauserverclient/server/server.py @@ -31,6 +31,7 @@ Fileuploads, FlowRuns, Metrics, + Endpoint, ) from .endpoint.exceptions import ( ServerInfoEndpointNotFoundError, @@ -62,6 +63,10 @@ def __init__(self, server_address, use_server_version=False, http_options=None, self._site_id = None self._user_id = None + # TODO: this needs to change to default to https, but without breaking existing code + if not server_address.startswith("https://") and not server_address.startswith("https://"): + server_address = "https://" + server_address + self._server_address: str = server_address self._session_factory = session_factory or requests.session @@ -96,21 +101,17 @@ def __init__(self, server_address, use_server_version=False, http_options=None, if http_options: self.add_http_options(http_options) - self.validate_server_connection() + self.validate_connection_settings() # does not make an actual outgoing request self.version = default_server_version if use_server_version: self.use_server_version() # this makes a server call - def validate_server_connection(self): + def validate_connection_settings(self): try: - if not self._server_address.startswith("https://") and not self._server_address.startswith("https://"): - self._server_address = "https://" + self._server_address - self._session.prepare_request( - requests.Request("GET", url=self._server_address, params=self._http_options) - ) + Endpoint(self).set_parameters(self._http_options, None, None, None, None) except Exception as req_ex: - raise ValueError("Invalid server initialization", req_ex) + raise ValueError("Server connection settings not valid", req_ex) def __repr__(self): return " [Connection: {}, {}]".format(self.baseurl, self.server_info.serverInfo) @@ -143,10 +144,12 @@ def _set_auth(self, site_id, user_id, auth_token): self._auth_token = auth_token def _get_legacy_version(self): - response = self._session.get(self.server_address + "/auth?format=xml") + dest = Endpoint(self) + response = dest._make_request(method=self.session.get, url=self.server_address + "/auth?format=xml") try: info_xml = fromstring(response.content) except ParseError as parseError: + logging.getLogger("TSC.server").info(parseError) logging.getLogger("TSC.server").info( "Could not read server version info. The server may not be running or configured." ) diff --git a/test/http/test_http_requests.py b/test/http/test_http_requests.py index e96879277..bf9292dec 100644 --- a/test/http/test_http_requests.py +++ b/test/http/test_http_requests.py @@ -82,20 +82,20 @@ def test_http_options_not_sequence_fails(self): def test_validate_connection_http(self): url = "http://cookies.com" server = TSC.Server(url) - server.validate_server_connection() + server.validate_connection_settings() self.assertEqual(url, server.server_address) def test_validate_connection_https(self): url = "https://cookies.com" server = TSC.Server(url) - server.validate_server_connection() + server.validate_connection_settings() self.assertEqual(url, server.server_address) def test_validate_connection_no_protocol(self): url = "cookies.com" fixed_url = "http://cookies.com" server = TSC.Server(url) - server.validate_server_connection() + server.validate_connection_settings() self.assertEqual(fixed_url, server.server_address) diff --git a/test/test_endpoint.py b/test/test_endpoint.py index e583a9188..5b6324cab 100644 --- a/test/test_endpoint.py +++ b/test/test_endpoint.py @@ -38,3 +38,21 @@ class FakeResponse(object): server_response = FakeResponse() log = endpoint.log_response_safely(server_response) self.assertTrue(log.find("[Truncated File Contents]") > 0, log) + + def test_set_user_agent_from_options_headers(self): + params = {"User-Agent": "1", "headers": {"User-Agent": "2"}} + result = TSC.server.Endpoint.set_user_agent(params) + # it should use the value under 'headers' if more than one is given + print(result) + print(result["headers"]["User-Agent"]) + self.assertTrue(result["headers"]["User-Agent"] == "2") + + def test_set_user_agent_from_options(self): + params = {"headers": {"User-Agent": "2"}} + result = TSC.server.Endpoint.set_user_agent(params) + self.assertTrue(result["headers"]["User-Agent"] == "2") + + def test_set_user_agent_when_blank(self): + params = {"headers": {}} + result = TSC.server.Endpoint.set_user_agent(params) + self.assertTrue(result["headers"]["User-Agent"].startswith("Tableau Server Client"))