From ff4a41754d68791ae0df3593661f3e29583e7b7f Mon Sep 17 00:00:00 2001 From: Rodrigo Primo Date: Fri, 12 Apr 2013 17:49:57 -0300 Subject: [PATCH 001/151] url limit is present in openid 1 and not on openid 2 --- openid/server/server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openid/server/server.py b/openid/server/server.py index 5d426ea9..4f37a34c 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -128,7 +128,7 @@ from openid.server.trustroot import TrustRoot, verifyReturnTo from openid.association import Association, default_negotiator, getSecretSize from openid.message import Message, InvalidOpenIDNamespace, \ - OPENID_NS, OPENID2_NS, IDENTIFIER_SELECT, OPENID1_URL_LIMIT + OPENID_NS, OPENID1_NS, OPENID2_NS, IDENTIFIER_SELECT, OPENID1_URL_LIMIT from openid.urinorm import urinorm HTTP_OK = 200 @@ -1043,7 +1043,7 @@ def whichEncoding(self): @change: 2.1.0 added the ENCODE_HTML_FORM response. """ if self.request.mode in BROWSER_REQUEST_MODES: - if self.fields.getOpenIDNamespace() == OPENID2_NS and \ + if self.fields.getOpenIDNamespace() == OPENID1_NS and \ len(self.encodeToURL()) > OPENID1_URL_LIMIT: return ENCODE_HTML_FORM else: @@ -1719,7 +1719,7 @@ def whichEncoding(self): displayed to the user. """ if self.hasReturnTo(): - if self.openid_message.getOpenIDNamespace() == OPENID2_NS and \ + if self.openid_message.getOpenIDNamespace() == OPENID1_NS and \ len(self.encodeToURL()) > OPENID1_URL_LIMIT: return ENCODE_HTML_FORM else: From 562398e374e65fc8cd8c09953d45eace6d31c636 Mon Sep 17 00:00:00 2001 From: Rodrigo Primo Date: Thu, 9 May 2013 10:40:21 -0300 Subject: [PATCH 002/151] better way to check if message is using openid version 1 --- openid/server/server.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openid/server/server.py b/openid/server/server.py index 4f37a34c..dd7657a9 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -128,7 +128,7 @@ from openid.server.trustroot import TrustRoot, verifyReturnTo from openid.association import Association, default_negotiator, getSecretSize from openid.message import Message, InvalidOpenIDNamespace, \ - OPENID_NS, OPENID1_NS, OPENID2_NS, IDENTIFIER_SELECT, OPENID1_URL_LIMIT + OPENID_NS, OPENID2_NS, IDENTIFIER_SELECT, OPENID1_URL_LIMIT from openid.urinorm import urinorm HTTP_OK = 200 @@ -1043,7 +1043,7 @@ def whichEncoding(self): @change: 2.1.0 added the ENCODE_HTML_FORM response. """ if self.request.mode in BROWSER_REQUEST_MODES: - if self.fields.getOpenIDNamespace() == OPENID1_NS and \ + if self.fields.isOpenID1() and \ len(self.encodeToURL()) > OPENID1_URL_LIMIT: return ENCODE_HTML_FORM else: @@ -1719,7 +1719,7 @@ def whichEncoding(self): displayed to the user. """ if self.hasReturnTo(): - if self.openid_message.getOpenIDNamespace() == OPENID1_NS and \ + if self.openid_message.isOpenID1() and \ len(self.encodeToURL()) > OPENID1_URL_LIMIT: return ENCODE_HTML_FORM else: From 05b759cfb1615d7b18be207b506ec5bfd974c9a8 Mon Sep 17 00:00:00 2001 From: Rodrigo Primo Date: Tue, 28 May 2013 14:58:45 -0300 Subject: [PATCH 003/151] remove unsupported python version 2.5 from travis-ci configuration file --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 1c0f4d7f..bc8614df 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: python python: - - 2.5 - 2.6 - 2.7 From d199d5954834ff2aaa0648ba3a8602dc1ea28713 Mon Sep 17 00:00:00 2001 From: Rodrigo Primo Date: Wed, 29 May 2013 18:20:21 -0300 Subject: [PATCH 004/151] adapt djopenid example to work with django 1.4 (fixes #51) --- examples/djopenid/settings.py | 20 ++++++++++++-------- examples/djopenid/util.py | 12 ++++++------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/examples/djopenid/settings.py b/examples/djopenid/settings.py index 6d0fe0c2..f2a7c872 100644 --- a/examples/djopenid/settings.py +++ b/examples/djopenid/settings.py @@ -19,12 +19,16 @@ MANAGERS = ADMINS -DATABASE_ENGINE = 'sqlite3' # 'postgresql', 'mysql', 'sqlite3' or 'ado_mssql'. -DATABASE_NAME = '/tmp/test.db' # Or path to database file if using sqlite3. -DATABASE_USER = '' # Not used with sqlite3. -DATABASE_PASSWORD = '' # Not used with sqlite3. -DATABASE_HOST = '' # Set to empty string for localhost. Not used with sqlite3. -DATABASE_PORT = '' # Set to empty string for default. Not used with sqlite3. +DATABASES = { + 'default': { + 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'mysql', 'sqlite3' or 'oracle'. + 'NAME': '/tmp/test.db', # Or path to database file if using sqlite3. + 'USER': '', # Not used with sqlite3. + 'PASSWORD': '', # Not used with sqlite3. + 'HOST': '', # Set to empty string for localhost. Not used with sqlite3. + 'PORT': '', # Set to empty string for default. Not used with sqlite3. + } +} # Local time zone for this installation. All choices can be found here: # http://www.postgresql.org/docs/current/static/datetime-keywords.html#DATETIME-TIMEZONE-SET-TABLE @@ -55,8 +59,8 @@ # List of callables that know how to import templates from various sources. TEMPLATE_LOADERS = ( - 'django.template.loaders.filesystem.load_template_source', - 'django.template.loaders.app_directories.load_template_source', + 'django.template.loaders.filesystem.Loader', + 'django.template.loaders.app_directories.Loader', # 'django.template.loaders.eggs.load_template_source', ) diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index 4f359e14..7b30df13 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -41,7 +41,7 @@ def getOpenIDStore(filestore_path, table_prefix): The result of this function should be passed to the Consumer constructor as the store parameter. """ - if not settings.DATABASE_ENGINE: + if not settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE'): return FileOpenIDStore(filestore_path) # Possible side-effect: create a database connection if one isn't @@ -55,18 +55,18 @@ def getOpenIDStore(filestore_path, table_prefix): } types = { - 'postgresql': sqlstore.PostgreSQLStore, - 'mysql': sqlstore.MySQLStore, - 'sqlite3': sqlstore.SQLiteStore, + 'django.db.backends.postgresql': sqlstore.PostgreSQLStore, + 'django.db.backends.mysql': sqlstore.MySQLStore, + 'django.db.backends.sqlite3': sqlstore.SQLiteStore, } try: - s = types[settings.DATABASE_ENGINE](connection.connection, + s = types[settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE')](connection.connection, **tablenames) except KeyError: raise ImproperlyConfigured, \ "Database engine %s not supported by OpenID library" % \ - (settings.DATABASE_ENGINE,) + (settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE'),) try: s.createTables() From f8e8e8f1132fd7b3a3987e68f6d9dd045ed3e59e Mon Sep 17 00:00:00 2001 From: Rodrigo Primo Date: Thu, 11 Jul 2013 19:51:38 -0300 Subject: [PATCH 005/151] call all test files instead of just the first --- run_tests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/run_tests.sh b/run_tests.sh index fc79b772..3a2b7249 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,3 +1,3 @@ #!/bin/sh -python openid/test/test*.py +nosetests From 4c8d4af7fbcf962fdee29613582e03979c278367 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Fri, 12 Jul 2013 13:10:19 -0700 Subject: [PATCH 006/151] Fix test broken by 95aa2a9 (change form built-in logging to logging module) --- openid/test/test_association_response.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index cf9d0147..5a68ac31 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -176,7 +176,7 @@ def mkTest(expected_session_type, session_type_value): """ def test(self): self._doTest(expected_session_type, session_type_value) - self.failUnlessEqual(0, len(self.messages)) + self.failUnlessLogEmpty() return test @@ -214,9 +214,7 @@ def test_explicitNoEncryption(self): session_type_value='no-encryption', expected_session_type='no-encryption', ) - self.failUnlessEqual(1, len(self.messages)) - self.failUnless(self.messages[0].startswith( - 'WARNING: OpenID server sent "no-encryption"')) + self.failUnlessLogMatches('OpenID server sent "no-encryption"') test_dhSHA1 = mkTest( session_type_value='DH-SHA1', From 45fe1dbad3fd2ed4bd50c5e2f4aedd6a68bba5d6 Mon Sep 17 00:00:00 2001 From: Jeff Dairiki Date: Fri, 12 Jul 2013 13:59:20 -0700 Subject: [PATCH 007/151] Rename TestAuthRequestMixin so that nose doesn't collect it as a test. Nose is apparently quite aggressive about what it considers to be a test. TestAuthRequestMixin was being treated as a test (even though it does not inherit from unittest.TestCase) solely on account of its name. --- openid/test/test_auth_request.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index d9e72332..9114823a 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -23,7 +23,7 @@ def isOPIdentifier(self): class DummyAssoc(object): handle = "assoc-handle" -class TestAuthRequestMixin(support.OpenIDTestMixin): +class AuthRequestTestMixin(support.OpenIDTestMixin): """Mixin for AuthRequest tests for OpenID 1 and 2; DON'T add unittest.TestCase as a base class here.""" @@ -101,7 +101,7 @@ def test_standard(self): self.failUnlessHasIdentifiers( msg, self.endpoint.local_id, self.endpoint.claimed_id) -class TestAuthRequestOpenID2(TestAuthRequestMixin, unittest.TestCase): +class TestAuthRequestOpenID2(AuthRequestTestMixin, unittest.TestCase): preferred_namespace = message.OPENID2_NS def failUnlessHasRealm(self, msg): @@ -151,7 +151,7 @@ def test_opIdentifierSendsIdentifierSelect(self): self.failUnlessHasIdentifiers( msg, message.IDENTIFIER_SELECT, message.IDENTIFIER_SELECT) -class TestAuthRequestOpenID1(TestAuthRequestMixin, unittest.TestCase): +class TestAuthRequestOpenID1(AuthRequestTestMixin, unittest.TestCase): preferred_namespace = message.OPENID1_NS def setUpEndpoint(self): From a3ce4cc90d585b988e39078d3f291e5c250dcb00 Mon Sep 17 00:00:00 2001 From: Leonardo Santagada Date: Wed, 15 Oct 2014 11:28:11 -0300 Subject: [PATCH 008/151] Update README.md removed most requirements as they are not needed on the supported python versions. --- README.md | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/README.md b/README.md index 180b6c4b..17b0cb6a 100644 --- a/README.md +++ b/README.md @@ -9,14 +9,7 @@ This is the Python OpenID library. REQUIREMENTS ============ - - Python 2.3, 2.4, or 2.5. - - - ElementTree. This is included in the Python 2.5 standard library, - but users of earlier versions of Python may need to install it - seperately. - - - pycrypto, if on Python 2.3 and without /dev/urandom, or on Python - 2.3 or 2.4 and you want SHA256. + - Python 2.6, 2.7. INSTALLATION From 48c6dbfb312fc49dfdebdedbe17b16813528b8ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 21 Nov 2017 16:38:50 +0100 Subject: [PATCH 009/151] Fix bug introduced in ff4a417 --- openid/server/server.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/openid/server/server.py b/openid/server/server.py index dd7657a9..681de527 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -1043,8 +1043,9 @@ def whichEncoding(self): @change: 2.1.0 added the ENCODE_HTML_FORM response. """ if self.request.mode in BROWSER_REQUEST_MODES: - if self.fields.isOpenID1() and \ + if self.fields.isOpenID2() and \ len(self.encodeToURL()) > OPENID1_URL_LIMIT: + # Message can be encoded as HTML form only if it's OpenID 2.0. return ENCODE_HTML_FORM else: return ENCODE_URL @@ -1719,8 +1720,9 @@ def whichEncoding(self): displayed to the user. """ if self.hasReturnTo(): - if self.openid_message.isOpenID1() and \ + if self.openid_message.isOpenID2() and \ len(self.encodeToURL()) > OPENID1_URL_LIMIT: + # Message can be encoded as HTML form only if it's OpenID 2.0. return ENCODE_HTML_FORM else: return ENCODE_URL From 1b759dbfe526ac8749fb144c4b4a6b85ed8841c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 21 Nov 2017 16:20:23 +0100 Subject: [PATCH 010/151] Fix broken tests --- .gitignore | 2 ++ .travis.yml | 2 +- openid/test/test_server.py | 2 +- run_tests.sh | 3 +-- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index d5864bab..f1ff221d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ *.pyc *.swp .tox +# Created in tests +/sstore diff --git a/.travis.yml b/.travis.yml index bc8614df..e79365f4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,6 @@ python: - 2.6 - 2.7 -before_install: pip install --use-mirrors Django nose twill pycrypto +before_install: pip install Django pycrypto lxml install: python setup.py install script: ./run_tests.sh diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 268f226a..17e4e59f 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -570,7 +570,7 @@ def test_id_res_OpenID2_POST(self): self.failUnless(len(response.encodeToURL()) > OPENID1_URL_LIMIT) self.failUnless(response.whichEncoding() == server.ENCODE_HTML_FORM) webresponse = self.encode(response) - self.failUnlessEqual(webresponse.body, response.toFormMarkup()) + self.assertIn(response.toFormMarkup(), webresponse.body) def test_toFormMarkup(self): request = server.CheckIDRequest( diff --git a/run_tests.sh b/run_tests.sh index 3a2b7249..9cb637fa 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -1,3 +1,2 @@ #!/bin/sh - -nosetests +python admin/runtests From 12614250d97db739be05177a5c4fa424574e18f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 22 Nov 2017 09:47:19 +0100 Subject: [PATCH 011/151] Drop support for python 2.6 --- .travis.yml | 1 - README.md | 2 +- openid/test/test_discover.py | 4 ---- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/.travis.yml b/.travis.yml index e79365f4..7799c93e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,6 @@ language: python python: - - 2.6 - 2.7 before_install: pip install Django pycrypto lxml diff --git a/README.md b/README.md index 17b0cb6a..b54b3ed9 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ This is the Python OpenID library. REQUIREMENTS ============ - - Python 2.6, 2.7. + - Python 2.7. INSTALLATION diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 80be5cb4..18bdd0ca 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -91,10 +91,6 @@ class TestFetchException(datadriven.DataDrivenTestCase): RuntimeError(), ] - # String exceptions are finally gone from Python 2.6. - if sys.version_info[:2] < (2, 6): - cases.append('oi!') - def __init__(self, exc): datadriven.DataDrivenTestCase.__init__(self, repr(exc)) self.exc = exc From d0b5f037fbb946af39d5092d47729a3a9ff5b248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 22 Nov 2017 12:03:07 +0100 Subject: [PATCH 012/151] Rewrite 'Urllib2Fetcher' tests --- openid/test/test_fetchers.py | 89 ++++++++++++++++++++++++++++++++++-- 1 file changed, 84 insertions(+), 5 deletions(-) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index da1eea84..6370a89c 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -2,7 +2,11 @@ import unittest import sys import urllib2 +from urllib import addinfourl import socket +from cStringIO import StringIO + +from mock import Mock from openid import fetchers @@ -12,13 +16,15 @@ def failUnlessResponseExpected(expected, actual): assert expected.final_url == actual.final_url, ( "%r != %r" % (expected.final_url, actual.final_url)) assert expected.status == actual.status - assert expected.body == actual.body + assert expected.body == actual.body, "%r != %r" % (expected.body, actual.body) got_headers = dict(actual.headers) - del got_headers['date'] - del got_headers['server'] + # TODO: Delete these pops + got_headers.pop('date', None) + got_headers.pop('server', None) for k, v in expected.headers.iteritems(): assert got_headers[k] == v, (k, v, got_headers[k]) + def test_fetcher(fetcher, exc, server): def geturl(path): return 'http://%s:%s%s' % (socket.getfqdn(server.server_name), @@ -83,7 +89,6 @@ def plain(path, code): def run_fetcher_tests(server): exc_fetchers = [] for klass, library_name in [ - (fetchers.Urllib2Fetcher, 'urllib2'), (fetchers.CurlHTTPFetcher, 'pycurl'), (fetchers.HTTPLib2Fetcher, 'httplib2'), ]: @@ -278,8 +283,82 @@ def test_notWrapped(self): else: self.fail('Should have raised an exception') + +class TestHandler(urllib2.BaseHandler): + """Urllib2 test handler.""" + + def __init__(self, http_mock): + self.http_mock = http_mock + + def http_open(self, req): + return self.http_mock() + + +class TestUrllib2Fetcher(unittest.TestCase): + """Test `Urllib2Fetcher` class.""" + + fetcher = fetchers.Urllib2Fetcher() + invalid_url_error = ValueError + + def setUp(self): + self.http_mock = Mock(side_effect=[]) + opener = urllib2.OpenerDirector() + opener.add_handler(TestHandler(self.http_mock)) + urllib2.install_opener(opener) + + def tearDown(self): + # Uninstall custom opener + urllib2.install_opener(None) + + def add_response(self, url, status_code, headers, body=None): + response = addinfourl(StringIO(body or ''), headers, url, status_code) + responses = list(self.http_mock.side_effect) + responses.append(response) + self.http_mock.side_effect = responses + + def test_success(self): + # Test success response + self.add_response('http://example.cz/success/', 200, {'Content-Type': 'text/plain'}, 'BODY') + response = self.fetcher.fetch('http://example.cz/success/') + expected = fetchers.HTTPResponse('http://example.cz/success/', 200, {'Content-Type': 'text/plain'}, 'BODY') + failUnlessResponseExpected(expected, response) + + def test_redirect(self): + # Test redirect response - a final response comes from another URL. + self.add_response('http://example.cz/success/', 200, {'Content-Type': 'text/plain'}, 'BODY') + response = self.fetcher.fetch('http://example.cz/redirect/') + expected = fetchers.HTTPResponse('http://example.cz/success/', 200, {'Content-Type': 'text/plain'}, 'BODY') + failUnlessResponseExpected(expected, response) + + def test_error(self): + # Test error responses - returned as obtained + self.add_response('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') + response = self.fetcher.fetch('http://example.cz/error/') + expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') + failUnlessResponseExpected(expected, response) + + def test_invalid_url(self): + with self.assertRaisesRegexp(self.invalid_url_error, 'Bad URL scheme:'): + self.fetcher.fetch('invalid://example.cz/') + + def test_connection_error(self): + # Test connection error + self.http_mock.side_effect = urllib2.HTTPError('http://example.cz/error/', 500, 'Error message', + {'Content-Type': 'text/plain'}, StringIO('BODY')) + response = self.fetcher.fetch('http://example.cz/error/') + expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') + failUnlessResponseExpected(expected, response) + + +class TestSilencedUrllib2Fetcher(TestUrllib2Fetcher): + """Test silenced `Urllib2Fetcher` class.""" + + fetcher = fetchers.ExceptionWrappingFetcher(fetchers.Urllib2Fetcher()) + invalid_url_error = fetchers.HTTPFetchingError + + def pyUnitTests(): case1 = unittest.FunctionTestCase(test) loadTests = unittest.defaultTestLoader.loadTestsFromTestCase case2 = loadTests(DefaultFetcherTest) - return unittest.TestSuite([case1, case2]) + return unittest.TestSuite([case1, case2, loadTests(TestUrllib2Fetcher), loadTests(TestSilencedUrllib2Fetcher)]) From a612136235d6734f751669b22fb1c660c36f1cf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 29 Nov 2010 12:01:17 +0100 Subject: [PATCH 013/151] Assoc_type is required for protocol 2.0 --- openid/server/server.py | 10 +++++++++- openid/test/test_server.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/openid/server/server.py b/openid/server/server.py index 681de527..9a9a15b4 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -425,12 +425,21 @@ def fromMessage(klass, message, op_endpoint=UNUSED): 'assocaition session type. Continuing anyway.') elif not session_type: session_type = 'no-encryption' + + # in 1.0 assoc_type has default + assoc_type = message.getArg(OPENID_NS, 'assoc_type', 'HMAC-SHA1') else: session_type = message.getArg(OPENID2_NS, 'session_type') if session_type is None: raise ProtocolError(message, text="session_type missing from request") + # in 2.0 assoc_type is required + assoc_type = message.getArg(OPENID2_NS, 'assoc_type') + if assoc_type is None: + raise ProtocolError(message, + text="assoc_type missing from request") + try: session_class = klass.session_classes[session_type] except KeyError: @@ -443,7 +452,6 @@ def fromMessage(klass, message, op_endpoint=UNUSED): raise ProtocolError(message, 'Error parsing %s session: %s' % (session_class.session_type, why[0])) - assoc_type = message.getArg(OPENID_NS, 'assoc_type', 'HMAC-SHA1') if assoc_type not in session.allowed_assoc_types: fmt = 'Session type %s does not support association type %s' raise ProtocolError(message, fmt % (session_type, assoc_type)) diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 17e4e59f..797c051f 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -709,6 +709,7 @@ def test_cancelToForm(self): def test_assocReply(self): msg = Message(OPENID2_NS) msg.setArg(OPENID2_NS, 'session_type', 'no-encryption') + msg.setArg(OPENID2_NS, 'assoc_type', 'HMAC-SHA1') request = server.AssociateRequest.fromMessage(msg) response = server.OpenIDResponse(request) response.fields = Message.fromPostArgs( @@ -834,6 +835,7 @@ def test_cancel(self): def test_assocReply(self): msg = Message(OPENID2_NS) msg.setArg(OPENID2_NS, 'session_type', 'no-encryption') + msg.setArg(OPENID2_NS, 'assoc_type', 'HMAC-SHA1') request = server.AssociateRequest.fromMessage(msg) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({'assoc_handle': "every-zig"}) @@ -1702,6 +1704,7 @@ def test_associate2(self): msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', + 'openid.assoc_type': 'HMAC-SHA1', }) request = server.AssociateRequest.fromMessage(msg) @@ -1724,6 +1727,7 @@ def test_associate3(self): msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', + 'openid.assoc_type': 'HMAC-SHA1', }) request = server.AssociateRequest.fromMessage(msg) @@ -1766,6 +1770,16 @@ def test_missingSessionTypeOpenID2(self): self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, msg) + def test_missingAssocTypeOpenID2(self): + """Make sure assoc_type is required in OpenID 2""" + msg = Message.fromPostArgs({ + 'openid.ns': OPENID2_NS, + 'openid.session_type': 'no-encryption', + }) + + self.assertRaises(server.ProtocolError, + server.AssociateRequest.fromMessage, msg) + def test_checkAuth(self): request = server.CheckAuthRequest('arrrrrf', '0x3999', []) response = self.server.openid_check_authentication(request) From 229bcc3054d04251cbbebba0539929637d7b0d24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 10 Feb 2011 15:19:20 +0100 Subject: [PATCH 014/151] Introduce InvalidNamespace exception --- openid/message.py | 9 +++++++-- openid/server/server.py | 8 +++++++- openid/test/test_message.py | 21 +++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/openid/message.py b/openid/message.py index b287d2e2..35ca22a2 100644 --- a/openid/message.py +++ b/openid/message.py @@ -72,6 +72,11 @@ def __str__(self): s += " %r" % (self.args[0],) return s +class InvalidNamespace(KeyError): + """ + Raised if there is problem with other namespaces than OpenID namespace + """ + # Sentinel used for Message implementation to indicate that getArg # should raise an exception instead of returning a default. @@ -582,7 +587,7 @@ def addAlias(self, namespace_uri, desired_alias, implicit=False): desired_alias, current_namespace_uri, desired_alias) - raise KeyError(msg) + raise InvalidNamespace(msg) # Check that there is not already a (different) alias for # this namespace URI @@ -590,7 +595,7 @@ def addAlias(self, namespace_uri, desired_alias, implicit=False): if alias is not None and alias != desired_alias: fmt = ('Cannot map %r to alias %r. ' 'It is already mapped to alias %r') - raise KeyError(fmt % (namespace_uri, desired_alias, alias)) + raise InvalidNamespace(fmt % (namespace_uri, desired_alias, alias)) assert (desired_alias == NULL_NAMESPACE or type(desired_alias) in [str, unicode]), repr(desired_alias) diff --git a/openid/server/server.py b/openid/server/server.py index 9a9a15b4..db550328 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -127,7 +127,7 @@ from openid.store.nonce import mkNonce from openid.server.trustroot import TrustRoot, verifyReturnTo from openid.association import Association, default_negotiator, getSecretSize -from openid.message import Message, InvalidOpenIDNamespace, \ +from openid.message import Message, InvalidOpenIDNamespace, InvalidNamespace, \ OPENID_NS, OPENID2_NS, IDENTIFIER_SELECT, OPENID1_URL_LIMIT from openid.urinorm import urinorm @@ -1444,6 +1444,12 @@ def decode(self, query): query['openid.ns'] = OPENID2_NS message = Message.fromPostArgs(query) raise ProtocolError(message, str(err)) + except InvalidNamespace, err: + # If openid.ns is OK, but there is problem with other namespaces + # We keep only bare parts of query and we try to make a ProtocolError from it + query = [(key, value) for key, value in query.items() if key.count('.') < 2] + message = Message.fromPostArgs(dict(query)) + raise ProtocolError(message, str(err)) mode = message.getArg(OPENID_NS, 'mode') if not mode: diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 3c176ae2..44d6c790 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -700,6 +700,27 @@ def test_112B(self): self.assertEqual(args, m.toPostArgs()) self.failUnless(m.isOpenID2()) + def test_repetitive_namespaces(self): + """ + Message that raises KeyError during encoding, because openid namespace is used in attributes + """ + args = {'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': 'http://binkley.lan/user/test01', + 'openid.identity': 'http://test01.binkley.lan/', + 'openid.mode': 'id_res', + 'openid.ns': 'http://specs.openid.net/auth/2.0', + 'openid.op_endpoint': 'http://binkley.lan/server', + 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies', + 'openid.ns.pape': 'http://specs.openid.net/auth/2.0', + 'openid.pape.auth_policies': 'none', + 'openid.pape.auth_time': '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': '0', + } + self.failUnlessRaises(message.InvalidNamespace, message.Message.fromPostArgs, args) + def test_implicit_sreg_ns(self): openid_args = { 'sreg.email': 'a@b.com' From 06ffd1a9bcfc5f5d8f6a7ddeeae0745b21a273a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 29 Nov 2010 11:47:53 +0100 Subject: [PATCH 015/151] Add usual methods to StoreRequest and StoreResponse --- openid/extensions/ax.py | 60 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index 65d0a512..a718e6dd 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -739,6 +739,36 @@ def getExtensionArgs(self): ax_args.update(kv_args) return ax_args + def fromOpenIDRequest(cls, openid_request): + """Extract a StoreRequest from an OpenID message + + @param openid_request: The OpenID authentication request + containing the attribute fetch request + @type openid_request: C{L{openid.server.server.CheckIDRequest}} + + @rtype: C{L{StoreRequest}} or C{None} + @returns: The StoreRequest extracted from the message or None, if + the message contained no AX extension. + + @raises KeyError: if the AuthRequest is not consistent in its use + of namespace aliases. + + @raises AXError: When parseExtensionArgs would raise same. + + @see: L{parseExtensionArgs} + """ + message = openid_request.message + ax_args = message.getArgs(cls.ns_uri) + self = cls() + try: + self.parseExtensionArgs(ax_args) + except NotAXMessage, err: + return None + + return self + + fromOpenIDRequest = classmethod(fromOpenIDRequest) + class StoreResponse(AXMessage): """An indication that the store request was processed along with @@ -772,3 +802,33 @@ def getExtensionArgs(self): ax_args['error'] = self.error_message return ax_args + + def fromSuccessResponse(cls, success_response, signed=True): + """Construct a StoreResponse object from an OpenID library + SuccessResponse object. + + @param success_response: A successful id_res response object + @type success_response: openid.consumer.consumer.SuccessResponse + + @param signed: Whether non-signed args should be + processsed. If True (the default), only signed arguments + will be processsed. + @type signed: bool + + @returns: A StoreResponse containing the data from the OpenID + message, or None if the SuccessResponse did not contain AX + extension data. + + @raises AXError: when the AX data cannot be parsed. + """ + self = cls() + ax_args = success_response.extensionResponse(self.ns_uri, signed) + + try: + self.parseExtensionArgs(ax_args) + except NotAXMessage, err: + return None + else: + return self + + fromSuccessResponse = classmethod(fromSuccessResponse) From 6ebea90538a058922e82021ad0af0f1d78f3dc3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 19 Aug 2011 13:33:54 +0200 Subject: [PATCH 016/151] Fix missing port on building discovery URL for realm with wildcard --- openid/server/trustroot.py | 6 +++++- openid/test/test_rpverify.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index 843ddc46..f51c035f 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -338,7 +338,11 @@ def buildDiscoveryURL(self): # Use "www." in place of the star assert self.host.startswith('.'), self.host www_domain = 'www' + self.host - return '%s://%s%s' % (self.proto, www_domain, self.path) + if self.port: + port = ':%s' % self.port + else: + port = '' + return '%s://%s%s%s' % (self.proto, www_domain, port, self.path) else: return self.unparsed diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index 9d781bb1..e84d7af4 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -36,6 +36,12 @@ def test_wildcard(self): self.failUnlessDiscoURL('http://*.example.com/foo', 'http://www.example.com/foo') + def test_wildcard_port(self): + """There is a wildcard + """ + self.failUnlessDiscoURL('http://*.example.com:8001/foo', + 'http://www.example.com:8001/foo') + class TestExtractReturnToURLs(unittest.TestCase): disco_url = 'http://example.com/' From efc49289a1bf383bb6ff75ac3f33d4cd23712e73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 12 Sep 2011 13:07:08 +0200 Subject: [PATCH 017/151] Fix case when return URL contains parameter without value --- openid/consumer/consumer.py | 2 +- openid/test/test_consumer.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 91e6d75a..18c36a2c 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -852,7 +852,7 @@ def _verifyReturnToArgs(query): parsed_url = urlparse(return_to) rt_query = parsed_url[4] - parsed_args = cgi.parse_qsl(rt_query) + parsed_args = cgi.parse_qsl(rt_query, keep_blank_values=True) for rt_key, rt_value in parsed_args: try: diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 33a75647..b6e6ec5b 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -1083,6 +1083,15 @@ def test_returnToArgsOkay(self): # no return value, success is assumed if there are no exceptions. self.consumer._verifyReturnToArgs(query) + def test_returnToEmptyArg(self): + query = { + 'openid.mode': 'id_res', + 'openid.return_to': 'http://example.com/?foo=', + 'foo': '', + } + # no return value, success is assumed if there are no exceptions. + self.consumer._verifyReturnToArgs(query) + def test_returnToArgsUnexpectedArg(self): query = { 'openid.mode': 'id_res', From 44ed430eb790c776e6d097fd55594ec71bfe92cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 22 Nov 2017 12:48:39 +0100 Subject: [PATCH 018/151] Add backport for Server encoding classes 4c17264 --- openid/server/server.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/openid/server/server.py b/openid/server/server.py index db550328..9bef5892 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -1520,14 +1520,18 @@ class Server(object): associations I can make and how. @type negotiator: L{openid.association.SessionNegotiator} """ - + + signatoryClass = Signatory + encoderClass = SigningEncoder + decoderClass = Decoder + def __init__( self, store, op_endpoint=None, - signatoryClass=Signatory, - encoderClass=SigningEncoder, - decoderClass=Decoder): + signatoryClass=None, + encoderClass=None, + decoderClass=None): """A new L{Server}. @param store: The back-end where my associations are stored. @@ -1543,8 +1547,23 @@ def __init__( if you want to respond to any version 2 OpenID requests. """ self.store = store + if signatoryClass is None: + signatoryClass = self.signatoryClass + if signatoryClass != Server.signatoryClass: + warnings.warn("Attribute signatoryClass on Server class is deprecated." + "Use signatoryClass argument of __init__ instead.", DeprecationWarning) self.signatory = signatoryClass(self.store) + if encoderClass is None: + encoderClass = self.encoderClass + if encoderClass != Server.encoderClass: + warnings.warn("Attribute encoderClass on Server class is deprecated." + "Use encoderClass argument of __init__ instead.", DeprecationWarning) self.encoder = encoderClass(self.signatory) + if decoderClass is None: + decoderClass = self.decoderClass + if decoderClass != Server.decoderClass: + warnings.warn("Attribute decoderClass on Server class is deprecated." + "Use decoderClass argument of __init__ instead.", DeprecationWarning) self.decoder = decoderClass(self) self.negotiator = default_negotiator.copy() From 189a2092b20c7e6f61c0b35c73626b7513e3100b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 22 Nov 2017 13:31:54 +0100 Subject: [PATCH 019/151] Add Makefile --- .gitattributes | 1 + .gitignore | 2 ++ Makefile | 10 ++++++++++ 3 files changed, 13 insertions(+) create mode 100644 .gitattributes create mode 100644 Makefile diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..1aadc140 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +/Makefile whitespace=space-before-tab,indent-with-non-tab,tabwidth=4 diff --git a/.gitignore b/.gitignore index f1ff221d..faa1bf6e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ *.swp .tox # Created in tests +/.coverage +/htmlcov /sstore diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..161b4211 --- /dev/null +++ b/Makefile @@ -0,0 +1,10 @@ +.PHONY: test coverage + +test: + python admin/runtests + +coverage: + python-coverage erase + -rm -r htmlcov + python-coverage run --branch --source="." admin/runtests + python-coverage html --directory=htmlcov From fe13eb842bb3958c55f29c1cd0e244af3c4d2c3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 22 Nov 2017 13:24:46 +0100 Subject: [PATCH 020/151] Clean imports by isort --- .isort.cfg | 5 +++ .travis.yml | 6 ++-- Makefile | 8 ++++- admin/builddiscover.py | 1 + admin/gettlds.py | 4 +-- examples/consumer.py | 21 ++++++------ examples/djopenid/consumer/views.py | 4 +-- examples/djopenid/manage.py | 1 + examples/djopenid/server/tests.py | 13 ++++---- examples/djopenid/server/views.py | 17 +++++----- examples/djopenid/util.py | 12 +++---- examples/djopenid/views.py | 5 +-- examples/server.py | 20 ++++++------ openid/association.py | 4 +-- openid/consumer/consumer.py | 21 ++++-------- openid/consumer/discover.py | 20 ++++-------- openid/cryptutil.py | 2 +- openid/dh.py | 4 +-- openid/extension.py | 1 + openid/extensions/ax.py | 2 +- openid/extensions/draft/pape2.py | 3 +- openid/extensions/draft/pape5.py | 5 +-- openid/extensions/sreg.py | 6 ++-- openid/fetchers.py | 4 +-- openid/kvform.py | 3 +- openid/message.py | 6 ++-- openid/oidutil.py | 3 +- openid/server/server.py | 15 ++++----- openid/server/trustroot.py | 8 ++--- openid/sreg.py | 5 +-- openid/store/filestore.py | 14 ++++---- openid/store/memstore.py | 5 +-- openid/store/nonce.py | 7 ++-- openid/store/sqlstore.py | 3 +- openid/test/cryptutil.py | 4 +-- openid/test/datadriven.py | 3 +- openid/test/dh.py | 2 ++ openid/test/discoverdata.py | 4 +-- openid/test/kvform.py | 4 ++- openid/test/linkparse.py | 6 ++-- openid/test/oidutil.py | 6 ++-- openid/test/storetest.py | 14 ++++---- openid/test/support.py | 6 ++-- openid/test/test_accept.py | 4 ++- openid/test/test_association.py | 27 +++++++--------- openid/test/test_association_response.py | 12 +++---- openid/test/test_auth_request.py | 3 +- openid/test/test_ax.py | 6 ++-- openid/test/test_consumer.py | 41 ++++++++++-------------- openid/test/test_discover.py | 17 +++++----- openid/test/test_etxrd.py | 6 ++-- openid/test/test_examples.py | 13 +++++--- openid/test/test_extension.py | 6 ++-- openid/test/test_fetchers.py | 10 +++--- openid/test/test_htmldiscover.py | 4 ++- openid/test/test_message.py | 10 +++--- openid/test/test_negotiation.py | 8 +++-- openid/test/test_nonce.py | 9 ++---- openid/test/test_openidyadis.py | 3 +- openid/test/test_pape.py | 3 +- openid/test/test_pape_draft2.py | 3 +- openid/test/test_pape_draft5.py | 5 +-- openid/test/test_parsehtml.py | 7 ++-- openid/test/test_rpverify.py | 8 +++-- openid/test/test_server.py | 14 ++++---- openid/test/test_sreg.py | 5 +-- openid/test/test_symbol.py | 1 + openid/test/test_urinorm.py | 2 ++ openid/test/test_verifydisco.py | 5 +-- openid/test/test_xri.py | 2 ++ openid/test/test_xrires.py | 2 ++ openid/test/test_yadis_discover.py | 9 +++--- openid/test/trustroot.py | 2 ++ openid/yadis/discover.py | 6 ++-- openid/yadis/etxrd.py | 6 ++-- openid/yadis/filters.py | 1 + openid/yadis/parsehtml.py | 2 +- openid/yadis/services.py | 5 +-- openid/yadis/xrires.py | 3 +- setup.py | 2 +- 80 files changed, 308 insertions(+), 271 deletions(-) create mode 100644 .isort.cfg diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 00000000..6c8243e1 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,5 @@ +[settings] +line_length = 120 +combine_as_imports = true +known_third_party = mock,twill +known_first_party = openid diff --git a/.travis.yml b/.travis.yml index 7799c93e..14636d8c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,6 +3,8 @@ language: python python: - 2.7 -before_install: pip install Django pycrypto lxml +before_install: pip install Django pycrypto lxml isort install: python setup.py install -script: ./run_tests.sh +script: + - make check-isort + - make test diff --git a/Makefile b/Makefile index 161b4211..8b779090 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test coverage +.PHONY: test coverage isort check-isort test: python admin/runtests @@ -8,3 +8,9 @@ coverage: -rm -r htmlcov python-coverage run --branch --source="." admin/runtests python-coverage html --directory=htmlcov + +isort: + isort --recursive . + +check-isort: + isort --check-only --diff --recursive . diff --git a/admin/builddiscover.py b/admin/builddiscover.py index d065c0a5..011ab883 100755 --- a/admin/builddiscover.py +++ b/admin/builddiscover.py @@ -4,6 +4,7 @@ from openid.test import discoverdata + manifest_header = """\ # This file contains test cases for doing YADIS identity URL and # service discovery. For each case, there are three URLs. The first diff --git a/admin/gettlds.py b/admin/gettlds.py index 43006380..f473224d 100644 --- a/admin/gettlds.py +++ b/admin/gettlds.py @@ -8,10 +8,8 @@ Then cut-n-paste. """ - -import urllib2 - import sys +import urllib2 langs = { 'php': (r"'/\.(", diff --git a/examples/consumer.py b/examples/consumer.py index 1c38a623..c4f299c0 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -8,17 +8,18 @@ """ __copyright__ = 'Copyright 2005-2008, Janrain, Inc.' -from Cookie import SimpleCookie import cgi -import urlparse import cgitb import sys +import urlparse +from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer +from Cookie import SimpleCookie + def quoteattr(s): qs = cgi.escape(s, 1) return '"%s"' % (qs,) -from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler try: import openid @@ -32,14 +33,14 @@ def quoteattr(s): For more information, see the README in the root of the library distribution.""") sys.exit(1) +else: + from openid.consumer import consumer + from openid.cryptutil import randomString + from openid.extensions import pape, sreg + from openid.fetchers import Urllib2Fetcher, setDefaultFetcher + from openid.oidutil import appendArgs + from openid.store import filestore, memstore -from openid.store import memstore -from openid.store import filestore -from openid.consumer import consumer -from openid.oidutil import appendArgs -from openid.cryptutil import randomString -from openid.fetchers import setDefaultFetcher, Urllib2Fetcher -from openid.extensions import pape, sreg # Used with an OpenID provider affiliate program. OPENID_PROVIDER_NAME = 'MyOpenID' diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index c8992947..1f4dd945 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -6,10 +6,10 @@ from openid.consumer import consumer from openid.consumer.discover import DiscoveryFailure from openid.extensions import ax, pape, sreg -from openid.yadis.constants import YADIS_HEADER_NAME, YADIS_CONTENT_TYPE from openid.server.trustroot import RP_RETURN_TO_URL_TYPE +from openid.yadis.constants import YADIS_CONTENT_TYPE, YADIS_HEADER_NAME -from djopenid import util +from .. import util PAPE_POLICIES = [ 'AUTH_PHISHING_RESISTANT', diff --git a/examples/djopenid/manage.py b/examples/djopenid/manage.py index 5e78ea97..ae949585 100644 --- a/examples/djopenid/manage.py +++ b/examples/djopenid/manage.py @@ -1,5 +1,6 @@ #!/usr/bin/env python from django.core.management import execute_manager + try: import settings # Assumed to be in the same directory. except ImportError: diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index e7ddd06e..d86151bc 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -1,16 +1,17 @@ -from django.test.testcases import TestCase -from djopenid.server import views -from djopenid import util - -from django.http import HttpRequest from django.contrib.sessions.middleware import SessionWrapper +from django.http import HttpRequest +from django.test.testcases import TestCase -from openid.server.server import CheckIDRequest from openid.message import Message +from openid.server.server import CheckIDRequest from openid.yadis.constants import YADIS_CONTENT_TYPE from openid.yadis.services import applyFilter +from .. import util +from ..server import views + + def dummyRequest(): request = HttpRequest() request.session = SessionWrapper("test") diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index 67fa00b1..bb6d6602 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -17,20 +17,19 @@ import cgi -from djopenid import util -from djopenid.util import getViewURL - from django import http from django.views.generic.simple import direct_to_template -from openid.server.server import Server, ProtocolError, CheckIDRequest, \ - EncodingError -from openid.server.trustroot import verifyReturnTo -from openid.yadis.discover import DiscoveryFailure from openid.consumer.discover import OPENID_IDP_2_0_TYPE -from openid.extensions import sreg -from openid.extensions import pape +from openid.extensions import pape, sreg from openid.fetchers import HTTPFetchingError +from openid.server.server import CheckIDRequest, EncodingError, ProtocolError, Server +from openid.server.trustroot import verifyReturnTo +from openid.yadis.discover import DiscoveryFailure + +from .. import util +from ..util import getViewURL + def getOpenIDStore(): """ diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index 7b30df13..f06e11fb 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -5,20 +5,20 @@ from urlparse import urljoin -from django.db import connection -from django.template.context import RequestContext -from django.template import loader from django import http +from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import reverse as reverseURL +from django.db import connection +from django.template import loader +from django.template.context import RequestContext from django.views.generic.simple import direct_to_template -from django.conf import settings - -from openid.store.filestore import FileOpenIDStore from openid.store import sqlstore +from openid.store.filestore import FileOpenIDStore from openid.yadis.constants import YADIS_CONTENT_TYPE + def getOpenIDStore(filestore_path, table_prefix): """ Returns an OpenID association store object based on the database diff --git a/examples/djopenid/views.py b/examples/djopenid/views.py index 5d399d60..3f08324d 100644 --- a/examples/djopenid/views.py +++ b/examples/djopenid/views.py @@ -1,7 +1,9 @@ -from djopenid import util from django.views.generic.simple import direct_to_template +from . import util + + def index(request): consumer_url = util.getViewURL( request, 'djopenid.consumer.views.startOpenID') @@ -11,4 +13,3 @@ def index(request): request, 'index.html', {'consumer_url':consumer_url, 'server_url':server_url}) - diff --git a/examples/server.py b/examples/server.py index 3adc61b5..ddbe5e45 100644 --- a/examples/server.py +++ b/examples/server.py @@ -2,19 +2,20 @@ __copyright__ = 'Copyright 2005-2008, Janrain, Inc.' -from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler -from urlparse import urlparse - -import time -import Cookie import cgi import cgitb +import Cookie import sys +import time +from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer +from urlparse import urlparse + def quoteattr(s): qs = cgi.escape(s, 1) return '"%s"' % (qs,) + try: import openid except ImportError: @@ -27,11 +28,12 @@ def quoteattr(s): For more information, see the README in the root of the library distribution.""") sys.exit(1) +else: + from openid.consumer import discover + from openid.extensions import sreg + from openid.server import server + from openid.store.filestore import FileOpenIDStore -from openid.extensions import sreg -from openid.server import server -from openid.store.filestore import FileOpenIDStore -from openid.consumer import discover class OpenIDHTTPServer(HTTPServer): """ diff --git a/openid/association.py b/openid/association.py index e1429ff7..f9cc91e4 100644 --- a/openid/association.py +++ b/openid/association.py @@ -34,9 +34,7 @@ import time -from openid import cryptutil -from openid import kvform -from openid import oidutil +from openid import cryptutil, kvform, oidutil from openid.message import OPENID_NS all_association_types = [ diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 18c36a2c..4b5dfce0 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -190,23 +190,16 @@ import cgi import copy import logging -from urlparse import urlparse, urldefrag - -from openid import fetchers - -from openid.consumer.discover import discover, OpenIDServiceEndpoint, \ - DiscoveryFailure, OPENID_1_0_TYPE, OPENID_1_1_TYPE, OPENID_2_0_TYPE -from openid.message import Message, OPENID_NS, OPENID2_NS, OPENID1_NS, \ - IDENTIFIER_SELECT, no_default, BARE_NS -from openid import cryptutil -from openid import oidutil -from openid.association import Association, default_negotiator, \ - SessionNegotiator +from urlparse import urldefrag, urlparse + +from openid import cryptutil, fetchers, oidutil, urinorm +from openid.association import Association, SessionNegotiator, default_negotiator +from openid.consumer.discover import (OPENID_1_0_TYPE, OPENID_1_1_TYPE, OPENID_2_0_TYPE, DiscoveryFailure, + OpenIDServiceEndpoint, discover) from openid.dh import DiffieHellman +from openid.message import BARE_NS, IDENTIFIER_SELECT, OPENID1_NS, OPENID2_NS, OPENID_NS, Message, no_default from openid.store.nonce import mkNonce, split as splitNonce from openid.yadis.manager import Discovery -from openid import urinorm - __all__ = ['AuthRequest', 'Consumer', 'SuccessResponse', 'SetupNeededResponse', 'CancelResponse', 'FailureResponse', diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index a30e7872..e4b9e639 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -13,20 +13,16 @@ 'discover', ] -import urlparse import logging +import urlparse -from openid import fetchers, urinorm - -from openid import yadis -from openid.yadis.etxrd import nsTag, XRDSError, XRD_NS_2_0 -from openid.yadis.services import applyFilter as extractServices -from openid.yadis.discover import discover as yadisDiscover -from openid.yadis.discover import DiscoveryFailure -from openid.yadis import xrires, filters -from openid.yadis import xri - +from openid import fetchers, urinorm, yadis from openid.consumer import html_parse +from openid.message import OPENID1_NS as OPENID_1_0_MESSAGE_NS, OPENID2_NS as OPENID_2_0_MESSAGE_NS +from openid.yadis import filters, xri, xrires +from openid.yadis.discover import DiscoveryFailure, discover as yadisDiscover +from openid.yadis.etxrd import XRD_NS_2_0, XRDSError, nsTag +from openid.yadis.services import applyFilter as extractServices OPENID_1_0_NS = 'http://openid.net/xmlns/1.0' OPENID_IDP_2_0_TYPE = 'http://specs.openid.net/auth/2.0/server' @@ -34,8 +30,6 @@ OPENID_1_1_TYPE = 'http://openid.net/signon/1.1' OPENID_1_0_TYPE = 'http://openid.net/signon/1.0' -from openid.message import OPENID1_NS as OPENID_1_0_MESSAGE_NS -from openid.message import OPENID2_NS as OPENID_2_0_MESSAGE_NS class OpenIDServiceEndpoint(object): """Object representing an OpenID service endpoint. diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 0ac3ce3d..868877a9 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -30,7 +30,7 @@ import os import random -from openid.oidutil import toBase64, fromBase64 +from openid.oidutil import fromBase64, toBase64 try: import hashlib diff --git a/openid/dh.py b/openid/dh.py index bb83bbe8..3478240b 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -1,5 +1,5 @@ -from openid import cryptutil -from openid import oidutil +from openid import cryptutil, oidutil + def strxor(x, y): if len(x) != len(y): diff --git a/openid/extension.py b/openid/extension.py index d48bbb2f..6366f03d 100644 --- a/openid/extension.py +++ b/openid/extension.py @@ -1,5 +1,6 @@ from openid import message as message_module + class Extension(object): """An interface for OpenID extensions. diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index a718e6dd..6b21812b 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -13,8 +13,8 @@ ] from openid import extension +from openid.message import OPENID_NS, NamespaceMap from openid.server.trustroot import TrustRoot -from openid.message import NamespaceMap, OPENID_NS # Use this as the 'count' value for an attribute in a FetchRequest to # ask for as many values as the OP can provide. diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index e7320465..b800ce2b 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -15,9 +15,10 @@ 'AUTH_MULTI_FACTOR_PHYSICAL', ] -from openid.extension import Extension import re +from openid.extension import Extension + ns_uri = "http://specs.openid.net/extensions/pape/1.0" AUTH_MULTI_FACTOR_PHYSICAL = \ diff --git a/openid/extensions/draft/pape5.py b/openid/extensions/draft/pape5.py index 3bd1ffc0..e1468736 100644 --- a/openid/extensions/draft/pape5.py +++ b/openid/extensions/draft/pape5.py @@ -17,9 +17,10 @@ 'LEVELS_JISA', ] -from openid.extension import Extension -import warnings import re +import warnings + +from openid.extension import Extension ns_uri = "http://specs.openid.net/extensions/pape/1.0" diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index c66a8b08..87e46fa3 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -35,11 +35,11 @@ namespace and XRD Type value """ -from openid.message import registerNamespaceAlias, \ - NamespaceAliasRegistrationError -from openid.extension import Extension import logging +from openid.extension import Extension +from openid.message import NamespaceAliasRegistrationError, registerNamespaceAlias + try: basestring #pylint:disable-msg=W0104 except NameError: diff --git a/openid/fetchers.py b/openid/fetchers.py index 1c119a45..d4b80290 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -7,10 +7,10 @@ 'HTTPFetcher', 'createHTTPFetcher', 'HTTPFetchingError', 'HTTPError'] -import urllib2 -import time import cStringIO import sys +import time +import urllib2 import openid import openid.urinorm diff --git a/openid/kvform.py b/openid/kvform.py index 38258550..846cf74c 100644 --- a/openid/kvform.py +++ b/openid/kvform.py @@ -1,7 +1,8 @@ __all__ = ['seqToKV', 'kvToSeq', 'dictToKV', 'kvToDict'] -import types import logging +import types + class KVFormError(ValueError): pass diff --git a/openid/message.py b/openid/message.py index 35ca22a2..92706d93 100644 --- a/openid/message.py +++ b/openid/message.py @@ -5,11 +5,11 @@ 'IDENTIFIER_SELECT'] import copy -import warnings import urllib +import warnings + +from openid import kvform, oidutil -from openid import oidutil -from openid import kvform try: ElementTree = oidutil.importElementTree() except ImportError: diff --git a/openid/oidutil.py b/openid/oidutil.py index b109a734..36d0af10 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -8,10 +8,9 @@ __all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML', 'toUnicode'] import binascii +import logging import sys import urlparse -import logging - from urllib import urlencode elementtree_modules = [ diff --git a/openid/server/server.py b/openid/server/server.py index 9bef5892..7cd1ae99 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -116,19 +116,18 @@ @group Response Encodings: ENCODE_KVFORM, ENCODE_HTML_FORM, ENCODE_URL """ -import time, warnings import logging +import time +import warnings from copy import deepcopy -from openid import cryptutil -from openid import oidutil -from openid import kvform +from openid import cryptutil, kvform, oidutil +from openid.association import Association, default_negotiator, getSecretSize from openid.dh import DiffieHellman -from openid.store.nonce import mkNonce +from openid.message import (IDENTIFIER_SELECT, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, InvalidNamespace, + InvalidOpenIDNamespace, Message) from openid.server.trustroot import TrustRoot, verifyReturnTo -from openid.association import Association, default_negotiator, getSecretSize -from openid.message import Message, InvalidOpenIDNamespace, InvalidNamespace, \ - OPENID_NS, OPENID2_NS, IDENTIFIER_SELECT, OPENID1_URL_LIMIT +from openid.store.nonce import mkNonce from openid.urinorm import urinorm HTTP_OK = 200 diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index f51c035f..49863539 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -17,13 +17,13 @@ 'verifyReturnTo', ] +import logging +import re +from urlparse import urlparse, urlunparse + from openid import urinorm from openid.yadis import services -from urlparse import urlparse, urlunparse -import re -import logging - ############################################ _protocols = ['http', 'https'] _top_level_domains = [ diff --git a/openid/sreg.py b/openid/sreg.py index d665a5d0..bf454d7b 100644 --- a/openid/sreg.py +++ b/openid/sreg.py @@ -1,7 +1,8 @@ """moved to L{openid.extensions.sreg}""" import warnings -warnings.warn("openid.sreg has moved to openid.extensions.sreg", - DeprecationWarning) from openid.extensions.sreg import * + +warnings.warn("openid.sreg has moved to openid.extensions.sreg", + DeprecationWarning) diff --git a/openid/store/filestore.py b/openid/store/filestore.py index 89884aca..adb69dac 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -3,14 +3,18 @@ flat files. """ -import string +import logging import os import os.path +import string import time -import logging - from errno import EEXIST, ENOENT +from openid import cryptutil, oidutil +from openid.association import Association +from openid.store import nonce +from openid.store.interface import OpenIDStore + try: from tempfile import mkstemp except ImportError: @@ -34,10 +38,6 @@ def mkstemp(dir): raise RuntimeError('Failed to get temp file after 5 attempts') -from openid.association import Association -from openid.store.interface import OpenIDStore -from openid.store import nonce -from openid import cryptutil, oidutil _filename_allowed = string.ascii_letters + string.digits + '.' try: diff --git a/openid/store/memstore.py b/openid/store/memstore.py index e2748fb2..89a16bdc 100644 --- a/openid/store/memstore.py +++ b/openid/store/memstore.py @@ -1,10 +1,11 @@ """A simple store using only in-process memory.""" -from openid.store import nonce - import copy import time +from openid.store import nonce + + class ServerAssocs(object): def __init__(self): self.assocs = {} diff --git a/openid/store/nonce.py b/openid/store/nonce.py index e9337a8a..3814dd1d 100644 --- a/openid/store/nonce.py +++ b/openid/store/nonce.py @@ -4,10 +4,11 @@ 'checkTimestamp', ] -from openid import cryptutil -from time import strptime, strftime, gmtime, time -from calendar import timegm import string +from calendar import timegm +from time import gmtime, strftime, strptime, time + +from openid import cryptutil NONCE_CHARS = string.ascii_letters + string.digits diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index 58c4337e..a629e726 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -10,8 +10,9 @@ import time from openid.association import Association -from openid.store.interface import OpenIDStore from openid.store import nonce +from openid.store.interface import OpenIDStore + def _inTxn(func): def wrapped(self, *args, **kwargs): diff --git a/openid/test/cryptutil.py b/openid/test/cryptutil.py index 753596cb..e52b6a3b 100644 --- a/openid/test/cryptutil.py +++ b/openid/test/cryptutil.py @@ -1,6 +1,6 @@ -import sys -import random import os.path +import random +import sys from openid import cryptutil diff --git a/openid/test/datadriven.py b/openid/test/datadriven.py index 2dbcfd0d..c7dc4f70 100644 --- a/openid/test/datadriven.py +++ b/openid/test/datadriven.py @@ -1,5 +1,6 @@ -import unittest import types +import unittest + class DataDrivenTestCase(unittest.TestCase): cases = [] diff --git a/openid/test/dh.py b/openid/test/dh.py index 16b8d560..299730b1 100644 --- a/openid/test/dh.py +++ b/openid/test/dh.py @@ -1,6 +1,8 @@ import os.path + from openid.dh import DiffieHellman, strxor + def test_strxor(): NUL = '\x00' diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 78b18f73..1d906d8a 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -1,9 +1,9 @@ """Module to make discovery data test cases available""" -import urlparse import os.path +import urlparse -from openid.yadis.discover import DiscoveryResult, DiscoveryFailure from openid.yadis.constants import YADIS_HEADER_NAME +from openid.yadis.discover import DiscoveryFailure, DiscoveryResult tests_dir = os.path.dirname(__file__) data_path = os.path.join(tests_dir, 'data') diff --git a/openid/test/kvform.py b/openid/test/kvform.py index 636aa0bf..b54a64b5 100644 --- a/openid/test/kvform.py +++ b/openid/test/kvform.py @@ -1,6 +1,8 @@ +import unittest + from openid import kvform from openid.test.support import CatchLogs -import unittest + class KVBaseTest(unittest.TestCase, CatchLogs): def shortDescription(self): diff --git a/openid/test/linkparse.py b/openid/test/linkparse.py index 04475b7a..adcdfb35 100644 --- a/openid/test/linkparse.py +++ b/openid/test/linkparse.py @@ -1,8 +1,10 @@ -from openid.consumer.html_parse import parseLinkAttrs -import os.path import codecs +import os.path import unittest +from openid.consumer.html_parse import parseLinkAttrs + + def parseLink(line): parts = line.split() optional = parts[0] == 'Link*:' diff --git a/openid/test/oidutil.py b/openid/test/oidutil.py index cc42887b..568f16af 100644 --- a/openid/test/oidutil.py +++ b/openid/test/oidutil.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- -import unittest import codecs -import string import random +import string +import unittest + from openid import oidutil + def test_base64(): allowed_s = string.ascii_letters + string.digits + '+/=' allowed_d = {} diff --git a/openid/test/storetest.py b/openid/test/storetest.py index e428c36d..6d876fc2 100644 --- a/openid/test/storetest.py +++ b/openid/test/storetest.py @@ -1,14 +1,14 @@ +import os +import random +import socket +import string +import time +import unittest + from openid.association import Association from openid.cryptutil import randomString from openid.store.nonce import mkNonce, split -import unittest -import string -import time -import socket -import random -import os - db_host = 'dbtest' allowed_handle = [] diff --git a/openid/test/support.py b/openid/test/support.py index 621b5a66..3901e25d 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -1,6 +1,8 @@ -from openid import message -from logging.handlers import BufferingHandler import logging +from logging.handlers import BufferingHandler + +from openid import message + class TestHandler(BufferingHandler): def __init__(self, messages): diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index e8d9be00..547e42a6 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -1,7 +1,9 @@ -import unittest import os.path +import unittest + from openid.yadis import accept + def getTestData(): """Read the test data off of disk diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 6404a008..8ab81785 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -1,13 +1,17 @@ -from openid.test import datadriven - -import unittest - -from openid.message import Message, BARE_NS, OPENID_NS, OPENID2_NS -from openid import association import time -from openid import cryptutil +import unittest import warnings +from openid import association, cryptutil +from openid.consumer.consumer import (DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, + PlainTextConsumerSession) +from openid.dh import DiffieHellman +from openid.message import BARE_NS, OPENID2_NS, OPENID_NS, Message +from openid.server.server import (DiffieHellmanSHA1ServerSession, DiffieHellmanSHA256ServerSession, + PlainTextServerSession) +from openid.test import datadriven + + class AssociationSerializationTest(unittest.TestCase): def test_roundTrip(self): issued = int(time.time()) @@ -22,17 +26,8 @@ def test_roundTrip(self): self.failUnlessEqual(assoc.lifetime, assoc2.lifetime) self.failUnlessEqual(assoc.assoc_type, assoc2.assoc_type) -from openid.server.server import \ - DiffieHellmanSHA1ServerSession, \ - DiffieHellmanSHA256ServerSession, \ - PlainTextServerSession -from openid.consumer.consumer import \ - DiffieHellmanSHA1ConsumerSession, \ - DiffieHellmanSHA256ConsumerSession, \ - PlainTextConsumerSession -from openid.dh import DiffieHellman def createNonstandardConsumerDH(): nonstandard_dh = DiffieHellman(1315291, 2) diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 5a68ac31..11161fbb 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -3,15 +3,15 @@ This duplicates some things that are covered by test_consumer, but this works for now. """ +import unittest + from openid import oidutil -from openid.test.test_consumer import CatchLogs -from openid.message import Message, OPENID2_NS, OPENID_NS, no_default +from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession, GenericConsumer, ProtocolError +from openid.consumer.discover import OPENID_1_1_TYPE, OPENID_2_0_TYPE, OpenIDServiceEndpoint +from openid.message import OPENID2_NS, OPENID_NS, Message, no_default from openid.server.server import DiffieHellmanSHA1ServerSession -from openid.consumer.consumer import GenericConsumer, \ - DiffieHellmanSHA1ConsumerSession, ProtocolError -from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_1_1_TYPE, OPENID_2_0_TYPE from openid.store import memstore -import unittest +from openid.test.test_consumer import CatchLogs # Some values we can use for convenience (see mkAssocResponse) association_response_values = { diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index 9114823a..1419ab54 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -1,10 +1,11 @@ import cgi import unittest -from openid.consumer import consumer from openid import message +from openid.consumer import consumer from openid.test import support + class DummyEndpoint(object): preferred_namespace = None local_id = None diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 9c349a78..28f90fab 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -2,9 +2,11 @@ """ import unittest -from openid.extensions import ax -from openid.message import NamespaceMap, Message, OPENID2_NS + from openid.consumer.consumer import SuccessResponse +from openid.extensions import ax +from openid.message import OPENID2_NS, Message, NamespaceMap + class BogusAXMessage(ax.AXMessage): mode = 'bogus' diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index b6e6ec5b..4bc51122 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -1,32 +1,27 @@ -import urlparse import cgi import time +import unittest +import urlparse import warnings -from openid.message import Message, OPENID_NS, OPENID2_NS, IDENTIFIER_SELECT, \ - OPENID1_NS, BARE_NS -from openid import cryptutil, dh, oidutil, kvform -from openid.store.nonce import mkNonce, split as splitNonce -from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_2_0_TYPE, \ - OPENID_1_1_TYPE -from openid.consumer.consumer import \ - AuthRequest, GenericConsumer, SUCCESS, FAILURE, CANCEL, SETUP_NEEDED, \ - SuccessResponse, FailureResponse, SetupNeededResponse, CancelResponse, \ - DiffieHellmanSHA1ConsumerSession, Consumer, PlainTextConsumerSession, \ - SetupNeededError, DiffieHellmanSHA256ConsumerSession, ServerError, \ - ProtocolError, _httpResponseToMessage -from openid import association -from openid.server.server import \ - PlainTextServerSession, DiffieHellmanSHA1ServerSession -from openid.yadis.manager import Discovery -from openid.yadis.discover import DiscoveryFailure +from openid import association, cryptutil, dh, fetchers, kvform, oidutil +from openid.consumer.consumer import (CANCEL, FAILURE, SETUP_NEEDED, SUCCESS, AuthRequest, CancelResponse, Consumer, + DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, + FailureResponse, GenericConsumer, PlainTextConsumerSession, ProtocolError, + ServerError, SetupNeededError, SetupNeededResponse, SuccessResponse, + _httpResponseToMessage) +from openid.consumer.discover import OPENID_1_1_TYPE, OPENID_2_0_TYPE, OpenIDServiceEndpoint from openid.dh import DiffieHellman - -from openid.fetchers import HTTPResponse, HTTPFetchingError -from openid import fetchers +from openid.extension import Extension +from openid.fetchers import HTTPFetchingError, HTTPResponse +from openid.message import BARE_NS, IDENTIFIER_SELECT, OPENID1_NS, OPENID2_NS, OPENID_NS, Message +from openid.server.server import DiffieHellmanSHA1ServerSession, PlainTextServerSession from openid.store import memstore +from openid.store.nonce import mkNonce, split as splitNonce +from openid.yadis.discover import DiscoveryFailure +from openid.yadis.manager import Discovery -from support import CatchLogs +from .support import CatchLogs assocs = [ ('another 20-byte key.', 'Snarky'), @@ -207,7 +202,6 @@ def run(): run() assert fetcher.num_assocs == 2 -import unittest http_server_url = 'http://server.example.com/' consumer_url = 'http://consumer.example.com/' @@ -2044,7 +2038,6 @@ def returnTrue(unused1, unused2): 'http://claimed.id/', [self.to_match]) self.failUnlessEqual(matching_endpoint, result) -from openid.extension import Extension class SillyExtension(Extension): ns_uri = 'http://silly.example.com/' ns_alias = 'silly' diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 18bdd0ca..5b6c2996 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -1,16 +1,18 @@ # -*- coding: utf-8 -*- +import os.path import sys import unittest -import datadriven -import os.path -from openid import fetchers -from openid.fetchers import HTTPResponse -from openid.yadis.discover import DiscoveryFailure +import warnings +from urlparse import urlsplit + +from openid import fetchers, message from openid.consumer import discover +from openid.fetchers import HTTPResponse from openid.yadis import xrires +from openid.yadis.discover import DiscoveryFailure from openid.yadis.xri import XRI -from urlparse import urlsplit -from openid import message + +from . import datadriven ### Tests for conditions that trigger DiscoveryFailure @@ -64,7 +66,6 @@ def runOneTest(self): # testing the behaviour in the presence of string exceptions, # deprecated or not, so tell it no to complain when this particular # string exception is raised. -import warnings warnings.filterwarnings('ignore', 'raising a string.*', DeprecationWarning, r'^openid\.test\.test_discover$', 77) diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index 51cd27f6..c3ff68ab 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -1,6 +1,8 @@ -import unittest -from openid.yadis import services, etxrd, xri import os.path +import unittest + +from openid.yadis import etxrd, services, xri + def datapath(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) diff --git a/openid/test/test_examples.py b/openid/test/test_examples.py index 92269d05..ca83d839 100644 --- a/openid/test/test_examples.py +++ b/openid/test/test_examples.py @@ -1,14 +1,19 @@ "Test some examples." +import os.path import socket -import os.path, unittest, sys, time +import sys +import time +import unittest from cStringIO import StringIO -import twill.commands, twill.parse, twill.unit +import twill.commands +import twill.parse +import twill.unit -from openid.consumer.discover import \ - OpenIDServiceEndpoint, OPENID_1_1_TYPE from openid.consumer.consumer import AuthRequest +from openid.consumer.discover import OPENID_1_1_TYPE, OpenIDServiceEndpoint + class TwillTest(twill.unit.TestInfo): """Variant of twill.unit.TestInfo that runs a function as a test script, diff --git a/openid/test/test_extension.py b/openid/test/test_extension.py index 7dadbd0b..11ba1b26 100644 --- a/openid/test/test_extension.py +++ b/openid/test/test_extension.py @@ -1,8 +1,8 @@ -from openid import extension -from openid import message - import unittest +from openid import extension, message + + class DummyExtension(extension.Extension): ns_uri = 'http://an.extension/' ns_alias = 'dummy' diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 6370a89c..1ec5641f 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -1,10 +1,11 @@ -import warnings -import unittest +import socket import sys +import unittest import urllib2 -from urllib import addinfourl -import socket +import warnings +from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from cStringIO import StringIO +from urllib import addinfourl from mock import Mock @@ -118,7 +119,6 @@ def run_fetcher_tests(server): for f in non_exc_fetchers: test_fetcher(f, False, server) -from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer class FetcherTestHandler(BaseHTTPRequestHandler): cases = { diff --git a/openid/test/test_htmldiscover.py b/openid/test/test_htmldiscover.py index 0a49e163..e310435d 100644 --- a/openid/test/test_htmldiscover.py +++ b/openid/test/test_htmldiscover.py @@ -1,5 +1,7 @@ from openid.consumer.discover import OpenIDServiceEndpoint -import datadriven + +from . import datadriven + class BadLinksTestCase(datadriven.DataDrivenTestCase): cases = [ diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 44d6c790..571eec95 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- -from openid import message -from openid import oidutil -from openid.extensions import sreg - -import urllib import cgi import unittest +import urllib + +from openid import message, oidutil +from openid.extensions import sreg + def mkGetArgTest(ns, key, expected=None): def test(self): diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index 6245c142..c23ef96e 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -1,11 +1,13 @@ import unittest -from support import CatchLogs -from openid.message import Message, OPENID2_NS, OPENID1_NS, OPENID_NS from openid import association from openid.consumer.consumer import GenericConsumer, ServerError -from openid.consumer.discover import OpenIDServiceEndpoint, OPENID_2_0_TYPE +from openid.consumer.discover import OPENID_2_0_TYPE, OpenIDServiceEndpoint +from openid.message import OPENID1_NS, OPENID2_NS, OPENID_NS, Message + +from .support import CatchLogs + class ErrorRaisingConsumer(GenericConsumer): """ diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index 2138305c..fe171512 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -1,12 +1,9 @@ -from openid.test import datadriven +import re import time import unittest -import re -from openid.store.nonce import \ - mkNonce, \ - split as splitNonce, \ - checkTimestamp +from openid.store.nonce import checkTimestamp, mkNonce, split as splitNonce +from openid.test import datadriven nonce_re = re.compile(r'\A\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ') diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 8573d3ce..4b7cca4a 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -1,7 +1,6 @@ import unittest -from openid.consumer.discover import \ - OpenIDServiceEndpoint, OPENID_1_1_TYPE, OPENID_1_0_TYPE +from openid.consumer.discover import OPENID_1_0_TYPE, OPENID_1_1_TYPE, OpenIDServiceEndpoint from openid.yadis.services import applyFilter diff --git a/openid/test/test_pape.py b/openid/test/test_pape.py index ef47f60c..0507b2c8 100644 --- a/openid/test/test_pape.py +++ b/openid/test/test_pape.py @@ -1,7 +1,8 @@ +import unittest + from openid.extensions import pape -import unittest class PapeImportTestCase(unittest.TestCase): def test_version(self): diff --git a/openid/test/test_pape_draft2.py b/openid/test/test_pape_draft2.py index ed3d439f..f468015b 100644 --- a/openid/test/test_pape_draft2.py +++ b/openid/test/test_pape_draft2.py @@ -1,9 +1,10 @@ +import unittest + from openid.extensions.draft import pape2 as pape from openid.message import * from openid.server import server -import unittest class PapeRequestTestCase(unittest.TestCase): def setUp(self): diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index d93ee96e..9693fad9 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -1,13 +1,14 @@ +import unittest +import warnings + from openid.extensions.draft import pape5 as pape from openid.message import * from openid.server import server -import warnings warnings.filterwarnings('ignore', module=__name__, message='"none" used as a policy URI') -import unittest class PapeRequestTestCase(unittest.TestCase): def setUp(self): diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index 221565b2..fe90ac71 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -1,7 +1,10 @@ -from openid.yadis.parsehtml import YadisHTMLParser, ParseDone +import os.path +import sys +import unittest from HTMLParser import HTMLParseError -import os.path, unittest, sys +from openid.yadis.parsehtml import ParseDone, YadisHTMLParser + class _TestCase(unittest.TestCase): reserved_values = ['None', 'EOF'] diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index e84d7af4..d37b5949 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -3,11 +3,13 @@ __all__ = ['TestBuildDiscoveryURL'] -from openid.yadis.discover import DiscoveryResult, DiscoveryFailure -from openid.yadis import services +import unittest + from openid.server import trustroot from openid.test.support import CatchLogs -import unittest +from openid.yadis import services +from openid.yadis.discover import DiscoveryFailure, DiscoveryResult + # Too many methods does not apply to unit test objects #pylint:disable-msg=R0904 diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 797c051f..0b3b6cc7 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -1,18 +1,16 @@ """Tests for openid.server. """ -from openid.server import server -from openid import association, cryptutil, oidutil -from openid.message import Message, OPENID_NS, OPENID2_NS, OPENID1_NS, \ - IDENTIFIER_SELECT, no_default, OPENID1_URL_LIMIT -from openid.store import memstore -from openid.test.support import CatchLogs import cgi - import unittest import warnings - from urlparse import urlparse +from openid import association, cryptutil, oidutil +from openid.message import IDENTIFIER_SELECT, OPENID1_NS, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, Message, no_default +from openid.server import server +from openid.store import memstore +from openid.test.support import CatchLogs + # In general, if you edit or add tests here, try to move in the direction # of testing smaller units. For testing the external interfaces, we'll be # developing an implementation-agnostic testing suite. diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index 49f9ea91..0abbc5eb 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -1,8 +1,9 @@ +import unittest + from openid.extensions import sreg -from openid.message import NamespaceMap, Message, registerNamespaceAlias +from openid.message import Message, NamespaceMap, registerNamespaceAlias from openid.server.server import OpenIDRequest, OpenIDResponse -import unittest class SRegURITest(unittest.TestCase): def test_is11(self): diff --git a/openid/test/test_symbol.py b/openid/test/test_symbol.py index 42ca2123..7f9b79bb 100644 --- a/openid/test/test_symbol.py +++ b/openid/test/test_symbol.py @@ -2,6 +2,7 @@ from openid import oidutil + class SymbolTest(unittest.TestCase): def test_selfEquality(self): s = oidutil.Symbol('xxx') diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 7f5ecaa1..154f7516 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -1,7 +1,9 @@ import os import unittest + import openid.urinorm + class UrinormTest(unittest.TestCase): def __init__(self, desc, case, expected): unittest.TestCase.__init__(self) diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index e64a48a5..43664bc3 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -1,9 +1,10 @@ import unittest + from openid import message +from openid.consumer import consumer, discover from openid.test.support import OpenIDTestMixin -from openid.consumer import consumer from openid.test.test_consumer import TestIdRes -from openid.consumer import discover + def const(result): """Return a function that ignores any arguments and just returns diff --git a/openid/test/test_xri.py b/openid/test/test_xri.py index dc5921ae..33ea0e05 100644 --- a/openid/test/test_xri.py +++ b/openid/test/test_xri.py @@ -1,6 +1,8 @@ from unittest import TestCase + from openid.yadis import xri + class XriDiscoveryTestCase(TestCase): def test_isXRI(self): i = xri.identifierScheme diff --git a/openid/test/test_xrires.py b/openid/test/test_xrires.py index 4e17e3b4..873255c4 100644 --- a/openid/test/test_xrires.py +++ b/openid/test/test_xrires.py @@ -1,7 +1,9 @@ from unittest import TestCase + from openid.yadis import xrires + class ProxyQueryTestCase(TestCase): def setUp(self): self.proxy_url = 'http://xri.example.com/' diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index b13241d8..8c222d08 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -6,16 +6,15 @@ tests with a mock fetcher instead of spawning threads with BaseHTTPServer. """ -import unittest -import urlparse import re import types - -from openid.yadis.discover import discover, DiscoveryFailure +import unittest +import urlparse from openid import fetchers +from openid.yadis.discover import DiscoveryFailure, discover -import discoverdata +from . import discoverdata status_header_re = re.compile(r'Status: (\d+) .*?$', re.MULTILINE) diff --git a/openid/test/trustroot.py b/openid/test/trustroot.py index 236649ba..f934ce36 100644 --- a/openid/test/trustroot.py +++ b/openid/test/trustroot.py @@ -1,7 +1,9 @@ import os import unittest + from openid.server.trustroot import TrustRoot + class _ParseTest(unittest.TestCase): def __init__(self, sanity, desc, case): unittest.TestCase.__init__(self) diff --git a/openid/yadis/discover.py b/openid/yadis/discover.py index 1412d744..27fcd013 100644 --- a/openid/yadis/discover.py +++ b/openid/yadis/discover.py @@ -4,11 +4,10 @@ from StringIO import StringIO from openid import fetchers - -from openid.yadis.constants import \ - YADIS_HEADER_NAME, YADIS_CONTENT_TYPE, YADIS_ACCEPT_HEADER +from openid.yadis.constants import YADIS_ACCEPT_HEADER, YADIS_CONTENT_TYPE, YADIS_HEADER_NAME from openid.yadis.parsehtml import MetaNotFound, findHTMLMeta + class DiscoveryFailure(Exception): """Raised when a YADIS protocol error occurs in the discovery process""" identity_url = None @@ -152,4 +151,3 @@ def whereIsYadis(resp): pass return yadis_loc - diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index ef5cadfe..52a8ab32 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -18,13 +18,14 @@ 'expandServices', ] -import sys import random - +import sys from datetime import datetime from time import strptime from openid.oidutil import importElementTree +from openid.yadis import xri + ElementTree = importElementTree() # the different elementtree modules don't have a common exception @@ -40,7 +41,6 @@ except: XMLError = sys.exc_info()[0] -from openid.yadis import xri class XRDSError(Exception): """An error with the XRDS document.""" diff --git a/openid/yadis/filters.py b/openid/yadis/filters.py index d01c221b..43e4f3f1 100644 --- a/openid/yadis/filters.py +++ b/openid/yadis/filters.py @@ -13,6 +13,7 @@ from openid.yadis.etxrd import expandService + class BasicServiceEndpoint(object): """Generic endpoint object that contains parsed service information, as well as a reference to the service element from diff --git a/openid/yadis/parsehtml.py b/openid/yadis/parsehtml.py index 875a089c..c2f80294 100644 --- a/openid/yadis/parsehtml.py +++ b/openid/yadis/parsehtml.py @@ -1,8 +1,8 @@ __all__ = ['findHTMLMeta', 'MetaNotFound'] -from HTMLParser import HTMLParser, HTMLParseError import htmlentitydefs import re +from HTMLParser import HTMLParseError, HTMLParser from openid.yadis.constants import YADIS_HEADER_NAME diff --git a/openid/yadis/services.py b/openid/yadis/services.py index 4753c194..65d88344 100644 --- a/openid/yadis/services.py +++ b/openid/yadis/services.py @@ -1,8 +1,9 @@ # -*- test-case-name: openid.test.test_services -*- +from openid.yadis.discover import DiscoveryFailure, discover +from openid.yadis.etxrd import XRDSError, iterServices, parseXRDS from openid.yadis.filters import mkFilter -from openid.yadis.discover import discover, DiscoveryFailure -from openid.yadis.etxrd import parseXRDS, iterServices, XRDSError + def getServiceEndpoints(input_url, flt=None): """Perform the Yadis protocol on the input URL and return an diff --git a/openid/yadis/xrires.py b/openid/yadis/xrires.py index be663c66..e8fd7e4c 100644 --- a/openid/yadis/xrires.py +++ b/openid/yadis/xrires.py @@ -3,10 +3,11 @@ """ from urllib import urlencode + from openid import fetchers from openid.yadis import etxrd -from openid.yadis.xri import toURINormal from openid.yadis.services import iterServices +from openid.yadis.xri import toURINormal DEFAULT_PROXY = 'http://proxy.xri.net/' diff --git a/setup.py b/setup.py index 9e03e5fb..d68abde6 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,5 @@ -import sys import os +import sys try: from setuptools import setup From 17a8ffc7a5ba93e36e4fba02ca9a0c1d8ad3a598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 23 Nov 2017 11:51:00 +0100 Subject: [PATCH 021/151] Drop support for Python <2.7 --- examples/consumer.py | 60 +++++++------- examples/server.py | 52 ++++++------- openid/consumer/discover.py | 7 -- openid/cryptutil.py | 64 +++------------ openid/extensions/sreg.py | 6 -- openid/server/trustroot.py | 4 - openid/store/filestore.py | 44 +---------- openid/store/nonce.py | 5 +- openid/test/test_association.py | 34 ++++---- openid/test/test_consumer.py | 11 ++- openid/test/test_discover.py | 7 -- openid/test/test_server.py | 134 +++++++++++++++----------------- 12 files changed, 146 insertions(+), 282 deletions(-) diff --git a/examples/consumer.py b/examples/consumer.py index c4f299c0..1d448ded 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -10,6 +10,7 @@ import cgi import cgitb +import optparse import sys import urlparse from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer @@ -470,39 +471,30 @@ def main(host, port, data_path, weak_ssl=False): server.serve_forever() if __name__ == '__main__': - host = 'localhost' - port = 8001 - weak_ssl = False - - try: - import optparse - except ImportError: - pass # Use defaults (for Python 2.2) - else: - parser = optparse.OptionParser('Usage:\n %prog [options]') - parser.add_option( - '-d', '--data-path', dest='data_path', - help='Data directory for storing OpenID consumer state. ' - 'Setting this option implies using a "FileStore."') - parser.add_option( - '-p', '--port', dest='port', type='int', default=port, - help='Port on which to listen for HTTP requests. ' - 'Defaults to port %default.') - parser.add_option( - '-s', '--host', dest='host', default=host, - help='Host on which to listen for HTTP requests. ' - 'Also used for generating URLs. Defaults to %default.') - parser.add_option( - '-w', '--weakssl', dest='weakssl', default=False, - action='store_true', help='Skip ssl cert verification') - - options, args = parser.parse_args() - if args: - parser.error('Expected no arguments. Got %r' % args) - - host = options.host - port = options.port - data_path = options.data_path - weak_ssl = options.weakssl + parser = optparse.OptionParser('Usage:\n %prog [options]') + parser.add_option( + '-d', '--data-path', dest='data_path', + help='Data directory for storing OpenID consumer state. ' + 'Setting this option implies using a "FileStore."') + parser.add_option( + '-p', '--port', dest='port', type='int', default=8001, + help='Port on which to listen for HTTP requests. ' + 'Defaults to port %default.') + parser.add_option( + '-s', '--host', dest='host', default='localhost', + help='Host on which to listen for HTTP requests. ' + 'Also used for generating URLs. Defaults to %default.') + parser.add_option( + '-w', '--weakssl', dest='weakssl', default=False, + action='store_true', help='Skip ssl cert verification') + + options, args = parser.parse_args() + if args: + parser.error('Expected no arguments. Got %r' % args) + + host = options.host + port = options.port + data_path = options.data_path + weak_ssl = options.weakssl main(host, port, data_path, weak_ssl) diff --git a/examples/server.py b/examples/server.py index ddbe5e45..0b12597e 100644 --- a/examples/server.py +++ b/examples/server.py @@ -5,6 +5,7 @@ import cgi import cgitb import Cookie +import optparse import sys import time from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer @@ -689,35 +690,26 @@ def main(host, port, data_path): httpserver.serve_forever() if __name__ == '__main__': - host = 'localhost' - data_path = 'sstore' - port = 8000 - - try: - import optparse - except ImportError: - pass # Use defaults (for Python 2.2) - else: - parser = optparse.OptionParser('Usage:\n %prog [options]') - parser.add_option( - '-d', '--data-path', dest='data_path', default=data_path, - help='Data directory for storing OpenID consumer state. ' - 'Defaults to "%default" in the current directory.') - parser.add_option( - '-p', '--port', dest='port', type='int', default=port, - help='Port on which to listen for HTTP requests. ' - 'Defaults to port %default.') - parser.add_option( - '-s', '--host', dest='host', default=host, - help='Host on which to listen for HTTP requests. ' - 'Also used for generating URLs. Defaults to %default.') - - options, args = parser.parse_args() - if args: - parser.error('Expected no arguments. Got %r' % args) - - host = options.host - port = options.port - data_path = options.data_path + parser = optparse.OptionParser('Usage:\n %prog [options]') + parser.add_option( + '-d', '--data-path', dest='data_path', default='sstore', + help='Data directory for storing OpenID consumer state. ' + 'Defaults to "%default" in the current directory.') + parser.add_option( + '-p', '--port', dest='port', type='int', default=8000, + help='Port on which to listen for HTTP requests. ' + 'Defaults to port %default.') + parser.add_option( + '-s', '--host', dest='host', default='localhost', + help='Host on which to listen for HTTP requests. ' + 'Also used for generating URLs. Defaults to %default.') + + options, args = parser.parse_args() + if args: + parser.error('Expected no arguments. Got %r' % args) + + host = options.host + port = options.port + data_path = options.data_path main(host, port, data_path) diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index e4b9e639..c8db7a59 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -314,13 +314,6 @@ def arrangeByType(service_list, preferred_types): """Rearrange service_list in a new list so services are ordered by types listed in preferred_types. Return the new list.""" - def enumerate(elts): - """Return an iterable that pairs the index of an element with - that element. - - For Python 2.2 compatibility""" - return zip(range(len(elts)), elts) - def bestMatchingService(service): """Return the index of the first matching type, or something higher if no type matches. diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 868877a9..769aa6c5 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -1,9 +1,6 @@ """Module containing a cryptographic-quality source of randomness and other cryptographically useful functionality -Python 2.4 needs no external support for this module, nor does Python -2.3 on a system with /dev/urandom. - Other configurations will need a quality source of random bytes and access to a function that will convert binary strings to long integers. This module will work with the Python Cryptography Toolkit @@ -26,30 +23,21 @@ 'sha256', ] +import hashlib import hmac import os import random from openid.oidutil import fromBase64, toBase64 -try: - import hashlib -except ImportError: - import sha as sha1_module - try: - from Crypto.Hash import SHA256 as sha256_module - except ImportError: - sha256_module = None - -else: - class HashContainer(object): - def __init__(self, hash_constructor): - self.new = hash_constructor - self.digest_size = hash_constructor().digest_size +class HashContainer(object): + def __init__(self, hash_constructor): + self.new = hash_constructor + self.digest_size = hash_constructor().digest_size - sha1_module = HashContainer(hashlib.sha1) - sha256_module = HashContainer(hashlib.sha256) +sha1_module = HashContainer(hashlib.sha1) +sha256_module = HashContainer(hashlib.sha256) def hmacSha1(key, text): return hmac.new(key, text, sha1_module).digest() @@ -57,47 +45,17 @@ def hmacSha1(key, text): def sha1(s): return sha1_module.new(s).digest() -if sha256_module is not None: - def hmacSha256(key, text): - return hmac.new(key, text, sha256_module).digest() +def hmacSha256(key, text): + return hmac.new(key, text, sha256_module).digest() - def sha256(s): - return sha256_module.new(s).digest() +def sha256(s): + return sha256_module.new(s).digest() - SHA256_AVAILABLE = True - -else: - _no_sha256 = NotImplementedError( - 'Use Python 2.5, install pycrypto or install hashlib to use SHA256') - - def hmacSha256(unused_key, unused_text): - raise _no_sha256 - - def sha256(s): - raise _no_sha256 - - SHA256_AVAILABLE = False try: from Crypto.Util.number import long_to_bytes, bytes_to_long except ImportError: import pickle - try: - # Check Python compatiblity by raising an exception on import - # if the needed functionality is not present. Present in - # Python >= 2.3 - pickle.encode_long - pickle.decode_long - except AttributeError: - raise ImportError( - 'No functionality for serializing long integers found') - - # Present in Python >= 2.4 - try: - reversed - except NameError: - def reversed(seq): - return map(seq.__getitem__, xrange(len(seq) - 1, -1, -1)) def longToBinary(l): if l == 0: diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 87e46fa3..18849ece 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -40,12 +40,6 @@ from openid.extension import Extension from openid.message import NamespaceAliasRegistrationError, registerNamespaceAlias -try: - basestring #pylint:disable-msg=W0104 -except NameError: - # For Python 2.2 - basestring = (str, unicode) #pylint:disable-msg=W0622 - __all__ = [ 'SRegRequest', 'SRegResponse', diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index 49863539..6dce8d34 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -87,10 +87,6 @@ def _parseURL(url): return None proto, netloc, path, params, query, frag = urlparse(url) if not path: - # Python <2.4 does not parse URLs with no path properly - if not query and '?' in netloc: - netloc, query = netloc.split('?', 1) - path = '/' path = urlunparse(('', '', path, params, query, frag)) diff --git a/openid/store/filestore.py b/openid/store/filestore.py index adb69dac..e43b1468 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -9,55 +9,15 @@ import string import time from errno import EEXIST, ENOENT +from tempfile import mkstemp from openid import cryptutil, oidutil from openid.association import Association from openid.store import nonce from openid.store.interface import OpenIDStore -try: - from tempfile import mkstemp -except ImportError: - # Python < 2.3 - import warnings - warnings.filterwarnings("ignore", - "tempnam is a potential security risk", - RuntimeWarning, - "openid.store.filestore") - - def mkstemp(dir): - for _ in range(5): - name = os.tempnam(dir) - try: - fd = os.open(name, os.O_CREAT | os.O_EXCL | os.O_RDWR, 0600) - except OSError, why: - if why.errno != EEXIST: - raise - else: - return fd, name - - raise RuntimeError('Failed to get temp file after 5 attempts') - - _filename_allowed = string.ascii_letters + string.digits + '.' -try: - # 2.4 - set -except NameError: - try: - # 2.3 - import sets - except ImportError: - # Python < 2.2 - d = {} - for c in _filename_allowed: - d[c] = None - _isFilenameSafe = d.has_key - del d - else: - _isFilenameSafe = sets.Set(_filename_allowed).__contains__ -else: - _isFilenameSafe = set(_filename_allowed).__contains__ +_isFilenameSafe = set(_filename_allowed).__contains__ def _safe64(s): h64 = oidutil.toBase64(cryptutil.sha1(s)) diff --git a/openid/store/nonce.py b/openid/store/nonce.py index 3814dd1d..89ef096f 100644 --- a/openid/store/nonce.py +++ b/openid/store/nonce.py @@ -33,10 +33,7 @@ def split(nonce_string): formatted time string """ timestamp_str = nonce_string[:time_str_len] - try: - timestamp = timegm(strptime(timestamp_str, time_fmt)) - except AssertionError: # Python 2.2 - timestamp = -1 + timestamp = timegm(strptime(timestamp_str, time_fmt)) if timestamp < 0: raise ValueError('time out of range') return timestamp, nonce_string[time_str_len:] diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 8ab81785..86c2883d 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -112,14 +112,13 @@ def test_sha1(self): sig = assoc.sign(self.pairs) self.failUnlessEqual(sig, expected) - if cryptutil.SHA256_AVAILABLE: - def test_sha256(self): - assoc = association.Association.fromExpiresIn( - 3600, '{sha256SA}', 'very_secret', "HMAC-SHA256") - expected = ('\xfd\xaa\xfe;\xac\xfc*\x988\xad\x05d6-\xeaVy' - '\xd5\xa5Z.<\xa9\xed\x18\x82\\$\x95x\x1c&') - sig = assoc.sign(self.pairs) - self.failUnlessEqual(sig, expected) + def test_sha256(self): + assoc = association.Association.fromExpiresIn( + 3600, '{sha256SA}', 'very_secret', "HMAC-SHA256") + expected = ('\xfd\xaa\xfe;\xac\xfc*\x988\xad\x05d6-\xeaVy' + '\xd5\xa5Z.<\xa9\xed\x18\x82\\$\x95x\x1c&') + sig = assoc.sign(self.pairs) + self.failUnlessEqual(sig, expected) @@ -144,16 +143,15 @@ def test_signSHA1(self): self.failUnlessEqual(signed.getArg(BARE_NS, "xey"), "value", signed) - if cryptutil.SHA256_AVAILABLE: - def test_signSHA256(self): - assoc = association.Association.fromExpiresIn( - 3600, '{sha1}', 'very_secret', "HMAC-SHA256") - signed = assoc.signMessage(self.message) - self.failUnless(signed.getArg(OPENID_NS, "sig")) - self.failUnlessEqual(signed.getArg(OPENID_NS, "signed"), - "assoc_handle,identifier,mode,ns,signed") - self.failUnlessEqual(signed.getArg(BARE_NS, "xey"), "value", - signed) + def test_signSHA256(self): + assoc = association.Association.fromExpiresIn( + 3600, '{sha1}', 'very_secret', "HMAC-SHA256") + signed = assoc.signMessage(self.message) + self.failUnless(signed.getArg(OPENID_NS, "sig")) + self.failUnlessEqual(signed.getArg(OPENID_NS, "signed"), + "assoc_handle,identifier,mode,ns,signed") + self.failUnlessEqual(signed.getArg(BARE_NS, "xey"), "value", + signed) class TestCheckMessageSignature(unittest.TestCase): diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 4bc51122..acab7c04 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -1941,12 +1941,11 @@ class TestOpenID2SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase): session_cls = DiffieHellmanSHA1ConsumerSession message_namespace = OPENID2_NS -if cryptutil.SHA256_AVAILABLE: - class TestOpenID2SHA256(TestDiffieHellmanResponseParameters, unittest.TestCase): - session_cls = DiffieHellmanSHA256ConsumerSession - message_namespace = OPENID2_NS -else: - warnings.warn("Not running SHA256 association session tests.") + +class TestOpenID2SHA256(TestDiffieHellmanResponseParameters, unittest.TestCase): + session_cls = DiffieHellmanSHA256ConsumerSession + message_namespace = OPENID2_NS + class TestNoStore(unittest.TestCase): def setUp(self): diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 5b6c2996..a09a1f2c 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -62,13 +62,6 @@ def runOneTest(self): ### Tests for raising/catching exceptions from the fetcher through the ### discover function -# Python 2.5 displays a message when running this test, which is -# testing the behaviour in the presence of string exceptions, -# deprecated or not, so tell it no to complain when this particular -# string exception is raised. -warnings.filterwarnings('ignore', 'raising a string.*', DeprecationWarning, - r'^openid\.test\.test_discover$', 77) - class ErrorRaisingFetcher(object): """Just raise an exception when fetch is called""" diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 0b3b6cc7..a19a734e 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -2,7 +2,6 @@ """ import cgi import unittest -import warnings from urlparse import urlparse from openid import association, cryptutil, oidutil @@ -1434,58 +1433,54 @@ def test_dhSHA1(self): secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha1) self.failUnlessEqual(secret, self.assoc.secret) + def test_dhSHA256(self): + self.assoc = self.signatory.createAssociation( + dumb=False, assoc_type='HMAC-SHA256') + from openid.dh import DiffieHellman + from openid.server.server import DiffieHellmanSHA256ServerSession + consumer_dh = DiffieHellman.fromDefaults() + cpub = consumer_dh.public + server_dh = DiffieHellman.fromDefaults() + session = DiffieHellmanSHA256ServerSession(server_dh, cpub) + self.request = server.AssociateRequest(session, 'HMAC-SHA256') + response = self.request.answer(self.assoc) + rfg = lambda f: response.fields.getArg(OPENID_NS, f) + self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA256") + self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) + self.failIf(rfg("mac_key")) + self.failUnlessEqual(rfg("session_type"), "DH-SHA256") + self.failUnless(rfg("enc_mac_key")) + self.failUnless(rfg("dh_server_public")) + + enc_key = rfg("enc_mac_key").decode('base64') + spub = cryptutil.base64ToLong(rfg("dh_server_public")) + secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha256) + self.failUnlessEqual(secret, self.assoc.secret) - if not cryptutil.SHA256_AVAILABLE: - warnings.warn("Not running SHA256 tests.") - else: - def test_dhSHA256(self): - self.assoc = self.signatory.createAssociation( - dumb=False, assoc_type='HMAC-SHA256') - from openid.dh import DiffieHellman - from openid.server.server import DiffieHellmanSHA256ServerSession - consumer_dh = DiffieHellman.fromDefaults() - cpub = consumer_dh.public - server_dh = DiffieHellman.fromDefaults() - session = DiffieHellmanSHA256ServerSession(server_dh, cpub) - self.request = server.AssociateRequest(session, 'HMAC-SHA256') - response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) - self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA256") - self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) - self.failIf(rfg("mac_key")) - self.failUnlessEqual(rfg("session_type"), "DH-SHA256") - self.failUnless(rfg("enc_mac_key")) - self.failUnless(rfg("dh_server_public")) - - enc_key = rfg("enc_mac_key").decode('base64') - spub = cryptutil.base64ToLong(rfg("dh_server_public")) - secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha256) - self.failUnlessEqual(secret, self.assoc.secret) - - def test_protoError256(self): - from openid.consumer.consumer import \ - DiffieHellmanSHA256ConsumerSession - - s256_session = DiffieHellmanSHA256ConsumerSession() - - invalid_s256 = {'openid.assoc_type':'HMAC-SHA1', - 'openid.session_type':'DH-SHA256',} - invalid_s256.update(s256_session.getRequest()) - - invalid_s256_2 = {'openid.assoc_type':'MONKEY-PIRATE', - 'openid.session_type':'DH-SHA256',} - invalid_s256_2.update(s256_session.getRequest()) - - bad_request_argss = [ - invalid_s256, - invalid_s256_2, - ] - - for request_args in bad_request_argss: - message = Message.fromPostArgs(request_args) - self.failUnlessRaises(server.ProtocolError, - server.AssociateRequest.fromMessage, - message) + def test_protoError256(self): + from openid.consumer.consumer import \ + DiffieHellmanSHA256ConsumerSession + + s256_session = DiffieHellmanSHA256ConsumerSession() + + invalid_s256 = {'openid.assoc_type':'HMAC-SHA1', + 'openid.session_type':'DH-SHA256',} + invalid_s256.update(s256_session.getRequest()) + + invalid_s256_2 = {'openid.assoc_type':'MONKEY-PIRATE', + 'openid.session_type':'DH-SHA256',} + invalid_s256_2.update(s256_session.getRequest()) + + bad_request_argss = [ + invalid_s256, + invalid_s256_2, + ] + + for request_args in bad_request_argss: + message = Message.fromPostArgs(request_args) + self.failUnlessRaises(server.ProtocolError, + server.AssociateRequest.fromMessage, + message) def test_protoError(self): from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession @@ -1739,25 +1734,22 @@ def test_associate3(self): self.failUnlessEqual(response.fields.getArg(OPENID_NS, "session_type"), 'DH-SHA256') - if not cryptutil.SHA256_AVAILABLE: - warnings.warn("Not running SHA256 tests.") - else: - def test_associate4(self): - """DH-SHA256 association session""" - self.server.negotiator.setAllowedTypes( - [('HMAC-SHA256', 'DH-SHA256')]) - query = { - 'openid.dh_consumer_public': - 'ALZgnx8N5Lgd7pCj8K86T/DDMFjJXSss1SKoLmxE72kJTzOtG6I2PaYrHX' - 'xku4jMQWSsGfLJxwCZ6280uYjUST/9NWmuAfcrBfmDHIBc3H8xh6RBnlXJ' - '1WxJY3jHd5k1/ZReyRZOxZTKdF/dnIqwF8ZXUwI6peV0TyS/K1fOfF/s', - 'openid.assoc_type': 'HMAC-SHA256', - 'openid.session_type': 'DH-SHA256', - } - message = Message.fromPostArgs(query) - request = server.AssociateRequest.fromMessage(message) - response = self.server.openid_associate(request) - self.failUnless(response.fields.hasKey(OPENID_NS, "assoc_handle")) + def test_associate4(self): + """DH-SHA256 association session""" + self.server.negotiator.setAllowedTypes( + [('HMAC-SHA256', 'DH-SHA256')]) + query = { + 'openid.dh_consumer_public': + 'ALZgnx8N5Lgd7pCj8K86T/DDMFjJXSss1SKoLmxE72kJTzOtG6I2PaYrHX' + 'xku4jMQWSsGfLJxwCZ6280uYjUST/9NWmuAfcrBfmDHIBc3H8xh6RBnlXJ' + '1WxJY3jHd5k1/ZReyRZOxZTKdF/dnIqwF8ZXUwI6peV0TyS/K1fOfF/s', + 'openid.assoc_type': 'HMAC-SHA256', + 'openid.session_type': 'DH-SHA256', + } + message = Message.fromPostArgs(query) + request = server.AssociateRequest.fromMessage(message) + response = self.server.openid_associate(request) + self.failUnless(response.fields.hasKey(OPENID_NS, "assoc_handle")) def test_missingSessionTypeOpenID2(self): """Make sure session_type is required in OpenID 2""" From 95a143cc06edae839178c9d99d6c89dc43b99d40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 23 Nov 2017 13:29:52 +0100 Subject: [PATCH 022/151] Update logging --- openid/consumer/consumer.py | 85 +++++++++++++++---------------------- openid/consumer/discover.py | 4 +- openid/extensions/sreg.py | 5 ++- openid/kvform.py | 6 ++- openid/oidutil.py | 9 ++-- openid/server/server.py | 17 +++----- openid/server/trustroot.py | 8 ++-- openid/store/filestore.py | 6 ++- openid/test/support.py | 4 +- 9 files changed, 67 insertions(+), 77 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 4b5dfce0..eaa48472 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -206,6 +206,8 @@ 'SUCCESS', 'FAILURE', 'CANCEL', 'SETUP_NEEDED', ] +_LOGGER = logging.getLogger(__name__) + def makeKVPost(request_message, server_url): """Make a Direct Request to an OpenID Provider and return the @@ -656,7 +658,7 @@ def _checkReturnTo(self, message, return_to): try: self._verifyReturnToArgs(message.toPostArgs()) except ProtocolError, why: - logging.exception("Verifying return_to arguments: %s" % (why[0],)) + _LOGGER.exception("Verifying return_to arguments: %s", why) return False # Check the return_to base URL against the one in the message. @@ -722,9 +724,8 @@ def _doIdRes(self, message, endpoint, return_to): # Verify discovery information: endpoint = self._verifyDiscoveryResults(message, endpoint) - logging.info("Received id_res response from %s using association %s" % - (endpoint.server_url, - message.getArg(OPENID_NS, 'assoc_handle'))) + _LOGGER.info("Received id_res response from %s using association %s", + endpoint.server_url, message.getArg(OPENID_NS, 'assoc_handle')) self._idResCheckSignature(message, endpoint.server_url) @@ -916,7 +917,7 @@ def _verifyDiscoveryResultsOpenID2(self, resp_msg, endpoint): # endpoints and responses that didn't match the original # request. if not endpoint: - logging.info('No pre-discovered information supplied.') + _LOGGER.info('No pre-discovered information supplied.') endpoint = self._discoverAndVerify(to_match.claimed_id, [to_match]) else: # The claimed ID matches, so we use the endpoint that we @@ -925,10 +926,8 @@ def _verifyDiscoveryResultsOpenID2(self, resp_msg, endpoint): try: self._verifyDiscoverySingle(endpoint, to_match) except ProtocolError, e: - logging.exception( - "Error attempting to use stored discovery information: " + - str(e)) - logging.info("Attempting discovery to verify endpoint") + _LOGGER.exception("Error attempting to use stored discovery information: %s", e) + _LOGGER.info("Attempting discovery to verify endpoint") endpoint = self._discoverAndVerify( to_match.claimed_id, [to_match]) @@ -970,9 +969,8 @@ def _verifyDiscoveryResultsOpenID1(self, resp_msg, endpoint): except TypeURIMismatch: self._verifyDiscoverySingle(endpoint, to_match_1_0) except ProtocolError, e: - logging.exception("Error attempting to use stored discovery information: " + - str(e)) - logging.info("Attempting discovery to verify endpoint") + _LOGGER.exception("Error attempting to use stored discovery information: %s", e) + _LOGGER.info("Attempting discovery to verify endpoint") else: return endpoint @@ -1042,7 +1040,7 @@ def _discoverAndVerify(self, claimed_id, to_match_endpoints): @raises DiscoveryFailure: when discovery fails. """ - logging.info('Performing discovery on %s' % (claimed_id,)) + _LOGGER.info('Performing discovery on %s', claimed_id) _, services = self._discover(claimed_id) if not services: raise DiscoveryFailure('No OpenID information found at %s' % @@ -1069,10 +1067,9 @@ def _verifyDiscoveredServices(self, claimed_id, services, to_match_endpoints): # succeeded. Return this endpoint. return endpoint else: - logging.error('Discovery verification failure for %s' % - (claimed_id,)) + _LOGGER.error('Discovery verification failure for %s', claimed_id) for failure_message in failure_messages: - logging.error(' * Endpoint mismatch: ' + failure_message) + _LOGGER.error(' * Endpoint mismatch: %s', failure_message) raise DiscoveryFailure( 'No matching endpoint found after discovering %s' @@ -1084,14 +1081,14 @@ def _checkAuth(self, message, server_url): @returns: True if the request is valid. @rtype: bool """ - logging.info('Using OpenID check_authentication') + _LOGGER.info('Using OpenID check_authentication') request = self._createCheckAuthRequest(message) if request is None: return False try: response = self._makeKVPost(request, server_url) except (fetchers.HTTPFetchingError, ServerError), e: - logging.exception('check_authentication failed: %s' % (e[0],)) + _LOGGER.exception('check_authentication failed: %s', e) return False else: return self._processCheckAuthResponse(response, server_url) @@ -1103,12 +1100,12 @@ def _createCheckAuthRequest(self, message): signed = message.getArg(OPENID_NS, 'signed') if signed: for k in signed.split(','): - logging.info(k) + _LOGGER.info(k) val = message.getAliasedArg(k) # Signed value is missing if val is None: - logging.info('Missing signed field %r' % (k,)) + _LOGGER.info('Missing signed field %r', k) return None check_auth_message = message.copy() @@ -1123,18 +1120,16 @@ def _processCheckAuthResponse(self, response, server_url): invalidate_handle = response.getArg(OPENID_NS, 'invalidate_handle') if invalidate_handle is not None: - logging.info( - 'Received "invalidate_handle" from server %s' % (server_url,)) + _LOGGER.info('Received "invalidate_handle" from server %s', server_url) if self.store is None: - logging.error('Unexpectedly got invalidate_handle without ' - 'a store!') + _LOGGER.error('Unexpectedly got invalidate_handle without a store!') else: self.store.removeAssociation(server_url, invalidate_handle) if is_valid == 'true': return True else: - logging.error('Server responds that checkAuth call is not valid') + _LOGGER.error('Server responds that checkAuth call is not valid') return False def _getAssociation(self, endpoint): @@ -1187,10 +1182,8 @@ def _negotiateAssociation(self, endpoint): except ServerError, why: # Do not keep trying, since it rejected the # association type that it told us to use. - logging.error('Server %s refused its suggested association ' - 'type: session_type=%s, assoc_type=%s' - % (endpoint.server_url, session_type, - assoc_type)) + _LOGGER.error('Server %s refused its suggested association type: session_type=%s, assoc_type=%s', + endpoint.server_url, session_type, assoc_type) return None else: return assoc @@ -1210,17 +1203,14 @@ def _extractSupportedAssociationType(self, server_error, endpoint, # should be considered a total failure. if server_error.error_code != 'unsupported-type' or \ server_error.message.isOpenID1(): - logging.error( - 'Server error when requesting an association from %r: %s' - % (endpoint.server_url, server_error.error_text)) + _LOGGER.error('Server error when requesting an association from %r: %s', + endpoint.server_url, server_error.error_text) return None # The server didn't like the association/session type # that we sent, and it sent us back a message that # might tell us how to handle it. - logging.error( - 'Unsupported association type %s: %s' % (assoc_type, - server_error.error_text,)) + _LOGGER.error('Unsupported association type %s: %s', assoc_type, server_error.error_text) # Extract the session_type and assoc_type from the # error message @@ -1228,13 +1218,11 @@ def _extractSupportedAssociationType(self, server_error, endpoint, session_type = server_error.message.getArg(OPENID_NS, 'session_type') if assoc_type is None or session_type is None: - logging.error('Server responded with unsupported association ' - 'session but did not supply a fallback.') + _LOGGER.error('Server responded with unsupported association session but did not supply a fallback.') return None elif not self.negotiator.isAllowed(assoc_type, session_type): - fmt = ('Server sent unsupported session/association type: ' - 'session_type=%s, assoc_type=%s') - logging.error(fmt % (session_type, assoc_type)) + _LOGGER.error('Server sent unsupported session/association type: session_type=%s, assoc_type=%s', + session_type, assoc_type) return None else: return assoc_type, session_type @@ -1255,18 +1243,16 @@ def _requestAssociation(self, endpoint, assoc_type, session_type): try: response = self._makeKVPost(args, endpoint.server_url) except fetchers.HTTPFetchingError, why: - logging.exception('openid.associate request failed: %s' % (why[0],)) + _LOGGER.exception('openid.associate request failed: %s', why) return None try: assoc = self._extractAssociation(response, assoc_session) except KeyError, why: - logging.exception('Missing required parameter in response from %s: %s' - % (endpoint.server_url, why[0])) + _LOGGER.exception('Missing required parameter in response from %s: %s', endpoint.server_url, why) return None except ProtocolError, why: - logging.exception('Protocol error parsing response from %s: %s' % ( - endpoint.server_url, why[0])) + _LOGGER.exception('Protocol error parsing response from %s: %s', endpoint.server_url, why) return None else: return assoc @@ -1342,8 +1328,7 @@ def _getOpenID1SessionType(self, assoc_response): # OpenID 1, but we'll accept it anyway, while issuing a # warning. if session_type == 'no-encryption': - logging.warn('OpenID server sent "no-encryption"' - 'for OpenID 1.X') + _LOGGER.warn('OpenID server sent "no-encryption" for OpenID 1.X') # Missing or empty session type is the way to flag a # 'no-encryption' response. Change the session type to @@ -1593,8 +1578,7 @@ def getMessage(self, realm, return_to=None, immediate=False): else: assoc_log_msg = 'using stateless mode.' - logging.info("Generated %s request to %s %s" % - (mode, self.endpoint.server_url, assoc_log_msg)) + _LOGGER.info("Generated %s request to %s %s", mode, self.endpoint.server_url, assoc_log_msg) return message @@ -1773,8 +1757,7 @@ def getSignedNS(self, ns_uri): for key in msg_args.iterkeys(): if not self.isSigned(ns_uri, key): - logging.info("SuccessResponse.getSignedNS: (%s, %s) not signed." - % (ns_uri, key)) + _LOGGER.info("SuccessResponse.getSignedNS: (%s, %s) not signed.", ns_uri, key) return None return msg_args diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index c8db7a59..5764dc5f 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -24,6 +24,8 @@ from openid.yadis.etxrd import XRD_NS_2_0, XRDSError, nsTag from openid.yadis.services import applyFilter as extractServices +_LOGGER = logging.getLogger(__name__) + OPENID_1_0_NS = 'http://openid.net/xmlns/1.0' OPENID_IDP_2_0_TYPE = 'http://specs.openid.net/auth/2.0/server' OPENID_2_0_TYPE = 'http://specs.openid.net/auth/2.0/signon' @@ -413,7 +415,7 @@ def discoverXRI(iname): for service_element in services: endpoints.extend(flt.getServiceEndpoints(iname, service_element)) except XRDSError: - logging.exception('xrds error on ' + iname) + _LOGGER.exception('xrds error on %s', iname) for endpoint in endpoints: # Is there a way to pass this through the filter to the endpoint diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 18849ece..e147cf16 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -50,6 +50,8 @@ 'supportsSReg', ] +_LOGGER = logging.getLogger(__name__) + # The data fields that are listed in the sreg spec data_fields = { 'fullname':'Full Name', @@ -89,8 +91,7 @@ def checkFieldName(field_name): try: registerNamespaceAlias(ns_uri_1_1, 'sreg') except NamespaceAliasRegistrationError, e: - logging.exception('registerNamespaceAlias(%r, %r) failed: %s' % (ns_uri_1_1, - 'sreg', str(e),)) + _LOGGER.exception('registerNamespaceAlias(%r, %r) failed: %s', ns_uri_1_1, 'sreg', e) def supportsSReg(endpoint): """Does the given endpoint advertise support for simple diff --git a/openid/kvform.py b/openid/kvform.py index 846cf74c..8252d91a 100644 --- a/openid/kvform.py +++ b/openid/kvform.py @@ -3,6 +3,8 @@ import logging import types +_LOGGER = logging.getLogger(__name__) + class KVFormError(ValueError): pass @@ -22,7 +24,7 @@ def err(msg): if strict: raise KVFormError(formatted) else: - logging.warn(formatted) + _LOGGER.warn(formatted) lines = [] for k, v in seq: @@ -73,7 +75,7 @@ def err(msg): if strict: raise KVFormError(formatted) else: - logging.warn(formatted) + _LOGGER.warn(formatted) lines = data.split('\n') if lines[-1]: diff --git a/openid/oidutil.py b/openid/oidutil.py index 36d0af10..a92b453a 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -13,6 +13,8 @@ import urlparse from urllib import urlencode +_LOGGER = logging.getLogger(__name__) + elementtree_modules = [ 'lxml.etree', 'xml.etree.cElementTree', @@ -77,8 +79,8 @@ def importElementTree(module_names=None): except (SystemExit, MemoryError, AssertionError): raise except: - logging.exception('Not using ElementTree library %r because it failed to ' - 'parse a trivial document: %s' % mod_name) + logging.exception('Not using ElementTree library %r because it failed to parse a trivial document: %s', + mod_name) else: return ElementTree else: @@ -105,8 +107,7 @@ def log(message, level=0): @returns: Nothing. """ - logging.error("This is a legacy log message, please use the " - "logging module. Message: %s", message) + logging.error("This is a legacy log message, please use the logging module. Message: %s", message) def appendArgs(url, args): """Append query arguments to a HTTP(s) URL. If the URL already has diff --git a/openid/server/server.py b/openid/server/server.py index 7cd1ae99..1e456e0a 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -130,6 +130,8 @@ from openid.store.nonce import mkNonce from openid.urinorm import urinorm +_LOGGER = logging.getLogger(__name__) + HTTP_OK = 200 HTTP_REDIRECT = 302 HTTP_ERROR = 400 @@ -420,7 +422,7 @@ def fromMessage(klass, message, op_endpoint=UNUSED): if message.isOpenID1(): session_type = message.getArg(OPENID_NS, 'session_type') if session_type == 'no-encryption': - logging.warn('Received OpenID 1 request with a no-encryption ' + _LOGGER.warn('Received OpenID 1 request with a no-encryption ' 'assocaition session type. Continuing anyway.') elif not session_type: session_type = 'no-encryption' @@ -1179,17 +1181,13 @@ def verify(self, assoc_handle, message): """ assoc = self.getAssociation(assoc_handle, dumb=True) if not assoc: - logging.error("failed to get assoc with handle %r to verify " - "message %r" - % (assoc_handle, message)) + _LOGGER.error("failed to get assoc with handle %r to verify message %r", assoc_handle, message) return False try: valid = assoc.checkMessageSignature(message) except ValueError, ex: - logging.exception("Error in verifying %s with %s: %s" % (message, - assoc, - ex)) + _LOGGER.exception("Error in verifying %s with %s: %s", message, assoc, ex) return False return valid @@ -1294,9 +1292,8 @@ def getAssociation(self, assoc_handle, dumb, checkExpiration=True): key = self._normal_key assoc = self.store.getAssociation(key, assoc_handle) if assoc is not None and assoc.expiresIn <= 0: - logging.info("requested %sdumb key %r is expired (by %s seconds)" % - ((not dumb) and 'not-' or '', - assoc_handle, assoc.expiresIn)) + _LOGGER.info("requested %sdumb key %r is expired (by %s seconds)", + (not dumb) and 'not-' or '', assoc_handle, assoc.expiresIn) if checkExpiration: self.store.removeAssociation(key, assoc_handle) assoc = None diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index 6dce8d34..955a0d8b 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -24,6 +24,8 @@ from openid import urinorm from openid.yadis import services +_LOGGER = logging.getLogger(__name__) + ############################################ _protocols = ['http', 'https'] _top_level_domains = [ @@ -443,12 +445,12 @@ def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): try: allowable_urls = _vrfy(realm.buildDiscoveryURL()) except RealmVerificationRedirected, err: - logging.exception(str(err)) + _LOGGER.exception(str(err)) return False if returnToMatches(allowable_urls, return_to): return True else: - logging.error("Failed to validate return_to %r for realm %r, was not " - "in %s" % (return_to, realm_str, allowable_urls)) + _LOGGER.error("Failed to validate return_to %r for realm %r, was not in %s", + return_to, realm_str, allowable_urls) return False diff --git a/openid/store/filestore.py b/openid/store/filestore.py index e43b1468..3ec4c599 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -16,6 +16,8 @@ from openid.store import nonce from openid.store.interface import OpenIDStore +_LOGGER = logging.getLogger(__name__) + _filename_allowed = string.ascii_letters + string.digits + '.' _isFilenameSafe = set(_filename_allowed).__contains__ @@ -332,8 +334,8 @@ def _allAssocs(self): association_file = file(association_filename, 'rb') except IOError, why: if why.errno == ENOENT: - logging.exception("%s disappeared during %s._allAssocs" % ( - association_filename, self.__class__.__name__)) + _LOGGER.exception("%s disappeared during %s._allAssocs", + association_filename, self.__class__.__name__) else: raise else: diff --git a/openid/test/support.py b/openid/test/support.py index 3901e25d..d61973c6 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -13,7 +13,7 @@ def shouldFlush(self): return False def emit(self, record): - self.messages.append(record.__dict__) + self.messages.append(record) class OpenIDTestMixin(object): def failUnlessOpenIDValueEquals(self, msg, key, expected, ns=None): @@ -57,7 +57,7 @@ def failUnlessLogMatches(self, *prefixes): number of prefixes is different than the number of log messages. """ - messages = [r['msg'] for r in self.messages] + messages = [r.getMessage() for r in self.messages] assert len(prefixes) == len(messages), \ "Expected log prefixes %r, got %r" % (prefixes, messages) From f58d7cee3e9f4bff9854dc10ffcd105fb3bc6619 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 24 Nov 2017 15:09:25 +0100 Subject: [PATCH 023/151] Drop uncaught_exceptions from ExceptionWrappingFetcher --- openid/fetchers.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/openid/fetchers.py b/openid/fetchers.py index d4b80290..b30f8954 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -159,13 +159,7 @@ def __init__(self, why=None): self.why = why class ExceptionWrappingFetcher(HTTPFetcher): - """Fetcher that wraps another fetcher, causing all exceptions - - @cvar uncaught_exceptions: Exceptions that should be exposed to the - user if they are raised by the fetch call - """ - - uncaught_exceptions = (SystemExit, KeyboardInterrupt, MemoryError) + """Fetcher wrapper which wraps all exceptions to `HTTPFetchingError`.""" def __init__(self, fetcher): self.fetcher = fetcher @@ -173,9 +167,7 @@ def __init__(self, fetcher): def fetch(self, *args, **kwargs): try: return self.fetcher.fetch(*args, **kwargs) - except self.uncaught_exceptions: - raise - except: + except Exception: exc_cls, exc_inst = sys.exc_info()[:2] if exc_inst is None: # string exceptions From 8f0ff0d27771514d16a415b8ac76d18ea0809f38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 24 Nov 2017 15:07:29 +0100 Subject: [PATCH 024/151] Pepify and add flake8 --- .flake8 | 5 + .travis.yml | 3 +- Makefile | 7 +- admin/builddiscover.py | 2 + admin/gettlds.py | 2 +- admin/runtests | 39 +- contrib/associate | 3 + contrib/openid-parse | 26 +- contrib/upgrade-store-1.1-to-2.0 | 36 +- examples/consumer.py | 28 +- examples/discover | 12 +- examples/djopenid/consumer/models.py | 4 +- examples/djopenid/consumer/urls.py | 3 +- examples/djopenid/consumer/views.py | 20 +- examples/djopenid/manage.py | 6 +- examples/djopenid/server/models.py | 4 +- examples/djopenid/server/tests.py | 8 +- examples/djopenid/server/urls.py | 3 +- examples/djopenid/server/views.py | 31 +- examples/djopenid/settings.py | 8 +- examples/djopenid/urls.py | 2 +- examples/djopenid/util.py | 33 +- examples/djopenid/views.py | 2 +- examples/server.py | 62 ++- openid/__init__.py | 2 +- openid/association.py | 48 +-- openid/consumer/consumer.py | 113 +++-- openid/consumer/discover.py | 42 +- openid/consumer/html_parse.py | 38 +- openid/cryptutil.py | 28 +- openid/dh.py | 16 +- openid/extension.py | 2 + openid/extensions/ax.py | 40 +- openid/extensions/draft/pape2.py | 15 +- openid/extensions/draft/pape5.py | 16 +- openid/extensions/sreg.py | 37 +- openid/fetchers.py | 28 +- openid/kvform.py | 7 +- openid/message.py | 33 +- openid/oidutil.py | 29 +- openid/server/server.py | 108 ++--- openid/server/trustroot.py | 25 +- openid/sreg.py | 4 +- openid/store/filestore.py | 29 +- openid/store/interface.py | 1 + openid/store/memstore.py | 2 +- openid/store/nonce.py | 5 +- openid/store/sqlstore.py | 25 +- openid/test/cryptutil.py | 46 +- openid/test/datadriven.py | 4 +- openid/test/dh.py | 9 +- openid/test/discoverdata.py | 50 ++- openid/test/kvform.py | 43 +- openid/test/linkparse.py | 5 + openid/test/oidutil.py | 20 +- openid/test/storetest.py | 49 ++- openid/test/support.py | 36 +- openid/test/test_accept.py | 11 +- openid/test/test_association.py | 27 +- openid/test/test_association_response.py | 73 ++-- openid/test/test_auth_request.py | 12 +- openid/test/test_ax.py | 243 +++++------ openid/test/test_consumer.py | 511 +++++++++++------------ openid/test/test_discover.py | 122 +++--- openid/test/test_etxrd.py | 17 +- openid/test/test_examples.py | 13 +- openid/test/test_extension.py | 1 + openid/test/test_fetchers.py | 51 +-- openid/test/test_htmldiscover.py | 3 +- openid/test/test_message.py | 318 +++++++------- openid/test/test_negotiation.py | 13 +- openid/test/test_nonce.py | 10 +- openid/test/test_openidyadis.py | 31 +- openid/test/test_pape_draft2.py | 68 +-- openid/test/test_pape_draft5.py | 70 ++-- openid/test/test_parsehtml.py | 9 +- openid/test/test_rpverify.py | 14 +- openid/test/test_server.py | 438 ++++++++++--------- openid/test/test_sreg.py | 112 ++--- openid/test/test_symbol.py | 1 + openid/test/test_urinorm.py | 3 +- openid/test/test_verifydisco.py | 77 ++-- openid/test/test_xri.py | 9 +- openid/test/test_xrires.py | 2 - openid/test/test_yadis_discover.py | 19 +- openid/test/trustroot.py | 5 + openid/urinorm.py | 23 +- openid/yadis/__init__.py | 2 +- openid/yadis/accept.py | 30 +- openid/yadis/constants.py | 2 +- openid/yadis/discover.py | 6 +- openid/yadis/etxrd.py | 33 +- openid/yadis/filters.py | 10 +- openid/yadis/manager.py | 3 +- openid/yadis/parsehtml.py | 17 +- openid/yadis/services.py | 3 +- openid/yadis/xri.py | 20 +- openid/yadis/xrires.py | 9 +- pylintrc | 40 -- setup.py | 22 +- 100 files changed, 1954 insertions(+), 1853 deletions(-) create mode 100644 .flake8 delete mode 100644 pylintrc diff --git a/.flake8 b/.flake8 new file mode 100644 index 00000000..75ab4379 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +# Ignore E123 - enforce hang-closing instead +ignore = E123,W503 +max-complexity = 22 diff --git a/.travis.yml b/.travis.yml index 14636d8c..fe119573 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,8 +3,9 @@ language: python python: - 2.7 -before_install: pip install Django pycrypto lxml isort +before_install: pip install Django pycrypto lxml isort flake8 install: python setup.py install script: - make check-isort + - make check-flake8 - make test diff --git a/Makefile b/Makefile index 8b779090..2cba2775 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test coverage isort check-isort +.PHONY: test coverage isort check-all check-isort check-flake8 test: python admin/runtests @@ -12,5 +12,10 @@ coverage: isort: isort --recursive . +check-all: check-isort check-flake8 + check-isort: isort --check-only --diff --recursive . + +check-flake8: + flake8 --format=pylint . diff --git a/admin/builddiscover.py b/admin/builddiscover.py index 011ab883..ef4ede92 100755 --- a/admin/builddiscover.py +++ b/admin/builddiscover.py @@ -29,6 +29,7 @@ """ + def buildDiscover(base_url, out_dir): """Convert all files in a directory to apache mod_asis files in another directory.""" @@ -63,6 +64,7 @@ def writeTestFile(test_name): manifest_file.write(chunk) manifest_file.close() + if __name__ == '__main__': import sys buildDiscover(*sys.argv[1:]) diff --git a/admin/gettlds.py b/admin/gettlds.py index f473224d..b2a7c92c 100644 --- a/admin/gettlds.py +++ b/admin/gettlds.py @@ -21,7 +21,7 @@ 'ruby': ("%w'", "", " ", "", "'"), - } +} lang = sys.argv[1] prefix, line_prefix, separator, line_suffix, suffix = langs[lang] diff --git a/admin/runtests b/admin/runtests index b2a3a79f..db7a647e 100755 --- a/admin/runtests +++ b/admin/runtests @@ -1,11 +1,14 @@ #!/usr/bin/env python -import os.path, sys, warnings +import os.path +import sys +import warnings test_modules = [ 'cryptutil', 'oidutil', 'dh', - ] +] + def fixpath(): try: @@ -17,10 +20,11 @@ def fixpath(): print "putting %s in sys.path" % (parent,) sys.path.insert(0, parent) + def otherTests(): failed = [] for module_name in test_modules: - print 'Testing %s...' % (module_name,) , + print 'Testing %s...' % (module_name,), sys.stdout.flush() module_name = 'openid.test.' + module_name try: @@ -31,17 +35,15 @@ def otherTests(): else: try: test_mod.test() - except (SystemExit, KeyboardInterrupt): - raise - except: + except Exception: sys.excepthook(*sys.exc_info()) failed.append(module_name) else: print 'Succeeded.' - return failed + def pyunitTests(): import unittest pyunit_module_names = [ @@ -63,16 +65,16 @@ def pyunitTests(): 'pape_draft5', 'rpverify', 'extension', - ] + ] pyunit_modules = [ __import__('openid.test.test_%s' % (name,), {}, {}, ['unused']) for name in pyunit_module_names - ] + ] try: from openid.test import test_examples - except ImportError, e: + except ImportError as e: if 'twill' in str(e): warnings.warn("Could not import twill; skipping test_examples.") else: @@ -98,7 +100,7 @@ def pyunitTests(): 'test_urinorm', 'test_yadis_discover', 'trustroot', - ] + ] loader = unittest.TestLoader() s = unittest.TestSuite() @@ -110,18 +112,17 @@ def pyunitTests(): m = __import__('openid.test.%s' % (name,), {}, {}, ['unused']) try: s.addTest(m.pyUnitTests()) - except AttributeError, ex: + except AttributeError as ex: # because the AttributeError doesn't actually say which # object it was. print "Error loading tests from %s:" % (name,) raise - runner = unittest.TextTestRunner() # verbosity=2) + runner = unittest.TextTestRunner() # verbosity=2) return runner.run(s) - def splitDir(d, count): # in python2.4 and above, it's easier to spell this as # d.rsplit(os.sep, count) @@ -130,7 +131,6 @@ def splitDir(d, count): return d - def _import_djopenid(): """Import djopenid from examples/ @@ -153,7 +153,6 @@ def _import_djopenid(): sys.modules['djopenid'] = djopenid - def django_tests(): """Runs tests from examples/djopenid. @@ -167,11 +166,12 @@ def django_tests(): try: import django.test.simple - except ImportError, e: + except ImportError as e: warnings.warn("django.test.simple not found; " "django examples not tested.") return 0 - import djopenid.server.models, djopenid.consumer.models + import djopenid.server.models + import djopenid.consumer.models print "Testing Django examples:" # These tests do get put in to a pyunit test suite, so we could run them @@ -180,12 +180,14 @@ def django_tests(): return django.test.simple.run_tests([djopenid.server.models, djopenid.consumer.models]) + try: bool except NameError: def bool(x): return not not x + def main(): fixpath() other_failed = otherTests() @@ -200,5 +202,6 @@ def main(): (django_failures > 0)) return failed + if __name__ == '__main__': sys.exit(main() and 1 or 0) diff --git a/contrib/associate b/contrib/associate index 4cb05c31..76fe5b0e 100755 --- a/contrib/associate +++ b/contrib/associate @@ -10,6 +10,7 @@ from openid.consumer.discover import OpenIDServiceEndpoint from datetime import datetime + def verboseAssociation(assoc): """A more verbose representation of an Association. """ @@ -24,6 +25,7 @@ def verboseAssociation(assoc): """ return fmt % d + def main(): if not sys.argv[1:]: print "Usage: %s ENDPOINT_URL..." % (sys.argv[0],) @@ -43,5 +45,6 @@ def main(): else: print " ...no association." + if __name__ == '__main__': main() diff --git a/contrib/openid-parse b/contrib/openid-parse index 21ab18df..ac2c5dff 100644 --- a/contrib/openid-parse +++ b/contrib/openid-parse @@ -9,12 +9,16 @@ Requires the 'xsel' program to get the contents of the clipboard. from pprint import pformat from urlparse import urlsplit, urlunsplit -import cgi, re, subprocess, sys +import cgi +import re +import subprocess +import sys from openid import message OPENID_SORT_ORDER = ['mode', 'identity', 'claimed_id'] + class NoQuery(Exception): def __init__(self, url): self.url = url @@ -42,7 +46,7 @@ def main(): for url in urls: try: queries.append(queryFromURL(url)) - except NoQuery, err: + except NoQuery as err: errors.append(err) queries.extend(queriesFromLogs(source)) @@ -73,7 +77,7 @@ def openidFromQuery(query): try: msg = message.Message.fromPostArgs(unlistify(query)) s = formatOpenIDMessage(msg) - except Exception, err: + except Exception as err: # XXX - side effect. sys.stderr.write(str(err)) s = pformat(query) @@ -103,8 +107,7 @@ def formatOpenIDMessage(msg): except KeyError: pass - values = values.items() - values.sort() + values = sorted(values.items()) for k, v in values: ns_output.append(" %s = %s" % (k, v)) @@ -124,6 +127,7 @@ def queriesFromLogs(s): return [(match.group(1), cgi.parse_qs(match.group(2))) for match in qre.finditer(s)] + def queriesFromPostdata(s): # This looks for query data in a line that starts POSTDATA=. # Tamperdata outputs such lines. If there's a 'Host=' in that block, @@ -133,16 +137,18 @@ def queriesFromPostdata(s): return [(match.group('host') or 'POSTDATA', cgi.parse_qs(match.group('query'))) for match in qre.finditer(s)] + def find_urls(s): # Regular expression borrowed from urlscan # by Daniel Burrows , GPL. - urlinternalpattern=r'[{}a-zA-Z/\-_0-9%?&.=:;+,#~]' - urltrailingpattern=r'[{}a-zA-Z/\-_0-9%&=+#]' + urlinternalpattern = r'[{}a-zA-Z/\-_0-9%?&.=:;+,#~]' + urltrailingpattern = r'[{}a-zA-Z/\-_0-9%&=+#]' httpurlpattern = r'(?:https?://' + urlinternalpattern + r'*' + urltrailingpattern + r')' # Used to guess that blah.blah.blah.TLD is a URL. - tlds=['biz', 'com', 'edu', 'info', 'org'] - guessedurlpattern=r'(?:[a-zA-Z0-9_\-%]+(?:\.[a-zA-Z0-9_\-%]+)*\.(?:' + '|'.join(tlds) + '))' - urlre = re.compile(r'(?:<(?:URL:)?)?(' + httpurlpattern + '|' + guessedurlpattern + '|(?:mailto:[a-zA-Z0-9\-_]*@[0-9a-zA-Z_\-.]*[0-9a-zA-Z_\-]))>?') + tlds = ['biz', 'com', 'edu', 'info', 'org'] + guessedurlpattern = r'(?:[a-zA-Z0-9_\-%]+(?:\.[a-zA-Z0-9_\-%]+)*\.(?:' + '|'.join(tlds) + '))' + urlre = re.compile(r'(?:<(?:URL:)?)?(' + httpurlpattern + '|' + guessedurlpattern + + '|(?:mailto:[a-zA-Z0-9\-_]*@[0-9a-zA-Z_\-.]*[0-9a-zA-Z_\-]))>?') return [match.group(1) for match in urlre.finditer(s)] diff --git a/contrib/upgrade-store-1.1-to-2.0 b/contrib/upgrade-store-1.1-to-2.0 index 1f587c35..1907ce37 100644 --- a/contrib/upgrade-store-1.1-to-2.0 +++ b/contrib/upgrade-store-1.1-to-2.0 @@ -23,15 +23,16 @@ from optparse import OptionParser def askForPassword(): return getpass.getpass("DB Password: ") -def askForConfirmation(dbname,tablename): + +def askForConfirmation(dbname, tablename): print """The table %s from the database %s will be dropped, and - an empty table with the new nonce table schema will replace it."""%( - tablename, dbname) + an empty table with the new nonce table schema will replace it.""" % (tablename, dbname) return raw_input("Continue? ").lower().strip().startswith('y') + def doSQLiteUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url VARCHAR, @@ -39,13 +40,14 @@ def doSQLiteUpgrade(db_conn, nonce_table_name='oid_nonces'): salt CHAR(40), UNIQUE(server_url, timestamp, salt) ); - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() + def doMySQLUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url BLOB, @@ -54,13 +56,14 @@ def doMySQLUpgrade(db_conn, nonce_table_name='oid_nonces'): PRIMARY KEY (server_url(255), timestamp, salt) ) TYPE=InnoDB; - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() + def doPostgreSQLUpgrade(db_conn, nonce_table_name='oid_nonces'): cur = db_conn.cursor() - cur.execute('DROP TABLE %s'%nonce_table_name) + cur.execute('DROP TABLE %s' % nonce_table_name) sql = """ CREATE TABLE %s ( server_url VARCHAR(2047), @@ -68,11 +71,12 @@ def doPostgreSQLUpgrade(db_conn, nonce_table_name='oid_nonces'): salt CHAR(40), PRIMARY KEY (server_url, timestamp, salt) ); - """%nonce_table_name + """ % nonce_table_name cur.execute(sql) cur.close() db_conn.commit() + def main(argv=None): parser = OptionParser() parser.add_option("-u", "--user", dest="username", @@ -106,7 +110,7 @@ def main(argv=None): return 1 try: db_conn = sqlite.connect(options.sqlite_db_name) - except Exception, e: + except Exception as e: print "Could not connect to SQLite database:", str(e) return 1 @@ -125,11 +129,11 @@ def main(argv=None): return 1 try: - db_conn = psycopg.connect(database = options.postgres_db_name, - user = options.username, - host = options.db_host, - password = password) - except Exception, e: + db_conn = psycopg.connect(database=options.postgres_db_name, + user=options.username, + host=options.db_host, + password=password) + except Exception as e: print "Could not connect to PostgreSQL database:", str(e) return 1 @@ -150,7 +154,7 @@ def main(argv=None): try: db_conn = MySQLdb.connect(options.db_host, options.username, password, options.mysql_db_name) - except Exception, e: + except Exception as e: print "Could not connect to MySQL database:", str(e) return 1 diff --git a/examples/consumer.py b/examples/consumer.py index 1d448ded..908130af 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -35,6 +35,7 @@ def quoteattr(s): distribution.""") sys.exit(1) else: + del openid from openid.consumer import consumer from openid.cryptutil import randomString from openid.extensions import pape, sreg @@ -45,13 +46,14 @@ def quoteattr(s): # Used with an OpenID provider affiliate program. OPENID_PROVIDER_NAME = 'MyOpenID' -OPENID_PROVIDER_URL ='https://www.myopenid.com/affiliate_signup?affiliate_id=39' +OPENID_PROVIDER_URL = 'https://www.myopenid.com/affiliate_signup?affiliate_id=39' class OpenIDHTTPServer(HTTPServer): """http server that contains a reference to an OpenID consumer and knows its base URL. """ + def __init__(self, store, *args, **kwargs): HTTPServer.__init__(self, *args, **kwargs) self.sessions = {} @@ -63,6 +65,7 @@ def __init__(self, store, *args, **kwargs): else: self.base_url = 'http://%s/' % (self.server_name,) + class OpenIDRequestHandler(BaseHTTPRequestHandler): """Request handler that knows how to verify an OpenID identity.""" SESSION_COOKIE_NAME = 'pyoidconsexsid' @@ -145,9 +148,7 @@ def do_GET(self): else: self.notFound() - except (KeyboardInterrupt, SystemExit): - raise - except: + except Exception: self.send_response(500) self.send_header('Content-type', 'text/html') self.setSessionCookie() @@ -170,10 +171,10 @@ def doVerify(self): use_pape = 'use_pape' in self.query use_stateless = 'use_stateless' in self.query - oidconsumer = self.getConsumer(stateless = use_stateless) + oidconsumer = self.getConsumer(stateless=use_stateless) try: request = oidconsumer.begin(openid_url) - except consumer.DiscoveryFailure, exc: + except consumer.DiscoveryFailure as exc: fetch_error_string = 'Error in discovery: %s' % ( cgi.escape(str(exc[0]))) self.render(fetch_error_string, @@ -207,7 +208,7 @@ def doVerify(self): else: form_html = request.htmlMarkup( trust_root, return_to, - form_tag_attrs={'id':'openid_message'}, + form_tag_attrs={'id': 'openid_message'}, immediate=immediate) self.wfile.write(form_html) @@ -230,7 +231,7 @@ def doProcess(self): # us. Status is a code indicating the response type. info is # either None or a string containing more information about # the return type. - url = 'http://'+self.headers.get('Host')+self.path + url = 'http://' + self.headers.get('Host') + self.path info = oidconsumer.complete(self.query, url) sreg_resp = None @@ -300,8 +301,7 @@ def renderSREG(self, sreg_data): self.wfile.write( '
No registration data was returned
') else: - sreg_list = sreg_data.items() - sreg_list.sort() + sreg_list = sorted(sreg_data.items()) self.wfile.write( '

Registration Data

' '' @@ -443,14 +443,17 @@ def pageFooter(self, form_contents):
- - + + + + ''' % (quoteattr(self.buildURL('verify')), quoteattr(form_contents))) + def main(host, port, data_path, weak_ssl=False): # Instantiate OpenID consumer store and OpenID consumer. If you # were connecting to a database, you would create the database @@ -470,6 +473,7 @@ def main(host, port, data_path, weak_ssl=False): print server.base_url server.serve_forever() + if __name__ == '__main__': parser = optparse.OptionParser('Usage:\n %prog [options]') parser.add_option( diff --git a/examples/discover b/examples/discover index 9b74e8a0..e2ede67e 100644 --- a/examples/discover +++ b/examples/discover @@ -2,10 +2,11 @@ from openid.consumer.discover import discover, DiscoveryFailure from openid.fetchers import HTTPFetchingError -names = [["server_url", "Server URL "], - ["local_id", "Local ID "], +names = [["server_url", "Server URL "], + ["local_id", "Local ID "], ["canonicalID", "Canonical ID"], - ] + ] + def show_services(user_input, normalized, services): print " Claimed identifier:", normalized @@ -28,6 +29,7 @@ def show_services(user_input, normalized, services): print " No OpenID services found" print + if __name__ == "__main__": import sys @@ -36,10 +38,10 @@ if __name__ == "__main__": print "Running discovery on", user_input try: normalized, services = discover(user_input) - except DiscoveryFailure, why: + except DiscoveryFailure as why: print "Discovery failed:", why print - except HTTPFetchingError, why: + except HTTPFetchingError as why: print "HTTP request failed:", why print else: diff --git a/examples/djopenid/consumer/models.py b/examples/djopenid/consumer/models.py index 71a83623..b194906e 100644 --- a/examples/djopenid/consumer/models.py +++ b/examples/djopenid/consumer/models.py @@ -1,3 +1 @@ -from django.db import models - -# Create your models here. +"""Required module for Django application.""" diff --git a/examples/djopenid/consumer/urls.py b/examples/djopenid/consumer/urls.py index d55e056c..7190093e 100644 --- a/examples/djopenid/consumer/urls.py +++ b/examples/djopenid/consumer/urls.py @@ -1,5 +1,4 @@ - -from django.conf.urls.defaults import * +from django.conf.urls.defaults import patterns urlpatterns = patterns( 'djopenid.consumer.views', diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index 1f4dd945..bbc0ff87 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -1,5 +1,3 @@ - -from django import http from django.http import HttpResponseRedirect from django.views.generic.simple import direct_to_template @@ -7,7 +5,7 @@ from openid.consumer.discover import DiscoveryFailure from openid.extensions import ax, pape, sreg from openid.server.trustroot import RP_RETURN_TO_URL_TYPE -from openid.yadis.constants import YADIS_CONTENT_TYPE, YADIS_HEADER_NAME +from openid.yadis.constants import YADIS_HEADER_NAME from .. import util @@ -15,12 +13,13 @@ 'AUTH_PHISHING_RESISTANT', 'AUTH_MULTI_FACTOR', 'AUTH_MULTI_FACTOR_PHYSICAL', - ] +] # List of (name, uri) for use in generating the request form. POLICY_PAIRS = [(p, getattr(pape, p)) for p in PAPE_POLICIES] + def getOpenIDStore(): """ Return an OpenID store object fit for the currently-chosen @@ -28,21 +27,24 @@ def getOpenIDStore(): """ return util.getOpenIDStore('/tmp/djopenid_c_store', 'c_') + def getConsumer(request): """ Get a Consumer object to perform OpenID authentication. """ return consumer.Consumer(request.session, getOpenIDStore()) + def renderIndexPage(request, **template_args): template_args['consumer_url'] = util.getViewURL(request, startOpenID) template_args['pape_policies'] = POLICY_PAIRS - response = direct_to_template( + response = direct_to_template( request, 'consumer/index.html', template_args) response[YADIS_HEADER_NAME] = util.getViewURL(request, rpXRDS) return response + def startOpenID(request): """ Start the OpenID authentication process. Renders an @@ -67,7 +69,7 @@ def startOpenID(request): try: auth_request = c.begin(openid_url) - except DiscoveryFailure, e: + except DiscoveryFailure as e: # Some other protocol-level failure occurred. error = "OpenID discovery error: %s" % (str(e),) @@ -133,6 +135,7 @@ def startOpenID(request): return renderIndexPage(request) + def finishOpenID(request): """ Finish the OpenID authentication process. Invoke the OpenID @@ -173,7 +176,7 @@ def finishOpenID(request): 'http://schema.openid.net/namePerson'), 'web': ax_response.get( 'http://schema.openid.net/contact/web/default'), - } + } # Get a PAPE response object if response information was # included in the OpenID response. @@ -197,7 +200,7 @@ def finishOpenID(request): 'sreg': sreg_response and sreg_response.items(), 'ax': ax_items.items(), 'pape': pape_response} - } + } result = results[response.status] @@ -210,6 +213,7 @@ def finishOpenID(request): return renderIndexPage(request, **result) + def rpXRDS(request): """ Return a relying party verification XRDS document diff --git a/examples/djopenid/manage.py b/examples/djopenid/manage.py index ae949585..45a1ee63 100644 --- a/examples/djopenid/manage.py +++ b/examples/djopenid/manage.py @@ -2,10 +2,12 @@ from django.core.management import execute_manager try: - import settings # Assumed to be in the same directory. + import settings # Assumed to be in the same directory. except ImportError: import sys - sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" % __file__) + sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've " + "customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If " + "the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" % __file__) sys.exit(1) if __name__ == "__main__": diff --git a/examples/djopenid/server/models.py b/examples/djopenid/server/models.py index 71a83623..b194906e 100644 --- a/examples/djopenid/server/models.py +++ b/examples/djopenid/server/models.py @@ -1,3 +1 @@ -from django.db import models - -# Create your models here. +"""Required module for Django application.""" diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index d86151bc..6cae5471 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -19,6 +19,7 @@ def dummyRequest(): request.META['SERVER_PROTOCOL'] = 'HTTP' return request + class TestProcessTrustResult(TestCase): def setUp(self): self.request = dummyRequest() @@ -32,12 +33,11 @@ def setUp(self): 'openid.identity': id_url, 'openid.return_to': 'http://127.0.0.1/%s' % (self.id(),), 'openid.sreg.required': 'postcode', - }) + }) self.openid_request = CheckIDRequest.fromMessage(message, op_endpoint) views.setRequest(self.request, self.openid_request) - def test_allow(self): self.request.POST['allow'] = 'Yes' @@ -61,7 +61,6 @@ def test_cancel(self): self.failIf('openid.sreg.postcode=12345' in finalURL, finalURL) - class TestShowDecidePage(TestCase): def test_unreachableRealm(self): self.request = dummyRequest() @@ -75,7 +74,7 @@ def test_unreachableRealm(self): 'openid.identity': id_url, 'openid.return_to': 'http://unreachable.invalid/%s' % (self.id(),), 'openid.sreg.required': 'postcode', - }) + }) self.openid_request = CheckIDRequest.fromMessage(message, op_endpoint) views.setRequest(self.request, self.openid_request) @@ -85,7 +84,6 @@ def test_unreachableRealm(self): response) - class TestGenericXRDS(TestCase): def test_genericRender(self): """Render an XRDS document with a single type URI and a single endpoint URL diff --git a/examples/djopenid/server/urls.py b/examples/djopenid/server/urls.py index d6931a4d..6763d856 100644 --- a/examples/djopenid/server/urls.py +++ b/examples/djopenid/server/urls.py @@ -1,5 +1,4 @@ - -from django.conf.urls.defaults import * +from django.conf.urls.defaults import patterns urlpatterns = patterns( 'djopenid.server.views', diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index bb6d6602..bbb9468d 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -23,7 +23,7 @@ from openid.consumer.discover import OPENID_IDP_2_0_TYPE from openid.extensions import pape, sreg from openid.fetchers import HTTPFetchingError -from openid.server.server import CheckIDRequest, EncodingError, ProtocolError, Server +from openid.server.server import EncodingError, ProtocolError, Server from openid.server.trustroot import verifyReturnTo from openid.yadis.discover import DiscoveryFailure @@ -38,12 +38,14 @@ def getOpenIDStore(): """ return util.getOpenIDStore('/tmp/djopenid_s_store', 's_') + def getServer(request): """ Get a Server object to perform OpenID authentication. """ return Server(getOpenIDStore(), getViewURL(request, endpoint)) + def setRequest(request, openid_request): """ Store the openid request information in the session. @@ -53,12 +55,14 @@ def setRequest(request, openid_request): else: request.session['openid_request'] = None + def getRequest(request): """ Get an openid request from the session, if any. """ return request.session.get('openid_request') + def server(request): """ Respond to requests for the server's primary web page. @@ -70,6 +74,7 @@ def server(request): 'server_xrds_url': getViewURL(request, idpXrds), }) + def idpXrds(request): """ Respond to requests for the IDP's XRDS document, which is used in @@ -78,6 +83,7 @@ def idpXrds(request): return util.renderXRDS( request, [OPENID_IDP_2_0_TYPE], [getViewURL(request, endpoint)]) + def idPage(request): """ Serve the identity page for OpenID URLs. @@ -87,6 +93,7 @@ def idPage(request): 'server/idPage.html', {'server_url': getViewURL(request, endpoint)}) + def trustPage(request): """ Display the trust page template, which allows the user to decide @@ -95,7 +102,8 @@ def trustPage(request): return direct_to_template( request, 'server/trust.html', - {'trust_handler_url':getViewURL(request, processTrustResult)}) + {'trust_handler_url': getViewURL(request, processTrustResult)}) + def endpoint(request): """ @@ -109,7 +117,7 @@ def endpoint(request): # library can use. try: openid_request = s.decodeRequest(query) - except ProtocolError, why: + except ProtocolError as why: # This means the incoming request was invalid. return direct_to_template( request, @@ -134,6 +142,7 @@ def endpoint(request): openid_response = s.handleRequest(openid_request) return displayResponse(request, openid_response) + def handleCheckIDRequest(request, openid_request): """ Handle checkid_* requests. Get input from the user to find out @@ -175,6 +184,7 @@ def handleCheckIDRequest(request, openid_request): setRequest(request, openid_request) return showDecidePage(request, openid_request) + def showDecidePage(request, openid_request): """ Render a page to the user so a trust decision can be made. @@ -186,11 +196,10 @@ def showDecidePage(request, openid_request): try: # Stringify because template's ifequal can only compare to strings. - trust_root_valid = verifyReturnTo(trust_root, return_to) \ - and "Valid" or "Invalid" - except DiscoveryFailure, err: + trust_root_valid = verifyReturnTo(trust_root, return_to) and "Valid" or "Invalid" + except DiscoveryFailure: trust_root_valid = "DISCOVERY_FAILED" - except HTTPFetchingError, err: + except HTTPFetchingError: trust_root_valid = "Unreachable" pape_request = pape.Request.fromOpenIDRequest(openid_request) @@ -199,11 +208,12 @@ def showDecidePage(request, openid_request): request, 'server/trust.html', {'trust_root': trust_root, - 'trust_handler_url':getViewURL(request, processTrustResult), + 'trust_handler_url': getViewURL(request, processTrustResult), 'trust_root_valid': trust_root_valid, 'pape_request': pape_request, }) + def processTrustResult(request): """ Handle the result of a trust decision and respond to the RP @@ -236,7 +246,7 @@ def processTrustResult(request): 'country': 'ES', 'language': 'eu', 'timezone': 'America/New_York', - } + } sreg_req = sreg.SRegRequest.fromOpenIDRequest(openid_request) sreg_resp = sreg.SRegResponse.extractResponse(sreg_req, sreg_data) @@ -248,6 +258,7 @@ def processTrustResult(request): return displayResponse(request, openid_response) + def displayResponse(request, openid_response): """ Display an OpenID response. Errors will be displayed directly to @@ -260,7 +271,7 @@ def displayResponse(request, openid_response): # Encode the response into something that is renderable. try: webresponse = s.encodeResponse(openid_response) - except EncodingError, why: + except EncodingError as why: # If it couldn't be encoded, display an error. text = why.response.encodeToKVForm() return direct_to_template( diff --git a/examples/djopenid/settings.py b/examples/djopenid/settings.py index f2a7c872..1ba3ff44 100644 --- a/examples/djopenid/settings.py +++ b/examples/djopenid/settings.py @@ -6,9 +6,11 @@ try: import openid -except ImportError, e: +except ImportError as e: warnings.warn("Could not import OpenID library. Please consult the djopenid README.") sys.exit(1) +else: + del openid DEBUG = True TEMPLATE_DEBUG = DEBUG @@ -21,7 +23,7 @@ DATABASES = { 'default': { - 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'mysql', 'sqlite3' or 'oracle'. + 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'mysql', 'sqlite3' or 'oracle'. 'NAME': '/tmp/test.db', # Or path to database file if using sqlite3. 'USER': '', # Not used with sqlite3. 'PASSWORD': '', # Not used with sqlite3. @@ -61,7 +63,7 @@ TEMPLATE_LOADERS = ( 'django.template.loaders.filesystem.Loader', 'django.template.loaders.app_directories.Loader', -# 'django.template.loaders.eggs.load_template_source', + # 'django.template.loaders.eggs.load_template_source', ) MIDDLEWARE_CLASSES = ( diff --git a/examples/djopenid/urls.py b/examples/djopenid/urls.py index d91ee1f1..37833177 100644 --- a/examples/djopenid/urls.py +++ b/examples/djopenid/urls.py @@ -1,4 +1,4 @@ -from django.conf.urls.defaults import * +from django.conf.urls.defaults import include, patterns urlpatterns = patterns( '', diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index f06e11fb..2847d8e3 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -1,17 +1,12 @@ - """ Utility code for the Django example consumer and server. """ - from urlparse import urljoin -from django import http from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import reverse as reverseURL from django.db import connection -from django.template import loader -from django.template.context import RequestContext from django.views.generic.simple import direct_to_template from openid.store import sqlstore @@ -41,7 +36,7 @@ def getOpenIDStore(filestore_path, table_prefix): The result of this function should be passed to the Consumer constructor as the store parameter. """ - if not settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE'): + if not settings.DATABASES.get('default', {'ENGINE': None}).get('ENGINE'): return FileOpenIDStore(filestore_path) # Possible side-effect: create a database connection if one isn't @@ -52,27 +47,23 @@ def getOpenIDStore(filestore_path, table_prefix): tablenames = { 'associations_table': table_prefix + 'openid_associations', 'nonces_table': table_prefix + 'openid_nonces', - } + } types = { 'django.db.backends.postgresql': sqlstore.PostgreSQLStore, 'django.db.backends.mysql': sqlstore.MySQLStore, 'django.db.backends.sqlite3': sqlstore.SQLiteStore, - } + } + engine = settings.DATABASES.get('default', {'ENGINE': None}).get('ENGINE') try: - s = types[settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE')](connection.connection, - **tablenames) + s = types[engine](connection.connection, **tablenames) except KeyError: - raise ImproperlyConfigured, \ - "Database engine %s not supported by OpenID library" % \ - (settings.DATABASES.get('default', {'ENGINE':None}).get('ENGINE'),) + raise ImproperlyConfigured("Database engine %s not supported by OpenID library" % engine) try: s.createTables() - except (SystemExit, KeyboardInterrupt, MemoryError), e: - raise - except: + except Exception: # XXX This is not the Right Way to do this, but because the # underlying database implementation might differ in behavior # at this point, we can't reliably catch the right @@ -85,11 +76,13 @@ def getOpenIDStore(filestore_path, table_prefix): return s + def getViewURL(req, view_name_or_obj, args=None, kwargs=None): relative_url = reverseURL(view_name_or_obj, args=args, kwargs=kwargs) full_path = req.META.get('SCRIPT_NAME', '') + relative_url return urljoin(getBaseURL(req), full_path) + def getBaseURL(req): """ Given a Django web request object, returns the OpenID 'trust root' @@ -101,12 +94,12 @@ def getBaseURL(req): name = req.META['HTTP_HOST'] try: name = name[:name.index(':')] - except: + except Exception: pass try: port = int(req.META['SERVER_PORT']) - except: + except Exception: port = 80 proto = req.META['SERVER_PROTOCOL'] @@ -124,6 +117,7 @@ def getBaseURL(req): url = "%s://%s%s/" % (proto, name, port) return url + def normalDict(request_data): """ Converts a django request MutliValueDict (e.g., request.GET, @@ -135,6 +129,7 @@ def normalDict(request_data): """ return dict((k, v) for k, v in request_data.iteritems()) + def renderXRDS(request, type_uris, endpoint_urls): """Render an XRDS page with the specified type URIs and endpoint URLs in one service block, and return a response with the @@ -142,6 +137,6 @@ def renderXRDS(request, type_uris, endpoint_urls): """ response = direct_to_template( request, 'xrds.xml', - {'type_uris':type_uris, 'endpoint_urls':endpoint_urls,}) + {'type_uris': type_uris, 'endpoint_urls': endpoint_urls}) response['Content-Type'] = YADIS_CONTENT_TYPE return response diff --git a/examples/djopenid/views.py b/examples/djopenid/views.py index 3f08324d..5d7a4e2a 100644 --- a/examples/djopenid/views.py +++ b/examples/djopenid/views.py @@ -12,4 +12,4 @@ def index(request): return direct_to_template( request, 'index.html', - {'consumer_url':consumer_url, 'server_url':server_url}) + {'consumer_url': consumer_url, 'server_url': server_url}) diff --git a/examples/server.py b/examples/server.py index 0b12597e..2da8835c 100644 --- a/examples/server.py +++ b/examples/server.py @@ -30,6 +30,7 @@ def quoteattr(s): distribution.""") sys.exit(1) else: + del openid from openid.consumer import discover from openid.extensions import sreg from openid.server import server @@ -41,6 +42,7 @@ class OpenIDHTTPServer(HTTPServer): http server that contains a reference to an OpenID Server and knows its base URL. """ + def __init__(self, *args, **kwargs): HTTPServer.__init__(self, *args, **kwargs) @@ -63,7 +65,6 @@ def __init__(self, *args, **kwargs): self.user = None BaseHTTPRequestHandler.__init__(self, *args, **kwargs) - def do_GET(self): try: self.parsed_uri = urlparse(self.path) @@ -94,9 +95,7 @@ def do_GET(self): self.send_response(404) self.end_headers() - except (KeyboardInterrupt, SystemExit): - raise - except: + except Exception: self.send_response(500) self.send_header('Content-type', 'text/html') self.end_headers() @@ -124,9 +123,7 @@ def do_POST(self): self.send_response(404) self.end_headers() - except (KeyboardInterrupt, SystemExit): - raise - except: + except Exception: self.send_response(500) self.send_header('Content-type', 'text/html') self.end_headers() @@ -160,7 +157,6 @@ def handleAllow(self, query): self.displayResponse(response) - def setUser(self): cookies = self.headers.get('Cookie') if cookies: @@ -181,7 +177,7 @@ def isAuthorized(self, identity_url, trust_root): def serverEndPoint(self, query): try: request = self.server.openid.decodeRequest(query) - except server.ProtocolError, why: + except server.ProtocolError as why: self.displayResponse(why) return @@ -203,8 +199,8 @@ def addSRegResponse(self, request, response): # and the user should be asked for permission to release # it. sreg_data = { - 'nickname':self.user - } + 'nickname': self.user + } sreg_resp = sreg.SRegResponse.extractResponse(sreg_req, sreg_data) response.addExtension(sreg_resp) @@ -229,7 +225,7 @@ def handleCheckIDRequest(self, request): def displayResponse(self, response): try: webresponse = self.server.openid.encodeResponse(response) - except server.EncodingError, why: + except server.EncodingError as why: text = why.response.encodeToKVForm() self.showErrorPage('
%s
' % cgi.escape(text)) return @@ -287,7 +283,7 @@ def term(url, text): ('http://www.openidenabled.com/', 'An OpenID community Web site, home of this library'), ('http://www.openid.net/', 'the official OpenID Web site'), - ] + ] resource_markup = ''.join([term(url, text) for url, text in resources]) @@ -336,14 +332,14 @@ def showErrorPage(self, error_message): ''' % error_message) def showDecidePage(self, request): - id_url_base = self.server.base_url+'id/' + id_url_base = self.server.base_url + 'id/' # XXX: This may break if there are any synonyms for id_url_base, # such as referring to it by IP address or a CNAME. - assert (request.identity.startswith(id_url_base) or + assert (request.identity.startswith(id_url_base) or request.idSelect()), repr((request.identity, id_url_base)) expected_user = request.identity[len(id_url_base):] - if request.idSelect(): # We are being asked to select an ID + if request.idSelect(): # We are being asked to select an ID msg = '''\

A site has asked for your identity. You may select an identifier by which you would like this site to know you. @@ -355,7 +351,7 @@ def showDecidePage(self, request): fdata = { 'id_url_base': id_url_base, 'trust_root': request.trust_root, - } + } form = '''\

@@ -370,7 +366,7 @@ def showDecidePage(self, request): - '''%fdata + ''' % fdata elif expected_user == self.user: msg = '''\

A new site has asked to confirm your identity. If you @@ -382,7 +378,7 @@ def showDecidePage(self, request): fdata = { 'identity': request.identity, 'trust_root': request.trust_root, - } + } form = '''\

@@ -400,7 +396,7 @@ def showDecidePage(self, request): mdata = { 'expected_user': expected_user, 'user': self.user, - } + } msg = '''\

A site has asked for an identity belonging to %(expected_user)s, but you are logged in as %(user)s. To @@ -412,7 +408,7 @@ def showDecidePage(self, request): 'identity': request.identity, 'trust_root': request.trust_root, 'expected_user': expected_user, - } + } form = '''\

Identity:%(identity)s
@@ -432,9 +428,9 @@ def showDecidePage(self, request): def showIdPage(self, path): link_tag = '' %\ - self.server.base_url - yadis_loc_tag = ''%\ - (self.server.base_url+'yadis/'+path[4:]) + self.server.base_url + yadis_loc_tag = '' %\ + (self.server.base_url + 'yadis/' + path[4:]) disco_tags = link_tag + yadis_loc_tag ident = self.server.base_url + path[1:] @@ -480,8 +476,8 @@ def showYadis(self, user): -"""%(discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE, - endpoint_url, user_url)) +""" % (discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE, + endpoint_url, user_url)) def showServerYadis(self): self.send_response(200) @@ -503,10 +499,10 @@ def showServerYadis(self): -"""%(discover.OPENID_IDP_2_0_TYPE, endpoint_url,)) +""" % (discover.OPENID_IDP_2_0_TYPE, endpoint_url,)) def showMainPage(self): - yadis_tag = ''%\ + yadis_tag = '' %\ (self.server.base_url + 'serveryadis') if self.user: openid_url = self.server.base_url + 'id/' + self.user @@ -521,7 +517,7 @@ def showMainPage(self): order to simulate a standard Web user experience. You are not logged in.

""" - self.showPage(200, 'Main Page', head_extras = yadis_tag, msg='''\ + self.showPage(200, 'Main Page', head_extras=yadis_tag, msg='''\

This is a simple OpenID server implemented using the Python OpenID library.

@@ -557,13 +553,14 @@ def showPage(self, response_code, title, if self.user is None: user_link = 'not logged in.' else: - user_link = 'logged in as %s.
Log out' % \ + user_link = 'logged in as %s.
' \ + 'Log out' % \ (self.user, self.user) body = '' if err is not None: - body += '''\ + body += '''\
%s
@@ -588,7 +585,7 @@ def showPage(self, response_code, title, 'head_extras': head_extras, 'body': body, 'user_link': user_link, - } + } self.send_response(response_code) self.writeUserHeader() @@ -689,6 +686,7 @@ def main(host, port, data_path): print httpserver.base_url httpserver.serve_forever() + if __name__ == '__main__': parser = optparse.OptionParser('Usage:\n %prog [options]') parser.add_option( diff --git a/openid/__init__.py b/openid/__init__.py index 8ecb0339..b172b30c 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -41,7 +41,7 @@ 'store', 'urinorm', 'yadis', - ] +] # Parse the version info try: diff --git a/openid/association.py b/openid/association.py index f9cc91e4..8a52b78f 100644 --- a/openid/association.py +++ b/openid/association.py @@ -30,7 +30,7 @@ 'encrypted_negotiator', 'SessionNegotiator', 'Association', - ] +] import time @@ -40,7 +40,7 @@ all_association_types = [ 'HMAC-SHA1', 'HMAC-SHA256', - ] +] if hasattr(cryptutil, 'hmacSha256'): supported_association_types = list(all_association_types) @@ -50,32 +50,34 @@ ('HMAC-SHA1', 'no-encryption'), ('HMAC-SHA256', 'DH-SHA256'), ('HMAC-SHA256', 'no-encryption'), - ] + ] only_encrypted_association_order = [ ('HMAC-SHA1', 'DH-SHA1'), ('HMAC-SHA256', 'DH-SHA256'), - ] + ] else: supported_association_types = ['HMAC-SHA1'] default_association_order = [ ('HMAC-SHA1', 'DH-SHA1'), ('HMAC-SHA1', 'no-encryption'), - ] + ] only_encrypted_association_order = [ ('HMAC-SHA1', 'DH-SHA1'), - ] + ] + def getSessionTypes(assoc_type): """Return the allowed session types for a given association type""" assoc_to_session = { 'HMAC-SHA1': ['DH-SHA1', 'no-encryption'], 'HMAC-SHA256': ['DH-SHA256', 'no-encryption'], - } + } return assoc_to_session.get(assoc_type, []) + def checkSessionType(assoc_type, session_type): """Check to make sure that this pair of assoc type and session type are allowed""" @@ -84,6 +86,7 @@ def checkSessionType(assoc_type, session_type): 'Session type %r not valid for assocation type %r' % (session_type, assoc_type)) + class SessionNegotiator(object): """A session negotiator controls the allowed and preferred association types and association session types. Both the @@ -166,7 +169,6 @@ def addAllowedType(self, assoc_type, session_type=None): checkSessionType(assoc_type, session_type) self.allowed_types.append((assoc_type, session_type)) - def isAllowed(self, assoc_type, session_type): """Is this combination of association type and session type allowed?""" assoc_good = (assoc_type, session_type) in self.allowed_types @@ -181,9 +183,11 @@ def getAllowedType(self): except IndexError: return (None, None) + default_negotiator = SessionNegotiator(default_association_order) encrypted_negotiator = SessionNegotiator(only_encrypted_association_order) + def getSecretSize(assoc_type): if assoc_type == 'HMAC-SHA1': return 20 @@ -192,6 +196,7 @@ def getSecretSize(assoc_type): else: raise ValueError('Unsupported association type: %r' % (assoc_type,)) + class Association(object): """ This class represents an association between a server and a @@ -247,14 +252,12 @@ class Association(object): 'issued', 'lifetime', 'assoc_type', - ] - + ] _macs = { 'HMAC-SHA1': cryptutil.hmacSha1, 'HMAC-SHA256': cryptutil.hmacSha256, - } - + } def fromExpiresIn(cls, expires_in, handle, secret, assoc_type): """ @@ -378,7 +381,7 @@ def __eq__(self, other): @rtype: C{bool} """ - return type(self) is type(other) and self.__dict__ == other.__dict__ + return type(self) == type(other) and self.__dict__ == other.__dict__ def __ne__(self, other): """ @@ -403,13 +406,13 @@ def serialize(self): @rtype: str """ data = { - 'version':'2', - 'handle':self.handle, - 'secret':oidutil.toBase64(self.secret), - 'issued':str(int(self.issued)), - 'lifetime':str(int(self.lifetime)), - 'assoc_type':self.assoc_type - } + 'version': '2', + 'handle': self.handle, + 'secret': oidutil.toBase64(self.secret), + 'issued': str(int(self.issued)), + 'lifetime': str(int(self.lifetime)), + 'assoc_type': self.assoc_type + } assert len(data) == len(self.assoc_keys) pairs = [] @@ -476,7 +479,6 @@ def sign(self, pairs): return mac(self.secret, kv) - def getMessageSignature(self, message): """Return the signature of a message. @@ -499,8 +501,7 @@ def signMessage(self, message): @return: a new Message object with a signature @rtype: L{openid.message.Message} """ - if (message.hasKey(OPENID_NS, 'sig') or - message.hasKey(OPENID_NS, 'signed')): + if (message.hasKey(OPENID_NS, 'sig') or message.hasKey(OPENID_NS, 'signed')): raise ValueError('Message already has signed list or signature') extant_handle = message.getArg(OPENID_NS, 'assoc_handle') @@ -532,7 +533,6 @@ def checkMessageSignature(self, message): calculated_sig = self.getMessageSignature(message) return cryptutil.const_eq(calculated_sig, message_sig) - def _makePairs(self, message): signed = message.getArg(OPENID_NS, 'signed') if not signed: diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index eaa48472..c811ce05 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -251,7 +251,6 @@ def _httpResponseToMessage(response, server_url): return response_message - class Consumer(object): """An OpenID consumer implementation that performs discovery and does session management. @@ -338,7 +337,7 @@ def begin(self, user_url, anonymous=False): disco = Discovery(self.session, user_url, self.session_key_prefix) try: service = disco.getNextService(self._discover) - except fetchers.HTTPFetchingError, why: + except fetchers.HTTPFetchingError as why: raise DiscoveryFailure( 'Error fetching XRDS document: %s' % (why[0],), None) @@ -374,7 +373,7 @@ def beginWithoutDiscovery(self, service, anonymous=False): try: auth_req.setAnonymous(anonymous) - except ValueError, why: + except ValueError as why: raise ProtocolError(str(why)) return auth_req @@ -414,8 +413,7 @@ def complete(self, query, current_url): except KeyError: pass - if (response.status in ['success', 'cancel'] and - response.identity_url is not None): + if (response.status in ['success', 'cancel'] and response.identity_url is not None): disco = Discovery(self.session, response.identity_url, @@ -448,6 +446,7 @@ def setAssociationPreference(self, association_preferences): """ self.consumer.negotiator = SessionNegotiator(association_preferences) + class DiffieHellmanSHA1ConsumerSession(object): session_type = 'DH-SHA1' hash_func = staticmethod(cryptutil.sha1) @@ -469,7 +468,7 @@ def getRequest(self): args.update({ 'dh_modulus': cryptutil.longToBase64(self.dh.modulus), 'dh_gen': cryptutil.longToBase64(self.dh.generator), - }) + }) return args @@ -481,12 +480,14 @@ def extractSecret(self, response): enc_mac_key = oidutil.fromBase64(enc_mac_key64) return self.dh.xorSecret(dh_server_public, enc_mac_key, self.hash_func) + class DiffieHellmanSHA256ConsumerSession(DiffieHellmanSHA1ConsumerSession): session_type = 'DH-SHA256' hash_func = staticmethod(cryptutil.sha256) secret_size = 32 allowed_assoc_types = ['HMAC-SHA256'] + class PlainTextConsumerSession(object): session_type = 'no-encryption' allowed_assoc_types = ['HMAC-SHA1', 'HMAC-SHA256'] @@ -498,17 +499,21 @@ def extractSecret(self, response): mac_key64 = response.getArg(OPENID_NS, 'mac_key', no_default) return oidutil.fromBase64(mac_key64) + class SetupNeededError(Exception): """Internally-used exception that indicates that an immediate-mode request cancelled.""" + def __init__(self, user_setup_url=None): Exception.__init__(self, user_setup_url) self.user_setup_url = user_setup_url + class ProtocolError(ValueError): """Exception that indicates that a message violated the protocol. It is raised and caught internally to this file.""" + class TypeURIMismatch(ProtocolError): """A protocol error arising from type URIs mismatching """ @@ -525,7 +530,6 @@ def __str__(self): return s - class ServerError(Exception): """Exception that is raised when the server returns a 400 response code to a direct request.""" @@ -546,6 +550,7 @@ def fromMessage(cls, message): fromMessage = classmethod(fromMessage) + class GenericConsumer(object): """This is the implementation of the common logic for OpenID consumers. It is unaware of the application in which it is @@ -573,10 +578,10 @@ class GenericConsumer(object): openid1_return_to_identifier_name = 'openid1_claimed_id' session_types = { - 'DH-SHA1':DiffieHellmanSHA1ConsumerSession, - 'DH-SHA256':DiffieHellmanSHA256ConsumerSession, - 'no-encryption':PlainTextConsumerSession, - } + 'DH-SHA1': DiffieHellmanSHA1ConsumerSession, + 'DH-SHA256': DiffieHellmanSHA256ConsumerSession, + 'no-encryption': PlainTextConsumerSession, + } _discover = staticmethod(discover) @@ -635,12 +640,12 @@ def _complete_setup_needed(self, message, endpoint, _): def _complete_id_res(self, message, endpoint, return_to): try: self._checkSetupNeeded(message) - except SetupNeededError, why: + except SetupNeededError as why: return SetupNeededResponse(endpoint, why.user_setup_url) else: try: return self._doIdRes(message, endpoint, return_to) - except (ProtocolError, DiscoveryFailure), why: + except (ProtocolError, DiscoveryFailure) as why: return FailureResponse(endpoint, why[0]) def _completeInvalid(self, message, endpoint, _): @@ -657,7 +662,7 @@ def _checkReturnTo(self, message, return_to): # message. try: self._verifyReturnToArgs(message.toPostArgs()) - except ProtocolError, why: + except ProtocolError as why: _LOGGER.exception("Verifying return_to arguments: %s", why) return False @@ -721,7 +726,6 @@ def _doIdRes(self, message, endpoint, return_to): "return_to does not match return URL. Expected %r, got %r" % (return_to, message.getArg(OPENID_NS, 'return_to'))) - # Verify discovery information: endpoint = self._verifyDiscoveryResults(message, endpoint) _LOGGER.info("Received id_res response from %s using association %s", @@ -763,11 +767,10 @@ def _idResCheckNonce(self, message, endpoint): try: timestamp, salt = splitNonce(nonce) - except ValueError, why: + except ValueError as why: raise ProtocolError('Malformed nonce: %s' % (why[0],)) - if (self.store is not None and - not self.store.useNonce(server_url, timestamp, salt)): + if (self.store is not None and not self.store.useNonce(server_url, timestamp, salt)): raise ProtocolError('Nonce already used or out of range') def _idResCheckSignature(self, message, server_url): @@ -811,15 +814,12 @@ def _idResCheckForFields(self, message): require_fields = { OPENID2_NS: basic_fields + ['op_endpoint'], OPENID1_NS: basic_fields + ['identity'], - } + } require_sigs = { - OPENID2_NS: basic_sig_fields + ['response_nonce', - 'claimed_id', - 'assoc_handle', - 'op_endpoint',], + OPENID2_NS: basic_sig_fields + ['response_nonce', 'claimed_id', 'assoc_handle', 'op_endpoint'], OPENID1_NS: basic_sig_fields, - } + } for field in require_fields[message.getOpenIDNamespace()]: if not message.hasKey(OPENID_NS, field): @@ -833,7 +833,6 @@ def _idResCheckForFields(self, message): if message.hasKey(OPENID_NS, field) and field not in signed_list: raise ProtocolError('"%s" not signed' % (field,)) - def _verifyReturnToArgs(query): """Verify that the arguments in the return_to URL are present in this response. @@ -883,7 +882,6 @@ def _verifyDiscoveryResults(self, resp_msg, endpoint=None): else: return self._verifyDiscoveryResultsOpenID1(resp_msg, endpoint) - def _verifyDiscoveryResultsOpenID2(self, resp_msg, endpoint): to_match = OpenIDServiceEndpoint() to_match.type_uris = [OPENID_2_0_TYPE] @@ -896,8 +894,7 @@ def _verifyDiscoveryResultsOpenID2(self, resp_msg, endpoint): # claimed_id and identifier must both be present or both # be absent - if (to_match.claimed_id is None and - to_match.local_id is not None): + if (to_match.claimed_id is None and to_match.local_id is not None): raise ProtocolError( 'openid.identity is present without openid.claimed_id') @@ -925,7 +922,7 @@ def _verifyDiscoveryResultsOpenID2(self, resp_msg, endpoint): # case. try: self._verifyDiscoverySingle(endpoint, to_match) - except ProtocolError, e: + except ProtocolError as e: _LOGGER.exception("Error attempting to use stored discovery information: %s", e) _LOGGER.info("Attempting discovery to verify endpoint") endpoint = self._discoverAndVerify( @@ -968,7 +965,7 @@ def _verifyDiscoveryResultsOpenID1(self, resp_msg, endpoint): self._verifyDiscoverySingle(endpoint, to_match) except TypeURIMismatch: self._verifyDiscoverySingle(endpoint, to_match_1_0) - except ProtocolError, e: + except ProtocolError as e: _LOGGER.exception("Error attempting to use stored discovery information: %s", e) _LOGGER.info("Attempting discovery to verify endpoint") else: @@ -1048,7 +1045,6 @@ def _discoverAndVerify(self, claimed_id, to_match_endpoints): return self._verifyDiscoveredServices(claimed_id, services, to_match_endpoints) - def _verifyDiscoveredServices(self, claimed_id, services, to_match_endpoints): """See @L{_discoverAndVerify}""" @@ -1060,7 +1056,7 @@ def _verifyDiscoveredServices(self, claimed_id, services, to_match_endpoints): try: self._verifyDiscoverySingle( endpoint, to_match_endpoint) - except ProtocolError, why: + except ProtocolError as why: failure_messages.append(str(why)) else: # It matches, so discover verification has @@ -1087,7 +1083,7 @@ def _checkAuth(self, message, server_url): return False try: response = self._makeKVPost(request, server_url) - except (fetchers.HTTPFetchingError, ServerError), e: + except (fetchers.HTTPFetchingError, ServerError) as e: _LOGGER.exception('check_authentication failed: %s', e) return False else: @@ -1167,7 +1163,7 @@ def _negotiateAssociation(self, endpoint): try: assoc = self._requestAssociation( endpoint, assoc_type, session_type) - except ServerError, why: + except ServerError as why: supportedTypes = self._extractSupportedAssociationType(why, endpoint, assoc_type) @@ -1179,7 +1175,7 @@ def _negotiateAssociation(self, endpoint): try: assoc = self._requestAssociation( endpoint, assoc_type, session_type) - except ServerError, why: + except ServerError as why: # Do not keep trying, since it rejected the # association type that it told us to use. _LOGGER.error('Server %s refused its suggested association type: session_type=%s, assoc_type=%s', @@ -1201,8 +1197,7 @@ def _extractSupportedAssociationType(self, server_error, endpoint, """ # Any error message whose code is not 'unsupported-type' # should be considered a total failure. - if server_error.error_code != 'unsupported-type' or \ - server_error.message.isOpenID1(): + if server_error.error_code != 'unsupported-type' or server_error.message.isOpenID1(): _LOGGER.error('Server error when requesting an association from %r: %s', endpoint.server_url, server_error.error_text) return None @@ -1227,7 +1222,6 @@ def _extractSupportedAssociationType(self, server_error, endpoint, else: return assoc_type, session_type - def _requestAssociation(self, endpoint, assoc_type, session_type): """Make and process one association request to this endpoint's OP endpoint URL. @@ -1242,16 +1236,16 @@ def _requestAssociation(self, endpoint, assoc_type, session_type): try: response = self._makeKVPost(args, endpoint.server_url) - except fetchers.HTTPFetchingError, why: + except fetchers.HTTPFetchingError as why: _LOGGER.exception('openid.associate request failed: %s', why) return None try: assoc = self._extractAssociation(response, assoc_session) - except KeyError, why: + except KeyError as why: _LOGGER.exception('Missing required parameter in response from %s: %s', endpoint.server_url, why) return None - except ProtocolError, why: + except ProtocolError as why: _LOGGER.exception('Protocol error parsing response from %s: %s', endpoint.server_url, why) return None else: @@ -1287,15 +1281,14 @@ def _createAssociateRequest(self, endpoint, assoc_type, session_type): args = { 'mode': 'associate', 'assoc_type': assoc_type, - } + } if not endpoint.compatibilityMode(): args['ns'] = OPENID2_NS # Leave out the session type if we're in compatibility mode # *and* it's no-encryption. - if (not endpoint.compatibilityMode() or - assoc_session.session_type != 'no-encryption'): + if (not endpoint.compatibilityMode() or assoc_session.session_type != 'no-encryption'): args['session_type'] = assoc_session.session_type args.update(assoc_session.getRequest()) @@ -1372,7 +1365,7 @@ def _extractAssociation(self, assoc_response, assoc_session): OPENID_NS, 'expires_in', no_default) try: expires_in = int(expires_in_str) - except ValueError, why: + except ValueError as why: raise ProtocolError('Invalid expires_in field: %s' % (why[0],)) # OpenID 1 has funny association session behaviour. @@ -1384,8 +1377,7 @@ def _extractAssociation(self, assoc_response, assoc_session): # Session type mismatch if assoc_session.session_type != session_type: - if (assoc_response.isOpenID1() and - session_type == 'no-encryption'): + if (assoc_response.isOpenID1() and session_type == 'no-encryption'): # In OpenID 1, any association request can result in a # 'no-encryption' association response. Setting # assoc_session to a new no-encryption session should @@ -1410,13 +1402,14 @@ def _extractAssociation(self, assoc_response, assoc_session): # type. try: secret = assoc_session.extractSecret(assoc_response) - except ValueError, why: + except ValueError as why: fmt = 'Malformed response for %s session: %s' raise ProtocolError(fmt % (assoc_session.session_type, why[0])) return Association.fromExpiresIn( expires_in, assoc_handle, secret, assoc_type) + class AuthRequest(object): """An object that holds the state necessary for generating an OpenID authentication request. This object holds the association @@ -1550,11 +1543,7 @@ def getMessage(self, realm, return_to=None, immediate=False): realm_key = 'realm' message.updateArgs(OPENID_NS, - { - realm_key:realm, - 'mode':mode, - 'return_to':return_to, - }) + {realm_key: realm, 'mode': mode, 'return_to': return_to}) if not self._anonymous: if self.endpoint.isOPIdentifier(): @@ -1623,8 +1612,7 @@ def redirectURL(self, realm, return_to=None, immediate=False): message = self.getMessage(realm, return_to, immediate) return message.toURL(self.endpoint.server_url) - def formMarkup(self, realm, return_to=None, immediate=False, - form_tag_attrs=None): + def formMarkup(self, realm, return_to=None, immediate=False, form_tag_attrs=None): """Get html for a form to submit this request to the IDP. @param form_tag_attrs: Dictionary of attributes to be added to @@ -1634,11 +1622,9 @@ def formMarkup(self, realm, return_to=None, immediate=False, @type form_tag_attrs: {unicode: unicode} """ message = self.getMessage(realm, return_to, immediate) - return message.toFormMarkup(self.endpoint.server_url, - form_tag_attrs) + return message.toFormMarkup(self.endpoint.server_url, form_tag_attrs) - def htmlMarkup(self, realm, return_to=None, immediate=False, - form_tag_attrs=None): + def htmlMarkup(self, realm, return_to=None, immediate=False, form_tag_attrs=None): """Get an autosubmitting HTML page that submits this request to the IDP. This is just a wrapper for formMarkup. @@ -1646,10 +1632,7 @@ def htmlMarkup(self, realm, return_to=None, immediate=False, @returns: str """ - return oidutil.autoSubmitHTML(self.formMarkup(realm, - return_to, - immediate, - form_tag_attrs)) + return oidutil.autoSubmitHTML(self.formMarkup(realm, return_to, immediate, form_tag_attrs)) def shouldSendRedirect(self): """Should this OpenID authentication request be sent as a HTTP @@ -1659,11 +1642,13 @@ def shouldSendRedirect(self): """ return self.endpoint.compatibilityMode() + FAILURE = 'failure' SUCCESS = 'success' CANCEL = 'cancel' SETUP_NEEDED = 'setup_needed' + class Response(object): status = None @@ -1694,6 +1679,7 @@ def getDisplayIdentifier(self): return self.endpoint.getDisplayIdentifier() return None + class SuccessResponse(Response): """A response with a status of SUCCESS. Indicates that this request is a successful acknowledgement from the OpenID server that the @@ -1854,6 +1840,7 @@ class CancelResponse(Response): def __init__(self, endpoint): self.setEndpoint(endpoint) + class SetupNeededResponse(Response): """A response with a status of SETUP_NEEDED. Indicates that the request was in immediate mode, and the server is unable to diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index 5764dc5f..f847c63a 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -11,12 +11,12 @@ 'OPENID_IDP_2_0_TYPE', 'OpenIDServiceEndpoint', 'discover', - ] +] import logging import urlparse -from openid import fetchers, urinorm, yadis +from openid import fetchers, urinorm from openid.consumer import html_parse from openid.message import OPENID1_NS as OPENID_1_0_MESSAGE_NS, OPENID2_NS as OPENID_2_0_MESSAGE_NS from openid.yadis import filters, xri, xrires @@ -48,7 +48,7 @@ class OpenIDServiceEndpoint(object): OPENID_2_0_TYPE, OPENID_1_1_TYPE, OPENID_1_0_TYPE, - ] + ] def __init__(self): self.claimed_id = None @@ -56,15 +56,14 @@ def __init__(self): self.type_uris = [] self.local_id = None self.canonicalID = None - self.used_yadis = False # whether this came from an XRDS + self.used_yadis = False # whether this came from an XRDS self.display_identifier = None def usesExtension(self, extension_uri): return extension_uri in self.type_uris def preferredNamespace(self): - if (OPENID_IDP_2_0_TYPE in self.type_uris or - OPENID_2_0_TYPE in self.type_uris): + if (OPENID_IDP_2_0_TYPE in self.type_uris or OPENID_2_0_TYPE in self.type_uris): return OPENID_2_0_MESSAGE_NS else: return OPENID_1_0_MESSAGE_NS @@ -74,10 +73,7 @@ def supportsType(self, type_uri): I consider C{/server} endpoints to implicitly support C{/signon}. """ - return ( - (type_uri in self.type_uris) or - (type_uri == OPENID_2_0_TYPE and self.isOPIdentifier()) - ) + return ((type_uri in self.type_uris) or (type_uri == OPENID_2_0_TYPE and self.isOPIdentifier())) def getDisplayIdentifier(self): """Return the display_identifier if set, else return the claimed_id. @@ -155,7 +151,7 @@ def fromHTML(cls, uri, html): discovery_types = [ (OPENID_2_0_TYPE, 'openid2.provider', 'openid2.local_id'), (OPENID_1_1_TYPE, 'openid.server', 'openid.delegate'), - ] + ] link_attrs = html_parse.parseLinkAttrs(html) services = [] @@ -178,7 +174,6 @@ def fromHTML(cls, uri, html): fromHTML = classmethod(fromHTML) - def fromXRDS(cls, uri, xrds): """Parse the given document as XRDS looking for OpenID services. @@ -192,7 +187,6 @@ def fromXRDS(cls, uri, xrds): fromXRDS = classmethod(fromXRDS) - def fromDiscoveryResult(cls, discoveryResult): """Create endpoints from a DiscoveryResult. @@ -213,7 +207,6 @@ def fromDiscoveryResult(cls, discoveryResult): fromDiscoveryResult = classmethod(fromDiscoveryResult) - def fromOPEndpointURL(cls, op_endpoint_url): """Construct an OP-Identifier OpenIDServiceEndpoint object for a given OP Endpoint URL @@ -228,7 +221,6 @@ def fromOPEndpointURL(cls, op_endpoint_url): fromOPEndpointURL = classmethod(fromOPEndpointURL) - def __str__(self): return ("<%s.%s " "server_url=%r " @@ -237,7 +229,7 @@ def __str__(self): "canonicalID=%r " "used_yadis=%s " ">" - % (self.__class__.__module__, self.__class__.__name__, + % (self.__class__.__module__, self.__class__.__name__, self.server_url, self.claimed_id, self.local_id, @@ -245,7 +237,6 @@ def __str__(self): self.used_yadis)) - def findOPLocalIdentifier(service_element, type_uris): """Find the OP-Local Identifier for this xrd:Service element. @@ -275,8 +266,7 @@ def findOPLocalIdentifier(service_element, type_uris): # Build the list of tags that could contain the OP-Local Identifier local_id_tags = [] - if (OPENID_1_1_TYPE in type_uris or - OPENID_1_0_TYPE in type_uris): + if (OPENID_1_1_TYPE in type_uris or OPENID_1_0_TYPE in type_uris): local_id_tags.append(nsTag(OPENID_1_0_NS, 'Delegate')) if OPENID_2_0_TYPE in type_uris: @@ -296,22 +286,25 @@ def findOPLocalIdentifier(service_element, type_uris): return local_id + def normalizeURL(url): """Normalize a URL, converting normalization failures to DiscoveryFailure""" try: normalized = urinorm.urinorm(url) - except ValueError, why: + except ValueError as why: raise DiscoveryFailure('Normalizing identifier: %s' % (why[0],), None) else: return urlparse.urldefrag(normalized)[0] + def normalizeXRI(xri): """Normalize an XRI, stripping its scheme if present""" if xri.startswith("xri://"): xri = xri[6:] return xri + def arrangeByType(service_list, preferred_types): """Rearrange service_list in a new list so services are ordered by types listed in preferred_types. Return the new list.""" @@ -333,9 +326,7 @@ def bestMatchingService(service): # Build a list with the service elements in tuples whose # comparison will prefer the one with the best matching service - prio_services = [(bestMatchingService(s), orig_index, s) - for (orig_index, s) in enumerate(service_list)] - prio_services.sort() + prio_services = sorted((bestMatchingService(s), orig_index, s) for (orig_index, s) in enumerate(service_list)) # Now that the services are sorted by priority, remove the sort # keys from the list. @@ -344,6 +335,7 @@ def bestMatchingService(service): return prio_services + def getOPOrUserServices(openid_services): """Extract OP Identifier services. If none found, return the rest, sorted with most preferred first according to @@ -360,6 +352,7 @@ def getOPOrUserServices(openid_services): return op_services or openid_services + def discoverYadis(uri): """Discover OpenID services for a URI. Tries Yadis and falls back on old-style discovery if Yadis fails. @@ -401,6 +394,7 @@ def discoverYadis(uri): return (yadis_url, getOPOrUserServices(openid_services)) + def discoverXRI(iname): endpoints = [] iname = normalizeXRI(iname) @@ -440,6 +434,7 @@ def discoverNoYadis(uri): claimed_id, http_resp.body) return claimed_id, openid_services + def discoverURI(uri): parsed = urlparse.urlparse(uri) if parsed[0] and parsed[1]: @@ -453,6 +448,7 @@ def discoverURI(uri): claimed_id = normalizeURL(claimed_id) return claimed_id, openid_services + def discover(identifier): if xri.identifierScheme(identifier) == "XRI": return discoverXRI(identifier) diff --git a/openid/consumer/html_parse.py b/openid/consumer/html_parse.py index 880dfda6..14ff8cc2 100644 --- a/openid/consumer/html_parse.py +++ b/openid/consumer/html_parse.py @@ -70,12 +70,17 @@ __all__ = ['parseLinkAttrs'] import re - -flags = ( re.DOTALL # Match newlines with '.' - | re.IGNORECASE - | re.VERBOSE # Allow comments and whitespace in patterns - | re.UNICODE # Make \b respect Unicode word boundaries - ) +from functools import partial + +flags = ( + # Match newlines with '.' + re.DOTALL | + re.IGNORECASE | + # Allow comments and whitespace in patterns + re.VERBOSE | + # Make \b respect Unicode word boundaries + re.UNICODE +) # Stuff to remove before we start looking for tags removed_re = re.compile(r''' @@ -123,6 +128,7 @@ ) ''' + def tagMatcher(tag_name, *close_tags): if close_tags: options = '|'.join((tag_name,) + close_tags) @@ -133,6 +139,7 @@ def tagMatcher(tag_name, *close_tags): expr = tag_expr % locals() return re.compile(expr, flags) + # Must contain at least an open html and an open head tag html_find = tagMatcher('html') head_find = tagMatcher('head', 'body') @@ -160,17 +167,20 @@ def tagMatcher(tag_name, *close_tags): # Entity replacement: replacements = { - 'amp':'&', - 'lt':'<', - 'gt':'>', - 'quot':'"', - } + 'amp': '&', + 'lt': '<', + 'gt': '>', + 'quot': '"', +} ent_replace = re.compile(r'&(%s);' % '|'.join(replacements.keys())) + + def replaceEnt(mo): "Replace the entities that are specified by OpenID" return replacements.get(mo.group(1), mo.group()) + def parseLinkAttrs(html): """Find all link tags in a string representing a HTML document and return a list of their attributes. @@ -214,6 +224,7 @@ def parseLinkAttrs(html): return matches + def relMatches(rel_attr, target_rel): """Does this target_rel appear in the rel_str?""" # XXX: TESTME @@ -225,19 +236,22 @@ def relMatches(rel_attr, target_rel): return 0 + def linkHasRel(link_attrs, target_rel): """Does this link have target_rel as a relationship?""" # XXX: TESTME rel_attr = link_attrs.get('rel') return rel_attr and relMatches(rel_attr, target_rel) + def findLinksRel(link_attrs_list, target_rel): """Filter the list of link attributes on whether it has target_rel as a relationship.""" # XXX: TESTME - matchesTarget = lambda attrs: linkHasRel(attrs, target_rel) + matchesTarget = partial(linkHasRel, target_rel=target_rel) return filter(matchesTarget, link_attrs_list) + def findFirstHref(link_attrs_list, target_rel): """Return the value of the href attribute for the first link tag in the list that has target_rel as a relationship.""" diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 769aa6c5..27ff7965 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -21,7 +21,7 @@ 'randrange', 'sha1', 'sha256', - ] +] import hashlib import hmac @@ -36,18 +36,23 @@ def __init__(self, hash_constructor): self.new = hash_constructor self.digest_size = hash_constructor().digest_size + sha1_module = HashContainer(hashlib.sha1) sha256_module = HashContainer(hashlib.sha256) + def hmacSha1(key, text): return hmac.new(key, text, sha1_module).digest() + def sha1(s): return sha1_module.new(s).digest() + def hmacSha256(key, text): return hmac.new(key, text, sha256_module).digest() + def sha256(s): return sha256_module.new(s).digest() @@ -57,22 +62,22 @@ def sha256(s): except ImportError: import pickle - def longToBinary(l): - if l == 0: + def longToBinary(value): + if value == 0: return '\x00' - return ''.join(reversed(pickle.encode_long(l))) + return ''.join(reversed(pickle.encode_long(value))) def binaryToLong(s): return pickle.decode_long(''.join(reversed(s))) else: # We have pycrypto - def longToBinary(l): - if l < 0: + def longToBinary(value): + if value < 0: raise ValueError('This function only supports positive integers') - bytes = long_to_bytes(l) + bytes = long_to_bytes(value) if ord(bytes[0]) > 127: return '\x00' + bytes else: @@ -112,6 +117,7 @@ def getBytes(n): return ''.join(bytes) else: _pool = RandomPool() + def getBytes(n, pool=_pool): if pool.entropy < n: pool.randomize() @@ -125,9 +131,9 @@ def getBytes(n, pool=_pool): # numbers larger than sys.maxint for randrange. For simplicity, # use this implementation for any Python that does not have # random.SystemRandom - from math import log, ceil _duplicate_cache = {} + def randrange(start, stop=None, step=1): if stop is None: stop = start @@ -154,7 +160,7 @@ def randrange(start, stop=None, step=1): _duplicate_cache[r] = (duplicate, nbytes) - while 1: + while True: bytes = '\x00' + getBytes(nbytes) n = binaryToLong(bytes) # Keep looping if this value is in the low duplicated range @@ -163,12 +169,15 @@ def randrange(start, stop=None, step=1): return start + (n % r) * step + def longToBase64(l): return toBase64(longToBinary(l)) + def base64ToLong(s): return binaryToLong(fromBase64(s)) + def randomString(length, chrs=None): """Produce a string of length random bytes, chosen from chrs.""" if chrs is None: @@ -177,6 +186,7 @@ def randomString(length, chrs=None): n = len(chrs) return ''.join([chrs[randrange(n)] for _ in xrange(length)]) + def const_eq(s1, s2): if len(s1) != len(s2): return False diff --git a/openid/dh.py b/openid/dh.py index 3478240b..b0400b9e 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -1,15 +1,23 @@ -from openid import cryptutil, oidutil +from openid import cryptutil + + +def _xor(a_b): + a, b = a_b + return chr(ord(a) ^ ord(b)) def strxor(x, y): if len(x) != len(y): raise ValueError('Inputs to strxor must have the same length') - xor = lambda (a, b): chr(ord(a) ^ ord(b)) - return "".join(map(xor, zip(x, y))) + return "".join(map(_xor, zip(x, y))) + class DiffieHellman(object): - DEFAULT_MOD = 155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698188993727816152646631438561595825688188889951272158842675419950341258706556549803580104870537681476726513255747040765857479291291572334510643245094715007229621094194349783925984760375594985848253359305585439638443L + DEFAULT_MOD = int('155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698' + '188993727816152646631438561595825688188889951272158842675419950341258706556549803580104870537681' + '476726513255747040765857479291291572334510643245094715007229621094194349783925984760375594985848' + '253359305585439638443') DEFAULT_GEN = 2 diff --git a/openid/extension.py b/openid/extension.py index 6366f03d..55e129b5 100644 --- a/openid/extension.py +++ b/openid/extension.py @@ -1,3 +1,5 @@ +import warnings + from openid import message as message_module diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index 6b21812b..c8fac3f4 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -5,12 +5,12 @@ """ __all__ = [ - 'AttributeRequest', + 'AttrInfo', 'FetchRequest', 'FetchResponse', 'StoreRequest', 'StoreResponse', - ] +] from openid import extension from openid.message import OPENID_NS, NamespaceMap @@ -24,6 +24,7 @@ # completeness. MINIMUM_SUPPORTED_ALIAS_LENGTH = 32 + def checkAlias(alias): """ Check an alias for invalid characters; raise AXError if any are @@ -60,11 +61,6 @@ class AXMessage(extension.Extension): be overridden in subclasses. """ - # This class is abstract, so it's OK that it doesn't override the - # abstract method in Extension: - # - #pylint:disable-msg=W0223 - ns_alias = 'ax' mode = None ns_uri = 'http://openid.net/srv/ax/1.0' @@ -90,7 +86,7 @@ def _newArgs(self): basic information that must be in every attribute exchange message. """ - return {'mode':self.mode} + return {'mode': self.mode} class AttrInfo(object): @@ -122,11 +118,6 @@ class AttrInfo(object): @type alias: str or NoneType """ - # It's OK that this class doesn't have public methods (it's just a - # holder for a bunch of attributes): - # - #pylint:disable-msg=R0903 - def __init__(self, type_uri, count=1, required=False, alias=None): self.required = required self.count = count @@ -146,6 +137,7 @@ def wantsUnlimitedValues(self): """ return self.count == UNLIMITED_VALUES + def toTypeURIs(namespace_map, alias_list_s): """Given a namespace mapping and a string containing a comma-separated list of namespace aliases, return a list of type @@ -304,7 +296,7 @@ def fromOpenIDRequest(cls, openid_request): self = cls() try: self.parseExtensionArgs(ax_args) - except NotAXMessage, err: + except NotAXMessage: return None if self.update_url: @@ -413,11 +405,6 @@ class AXKeyValueMessage(AXMessage): fetch_response and store_request. """ - # This class is abstract, so it's OK that it doesn't override the - # abstract method in Extension: - # - #pylint:disable-msg=W0223 - def __init__(self): AXMessage.__init__(self) self.data = {} @@ -652,8 +639,7 @@ def getExtensionArgs(self): values = [] zero_value_types.append(attr_info) - if (attr_info.count != UNLIMITED_VALUES) and \ - (attr_info.count < len(values)): + if (attr_info.count != UNLIMITED_VALUES) and (attr_info.count < len(values)): raise AXError( 'More than the number of requested values were ' 'specified for %r' % (attr_info.type_uri,)) @@ -671,8 +657,7 @@ def getExtensionArgs(self): kv_args['type.' + alias] = attr_info.type_uri kv_args['count.' + alias] = '0' - update_url = ((self.request and self.request.update_url) - or self.update_url) + update_url = ((self.request and self.request.update_url) or self.update_url) if update_url: ax_args['update_url'] = update_url @@ -709,7 +694,7 @@ def fromSuccessResponse(cls, success_response, signed=True): try: self.parseExtensionArgs(ax_args) - except NotAXMessage, err: + except NotAXMessage: return None else: return self @@ -762,7 +747,7 @@ def fromOpenIDRequest(cls, openid_request): self = cls() try: self.parseExtensionArgs(ax_args) - except NotAXMessage, err: + except NotAXMessage: return None return self @@ -782,8 +767,7 @@ def __init__(self, succeeded=True, error_message=None): AXMessage.__init__(self) if succeeded and error_message is not None: - raise AXError('An error message may only be included in a ' - 'failing fetch response') + raise AXError('An error message may only be included in a failing fetch response') if succeeded: self.mode = self.SUCCESS_MODE else: @@ -826,7 +810,7 @@ def fromSuccessResponse(cls, success_response, signed=True): try: self.parseExtensionArgs(ax_args) - except NotAXMessage, err: + except NotAXMessage: return None else: return self diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index b800ce2b..f9b84c84 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -13,7 +13,7 @@ 'AUTH_PHISHING_RESISTANT', 'AUTH_MULTI_FACTOR', 'AUTH_MULTI_FACTOR_PHYSICAL', - ] +] import re @@ -30,6 +30,7 @@ TIME_VALIDATOR = re.compile('^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') + class Request(Extension): """A Provider Authentication Policy request, sent from a relying party to a provider @@ -75,8 +76,8 @@ def getExtensionArgs(self): """@see: C{L{Extension.getExtensionArgs}} """ ns_args = { - 'preferred_auth_policies':' '.join(self.preferred_auth_policies) - } + 'preferred_auth_policies': ' '.join(self.preferred_auth_policies) + } if self.max_auth_age is not None: ns_args['max_auth_age'] = str(self.max_auth_age) @@ -148,6 +149,7 @@ def preferredTypes(self, supported_types): return filter(self.preferred_auth_policies.__contains__, supported_types) + Request.ns_uri = ns_uri @@ -254,12 +256,12 @@ def getExtensionArgs(self): """ if len(self.auth_policies) == 0: ns_args = { - 'auth_policies':'none', + 'auth_policies': 'none', } else: ns_args = { - 'auth_policies':' '.join(self.auth_policies), - } + 'auth_policies': ' '.join(self.auth_policies), + } if self.nist_auth_level is not None: if self.nist_auth_level not in range(0, 5): @@ -275,4 +277,5 @@ def getExtensionArgs(self): return ns_args + Response.ns_uri = ns_uri diff --git a/openid/extensions/draft/pape5.py b/openid/extensions/draft/pape5.py index e1468736..6d0b1ddf 100644 --- a/openid/extensions/draft/pape5.py +++ b/openid/extensions/draft/pape5.py @@ -15,7 +15,7 @@ 'AUTH_MULTI_FACTOR_PHYSICAL', 'LEVELS_NIST', 'LEVELS_JISA', - ] +] import re import warnings @@ -38,11 +38,12 @@ LEVELS_NIST = 'http://csrc.nist.gov/publications/nistpubs/800-63/SP800-63V1_0_2.pdf' LEVELS_JISA = 'http://www.jisa.or.jp/spec/auth_level.html' + class PAPEExtension(Extension): _default_auth_level_aliases = { 'nist': LEVELS_NIST, 'jisa': LEVELS_JISA, - } + } def __init__(self): self.auth_level_aliases = self._default_auth_level_aliases.copy() @@ -90,6 +91,7 @@ def _getAlias(self, auth_level_uri): raise KeyError(auth_level_uri) + class Request(PAPEExtension): """A Provider Authentication Policy request, sent from a relying party to a provider @@ -152,8 +154,8 @@ def getExtensionArgs(self): """@see: C{L{Extension.getExtensionArgs}} """ ns_args = { - 'preferred_auth_policies':' '.join(self.preferred_auth_policies), - } + 'preferred_auth_policies': ' '.join(self.preferred_auth_policies), + } if self.max_auth_age is not None: ns_args['max_auth_age'] = str(self.max_auth_age) @@ -266,6 +268,7 @@ def preferredTypes(self, supported_types): return filter(self.preferred_auth_policies.__contains__, supported_types) + Request.ns_uri = ns_uri @@ -455,8 +458,8 @@ def getExtensionArgs(self): } else: ns_args = { - 'auth_policies':' '.join(self.auth_policies), - } + 'auth_policies': ' '.join(self.auth_policies), + } for level_type, level in self.auth_levels.iteritems(): alias = self._getAlias(level_type) @@ -471,4 +474,5 @@ def getExtensionArgs(self): return ns_args + Response.ns_uri = ns_uri diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index e147cf16..786aeeaf 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -48,22 +48,23 @@ 'ns_uri_1_0', 'ns_uri_1_1', 'supportsSReg', - ] +] _LOGGER = logging.getLogger(__name__) # The data fields that are listed in the sreg spec data_fields = { - 'fullname':'Full Name', - 'nickname':'Nickname', - 'dob':'Date of Birth', - 'email':'E-mail Address', - 'gender':'Gender', - 'postcode':'Postal Code', - 'country':'Country', - 'language':'Language', - 'timezone':'Time Zone', - } + 'fullname': 'Full Name', + 'nickname': 'Nickname', + 'dob': 'Date of Birth', + 'email': 'E-mail Address', + 'gender': 'Gender', + 'postcode': 'Postal Code', + 'country': 'Country', + 'language': 'Language', + 'timezone': 'Time Zone', +} + def checkFieldName(field_name): """Check to see that the given value is a valid simple @@ -76,6 +77,7 @@ def checkFieldName(field_name): raise ValueError('%r is not a defined simple registration field' % (field_name,)) + # URI used in the wild for Yadis documents advertising simple # registration support ns_uri_1_0 = 'http://openid.net/sreg/1.0' @@ -90,9 +92,10 @@ def checkFieldName(field_name): try: registerNamespaceAlias(ns_uri_1_1, 'sreg') -except NamespaceAliasRegistrationError, e: +except NamespaceAliasRegistrationError as e: _LOGGER.exception('registerNamespaceAlias(%r, %r) failed: %s', ns_uri_1_1, 'sreg', e) + def supportsSReg(endpoint): """Does the given endpoint advertise support for simple registration? @@ -106,6 +109,7 @@ def supportsSReg(endpoint): return (endpoint.usesExtension(ns_uri_1_1) or endpoint.usesExtension(ns_uri_1_0)) + class SRegNamespaceError(ValueError): """The simple registration namespace was not found and could not be created using the expected name (there's another extension @@ -120,6 +124,7 @@ class SRegNamespaceError(ValueError): the message that is being processed. """ + def getSRegNS(message): """Extract the simple registration namespace URI from the given OpenID message. Handles OpenID 1 and 2, as well as both sreg @@ -151,14 +156,13 @@ def getSRegNS(message): sreg_ns_uri = ns_uri_1_1 try: message.namespaces.addAlias(ns_uri_1_1, 'sreg') - except KeyError, why: + except KeyError as why: # An alias for the string 'sreg' already exists, but it's # defined for something other than simple registration raise SRegNamespaceError(why[0]) - # we know that sreg_ns_uri defined, because it's defined in the - # else clause of the loop as well, so disable the warning - return sreg_ns_uri #pylint:disable-msg=W0631 + return sreg_ns_uri + class SRegRequest(Extension): """An object to hold the state of a simple registration request. @@ -368,6 +372,7 @@ def getExtensionArgs(self): return args + class SRegResponse(Extension): """Represents the data returned in a simple registration response inside of an OpenID C{id_res} response. This object will be diff --git a/openid/fetchers.py b/openid/fetchers.py index b30f8954..750b5f55 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -32,6 +32,7 @@ USER_AGENT = "python-openid/%s (%s)" % (openid.__version__, sys.platform) MAX_RESPONSE_KB = 1024 + def fetch(url, body=None, headers=None): """Invoke the fetch method on the default fetcher. Most users should need only this method. @@ -41,6 +42,7 @@ def fetch(url, body=None, headers=None): fetcher = getDefaultFetcher() return fetcher.fetch(url, body, headers) + def createHTTPFetcher(): """Create a default HTTP fetcher instance @@ -52,11 +54,13 @@ def createHTTPFetcher(): return fetcher + # Contains the currently set HTTP fetcher. If it is set to None, the # library will call createHTTPFetcher() to set it. Do not access this # variable outside of this module. _default_fetcher = None + def getDefaultFetcher(): """Return the default fetcher instance if no fetcher has been set, it will create a default fetcher. @@ -71,6 +75,7 @@ def getDefaultFetcher(): return _default_fetcher + def setDefaultFetcher(fetcher, wrap_exceptions=True): """Set the default fetcher @@ -91,6 +96,7 @@ def setDefaultFetcher(fetcher, wrap_exceptions=True): else: _default_fetcher = ExceptionWrappingFetcher(fetcher) + def usingCurl(): """Whether the currently set HTTP fetcher is a Curl HTTP fetcher.""" fetcher = getDefaultFetcher() @@ -98,6 +104,7 @@ def usingCurl(): fetcher = fetcher.fetcher return isinstance(fetcher, CurlHTTPFetcher) + class HTTPResponse(object): """XXX document attributes""" headers = None @@ -116,6 +123,7 @@ def __repr__(self): self.status, self.final_url) + class HTTPFetcher(object): """ This class is the interface for openid HTTP fetchers. This @@ -145,19 +153,23 @@ def fetch(self, url, body=None, headers=None): """ raise NotImplementedError + def _allowedURL(url): return url.startswith('http://') or url.startswith('https://') + class HTTPFetchingError(Exception): """Exception that is wrapped around all exceptions that are raised by the underlying fetcher when using the ExceptionWrappingFetcher @ivar why: The exception that caused this exception """ + def __init__(self, why=None): Exception.__init__(self, why) self.why = why + class ExceptionWrappingFetcher(HTTPFetcher): """Fetcher wrapper which wraps all exceptions to `HTTPFetchingError`.""" @@ -175,6 +187,7 @@ def fetch(self, *args, **kwargs): raise HTTPFetchingError(why=exc_inst) + class Urllib2Fetcher(HTTPFetcher): """An C{L{HTTPFetcher}} that uses urllib2. """ @@ -201,7 +214,7 @@ def fetch(self, url, body=None, headers=None): return self._makeResponse(f) finally: f.close() - except urllib2.HTTPError, why: + except urllib2.HTTPError as why: try: return self._makeResponse(why) finally: @@ -220,6 +233,7 @@ def _makeResponse(self, urllib2_response): return resp + class HTTPError(HTTPFetchingError): """ This exception is raised by the C{L{CurlHTTPFetcher}} when it @@ -228,12 +242,14 @@ class HTTPError(HTTPFetchingError): pass # XXX: define what we mean by paranoid, and make sure it is. + + class CurlHTTPFetcher(HTTPFetcher): """ An C{L{HTTPFetcher}} that uses pycurl for fetching. See U{http://pycurl.sourceforge.net/}. """ - ALLOWED_TIME = 20 # seconds + ALLOWED_TIME = 20 # seconds def __init__(self): HTTPFetcher.__init__(self) @@ -244,7 +260,7 @@ def _parseHeaders(self, header_file): header_file.seek(0) # Remove the status line from the beginning of the input - unused_http_status_line = header_file.readline().lower () + unused_http_status_line = header_file.readline().lower() if unused_http_status_line.startswith('http/1.1 100 '): unused_http_status_line = header_file.readline() unused_http_status_line = header_file.readline() @@ -309,8 +325,9 @@ def fetch(self, url, body=None, headers=None): raise HTTPError("Fetching URL not allowed: %r" % (url,)) data = cStringIO.StringIO() + def write_data(chunk): - if data.tell() > 1024*MAX_RESPONSE_KB: + if data.tell() > 1024 * MAX_RESPONSE_KB: return 0 else: return data.write(chunk) @@ -350,6 +367,7 @@ def write_data(chunk): finally: c.close() + class HTTPLib2Fetcher(HTTPFetcher): """A fetcher that uses C{httplib2} for performing HTTP requests. This implementation supports HTTP caching. @@ -419,4 +437,4 @@ def fetch(self, url, body=None, headers=None): final_url=final_url, headers=dict(httplib2_response.items()), status=httplib2_response.status, - ) + ) diff --git a/openid/kvform.py b/openid/kvform.py index 8252d91a..e0e91a0d 100644 --- a/openid/kvform.py +++ b/openid/kvform.py @@ -9,6 +9,7 @@ class KVFormError(ValueError): pass + def seqToKV(seq, strict=False): """Represent a sequence of pairs of strings as newline-terminated key:value pairs. The pairs are generated in the order given. @@ -62,6 +63,7 @@ def err(msg): return ''.join(lines).encode('UTF8') + def kvToSeq(data, strict=False): """ @@ -116,10 +118,11 @@ def err(msg): return pairs + def dictToKV(d): - seq = d.items() - seq.sort() + seq = sorted(d.items()) return seqToKV(seq) + def kvToDict(s): return dict(kvToSeq(s)) diff --git a/openid/message.py b/openid/message.py index 92706d93..9c487d60 100644 --- a/openid/message.py +++ b/openid/message.py @@ -55,23 +55,27 @@ 'dh_consumer_public', 'claimed_id', 'identity', 'realm', 'invalidate_handle', 'op_endpoint', 'response_nonce', 'sig', 'assoc_handle', 'trust_root', 'openid', - ] +] + class UndefinedOpenIDNamespace(ValueError): """Raised if the generic OpenID namespace is accessed when there is no OpenID namespace set for this message.""" + class InvalidOpenIDNamespace(ValueError): """Raised if openid.ns is not a recognized value. For recognized values, see L{Message.allowed_openid_namespaces} """ + def __str__(self): s = "Invalid OpenID Namespace" if self.args: s += " %r" % (self.args[0],) return s + class InvalidNamespace(KeyError): """ Raised if there is problem with other namespaces than OpenID namespace @@ -86,12 +90,14 @@ class InvalidNamespace(KeyError): # registerNamespaceAlias. registered_aliases = {} + class NamespaceAliasRegistrationError(Exception): """ Raised when an alias or namespace URI has already been registered. """ pass + def registerNamespaceAlias(namespace_uri, alias): """ Registers a (namespace URI, alias) mapping in a global namespace @@ -106,15 +112,14 @@ def registerNamespaceAlias(namespace_uri, alias): return if namespace_uri in registered_aliases.values(): - raise NamespaceAliasRegistrationError, \ - 'Namespace uri %r already registered' % (namespace_uri,) + raise NamespaceAliasRegistrationError('Namespace uri %r already registered' % (namespace_uri,)) if alias in registered_aliases: - raise NamespaceAliasRegistrationError, \ - 'Alias %r already registered' % (alias,) + raise NamespaceAliasRegistrationError('Alias %r already registered' % (alias,)) registered_aliases[alias] = namespace_uri + class Message(object): """ In the implementation of this object, None represents the global @@ -158,7 +163,6 @@ def fromPostArgs(cls, args): raise TypeError("query dict must have one value for each key, " "not lists of values. Query is %r" % (args,)) - try: prefix, rest = key.split('.', 1) except ValueError: @@ -348,7 +352,7 @@ def toFormMarkup(self, action_url, form_tag_attrs=None, form.append(ElementTree.Element(u'input', attrs)) submit = ElementTree.Element(u'input', - {u'type':'submit', u'value':oidutil.toUnicode(submit_text)}) + {u'type': 'submit', u'value': oidutil.toUnicode(submit_text)}) form.append(submit) return ElementTree.tostring(form, encoding='utf-8') @@ -367,8 +371,7 @@ def toKVForm(self): def toURLEncoded(self): """Generate an x-www-urlencoded string""" - args = self.toPostArgs().items() - args.sort() + args = sorted(self.toPostArgs().items()) return urllib.urlencode(args) def _fixNS(self, namespace): @@ -464,7 +467,7 @@ def getArgs(self, namespace): for ((pair_ns, ns_key), value) in self.args.iteritems() if pair_ns == namespace - ]) + ]) def updateArgs(self, namespace, updates): """Set multiple key/value pairs in one call @@ -497,11 +500,9 @@ def __repr__(self): def __eq__(self, other): return self.args == other.args - def __ne__(self, other): return not (self == other) - def getAliasedArg(self, aliased_key, default=None): if aliased_key == 'ns': return self.getOpenIDNamespace() @@ -530,9 +531,11 @@ def getAliasedArg(self, aliased_key, default=None): return self.getArg(ns, key, default) + class NamespaceMap(object): """Maintains a bijective map between namespace uris and aliases. """ + def __init__(self): self.alias_to_namespace = {} self.namespace_to_alias = {} @@ -564,8 +567,7 @@ def addAlias(self, namespace_uri, desired_alias, implicit=False): """ # Check that desired_alias is not an openid protocol field as # per the spec. - assert desired_alias not in OPENID_PROTOCOL_FIELDS, \ - "%r is not an allowed namespace alias" % (desired_alias,) + assert desired_alias not in OPENID_PROTOCOL_FIELDS, "%r is not an allowed namespace alias" % (desired_alias,) # Check that desired_alias does not contain a period as per # the spec. @@ -576,8 +578,7 @@ def addAlias(self, namespace_uri, desired_alias, implicit=False): # Check that there is not a namespace already defined for # the desired alias current_namespace_uri = self.alias_to_namespace.get(desired_alias) - if (current_namespace_uri is not None - and current_namespace_uri != namespace_uri): + if (current_namespace_uri is not None and current_namespace_uri != namespace_uri): fmt = ('Cannot map %r to alias %r. ' '%r is already mapped to alias %r') diff --git a/openid/oidutil.py b/openid/oidutil.py index a92b453a..13954b76 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -9,8 +9,6 @@ import binascii import logging -import sys -import urlparse from urllib import urlencode _LOGGER = logging.getLogger(__name__) @@ -21,7 +19,8 @@ 'xml.etree.ElementTree', 'cElementTree', 'elementtree.ElementTree', - ] +] + def toUnicode(value): """Returns the given argument as a unicode object. @@ -35,6 +34,7 @@ def toUnicode(value): return value.decode('utf-8') return unicode(value) + def autoSubmitHTML(form, title='OpenID transaction in progress'): return """ @@ -53,6 +53,7 @@ def autoSubmitHTML(form, title='OpenID transaction in progress'): """ % (title, form) + def importElementTree(module_names=None): """Find a working ElementTree implementation, trying the standard places that such a thing might show up. @@ -76,9 +77,7 @@ def importElementTree(module_names=None): # Make sure it can actually parse XML try: ElementTree.XML('') - except (SystemExit, MemoryError, AssertionError): - raise - except: + except Exception: logging.exception('Not using ElementTree library %r because it failed to parse a trivial document: %s', mod_name) else: @@ -89,6 +88,7 @@ def importElementTree(module_names=None): 'Tried importing %r' % (module_names,) ) + def log(message, level=0): """Handle a log message from the OpenID library. @@ -109,6 +109,7 @@ def log(message, level=0): logging.error("This is a legacy log message, please use the logging module. Message: %s", message) + def appendArgs(url, args): """Append query arguments to a HTTP(s) URL. If the URL already has query arguemtns, these arguments will be added, and the existing @@ -129,8 +130,7 @@ def appendArgs(url, args): @rtype: str """ if hasattr(args, 'items'): - args = args.items() - args.sort() + args = sorted(args.items()) else: args = list(args) @@ -146,10 +146,10 @@ def appendArgs(url, args): # about the encodings of plain bytes (str). i = 0 for k, v in args: - if type(k) is not str: + if not isinstance(k, str): k = k.encode('UTF-8') - if type(v) is not str: + if not isinstance(v, str): v = v.encode('UTF-8') args[i] = (k, v) @@ -157,17 +157,20 @@ def appendArgs(url, args): return '%s%s%s' % (url, sep, urlencode(args)) + def toBase64(s): """Represent string s as base64, omitting newlines""" return binascii.b2a_base64(s)[:-1] + def fromBase64(s): try: return binascii.a2b_base64(s) - except binascii.Error, why: + except binascii.Error as why: # Convert to a common exception type raise ValueError(why[0]) + class Symbol(object): """This class implements an object that compares equal to others of the same type that have the same name. These are distict from @@ -178,13 +181,13 @@ def __init__(self, name): self.name = name def __eq__(self, other): - return type(self) is type(other) and self.name == other.name + return type(self) == type(other) and self.name == other.name def __ne__(self, other): return not (self == other) def __hash__(self): return hash((self.__class__, self.name)) - + def __repr__(self): return '' % (self.name,) diff --git a/openid/server/server.py b/openid/server/server.py index 1e456e0a..436b8add 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -144,6 +144,7 @@ UNUSED = None + class OpenIDRequest(object): """I represent an incoming OpenID request. @@ -190,7 +191,6 @@ def __init__(self, assoc_handle, signed, invalidate_handle=None): self.invalidate_handle = invalidate_handle self.namespace = OPENID2_NS - def fromMessage(klass, message, op_endpoint=UNUSED): """Construct me from an OpenID Message. @@ -206,7 +206,7 @@ def fromMessage(klass, message, op_endpoint=UNUSED): self.sig = message.getArg(OPENID_NS, 'sig') if (self.assoc_handle is None or - self.sig is None): + self.sig is None): fmt = "%s request missing required parameter from message %s" raise ProtocolError( message, text=fmt % (self.mode, message)) @@ -253,7 +253,6 @@ def answer(self, signatory): OPENID_NS, 'invalidate_handle', self.invalidate_handle) return response - def __str__(self): if self.invalidate_handle: ih = " invalidate? %r" % (self.invalidate_handle,) @@ -330,7 +329,7 @@ def fromMessage(cls, message): dh_modulus = message.getArg(OPENID_NS, 'dh_modulus') dh_gen = message.getArg(OPENID_NS, 'dh_gen') if (dh_modulus is None and dh_gen is not None or - dh_gen is None and dh_modulus is not None): + dh_gen is None and dh_modulus is not None): if dh_modulus is None: missing = 'modulus' @@ -367,13 +366,15 @@ def answer(self, secret): return { 'dh_server_public': cryptutil.longToBase64(self.dh.public), 'enc_mac_key': oidutil.toBase64(mac_key), - } + } + class DiffieHellmanSHA256ServerSession(DiffieHellmanSHA1ServerSession): session_type = 'DH-SHA256' hash_func = staticmethod(cryptutil.sha256) allowed_assoc_types = ['HMAC-SHA256'] + class AssociateRequest(OpenIDRequest): """A request to establish an X{association}. @@ -397,7 +398,7 @@ class AssociateRequest(OpenIDRequest): 'no-encryption': PlainTextServerSession, 'DH-SHA1': DiffieHellmanSHA1ServerSession, 'DH-SHA256': DiffieHellmanSHA256ServerSession, - } + } def __init__(self, session, assoc_type): """Construct me. @@ -410,7 +411,6 @@ def __init__(self, session, assoc_type): self.assoc_type = assoc_type self.namespace = OPENID2_NS - def fromMessage(klass, message, op_endpoint=UNUSED): """Construct me from an OpenID Message. @@ -423,7 +423,7 @@ def fromMessage(klass, message, op_endpoint=UNUSED): session_type = message.getArg(OPENID_NS, 'session_type') if session_type == 'no-encryption': _LOGGER.warn('Received OpenID 1 request with a no-encryption ' - 'assocaition session type. Continuing anyway.') + 'assocaition session type. Continuing anyway.') elif not session_type: session_type = 'no-encryption' @@ -449,7 +449,7 @@ def fromMessage(klass, message, op_endpoint=UNUSED): try: session = session_class.fromMessage(message) - except ValueError, why: + except ValueError as why: raise ProtocolError(message, 'Error parsing %s session: %s' % (session_class.session_type, why[0])) @@ -479,7 +479,7 @@ def answer(self, assoc): 'expires_in': '%d' % (assoc.getExpiresIn(),), 'assoc_type': self.assoc_type, 'assoc_handle': assoc.handle, - }) + }) response.fields.updateArgs(OPENID_NS, self.session.answer(assoc.secret)) @@ -513,6 +513,7 @@ def answerUnsupported(self, message, preferred_association_type=None, return response + class CheckIDRequest(OpenIDRequest): """A request to confirm the identity of a user. @@ -571,8 +572,7 @@ def __init__(self, identity, return_to, trust_root=None, immediate=False, self.immediate = False self.mode = "checkid_setup" - if self.return_to is not None and \ - not TrustRoot.parse(self.return_to): + if self.return_to is not None and not TrustRoot.parse(self.return_to): raise MalformedReturnURL(None, self.return_to) if not self.trustRootValid(): raise UntrustedReturnURL(None, self.return_to, self.trust_root) @@ -650,8 +650,7 @@ def fromMessage(klass, message, op_endpoint): # Using 'or' here is slightly different than sending a default # argument to getArg, as it will treat no value and an empty # string as equivalent. - self.trust_root = (message.getArg(OPENID_NS, trust_root_param) - or self.return_to) + self.trust_root = (message.getArg(OPENID_NS, trust_root_param) or self.return_to) if not message.isOpenID1(): if self.return_to is self.trust_root is None: @@ -666,8 +665,7 @@ def fromMessage(klass, message, op_endpoint): # is a valid URL. Not all trust roots are valid return_to URLs, # however (particularly ones with wildcards), so this is still a # little sketchy. - if self.return_to is not None and \ - not TrustRoot.parse(self.return_to): + if self.return_to is not None and not TrustRoot.parse(self.return_to): raise MalformedReturnURL(message, self.return_to) # I first thought that checking to see if the return_to is within @@ -798,10 +796,10 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): if allow: mode = 'id_res' elif self.message.isOpenID1(): - if self.immediate: - mode = 'id_res' - else: - mode = 'cancel' + if self.immediate: + mode = 'id_res' + else: + mode = 'cancel' else: if self.immediate: mode = 'setup_needed' @@ -829,8 +827,7 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): normalized_request_identity = urinorm(self.identity) normalized_answer_identity = urinorm(identity) - if (normalized_request_identity != - normalized_answer_identity): + if normalized_request_identity != normalized_answer_identity: raise ValueError( "Request was for identity %r, cannot reply " "with identity %r" % (self.identity, identity)) @@ -851,13 +848,13 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): raise ValueError( "Request was an OpenID 1 request, so response must " "include an identifier." - ) + ) response.fields.updateArgs(OPENID_NS, { 'mode': mode, 'return_to': self.return_to, 'response_nonce': mkNonce(), - }) + }) if server_url: response.fields.setArg(OPENID_NS, 'op_endpoint', server_url) @@ -888,7 +885,6 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): return response - def encodeToURL(self, server_url): """Encode this request as a URL to GET. @@ -922,7 +918,6 @@ def encodeToURL(self, server_url): response.updateArgs(OPENID_NS, q) return response.toURL(server_url) - def getCancelURL(self): """Get the URL to cancel this request. @@ -949,7 +944,6 @@ def getCancelURL(self): response.setArg(OPENID_NS, 'mode', 'cancel') return response.toURL(self.return_to) - def __repr__(self): return '<%s id:%r im:%s tr:%r ah:%r>' % (self.__class__.__name__, self.identity, @@ -958,7 +952,6 @@ def __repr__(self): self.assoc_handle) - class OpenIDResponse(object): """I am a response to an OpenID request. @@ -995,7 +988,6 @@ def __str__(self): self.request.__class__.__name__, self.fields) - def toFormMarkup(self, form_tag_attrs=None): """Returns the form markup for this response. @@ -1033,7 +1025,6 @@ def renderAsForm(self): """ return self.whichEncoding() == ENCODE_HTML_FORM - def needsSigning(self): """Does this response require signing? @@ -1041,7 +1032,6 @@ def needsSigning(self): """ return self.fields.getArg(OPENID_NS, 'mode') == 'id_res' - # implements IEncodable def whichEncoding(self): @@ -1061,7 +1051,6 @@ def whichEncoding(self): else: return ENCODE_KVFORM - def encodeToURL(self): """Encode a response as a URL for the user agent to GET. @@ -1072,7 +1061,6 @@ def encodeToURL(self): """ return self.fields.toURL(self.request.return_to) - def addExtension(self, extension_response): """ Add an extension response to this response message. @@ -1086,7 +1074,6 @@ def addExtension(self, extension_response): """ extension_response.toMessage(self.fields) - def encodeToKVForm(self): """Encode a response in key-value colon/newline format. @@ -1101,7 +1088,6 @@ def encodeToKVForm(self): return self.fields.toKVForm() - class WebResponse(object): """I am a response to an OpenID request in terms a web server understands. @@ -1132,7 +1118,6 @@ def __init__(self, code=HTTP_OK, headers=None, body=""): self.body = body - class Signatory(object): """I sign things. @@ -1146,7 +1131,7 @@ class Signatory(object): @type SECRET_LIFETIME: int """ - SECRET_LIFETIME = 14 * 24 * 60 * 60 # 14 days, in seconds + SECRET_LIFETIME = 14 * 24 * 60 * 60 # 14 days, in seconds # keys have a bogus server URL in them because the filestore # really does expect that key to be a URL. This seems a little @@ -1155,7 +1140,6 @@ class Signatory(object): _normal_key = 'http://localhost/|normal' _dumb_key = 'http://localhost/|dumb' - def __init__(self, store): """Create a new Signatory. @@ -1165,7 +1149,6 @@ def __init__(self, store): assert store is not None self.store = store - def verify(self, assoc_handle, message): """Verify that the signature for some data is valid. @@ -1186,12 +1169,11 @@ def verify(self, assoc_handle, message): try: valid = assoc.checkMessageSignature(message) - except ValueError, ex: + except ValueError as ex: _LOGGER.exception("Error in verifying %s with %s: %s", message, assoc, ex) return False return valid - def sign(self, response): """Sign a response. @@ -1232,11 +1214,10 @@ def sign(self, response): try: signed_response.fields = assoc.signMessage(signed_response.fields) - except kvform.KVFormError, err: + except kvform.KVFormError as err: raise EncodingError(response, explanation=str(err)) return signed_response - def createAssociation(self, dumb=True, assoc_type='HMAC-SHA1'): """Make a new association. @@ -1264,7 +1245,6 @@ def createAssociation(self, dumb=True, assoc_type='HMAC-SHA1'): self.store.storeAssociation(key, assoc) return assoc - def getAssociation(self, assoc_handle, dumb, checkExpiration=True): """Get the association with the specified handle. @@ -1299,7 +1279,6 @@ def getAssociation(self, assoc_handle, dumb, checkExpiration=True): assoc = None return assoc - def invalidate(self, assoc_handle, dumb): """Invalidates the association with the given handle. @@ -1315,7 +1294,6 @@ def invalidate(self, assoc_handle, dumb): self.store.removeAssociation(key, assoc_handle) - class Encoder(object): """I encode responses in to L{WebResponses}. @@ -1327,7 +1305,6 @@ class Encoder(object): responseFactory = WebResponse - def encode(self, response): """Encode a response to a L{WebResponse}. @@ -1353,7 +1330,6 @@ def encode(self, response): return wr - class SigningEncoder(Encoder): """I encode responses in to L{WebResponses}, signing them when required. """ @@ -1366,7 +1342,6 @@ def __init__(self, signatory): """ self.signatory = signatory - def encode(self, response): """Encode a response to a L{WebResponse}, signing it first if appropriate. @@ -1390,7 +1365,6 @@ def encode(self, response): return super(SigningEncoder, self).encode(response) - class Decoder(object): """I decode an incoming web request in to a L{OpenIDRequest}. """ @@ -1400,7 +1374,7 @@ class Decoder(object): 'checkid_immediate': CheckIDRequest.fromMessage, 'check_authentication': CheckAuthRequest.fromMessage, 'associate': AssociateRequest.fromMessage, - } + } def __init__(self, server): """Construct a Decoder. @@ -1431,7 +1405,7 @@ def decode(self, query): try: message = Message.fromPostArgs(query) - except InvalidOpenIDNamespace, err: + except InvalidOpenIDNamespace as err: # It's useful to have a Message attached to a ProtocolError, so we # override the bad ns value to build a Message out of it. Kinda # kludgy, since it's made of lies, but the parts that aren't lies @@ -1440,7 +1414,7 @@ def decode(self, query): query['openid.ns'] = OPENID2_NS message = Message.fromPostArgs(query) raise ProtocolError(message, str(err)) - except InvalidNamespace, err: + except InvalidNamespace as err: # If openid.ns is OK, but there is problem with other namespaces # We keep only bare parts of query and we try to make a ProtocolError from it query = [(key, value) for key, value in query.items() if key.count('.') < 2] @@ -1455,7 +1429,6 @@ def decode(self, query): handler = self._handlers.get(mode, self.defaultDecoder) return handler(message, self.server.op_endpoint) - def defaultDecoder(self, message, server): """Called to decode queries when no handler for that mode is found. @@ -1467,7 +1440,6 @@ def defaultDecoder(self, message, server): raise ProtocolError(message, text=fmt % (mode,)) - class Server(object): """I handle requests for an OpenID server. @@ -1521,13 +1493,7 @@ class Server(object): encoderClass = SigningEncoder decoderClass = Decoder - def __init__( - self, - store, - op_endpoint=None, - signatoryClass=None, - encoderClass=None, - decoderClass=None): + def __init__(self, store, op_endpoint=None, signatoryClass=None, encoderClass=None, decoderClass=None): """A new L{Server}. @param store: The back-end where my associations are stored. @@ -1570,7 +1536,6 @@ def __init__( stacklevel=2) self.op_endpoint = op_endpoint - def handleRequest(self, request): """Handle a request. @@ -1592,7 +1557,6 @@ def handleRequest(self, request): "%s has no handler for a request of mode %r." % (self, request.mode)) - def openid_check_authentication(self, request): """Handle and respond to C{check_authentication} requests. @@ -1600,7 +1564,6 @@ def openid_check_authentication(self, request): """ return request.answer(self.signatory) - def openid_associate(self, request): """Handle and respond to C{associate} requests. @@ -1616,14 +1579,12 @@ def openid_associate(self, request): else: message = ('Association type %r is not supported with ' 'session type %r' % (assoc_type, session_type)) - (preferred_assoc_type, preferred_session_type) = \ - self.negotiator.getAllowedType() + (preferred_assoc_type, preferred_session_type) = self.negotiator.getAllowedType() return request.answerUnsupported( message, preferred_assoc_type, preferred_session_type) - def decodeRequest(self, query): """Transform query parameters into an L{OpenIDRequest}. @@ -1643,7 +1604,6 @@ def decodeRequest(self, query): """ return self.decoder.decode(query) - def encodeResponse(self, response): """Encode a response to a L{WebResponse}, signing it first if appropriate. @@ -1659,7 +1619,6 @@ def encodeResponse(self, response): return self.encoder.encode(response) - class ProtocolError(Exception): """A message did not conform to the OpenID protocol. @@ -1683,7 +1642,6 @@ def __init__(self, message, text=None, reference=None, contact=None): assert type(message) not in [str, unicode] Exception.__init__(self, text) - def getReturnTo(self): """Get the return_to argument from the request, if any. @@ -1778,13 +1736,11 @@ def whichEncoding(self): return None - class VersionError(Exception): """Raised when an operation was attempted that is not compatible with the protocol version being used.""" - class NoReturnToError(Exception): """Raised when a response to a request cannot be generated because the request contains no return_to URL. @@ -1792,7 +1748,6 @@ class NoReturnToError(Exception): pass - class EncodingError(Exception): """Could not encode this as a protocol message. @@ -1821,7 +1776,6 @@ class AlreadySigned(EncodingError): """This response is already signed.""" - class UntrustedReturnURL(ProtocolError): """A return_to is outside the trust_root.""" @@ -1837,12 +1791,12 @@ def __str__(self): class MalformedReturnURL(ProtocolError): """The return_to URL doesn't look like a valid URL.""" + def __init__(self, openid_message, return_to): self.return_to = return_to ProtocolError.__init__(self, openid_message) - class MalformedTrustRoot(ProtocolError): """The trust root is not well-formed. @@ -1851,7 +1805,7 @@ class MalformedTrustRoot(ProtocolError): pass -#class IEncodable: # Interface +# class IEncodable: # Interface # def encodeToURL(return_to): # """Encode a response as a URL for redirection. # diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index 955a0d8b..ec771b9b 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -12,10 +12,10 @@ __all__ = [ 'TrustRoot', 'RP_RETURN_TO_URL_TYPE', - 'extractReturnToURLs', + 'getAllowedReturnURLs', 'returnToMatches', 'verifyReturnTo', - ] +] import logging import re @@ -66,11 +66,13 @@ host_segment_re = re.compile( r"(?:[-a-zA-Z0-9!$&'\(\)\*+,;=._~]|%[a-zA-Z0-9]{2})+$") + class RealmVerificationRedirected(Exception): """Attempting to verify this realm resulted in a redirect. @since: 2.1.0 """ + def __init__(self, relying_party_url, rp_url_after_redirects): self.relying_party_url = relying_party_url self.rp_url_after_redirects = rp_url_after_redirects @@ -111,6 +113,7 @@ def _parseURL(url): return proto, host, port, path + class TrustRoot(object): """ This class represents an OpenID trust root. The C{L{parse}} @@ -178,7 +181,7 @@ def isSane(self): if self.wildcard: if len(tld) == 2 and len(host_parts[-2]) <= 3: # It's a 2-letter tld with a short second to last segment - # so there needs to be more than two segments specified + # so there needs to be more than two segments specified # (e.g. *.co.uk is insane) return len(host_parts) > 2 @@ -239,8 +242,7 @@ def validateURL(self, url): else: allowed = '?/' - return (self.path[-1] in allowed or - path[path_len] in allowed) + return (self.path[-1] in allowed or path[path_len] in allowed) return True @@ -352,12 +354,14 @@ def __repr__(self): def __str__(self): return repr(self) + # The URI for relying party discovery, used in realm verification. # # XXX: This should probably live somewhere else (like in # openid.consumer or openid.yadis somewhere) RP_RETURN_TO_URL_TYPE = 'http://specs.openid.net/auth/2.0/return_to' + def _extractReturnURL(endpoint): """If the endpoint is a relying party OpenID return_to endpoint, return the endpoint URL. Otherwise, return None. @@ -380,6 +384,7 @@ def _extractReturnURL(endpoint): else: return None + def returnToMatches(allowed_return_to_urls, return_to): """Is the return_to URL under one of the supplied allowed return_to URLs? @@ -394,7 +399,8 @@ def returnToMatches(allowed_return_to_urls, return_to): # a wildcard. return_realm = TrustRoot.parse(allowed_return_to) - if (# Parses as a trust root + if ( + # Parses as a trust root return_realm is not None and # Does not have a wildcard @@ -402,12 +408,13 @@ def returnToMatches(allowed_return_to_urls, return_to): # Matches the return_to that we passed in with it return_realm.validateURL(return_to) - ): + ): return True # No URL in the list matched return False + def getAllowedReturnURLs(relying_party_url): """Given a relying party discovery URL return a list of return_to URLs. @@ -424,6 +431,8 @@ def getAllowedReturnURLs(relying_party_url): return return_to_urls # _vrfy parameter is there to make testing easier + + def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): """Verify that a return_to URL is valid for the given realm. @@ -444,7 +453,7 @@ def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): try: allowable_urls = _vrfy(realm.buildDiscoveryURL()) - except RealmVerificationRedirected, err: + except RealmVerificationRedirected as err: _LOGGER.exception(str(err)) return False diff --git a/openid/sreg.py b/openid/sreg.py index bf454d7b..bceb53fe 100644 --- a/openid/sreg.py +++ b/openid/sreg.py @@ -2,7 +2,9 @@ import warnings -from openid.extensions.sreg import * +from openid.extensions.sreg import SRegRequest, SRegResponse, data_fields, ns_uri, ns_uri_1_0, ns_uri_1_1, supportsSReg warnings.warn("openid.sreg has moved to openid.extensions.sreg", DeprecationWarning) + +__all__ = ['SRegRequest', 'SRegResponse', 'data_fields', 'ns_uri', 'ns_uri_1_0', 'ns_uri_1_1', 'supportsSReg'] diff --git a/openid/store/filestore.py b/openid/store/filestore.py index 3ec4c599..0c5c044d 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -21,6 +21,7 @@ _filename_allowed = string.ascii_letters + string.digits + '.' _isFilenameSafe = set(_filename_allowed).__contains__ + def _safe64(s): h64 = oidutil.toBase64(cryptutil.sha1(s)) h64 = h64.replace('+', '_') @@ -28,6 +29,7 @@ def _safe64(s): h64 = h64.replace('=', '') return h64 + def _filenameEscape(s): filename_chunks = [] for c in s: @@ -37,6 +39,7 @@ def _filenameEscape(s): filename_chunks.append('_%02X' % ord(c)) return ''.join(filename_chunks) + def _removeIfPresent(filename): """Attempt to remove a file, returning whether the file existed at the time of the call. @@ -45,7 +48,7 @@ def _removeIfPresent(filename): """ try: os.unlink(filename) - except OSError, why: + except OSError as why: if why.errno == ENOENT: # Someone beat us to it, but it's gone, so that's OK return 0 @@ -55,6 +58,7 @@ def _removeIfPresent(filename): # File was present return 1 + def _ensureDir(dir_name): """Create dir_name as a directory if it does not exist. If it exists, make sure that it is, in fact, a directory. @@ -65,10 +69,11 @@ def _ensureDir(dir_name): """ try: os.makedirs(dir_name) - except OSError, why: + except OSError as why: if why.errno != EEXIST or not os.path.isdir(dir_name): raise + class FileOpenIDStore(OpenIDStore): """ This is a filesystem-based store for OpenID associations and @@ -108,7 +113,7 @@ def __init__(self, directory): # directory self.temp_dir = os.path.join(directory, 'temp') - self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds + self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds self._setup() @@ -137,7 +142,9 @@ def _mktemp(self): try: file_obj = os.fdopen(fd, 'wb') return file_obj, name - except: + except Exception: + # If there was an error, don't leave the temporary file + # around. _removeIfPresent(name) raise @@ -183,7 +190,7 @@ def storeAssociation(self, server_url, association): try: os.rename(tmp, filename) - except OSError, why: + except OSError as why: if why.errno != EEXIST: raise @@ -192,7 +199,7 @@ def storeAssociation(self, server_url, association): # file, but not in putting the temporary file in place. try: os.unlink(filename) - except OSError, why: + except OSError as why: if why.errno == ENOENT: pass else: @@ -201,7 +208,7 @@ def storeAssociation(self, server_url, association): # Now the target should not exist. Try renaming again, # giving up if it fails. os.rename(tmp, filename) - except: + except Exception: # If there was an error, don't leave the temporary file # around. _removeIfPresent(tmp) @@ -252,7 +259,7 @@ def getAssociation(self, server_url, handle=None): def _getAssociation(self, filename): try: assoc_file = file(filename, 'rb') - except IOError, why: + except IOError as why: if why.errno == ENOENT: # No association exists for that URL and handle return None @@ -313,8 +320,8 @@ def useNonce(self, server_url, timestamp, salt): filename = os.path.join(self.nonce_dir, filename) try: - fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0200) - except OSError, why: + fd = os.open(filename, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o200) + except OSError as why: if why.errno == EEXIST: return False else: @@ -332,7 +339,7 @@ def _allAssocs(self): for association_filename in association_filenames: try: association_file = file(association_filename, 'rb') - except IOError, why: + except IOError as why: if why.errno == ENOENT: _LOGGER.exception("%s disappeared during %s._allAssocs", association_filename, self.__class__.__name__) diff --git a/openid/store/interface.py b/openid/store/interface.py index bb90972f..63776572 100644 --- a/openid/store/interface.py +++ b/openid/store/interface.py @@ -3,6 +3,7 @@ interface. """ + class OpenIDStore(object): """ This is the interface for the store objects the OpenID library diff --git a/openid/store/memstore.py b/openid/store/memstore.py index 89a16bdc..366a596e 100644 --- a/openid/store/memstore.py +++ b/openid/store/memstore.py @@ -49,12 +49,12 @@ def cleanup(self): return len(remove), len(self.assocs) - class MemoryStore(object): """In-process memory store. Use for single long-running processes. No persistence supplied. """ + def __init__(self): self.server_assocs = {} self.nonces = {} diff --git a/openid/store/nonce.py b/openid/store/nonce.py index 89ef096f..800dfecf 100644 --- a/openid/store/nonce.py +++ b/openid/store/nonce.py @@ -2,7 +2,7 @@ 'split', 'mkNonce', 'checkTimestamp', - ] +] import string from calendar import timegm @@ -20,6 +20,7 @@ time_fmt = '%Y-%m-%dT%H:%M:%SZ' time_str_len = len('0000-00-00T00:00:00Z') + def split(nonce_string): """Extract a timestamp from the given nonce string @@ -38,6 +39,7 @@ def split(nonce_string): raise ValueError('time out of range') return timestamp, nonce_string[time_str_len:] + def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None): """Is the timestamp that is part of the specified nonce string within the allowed clock-skew of the current time? @@ -74,6 +76,7 @@ def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None): # the past return past <= stamp <= future + def mkNonce(when=None): """Generate a nonce with the current timestamp diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index a629e726..c9e7b23a 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -4,7 +4,8 @@ Example of how to initialize a store database:: - python -c 'from openid.store import sqlstore; import pysqlite2.dbapi2; sqlstore.SQLiteStore(pysqlite2.dbapi2.connect("cstore.db")).createTables()' + python -c 'from openid.store import sqlstore; import pysqlite2.dbapi2; + sqlstore.SQLiteStore(pysqlite2.dbapi2.connect("cstore.db")).createTables()' """ import re import time @@ -29,6 +30,7 @@ def wrapped(self, *args, **kwargs): return wrapped + class SQLStore(OpenIDStore): """ This is the parent class for the SQL stores, which contains the @@ -98,14 +100,13 @@ def __init__(self, conn, associations_table=None, nonces_table=None): self._table_names = { 'associations': associations_table or self.associations_table, 'nonces': nonces_table or self.nonces_table, - } - self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds + } + self.max_nonce_age = 6 * 60 * 60 # Six hours, in seconds # DB API extension: search for "Connection Attributes .Error, # .ProgrammingError, etc." in # http://www.python.org/dev/peps/pep-0249/ - if (hasattr(self.conn, 'IntegrityError') and - hasattr(self.conn, 'OperationalError')): + if hasattr(self.conn, 'IntegrityError') and hasattr(self.conn, 'OperationalError'): self.exceptions = self.conn if not (hasattr(self.exceptions, 'IntegrityError') and @@ -139,6 +140,7 @@ def _execSQL(self, sql_name, *args): # arguments if they are passed in as unicode instead of str. # Currently the strings in our tables just have ascii in them, # so this ought to be safe. + def unicode_to_str(arg): if isinstance(arg, unicode): return str(arg) @@ -153,6 +155,7 @@ def __getattr__(self, attr): # as an attribute of this object and executes it. if attr[:3] == 'db_': sql_name = attr[3:] + '_sql' + def func(*args): return self._execSQL(sql_name, *args) setattr(self, attr, func) @@ -174,7 +177,7 @@ def _callInTransaction(self, func, *args, **kwargs): finally: self.cur.close() self.cur = None - except: + except Exception: self.conn.rollback() raise else: @@ -248,7 +251,7 @@ def txn_removeAssociation(self, server_url, handle): (str, str) -> bool """ self.db_remove_assoc(server_url, handle) - return self.cur.rowcount > 0 # -1 is undefined + return self.cur.rowcount > 0 # -1 is undefined removeAssociation = _inTxn(txn_removeAssociation) @@ -350,12 +353,13 @@ def useNonce(self, *args, **kwargs): # message from the OperationalError. try: return super(SQLiteStore, self).useNonce(*args, **kwargs) - except self.exceptions.OperationalError, why: + except self.exceptions.OperationalError as why: if re.match('^columns .* are not unique$', why[0]): return False else: raise + class MySQLStore(SQLStore): """ This is a MySQL-based specialization of C{L{SQLStore}}. @@ -417,13 +421,14 @@ class MySQLStore(SQLStore): clean_nonce_sql = 'DELETE FROM %(nonces)s WHERE timestamp < %%s;' def blobDecode(self, blob): - if type(blob) is str: + if isinstance(blob, str): # Versions of MySQLdb >= 1.2.2 return blob else: # Versions of MySQLdb prior to 1.2.2 (as far as we can tell) return blob.tostring() + class PostgreSQLStore(SQLStore): """ This is a PostgreSQL-based specialization of C{L{SQLStore}}. @@ -473,7 +478,7 @@ def db_set_assoc(self, server_url, handle, secret, issued, lifetime, assoc_type) REPLACE INTO is not supported by PostgreSQL (and is not standard SQL). """ - result = self.db_get_assoc(server_url, handle) + self.db_get_assoc(server_url, handle) rows = self.cur.fetchall() if len(rows): # Update the table since this associations already exists. diff --git a/openid/test/cryptutil.py b/openid/test/cryptutil.py index e52b6a3b..cf6074c1 100644 --- a/openid/test/cryptutil.py +++ b/openid/test/cryptutil.py @@ -7,6 +7,7 @@ # Most of the purpose of this test is to make sure that cryptutil can # find a good source of randomness on this machine. + def test_cryptrand(): # It's possible, but HIGHLY unlikely that a correct implementation # will fail by returning the same number twice @@ -17,15 +18,16 @@ def test_cryptrand(): assert len(t) == 32 assert s != t - a = cryptutil.randrange(2L ** 128) - b = cryptutil.randrange(2L ** 128) - assert type(a) is long - assert type(b) is long + a = cryptutil.randrange(2 ** 128) + b = cryptutil.randrange(2 ** 128) + assert isinstance(a, long) + assert isinstance(b, long) assert b != a # Make sure that we can generate random numbers that are larger # than platform int size - cryptutil.randrange(long(sys.maxint) + 1L) + cryptutil.randrange(long(sys.maxsize) + 1) + def test_reversed(): if hasattr(cryptutil, 'reversed'): @@ -37,10 +39,10 @@ def test_reversed(): ('abcdefg', 'gfedcba'), ([], []), ([1], [1]), - ([1,2], [2,1]), - ([1,2,3], [3,2,1]), + ([1, 2], [2, 1]), + ([1, 2, 3], [3, 2, 1]), (range(1000), range(999, -1, -1)), - ] + ] for case, expected in cases: expected = list(expected) @@ -49,28 +51,29 @@ def test_reversed(): twice = list(cryptutil.reversed(actual)) assert twice == list(case), (actual, case, twice) + def test_binaryLongConvert(): - MAX = sys.maxint + MAX = sys.maxsize for iteration in xrange(500): - n = 0L + n = 0 for i in range(10): n += long(random.randrange(MAX)) s = cryptutil.longToBinary(n) - assert type(s) is str + assert isinstance(s, str) n_prime = cryptutil.binaryToLong(s) assert n == n_prime, (n, n_prime) cases = [ - ('\x00', 0L), - ('\x01', 1L), - ('\x7F', 127L), - ('\x00\xFF', 255L), - ('\x00\x80', 128L), - ('\x00\x81', 129L), - ('\x00\x80\x00', 32768L), - ('OpenID is cool', 1611215304203901150134421257416556L) - ] + ('\x00', 0), + ('\x01', 1), + ('\x7F', 127), + ('\x00\xFF', 255), + ('\x00\x80', 128), + ('\x00\x81', 129), + ('\x00\x80\x00', 32768), + ('OpenID is cool', 1611215304203901150134421257416556) + ] for s, n in cases: n_prime = cryptutil.binaryToLong(s) @@ -78,6 +81,7 @@ def test_binaryLongConvert(): assert n == n_prime, (s, n, n_prime) assert s == s_prime, (n, s, s_prime) + def test_longToBase64(): f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) try: @@ -87,6 +91,7 @@ def test_longToBase64(): finally: f.close() + def test_base64ToLong(): f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) try: @@ -104,5 +109,6 @@ def test(): test_longToBase64() test_base64ToLong() + if __name__ == '__main__': test() diff --git a/openid/test/datadriven.py b/openid/test/datadriven.py index c7dc4f70..aac6e9db 100644 --- a/openid/test/datadriven.py +++ b/openid/test/datadriven.py @@ -31,6 +31,7 @@ def __init__(self, description): def shortDescription(self): return '%s for %s' % (self.__class__.__name__, self.description) + def loadTests(module_name): loader = unittest.defaultTestLoader this_module = __import__(module_name, {}, {}, [None]) @@ -38,8 +39,7 @@ def loadTests(module_name): tests = [] for name in dir(this_module): obj = getattr(this_module, name) - if (isinstance(obj, (type, types.ClassType)) and - issubclass(obj, unittest.TestCase)): + if isinstance(obj, (type, types.ClassType)) and issubclass(obj, unittest.TestCase): if hasattr(obj, 'loadTests'): tests.extend(obj.loadTests()) else: diff --git a/openid/test/dh.py b/openid/test/dh.py index 299730b1..01a6ab52 100644 --- a/openid/test/dh.py +++ b/openid/test/dh.py @@ -16,7 +16,7 @@ def test_strxor(): ('\x01', '\x02', '\x03'), ('\xf0', '\x0f', '\xff'), ('\xff', '\x0f', '\xf0'), - ] + ] for aa, bb, expected in cases: actual = strxor(aa, bb) @@ -28,7 +28,7 @@ def test_strxor(): (NUL * 3, NUL * 4), (''.join(map(chr, xrange(256))), ''.join(map(chr, xrange(128)))), - ] + ] for aa, bb in exc_cases: try: @@ -38,6 +38,7 @@ def test_strxor(): else: assert False, 'Expected ValueError, got %r' % (unexpected,) + def test1(): dh1 = DiffieHellman.fromDefaults() dh2 = DiffieHellman.fromDefaults() @@ -46,11 +47,13 @@ def test1(): assert secret1 == secret2 return secret1 + def test_exchange(): s1 = test1() s2 = test1() assert s1 != s2 + def test_public(): f = file(os.path.join(os.path.dirname(__file__), 'dhpriv')) dh = DiffieHellman.fromDefaults() @@ -63,10 +66,12 @@ def test_public(): finally: f.close() + def test(): test_exchange() test_public() test_strxor() + if __name__ == '__main__': test() diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 1d906d8a..32d9619c 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -9,25 +9,26 @@ data_path = os.path.join(tests_dir, 'data') testlist = [ -# success, input_name, id_name, result_name - (True, "equiv", "equiv", "xrds"), - (True, "header", "header", "xrds"), - (True, "lowercase_header", "lowercase_header", "xrds"), - (True, "xrds", "xrds", "xrds"), - (True, "xrds_ctparam", "xrds_ctparam", "xrds_ctparam"), - (True, "xrds_ctcase", "xrds_ctcase", "xrds_ctcase"), - (False, "xrds_html", "xrds_html", "xrds_html"), - (True, "redir_equiv", "equiv", "xrds"), - (True, "redir_header", "header", "xrds"), - (True, "redir_xrds", "xrds", "xrds"), - (False, "redir_xrds_html", "xrds_html", "xrds_html"), - (True, "redir_redir_equiv", "equiv", "xrds"), - (False, "404_server_response", None, None), - (False, "404_with_header", None, None), - (False, "404_with_meta", None, None), - (False, "201_server_response", None, None), - (False, "500_server_response", None, None), - ] + # success, input_name, id_name, result_name + (True, "equiv", "equiv", "xrds"), + (True, "header", "header", "xrds"), + (True, "lowercase_header", "lowercase_header", "xrds"), + (True, "xrds", "xrds", "xrds"), + (True, "xrds_ctparam", "xrds_ctparam", "xrds_ctparam"), + (True, "xrds_ctcase", "xrds_ctcase", "xrds_ctcase"), + (False, "xrds_html", "xrds_html", "xrds_html"), + (True, "redir_equiv", "equiv", "xrds"), + (True, "redir_header", "header", "xrds"), + (True, "redir_xrds", "xrds", "xrds"), + (False, "redir_xrds_html", "xrds_html", "xrds_html"), + (True, "redir_redir_equiv", "equiv", "xrds"), + (False, "404_server_response", None, None), + (False, "404_with_header", None, None), + (False, "404_with_meta", None, None), + (False, "201_server_response", None, None), + (False, "500_server_response", None, None), +] + def getDataName(*components): sanitized = [] @@ -42,15 +43,18 @@ def getDataName(*components): return os.path.join(data_path, *sanitized) + def getExampleXRDS(): filename = getDataName('example-xrds.xml') return file(filename).read() + example_xrds = getExampleXRDS() default_test_file = getDataName('test1-discover.txt') discover_tests = {} + def readTests(filename): data = file(filename).read() tests = {} @@ -59,6 +63,7 @@ def readTests(filename): tests[name] = content return tests + def getData(filename, name): global discover_tests try: @@ -67,25 +72,27 @@ def getData(filename, name): file_tests = discover_tests[filename] = readTests(filename) return file_tests[name] + def fillTemplate(test_name, template, base_url, example_xrds): mapping = [ ('URL_BASE/', base_url), ('', example_xrds), ('YADIS_HEADER', YADIS_HEADER_NAME), ('NAME', test_name), - ] + ] for k, v in mapping: template = template.replace(k, v) return template + def generateSample(test_name, base_url, example_xrds=example_xrds, filename=default_test_file): try: template = getData(filename, test_name) - except IOError, why: + except IOError as why: import errno if why[0] == errno.ENOENT: raise KeyError(filename) @@ -94,6 +101,7 @@ def generateSample(test_name, base_url, return fillTemplate(test_name, template, base_url, example_xrds) + def generateResult(base_url, input_name, id_name, result_name, success): input_url = urlparse.urljoin(base_url, input_name) diff --git a/openid/test/kvform.py b/openid/test/kvform.py index b54a64b5..7bbb5cef 100644 --- a/openid/test/kvform.py +++ b/openid/test/kvform.py @@ -17,6 +17,7 @@ def setUp(self): def tearDown(self): CatchLogs.tearDown(self) + class KVDictTest(KVBaseTest): def __init__(self, kv, dct, warnings): unittest.TestCase.__init__(self) @@ -40,6 +41,7 @@ def runTest(self): d2 = kvform.kvToDict(kv) self.failUnlessEqual(d, d2) + class KVSeqTest(KVBaseTest): def __init__(self, seq, kv, expected_warnings): unittest.TestCase.__init__(self) @@ -52,9 +54,9 @@ def cleanSeq(self, seq): and end of each value of each pair""" clean = [] for k, v in self.seq: - if type(k) is str: + if isinstance(k, str): k = k.decode('utf8') - if type(v) is str: + if isinstance(v, str): v = v.decode('utf8') clean.append((k.strip(), v.strip())) return clean @@ -63,7 +65,7 @@ def runTest(self): # seq serializes to expected kvform actual = kvform.seqToKV(self.seq) self.failUnlessEqual(self.kvform, actual) - self.failUnless(type(actual) is str) + self.assertIsInstance(actual, str) # Parse back to sequence. Expected to be unchanged, except # stripping whitespace from start and end of values @@ -74,15 +76,14 @@ def runTest(self): self.failUnlessEqual(seq, clean_seq) self.checkWarnings(self.expected_warnings) + kvdict_cases = [ # (kvform, parsed dictionary, expected warnings) ('', {}, 0), - ('college:harvey mudd\n', {'college':'harvey mudd'}, 0), - ('city:claremont\nstate:CA\n', - {'city':'claremont', 'state':'CA'}, 0), + ('college:harvey mudd\n', {'college': 'harvey mudd'}, 0), + ('city:claremont\nstate:CA\n', {'city': 'claremont', 'state': 'CA'}, 0), ('is_valid:true\ninvalidate_handle:{HMAC-SHA1:2398410938412093}\n', - {'is_valid':'true', - 'invalidate_handle':'{HMAC-SHA1:2398410938412093}'}, 0), + {'is_valid': 'true', 'invalidate_handle': '{HMAC-SHA1:2398410938412093}'}, 0), # Warnings from lines with no colon: ('x\n', {}, 1), @@ -93,18 +94,18 @@ def runTest(self): ('x\n\n', {}, 1), # Warning from empty key - (':\n', {'':''}, 1), - (':missing key\n', {'':'missing key'}, 1), + (':\n', {'': ''}, 1), + (':missing key\n', {'': 'missing key'}, 1), # Warnings from leading or trailing whitespace in key or value - (' street:foothill blvd\n', {'street':'foothill blvd'}, 1), - ('major: computer science\n', {'major':'computer science'}, 1), - (' dorm : east \n', {'dorm':'east'}, 2), + (' street:foothill blvd\n', {'street': 'foothill blvd'}, 1), + ('major: computer science\n', {'major': 'computer science'}, 1), + (' dorm : east \n', {'dorm': 'east'}, 2), # Warnings from missing trailing newline - ('e^(i*pi)+1:0', {'e^(i*pi)+1':'0'}, 1), - ('east:west\nnorth:south', {'east':'west', 'north':'south'}, 1), - ] + ('e^(i*pi)+1:0', {'e^(i*pi)+1': '0'}, 1), + ('east:west\nnorth:south', {'east': 'west', 'north': 'south'}, 1), +] kvseq_cases = [ ([], '', 0), @@ -131,7 +132,7 @@ def runTest(self): (' a ', ' b ')], ' open id : use ful \n a : b \n', 8), ([(u'foo', 'bar')], 'foo:bar\n', 0), - ] +] kvexc_cases = [ [('openid', 'use\nful')], @@ -140,7 +141,8 @@ def runTest(self): [('open:id', 'useful')], [('foo', 'bar'), ('ba\n d', 'seed')], [('foo', 'bar'), ('bad:', 'seed')], - ] +] + class KVExcTest(unittest.TestCase): def __init__(self, seq): @@ -153,14 +155,16 @@ def shortDescription(self): def runTest(self): self.failUnlessRaises(ValueError, kvform.seqToKV, self.seq) + class GeneralTest(KVBaseTest): kvform = '' def test_convert(self): - result = kvform.seqToKV([(1,1)]) + result = kvform.seqToKV([(1, 1)]) self.failUnlessEqual(result, '1:1\n') self.checkWarnings(2) + def pyUnitTests(): tests = [KVDictTest(*case) for case in kvdict_cases] tests.extend([KVSeqTest(*case) for case in kvseq_cases]) @@ -168,6 +172,7 @@ def pyUnitTests(): tests.append(unittest.defaultTestLoader.loadTestsFromTestCase(GeneralTest)) return unittest.TestSuite(tests) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/linkparse.py b/openid/test/linkparse.py index adcdfb35..f31f9ef2 100644 --- a/openid/test/linkparse.py +++ b/openid/test/linkparse.py @@ -23,6 +23,7 @@ def parseLink(line): return (optional, attrs) + def parseCase(s): header, markup = s.split('\n\n', 1) lines = header.split('\n') @@ -31,6 +32,7 @@ def parseCase(s): desc = name[6:] return desc, markup, map(parseLink, lines) + def parseTests(s): tests = [] @@ -47,6 +49,7 @@ def parseTests(s): return num_tests, tests + class _LinkTest(unittest.TestCase): def __init__(self, desc, case, expected, raw): unittest.TestCase.__init__(self) @@ -84,6 +87,7 @@ def runTest(self): assert i == len(actual) + def pyUnitTests(): here = os.path.dirname(os.path.abspath(__file__)) test_data_file_name = os.path.join(here, 'linkparse.txt') @@ -105,6 +109,7 @@ def test_parseSucceeded(): return unittest.TestSuite(tests) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/oidutil.py b/openid/test/oidutil.py index 568f16af..c7a002fa 100644 --- a/openid/test/oidutil.py +++ b/openid/test/oidutil.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import codecs import random import string import unittest @@ -25,7 +24,7 @@ def checkEncoded(s): '\x01', '\x00' * 100, ''.join(map(chr, range(256))), - ] + ] for s in cases: b64 = oidutil.toBase64(s) @@ -42,6 +41,7 @@ def checkEncoded(s): s_prime = oidutil.fromBase64(b64) assert s_prime == s, (s, b64, s_prime) + class AppendArgsTest(unittest.TestCase): def __init__(self, desc, args, expected): unittest.TestCase.__init__(self) @@ -56,6 +56,7 @@ def runTest(self): def shortDescription(self): return self.desc + class TestUnicodeConversion(unittest.TestCase): def test_toUnicode(self): @@ -68,6 +69,7 @@ def test_toUnicode(self): # Other encodings raise exceptions self.assertRaises(UnicodeDecodeError, lambda: oidutil.toUnicode(u'fööbär'.encode('latin-1'))) + class TestSymbol(unittest.TestCase): def testCopyHash(self): import copy @@ -96,7 +98,7 @@ def buildAppendTests(): simple + '?a=b'), ('one dict', - (simple, {'a':'b'}), + (simple, {'a': 'b'}), simple + '?a=b'), ('two list (same)', @@ -112,7 +114,7 @@ def buildAppendTests(): simple + '?b=c&a=b'), ('two dict (order)', - (simple, {'b':'c', 'a':'b'}), + (simple, {'b': 'c', 'a': 'b'}), simple + '?a=b&b=c'), ('escape', @@ -144,17 +146,17 @@ def buildAppendTests(): simple + '?stuff=bother&ack=ack'), ('args exist (dict 2)', - (simple + '?stuff=bother', {'ack': 'ack', 'zebra':'lion'}), + (simple + '?stuff=bother', {'ack': 'ack', 'zebra': 'lion'}), simple + '?stuff=bother&ack=ack&zebra=lion'), ('three args (dict)', - (simple, {'stuff': 'bother', 'ack': 'ack', 'zebra':'lion'}), + (simple, {'stuff': 'bother', 'ack': 'ack', 'zebra': 'lion'}), simple + '?ack=ack&stuff=bother&zebra=lion'), ('three args (list)', (simple, [('stuff', 'bother'), ('ack', 'ack'), ('zebra', 'lion')]), simple + '?stuff=bother&ack=ack&zebra=lion'), - ] + ] tests = [] @@ -164,12 +166,14 @@ def buildAppendTests(): return unittest.TestSuite(tests) + def pyUnitTests(): some = buildAppendTests() some.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestSymbol)) some.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestUnicodeConversion)) return some + def test_appendArgs(): suite = buildAppendTests() suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestSymbol)) @@ -181,10 +185,12 @@ def test_appendArgs(): # specified and tested in oidutil.py These include, but are not # limited to appendArgs + def test(skipPyUnit=True): test_base64() if not skipPyUnit: test_appendArgs() + if __name__ == '__main__': test(skipPyUnit=False) diff --git a/openid/test/storetest.py b/openid/test/storetest.py index 6d876fc2..a3885b5c 100644 --- a/openid/test/storetest.py +++ b/openid/test/storetest.py @@ -17,18 +17,20 @@ allowed_handle.append(c) allowed_handle = ''.join(allowed_handle) + def generateHandle(n): return randomString(n, allowed_handle) + generateSecret = randomString + def getTmpDbName(): hostname = socket.gethostname() hostname = hostname.replace('.', '_') hostname = hostname.replace('-', '_') - return "%s_%d_%s_openid_test" % \ - (hostname, os.getpid(), \ - random.randrange(1, int(time.time()))) + return "%s_%d_%s_openid_test" % (hostname, os.getpid(), random.randrange(1, int(time.time()))) + def testStore(store): """Make sure a given store has a minimum of API compliance. Call @@ -38,10 +40,11 @@ def testStore(store): OpenIDStore -> NoneType """ - ### Association functions + # Association functions now = int(time.time()) server_url = 'http://www.myopenid.com/openid' + def genAssoc(issued, lifetime=600): sec = generateSecret(20) hdl = generateHandle(128) @@ -146,15 +149,15 @@ def checkRemove(url, handle, expected): checkRemove(server_url, assoc.handle, False) checkRemove(server_url, assoc3.handle, False) - ### test expired associations + # test expired associations # assoc 1: server 1, valid # assoc 2: server 1, expired # assoc 3: server 2, expired # assoc 4: server 3, valid - assocValid1 = genAssoc(issued=-3600,lifetime=7200) + assocValid1 = genAssoc(issued=-3600, lifetime=7200) assocValid2 = genAssoc(issued=-5) - assocExpired1 = genAssoc(issued=-7200,lifetime=3600) - assocExpired2 = genAssoc(issued=-7200,lifetime=3600) + assocExpired1 = genAssoc(issued=-7200, lifetime=3600) + assocExpired2 = genAssoc(issued=-7200, lifetime=3600) store.cleanupAssociations() store.storeAssociation(server_url + '1', assocValid1) @@ -165,7 +168,7 @@ def checkRemove(url, handle, expected): cleaned = store.cleanupAssociations() assert cleaned == 2, cleaned - ### Nonce functions + # Nonce functions def checkUseNonce(nonce, expected, server_url, msg=''): stamp, salt = split(nonce) @@ -189,7 +192,6 @@ def checkUseNonce(nonce, expected, server_url, msg=''): old_nonce = mkNonce(3600) checkUseNonce(old_nonce, False, url, "Old nonce (%r) passed." % (old_nonce,)) - old_nonce1 = mkNonce(now - 20000) old_nonce2 = mkNonce(now - 10000) recent_nonce = mkNonce(now - 600) @@ -235,11 +237,12 @@ def test_filestore(): try: testStore(store) store.cleanup() - except: + except Exception: raise else: shutil.rmtree(temp_dir) + def test_sqlite(): from openid.store import sqlstore try: @@ -252,6 +255,7 @@ def test_sqlite(): store.createTables() testStore(store) + def test_mysql(): from openid.store import sqlstore try: @@ -263,12 +267,10 @@ def test_mysql(): db_passwd = '' db_name = getTmpDbName() - from MySQLdb.constants import ER - # Change this connect line to use the right user and password try: - conn = MySQLdb.connect(user=db_user, passwd=db_passwd, host = db_host) - except MySQLdb.OperationalError, why: + conn = MySQLdb.connect(user=db_user, passwd=db_passwd, host=db_host) + except MySQLdb.OperationalError as why: if why[0] == 2005: print ('Skipping MySQL store test (cannot connect ' 'to test server on host %r)' % (db_host,)) @@ -292,6 +294,7 @@ def test_mysql(): # failing test, comment out this line. conn.query('DROP DATABASE %s;' % db_name) + def test_postgresql(): """ Tests the PostgreSQLStore on a locally-hosted PostgreSQL database @@ -329,8 +332,7 @@ def test_postgresql(): # Connect once to create the database; reconnect to access the # new database. - conn_create = psycopg.connect(database = 'template1', user = db_user, - host = db_host) + conn_create = psycopg.connect(database='template1', user=db_user, host=db_host) conn_create.autocommit() # Create the test database. @@ -339,8 +341,7 @@ def test_postgresql(): conn_create.close() # Connect to the test database. - conn_test = psycopg.connect(database = db_name, user = db_user, - host = db_host) + conn_test = psycopg.connect(database=db_name, user=db_user, host=db_host) # OK, we're in the right environment. Create the store # instance and create the tables. @@ -361,31 +362,33 @@ def test_postgresql(): time.sleep(1) # Remove the database now that the test is over. - conn_remove = psycopg.connect(database = 'template1', user = db_user, - host = db_host) + conn_remove = psycopg.connect(database='template1', user=db_user, host=db_host) conn_remove.autocommit() cursor = conn_remove.cursor() cursor.execute('DROP DATABASE %s;' % (db_name,)) conn_remove.close() + def test_memstore(): from openid.store import memstore testStore(memstore.MemoryStore()) + test_functions = [ test_filestore, test_sqlite, test_mysql, test_postgresql, test_memstore, - ] +] + def pyUnitTests(): tests = map(unittest.FunctionTestCase, test_functions) - load = unittest.defaultTestLoader.loadTestsFromTestCase return unittest.TestSuite(tests) + if __name__ == '__main__': import sys suite = pyUnitTests() diff --git a/openid/test/support.py b/openid/test/support.py index d61973c6..e864c899 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -7,7 +7,7 @@ class TestHandler(BufferingHandler): def __init__(self, messages): BufferingHandler.__init__(self, 0) - self.messages = messages + self.messages = messages def shouldFlush(self): return False @@ -15,6 +15,7 @@ def shouldFlush(self): def emit(self, record): self.messages.append(record) + class OpenIDTestMixin(object): def failUnlessOpenIDValueEquals(self, msg, key, expected, ns=None): if ns is None: @@ -33,22 +34,23 @@ def failIfOpenIDKeyExists(self, msg, key, ns=None): error_message = 'openid.%s unexpectedly present: %s' % (key, actual) self.failIf(actual is not None, error_message) + class CatchLogs(object): def setUp(self): - self.messages = [] - root_logger = logging.getLogger() - self.old_log_level = root_logger.getEffectiveLevel() - root_logger.setLevel(logging.DEBUG) + self.messages = [] + root_logger = logging.getLogger() + self.old_log_level = root_logger.getEffectiveLevel() + root_logger.setLevel(logging.DEBUG) - self.handler = TestHandler(self.messages) - formatter = logging.Formatter("%(message)s [%(asctime)s - %(name)s - %(levelname)s]") - self.handler.setFormatter(formatter) - root_logger.addHandler(self.handler) + self.handler = TestHandler(self.messages) + formatter = logging.Formatter("%(message)s [%(asctime)s - %(name)s - %(levelname)s]") + self.handler.setFormatter(formatter) + root_logger.addHandler(self.handler) def tearDown(self): root_logger = logging.getLogger() - root_logger.removeHandler(self.handler) - root_logger.setLevel(self.old_log_level) + root_logger.removeHandler(self.handler) + root_logger.setLevel(self.old_log_level) def failUnlessLogMatches(self, *prefixes): """ @@ -58,14 +60,10 @@ def failUnlessLogMatches(self, *prefixes): messages. """ messages = [r.getMessage() for r in self.messages] - assert len(prefixes) == len(messages), \ - "Expected log prefixes %r, got %r" % (prefixes, - messages) - - for prefix, message in zip(prefixes, messages): - assert message.startswith(prefix), \ - "Expected log prefixes %r, got %r" % (prefixes, - messages) + assert len(prefixes) == len(messages), "Expected log prefixes %r, got %r" % (prefixes, messages) + + for prefix, msg in zip(prefixes, messages): + assert msg.startswith(prefix), "Expected log prefixes %r, got %r" % (prefixes, messages) def failUnlessLogEmpty(self): self.failUnlessLogMatches() diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 547e42a6..b8af670e 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -17,6 +17,7 @@ def getTestData(): i += 1 return lines + def chunk(lines): """Return groups of lines separated by whitespace or comments @@ -38,6 +39,7 @@ def chunk(lines): return chunks + def parseLines(chunk): """Take the given chunk of lines and turn it into a test data dictionary @@ -51,6 +53,7 @@ def parseLines(chunk): return items + def parseAvailable(available_text): """Parse an Available: line's data @@ -58,6 +61,7 @@ def parseAvailable(available_text): """ return [s.strip() for s in available_text.split(',')] + def parseExpected(expected_text): """Parse an Expected: line's data @@ -78,6 +82,7 @@ def parseExpected(expected_text): return expected + class MatchAcceptTest(unittest.TestCase): def __init__(self, descr, accept_header, available, expected): unittest.TestCase.__init__(self) @@ -94,6 +99,7 @@ def runTest(self): actual = accept.matchTypes(accepted, self.available) self.failUnlessEqual(self.expected, actual) + def pyUnitTests(): lines = getTestData() chunks = chunk(lines) @@ -107,7 +113,7 @@ def pyUnitTests(): lnos.append(lno) try: available = parseAvailable(avail_data) - except: + except Exception: print 'On line', lno raise @@ -115,7 +121,7 @@ def pyUnitTests(): lnos.append(lno) try: expected = parseExpected(exp_data) - except: + except Exception: print 'On line', lno raise @@ -124,6 +130,7 @@ def pyUnitTests(): cases.append(case) return unittest.TestSuite(cases) + if __name__ == '__main__': runner = unittest.TextTestRunner() runner.run(pyUnitTests()) diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 86c2883d..929d5b6e 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -1,14 +1,11 @@ import time import unittest -import warnings -from openid import association, cryptutil -from openid.consumer.consumer import (DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, - PlainTextConsumerSession) +from openid import association +from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession, PlainTextConsumerSession from openid.dh import DiffieHellman from openid.message import BARE_NS, OPENID2_NS, OPENID_NS, Message -from openid.server.server import (DiffieHellmanSHA1ServerSession, DiffieHellmanSHA256ServerSession, - PlainTextServerSession) +from openid.server.server import DiffieHellmanSHA1ServerSession, PlainTextServerSession from openid.test import datadriven @@ -27,25 +24,24 @@ def test_roundTrip(self): self.failUnlessEqual(assoc.assoc_type, assoc2.assoc_type) - - def createNonstandardConsumerDH(): nonstandard_dh = DiffieHellman(1315291, 2) return DiffieHellmanSHA1ConsumerSession(nonstandard_dh) + class DiffieHellmanSessionTest(datadriven.DataDrivenTestCase): secrets = [ '\x00' * 20, '\xff' * 20, ' ' * 20, 'This is a secret....', - ] + ] session_factories = [ (DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA1ServerSession), (createNonstandardConsumerDH, DiffieHellmanSHA1ServerSession), (PlainTextConsumerSession, PlainTextServerSession), - ] + ] def generateCases(cls): return [(c, s, sec) @@ -69,7 +65,6 @@ def runOneTest(self): self.failUnlessEqual(self.secret, check_secret) - class TestMakePairs(unittest.TestCase): """Check the key-value formatting methods of associations. """ @@ -81,29 +76,26 @@ def setUp(self): 'identifier': '=example', 'signed': 'identifier,mode', 'sig': 'cephalopod', - }) + }) m.updateArgs(BARE_NS, {'xey': 'value'}) self.assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") - def testMakePairs(self): """Make pairs using the OpenID 1.x type signed list.""" pairs = self.assoc._makePairs(self.message) expected = [ ('identifier', '=example'), ('mode', 'id_res'), - ] + ] self.failUnlessEqual(pairs, expected) - class TestMac(unittest.TestCase): def setUp(self): self.pairs = [('key1', 'value1'), ('key2', 'value2')] - def test_sha1(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") @@ -121,7 +113,6 @@ def test_sha256(self): self.failUnlessEqual(sig, expected) - class TestMessageSigning(unittest.TestCase): def setUp(self): self.message = m = Message(OPENID2_NS) @@ -132,7 +123,6 @@ def setUp(self): 'openid.identifier': '=example', 'xey': 'value'} - def test_signSHA1(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") @@ -170,6 +160,7 @@ def test_aintGotSignedList(self): def pyUnitTests(): return datadriven.loadTests(__name__) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 11161fbb..79be68c7 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -5,10 +5,9 @@ """ import unittest -from openid import oidutil -from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession, GenericConsumer, ProtocolError +from openid.consumer.consumer import GenericConsumer, ProtocolError from openid.consumer.discover import OPENID_1_1_TYPE, OPENID_2_0_TYPE, OpenIDServiceEndpoint -from openid.message import OPENID2_NS, OPENID_NS, Message, no_default +from openid.message import OPENID2_NS, OPENID_NS, Message from openid.server.server import DiffieHellmanSHA1ServerSession from openid.store import memstore from openid.test.test_consumer import CatchLogs @@ -16,11 +15,12 @@ # Some values we can use for convenience (see mkAssocResponse) association_response_values = { 'expires_in': '1000', - 'assoc_handle':'a handle', - 'assoc_type':'a type', - 'session_type':'a session type', - 'ns':OPENID2_NS, - } + 'assoc_handle': 'a handle', + 'assoc_type': 'a type', + 'session_type': 'a session type', + 'ns': OPENID2_NS, +} + def mkAssocResponse(*keys): """Build an association response message that contains the @@ -32,6 +32,7 @@ def mkAssocResponse(*keys): args = dict([(key, association_response_values[key]) for key in keys]) return Message.fromOpenIDArgs(args) + class BaseAssocTest(CatchLogs, unittest.TestCase): def setUp(self): CatchLogs.setUp(self) @@ -42,12 +43,13 @@ def setUp(self): def failUnlessProtocolError(self, str_prefix, func, *args, **kwargs): try: result = func(*args, **kwargs) - except ProtocolError, e: + except ProtocolError as e: message = 'Expected prefix %r, got %r' % (str_prefix, e[0]) self.failUnless(e[0].startswith(str_prefix), message) else: self.fail('Expected ProtocolError, got %r' % (result,)) + def mkExtractAssocMissingTest(keys): """Factory function for creating test methods for generating missing field tests. @@ -77,6 +79,7 @@ def test(self): return test + class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest): """Test for returning an error upon missing fields in association responses for OpenID 2""" @@ -95,6 +98,7 @@ class TestExtractAssociationMissingFieldsOpenID2(BaseAssocTest): test_missingSessionType_openid2 = mkExtractAssocMissingTest( ['expires_in', 'assoc_handle', 'assoc_type', 'ns']) + class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest): """Test for returning an error upon missing fields in association responses for OpenID 2""" @@ -110,11 +114,13 @@ class TestExtractAssociationMissingFieldsOpenID1(BaseAssocTest): test_missingAssocType_openid1 = mkExtractAssocMissingTest( ['expires_in', 'assoc_handle']) + class DummyAssocationSession(object): def __init__(self, session_type, allowed_assoc_types=()): self.session_type = session_type self.allowed_assoc_types = allowed_assoc_types + class ExtractAssociationSessionTypeMismatch(BaseAssocTest): def mkTest(requested_session_type, response_session_type, openid1=False): def test(self): @@ -124,48 +130,47 @@ def test(self): keys.remove('ns') msg = mkAssocResponse(*keys) msg.setArg(OPENID_NS, 'session_type', response_session_type) - self.failUnlessProtocolError('Session type mismatch', - self.consumer._extractAssociation, msg, assoc_session) + self.failUnlessProtocolError('Session type mismatch', self.consumer._extractAssociation, msg, assoc_session) return test test_typeMismatchNoEncBlank_openid2 = mkTest( requested_session_type='no-encryption', response_session_type='', - ) + ) test_typeMismatchDHSHA1NoEnc_openid2 = mkTest( requested_session_type='DH-SHA1', response_session_type='no-encryption', - ) + ) test_typeMismatchDHSHA256NoEnc_openid2 = mkTest( requested_session_type='DH-SHA256', response_session_type='no-encryption', - ) + ) test_typeMismatchNoEncDHSHA1_openid2 = mkTest( requested_session_type='no-encryption', response_session_type='DH-SHA1', - ) + ) test_typeMismatchDHSHA1NoEnc_openid1 = mkTest( requested_session_type='DH-SHA1', response_session_type='DH-SHA256', openid1=True, - ) + ) test_typeMismatchDHSHA256NoEnc_openid1 = mkTest( requested_session_type='DH-SHA256', response_session_type='DH-SHA1', openid1=True, - ) + ) test_typeMismatchNoEncDHSHA1_openid1 = mkTest( requested_session_type='no-encryption', response_session_type='DH-SHA1', openid1=True, - ) + ) class TestOpenID1AssociationResponseSessionType(BaseAssocTest): @@ -174,6 +179,7 @@ def mkTest(expected_session_type, session_type_value): be used if the OpenID 1 response to an associate call sets the 'session_type' field to `session_type_value` """ + def test(self): self._doTest(expected_session_type, session_type_value) self.failUnlessLogEmpty() @@ -201,25 +207,25 @@ def _doTest(self, expected_session_type, session_type_value): test_none = mkTest( session_type_value=None, expected_session_type='no-encryption', - ) + ) test_empty = mkTest( session_type_value='', expected_session_type='no-encryption', - ) + ) # This one's different because it expects log messages def test_explicitNoEncryption(self): self._doTest( session_type_value='no-encryption', expected_session_type='no-encryption', - ) + ) self.failUnlessLogMatches('OpenID server sent "no-encryption"') test_dhSHA1 = mkTest( session_type_value='DH-SHA1', expected_session_type='DH-SHA1', - ) + ) # DH-SHA256 is not a valid session type for OpenID1, but this # function does not test that. This is mostly just to make sure @@ -229,7 +235,8 @@ def test_explicitNoEncryption(self): test_dhSHA256 = mkTest( session_type_value='DH-SHA256', expected_session_type='DH-SHA256', - ) + ) + class DummyAssociationSession(object): secret = "shh! don't tell!" @@ -243,6 +250,7 @@ def extractSecret(self, message): self.extract_secret_called = True return self.secret + class TestInvalidFields(BaseAssocTest): def setUp(self): BaseAssocTest.setUp(self) @@ -256,11 +264,11 @@ def setUp(self): # These arguments should all be valid self.assoc_response = Message.fromOpenIDArgs({ 'expires_in': '1000', - 'assoc_handle':self.assoc_handle, - 'assoc_type':self.assoc_type, - 'session_type':self.session_type, - 'ns':OPENID2_NS, - }) + 'assoc_handle': self.assoc_handle, + 'assoc_type': self.assoc_type, + 'session_type': self.session_type, + 'ns': OPENID2_NS, + }) self.assoc_session = DummyAssociationSession() @@ -283,15 +291,13 @@ def test_badAssocType(self): # for the given session. self.assoc_session.allowed_assoc_types = [] self.failUnlessProtocolError('Unsupported assoc_type for session', - self.consumer._extractAssociation, - self.assoc_response, self.assoc_session) + self.consumer._extractAssociation, self.assoc_response, self.assoc_session) def test_badExpiresIn(self): # Invalid value for expires_in should cause failure self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever') self.failUnlessProtocolError('Invalid expires_in', - self.consumer._extractAssociation, - self.assoc_response, self.assoc_session) + self.consumer._extractAssociation, self.assoc_response, self.assoc_session) # XXX: This is what causes most of the imports in this file. It is @@ -334,5 +340,4 @@ def test_openid2success(self): def test_badDHValues(self): sess, server_resp = self._setUpDH() server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00') - self.failUnlessProtocolError('Malformed response for', - self.consumer._extractAssociation, server_resp, sess) + self.failUnlessProtocolError('Malformed response for', self.consumer._extractAssociation, server_resp, sess) diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index 1419ab54..c92ccb31 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -1,4 +1,3 @@ -import cgi import unittest from openid import message @@ -21,9 +20,11 @@ def getLocalID(self): def isOPIdentifier(self): return self.is_op_identifier + class DummyAssoc(object): handle = "assoc-handle" + class AuthRequestTestMixin(support.OpenIDTestMixin): """Mixin for AuthRequest tests for OpenID 1 and 2; DON'T add unittest.TestCase as a base class here.""" @@ -102,6 +103,7 @@ def test_standard(self): self.failUnlessHasIdentifiers( msg, self.endpoint.local_id, self.endpoint.claimed_id) + class TestAuthRequestOpenID2(AuthRequestTestMixin, unittest.TestCase): preferred_namespace = message.OPENID2_NS @@ -152,13 +154,10 @@ def test_opIdentifierSendsIdentifierSelect(self): self.failUnlessHasIdentifiers( msg, message.IDENTIFIER_SELECT, message.IDENTIFIER_SELECT) + class TestAuthRequestOpenID1(AuthRequestTestMixin, unittest.TestCase): preferred_namespace = message.OPENID1_NS - def setUpEndpoint(self): - TestAuthRequestBase.setUpEndpoint(self) - self.endpoint.preferred_namespace = message.OPENID1_NS - def failUnlessHasIdentifiers(self, msg, op_specific_id, claimed_id): """Make sure claimed_is is *absent* in request.""" self.failUnlessOpenIDValueEquals(msg, 'identity', op_specific_id) @@ -195,13 +194,16 @@ def test_identifierSelect(self): self.failUnlessEqual(message.IDENTIFIER_SELECT, msg.getArg(message.OPENID1_NS, 'identity')) + class TestAuthRequestOpenID1Immediate(TestAuthRequestOpenID1): immediate = True expected_mode = 'checkid_immediate' + class TestAuthRequestOpenID2Immediate(TestAuthRequestOpenID2): immediate = True expected_mode = 'checkid_immediate' + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 28f90fab..80be5c8c 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -13,10 +13,12 @@ class BogusAXMessage(ax.AXMessage): getExtensionArgs = ax.AXMessage._newArgs + class DummyRequest(object): def __init__(self, message): self.message = message + class AXMessageTest(unittest.TestCase): def setUp(self): self.bax = BogusAXMessage() @@ -24,10 +26,10 @@ def setUp(self): def test_checkMode(self): check = self.bax._checkMode self.failUnlessRaises(ax.NotAXMessage, check, {}) - self.failUnlessRaises(ax.AXError, check, {'mode':'fetch_request'}) + self.failUnlessRaises(ax.AXError, check, {'mode': 'fetch_request'}) # does not raise an exception when the mode is right - check({'mode':self.bax.mode}) + check({'mode': self.bax.mode}) def test_checkMode_newArgs(self): """_newArgs generates something that has the correct mode""" @@ -80,6 +82,7 @@ def test_two(self): uris = ax.toTypeURIs(self.aliases, ','.join([alias1, alias2])) self.failUnlessEqual([uri1, uri2], uris) + class ParseAXValuesTest(unittest.TestCase): """Testing AXKeyValueMessage.parseExtensionArgs.""" @@ -97,27 +100,27 @@ def test_emptyIsValid(self): self.failUnlessAXValues({}, {}) def test_missingValueForAliasExplodes(self): - self.failUnlessAXKeyError({'type.foo':'urn:foo'}) + self.failUnlessAXKeyError({'type.foo': 'urn:foo'}) def test_countPresentButNotValue(self): - self.failUnlessAXKeyError({'type.foo':'urn:foo', - 'count.foo':'1'}) + self.failUnlessAXKeyError({'type.foo': 'urn:foo', + 'count.foo': '1'}) def test_invalidCountValue(self): msg = ax.FetchRequest() self.failUnlessRaises(ax.AXError, msg.parseExtensionArgs, - {'type.foo':'urn:foo', - 'count.foo':'bogus'}) + {'type.foo': 'urn:foo', + 'count.foo': 'bogus'}) def test_requestUnlimitedValues(self): msg = ax.FetchRequest() msg.parseExtensionArgs( - {'mode':'fetch_request', - 'required':'foo', - 'type.foo':'urn:foo', - 'count.foo':ax.UNLIMITED_VALUES}) + {'mode': 'fetch_request', + 'required': 'foo', + 'type.foo': 'urn:foo', + 'count.foo': ax.UNLIMITED_VALUES}) attrs = list(msg.iterAttrs()) foo = attrs[0] @@ -135,20 +138,20 @@ def test_longAlias(self): {'type.%s' % (alias,): 'urn:foo', 'count.%s' % (alias,): '1', 'value.%s.1' % (alias,): 'first'} - ) + ) def test_invalidAlias(self): types = [ ax.AXKeyValueMessage, ax.FetchRequest - ] + ] inputs = [ - {'type.a.b':'urn:foo', - 'count.a.b':'1'}, - {'type.a,b':'urn:foo', - 'count.a,b':'1'}, - ] + {'type.a.b': 'urn:foo', + 'count.a.b': '1'}, + {'type.a,b': 'urn:foo', + 'count.a,b': '1'}, + ] for typ in types: for input in inputs: @@ -158,37 +161,37 @@ def test_invalidAlias(self): def test_countPresentAndIsZero(self): self.failUnlessAXValues( - {'type.foo':'urn:foo', - 'count.foo':'0', - }, {'urn:foo':[]}) + {'type.foo': 'urn:foo', + 'count.foo': '0', + }, {'urn:foo': []}) def test_singletonEmpty(self): self.failUnlessAXValues( - {'type.foo':'urn:foo', - 'value.foo':'', - }, {'urn:foo':[]}) + {'type.foo': 'urn:foo', + 'value.foo': '', + }, {'urn:foo': []}) def test_doubleAlias(self): self.failUnlessAXKeyError( - {'type.foo':'urn:foo', - 'value.foo':'', - 'type.bar':'urn:foo', - 'value.bar':'', + {'type.foo': 'urn:foo', + 'value.foo': '', + 'type.bar': 'urn:foo', + 'value.bar': '', }) def test_doubleSingleton(self): self.failUnlessAXValues( - {'type.foo':'urn:foo', - 'value.foo':'', - 'type.bar':'urn:bar', - 'value.bar':'', - }, {'urn:foo':[], 'urn:bar':[]}) + {'type.foo': 'urn:foo', + 'value.foo': '', + 'type.bar': 'urn:bar', + 'value.bar': '', + }, {'urn:foo': [], 'urn:bar': []}) def test_singletonValue(self): self.failUnlessAXValues( - {'type.foo':'urn:foo', - 'value.foo':'Westfall', - }, {'urn:foo':['Westfall']}) + {'type.foo': 'urn:foo', + 'value.foo': 'Westfall', + }, {'urn:foo': ['Westfall']}) class FetchRequestTest(unittest.TestCase): @@ -197,7 +200,6 @@ def setUp(self): self.type_a = 'http://janrain.example.com/a' self.alias_a = 'a' - def test_mode(self): self.failUnlessEqual(self.msg.mode, 'fetch_request') @@ -230,14 +232,14 @@ def test_addTwice(self): def test_getExtensionArgs_empty(self): expected_args = { - 'mode':'fetch_request', - } + 'mode': 'fetch_request', + } self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) def test_getExtensionArgs_noAlias(self): attr = ax.AttrInfo( - type_uri = 'type://of.transportation', - ) + type_uri='type://of.transportation', + ) self.msg.add(attr) ax_args = self.msg.getExtensionArgs() for k, v in ax_args.iteritems(): @@ -248,32 +250,32 @@ def test_getExtensionArgs_noAlias(self): self.fail("Didn't find the type definition") self.failUnlessExtensionArgs({ - 'type.' + alias:attr.type_uri, - 'if_available':alias, - }) + 'type.' + alias: attr.type_uri, + 'if_available': alias, + }) def test_getExtensionArgs_alias_if_available(self): attr = ax.AttrInfo( - type_uri = 'type://of.transportation', - alias = 'transport', - ) + type_uri='type://of.transportation', + alias='transport', + ) self.msg.add(attr) self.failUnlessExtensionArgs({ - 'type.' + attr.alias:attr.type_uri, - 'if_available':attr.alias, - }) + 'type.' + attr.alias: attr.type_uri, + 'if_available': attr.alias, + }) def test_getExtensionArgs_alias_req(self): attr = ax.AttrInfo( - type_uri = 'type://of.transportation', - alias = 'transport', - required = True, - ) + type_uri='type://of.transportation', + alias='transport', + required=True, + ) self.msg.add(attr) self.failUnlessExtensionArgs({ - 'type.' + attr.alias:attr.type_uri, - 'required':attr.alias, - }) + 'type.' + attr.alias: attr.type_uri, + 'required': attr.alias, + }) def failUnlessExtensionArgs(self, expected_args): """Make sure that getExtensionArgs has the expected result @@ -293,18 +295,18 @@ def test_getRequiredAttrs_empty(self): def test_parseExtensionArgs_extraType(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + } self.failUnlessRaises(ValueError, self.msg.parseExtensionArgs, extension_args) def test_parseExtensionArgs(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'if_available':self.alias_a - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'if_available': self.alias_a + } self.msg.parseExtensionArgs(extension_args) self.failUnless(self.type_a in self.msg) self.failUnlessEqual([self.type_a], list(self.msg)) @@ -317,37 +319,37 @@ def test_parseExtensionArgs(self): def test_extensionArgs_idempotent(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'if_available':self.alias_a - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'if_available': self.alias_a + } self.msg.parseExtensionArgs(extension_args) self.failUnlessEqual(extension_args, self.msg.getExtensionArgs()) self.failIf(self.msg.requested_attributes[self.type_a].required) def test_extensionArgs_idempotent_count_required(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'count.' + self.alias_a:'2', - 'required':self.alias_a - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'count.' + self.alias_a: '2', + 'required': self.alias_a + } self.msg.parseExtensionArgs(extension_args) self.failUnlessEqual(extension_args, self.msg.getExtensionArgs()) self.failUnless(self.msg.requested_attributes[self.type_a].required) def test_extensionArgs_count1(self): extension_args = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'count.' + self.alias_a:'1', - 'if_available':self.alias_a, - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'count.' + self.alias_a: '1', + 'if_available': self.alias_a, + } extension_args_norm = { - 'mode':'fetch_request', - 'type.' + self.alias_a:self.type_a, - 'if_available':self.alias_a, - } + 'mode': 'fetch_request', + 'type.' + self.alias_a: self.type_a, + 'if_available': self.alias_a, + } self.msg.parseExtensionArgs(extension_args) self.failUnlessEqual(extension_args_norm, self.msg.getExtensionArgs()) @@ -358,7 +360,7 @@ def test_openidNoRealm(self): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://different.site/path', 'ax.mode': 'fetch_request', - }) + }) self.failUnlessRaises(ax.AXError, ax.FetchRequest.fromOpenIDRequest, DummyRequest(openid_req_msg)) @@ -371,7 +373,7 @@ def test_openidUpdateURLVerificationError(self): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://different.site/path', 'ax.mode': 'fetch_request', - }) + }) self.failUnlessRaises(ax.AXError, ax.FetchRequest.fromOpenIDRequest, @@ -385,9 +387,9 @@ def test_openidUpdateURLVerificationSuccess(self): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://example.com/realm/update_path', 'ax.mode': 'fetch_request', - }) + }) - fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) + ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) def test_openidUpdateURLVerificationSuccessReturnTo(self): openid_req_msg = Message.fromOpenIDArgs({ @@ -397,16 +399,16 @@ def test_openidUpdateURLVerificationSuccessReturnTo(self): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://example.com/realm/update_path', 'ax.mode': 'fetch_request', - }) + }) - fr = ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) + ax.FetchRequest.fromOpenIDRequest(DummyRequest(openid_req_msg)) def test_fromOpenIDRequestWithoutExtension(self): """return None for an OpenIDRequest without AX paramaters.""" openid_req_msg = Message.fromOpenIDArgs({ 'mode': 'checkid_setup', 'ns': OPENID2_NS, - }) + }) oreq = DummyRequest(openid_req_msg) r = ax.FetchRequest.fromOpenIDRequest(oreq) self.failUnless(r is None, "%s is not None" % (r,)) @@ -420,7 +422,7 @@ def test_fromOpenIDRequestWithoutData(self): 'ns': OPENID2_NS, 'ns.ax': ax.AXMessage.ns_uri, 'ax.mode': 'fetch_request', - }) + }) oreq = DummyRequest(openid_req_msg) r = ax.FetchRequest.fromOpenIDRequest(oreq) self.failUnless(r is not None) @@ -440,14 +442,14 @@ def test_construct(self): def test_getExtensionArgs_empty(self): expected_args = { - 'mode':'fetch_response', - } + 'mode': 'fetch_response', + } self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) def test_getExtensionArgs_empty_request(self): expected_args = { - 'mode':'fetch_response', - } + 'mode': 'fetch_response', + } req = ax.FetchRequest() msg = ax.FetchResponse(request=req) self.failUnlessEqual(expected_args, msg.getExtensionArgs()) @@ -457,10 +459,10 @@ def test_getExtensionArgs_empty_request_some(self): alias = 'ext0' expected_args = { - 'mode':'fetch_response', + 'mode': 'fetch_response', 'type.%s' % (alias,): uri, 'count.%s' % (alias,): '0' - } + } req = ax.FetchRequest() req.add(ax.AttrInfo(uri)) msg = ax.FetchResponse(request=req) @@ -471,11 +473,11 @@ def test_updateUrlInResponse(self): alias = 'ext0' expected_args = { - 'mode':'fetch_response', + 'mode': 'fetch_response', 'update_url': self.request_update_url, 'type.%s' % (alias,): uri, 'count.%s' % (alias,): '0' - } + } req = ax.FetchRequest(update_url=self.request_update_url) req.add(ax.AttrInfo(uri)) msg = ax.FetchResponse(request=req) @@ -483,11 +485,11 @@ def test_updateUrlInResponse(self): def test_getExtensionArgs_some_request(self): expected_args = { - 'mode':'fetch_response', - 'type.' + self.alias_a:self.type_a, - 'value.' + self.alias_a + '.1':self.value_a, + 'mode': 'fetch_response', + 'type.' + self.alias_a: self.type_a, + 'value.' + self.alias_a + '.1': self.value_a, 'count.' + self.alias_a: '1' - } + } req = ax.FetchRequest() req.add(ax.AttrInfo(self.type_a, alias=self.alias_a)) msg = ax.FetchResponse(request=req) @@ -501,7 +503,6 @@ def test_getExtensionArgs_some_not_request(self): self.failUnlessRaises(KeyError, msg.getExtensionArgs) def test_getSingle_success(self): - req = ax.FetchRequest() self.msg.addValue(self.type_a, self.value_a) self.failUnlessEqual(self.value_a, self.msg.getSingle(self.type_a)) @@ -520,9 +521,10 @@ def test_fromSuccessResponseWithoutExtension(self): args = { 'mode': 'id_res', 'ns': OPENID2_NS, - } + } sf = ['openid.' + i for i in args.keys()] msg = Message.fromOpenIDArgs(args) + class Endpoint: claimed_id = 'http://invalid.' @@ -538,9 +540,10 @@ def test_fromSuccessResponseWithoutData(self): 'ns': OPENID2_NS, 'ns.ax': ax.AXMessage.ns_uri, 'ax.mode': 'fetch_response', - } + } sf = ['openid.' + i for i in args.keys()] msg = Message.fromOpenIDArgs(args) + class Endpoint: claimed_id = 'http://invalid.' @@ -558,12 +561,13 @@ def test_fromSuccessResponseWithData(self): 'ns.ax': ax.AXMessage.ns_uri, 'ax.update_url': 'http://example.com/realm/update_path', 'ax.mode': 'fetch_response', - 'ax.type.'+name: uri, - 'ax.count.'+name: '1', - 'ax.value.%s.1'%name: value, - } + 'ax.type.' + name: uri, + 'ax.count.' + name: '1', + 'ax.value.%s.1' % name: value, + } sf = ['openid.' + i for i in args.keys()] msg = Message.fromOpenIDArgs(args) + class Endpoint: claimed_id = 'http://invalid.' @@ -585,8 +589,8 @@ def test_construct(self): def test_getExtensionArgs_empty(self): args = self.msg.getExtensionArgs() expected_args = { - 'mode':'store_request', - } + 'mode': 'store_request', + } self.failUnlessEqual(expected_args, args) def test_getExtensionArgs_nonempty(self): @@ -596,27 +600,28 @@ def test_getExtensionArgs_nonempty(self): msg.setValues(self.type_a, ['foo', 'bar']) args = msg.getExtensionArgs() expected_args = { - 'mode':'store_request', + 'mode': 'store_request', 'type.' + self.alias_a: self.type_a, 'count.' + self.alias_a: '2', - 'value.%s.1' % (self.alias_a,):'foo', - 'value.%s.2' % (self.alias_a,):'bar', - } + 'value.%s.1' % (self.alias_a,): 'foo', + 'value.%s.2' % (self.alias_a,): 'bar', + } self.failUnlessEqual(expected_args, args) + class StoreResponseTest(unittest.TestCase): def test_success(self): msg = ax.StoreResponse() self.failUnless(msg.succeeded()) self.failIf(msg.error_message) - self.failUnlessEqual({'mode':'store_response_success'}, + self.failUnlessEqual({'mode': 'store_response_success'}, msg.getExtensionArgs()) def test_fail_nomsg(self): msg = ax.StoreResponse(False) self.failIf(msg.succeeded()) self.failIf(msg.error_message) - self.failUnlessEqual({'mode':'store_response_failure'}, + self.failUnlessEqual({'mode': 'store_response_failure'}, msg.getExtensionArgs()) def test_fail_msg(self): @@ -624,5 +629,5 @@ def test_fail_msg(self): msg = ax.StoreResponse(False, reason) self.failIf(msg.succeeded()) self.failUnlessEqual(reason, msg.error_message) - self.failUnlessEqual({'mode':'store_response_failure', - 'error':reason}, msg.getExtensionArgs()) + self.failUnlessEqual({'mode': 'store_response_failure', + 'error': reason}, msg.getExtensionArgs()) diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index acab7c04..05496638 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -2,9 +2,8 @@ import time import unittest import urlparse -import warnings -from openid import association, cryptutil, dh, fetchers, kvform, oidutil +from openid import association, cryptutil, fetchers, kvform, oidutil from openid.consumer.consumer import (CANCEL, FAILURE, SETUP_NEEDED, SUCCESS, AuthRequest, CancelResponse, Consumer, DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, FailureResponse, GenericConsumer, PlainTextConsumerSession, ProtocolError, @@ -26,7 +25,8 @@ assocs = [ ('another 20-byte key.', 'Snarky'), ('\x00' * 20, 'Zeros'), - ] +] + def mkSuccess(endpoint, q): """Convenience function to create a SuccessResponse with the given @@ -34,13 +34,15 @@ def mkSuccess(endpoint, q): signed_list = ['openid.' + k for k in q.keys()] return SuccessResponse(endpoint, Message.fromOpenIDArgs(q), signed_list) + def parseQuery(qs): q = {} for (k, v) in cgi.parse_qsl(qs): - assert not q.has_key(k) + assert k not in q q[k] = v return q + def associate(qs, assoc_secret, assoc_handle): """Do the server's half of the associate call, using the given secret and handle.""" @@ -48,10 +50,10 @@ def associate(qs, assoc_secret, assoc_handle): assert q['openid.mode'] == 'associate' assert q['openid.assoc_type'] == 'HMAC-SHA1' reply_dict = { - 'assoc_type':'HMAC-SHA1', - 'assoc_handle':assoc_handle, - 'expires_in':'600', - } + 'assoc_type': 'HMAC-SHA1', + 'assoc_handle': assoc_handle, + 'expires_in': '600', + } if q.get('openid.session_type') == 'DH-SHA1': assert len(q) == 6 or len(q) == 4 @@ -86,8 +88,9 @@ def getAssociation(self, server_url, handle=None): class TestFetcher(object): - def __init__(self, user_url, user_page, (assoc_secret, assoc_handle)): - self.get_responses = {user_url:self.response(user_url, 200, user_page)} + def __init__(self, user_url, user_page, xxx_todo_changeme): + (assoc_secret, assoc_handle) = xxx_todo_changeme + self.get_responses = {user_url: self.response(user_url, 200, user_page)} self.assoc_secret = assoc_secret self.assoc_handle = assoc_handle self.num_assocs = 0 @@ -104,7 +107,7 @@ def fetch(self, url, body=None, headers=None): try: body.index('openid.mode=associate') except ValueError: - pass # fall through + pass # fall through else: assert body.find('DH-SHA1') != -1 response = associate( @@ -114,6 +117,7 @@ def fetch(self, url, body=None, headers=None): return self.response(url, 404, 'Not found') + def makeFastConsumerSession(): """ Create custom DH object so tests run quickly. @@ -121,9 +125,11 @@ def makeFastConsumerSession(): dh = DiffieHellman(100389557, 2) return DiffieHellmanSHA1ConsumerSession(dh) + def setConsumerSession(con): con.session_types = {'DH-SHA1': makeFastConsumerSession} + def _test_success(server_url, user_url, delegate_url, links, immediate=False): store = memstore.MemoryStore() if immediate: @@ -149,8 +155,6 @@ def run(): request = consumer.begin(endpoint) return_to = consumer_url - m = request.getMessage(trust_root, return_to, immediate) - redirect_url = request.redirectURL(trust_root, return_to, immediate) parsed = urlparse.urlparse(redirect_url) @@ -159,11 +163,11 @@ def run(): new_return_to = q['openid.return_to'] del q['openid.return_to'] assert q == { - 'openid.mode':mode, - 'openid.identity':delegate_url, - 'openid.trust_root':trust_root, - 'openid.assoc_handle':fetcher.assoc_handle, - }, (q, user_url, delegate_url, mode) + 'openid.mode': mode, + 'openid.identity': delegate_url, + 'openid.trust_root': trust_root, + 'openid.assoc_handle': fetcher.assoc_handle, + }, (q, user_url, delegate_url, mode) assert new_return_to.startswith(return_to) assert redirect_url.startswith(server_url) @@ -171,11 +175,11 @@ def run(): parsed = urlparse.urlparse(new_return_to) query = parseQuery(parsed[4]) query.update({ - 'openid.mode':'id_res', - 'openid.return_to':new_return_to, - 'openid.identity':delegate_url, - 'openid.assoc_handle':fetcher.assoc_handle, - }) + 'openid.mode': 'id_res', + 'openid.return_to': new_return_to, + 'openid.identity': delegate_url, + 'openid.assoc_handle': fetcher.assoc_handle, + }) assoc = store.getAssociation(server_url, fetcher.assoc_handle) @@ -207,6 +211,7 @@ def run(): consumer_url = 'http://consumer.example.com/' https_server_url = 'https://server.example.com/' + class TestSuccess(unittest.TestCase, CatchLogs): server_url = http_server_url user_url = 'http://www.example.com/user.html' @@ -284,10 +289,12 @@ def checkReturnTo(unused1, unused2): return True self.consumer._checkReturnTo = checkReturnTo complete = self.consumer.complete + def callCompleteWithoutReturnTo(message, endpoint): return complete(message, endpoint, None) self.consumer.complete = callCompleteWithoutReturnTo + class TestIdResCheckSignature(TestIdRes): def setUp(self): TestIdRes.setUp(self) @@ -302,22 +309,19 @@ def setUp(self): 'openid.assoc_handle': self.assoc.handle, 'openid.signed': 'mode,identity,assoc_handle,signed', 'frobboz': 'banzit', - }) - + }) def test_sign(self): # assoc_handle to assoc with good sig self.consumer._idResCheckSignature(self.message, self.endpoint.server_url) - def test_signFailsWithBadSig(self): self.message.setArg(OPENID_NS, 'sig', 'BAD SIGNATURE') self.failUnlessRaises( ProtocolError, self.consumer._idResCheckSignature, self.message, self.endpoint.server_url) - def test_stateless(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") @@ -364,11 +368,12 @@ def test_notAList(self): query = {'openid.mode': ['cancel']} try: r = Message.fromPostArgs(query) - except TypeError, err: + except TypeError as err: self.failUnless(str(err).find('values') != -1, err) else: self.fail("expected TypeError, got this instead: %s" % (r,)) + class TestComplete(TestIdRes): """Testing GenericConsumer.complete. @@ -404,9 +409,7 @@ def test_cancel_with_return_to(self): def test_error(self): msg = 'an error message' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, - }) + message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.failUnlessEqual(r.status, FAILURE) @@ -416,10 +419,7 @@ def test_error(self): def test_errorWithNoOptionalKeys(self): msg = 'an error message' contact = 'some contact info here' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, - 'openid.contact': contact, - }) + message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, 'openid.contact': contact}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.failUnlessEqual(r.status, FAILURE) @@ -432,10 +432,8 @@ def test_errorWithOptionalKeys(self): msg = 'an error message' contact = 'me' reference = 'support ticket' - message = Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': msg, 'openid.reference': reference, - 'openid.contact': contact, 'openid.ns': OPENID2_NS, - }) + message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, 'openid.reference': reference, + 'openid.contact': contact, 'openid.ns': OPENID2_NS}) r = self.consumer.complete(message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) @@ -458,7 +456,8 @@ def test_idResMissingField(self): message, self.endpoint, None) def test_idResURLMismatch(self): - class VerifiedError(Exception): pass + class VerifiedError(Exception): + pass def discoverAndVerify(claimed_id, _to_match_endpoints): raise VerifiedError @@ -483,6 +482,7 @@ def discoverAndVerify(claimed_id, _to_match_endpoints): self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') + class TestCompleteMissingSig(unittest.TestCase, CatchLogs): def setUp(self): @@ -503,18 +503,17 @@ def setUp(self): 'signed': 'identity,return_to,response_nonce,assoc_handle,claimed_id,op_endpoint', 'claimed_id': claimed_id, 'op_endpoint': self.server_url, - 'ns':OPENID2_NS, + 'ns': OPENID2_NS, }) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = self.server_url self.endpoint.claimed_id = claimed_id - self.consumer._checkReturnTo = lambda unused1, unused2 : True + self.consumer._checkReturnTo = lambda unused1, unused2: True def tearDown(self): CatchLogs.tearDown(self) - def test_idResMissingNoSigs(self): def _vrfy(resp_msg, endpoint=None): return endpoint @@ -523,7 +522,6 @@ def _vrfy(resp_msg, endpoint=None): r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessSuccess(r) - def test_idResNoIdentity(self): self.message.delArg(OPENID_NS, 'identity') self.message.delArg(OPENID_NS, 'claimed_id') @@ -532,37 +530,31 @@ def test_idResNoIdentity(self): r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessSuccess(r) - def test_idResMissingIdentitySig(self): self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) - def test_idResMissingReturnToSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) - def test_idResMissingAssocHandleSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) - def test_idResMissingClaimedIDSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,assoc_handle') r = self.consumer.complete(self.message, self.endpoint, None) self.failUnlessEqual(r.status, FAILURE) - def failUnlessSuccess(self, response): if response.status != SUCCESS: self.fail("Non-successful response: %s" % (response,)) - class TestCheckAuthResponse(TestIdRes, CatchLogs): def setUp(self): CatchLogs.setUp(self) @@ -583,7 +575,7 @@ def _createAssoc(self): def test_goodResponse(self): """successful response to check_authentication""" - response = Message.fromOpenIDArgs({'is_valid':'true',}) + response = Message.fromOpenIDArgs({'is_valid': 'true'}) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) @@ -595,7 +587,7 @@ def test_missingAnswer(self): def test_badResponse(self): """check_authentication returns false when is_valid is false""" - response = Message.fromOpenIDArgs({'is_valid':'false',}) + response = Message.fromOpenIDArgs({'is_valid': 'false'}) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failIf(r) @@ -610,9 +602,9 @@ def test_badResponseInvalidate(self): """ self._createAssoc() response = Message.fromOpenIDArgs({ - 'is_valid':'false', - 'invalidate_handle':'handle', - }) + 'is_valid': 'false', + 'invalidate_handle': 'handle', + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failIf(r) self.failUnless( @@ -621,21 +613,21 @@ def test_badResponseInvalidate(self): def test_invalidateMissing(self): """invalidate_handle with a handle that is not present""" response = Message.fromOpenIDArgs({ - 'is_valid':'true', - 'invalidate_handle':'missing', - }) + 'is_valid': 'true', + 'invalidate_handle': 'missing', + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) self.failUnlessLogMatches( 'Received "invalidate_handle"' - ) + ) def test_invalidateMissing_noStore(self): """invalidate_handle with a handle that is not present""" response = Message.fromOpenIDArgs({ - 'is_valid':'true', - 'invalidate_handle':'missing', - }) + 'is_valid': 'true', + 'invalidate_handle': 'missing', + }) self.consumer.store = None r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) @@ -654,19 +646,20 @@ def test_invalidatePresent(self): """ self._createAssoc() response = Message.fromOpenIDArgs({ - 'is_valid':'true', - 'invalidate_handle':'handle', - }) + 'is_valid': 'true', + 'invalidate_handle': 'handle', + }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failUnless(r) self.failUnless( self.consumer.store.getAssociation(self.server_url) is None) + class TestSetupNeeded(TestIdRes): def failUnlessSetupNeeded(self, expected_setup_url, message): try: self.consumer._checkSetupNeeded(message) - except SetupNeededError, why: + except SetupNeededError as why: self.failUnlessEqual(expected_setup_url, why.user_setup_url) else: self.fail("Expected to find an immediate-mode response") @@ -677,7 +670,7 @@ def test_setupNeededOpenID1(self): message = Message.fromPostArgs({ 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, - }) + }) self.failUnless(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) @@ -688,7 +681,7 @@ def test_setupNeededOpenID1_extra(self): 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, 'openid.identity': 'bogus', - }) + }) self.failUnless(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) @@ -703,9 +696,9 @@ def test_noSetupNeededOpenID1(self): def test_setupNeededOpenID2(self): message = Message.fromOpenIDArgs({ - 'mode':'setup_needed', - 'ns':OPENID2_NS, - }) + 'mode': 'setup_needed', + 'ns': OPENID2_NS, + }) self.failUnless(message.isOpenID2()) response = self.consumer.complete(message, None, None) self.failUnlessEqual('setup_needed', response.status) @@ -713,8 +706,8 @@ def test_setupNeededOpenID2(self): def test_setupNeededDoesntWorkForOpenID1(self): message = Message.fromOpenIDArgs({ - 'mode':'setup_needed', - }) + 'mode': 'setup_needed', + }) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) @@ -725,15 +718,16 @@ def test_setupNeededDoesntWorkForOpenID1(self): def test_noSetupNeededOpenID2(self): message = Message.fromOpenIDArgs({ - 'mode':'id_res', - 'game':'puerto_rico', - 'ns':OPENID2_NS, - }) + 'mode': 'id_res', + 'game': 'puerto_rico', + 'ns': OPENID2_NS, + }) self.failUnless(message.isOpenID2()) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) + class IdResCheckForFieldsTest(TestIdRes): def setUp(self): self.consumer = GenericConsumer(None) @@ -746,32 +740,32 @@ def test(self): return test test_openid1Success = mkSuccessTest( - {'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', + {'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', }, ['return_to', 'identity']) test_openid2Success = mkSuccessTest( - {'ns':OPENID2_NS, - 'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'op_endpoint':'my favourite server', - 'response_nonce':'use only once', + {'ns': OPENID2_NS, + 'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'op_endpoint': 'my favourite server', + 'response_nonce': 'use only once', }, ['return_to', 'response_nonce', 'assoc_handle', 'op_endpoint']) test_openid2Success_identifiers = mkSuccessTest( - {'ns':OPENID2_NS, - 'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'claimed_id':'i claim to be me', - 'identity':'my server knows me as me', - 'op_endpoint':'my favourite server', - 'response_nonce':'use only once', + {'ns': OPENID2_NS, + 'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'claimed_id': 'i claim to be me', + 'identity': 'my server knows me as me', + 'op_endpoint': 'my favourite server', + 'response_nonce': 'use only once', }, ['return_to', 'response_nonce', 'identity', 'claimed_id', 'assoc_handle', 'op_endpoint']) @@ -781,7 +775,7 @@ def test(self): message = Message.fromOpenIDArgs(openid_args) try: self.consumer._idResCheckForFields(message) - except ProtocolError, why: + except ProtocolError as why: self.failUnless(why[0].startswith('Missing required')) else: self.fail('Expected an error, but none occurred') @@ -792,53 +786,56 @@ def test(self): message = Message.fromOpenIDArgs(openid_args) try: self.consumer._idResCheckForFields(message) - except ProtocolError, why: + except ProtocolError as why: self.failUnless(why[0].endswith('not signed')) else: self.fail('Expected an error, but none occurred') return test test_openid1Missing_returnToSig = mkMissingSignedTest( - {'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', - 'signed':'identity', + {'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', + 'signed': 'identity', }) test_openid1Missing_identitySig = mkMissingSignedTest( - {'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', - 'signed':'return_to' + {'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', + 'signed': 'return_to' }) test_openid2Missing_opEndpointSig = mkMissingSignedTest( - {'ns':OPENID2_NS, - 'return_to':'return', - 'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', - 'op_endpoint':'the endpoint', - 'signed':'return_to,identity,assoc_handle' + {'ns': OPENID2_NS, + 'return_to': 'return', + 'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', + 'op_endpoint': 'the endpoint', + 'signed': 'return_to,identity,assoc_handle' }) test_openid1MissingReturnTo = mkMissingFieldTest( - {'assoc_handle':'assoc handle', - 'sig':'a signature', - 'identity':'someone', + {'assoc_handle': 'assoc handle', + 'sig': 'a signature', + 'identity': 'someone', }) test_openid1MissingAssocHandle = mkMissingFieldTest( - {'return_to':'return', - 'sig':'a signature', - 'identity':'someone', + {'return_to': 'return', + 'sig': 'a signature', + 'identity': 'someone', }) # XXX: I could go on... -class CheckAuthHappened(Exception): pass + +class CheckAuthHappened(Exception): + pass + class CheckNonceVerifyTest(TestIdRes, CatchLogs): def setUp(self): @@ -869,24 +866,21 @@ def test_consumerNonceOpenID2(self): """OpenID 2 does not use consumer-generated nonce""" self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),) self.response = Message.fromOpenIDArgs( - {'return_to': self.return_to, 'ns':OPENID2_NS}) + {'return_to': self.return_to, 'ns': OPENID2_NS}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonce(self): """use server-generated nonce""" - self.response = Message.fromOpenIDArgs( - {'ns':OPENID2_NS, 'response_nonce': mkNonce(),}) + self.response = Message.fromOpenIDArgs({'ns': OPENID2_NS, 'response_nonce': mkNonce()}) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonceOpenID1(self): """OpenID 1 does not use server-generated nonce""" self.response = Message.fromOpenIDArgs( - {'ns':OPENID1_NS, - 'return_to': 'http://return.to/', - 'response_nonce': mkNonce(),}) + {'ns': OPENID1_NS, 'return_to': 'http://return.to/', 'response_nonce': mkNonce()}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) self.failUnlessLogEmpty() @@ -905,38 +899,31 @@ def test_badNonce(self): nonce = mkNonce() stamp, salt = splitNonce(nonce) self.store.useNonce(self.server_url, stamp, salt) - self.response = Message.fromOpenIDArgs( - {'response_nonce': nonce, - 'ns':OPENID2_NS, - }) + self.response = Message.fromOpenIDArgs({'response_nonce': nonce, 'ns': OPENID2_NS}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) def test_successWithNoStore(self): """When there is no store, checking the nonce succeeds""" self.consumer.store = None - self.response = Message.fromOpenIDArgs( - {'response_nonce': mkNonce(), - 'ns':OPENID2_NS, - }) + self.response = Message.fromOpenIDArgs({'response_nonce': mkNonce(), 'ns': OPENID2_NS}) self.consumer._idResCheckNonce(self.response, self.endpoint) self.failUnlessLogEmpty() def test_tamperedNonce(self): """Malformed nonce""" - self.response = Message.fromOpenIDArgs( - {'ns':OPENID2_NS, - 'response_nonce':'malformed'}) + self.response = Message.fromOpenIDArgs({'ns': OPENID2_NS, 'response_nonce': 'malformed'}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) def test_missingNonce(self): """no nonce parameter on the return_to""" self.response = Message.fromOpenIDArgs( - {'return_to': self.return_to}) + {'return_to': self.return_to}) self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) + class CheckAuthDetectingConsumer(GenericConsumer): def _checkAuth(self, *args): raise CheckAuthHappened(args) @@ -946,6 +933,7 @@ def _idResCheckNonce(self, *args): when it asks.""" return True + class TestCheckAuthTriggered(TestIdRes, CatchLogs): consumer_class = CheckAuthDetectingConsumer @@ -956,12 +944,12 @@ def setUp(self): def test_checkAuthTriggered(self): message = Message.fromPostArgs({ - 'openid.return_to':self.return_to, - 'openid.identity':self.server_id, - 'openid.assoc_handle':'not_found', + 'openid.return_to': self.return_to, + 'openid.identity': self.server_id, + 'openid.assoc_handle': 'not_found', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) self.disableReturnToChecking() try: result = self.consumer._doIdRes(message, self.endpoint, None) @@ -981,12 +969,12 @@ def test_checkAuthTriggeredWithAssoc(self): self.store.storeAssociation(self.server_url, assoc) self.disableReturnToChecking() message = Message.fromPostArgs({ - 'openid.return_to':self.return_to, - 'openid.identity':self.server_id, - 'openid.assoc_handle':'not_found', + 'openid.return_to': self.return_to, + 'openid.identity': self.server_id, + 'openid.assoc_handle': 'not_found', 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) try: result = self.consumer._doIdRes(message, self.endpoint, None) except CheckAuthHappened: @@ -1006,12 +994,12 @@ def test_expiredAssoc(self): self.store.storeAssociation(self.server_url, assoc) message = Message.fromPostArgs({ - 'openid.return_to':self.return_to, - 'openid.identity':self.server_id, - 'openid.assoc_handle':handle, + 'openid.return_to': self.return_to, + 'openid.identity': self.server_id, + 'openid.assoc_handle': handle, 'openid.sig': GOODSIG, 'openid.signed': 'identity,return_to', - }) + }) self.disableReturnToChecking() self.failUnlessRaises(ProtocolError, self.consumer._doIdRes, message, self.endpoint, None) @@ -1032,10 +1020,10 @@ def test_newerAssoc(self): self.store.storeAssociation(self.server_url, bad_assoc) query = { - 'return_to':self.return_to, - 'identity':self.server_id, - 'assoc_handle':good_handle, - } + 'return_to': self.return_to, + 'identity': self.server_id, + 'assoc_handle': good_handle, + } message = Message.fromOpenIDArgs(query) message = good_assoc.signMessage(message) @@ -1045,7 +1033,6 @@ def test_newerAssoc(self): self.failUnlessEqual(self.consumer_id, info.identity_url) - class TestReturnToArgs(unittest.TestCase): """Verifying the Return URL paramaters. From the specification "Verifying the Return URL":: @@ -1073,7 +1060,7 @@ def test_returnToArgsOkay(self): 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=bar', 'foo': 'bar', - } + } # no return value, success is assumed if there are no exceptions. self.consumer._verifyReturnToArgs(query) @@ -1082,7 +1069,7 @@ def test_returnToEmptyArg(self): 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=', 'foo': '', - } + } # no return value, success is assumed if there are no exceptions. self.consumer._verifyReturnToArgs(query) @@ -1091,7 +1078,7 @@ def test_returnToArgsUnexpectedArg(self): 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/', 'foo': 'bar', - } + } # no return value, success is assumed if there are no exceptions. self.failUnlessRaises(ProtocolError, self.consumer._verifyReturnToArgs, query) @@ -1100,7 +1087,7 @@ def test_returnToMismatch(self): query = { 'openid.mode': 'id_res', 'openid.return_to': 'http://example.com/?foo=bar', - } + } # fail, query has no key 'foo'. self.failUnlessRaises(ValueError, self.consumer._verifyReturnToArgs, query) @@ -1110,7 +1097,6 @@ def test_returnToMismatch(self): self.failUnlessRaises(ValueError, self.consumer._verifyReturnToArgs, query) - def test_noReturnTo(self): query = {'openid.mode': 'id_res'} self.failUnlessRaises(ValueError, @@ -1135,12 +1121,11 @@ def test_completeBadReturnTo(self): # Query args differ "http://some.url/path?foo=bar2", "http://some.url/path?foo2=bar", - ] + ] m = Message(OPENID1_NS) m.setArg(OPENID_NS, 'mode', 'cancel') m.setArg(BARE_NS, 'foo', 'bar') - endpoint = None for bad in bad_return_tos: m.setArg(OPENID_NS, 'return_to', bad) @@ -1156,12 +1141,12 @@ def test_completeGoodReturnTo(self): (return_to, {}), (return_to + "?another=arg", {(BARE_NS, 'another'): 'arg'}), (return_to + "?another=arg#fragment", {(BARE_NS, 'another'): 'arg'}), - ("HTTP"+return_to[4:], {}), - (return_to.replace('url','URL'), {}), + ("HTTP" + return_to[4:], {}), + (return_to.replace('url', 'URL'), {}), ("http://some.url:80/path", {}), ("http://some.url/p%61th", {}), ("http://some.url/./path", {}), - ] + ] endpoint = None @@ -1174,9 +1159,10 @@ def test_completeGoodReturnTo(self): m.setArg(OPENID_NS, 'return_to', good) result = self.consumer.complete(m, endpoint, return_to) - self.failUnless(isinstance(result, CancelResponse), \ + self.failUnless(isinstance(result, CancelResponse), "Expected CancelResponse, got %r for %s" % (result, good,)) + class MockFetcher(object): def __init__(self, response=None): self.response = response or HTTPResponse() @@ -1186,6 +1172,7 @@ def fetch(self, url, body=None, headers=None): self.fetches.append((url, body, headers)) return self.response + class ExceptionRaisingMockFetcher(object): class MyException(Exception): pass @@ -1193,15 +1180,17 @@ class MyException(Exception): def fetch(self, url, body=None, headers=None): raise self.MyException('mock fetcher exception') + class BadArgCheckingConsumer(GenericConsumer): def _makeKVPost(self, args, _): assert args == { - 'openid.mode':'check_authentication', - 'openid.signed':'foo', - 'openid.ns':OPENID1_NS - }, args + 'openid.mode': 'check_authentication', + 'openid.signed': 'foo', + 'openid.ns': OPENID1_NS + }, args return None + class TestCheckAuth(unittest.TestCase, CatchLogs): consumer_class = GenericConsumer @@ -1223,7 +1212,7 @@ def test_error(self): self.fetcher.response = HTTPResponse( "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') query = {'openid.signed': 'stuff', - 'openid.stuff':'a value'} + 'openid.stuff': 'a value'} r = self.consumer._checkAuth(Message.fromPostArgs(query), http_server_url) self.failIf(r) @@ -1231,13 +1220,12 @@ def test_error(self): def test_bad_args(self): query = { - 'openid.signed':'foo', - 'closid.foo':'something', - } + 'openid.signed': 'foo', + 'closid.foo': 'something', + } consumer = BadArgCheckingConsumer(self.store) consumer._checkAuth(Message.fromPostArgs(query), 'does://not.matter') - def test_signedList(self): query = Message.fromOpenIDArgs({ 'mode': 'id_res', @@ -1248,41 +1236,41 @@ def test_signedList(self): 'sreg.email': 'bogus@example.com', 'signed': 'identity,mode,ns.sreg,sreg.email', 'foo': 'bar', - }) + }) args = self.consumer._createCheckAuthRequest(query) self.failUnless(args.isOpenID1()) for signed_arg in query.getArg(OPENID_NS, 'signed').split(','): - self.failUnless(args.getAliasedArg(signed_arg), signed_arg) + self.failUnless(args.getAliasedArg(signed_arg), signed_arg) def test_112(self): - args = {'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', - 'openid.claimed_id': 'http://binkley.lan/user/test01', - 'openid.identity': 'http://test01.binkley.lan/', - 'openid.mode': 'id_res', - 'openid.ns': 'http://specs.openid.net/auth/2.0', - 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', - 'openid.op_endpoint': 'http://binkley.lan/server', - 'openid.pape.auth_policies': 'none', - 'openid.pape.auth_time': '2008-01-28T20:42:36Z', - 'openid.pape.nist_auth_level': '0', - 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', - 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', - 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', - 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' - } + args = { + 'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': 'http://binkley.lan/user/test01', + 'openid.identity': 'http://test01.binkley.lan/', + 'openid.mode': 'id_res', + 'openid.ns': 'http://specs.openid.net/auth/2.0', + 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', + 'openid.op_endpoint': 'http://binkley.lan/server', + 'openid.pape.auth_policies': 'none', + 'openid.pape.auth_time': '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': '0', + 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,' + 'ns.pape,pape.nist_auth_level,pape.auth_policies'} self.failUnlessEqual(OPENID2_NS, args['openid.ns']) incoming = Message.fromPostArgs(args) self.failUnless(incoming.isOpenID2()) car = self.consumer._createCheckAuthRequest(incoming) expected_args = args.copy() expected_args['openid.mode'] = 'check_authentication' - expected =Message.fromPostArgs(expected_args) + expected = Message.fromPostArgs(expected_args) self.failUnless(expected.isOpenID2()) self.failUnlessEqual(expected, car) self.failUnlessEqual(expected_args, car.toPostArgs()) - class TestFetchAssoc(unittest.TestCase, CatchLogs): consumer_class = GenericConsumer @@ -1300,7 +1288,7 @@ def test_error_404(self): self.failUnlessRaises( fetchers.HTTPFetchingError, self.consumer._makeKVPost, - Message.fromPostArgs({'mode':'associate'}), + Message.fromPostArgs({'mode': 'associate'}), "http://server_url") def test_error_exception_unwrapped(self): @@ -1311,7 +1299,7 @@ def test_error_exception_unwrapped(self): fetchers.setDefaultFetcher(self.fetcher, wrap_exceptions=False) self.failUnlessRaises(self.fetcher.MyException, self.consumer._makeKVPost, - Message.fromPostArgs({'mode':'associate'}), + Message.fromPostArgs({'mode': 'associate'}), "http://server_url") # exception fetching returns no association @@ -1322,7 +1310,7 @@ def test_error_exception_unwrapped(self): self.failUnlessRaises(self.fetcher.MyException, self.consumer._checkAuth, - Message.fromPostArgs({'openid.signed':''}), + Message.fromPostArgs({'openid.signed': ''}), 'some://url') def test_error_exception_wrapped(self): @@ -1334,7 +1322,7 @@ def test_error_exception_wrapped(self): fetchers.setDefaultFetcher(self.fetcher) self.failUnlessRaises(fetchers.HTTPFetchingError, self.consumer._makeKVPost, - Message.fromOpenIDArgs({'mode':'associate'}), + Message.fromOpenIDArgs({'mode': 'associate'}), "http://server_url") # exception fetching returns no association @@ -1342,7 +1330,7 @@ def test_error_exception_wrapped(self): e.server_url = 'some://url' self.failUnless(self.consumer._getAssociation(e) is None) - msg = Message.fromPostArgs({'openid.signed':''}) + msg = Message.fromPostArgs({'openid.signed': ''}) self.failIf(self.consumer._checkAuth(msg, 'some://url')) @@ -1353,33 +1341,33 @@ def setUp(self): def test_extensionResponse(self): resp = mkSuccess(self.endpoint, { - 'ns.sreg':'urn:sreg', - 'ns.unittest':'urn:unittest', - 'unittest.one':'1', - 'unittest.two':'2', - 'sreg.nickname':'j3h', - 'return_to':'return_to', - }) + 'ns.sreg': 'urn:sreg', + 'ns.unittest': 'urn:unittest', + 'unittest.one': '1', + 'unittest.two': '2', + 'sreg.nickname': 'j3h', + 'return_to': 'return_to', + }) utargs = resp.extensionResponse('urn:unittest', False) - self.failUnlessEqual(utargs, {'one':'1', 'two':'2'}) + self.failUnlessEqual(utargs, {'one': '1', 'two': '2'}) sregargs = resp.extensionResponse('urn:sreg', False) - self.failUnlessEqual(sregargs, {'nickname':'j3h'}) + self.failUnlessEqual(sregargs, {'nickname': 'j3h'}) def test_extensionResponseSigned(self): args = { - 'ns.sreg':'urn:sreg', - 'ns.unittest':'urn:unittest', - 'unittest.one':'1', - 'unittest.two':'2', - 'sreg.nickname':'j3h', - 'sreg.dob':'yesterday', - 'return_to':'return_to', + 'ns.sreg': 'urn:sreg', + 'ns.unittest': 'urn:unittest', + 'unittest.one': '1', + 'unittest.two': '2', + 'sreg.nickname': 'j3h', + 'sreg.dob': 'yesterday', + 'return_to': 'return_to', 'signed': 'sreg.nickname,unittest.one,sreg.dob', - } + } signed_list = ['openid.sreg.nickname', 'openid.unittest.one', - 'openid.sreg.dob',] + 'openid.sreg.dob'] # Don't use mkSuccess because it creates an all-inclusive # signed list. @@ -1388,7 +1376,7 @@ def test_extensionResponseSigned(self): # All args in this NS are signed, so expect all. sregargs = resp.extensionResponse('urn:sreg', True) - self.failUnlessEqual(sregargs, {'nickname':'j3h', 'dob': 'yesterday'}) + self.failUnlessEqual(sregargs, {'nickname': 'j3h', 'dob': 'yesterday'}) # Not all args in this NS are signed, so expect None when # asking for them. @@ -1400,7 +1388,7 @@ def test_noReturnTo(self): self.failUnless(resp.getReturnTo() is None) def test_returnTo(self): - resp = mkSuccess(self.endpoint, {'return_to':'return_to'}) + resp = mkSuccess(self.endpoint, {'return_to': 'return_to'}) self.failUnlessEqual(resp.getReturnTo(), 'return_to') def test_displayIdentifierClaimedId(self): @@ -1414,6 +1402,7 @@ def test_displayIdentifierOverride(self): self.failUnlessEqual(resp.getDisplayIdentifier(), "http://input.url/") + class StubConsumer(object): def __init__(self): self.assoc = object() @@ -1429,11 +1418,13 @@ def complete(self, message, endpoint, return_to): assert endpoint is self.endpoint return self.response + class ConsumerTest(unittest.TestCase): """Tests for high-level consumer.Consumer functions. Its GenericConsumer component is stubbed out with StubConsumer. """ + def setUp(self): self.endpoint = OpenIDServiceEndpoint() self.endpoint.claimed_id = self.identity_url = 'http://identity.url/' @@ -1473,13 +1464,14 @@ def __init__(self, *ignored): def test_beginHTTPError(self): """Make sure that the discovery HTTP failure case behaves properly """ + def getNextService(self, ignored): raise HTTPFetchingError("Unit test") def test(): try: self.consumer.begin('unused in this test') - except DiscoveryFailure, why: + except DiscoveryFailure as why: self.failUnless(why[0].startswith('Error fetching')) self.failIf(why[0].find('Unit test') == -1) else: @@ -1492,10 +1484,11 @@ def getNextService(self, ignored): return None url = 'http://a.user.url/' + def test(): try: self.consumer.begin(url) - except DiscoveryFailure, why: + except DiscoveryFailure as why: self.failUnless(why[0].startswith('No usable OpenID')) self.failIf(why[0].find(url) == -1) else: @@ -1503,7 +1496,6 @@ def test(): self.withDummyDiscovery(test, getNextService) - def test_beginWithoutDiscovery(self): # Does this really test anything non-trivial? result = self.consumer.beginWithoutDiscovery(self.endpoint) @@ -1631,9 +1623,7 @@ def test_successDifferentURL(self): resp_endpoint = OpenIDServiceEndpoint() resp_endpoint.claimed_id = "http://user.url/" - resp = self._doRespDisco( - True, - mkSuccess(resp_endpoint, {})) + self._doRespDisco(True, mkSuccess(resp_endpoint, {})) self.failUnless(self.discovery.getManager(force=True) is None) def test_begin(self): @@ -1646,7 +1636,6 @@ def test_begin(self): self.failUnless(auth_req.assoc is self.consumer.consumer.assoc) - class IDPDrivenTest(unittest.TestCase): def setUp(self): @@ -1655,12 +1644,10 @@ def setUp(self): self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = "http://idp.unittest/" - def test_idpDrivenBegin(self): # Testing here that the token-handling doesn't explode... self.consumer.begin(self.endpoint) - def test_idpDrivenComplete(self): identifier = '=directed_identifier' message = Message.fromPostArgs({ @@ -1669,20 +1656,21 @@ def test_idpDrivenComplete(self): 'openid.assoc_handle': 'z', 'openid.signed': 'identity,return_to', 'openid.sig': GOODSIG, - }) + }) discovered_endpoint = OpenIDServiceEndpoint() discovered_endpoint.claimed_id = identifier discovered_endpoint.server_url = self.endpoint.server_url discovered_endpoint.local_id = identifier iverified = [] + def verifyDiscoveryResults(identifier, endpoint): self.failUnless(endpoint is self.endpoint) iverified.append(discovered_endpoint) return discovered_endpoint self.consumer._verifyDiscoveryResults = verifyDiscoveryResults self.consumer._idResCheckNonce = lambda *args: True - self.consumer._checkReturnTo = lambda unused1, unused2 : True + self.consumer._checkReturnTo = lambda unused1, unused2: True response = self.consumer._doIdRes(message, self.endpoint, None) self.failUnlessSuccess(response) @@ -1691,7 +1679,6 @@ def verifyDiscoveryResults(identifier, endpoint): # assert that discovery attempt happens and returns good self.failUnlessEqual(iverified, [discovered_endpoint]) - def test_idpDrivenCompleteFraud(self): # crap with an identifier that doesn't match discovery info message = Message.fromPostArgs({ @@ -1700,21 +1687,20 @@ def test_idpDrivenCompleteFraud(self): 'openid.assoc_handle': 'z', 'openid.signed': 'identity,return_to', 'openid.sig': GOODSIG, - }) + }) + def verifyDiscoveryResults(identifier, endpoint): raise DiscoveryFailure("PHREAK!", None) self.consumer._verifyDiscoveryResults = verifyDiscoveryResults - self.consumer._checkReturnTo = lambda unused1, unused2 : True + self.consumer._checkReturnTo = lambda unused1, unused2: True self.failUnlessRaises(DiscoveryFailure, self.consumer._doIdRes, message, self.endpoint, None) - def failUnlessSuccess(self, response): if response.status != SUCCESS: self.fail("Non-successful response: %s" % (response,)) - class TestDiscoveryVerification(unittest.TestCase): services = [] @@ -1732,7 +1718,7 @@ def setUp(self): 'openid.identity': self.identifier, 'openid.claimed_id': self.identifier, 'openid.op_endpoint': self.server_url, - }) + }) self.endpoint = OpenIDServiceEndpoint() self.endpoint.server_url = self.server_url @@ -1748,7 +1734,6 @@ def test_theGoodStuff(self): self.failUnlessEqual(r, endpoint) - def test_otherServer(self): text = "verify failed" @@ -1769,12 +1754,11 @@ def discoverAndVerify(claimed_id, to_match_endpoints): self.services = [endpoint] try: r = self.consumer._verifyDiscoveryResults(self.message, endpoint) - except ProtocolError, e: + except ProtocolError as e: # Should we make more ProtocolError subclasses? self.failUnless(str(e), text) else: self.fail("expected ProtocolError, %r returned." % (r,)) - def test_foreignDelegate(self): text = "verify failed" @@ -1796,7 +1780,7 @@ def discoverAndVerify(claimed_id, to_match_endpoints): try: r = self.consumer._verifyDiscoveryResults(self.message, endpoint) - except ProtocolError, e: + except ProtocolError as e: self.failUnlessEqual(str(e), text) else: self.fail("Exepected ProtocolError, %r returned" % (r,)) @@ -1808,7 +1792,6 @@ def test_nothingDiscovered(self): self.consumer._verifyDiscoveryResults, self.message, self.endpoint) - def discoveryFunc(self, identifier): return identifier, self.services @@ -1832,10 +1815,10 @@ def test_noEncryptionSendsType(self): self.failUnless(isinstance(session, PlainTextConsumerSession)) expected = Message.fromOpenIDArgs( - {'ns':OPENID2_NS, - 'session_type':session_type, - 'mode':'associate', - 'assoc_type':self.assoc_type, + {'ns': OPENID2_NS, + 'session_type': session_type, + 'mode': 'associate', + 'assoc_type': self.assoc_type, }) self.failUnlessEqual(expected, args) @@ -1847,9 +1830,9 @@ def test_noEncryptionCompatibility(self): self.endpoint, self.assoc_type, session_type) self.failUnless(isinstance(session, PlainTextConsumerSession)) - self.failUnlessEqual(Message.fromOpenIDArgs({'mode':'associate', - 'assoc_type':self.assoc_type, - }), args) + self.failUnlessEqual( + Message.fromOpenIDArgs({'mode': 'associate', 'assoc_type': self.assoc_type}), + args) def test_dhSHA1Compatibility(self): # Set the consumer's session type to a fast session since we @@ -1870,9 +1853,9 @@ def test_dhSHA1Compatibility(self): # OK, session_type is set here and not for no-encryption # compatibility - expected = Message.fromOpenIDArgs({'mode':'associate', - 'session_type':'DH-SHA1', - 'assoc_type':self.assoc_type, + expected = Message.fromOpenIDArgs({'mode': 'associate', + 'session_type': 'DH-SHA1', + 'assoc_type': self.assoc_type, 'dh_modulus': 'BfvStQ==', 'dh_gen': 'Ag==', }) @@ -1881,6 +1864,7 @@ def test_dhSHA1Compatibility(self): # XXX: test the other types + class TestDiffieHellmanResponseParameters(object): session_cls = None message_namespace = None @@ -1933,10 +1917,12 @@ def testInvalidBase64MacKey(self): self.failUnlessRaises(ValueError, self.consumer_session.extractSecret, self.msg) + class TestOpenID1SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase): session_cls = DiffieHellmanSHA1ConsumerSession message_namespace = OPENID1_NS + class TestOpenID2SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase): session_cls = DiffieHellmanSHA1ConsumerSession message_namespace = OPENID2_NS @@ -1960,18 +1946,17 @@ def notCalled(unused): endpoint.claimed_id = 'identity_url' self.consumer._getAssociation = notCalled - auth_request = self.consumer.begin(endpoint) + self.consumer.begin(endpoint) # _getAssociation was not called - - class NonAnonymousAuthRequest(object): endpoint = 'unused' def setAnonymous(self, unused): raise ValueError('Should trigger ProtocolError') + class TestConsumerAnonymous(unittest.TestCase): def test_beginWithoutDiscoveryAnonymousFail(self): """Make sure that ValueError for setting an auth request @@ -1979,6 +1964,7 @@ def test_beginWithoutDiscoveryAnonymousFail(self): """ sess = {} consumer = Consumer(sess, None) + def bogusBegin(unused): return NonAnonymousAuthRequest() consumer.consumer.begin = bogusBegin @@ -1991,6 +1977,7 @@ class TestDiscoverAndVerify(unittest.TestCase): def setUp(self): self.consumer = GenericConsumer(None) self.discovery_result = None + def dummyDiscover(unused_identifier): return self.discovery_result self.consumer._discover = dummyDiscover @@ -2014,6 +2001,7 @@ def test_noMatches(self): assertion, then we end up raising a ProtocolError """ self.discovery_result = (None, ['unused']) + def raiseProtocolError(unused1, unused2): raise ProtocolError('unit test') self.consumer._verifyDiscoverySingle = raiseProtocolError @@ -2037,12 +2025,14 @@ def returnTrue(unused1, unused2): 'http://claimed.id/', [self.to_match]) self.failUnlessEqual(matching_endpoint, result) + class SillyExtension(Extension): ns_uri = 'http://silly.example.com/' ns_alias = 'silly' def getExtensionArgs(self): - return {'i_am':'silly'} + return {'i_am': 'silly'} + class TestAddExtension(unittest.TestCase): @@ -2054,7 +2044,6 @@ def test_SillyExtension(self): self.failUnlessEqual(ext.getExtensionArgs(), ext_args) - class TestKVPost(unittest.TestCase): def setUp(self): self.server_url = 'http://unittest/%s' % (self.id(),) @@ -2065,23 +2054,21 @@ def test_200(self): response.status = 200 response.body = "foo:bar\nbaz:quux\n" r = _httpResponseToMessage(response, self.server_url) - expected_msg = Message.fromOpenIDArgs({'foo':'bar','baz':'quux'}) + expected_msg = Message.fromOpenIDArgs({'foo': 'bar', 'baz': 'quux'}) self.failUnlessEqual(expected_msg, r) - def test_400(self): response = HTTPResponse() response.status = 400 response.body = "error:bonk\nerror_code:7\n" try: r = _httpResponseToMessage(response, self.server_url) - except ServerError, e: + except ServerError as e: self.failUnlessEqual(e.error_text, 'bonk') self.failUnlessEqual(e.error_code, '7') else: self.fail("Expected ServerError, got return %r" % (r,)) - def test_500(self): # 500 as an example of any non-200, non-400 code. response = HTTPResponse() @@ -2092,7 +2079,5 @@ def test_500(self): self.server_url) - - if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index a09a1f2c..29a73ffa 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -2,7 +2,6 @@ import os.path import sys import unittest -import warnings from urlparse import urlsplit from openid import fetchers, message @@ -14,7 +13,8 @@ from . import datadriven -### Tests for conditions that trigger DiscoveryFailure +# Tests for conditions that trigger DiscoveryFailure + class SimpleMockFetcher(object): def __init__(self, responses): @@ -26,6 +26,7 @@ def fetch(self, url, body=None, headers=None): assert response.final_url == url return response + class TestDiscoveryFailure(datadriven.DataDrivenTestCase): cases = [ [HTTPResponse('http://network.error/', None)], @@ -33,9 +34,9 @@ class TestDiscoveryFailure(datadriven.DataDrivenTestCase): [HTTPResponse('http://bad.request/', 400)], [HTTPResponse('http://server.error/', 500)], [HTTPResponse('http://header.found/', 200, - headers={'x-xrds-location':'http://xrds.missing/'}), + headers={'x-xrds-location': 'http://xrds.missing/'}), HTTPResponse('http://xrds.missing/', 404)], - ] + ] def __init__(self, responses): self.url = responses[0].final_url @@ -53,14 +54,14 @@ def runOneTest(self): expected_status = self.responses[-1].status try: discover.discover(self.url) - except DiscoveryFailure, why: + except DiscoveryFailure as why: self.failUnlessEqual(why.http_response.status, expected_status) else: self.fail('Did not raise DiscoveryFailure') -### Tests for raising/catching exceptions from the fetcher through the -### discover function +# Tests for raising/catching exceptions from the fetcher through the +# discover function class ErrorRaisingFetcher(object): """Just raise an exception when fetch is called""" @@ -71,9 +72,11 @@ def __init__(self, thing_to_raise): def fetch(self, url, body=None, headers=None): raise self.thing_to_raise + class DidFetch(Exception): """Custom exception just to make sure it's not handled differently""" + class TestFetchException(datadriven.DataDrivenTestCase): """Make sure exceptions get passed through discover function from fetcher.""" @@ -83,7 +86,7 @@ class TestFetchException(datadriven.DataDrivenTestCase): DidFetch(), ValueError(), RuntimeError(), - ] + ] def __init__(self, exc): datadriven.DataDrivenTestCase.__init__(self, repr(exc)) @@ -99,7 +102,7 @@ def tearDown(self): def runOneTest(self): try: discover.discover('http://doesnt.matter/') - except: + except Exception: exc = sys.exc_info()[1] if exc is None: # str exception @@ -110,7 +113,7 @@ def runOneTest(self): self.fail('Expected %r', self.exc) -### Tests for openid.consumer.discover.discover +# Tests for openid.consumer.discover.discover class TestNormalization(unittest.TestCase): def testAddingProtocol(self): @@ -119,10 +122,10 @@ def testAddingProtocol(self): try: discover.discover('users.stompy.janrain.com:8000/x') - except DiscoveryFailure, why: + except DiscoveryFailure: self.fail('failed to parse url with port correctly') except RuntimeError: - pass #expected + pass # expected fetchers.setDefaultFetcher(None) @@ -154,6 +157,7 @@ def fetch(self, url, body=None, headers=None): # from twisted.trial import unittest as trialtest + class BaseTestDiscovery(unittest.TestCase): id_url = "http://someuser.unittest/" @@ -195,7 +199,7 @@ def _checkService(self, s, '1.0': discover.OPENID_1_0_TYPE, '2.0': discover.OPENID_2_0_TYPE, '2.0 OP': discover.OPENID_IDP_2_0_TYPE, - } + } type_uris = [openid_types[t] for t in types] self.failUnlessEqual(type_uris, s.type_uris) @@ -217,12 +221,14 @@ def setUp(self): def tearDown(self): fetchers.setDefaultFetcher(None) + def readDataFile(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) filename = os.path.join( module_directory, 'data', 'test_discover', filename) return file(filename).read() + class TestDiscovery(BaseTestDiscovery): def _discover(self, content_type, data, expected_services, expected_id=None): @@ -254,8 +260,7 @@ def test_unicode_undecodable_html(self): """ data = readDataFile('unicode2.html') self.failUnlessRaises(UnicodeDecodeError, data.decode, 'utf-8') - self._discover(content_type='text/html;charset=utf-8', - data=data, expected_services=0) + self._discover(content_type='text/html;charset=utf-8', data=data, expected_services=0) def test_unicode_undecodable_html2(self): """ @@ -267,8 +272,7 @@ def test_unicode_undecodable_html2(self): data = readDataFile('unicode3.html') self.failUnlessRaises(UnicodeDecodeError, data.decode, 'utf-8') - self._discover(content_type='text/html;charset=utf-8', - data=data, expected_services=1) + self._discover(content_type='text/html;charset=utf-8', data=data, expected_services=1) def test_noOpenID(self): services = self._discover(content_type='text/plain', @@ -279,7 +283,7 @@ def test_noOpenID(self): content_type='text/html', data=readDataFile('openid_no_delegate.html'), expected_services=1, - ) + ) self._checkService( services[0], @@ -288,7 +292,7 @@ def test_noOpenID(self): server_url="http://www.myopenid.com/server", claimed_id=self.id_url, local_id=self.id_url, - ) + ) def test_html1(self): services = self._discover( @@ -296,7 +300,6 @@ def test_html1(self): data=readDataFile('openid.html'), expected_services=1) - self._checkService( services[0], used_yadis=False, @@ -305,7 +308,7 @@ def test_html1(self): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_html1Fragment(self): """Ensure that the Claimed Identifier does not have a fragment @@ -329,14 +332,14 @@ def test_html1Fragment(self): claimed_id=expected_id, local_id='http://smoker.myopenid.com/', display_identifier=expected_id, - ) + ) def test_html2(self): services = self._discover( content_type='text/html', data=readDataFile('openid2.html'), expected_services=1, - ) + ) self._checkService( services[0], @@ -346,14 +349,14 @@ def test_html2(self): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_html1And2(self): services = self._discover( content_type='text/html', data=readDataFile('openid_1_and_2.html'), expected_services=2, - ) + ) for t, s in zip(['2.0', '1.1'], services): self._checkService( @@ -364,12 +367,11 @@ def test_html1And2(self): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_yadisEmpty(self): - services = self._discover(content_type='application/xrds+xml', - data=readDataFile('yadis_0entries.xml'), - expected_services=0) + self._discover(content_type='application/xrds+xml', data=readDataFile('yadis_0entries.xml'), + expected_services=0) def test_htmlEmptyYadis(self): """HTML document has discovery information, but points to an @@ -390,7 +392,7 @@ def test_htmlEmptyYadis(self): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_yadis1NoDelegate(self): services = self._discover(content_type='application/xrds+xml', @@ -405,14 +407,14 @@ def test_yadis1NoDelegate(self): claimed_id=self.id_url, local_id=self.id_url, display_identifier=self.id_url, - ) + ) def test_yadis2NoLocalID(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid2_xrds_no_local_id.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -422,14 +424,14 @@ def test_yadis2NoLocalID(self): claimed_id=self.id_url, local_id=self.id_url, display_identifier=self.id_url, - ) + ) def test_yadis2(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid2_xrds.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -439,14 +441,14 @@ def test_yadis2(self): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_yadis2OP(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('yadis_idp.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -454,7 +456,7 @@ def test_yadis2OP(self): types=['2.0 OP'], server_url="http://www.myopenid.com/server", display_identifier=self.id_url, - ) + ) def test_yadis2OPDelegate(self): """The delegate tag isn't meaningful for OP entries.""" @@ -462,7 +464,7 @@ def test_yadis2OPDelegate(self): content_type='application/xrds+xml', data=readDataFile('yadis_idp_delegate.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -470,21 +472,20 @@ def test_yadis2OPDelegate(self): types=['2.0 OP'], server_url="http://www.myopenid.com/server", display_identifier=self.id_url, - ) + ) def test_yadis2BadLocalID(self): self.failUnlessRaises(DiscoveryFailure, self._discover, - content_type='application/xrds+xml', - data=readDataFile('yadis_2_bad_local_id.xml'), - expected_services=1, - ) + content_type='application/xrds+xml', + data=readDataFile('yadis_2_bad_local_id.xml'), + expected_services=1) def test_yadis1And2(self): services = self._discover( content_type='application/xrds+xml', data=readDataFile('openid_1_and_2_xrds.xml'), expected_services=1, - ) + ) self._checkService( services[0], @@ -494,14 +495,14 @@ def test_yadis1And2(self): claimed_id=self.id_url, local_id='http://smoker.myopenid.com/', display_identifier=self.id_url, - ) + ) def test_yadis1And2BadLocalID(self): self.failUnlessRaises(DiscoveryFailure, self._discover, - content_type='application/xrds+xml', - data=readDataFile('openid_1_and_2_xrds_bad_delegate.xml'), - expected_services=1, - ) + content_type='application/xrds+xml', + data=readDataFile('openid_1_and_2_xrds_bad_delegate.xml'), + expected_services=1) + class MockFetcherForXRIProxy(object): @@ -510,12 +511,10 @@ def __init__(self, documents, proxy_url=xrires.DEFAULT_PROXY): self.fetchlog = [] self.proxy_url = None - def fetch(self, url, body=None, headers=None): self.fetchlog.append((url, body, headers)) u = urlsplit(url) - proxy_host = u[1] xri = u[2] query = u[3] @@ -544,7 +543,7 @@ class TestXRIDiscovery(BaseTestDiscovery): documents = {'=smoker': ('application/xrds+xml', readDataFile('yadis_2entries_delegate.xml')), '=smoker*bad': ('application/xrds+xml', - readDataFile('yadis_another_delegate.xml')) } + readDataFile('yadis_another_delegate.xml'))} def test_xri(self): user_xri, services = discover.discoverXRI('=smoker') @@ -558,7 +557,7 @@ def test_xri(self): canonical_id=XRI("=!1000"), local_id='http://smoker.myopenid.com/', display_identifier='=smoker' - ) + ) self._checkService( services[1], @@ -569,7 +568,7 @@ def test_xri(self): canonical_id=XRI("=!1000"), local_id='http://frank.livejournal.com/', display_identifier='=smoker' - ) + ) def test_xri_normalize(self): user_xri, services = discover.discoverXRI('xri://=smoker') @@ -583,7 +582,7 @@ def test_xri_normalize(self): canonical_id=XRI("=!1000"), local_id='http://smoker.myopenid.com/', display_identifier='=smoker' - ) + ) self._checkService( services[1], @@ -594,7 +593,7 @@ def test_xri_normalize(self): canonical_id=XRI("=!1000"), local_id='http://frank.livejournal.com/', display_identifier='=smoker' - ) + ) def test_xriNoCanonicalID(self): user_xri, services = discover.discoverXRI('=smoker*bad') @@ -613,7 +612,7 @@ class TestXRIDiscoveryIDP(BaseTestDiscovery): fetcherClass = MockFetcherForXRIProxy documents = {'=smoker': ('application/xrds+xml', - readDataFile('yadis_2entries_idp.xml')) } + readDataFile('yadis_2entries_idp.xml'))} def test_xri(self): user_xri, services = discover.discoverXRI('=smoker') @@ -646,7 +645,8 @@ def runOneTest(self): discover.OPENID_1_0_TYPE]), (message.OPENID2_NS, [discover.OPENID_1_0_TYPE, discover.OPENID_2_0_TYPE]), - ] + ] + class TestIsOPIdentifier(unittest.TestCase): def setUp(self): @@ -682,6 +682,7 @@ def test_multiplePresent(self): discover.OPENID_IDP_2_0_TYPE] self.failUnless(self.endpoint.isOPIdentifier()) + class TestFromOPEndpointURL(unittest.TestCase): def setUp(self): self.op_endpoint_url = 'http://example.com/op/endpoint' @@ -704,6 +705,7 @@ def test_canonicalID(self): def test_serverURL(self): self.failUnlessEqual(self.endpoint.server_url, self.op_endpoint_url) + class TestDiscoverFunction(unittest.TestCase): def setUp(self): self._old_discoverURI = discover.discoverURI @@ -734,6 +736,7 @@ def test_xri(self): def test_xriChar(self): self.failUnlessEqual('XRI', discover.discover('=something')) + class TestEndpointSupportsType(unittest.TestCase): def setUp(self): self.endpoint = discover.OpenIDServiceEndpoint() @@ -745,7 +748,7 @@ def failUnlessSupportsOnly(self, *types): discover.OPENID_1_0_TYPE, discover.OPENID_2_0_TYPE, discover.OPENID_IDP_2_0_TYPE, - ]: + ]: if t in types: self.failUnless(self.endpoint.supportsType(t), "Must support %r" % (t,)) @@ -799,6 +802,7 @@ def test_strip_fragment(self): def pyUnitTests(): return datadriven.loadTests(__name__) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index c3ff68ab..cae2712c 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -8,7 +8,8 @@ def datapath(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) return os.path.join(module_directory, 'data', 'test_etxrd', filename) -XRD_FILE = datapath('valid-populated-xrds.xml') + +XRD_FILE = datapath('valid-populated-xrds.xml') NOXRDS_FILE = datapath('not-xrds.xml') NOXRD_FILE = datapath('no-xrd.xml') @@ -18,6 +19,7 @@ def datapath(filename): LID_2_0 = "http://lid.netmesh.org/sso/2.0b5" TYPEKEY_1_0 = "http://typekey.com/services/1.0" + def simpleOpenIDTransformer(endpoint): """Function to extract information from an OpenID service element""" if 'http://openid.net/signon/1.0' not in endpoint.type_uris: @@ -29,6 +31,7 @@ def simpleOpenIDTransformer(endpoint): delegate = delegates[0].text return (endpoint.uri, delegate) + class TestServiceParser(unittest.TestCase): def setUp(self): self.xmldoc = file(XRD_FILE).read() @@ -39,7 +42,7 @@ def _getServices(self, flt=None): def testParse(self): """Make sure that parsing succeeds at all""" - services = self._getServices() + self._getServices() def testParseOpenID(self): """Parse for OpenID services with a transformer function""" @@ -50,7 +53,7 @@ def testParseOpenID(self): ("http://www.schtuff.com/openid", "http://users.schtuff.com/josh"), ("http://www.livejournal.com/openid/server.bml", "http://www.livejournal.com/users/nedthealpaca/"), - ] + ] it = iter(services) for (server_url, delegate) in expectedServices: @@ -79,16 +82,13 @@ def testGetSeveral(self): # type, URL (TYPEKEY_1_0, None), (LID_2_0, "http://mylid.net/josh"), - ] + ] self._checkServices(expectedServices) def testGetSeveralForOne(self): """Getting services for one Service with several Type elements.""" - types = [ 'http://lid.netmesh.org/sso/2.0b5' - , 'http://lid.netmesh.org/2.0b5' - ] - + types = ['http://lid.netmesh.org/sso/2.0b5', 'http://lid.netmesh.org/2.0b5'] uri = "http://mylid.net/josh" for service in self._getServices(): @@ -131,6 +131,7 @@ def mkTest(iname, filename, expectedID): test for the given set of inputs""" filename = datapath(filename) + def test(self): xrds = etxrd.parseXRDS(file(filename).read()) self._getCanonicalID(iname, xrds, expectedID) diff --git a/openid/test/test_examples.py b/openid/test/test_examples.py index ca83d839..e1c5797f 100644 --- a/openid/test/test_examples.py +++ b/openid/test/test_examples.py @@ -52,6 +52,7 @@ def splitDir(d, count): d = os.path.dirname(d) return d + def runExampleServer(host, port, data_path): thisfile = os.path.abspath(sys.modules[__name__].__file__) topDir = splitDir(thisfile, 3) @@ -64,7 +65,6 @@ def runExampleServer(host, port, data_path): serverMain(host, port, data_path) - class TestServer(unittest.TestCase): """Acceptance tests for examples/server.py. @@ -88,13 +88,11 @@ def setUp(self): twill.commands.reset_browser() - def runExampleServer(self): """Zero-arg run-the-server function to be passed to TestInfo.""" # FIXME - make sure sstore starts clean. runExampleServer('127.0.0.1', self.server_port, 'sstore') - def v1endpoint(self, port): """Return an OpenID 1.1 OpenIDServiceEndpoint for the server.""" base = "http://%s:%s" % (socket.getfqdn('127.0.0.1'), port) @@ -104,7 +102,6 @@ def v1endpoint(self, port): ep.type_uris = [OPENID_1_1_TYPE] return ep - # TODO: test discovery def test_checkidv1(self): @@ -116,7 +113,6 @@ def test_checkidv1(self): if self.twillErr.getvalue(): self.fail(self.twillErr.getvalue()) - def test_allowed(self): """OpenID 1.1 checkid_setup request.""" ti = TwillTest(self.twill_allowed, self.runExampleServer, @@ -126,7 +122,6 @@ def test_allowed(self): if self.twillErr.getvalue(): self.fail(self.twillErr.getvalue()) - def twill_checkidv1(self, twillInfo): endpoint = self.v1endpoint(self.server_port) authreq = AuthRequest(endpoint, assoc=None) @@ -143,12 +138,11 @@ def twill_checkidv1(self, twillInfo): finalURL = headers['Location'] self.failUnless('openid.mode=id_res' in finalURL, finalURL) self.failUnless('openid.identity=' in finalURL, finalURL) - except twill.commands.TwillAssertionError, e: + except twill.commands.TwillAssertionError as e: msg = '%s\nFinal page:\n%s' % ( str(e), c.get_browser().get_html()) self.fail(msg) - def twill_allowed(self, twillInfo): endpoint = self.v1endpoint(self.server_port) authreq = AuthRequest(endpoint, assoc=None) @@ -171,7 +165,7 @@ def twill_allowed(self, twillInfo): headers = c.get_browser()._browser.response().info() finalURL = headers['Location'] self.failUnless(finalURL.startswith(self.return_to)) - except twill.commands.TwillAssertionError, e: + except twill.commands.TwillAssertionError: from traceback import format_exc msg = '%s\nTwill output:%s\nTwill errors:%s\nFinal page:\n%s' % ( format_exc(), @@ -180,7 +174,6 @@ def twill_allowed(self, twillInfo): c.get_browser().get_html()) self.fail(msg) - def tearDown(self): twill.set_output(None) twill.set_errout(None) diff --git a/openid/test/test_extension.py b/openid/test/test_extension.py index 11ba1b26..0f714c62 100644 --- a/openid/test/test_extension.py +++ b/openid/test/test_extension.py @@ -10,6 +10,7 @@ class DummyExtension(extension.Extension): def getExtensionArgs(self): return {} + class ToMessageTest(unittest.TestCase): def test_OpenID1(self): oid1_msg = message.Message(message.OPENID1_NS) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 1ec5641f..4cf5a22b 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -13,6 +13,7 @@ # XXX: make these separate test cases + def failUnlessResponseExpected(expected, actual): assert expected.final_url == actual.final_url, ( "%r != %r" % (expected.final_url, actual.final_url)) @@ -32,7 +33,7 @@ def geturl(path): server.socket.getsockname()[1], path) - expected_headers = {'content-type':'text/plain'} + expected_headers = {'content-type': 'text/plain'} def plain(path, code): path = '/' + path @@ -53,15 +54,13 @@ def plain(path, code): plain('forbidden', 403), plain('error', 500), plain('server_error', 503), - ] + ] for path, expected in cases: fetch_url = geturl(path) try: actual = fetcher.fetch(fetch_url) - except (SystemExit, KeyboardInterrupt): - pass - except: + except Exception: print fetcher, fetch_url raise else: @@ -73,29 +72,28 @@ def plain(path, code): 'ftp://janrain.com/pub/']: try: result = fetcher.fetch(err_url) - except (KeyboardInterrupt, SystemExit): - raise - except fetchers.HTTPError, why: + except fetchers.HTTPError: # This is raised by the Curl fetcher for bad cases # detected by the fetchers module, but it's a subclass of # HTTPFetchingError, so we have to catch it explicitly. assert exc - except fetchers.HTTPFetchingError, why: + except fetchers.HTTPFetchingError: assert not exc, (fetcher, exc, server) - except: + except Exception: assert exc else: assert False, 'An exception was expected for %r (%r)' % (fetcher, result) + def run_fetcher_tests(server): exc_fetchers = [] for klass, library_name in [ (fetchers.CurlHTTPFetcher, 'pycurl'), (fetchers.HTTPLib2Fetcher, 'httplib2'), - ]: + ]: try: exc_fetchers.append(klass()) - except RuntimeError, why: + except RuntimeError as why: if why[0].startswith('Cannot find %s library' % (library_name,)): try: __import__(library_name) @@ -122,17 +120,17 @@ def run_fetcher_tests(server): class FetcherTestHandler(BaseHTTPRequestHandler): cases = { - '/success':(200, None), - '/301redirect':(301, '/success'), - '/302redirect':(302, '/success'), - '/303redirect':(303, '/success'), - '/307redirect':(307, '/success'), - '/notfound':(404, None), - '/badreq':(400, None), - '/forbidden':(403, None), - '/error':(500, None), - '/server_error':(503, None), - } + '/success': (200, None), + '/301redirect': (301, '/success'), + '/302redirect': (302, '/success'), + '/303redirect': (303, '/success'), + '/307redirect': (307, '/success'), + '/notfound': (404, None), + '/badreq': (400, None), + '/forbidden': (403, None), + '/error': (500, None), + '/server_error': (503, None), + } def log_request(self, *args): pass @@ -173,7 +171,7 @@ def errorResponse(self, message=None): req = [ ('HTTP method', self.command), ('path', self.path), - ] + ] if message: req.append(('message', message)) @@ -197,6 +195,7 @@ def finish(self): self.wfile.close() self.rfile.close() + def test(): import socket host = socket.getfqdn('127.0.0.1') @@ -215,12 +214,14 @@ def test(): run_fetcher_tests(server) + class FakeFetcher(object): sentinel = object() def fetch(self, *args, **kwargs): return self.sentinel + class DefaultFetcherTest(unittest.TestCase): def setUp(self): """reset the default fetcher to None""" @@ -276,7 +277,7 @@ def test_notWrapped(self): fetchers.fetch('http://invalid.janrain.com/') except fetchers.HTTPFetchingError: self.fail('Should not be wrapping exception') - except: + except Exception: exc = sys.exc_info()[1] self.failUnless(isinstance(exc, urllib2.URLError), exc) pass diff --git a/openid/test/test_htmldiscover.py b/openid/test/test_htmldiscover.py index e310435d..188565b2 100644 --- a/openid/test/test_htmldiscover.py +++ b/openid/test/test_htmldiscover.py @@ -8,7 +8,7 @@ class BadLinksTestCase(datadriven.DataDrivenTestCase): '', "http://not.in.a.link.tag/", '', - ] + ] def __init__(self, data): datadriven.DataDrivenTestCase.__init__(self, data) @@ -19,5 +19,6 @@ def runOneTest(self): expected = [] self.failUnlessEqual(expected, actual) + def pyUnitTests(): return datadriven.loadTests(__name__) diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 571eec95..be6fc210 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -24,6 +24,7 @@ def test(self): return test + class EmptyMessageTest(unittest.TestCase): def setUp(self): self.msg = message.Message() @@ -94,7 +95,7 @@ def test_getAliasedArgSuccess(self): 'openid.test.flub': 'bogus'}) actual_uri = msg.getAliasedArg('ns.test', message.no_default) self.assertEquals("urn://foo", actual_uri) - + def test_getAliasedArgFailure(self): msg = message.Message.fromPostArgs({'openid.test.flub': 'bogus'}) self.assertRaises(KeyError, @@ -136,13 +137,13 @@ def test_getArgsNS3(self): def test_updateArgs(self): self.failUnlessRaises(message.UndefinedOpenIDNamespace, self.msg.updateArgs, message.OPENID_NS, - {'does not':'matter'}) + {'does not': 'matter'}) def _test_updateArgsNS(self, ns): update_args = { - 'Camper van Beethoven':'David Lowery', - 'Magnolia Electric Co.':'Jason Molina', - } + 'Camper van Beethoven': 'David Lowery', + 'Magnolia Electric Co.': 'Jason Molina', + } self.failUnlessEqual(self.msg.getArgs(ns), {}) self.msg.updateArgs(ns, update_args) @@ -219,19 +220,20 @@ def test_isOpenID1(self): def test_isOpenID2(self): self.failIf(self.msg.isOpenID2()) + class OpenID1MessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode':'error', - 'openid.error':'unit test'}) + self.msg = message.Message.fromPostArgs({'openid.mode': 'error', + 'openid.error': 'unit test'}) def test_toPostArgs(self): self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode':'error', - 'openid.error':'unit test'}) + {'openid.mode': 'error', + 'openid.error': 'unit test'}) def test_toArgs(self): - self.failUnlessEqual(self.msg.toArgs(), {'mode':'error', - 'error':'unit test'}) + self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', + 'error': 'unit test'}) def test_toKVForm(self): self.failUnlessEqual(self.msg.toKVForm(), @@ -249,8 +251,8 @@ def test_toURL(self): self.failUnlessEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] parsed = cgi.parse_qs(query) - self.failUnlessEqual(parsed, {'openid.mode':['error'], - 'openid.error':['unit test']}) + self.failUnlessEqual(parsed, {'openid.mode': ['error'], + 'openid.error': ['unit test']}) def test_getOpenID(self): self.failUnlessEqual(self.msg.getOpenIDNamespace(), message.OPENID1_NS) @@ -298,18 +300,14 @@ def test_hasKeyNS3(self): def test_getArgs(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID_NS), - {'mode':'error', - 'error':'unit test', - }) + {'mode': 'error', 'error': 'unit test'}) def test_getArgsBARE(self): self.failUnlessEqual(self.msg.getArgs(message.BARE_NS), {}) def test_getArgsNS1(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID1_NS), - {'mode':'error', - 'error':'unit test', - }) + {'mode': 'error', 'error': 'unit test'}) def test_getArgsNS2(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID2_NS), {}) @@ -321,9 +319,9 @@ def _test_updateArgsNS(self, ns, before=None): if before is None: before = {} update_args = { - 'Camper van Beethoven':'David Lowery', - 'Magnolia Electric Co.':'Jason Molina', - } + 'Camper van Beethoven': 'David Lowery', + 'Magnolia Electric Co.': 'Jason Molina', + } self.failUnlessEqual(self.msg.getArgs(ns), before) self.msg.updateArgs(ns, update_args) @@ -333,14 +331,14 @@ def _test_updateArgsNS(self, ns, before=None): def test_updateArgs(self): self._test_updateArgsNS(message.OPENID_NS, - before={'mode':'error', 'error':'unit test'}) + before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsBARE(self): self._test_updateArgsNS(message.BARE_NS) def test_updateArgsNS1(self): self._test_updateArgsNS(message.OPENID1_NS, - before={'mode':'error', 'error':'unit test'}) + before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsNS2(self): self._test_updateArgsNS(message.OPENID2_NS) @@ -395,40 +393,40 @@ def test_delArgNS2(self): def test_delArgNS3(self): self._test_delArgNS('urn:nothing-significant') - def test_isOpenID1(self): self.failUnless(self.msg.isOpenID1()) def test_isOpenID2(self): self.failIf(self.msg.isOpenID2()) + class OpenID1ExplicitMessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID1_NS + self.msg = message.Message.fromPostArgs({'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID1_NS }) def test_toPostArgs(self): self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID1_NS + {'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID1_NS }) def test_toArgs(self): - self.failUnlessEqual(self.msg.toArgs(), {'mode':'error', - 'error':'unit test', - 'ns':message.OPENID1_NS}) + self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', + 'error': 'unit test', + 'ns': message.OPENID1_NS}) def test_toKVForm(self): self.failUnlessEqual(self.msg.toKVForm(), - 'error:unit test\nmode:error\nns:%s\n' - %message.OPENID1_NS) + 'error:unit test\nmode:error\nns:%s\n' % message.OPENID1_NS) def test_toURLEncoded(self): - self.failUnlessEqual(self.msg.toURLEncoded(), - 'openid.error=unit+test&openid.mode=error&openid.ns=https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fopenid.net%2Fsignon%2F1.0') + self.failUnlessEqual( + self.msg.toURLEncoded(), + 'openid.error=unit+test&openid.mode=error&openid.ns=https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fopenid.net%2Fsignon%2F1.0') def test_toURL(self): base_url = 'http://base.url/' @@ -438,50 +436,48 @@ def test_toURL(self): self.failUnlessEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] parsed = cgi.parse_qs(query) - self.failUnlessEqual(parsed, {'openid.mode':['error'], - 'openid.error':['unit test'], - 'openid.ns':[message.OPENID1_NS] - }) + self.failUnlessEqual( + parsed, + {'openid.mode': ['error'], 'openid.error': ['unit test'], 'openid.ns': [message.OPENID1_NS]}) def test_isOpenID1(self): self.failUnless(self.msg.isOpenID1()) + class OpenID2MessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID2_NS - }) + self.msg = message.Message.fromPostArgs({'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS}) self.msg.setArg(message.BARE_NS, "xey", "value") def test_toPostArgs(self): self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID2_NS, + {'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS, 'xey': 'value', }) def test_toPostArgs_bug_with_utf8_encoded_values(self): - msg = message.Message.fromPostArgs({'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID2_NS - }) + msg = message.Message.fromPostArgs({'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS + }) msg.setArg(message.BARE_NS, 'ünicöde_key', 'ünicöde_välüe') self.failUnlessEqual(msg.toPostArgs(), - {'openid.mode':'error', - 'openid.error':'unit test', - 'openid.ns':message.OPENID2_NS, + {'openid.mode': 'error', + 'openid.error': 'unit test', + 'openid.ns': message.OPENID2_NS, 'ünicöde_key': 'ünicöde_välüe', }) - def test_toArgs(self): # This method can't tolerate BARE_NS. self.msg.delArg(message.BARE_NS, "xey") - self.failUnlessEqual(self.msg.toArgs(), {'mode':'error', - 'error':'unit test', - 'ns':message.OPENID2_NS, + self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', + 'error': 'unit test', + 'ns': message.OPENID2_NS, }) def test_toKVForm(self): @@ -492,12 +488,10 @@ def test_toKVForm(self): (message.OPENID2_NS,)) def _test_urlencoded(self, s): - expected = ('openid.error=unit+test&openid.mode=error&' - 'openid.ns=%s&xey=value' % ( - urllib.quote(message.OPENID2_NS, ''),)) + expected = ('openid.error=unit+test&openid.mode=error&openid.ns=%s&xey=value' % + urllib.quote(message.OPENID2_NS, '')) self.failUnlessEqual(s, expected) - def test_toURLEncoded(self): self._test_urlencoded(self.msg.toURLEncoded()) @@ -558,9 +552,7 @@ def test_hasKeyNS3(self): def test_getArgsOpenID(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID_NS), - {'mode':'error', - 'error':'unit test', - }) + {'mode': 'error', 'error': 'unit test'}) def test_getArgsBARE(self): self.failUnlessEqual(self.msg.getArgs(message.BARE_NS), @@ -571,9 +563,7 @@ def test_getArgsNS1(self): def test_getArgsNS2(self): self.failUnlessEqual(self.msg.getArgs(message.OPENID2_NS), - {'mode':'error', - 'error':'unit test', - }) + {'mode': 'error', 'error': 'unit test'}) def test_getArgsNS3(self): self.failUnlessEqual(self.msg.getArgs('urn:nothing-significant'), {}) @@ -582,9 +572,9 @@ def _test_updateArgsNS(self, ns, before=None): if before is None: before = {} update_args = { - 'Camper van Beethoven':'David Lowery', - 'Magnolia Electric Co.':'Jason Molina', - } + 'Camper van Beethoven': 'David Lowery', + 'Magnolia Electric Co.': 'Jason Molina', + } self.failUnlessEqual(self.msg.getArgs(ns), before) self.msg.updateArgs(ns, update_args) @@ -594,18 +584,18 @@ def _test_updateArgsNS(self, ns, before=None): def test_updateArgsOpenID(self): self._test_updateArgsNS(message.OPENID_NS, - before={'mode':'error', 'error':'unit test'}) + before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsBARE(self): self._test_updateArgsNS(message.BARE_NS, - before={'xey':'value'}) + before={'xey': 'value'}) def test_updateArgsNS1(self): self._test_updateArgsNS(message.OPENID1_NS) def test_updateArgsNS2(self): self._test_updateArgsNS(message.OPENID2_NS, - before={'mode':'error', 'error':'unit test'}) + before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsNS3(self): self._test_updateArgsNS('urn:nothing-significant') @@ -649,52 +639,53 @@ def test_badAlias(self): def test_mysterious_missing_namespace_bug(self): """A failing test for bug #112""" openid_args = { - 'assoc_handle': '{{HMAC-SHA256}{1211477242.29743}{v5cadg==}', - 'claimed_id': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', - 'ns.sreg': 'http://openid.net/extensions/sreg/1.1', - 'response_nonce': '2008-05-22T17:27:22ZUoW5.\\NV', - 'signed': 'return_to,identity,claimed_id,op_endpoint,response_nonce,ns.sreg,sreg.email,sreg.nickname,assoc_handle', - 'sig': 'e3eGZ10+TNRZitgq5kQlk5KmTKzFaCRI8OrRoXyoFa4=', - 'mode': 'check_authentication', - 'op_endpoint': 'http://nerdbank.org/OPAffirmative/ProviderNoAssoc.aspx', - 'sreg.nickname': 'Andy', - 'return_to': 'http://localhost.localdomain:8001/process?janrain_nonce=2008-05-22T17%3A27%3A21ZnxHULd', - 'invalidate_handle': '{{HMAC-SHA1}{1211477241.92242}{H0akXw==}', - 'identity': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', - 'sreg.email': 'a@b.com' - } + 'assoc_handle': '{{HMAC-SHA256}{1211477242.29743}{v5cadg==}', + 'claimed_id': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', + 'ns.sreg': 'http://openid.net/extensions/sreg/1.1', + 'response_nonce': '2008-05-22T17:27:22ZUoW5.\\NV', + 'signed': 'return_to,identity,claimed_id,op_endpoint,response_nonce,ns.sreg,sreg.email,sreg.nickname,' + 'assoc_handle', + 'sig': 'e3eGZ10+TNRZitgq5kQlk5KmTKzFaCRI8OrRoXyoFa4=', + 'mode': 'check_authentication', + 'op_endpoint': 'http://nerdbank.org/OPAffirmative/ProviderNoAssoc.aspx', + 'sreg.nickname': 'Andy', + 'return_to': 'http://localhost.localdomain:8001/process?janrain_nonce=2008-05-22T17%3A27%3A21ZnxHULd', + 'invalidate_handle': '{{HMAC-SHA1}{1211477241.92242}{H0akXw==}', + 'identity': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', + 'sreg.email': 'a@b.com'} m = message.Message.fromOpenIDArgs(openid_args) self.failUnless(('http://openid.net/extensions/sreg/1.1', 'sreg') in list(m.namespaces.iteritems())) missing = [] for k in openid_args['signed'].split(','): - if not ("openid."+k) in m.toPostArgs().keys(): + if not ("openid." + k) in m.toPostArgs().keys(): missing.append(k) self.assertEqual([], missing, missing) self.assertEqual(openid_args, m.toArgs()) self.failUnless(m.isOpenID1()) def test_112B(self): - args = {'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', - 'openid.claimed_id': 'http://binkley.lan/user/test01', - 'openid.identity': 'http://test01.binkley.lan/', - 'openid.mode': 'id_res', - 'openid.ns': 'http://specs.openid.net/auth/2.0', - 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', - 'openid.op_endpoint': 'http://binkley.lan/server', - 'openid.pape.auth_policies': 'none', - 'openid.pape.auth_time': '2008-01-28T20:42:36Z', - 'openid.pape.nist_auth_level': '0', - 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', - 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', - 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', - 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies' - } + args = { + 'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': 'http://binkley.lan/user/test01', + 'openid.identity': 'http://test01.binkley.lan/', + 'openid.mode': 'id_res', + 'openid.ns': 'http://specs.openid.net/auth/2.0', + 'openid.ns.pape': 'http://specs.openid.net/extensions/pape/1.0', + 'openid.op_endpoint': 'http://binkley.lan/server', + 'openid.pape.auth_policies': 'none', + 'openid.pape.auth_time': '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': '0', + 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,' + 'ns.pape,pape.nist_auth_level,pape.auth_policies'} m = message.Message.fromPostArgs(args) missing = [] for k in args['openid.signed'].split(','): - if not ("openid."+k) in m.toPostArgs().keys(): + if not ("openid." + k) in m.toPostArgs().keys(): missing.append(k) self.assertEqual([], missing, missing) self.assertEqual(args, m.toPostArgs()) @@ -704,27 +695,27 @@ def test_repetitive_namespaces(self): """ Message that raises KeyError during encoding, because openid namespace is used in attributes """ - args = {'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', - 'openid.claimed_id': 'http://binkley.lan/user/test01', - 'openid.identity': 'http://test01.binkley.lan/', - 'openid.mode': 'id_res', - 'openid.ns': 'http://specs.openid.net/auth/2.0', - 'openid.op_endpoint': 'http://binkley.lan/server', - 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', - 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', - 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', - 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,ns.pape,pape.nist_auth_level,pape.auth_policies', - 'openid.ns.pape': 'http://specs.openid.net/auth/2.0', - 'openid.pape.auth_policies': 'none', - 'openid.pape.auth_time': '2008-01-28T20:42:36Z', - 'openid.pape.nist_auth_level': '0', - } + args = { + 'openid.assoc_handle': 'fa1f5ff0-cde4-11dc-a183-3714bfd55ca8', + 'openid.claimed_id': 'http://binkley.lan/user/test01', + 'openid.identity': 'http://test01.binkley.lan/', + 'openid.mode': 'id_res', + 'openid.ns': 'http://specs.openid.net/auth/2.0', + 'openid.op_endpoint': 'http://binkley.lan/server', + 'openid.response_nonce': '2008-01-28T21:07:04Z99Q=', + 'openid.return_to': 'http://binkley.lan:8001/process?janrain_nonce=2008-01-28T21%3A07%3A02Z0tMIKx', + 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', + 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,' + 'ns.pape,pape.nist_auth_level,pape.auth_policies', + 'openid.ns.pape': 'http://specs.openid.net/auth/2.0', + 'openid.pape.auth_policies': 'none', + 'openid.pape.auth_time': '2008-01-28T20:42:36Z', + 'openid.pape.nist_auth_level': '0', + } self.failUnlessRaises(message.InvalidNamespace, message.Message.fromPostArgs, args) def test_implicit_sreg_ns(self): - openid_args = { - 'sreg.email': 'a@b.com' - } + openid_args = {'sreg.email': 'a@b.com'} m = message.Message.fromOpenIDArgs(openid_args) self.failUnless((sreg.ns_uri, 'sreg') in list(m.namespaces.iteritems())) @@ -778,6 +769,7 @@ def test_isOpenID1(self): def test_isOpenID2(self): self.failUnless(self.msg.isOpenID2()) + class MessageTest(unittest.TestCase): def setUp(self): self.postargs = { @@ -786,24 +778,24 @@ def setUp(self): 'openid.identity': 'http://bogus.example.invalid:port/', 'openid.assoc_handle': 'FLUB', 'openid.return_to': 'Neverland', - } + } self.action_url = 'scheme://host:port/path?query' self.form_tag_attrs = { 'company': 'janrain', 'class': 'fancyCSS', - } + } self.submit_text = 'GO!' - ### Expected data regardless of input + # Expected data regardless of input self.required_form_attrs = { - 'accept-charset':'UTF-8', - 'enctype':'application/x-www-form-urlencoded', + 'accept-charset': 'UTF-8', + 'enctype': 'application/x-www-form-urlencoded', 'method': 'post', - } + } def _checkForm(self, html, message_, action_url, form_tag_attrs, submit_text): @@ -818,8 +810,7 @@ def _checkForm(self, html, message_, action_url, # Check required form attributes for k, v in self.required_form_attrs.iteritems(): assert form.attrib[k] == v, \ - "Expected '%s' for required form attribute '%s', got '%s'" % \ - (v, k, form.attrib[k]) + "Expected '%s' for required form attribute '%s', got '%s'" % (v, k, form.attrib[k]) # Check extra form attributes for k, v in form_tag_attrs.iteritems(): @@ -831,13 +822,11 @@ def _checkForm(self, html, message_, action_url, continue assert form.attrib[k] == v, \ - "Form attribute '%s' should be '%s', found '%s'" % \ - (k, v, form.attrib[k]) + "Form attribute '%s' should be '%s', found '%s'" % (k, v, form.attrib[k]) # Check hidden fields against post args - hiddens = [e for e in form \ - if e.tag.upper() == 'INPUT' and \ - e.attrib['type'].upper() == 'HIDDEN'] + hiddens = [e for e in form + if e.tag.upper() == 'INPUT' and e.attrib['type'].upper() == 'HIDDEN'] # For each post arg, make sure there is a hidden with that # value. Make sure there are no other hiddens. @@ -845,34 +834,29 @@ def _checkForm(self, html, message_, action_url, for e in hiddens: if e.attrib['name'] == name: assert e.attrib['value'] == value, \ - "Expected value of hidden input '%s' to be '%s', got '%s'" % \ - (e.attrib['name'], value, e.attrib['value']) + "Expected value of hidden input '%s' to be '%s', got '%s'" % \ + (e.attrib['name'], value, e.attrib['value']) break else: self.fail("Post arg '%s' not found in form" % (name,)) for e in hiddens: assert e.attrib['name'] in message_.toPostArgs().keys(), \ - "Form element for '%s' not in " + \ - "original message" % (e.attrib['name']) + "Form element for '%s' not in original message" % (e.attrib['name']) # Check action URL assert form.attrib['action'] == action_url, \ - "Expected form 'action' to be '%s', got '%s'" % \ - (action_url, form.attrib['action']) + "Expected form 'action' to be '%s', got '%s'" % (action_url, form.attrib['action']) # Check submit text - submits = [e for e in form \ - if e.tag.upper() == 'INPUT' and \ - e.attrib['type'].upper() == 'SUBMIT'] + submits = [e for e in form + if e.tag.upper() == 'INPUT' and e.attrib['type'].upper() == 'SUBMIT'] assert len(submits) == 1, \ - "Expected only one 'input' with type = 'submit', got %d" % \ - (len(submits),) + "Expected only one 'input' with type = 'submit', got %d" % (len(submits),) assert submits[0].attrib['value'] == submit_text, \ - "Expected submit value to be '%s', got '%s'" % \ - (submit_text, submits[0].attrib['value']) + "Expected submit value to be '%s', got '%s'" % (submit_text, submits[0].attrib['value']) def test_toFormMarkup(self): m = message.Message.fromPostArgs(self.postargs) @@ -888,8 +872,8 @@ def test_toFormMarkup_bug_with_utf8_values(self): 'openid.identity': 'http://bogus.example.invalid:port/', 'openid.assoc_handle': 'FLUB', 'openid.return_to': 'Neverland', - 'ünicöde_key' : 'ünicöde_välüe', - } + 'ünicöde_key': 'ünicöde_välüe', + } m = message.Message.fromPostArgs(postargs) # Calling m.toFormMarkup with lxml used for ElementTree will throw # a ValueError. @@ -930,7 +914,6 @@ def test_overrideRequired(self): self._checkForm(html, m, self.action_url, tag_attrs, self.submit_text) - def test_setOpenIDNamespace_invalid(self): m = message.Message() invalid_things = [ @@ -944,19 +927,18 @@ def test_setOpenIDNamespace_invalid(self): 'https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fspecs.openid.net%2Fauth%2F2.0', # This is a Type URI, not a openid.ns value. 'http://specs.openid.net/auth/2.0/signon', - ] + ] for x in invalid_things: self.failUnlessRaises(message.InvalidOpenIDNamespace, m.setOpenIDNamespace, x, False) - def test_isOpenID1(self): v1_namespaces = [ # Yes, there are two of them. 'http://openid.net/signon/1.1', 'http://openid.net/signon/1.0', - ] + ] for ns in v1_namespaces: m = message.Message(ns) @@ -983,14 +965,13 @@ def test_setOpenIDNamespace_implicit(self): m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, True) self.failUnless(m.namespaces.isImplicit(message.THE_OTHER_OPENID1_NS)) - def test_explicitOpenID11NSSerialzation(self): m = message.Message() m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, implicit=False) post_args = m.toPostArgs() self.failUnlessEqual(post_args, - {'openid.ns':message.THE_OTHER_OPENID1_NS}) + {'openid.ns': message.THE_OTHER_OPENID1_NS}) def test_fromPostArgs_ns11(self): # An example of the stuff that some Drupal installations send us, @@ -1005,12 +986,11 @@ def test_fromPostArgs_ns11(self): u'openid.return_to': u'http://drupal.invalid/return_to', u'openid.sreg.required': u'nickname,email', u'openid.trust_root': u'http://drupal.invalid', - } + } m = message.Message.fromPostArgs(query) self.failUnless(m.isOpenID1()) - class NamespaceMapTest(unittest.TestCase): def test_onealias(self): nsm = message.NamespaceMap() @@ -1024,16 +1004,16 @@ def test_iteration(self): nsm = message.NamespaceMap() uripat = 'http://example.com/foo%r' - nsm.add(uripat%0) - for n in range(1,23): - self.failUnless(uripat%(n-1) in nsm) - self.failUnless(nsm.isDefined(uripat%(n-1))) - nsm.add(uripat%n) + nsm.add(uripat % 0) + for n in range(1, 23): + self.failUnless(uripat % (n - 1) in nsm) + self.failUnless(nsm.isDefined(uripat % (n - 1))) + nsm.add(uripat % n) for (uri, alias) in nsm.iteritems(): - self.failUnless(uri[22:]==alias[3:]) + self.failUnless(uri[22:] == alias[3:]) - i=0 + i = 0 it = nsm.iterAliases() try: while True: @@ -1042,7 +1022,7 @@ def test_iteration(self): except StopIteration: self.failUnless(i == 23) - i=0 + i = 0 it = nsm.iterNamespaceURIs() try: while True: diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index c23ef96e..8936ecd8 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -1,10 +1,9 @@ - import unittest from openid import association from openid.consumer.consumer import GenericConsumer, ServerError from openid.consumer.discover import OPENID_2_0_TYPE, OpenIDServiceEndpoint -from openid.message import OPENID1_NS, OPENID2_NS, OPENID_NS, Message +from openid.message import OPENID1_NS, OPENID_NS, Message from .support import CatchLogs @@ -29,11 +28,13 @@ def _requestAssociation(self, endpoint, assoc_type, session_type): else: return m + class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): """ Test the session type negotiation behavior of an OpenID 2 consumer. """ + def setUp(self): CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) @@ -141,7 +142,7 @@ def testUnsupportedWithRetryAndFail(self): msg.setArg(OPENID_NS, 'session_type', 'DH-SHA1') self.consumer.return_messages = [msg, - Message(self.endpoint.preferredNamespace())] + Message(self.endpoint.preferredNamespace())] self.failUnlessEqual(self.consumer._negotiateAssociation(self.endpoint), None) @@ -160,6 +161,7 @@ def testValid(self): self.failUnless(self.consumer._negotiateAssociation(self.endpoint) is assoc) self.failUnlessLogEmpty() + class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): """ Tests for the OpenID 1 consumer association session behavior. See @@ -170,6 +172,7 @@ class is not a subclass of the OpenID 2 tests. Instead, it uses these tests pass openid2-style messages to the openid 1 association processing logic to be sure it ignores the extra data. """ + def setUp(self): CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) @@ -247,12 +250,13 @@ def testValid(self): self.failUnless(self.consumer._negotiateAssociation(self.endpoint) is assoc) self.failUnlessLogEmpty() + class TestNegotiatorBehaviors(unittest.TestCase, CatchLogs): def setUp(self): self.allowed_types = [ ('HMAC-SHA1', 'no-encryption'), ('HMAC-SHA256', 'no-encryption'), - ] + ] self.n = association.SessionNegotiator(self.allowed_types) @@ -269,5 +273,6 @@ def testAddAllowedTypeContents(self): for typ in association.getSessionTypes(assoc_type): self.failUnless((assoc_type, typ) in self.n.allowed_types) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index fe171512..7b271346 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -1,5 +1,4 @@ import re -import time import unittest from openid.store.nonce import checkTimestamp, mkNonce, split as splitNonce @@ -7,6 +6,7 @@ nonce_re = re.compile(r'\A\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ') + class NonceTest(unittest.TestCase): def test_mkNonce(self): nonce = mkNonce() @@ -35,6 +35,7 @@ def test_mkSplit(self): self.failUnlessEqual(len(salt), 6) self.failUnlessEqual(et, t) + class BadSplitTest(datadriven.DataDrivenTestCase): cases = [ '', @@ -44,7 +45,7 @@ class BadSplitTest(datadriven.DataDrivenTestCase): '1970.01-01T00:00:00Z', 'Thu Sep 7 13:29:31 PDT 2006', 'monkeys', - ] + ] def __init__(self, nonce_str): datadriven.DataDrivenTestCase.__init__(self, nonce_str) @@ -53,6 +54,7 @@ def __init__(self, nonce_str): def runOneTest(self): self.failUnlessRaises(ValueError, splitNonce, self.nonce_str) + class CheckTimestampTest(datadriven.DataDrivenTestCase): cases = [ # exact, no allowed skew @@ -78,7 +80,7 @@ class CheckTimestampTest(datadriven.DataDrivenTestCase): # malformed nonce string ('monkeys', 0, 0, False), - ] + ] def __init__(self, nonce_string, allowed_skew, now, expected): datadriven.DataDrivenTestCase.__init__( @@ -92,9 +94,11 @@ def runOneTest(self): actual = checkTimestamp(self.nonce_string, self.allowed_skew, self.now) self.failUnlessEqual(bool(self.expected), bool(actual)) + def pyUnitTests(): return datadriven.loadTests(__name__) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 4b7cca4a..16aebea1 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -15,9 +15,11 @@ ''' + def mkXRDS(services): return XRDS_BOILERPLATE % (services,) + def mkService(uris=None, type_uris=None, local_id=None, dent=' '): chunks = [dent, '\n'] dent2 = dent + ' ' @@ -27,7 +29,7 @@ def mkService(uris=None, type_uris=None, local_id=None, dent=' '): if uris: for uri in uris: - if type(uri) is tuple: + if isinstance(uri, tuple): uri, prio = uri else: prio = None @@ -45,18 +47,21 @@ def mkService(uris=None, type_uris=None, local_id=None, dent=' '): return ''.join(chunks) + # Different sets of server URLs for use in the URI tag server_url_options = [ - [], # This case should not generate an endpoint object + [], # This case should not generate an endpoint object ['http://server.url/'], ['https://server.url/'], ['https://server.url/', 'http://server.url/'], ['https://server.url/', 'http://server.url/', 'http://example.server.url/'], - ] +] # Used for generating test data + + def subsets(l): """Generate all non-empty sublists of a list""" subsets_list = [[]] @@ -64,12 +69,13 @@ def subsets(l): subsets_list += [[x] + t for t in subsets_list] return subsets_list + # A couple of example extension type URIs. These are not at all # official, but are just here for testing. ext_types = [ 'http://janrain.com/extension/blah', 'http://openid.net/sreg/1.0', - ] +] # All valid combinations of Type tags that should produce an OpenID endpoint type_uri_options = [ @@ -81,14 +87,14 @@ def subsets(l): # All combinations of extension types (including empty extenstion list) for exts in subsets(ext_types) - ] +] # Range of valid Delegate tag values for generating test data local_id_options = [ None, 'http://vanity.domain/', 'https://somewhere/yadis/', - ] +] # All combinations of valid URIs, Type URIs and Delegate tags data = [ @@ -96,7 +102,8 @@ def subsets(l): for uris in server_url_options for type_uris in type_uri_options for local_id in local_id_options - ] +] + class OpenIDYadisTest(unittest.TestCase): def __init__(self, uris, type_uris, local_id): @@ -129,8 +136,7 @@ def runTest(self): self.failUnlessEqual(len(self.uris), len(endpoints)) # So that we can check equality on the endpoint types - type_uris = list(self.type_uris) - type_uris.sort() + type_uris = sorted(self.type_uris) seen_uris = [] for endpoint in endpoints: @@ -143,19 +149,18 @@ def runTest(self): self.failUnlessEqual(self.local_id, endpoint.local_id) # and types - actual_types = list(endpoint.type_uris) - actual_types.sort() + actual_types = sorted(endpoint.type_uris) self.failUnlessEqual(actual_types, type_uris) # So that they will compare equal, because we don't care what # order they are in seen_uris.sort() - uris = list(self.uris) - uris.sort() + uris = sorted(self.uris) # Make sure we saw all URIs, and saw each one once self.failUnlessEqual(uris, seen_uris) + def pyUnitTests(): cases = [] for args in data: diff --git a/openid/test/test_pape_draft2.py b/openid/test/test_pape_draft2.py index f468015b..be76550b 100644 --- a/openid/test/test_pape_draft2.py +++ b/openid/test/test_pape_draft2.py @@ -1,8 +1,7 @@ - import unittest from openid.extensions.draft import pape2 as pape -from openid.message import * +from openid.message import OPENID2_NS, Message from openid.server import server @@ -39,14 +38,15 @@ def test_getExtensionArgs(self): self.req.addPolicyURI('http://zig') self.failUnlessEqual({'preferred_auth_policies': 'http://uri http://zig'}, self.req.getExtensionArgs()) self.req.max_auth_age = 789 - self.failUnlessEqual({'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}, self.req.getExtensionArgs()) + self.failUnlessEqual({'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}, + self.req.getExtensionArgs()) def test_parseExtensionArgs(self): args = {'preferred_auth_policies': 'http://foo http://bar', 'max_auth_age': '9'} self.req.parseExtensionArgs(args) self.failUnlessEqual(9, self.req.max_auth_age) - self.failUnlessEqual(['http://foo','http://bar'], self.req.preferred_auth_policies) + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.preferred_auth_policies) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}) @@ -55,12 +55,12 @@ def test_parseExtensionArgs_empty(self): def test_fromOpenIDRequest(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.preferred_auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.max_auth_age': '5476' - }) + 'mode': 'checkid_setup', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.preferred_auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.max_auth_age': '5476' + }) oid_req = server.OpenIDRequest() oid_req.message = openid_req_msg req = pape.Request.fromOpenIDRequest(oid_req) @@ -81,6 +81,7 @@ def test_preferred_types(self): pape.AUTH_MULTI_FACTOR_PHYSICAL]) self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], pt) + class DummySuccessResponse: def __init__(self, message, signed_stuff): self.message = message @@ -89,6 +90,7 @@ def __init__(self, message, signed_stuff): def getSignedNS(self, ns_uri): return self.signed_stuff + class PapeResponseTestCase(unittest.TestCase): def setUp(self): self.req = pape.Response() @@ -122,9 +124,13 @@ def test_getExtensionArgs(self): self.req.addPolicyURI('http://zig') self.failUnlessEqual({'auth_policies': 'http://uri http://zig'}, self.req.getExtensionArgs()) self.req.auth_time = "1776-07-04T14:43:12Z" - self.failUnlessEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}, self.req.getExtensionArgs()) + self.failUnlessEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}, + self.req.getExtensionArgs()) self.req.nist_auth_level = 3 - self.failUnlessEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z", 'nist_auth_level': '3'}, self.req.getExtensionArgs()) + self.failUnlessEqual({'auth_policies': 'http://uri http://zig', + 'auth_time': "1776-07-04T14:43:12Z", + 'nist_auth_level': '3'}, + self.req.getExtensionArgs()) def test_getExtensionArgs_error_auth_age(self): self.req.auth_time = "long ago" @@ -143,13 +149,13 @@ def test_parseExtensionArgs(self): 'auth_time': '1970-01-01T00:00:00Z'} self.req.parseExtensionArgs(args) self.failUnlessEqual('1970-01-01T00:00:00Z', self.req.auth_time) - self.failUnlessEqual(['http://foo','http://bar'], self.req.auth_policies) + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}) self.failUnlessEqual(None, self.req.auth_time) self.failUnlessEqual([], self.req.auth_policies) - + def test_parseExtensionArgs_strict_bogus1(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': 'yesterday'} @@ -162,13 +168,13 @@ def test_parseExtensionArgs_strict_bogus2(self): 'nist_auth_level': 'some'} self.failUnlessRaises(ValueError, self.req.parseExtensionArgs, args, True) - + def test_parseExtensionArgs_strict_good(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': '1970-01-01T00:00:00Z', 'nist_auth_level': '0'} self.req.parseExtensionArgs(args, True) - self.failUnlessEqual(['http://foo','http://bar'], self.req.auth_policies) + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) self.failUnlessEqual('1970-01-01T00:00:00Z', self.req.auth_time) self.failUnlessEqual(0, self.req.nist_auth_level) @@ -177,21 +183,21 @@ def test_parseExtensionArgs_nostrict_bogus(self): 'auth_time': 'when the cows come home', 'nist_auth_level': 'some'} self.req.parseExtensionArgs(args) - self.failUnlessEqual(['http://foo','http://bar'], self.req.auth_policies) + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) self.failUnlessEqual(None, self.req.auth_time) self.failUnlessEqual(None, self.req.nist_auth_level) def test_fromSuccessResponse(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) signed_stuff = { - 'auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'auth_time': '1970-01-01T00:00:00Z' + 'auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'auth_time': '1970-01-01T00:00:00Z' } oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) req = pape.Response.fromSuccessResponse(oid_req) @@ -200,12 +206,12 @@ def test_fromSuccessResponse(self): def test_fromSuccessResponseNoSignedArgs(self): openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) signed_stuff = {} diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index 9693fad9..243eae5d 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -1,9 +1,8 @@ - import unittest import warnings from openid.extensions.draft import pape5 as pape -from openid.message import * +from openid.message import OPENID2_NS, Message from openid.server import server warnings.filterwarnings('ignore', module=__name__, @@ -111,7 +110,7 @@ def test_getExtensionArgsWithAuthLevels(self): ('auth_level.ns.%s' % alias2): uri2, 'preferred_auth_level_types': ' '.join([alias, alias2]), 'preferred_auth_policies': '', - } + } self.failUnlessEqual(expected_args, self.req.getExtensionArgs()) @@ -127,7 +126,7 @@ def test_parseExtensionArgsWithAuthLevels(self): ('auth_level.ns.%s' % alias2): uri2, 'preferred_auth_level_types': ' '.join([alias, alias2]), 'preferred_auth_policies': '', - } + } # Check request object state self.req.parseExtensionArgs(request_args, is_openid1=False, strict=False) @@ -141,8 +140,8 @@ def test_parseExtensionArgsWithAuthLevels(self): def test_parseExtensionArgsWithAuthLevels_openID1(self): request_args = { - 'preferred_auth_level_types':'nist jisa', - } + 'preferred_auth_level_types': 'nist jisa', + } expected_auth_levels = [pape.LEVELS_NIST, pape.LEVELS_JISA] self.req.parseExtensionArgs(request_args, is_openid1=True) self.assertEqual(expected_auth_levels, @@ -159,12 +158,12 @@ def test_parseExtensionArgsWithAuthLevels_openID1(self): request_args, is_openid1=False, strict=True) def test_parseExtensionArgs_ignoreBadAuthLevels(self): - request_args = {'preferred_auth_level_types':'monkeys'} + request_args = {'preferred_auth_level_types': 'monkeys'} self.req.parseExtensionArgs(request_args, False) self.assertEqual([], self.req.preferred_auth_level_types) def test_parseExtensionArgs_strictBadAuthLevels(self): - request_args = {'preferred_auth_level_types':'monkeys'} + request_args = {'preferred_auth_level_types': 'monkeys'} self.failUnlessRaises(ValueError, self.req.parseExtensionArgs, request_args, is_openid1=False, strict=True) @@ -173,7 +172,7 @@ def test_parseExtensionArgs(self): 'max_auth_age': '9'} self.req.parseExtensionArgs(args, False) self.failUnlessEqual(9, self.req.max_auth_age) - self.failUnlessEqual(['http://foo','http://bar'], + self.failUnlessEqual(['http://foo', 'http://bar'], self.req.preferred_auth_policies) self.failUnlessEqual([], self.req.preferred_auth_level_types) @@ -191,12 +190,12 @@ def test_parseExtensionArgs_empty(self): def test_fromOpenIDRequest(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.preferred_auth_policies': ' '.join(policy_uris), - 'pape.max_auth_age': '5476' - }) + 'mode': 'checkid_setup', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.preferred_auth_policies': ' '.join(policy_uris), + 'pape.max_auth_age': '5476' + }) oid_req = server.OpenIDRequest() oid_req.message = openid_req_msg req = pape.Request.fromOpenIDRequest(oid_req) @@ -217,6 +216,7 @@ def test_preferred_types(self): pape.AUTH_MULTI_FACTOR_PHYSICAL]) self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], pt) + class DummySuccessResponse: def __init__(self, message, signed_stuff): self.message = message @@ -228,6 +228,7 @@ def isOpenID1(self): def getSignedNS(self, ns_uri): return self.signed_stuff + class PapeResponseTestCase(unittest.TestCase): def setUp(self): self.resp = pape.Response() @@ -293,7 +294,7 @@ def test_parseExtensionArgs(self): 'auth_time': '1970-01-01T00:00:00Z'} self.resp.parseExtensionArgs(args, is_openid1=False) self.failUnlessEqual('1970-01-01T00:00:00Z', self.resp.auth_time) - self.failUnlessEqual(['http://foo','http://bar'], + self.failUnlessEqual(['http://foo', 'http://bar'], self.resp.auth_policies) def test_parseExtensionArgs_valid_none(self): @@ -327,7 +328,7 @@ def test_parseExtensionArgs_ignore_superfluous_none(self): args = { 'auth_policies': ' '.join(policies), - } + } self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) @@ -339,7 +340,7 @@ def test_parseExtensionArgs_none_strict(self): args = { 'auth_policies': ' '.join(policies), - } + } self.failUnlessRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) @@ -385,7 +386,7 @@ def test_parseExtensionArgs_strict_good(self): 'auth_level.nist': '0', 'auth_level.ns.nist': pape.LEVELS_NIST} self.resp.parseExtensionArgs(args, is_openid1=False, strict=True) - self.failUnlessEqual(['http://foo','http://bar'], + self.failUnlessEqual(['http://foo', 'http://bar'], self.resp.auth_policies) self.failUnlessEqual('1970-01-01T00:00:00Z', self.resp.auth_time) self.failUnlessEqual(0, self.resp.nist_auth_level) @@ -395,7 +396,7 @@ def test_parseExtensionArgs_nostrict_bogus(self): 'auth_time': 'when the cows come home', 'nist_auth_level': 'some'} self.resp.parseExtensionArgs(args, is_openid1=False) - self.failUnlessEqual(['http://foo','http://bar'], + self.failUnlessEqual(['http://foo', 'http://bar'], self.resp.auth_policies) self.failUnlessEqual(None, self.resp.auth_time) self.failUnlessEqual(None, self.resp.nist_auth_level) @@ -403,15 +404,15 @@ def test_parseExtensionArgs_nostrict_bogus(self): def test_fromSuccessResponse(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join(policy_uris), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join(policy_uris), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) signed_stuff = { - 'auth_policies': ' '.join(policy_uris), - 'auth_time': '1970-01-01T00:00:00Z' + 'auth_policies': ' '.join(policy_uris), + 'auth_time': '1970-01-01T00:00:00Z' } oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) req = pape.Response.fromSuccessResponse(oid_req) @@ -421,12 +422,12 @@ def test_fromSuccessResponse(self): def test_fromSuccessResponseNoSignedArgs(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join(policy_uris), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join(policy_uris), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) signed_stuff = {} @@ -438,5 +439,6 @@ def getSignedNS(self, ns_uri): resp = pape.Response.fromSuccessResponse(oid_req) self.failUnless(resp is None) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index fe90ac71..a3c038da 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -18,9 +18,9 @@ def __init__(self, filename, testname, expected, case): def runTest(self): p = YadisHTMLParser() - try: + try: p.feed(self.case) - except ParseDone, why: + except ParseDone as why: found = why[0] # make sure we protect outselves against accidental bogus @@ -44,6 +44,7 @@ def shortDescription(self): self.__class__.__module__, os.path.basename(self.filename)) + def parseCases(data): cases = [] for chunk in data.split('\f\n'): @@ -51,6 +52,7 @@ def parseCases(data): cases.append((expected, case)) return cases + def pyUnitTests(): """Make a pyunit TestSuite from a file defining test cases.""" s = unittest.TestSuite() @@ -58,10 +60,12 @@ def pyUnitTests(): s.addTest(_TestCase(filename, str(test_num), expected, case)) return s + def test(): runner = unittest.TextTestRunner() return runner.run(pyUnitTests()) + filenames = ['data/test1-parsehtml.txt'] default_test_files = [] @@ -70,6 +74,7 @@ def test(): full_name = os.path.join(base, filename) default_test_files.append(full_name) + def getCases(test_files=default_test_files): cases = [] for filename in test_files: diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index d37b5949..c1069818 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -11,8 +11,6 @@ from openid.yadis.discover import DiscoveryFailure, DiscoveryResult -# Too many methods does not apply to unit test objects -#pylint:disable-msg=R0904 class TestBuildDiscoveryURL(unittest.TestCase): """Tests for building the discovery URL from a realm and a return_to URL @@ -44,6 +42,7 @@ def test_wildcard_port(self): self.failUnlessDiscoURL('http://*.example.com:8001/foo', 'http://www.example.com:8001/foo') + class TestExtractReturnToURLs(unittest.TestCase): disco_url = 'http://example.com/' @@ -141,8 +140,7 @@ def test_twoEntries(self): -''', ['http://rp.example.com/return', - 'http://other.rp.example.com/return']) +''', ['http://rp.example.com/return', 'http://other.rp.example.com/return']) def test_twoEntries_withOther(self): self.failUnlessXRDSHasReturnURLs('''\ @@ -165,9 +163,7 @@ def test_twoEntries_withOther(self): -''', ['http://rp.example.com/return', - 'http://other.rp.example.com/return']) - +''', ['http://rp.example.com/return', 'http://other.rp.example.com/return']) class TestReturnToMatches(unittest.TestCase): @@ -203,6 +199,7 @@ def test_noMatch(self): [r], 'http://example.com/xss_exploit')) + class TestVerifyReturnTo(unittest.TestCase, CatchLogs): def setUp(self): @@ -210,7 +207,7 @@ def setUp(self): def tearDown(self): CatchLogs.tearDown(self) - + def test_bogusRealm(self): self.failIf(trustroot.verifyReturnTo('', 'http://example.com/')) @@ -250,5 +247,6 @@ def vrfy(disco_url): trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) self.failUnlessLogMatches("Attempting to verify") + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_server.py b/openid/test/test_server.py index a19a734e..171be4c2 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -2,9 +2,11 @@ """ import cgi import unittest +from functools import partial from urlparse import urlparse from openid import association, cryptutil, oidutil +from openid.consumer.consumer import DiffieHellmanSHA256ConsumerSession from openid.message import IDENTIFIER_SELECT, OPENID1_NS, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, Message, no_default from openid.server import server from openid.store import memstore @@ -16,9 +18,13 @@ # for more, see /etc/ssh/moduli -ALT_MODULUS = 0xCAADDDEC1667FC68B5FA15D53C4E1532DD24561A1A2D47A12C01ABEA1E00731F6921AAC40742311FDF9E634BB7131BEE1AF240261554389A910425E044E88C8359B010F5AD2B80E29CB1A5B027B19D9E01A6F63A6F45E5D7ED2FF6A2A0085050A7D0CF307C3DB51D2490355907B4427C23A98DF1EB8ABEF2BA209BB7AFFE86A7 +ALT_MODULUS = int('1423261515703355186607439952816216983770573549498844689430217675736088990483613604225135575535147900' + '4551229946895343158530081254885941985717109436635815890343316791551733211386105974742540867014420109' + '9811846875730766487278261498262568348338476437200556998366087779709990807518291581860338635288400119' + '293970087') ALT_GEN = 5 + class TestProtocolError(unittest.TestCase): def test_browserWithReturnTo(self): return_to = "http://rp.unittest/consumer" @@ -27,13 +33,13 @@ def test_browserWithReturnTo(self): 'openid.mode': 'monkeydance', 'openid.identity': 'http://wagu.unittest/', 'openid.return_to': return_to, - }) + }) e = server.ProtocolError(args, "plucky") self.failUnless(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } rt_base, result_args = e.encodeToURL().split('?', 1) result_args = cgi.parse_qs(result_args) @@ -48,14 +54,14 @@ def test_browserWithReturnTo_OpenID2_GET(self): 'openid.identity': 'http://wagu.unittest/', 'openid.claimed_id': 'http://wagu.unittest/', 'openid.return_to': return_to, - }) + }) e = server.ProtocolError(args, "plucky") self.failUnless(e.hasReturnTo()) expected_args = { 'openid.ns': [OPENID2_NS], 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } rt_base, result_args = e.encodeToURL().split('?', 1) result_args = cgi.parse_qs(result_args) @@ -70,15 +76,9 @@ def test_browserWithReturnTo_OpenID2_POST(self): 'openid.identity': 'http://wagu.unittest/', 'openid.claimed_id': 'http://wagu.unittest/', 'openid.return_to': return_to, - }) + }) e = server.ProtocolError(args, "plucky") self.failUnless(e.hasReturnTo()) - expected_args = { - 'openid.ns': [OPENID2_NS], - 'openid.mode': ['error'], - 'openid.error': ['plucky'], - } - self.failUnless(e.whichEncoding() == server.ENCODE_HTML_FORM) self.failUnless(e.toFormMarkup() == e.toMessage().toFormMarkup( args.getArg(OPENID_NS, 'return_to'))) @@ -90,13 +90,13 @@ def test_browserWithReturnTo_OpenID1_exceeds_limit(self): 'openid.mode': 'monkeydance', 'openid.identity': 'http://wagu.unittest/', 'openid.return_to': return_to, - }) + }) e = server.ProtocolError(args, "plucky") self.failUnless(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], - } + } self.failUnless(e.whichEncoding() == server.ENCODE_URL) @@ -109,7 +109,7 @@ def test_noReturnTo(self): args = Message.fromPostArgs({ 'openid.mode': 'zebradance', 'openid.identity': 'http://wagu.unittest/', - }) + }) e = server.ProtocolError(args, "waffles") self.failIf(e.hasReturnTo()) expected = """error:waffles @@ -117,7 +117,6 @@ def test_noReturnTo(self): """ self.failUnlessEqual(e.encodeToKVForm(), expected) - def test_noMessage(self): e = server.ProtocolError(None, "no moar pancakes") self.failIf(e.hasReturnTo()) @@ -146,14 +145,14 @@ def test_irrelevant(self): args = { 'pony': 'spotted', 'sreg.mutant_power': 'decaffinator', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_bad(self): args = { 'openid.mode': 'twos-compliment', 'openid.pants': 'zippered', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_dictOfLists(self): @@ -163,10 +162,10 @@ def test_dictOfLists(self): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, - } + } try: result = self.decode(args) - except TypeError, err: + except TypeError as err: self.failUnless(str(err).find('values') != -1, err) else: self.fail("Expected TypeError, but got result %s" % (result,)) @@ -180,7 +179,7 @@ def test_checkidImmediate(self): 'openid.trust_root': self.tr_url, # should be ignored 'openid.some.extension': 'junk', - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) self.failUnlessEqual(r.mode, "checkid_immediate") @@ -197,7 +196,7 @@ def test_checkidSetup(self): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) self.failUnlessEqual(r.mode, "checkid_setup") @@ -215,7 +214,7 @@ def test_checkidSetupOpenID2(self): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) self.failUnlessEqual(r.mode, "checkid_setup") @@ -233,7 +232,7 @@ def test_checkidSetupNoClaimedIDOpenID2(self): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoIdentityOpenID2(self): @@ -243,7 +242,7 @@ def test_checkidSetupNoIdentityOpenID2(self): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) self.failUnlessEqual(r.mode, "checkid_setup") @@ -261,7 +260,7 @@ def test_checkidSetupNoReturnOpenID1(self): 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.trust_root': self.tr_url, - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoReturnOpenID2(self): @@ -276,7 +275,7 @@ def test_checkidSetupNoReturnOpenID2(self): 'openid.claimed_id': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.realm': self.tr_url, - } + } self.failUnless(isinstance(self.decode(args), server.CheckIDRequest)) req = self.decode(args) @@ -294,7 +293,7 @@ def test_checkidSetupRealmRequiredOpenID2(self): 'openid.mode': 'checkid_setup', 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_checkidSetupBadReturn(self): @@ -303,10 +302,10 @@ def test_checkidSetupBadReturn(self): 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': 'not a url', - } + } try: result = self.decode(args) - except server.ProtocolError, err: + except server.ProtocolError as err: self.failUnless(err.openid_message) else: self.fail("Expected ProtocolError, instead returned with %s" % @@ -319,10 +318,10 @@ def test_checkidSetupUntrustedReturn(self): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': self.rt_url, 'openid.trust_root': 'http://not-the-return-place.unittest/', - } + } try: result = self.decode(args) - except server.UntrustedReturnURL, err: + except server.UntrustedReturnURL as err: self.failUnless(err.openid_message) else: self.fail("Expected UntrustedReturnURL, instead returned with %s" % @@ -338,13 +337,12 @@ def test_checkAuth(self): 'openid.return_to': 'signedval2', 'openid.response_nonce': 'signedval3', 'openid.baz': 'unsigned', - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckAuthRequest)) self.failUnlessEqual(r.mode, 'check_authentication') self.failUnlessEqual(r.sig, 'sigblob') - def test_checkAuthMissingSignature(self): args = { 'openid.mode': 'check_authentication', @@ -353,10 +351,9 @@ def test_checkAuthMissingSignature(self): 'openid.foo': 'signedval1', 'openid.bar': 'signedval2', 'openid.baz': 'unsigned', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_checkAuthAndInvalidate(self): args = { 'openid.mode': 'check_authentication', @@ -368,18 +365,17 @@ def test_checkAuthAndInvalidate(self): 'openid.return_to': 'signedval2', 'openid.response_nonce': 'signedval3', 'openid.baz': 'unsigned', - } + } r = self.decode(args) self.failUnless(isinstance(r, server.CheckAuthRequest)) self.failUnlessEqual(r.invalidate_handle, '[[SMART_handle]]') - def test_associateDH(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", - } + } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) self.failUnlessEqual(r.mode, "associate") @@ -392,20 +388,18 @@ def test_associateDHMissingKey(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', - } + } # Using DH-SHA1 without supplying dh_consumer_public is an error. self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_associateDHpubKeyNotB64(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "donkeydonkeydonkey", - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_associateDHModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { @@ -413,8 +407,8 @@ def test_associateDHModGen(self): 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': cryptutil.longToBase64(ALT_MODULUS), - 'openid.dh_gen': cryptutil.longToBase64(ALT_GEN) , - } + 'openid.dh_gen': cryptutil.longToBase64(ALT_GEN), + } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) self.failUnlessEqual(r.mode, "associate") @@ -424,7 +418,6 @@ def test_associateDHModGen(self): self.failUnlessEqual(r.session.dh.generator, ALT_GEN) self.failUnless(r.session.consumer_pubkey) - def test_associateDHCorruptModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { @@ -433,10 +426,9 @@ def test_associateDHCorruptModGen(self): 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', 'openid.dh_gen': 'gnocchi', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_associateDHMissingModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen args = { @@ -444,7 +436,7 @@ def test_associateDHMissingModGen(self): 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) @@ -461,20 +453,18 @@ def test_associateDHMissingModGen(self): # self.failUnlessRaises(server.ProtocolError, self.decode, args) # test_associateDHInvalidModGen.todo = "low-priority feature" - def test_associateWeirdSession(self): args = { 'openid.mode': 'associate', 'openid.session_type': 'FLCL6', 'openid.dh_consumer_public': "YQ==\n", - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) - def test_associatePlain(self): args = { 'openid.mode': 'associate', - } + } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) self.failUnlessEqual(r.mode, "associate") @@ -485,16 +475,16 @@ def test_nomode(self): args = { 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "my public keeey", - } + } self.failUnlessRaises(server.ProtocolError, self.decode, args) def test_invalidns(self): - args = {'openid.ns': 'Tuesday', - 'openid.mode': 'associate'} + args = {'openid.ns': 'Tuesday', + 'openid.mode': 'associate'} try: r = self.decode(args) - except server.ProtocolError, err: + except server.ProtocolError as err: # Assert that the ProtocolError does have a Message attached # to it, even though the request wasn't a well-formed Message. self.failUnless(err.openid_message) @@ -519,12 +509,12 @@ def test_id_res_OpenID2_GET(self): issued. """ request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ @@ -533,12 +523,12 @@ def test_id_res_OpenID2_GET(self): 'identity': request.identity, 'claimed_id': request.identity, 'return_to': request.return_to, - }) + }) self.failIf(response.renderAsForm()) self.failUnless(response.whichEncoding() == server.ENCODE_URL) webresponse = self.encode(response) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) def test_id_res_OpenID2_POST(self): """ @@ -547,12 +537,12 @@ def test_id_res_OpenID2_POST(self): returned. """ request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ @@ -561,7 +551,7 @@ def test_id_res_OpenID2_POST(self): 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + }) self.failUnless(response.renderAsForm()) self.failUnless(len(response.encodeToURL()) > OPENID1_URL_LIMIT) @@ -571,12 +561,12 @@ def test_id_res_OpenID2_POST(self): def test_toFormMarkup(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ @@ -585,19 +575,19 @@ def test_toFormMarkup(self): 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + }) - form_markup = response.toFormMarkup({'foo':'bar'}) + form_markup = response.toFormMarkup({'foo': 'bar'}) self.failUnless(' foo="bar"' in form_markup) def test_toHTML(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ @@ -606,7 +596,7 @@ def test_toHTML(self): 'identity': request.identity, 'claimed_id': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + }) html = response.toHTML() self.failUnless('' in html) self.failUnless('' in html) @@ -622,19 +612,19 @@ def test_id_res_OpenID1_exceeds_limit(self): place to preserve the status quo for OpenID 1. """ request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'id_res', 'identity': request.identity, 'return_to': 'x' * OPENID1_URL_LIMIT, - }) + }) self.failIf(response.renderAsForm()) self.failUnless(len(response.encodeToURL()) > OPENID1_URL_LIMIT) @@ -644,22 +634,22 @@ def test_id_res_OpenID1_exceeds_limit(self): def test_id_res(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'id_res', 'identity': request.identity, 'return_to': request.return_to, - }) + }) webresponse = self.encode(response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] self.failUnless(location.startswith(request.return_to), @@ -672,34 +662,34 @@ def test_id_res(self): def test_cancel(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'cancel', - }) + }) webresponse = self.encode(response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) def test_cancelToForm(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'mode': 'cancel', - }) + }) form = response.toFormMarkup() self.failUnless(form) @@ -726,7 +716,7 @@ def test_checkauthReply(self): response.fields = Message.fromOpenIDArgs({ 'is_valid': 'true', 'invalidate_handle': 'xXxX:xXXx' - }) + }) body = """invalidate_handle:xXxX:xXXx is_valid:true """ @@ -738,7 +728,7 @@ def test_checkauthReply(self): def test_unencodableError(self): args = Message.fromPostArgs({ 'openid.identity': 'http://limu.unittest/', - }) + }) e = server.ProtocolError(args, "wet paint") self.failUnlessRaises(server.EncodingError, self.encode, e) @@ -746,15 +736,14 @@ def test_encodableError(self): args = Message.fromPostArgs({ 'openid.mode': 'associate', 'openid.identity': 'http://limu.unittest/', - }) - body="error:snoot\nmode:error\n" + }) + body = "error:snoot\nmode:error\n" webresponse = self.encode(server.ProtocolError(args, "snoot")) self.failUnlessEqual(webresponse.code, server.HTTP_ERROR) self.failUnlessEqual(webresponse.headers, {}) self.failUnlessEqual(webresponse.body, body) - class TestSigningEncode(unittest.TestCase): def setUp(self): self._dumb_key = server.Signatory._dumb_key @@ -762,19 +751,19 @@ def setUp(self): self.store = memstore.MemoryStore() self.server = server.Server(self.store, "http://signing.unittest/enc") self.request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) self.request.message = Message(OPENID2_NS) self.response = server.OpenIDResponse(self.request) self.response.fields = Message.fromOpenIDArgs({ 'mode': 'id_res', 'identity': self.request.identity, 'return_to': self.request.return_to, - }) + }) self.signatory = server.Signatory(self.store) self.encoder = server.SigningEncoder(self.signatory) self.encode = self.encoder.encode @@ -788,7 +777,7 @@ def test_idres(self): self.request.assoc_handle = assoc_handle webresponse = self.encode(self.response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) @@ -799,7 +788,7 @@ def test_idres(self): def test_idresDumb(self): webresponse = self.encode(self.response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) @@ -813,18 +802,18 @@ def test_forgotStore(self): def test_cancel(self): request = server.CheckIDRequest( - identity = 'http://bombom.unittest/', - trust_root = 'http://burr.unittest/', - return_to = 'http://burr.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bombom.unittest/', + trust_root='http://burr.unittest/', + return_to='http://burr.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) request.message = Message(OPENID2_NS) response = server.OpenIDResponse(request) response.fields.setArg(OPENID_NS, 'mode', 'cancel') webresponse = self.encode(response) self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) - self.failUnless(webresponse.headers.has_key('location')) + self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) self.failIf('openid.sig' in query, response.fields.toPostArgs()) @@ -847,18 +836,19 @@ def test_alreadySigned(self): self.response.fields.setArg(OPENID_NS, 'sig', 'priorSig==') self.failUnlessRaises(server.AlreadySigned, self.encode, self.response) + class TestCheckID(unittest.TestCase): def setUp(self): self.op_endpoint = 'http://endpoint.unittest/' self.store = memstore.MemoryStore() self.server = server.Server(self.store, self.op_endpoint) self.request = server.CheckIDRequest( - identity = 'http://bambam.unittest/', - trust_root = 'http://bar.unittest/', - return_to = 'http://bar.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bambam.unittest/', + trust_root='http://bar.unittest/', + return_to='http://bar.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) self.request.message = Message(OPENID2_NS) def test_trustRootInvalid(self): @@ -878,7 +868,7 @@ def test_malformedTrustRoot(self): self.request.message = sentinel try: result = self.request.trustRootValid() - except server.MalformedTrustRoot, why: + except server.MalformedTrustRoot as why: self.failUnless(sentinel is why.openid_message) else: self.fail('Expected MalformedTrustRoot exception. Got %r' @@ -886,12 +876,12 @@ def test_malformedTrustRoot(self): def test_trustRootValidNoReturnTo(self): request = server.CheckIDRequest( - identity = 'http://bambam.unittest/', - trust_root = 'http://bar.unittest/', - return_to = None, - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bambam.unittest/', + trust_root='http://bar.unittest/', + return_to=None, + immediate=False, + op_endpoint=self.server.op_endpoint, + ) self.failUnless(request.trustRootValid()) @@ -909,6 +899,7 @@ def withVerifyReturnTo(new_verify, callable): # Ensure that exceptions are passed through sentinel = Exception() + def vrfyExc(trust_root, return_to): self.failUnlessEqual(self.request.trust_root, trust_root) self.failUnlessEqual(self.request.return_to, return_to) @@ -916,7 +907,7 @@ def vrfyExc(trust_root, return_to): try: withVerifyReturnTo(vrfyExc, self.request.returnToVerified) - except Exception, e: + except Exception as e: self.failUnless(e is sentinel, e) # Ensure that True and False are passed through unchanged @@ -938,7 +929,7 @@ def _expectAnswer(self, answer, identity=None, claimed_id=None): ('mode', 'id_res'), ('return_to', self.request.return_to), ('op_endpoint', self.op_endpoint), - ] + ] if identity: expected_list.append(('identity', identity)) if claimed_id: @@ -1140,13 +1131,13 @@ def test_fromMessageWithoutTrustRoot(self): def test_fromMessageWithEmptyTrustRoot(self): return_to = u'http://someplace.invalid/?go=thing' msg = Message.fromPostArgs({ - u'openid.assoc_handle': u'{blah}{blah}{OZivdQ==}', - u'openid.claimed_id': u'http://delegated.invalid/', - u'openid.identity': u'http://op-local.example.com/', - u'openid.mode': u'checkid_setup', - u'openid.ns': u'http://openid.net/signon/1.0', - u'openid.return_to': return_to, - u'openid.trust_root': u''}) + u'openid.assoc_handle': u'{blah}{blah}{OZivdQ==}', + u'openid.claimed_id': u'http://delegated.invalid/', + u'openid.identity': u'http://op-local.example.com/', + u'openid.mode': u'checkid_setup', + u'openid.ns': u'http://openid.net/signon/1.0', + u'openid.return_to': return_to, + u'openid.trust_root': u''}) result = server.CheckIDRequest.fromMessage(msg, self.server.op_endpoint) @@ -1172,7 +1163,7 @@ def test_answerAllowNoEndpointOpenID1(self): 'identity': identity, 'trust_root': 'http://bar.unittest/', 'return_to': 'http://bar.unittest/999', - }) + }) self.request = server.CheckIDRequest.fromMessage(reqmessage, None) answer = self.request.answer(True) @@ -1180,7 +1171,7 @@ def test_answerAllowNoEndpointOpenID1(self): ('mode', 'id_res'), ('return_to', self.request.return_to), ('identity', identity), - ] + ] for k, expected in expected_list: actual = answer.fields.getArg(OPENID_NS, k) @@ -1241,7 +1232,7 @@ def test_answerSetupDeny(self): answer = self.request.answer(False) self.failUnlessEqual(answer.fields.getArgs(OPENID_NS), { 'mode': 'cancel', - }) + }) def test_encodeToURL(self): server_url = 'http://openid-server.unittest/' @@ -1262,8 +1253,8 @@ def test_getCancelURL(self): rt, query_string = url.split('?') self.failUnlessEqual(self.request.return_to, rt) query = dict(cgi.parse_qsl(query_string)) - self.failUnlessEqual(query, {'openid.mode':'cancel', - 'openid.ns':OPENID2_NS}) + self.failUnlessEqual(query, {'openid.mode': 'cancel', + 'openid.ns': OPENID2_NS}) def test_getCancelURLimmed(self): self.request.mode = 'checkid_immediate' @@ -1271,7 +1262,6 @@ def test_getCancelURLimmed(self): self.failUnlessRaises(ValueError, self.request.getCancelURL) - class TestCheckIDExtension(unittest.TestCase): def setUp(self): @@ -1279,18 +1269,17 @@ def setUp(self): self.store = memstore.MemoryStore() self.server = server.Server(self.store, self.op_endpoint) self.request = server.CheckIDRequest( - identity = 'http://bambam.unittest/', - trust_root = 'http://bar.unittest/', - return_to = 'http://bar.unittest/999', - immediate = False, - op_endpoint = self.server.op_endpoint, - ) + identity='http://bambam.unittest/', + trust_root='http://bar.unittest/', + return_to='http://bar.unittest/999', + immediate=False, + op_endpoint=self.server.op_endpoint, + ) self.request.message = Message(OPENID2_NS) self.response = server.OpenIDResponse(self.request) self.response.fields.setArg(OPENID_NS, 'mode', 'id_res') self.response.fields.setArg(OPENID_NS, 'blue', 'star') - def test_addField(self): namespace = 'something:' self.response.fields.setArg(namespace, 'bright', 'potato') @@ -1300,13 +1289,12 @@ def test_addField(self): }) self.failUnlessEqual(self.response.fields.getArgs(namespace), - {'bright':'potato'}) - + {'bright': 'potato'}) def test_addFields(self): namespace = 'mi5:' - args = {'tangy': 'suspenders', - 'bravo': 'inclusion'} + args = {'tangy': 'suspenders', + 'bravo': 'inclusion'} self.response.fields.updateArgs(namespace, args) self.failUnlessEqual(self.response.fields.getArgs(OPENID_NS), {'blue': 'star', @@ -1315,7 +1303,6 @@ def test_addFields(self): self.failUnlessEqual(self.response.fields.getArgs(namespace), args) - class MockSignatory(object): isValid = True @@ -1349,7 +1336,7 @@ def setUp(self): 'openid.sig': 'signarture', 'one': 'alpha', 'two': 'beta', - }) + }) self.request = server.CheckAuthRequest( self.assoc_handle, self.message) @@ -1420,7 +1407,8 @@ def test_dhSHA1(self): session = DiffieHellmanSHA1ServerSession(server_dh, cpub) self.request = server.AssociateRequest(session, 'HMAC-SHA1') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) self.failIf(rfg("mac_key")) @@ -1444,7 +1432,8 @@ def test_dhSHA256(self): session = DiffieHellmanSHA256ServerSession(server_dh, cpub) self.request = server.AssociateRequest(session, 'HMAC-SHA256') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA256") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) self.failIf(rfg("mac_key")) @@ -1458,23 +1447,18 @@ def test_dhSHA256(self): self.failUnlessEqual(secret, self.assoc.secret) def test_protoError256(self): - from openid.consumer.consumer import \ - DiffieHellmanSHA256ConsumerSession - s256_session = DiffieHellmanSHA256ConsumerSession() - invalid_s256 = {'openid.assoc_type':'HMAC-SHA1', - 'openid.session_type':'DH-SHA256',} + invalid_s256 = {'openid.assoc_type': 'HMAC-SHA1', 'openid.session_type': 'DH-SHA256'} invalid_s256.update(s256_session.getRequest()) - invalid_s256_2 = {'openid.assoc_type':'MONKEY-PIRATE', - 'openid.session_type':'DH-SHA256',} + invalid_s256_2 = {'openid.assoc_type': 'MONKEY-PIRATE', 'openid.session_type': 'DH-SHA256'} invalid_s256_2.update(s256_session.getRequest()) bad_request_argss = [ invalid_s256, invalid_s256_2, - ] + ] for request_args in bad_request_argss: message = Message.fromPostArgs(request_args) @@ -1487,19 +1471,17 @@ def test_protoError(self): s1_session = DiffieHellmanSHA1ConsumerSession() - invalid_s1 = {'openid.assoc_type':'HMAC-SHA256', - 'openid.session_type':'DH-SHA1',} + invalid_s1 = {'openid.assoc_type': 'HMAC-SHA256', 'openid.session_type': 'DH-SHA1'} invalid_s1.update(s1_session.getRequest()) - invalid_s1_2 = {'openid.assoc_type':'ROBOT-NINJA', - 'openid.session_type':'DH-SHA1',} + invalid_s1_2 = {'openid.assoc_type': 'ROBOT-NINJA', 'openid.session_type': 'DH-SHA1'} invalid_s1_2.update(s1_session.getRequest()) bad_request_argss = [ - {'openid.assoc_type':'Wha?'}, + {'openid.assoc_type': 'Wha?'}, invalid_s1, invalid_s1_2, - ] + ] for request_args in bad_request_argss: message = Message.fromPostArgs(request_args) @@ -1516,7 +1498,7 @@ def test_protoErrorFields(self): openid1_args = { 'openid.identitiy': 'invalid', 'openid.mode': 'checkid_setup', - } + } openid2_args = dict(openid1_args) openid2_args.update({'openid.ns': OPENID2_NS}) @@ -1545,7 +1527,7 @@ def failUnlessExpiresInMatches(self, msg, expected_expires_in): # Slop is necessary because the tests can sometimes get run # right on a second boundary - slop = 1 # second + slop = 1 # second difference = expected_expires_in - expires_in error_message = ('"expires_in" value not within %s of expected: ' @@ -1556,7 +1538,8 @@ def failUnlessExpiresInMatches(self, msg, expected_expires_in): def test_plaintext(self): self.assoc = self.signatory.createAssociation(dumb=False, assoc_type='HMAC-SHA1') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) @@ -1578,7 +1561,7 @@ def test_plaintext_v2(self): 'openid.mode': 'associate', 'openid.assoc_type': 'HMAC-SHA1', 'openid.session_type': 'no-encryption', - } + } self.request = server.AssociateRequest.fromMessage( Message.fromPostArgs(args)) @@ -1587,7 +1570,8 @@ def test_plaintext_v2(self): self.assoc = self.signatory.createAssociation( dumb=False, assoc_type='HMAC-SHA1') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) @@ -1605,7 +1589,8 @@ def test_plaintext_v2(self): def test_plaintext256(self): self.assoc = self.signatory.createAssociation(dumb=False, assoc_type='HMAC-SHA256') response = self.request.answer(self.assoc) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) @@ -1632,8 +1617,9 @@ def test_unsupportedPrefer(self): message=message, preferred_session_type=allowed_sess, preferred_association_type=allowed_assoc, - ) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + ) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg('error_code'), 'unsupported-type') self.failUnlessEqual(rfg('assoc_type'), allowed_assoc) self.failUnlessEqual(rfg('error'), message) @@ -1647,12 +1633,14 @@ def test_unsupported(self): self.request.message = Message(OPENID2_NS) response = self.request.answerUnsupported(message) - rfg = lambda f: response.fields.getArg(OPENID_NS, f) + + rfg = partial(response.fields.getArg, OPENID_NS) self.failUnlessEqual(rfg('error_code'), 'unsupported-type') self.failUnlessEqual(rfg('assoc_type'), None) self.failUnlessEqual(rfg('error'), message) self.failUnlessEqual(rfg('session_type'), None) + class Counter(object): def __init__(self): self.count = 0 @@ -1660,6 +1648,7 @@ def __init__(self): def inc(self): self.count += 1 + class TestServer(unittest.TestCase, CatchLogs): def setUp(self): self.store = memstore.MemoryStore() @@ -1668,6 +1657,7 @@ def setUp(self): def test_dispatch(self): monkeycalled = Counter() + def monkeyDo(request): monkeycalled.inc() r = server.OpenIDResponse(request) @@ -1676,7 +1666,7 @@ def monkeyDo(request): request = server.OpenIDRequest() request.mode = "monkeymode" request.namespace = OPENID1_NS - webresult = self.server.handleRequest(request) + self.server.handleRequest(request) self.failUnlessEqual(monkeycalled.count, 1) def test_associate(self): @@ -1698,7 +1688,7 @@ def test_associate2(self): 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', 'openid.assoc_type': 'HMAC-SHA1', - }) + }) request = server.AssociateRequest.fromMessage(msg) @@ -1721,7 +1711,7 @@ def test_associate3(self): 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', 'openid.assoc_type': 'HMAC-SHA1', - }) + }) request = server.AssociateRequest.fromMessage(msg) response = self.server.openid_associate(request) @@ -1745,7 +1735,7 @@ def test_associate4(self): '1WxJY3jHd5k1/ZReyRZOxZTKdF/dnIqwF8ZXUwI6peV0TyS/K1fOfF/s', 'openid.assoc_type': 'HMAC-SHA256', 'openid.session_type': 'DH-SHA256', - } + } message = Message.fromPostArgs(query) request = server.AssociateRequest.fromMessage(message) response = self.server.openid_associate(request) @@ -1755,7 +1745,7 @@ def test_missingSessionTypeOpenID2(self): """Make sure session_type is required in OpenID 2""" msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, - }) + }) self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, msg) @@ -1765,7 +1755,7 @@ def test_missingAssocTypeOpenID2(self): msg = Message.fromPostArgs({ 'openid.ns': OPENID2_NS, 'openid.session_type': 'no-encryption', - }) + }) self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, msg) @@ -1775,6 +1765,7 @@ def test_checkAuth(self): response = self.server.openid_check_authentication(request) self.failUnless(response.fields.hasKey(OPENID_NS, "is_valid")) + class TestSignatory(unittest.TestCase, CatchLogs): def setUp(self): self.store = memstore.MemoryStore() @@ -1797,7 +1788,7 @@ def test_sign(self): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) self.failUnlessEqual( sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), @@ -1816,8 +1807,8 @@ def test_signDumb(self): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - 'ns':OPENID2_NS, - }) + 'ns': OPENID2_NS, + }) sresponse = self.signatory.sign(response) assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.failUnless(assoc_handle) @@ -1857,7 +1848,7 @@ def test_signExpired(self): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') @@ -1881,7 +1872,6 @@ def test_signExpired(self): self.failIf(self.store.getAssociation(self._normal_key, new_assoc_handle)) self.failUnless(self.messages) - def test_signInvalidHandle(self): request = server.OpenIDRequest() request.namespace = OPENID2_NS @@ -1893,7 +1883,7 @@ def test_signInvalidHandle(self): 'foo': 'amsigned', 'bar': 'notsigned', 'azu': 'alsosigned', - }) + }) sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') @@ -1913,7 +1903,6 @@ def test_signInvalidHandle(self): self.failIf(self.store.getAssociation(self._normal_key, new_assoc_handle)) self.failIf(self.messages, self.messages) - def test_verify(self): assoc_handle = '{vroom}{zoom}' assoc = association.Association.fromExpiresIn( @@ -1927,13 +1916,12 @@ def test_verify(self): 'openid.assoc_handle': assoc_handle, 'openid.signed': 'apple,assoc_handle,foo,signed', 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco=', - }) + }) verified = self.signatory.verify(assoc_handle, signed) self.failIf(self.messages, self.messages) self.failUnless(verified) - def test_verifyBadSig(self): assoc_handle = '{vroom}{zoom}' assoc = association.Association.fromExpiresIn( @@ -1947,7 +1935,7 @@ def test_verifyBadSig(self): 'openid.assoc_handle': assoc_handle, 'openid.signed': 'apple,assoc_handle,foo,signed', 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco='.encode('rot13'), - }) + }) verified = self.signatory.verify(assoc_handle, signed) self.failIf(self.messages, self.messages) @@ -1959,13 +1947,12 @@ def test_verifyBadHandle(self): 'foo': 'bar', 'apple': 'orange', 'openid.sig': "Ylu0KcIR7PvNegB/K41KpnRgJl0=", - }) + }) verified = self.signatory.verify(assoc_handle, signed) self.failIf(verified) self.failUnless(self.messages) - def test_verifyAssocMismatch(self): """Attempt to validate sign-all message with a signed-list assoc.""" assoc_handle = '{vroom}{zoom}' @@ -1978,7 +1965,7 @@ def test_verifyAssocMismatch(self): 'foo': 'bar', 'apple': 'orange', 'openid.sig': "d71xlHtqnq98DonoSgoK/nD+QRM=", - }) + }) verified = self.signatory.verify(assoc_handle, signed) self.failIf(verified) @@ -1992,10 +1979,10 @@ def test_getAssoc(self): self.failIf(self.messages, self.messages) def test_getAssocExpired(self): - assoc_handle = self.makeAssoc(dumb=True, lifetime=-10) + assoc_handle = self.makeAssoc(dumb=True, lifetime=-10) assoc = self.signatory.getAssociation(assoc_handle, True) self.failIf(assoc, assoc) - self.failUnless(self.messages) + self.failUnless(self.messages) def test_getAssocInvalid(self): ah = 'no-such-handle' @@ -2052,6 +2039,5 @@ def test_invalidate(self): self.failIf(self.messages, self.messages) - if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index 0abbc5eb..ddcf9dc4 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -1,7 +1,7 @@ import unittest from openid.extensions import sreg -from openid.message import Message, NamespaceMap, registerNamespaceAlias +from openid.message import Message, NamespaceMap from openid.server.server import OpenIDRequest, OpenIDResponse @@ -9,6 +9,7 @@ class SRegURITest(unittest.TestCase): def test_is11(self): self.failUnlessEqual(sreg.ns_uri_1_1, sreg.ns_uri) + class CheckFieldNameTest(unittest.TestCase): def test_goodNamePasses(self): for field_name in sreg.data_fields: @@ -21,6 +22,8 @@ def test_badTypeFails(self): self.failUnlessRaises(ValueError, sreg.checkFieldName, None) # For supportsSReg test + + class FakeEndpoint(object): def __init__(self, supported): self.supported = supported @@ -30,6 +33,7 @@ def usesExtension(self, namespace_uri): self.checked_uris.append(namespace_uri) return namespace_uri in self.supported + class SupportsSRegTest(unittest.TestCase): def test_unsupported(self): endpoint = FakeEndpoint([]) @@ -48,6 +52,7 @@ def test_supported_1_0(self): self.failUnlessEqual([sreg.ns_uri_1_1, sreg.ns_uri_1_0], endpoint.checked_uris) + class FakeMessage(object): def __init__(self): self.openid1 = False @@ -56,6 +61,7 @@ def __init__(self): def isOpenID1(self): return self.openid1 + class GetNSTest(unittest.TestCase): def setUp(self): self.msg = FakeMessage() @@ -110,13 +116,14 @@ def test_openID1_sregNSfromArgs(self): args = { 'sreg.optional': 'nickname', 'sreg.required': 'dob', - } + } m = Message.fromOpenIDArgs(args) self.failUnless(m.getArg(sreg.ns_uri_1_1, 'optional') == 'nickname') self.failUnless(m.getArg(sreg.ns_uri_1_1, 'required') == 'dob') + class SRegRequestTest(unittest.TestCase): def test_constructEmpty(self): req = sreg.SRegRequest() @@ -142,7 +149,6 @@ def test_constructBadFields(self): sreg.SRegRequest, ['elvis']) def test_fromOpenIDRequest(self): - args = {} ns_sentinel = object() args_sentinel = object() @@ -173,7 +179,7 @@ def parseExtensionArgs(req_self, args): openid_req.message = msg req = TestingReq.fromOpenIDRequest(openid_req) - self.failUnless(type(req) is TestingReq) + self.assertIsInstance(req, TestingReq) self.failUnless(msg.copied) def test_parseExtensionArgs_empty(self): @@ -183,60 +189,60 @@ def test_parseExtensionArgs_empty(self): def test_parseExtensionArgs_extraIgnored(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'janrain':'inc'}) + req.parseExtensionArgs({'janrain': 'inc'}) def test_parseExtensionArgs_nonStrict(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'required':'beans'}) + req.parseExtensionArgs({'required': 'beans'}) self.failUnlessEqual([], req.required) def test_parseExtensionArgs_strict(self): req = sreg.SRegRequest() self.failUnlessRaises( ValueError, - req.parseExtensionArgs, {'required':'beans'}, strict=True) + req.parseExtensionArgs, {'required': 'beans'}, strict=True) def test_parseExtensionArgs_policy(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'policy_url':'http://policy'}, strict=True) + req.parseExtensionArgs({'policy_url': 'http://policy'}, strict=True) self.failUnlessEqual('http://policy', req.policy_url) def test_parseExtensionArgs_requiredEmpty(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'required':''}, strict=True) + req.parseExtensionArgs({'required': ''}, strict=True) self.failUnlessEqual([], req.required) def test_parseExtensionArgs_optionalEmpty(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':''}, strict=True) + req.parseExtensionArgs({'optional': ''}, strict=True) self.failUnlessEqual([], req.optional) def test_parseExtensionArgs_optionalSingle(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname'}, strict=True) + req.parseExtensionArgs({'optional': 'nickname'}, strict=True) self.failUnlessEqual(['nickname'], req.optional) def test_parseExtensionArgs_optionalList(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname,email'}, strict=True) - self.failUnlessEqual(['nickname','email'], req.optional) + req.parseExtensionArgs({'optional': 'nickname,email'}, strict=True) + self.failUnlessEqual(['nickname', 'email'], req.optional) def test_parseExtensionArgs_optionalListBadNonStrict(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname,email,beer'}) - self.failUnlessEqual(['nickname','email'], req.optional) + req.parseExtensionArgs({'optional': 'nickname,email,beer'}) + self.failUnlessEqual(['nickname', 'email'], req.optional) def test_parseExtensionArgs_optionalListBadStrict(self): req = sreg.SRegRequest() self.failUnlessRaises( ValueError, - req.parseExtensionArgs, {'optional':'nickname,email,beer'}, + req.parseExtensionArgs, {'optional': 'nickname,email,beer'}, strict=True) def test_parseExtensionArgs_bothNonStrict(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname', - 'required':'nickname'}) + req.parseExtensionArgs({'optional': 'nickname', + 'required': 'nickname'}) self.failUnlessEqual([], req.optional) self.failUnlessEqual(['nickname'], req.required) @@ -245,16 +251,16 @@ def test_parseExtensionArgs_bothStrict(self): self.failUnlessRaises( ValueError, req.parseExtensionArgs, - {'optional':'nickname', - 'required':'nickname'}, + {'optional': 'nickname', + 'required': 'nickname'}, strict=True) def test_parseExtensionArgs_bothList(self): req = sreg.SRegRequest() - req.parseExtensionArgs({'optional':'nickname,email', - 'required':'country,postcode'}, strict=True) - self.failUnlessEqual(['nickname','email'], req.optional) - self.failUnlessEqual(['country','postcode'], req.required) + req.parseExtensionArgs({'optional': 'nickname,email', + 'required': 'country,postcode'}, strict=True) + self.failUnlessEqual(['nickname', 'email'], req.optional) + self.failUnlessEqual(['country', 'postcode'], req.required) def test_allRequestedFields(self): req = sreg.SRegRequest() @@ -262,8 +268,7 @@ def test_allRequestedFields(self): req.requestField('nickname') self.failUnlessEqual(['nickname'], req.allRequestedFields()) req.requestField('gender', required=True) - requested = req.allRequestedFields() - requested.sort() + requested = sorted(req.allRequestedFields()) self.failUnlessEqual(['gender', 'nickname'], requested) def test_wereFieldsRequested(self): @@ -378,38 +383,40 @@ def test_getExtensionArgs(self): self.failUnlessEqual({}, req.getExtensionArgs()) req.requestField('nickname') - self.failUnlessEqual({'optional':'nickname'}, req.getExtensionArgs()) + self.failUnlessEqual({'optional': 'nickname'}, req.getExtensionArgs()) req.requestField('email') - self.failUnlessEqual({'optional':'nickname,email'}, + self.failUnlessEqual({'optional': 'nickname,email'}, req.getExtensionArgs()) req.requestField('gender', required=True) - self.failUnlessEqual({'optional':'nickname,email', - 'required':'gender'}, + self.failUnlessEqual({'optional': 'nickname,email', + 'required': 'gender'}, req.getExtensionArgs()) req.requestField('postcode', required=True) - self.failUnlessEqual({'optional':'nickname,email', - 'required':'gender,postcode'}, + self.failUnlessEqual({'optional': 'nickname,email', + 'required': 'gender,postcode'}, req.getExtensionArgs()) req.policy_url = 'http://policy.invalid/' - self.failUnlessEqual({'optional':'nickname,email', - 'required':'gender,postcode', - 'policy_url':'http://policy.invalid/'}, + self.failUnlessEqual({'optional': 'nickname,email', + 'required': 'gender,postcode', + 'policy_url': 'http://policy.invalid/'}, req.getExtensionArgs()) + data = { - 'nickname':'linusaur', - 'postcode':'12345', - 'country':'US', - 'gender':'M', - 'fullname':'Leonhard Euler', - 'email':'president@whitehouse.gov', - 'dob':'0000-00-00', - 'language':'en-us', - } + 'nickname': 'linusaur', + 'postcode': '12345', + 'country': 'US', + 'gender': 'M', + 'fullname': 'Leonhard Euler', + 'email': 'president@whitehouse.gov', + 'dob': '0000-00-00', + 'language': 'en-us', +} + class DummySuccessResponse(object): def __init__(self, message, signed_stuff): @@ -419,6 +426,7 @@ def __init__(self, message, signed_stuff): def getSignedNS(self, ns_uri): return self.signed_stuff + class SRegResponseTest(unittest.TestCase): def test_construct(self): resp = sreg.SRegResponse(data) @@ -432,22 +440,23 @@ def test_construct(self): def test_fromSuccessResponse_signed(self): message = Message.fromOpenIDArgs({ - 'sreg.nickname':'The Mad Stork', - }) + 'sreg.nickname': 'The Mad Stork', + }) success_resp = DummySuccessResponse(message, {}) sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp) self.failIf(sreg_resp) def test_fromSuccessResponse_unsigned(self): message = Message.fromOpenIDArgs({ - 'sreg.nickname':'The Mad Stork', - }) + 'sreg.nickname': 'The Mad Stork', + }) success_resp = DummySuccessResponse(message, {}) sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp, signed_only=False) self.failUnlessEqual([('nickname', 'The Mad Stork')], sreg_resp.items()) + class SendFieldsTest(unittest.TestCase): def test(self): # Create a request message with simple registration fields @@ -476,10 +485,11 @@ def test(self): # Extract the fields that were sent sreg_data_resp = resp_msg.getArgs(sreg.ns_uri) self.failUnlessEqual( - {'nickname':'linusaur', - 'email':'president@whitehouse.gov', - 'fullname':'Leonhard Euler', + {'nickname': 'linusaur', + 'email': 'president@whitehouse.gov', + 'fullname': 'Leonhard Euler', }, sreg_data_resp) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_symbol.py b/openid/test/test_symbol.py index 7f9b79bb..74252226 100644 --- a/openid/test/test_symbol.py +++ b/openid/test/test_symbol.py @@ -32,5 +32,6 @@ def test_ne_inequality(self): y = oidutil.Symbol('yyy') self.failUnless(x != y) + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 154f7516..98e9f490 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -17,7 +17,7 @@ def shortDescription(self): def runTest(self): try: actual = openid.urinorm.urinorm(self.case) - except ValueError, why: + except ValueError as why: self.assertEqual(self.expected, 'fail', why) else: self.assertEqual(actual, self.expected) @@ -43,6 +43,7 @@ def parseTests(test_data): return result + def pyUnitTests(): here = os.path.dirname(os.path.abspath(__file__)) test_data_file_name = os.path.join(here, 'urinorm.txt') diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index 43664bc3..57ead86a 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -14,11 +14,12 @@ def constResult(*args, **kwargs): return constResult + class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): def failUnlessProtocolError(self, prefix, callable, *args, **kwargs): try: result = callable(*args, **kwargs) - except consumer.ProtocolError, e: + except consumer.ProtocolError as e: self.failUnless( e[0].startswith(prefix), 'Expected message prefix %r, got message %r' % (prefix, e[0])) @@ -37,30 +38,30 @@ def test_openID1NoLocalID(self): self.failUnlessLogEmpty() def test_openID1NoEndpoint(self): - msg = message.Message.fromOpenIDArgs({'identity':'snakes on a plane'}) + msg = message.Message.fromOpenIDArgs({'identity': 'snakes on a plane'}) self.failUnlessRaises(RuntimeError, self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2NoOPEndpointArg(self): - msg = message.Message.fromOpenIDArgs({'ns':message.OPENID2_NS}) + msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS}) self.failUnlessRaises(KeyError, self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2LocalIDNoClaimed(self): - msg = message.Message.fromOpenIDArgs({'ns':message.OPENID2_NS, - 'op_endpoint':'Phone Home', - 'identity':'Jose Lius Borges'}) + msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, + 'op_endpoint': 'Phone Home', + 'identity': 'Jose Lius Borges'}) self.failUnlessProtocolError( 'openid.identity is present without', self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2NoLocalIDClaimed(self): - msg = message.Message.fromOpenIDArgs({'ns':message.OPENID2_NS, - 'op_endpoint':'Phone Home', - 'claimed_id':'Manuel Noriega'}) + msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, + 'op_endpoint': 'Phone Home', + 'claimed_id': 'Manuel Noriega'}) self.failUnlessProtocolError( 'openid.claimed_id is present without', self.consumer._verifyDiscoveryResults, msg) @@ -68,8 +69,8 @@ def test_openID2NoLocalIDClaimed(self): def test_openID2NoIdentifiers(self): op_endpoint = 'Phone Home' - msg = message.Message.fromOpenIDArgs({'ns':message.OPENID2_NS, - 'op_endpoint':op_endpoint}) + msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, + 'op_endpoint': op_endpoint}) result_endpoint = self.consumer._verifyDiscoveryResults(msg) self.failUnless(result_endpoint.isOPIdentifier()) self.failUnlessEqual(op_endpoint, result_endpoint.server_url) @@ -82,10 +83,10 @@ def test_openID2NoEndpointDoesDisco(self): sentinel.claimed_id = 'monkeysoft' self.consumer._discoverAndVerify = const(sentinel) msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':'sour grapes', - 'claimed_id':'monkeysoft', - 'op_endpoint':op_endpoint}) + {'ns': message.OPENID2_NS, + 'identity': 'sour grapes', + 'claimed_id': 'monkeysoft', + 'op_endpoint': op_endpoint}) result = self.consumer._verifyDiscoveryResults(msg) self.failUnlessEqual(sentinel, result) self.failUnlessLogMatches('No pre-discovered') @@ -100,10 +101,10 @@ def test_openID2MismatchedDoesDisco(self): sentinel.claimed_id = 'monkeysoft' self.consumer._discoverAndVerify = const(sentinel) msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':'sour grapes', - 'claimed_id':'monkeysoft', - 'op_endpoint':op_endpoint}) + {'ns': message.OPENID2_NS, + 'identity': 'sour grapes', + 'claimed_id': 'monkeysoft', + 'op_endpoint': op_endpoint}) result = self.consumer._verifyDiscoveryResults(msg, mismatched) self.failUnlessEqual(sentinel, result) self.failUnlessLogMatches('Error attempting to use stored', @@ -117,10 +118,10 @@ def test_openid2UsePreDiscovered(self): endpoint.type_uris = [discover.OPENID_2_0_TYPE] msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':endpoint.local_id, - 'claimed_id':endpoint.claimed_id, - 'op_endpoint':endpoint.server_url}) + {'ns': message.OPENID2_NS, + 'identity': endpoint.local_id, + 'claimed_id': endpoint.claimed_id, + 'op_endpoint': endpoint.server_url}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.failUnless(result is endpoint) self.failUnlessLogEmpty() @@ -143,14 +144,14 @@ def discoverAndVerify(claimed_id, to_match_endpoints): self.consumer._discoverAndVerify = discoverAndVerify msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':endpoint.local_id, - 'claimed_id':endpoint.claimed_id, - 'op_endpoint':endpoint.server_url}) + {'ns': message.OPENID2_NS, + 'identity': endpoint.local_id, + 'claimed_id': endpoint.claimed_id, + 'op_endpoint': endpoint.server_url}) try: r = self.consumer._verifyDiscoveryResults(msg, endpoint) - except consumer.ProtocolError, e: + except consumer.ProtocolError as e: # Should we make more ProtocolError subclasses? self.failUnless(str(e), text) else: @@ -167,14 +168,15 @@ def test_openid1UsePreDiscovered(self): endpoint.type_uris = [discover.OPENID_1_1_TYPE] msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID1_NS, - 'identity':endpoint.local_id}) + {'ns': message.OPENID1_NS, + 'identity': endpoint.local_id}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.failUnless(result is endpoint) self.failUnlessLogEmpty() def test_openid1UsePreDiscoveredWrongType(self): - class VerifiedError(Exception): pass + class VerifiedError(Exception): + pass def discoverAndVerify(claimed_id, _to_match): raise VerifiedError @@ -188,8 +190,8 @@ def discoverAndVerify(claimed_id, _to_match): endpoint.type_uris = [discover.OPENID_2_0_TYPE] msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID1_NS, - 'identity':endpoint.local_id}) + {'ns': message.OPENID1_NS, + 'identity': endpoint.local_id}) self.failUnlessRaises( VerifiedError, @@ -208,18 +210,18 @@ def test_openid2Fragment(self): endpoint.type_uris = [discover.OPENID_2_0_TYPE] msg = message.Message.fromOpenIDArgs( - {'ns':message.OPENID2_NS, - 'identity':endpoint.local_id, + {'ns': message.OPENID2_NS, + 'identity': endpoint.local_id, 'claimed_id': claimed_id_frag, 'op_endpoint': endpoint.server_url}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) - + self.failUnlessEqual(result.local_id, endpoint.local_id) self.failUnlessEqual(result.server_url, endpoint.server_url) self.failUnlessEqual(result.type_uris, endpoint.type_uris) self.failUnlessEqual(result.claimed_id, claimed_id_frag) - + self.failUnlessLogEmpty() def test_openid1Fallback1_0(self): @@ -267,5 +269,6 @@ def test_endpointWithoutLocalID(self): self.failUnlessEqual(result, None) self.failUnlessLogEmpty() + if __name__ == '__main__': unittest.main() diff --git a/openid/test/test_xri.py b/openid/test/test_xri.py index 33ea0e05..6e6ac8e2 100644 --- a/openid/test/test_xri.py +++ b/openid/test/test_xri.py @@ -18,7 +18,6 @@ def test_escaping_percents(self): self.failUnlessEqual(xri.escapeForIRI('@example/abc%2Fd/ef'), '@example/abc%252Fd/ef') - def test_escaping_xref(self): # no escapes esc = xri.escapeForIRI @@ -33,7 +32,6 @@ def test_escaping_xref(self): esc('@example/foo/(@baz?p=q#r)?i=j#k')) - class XriTransformationTestCase(TestCase): def test_to_iri_normal(self): self.failUnlessEqual(xri.toIRINormal('@example'), 'xri://@example') @@ -53,7 +51,6 @@ def test_iri_to_url(self): self.failUnlessEqual(xri.iriToURI(s), expected) - class CanonicalIDTest(TestCase): def mkTest(providerID, canonicalID, isAuthoritative): def test(self): @@ -73,6 +70,7 @@ def test(self): test_atEqualsAndTooDeepFails = mkTest('@!1234!ABCD', '=!1234', False) test_differentBeginningFails = mkTest('=!BABE', '=!D00D', False) + class TestGetRootAuthority(TestCase): def mkTest(the_xri, expected_root): def test(self): @@ -96,8 +94,9 @@ def test(self): # Looking at the ABNF in XRI Syntax 2.0, I don't think you can # have example.com*bar. You can do (example.com)*bar, but that # would mean something else. - ##("example.com*bar/(=baz)", "example.com*bar"), - ##("baz.example.com!01/foo", "baz.example.com!01"), + # ("example.com*bar/(=baz)", "example.com*bar"), + # ("baz.example.com!01/foo", "baz.example.com!01"), + if __name__ == '__main__': import unittest diff --git a/openid/test/test_xrires.py b/openid/test/test_xrires.py index 873255c4..b06a8b27 100644 --- a/openid/test/test_xrires.py +++ b/openid/test/test_xrires.py @@ -11,7 +11,6 @@ def setUp(self): self.servicetype = 'xri://+i-service*(+forwarding)*($v*1.0)' self.servicetype_enc = 'xri%3A%2F%2F%2Bi-service%2A%28%2Bforwarding%29%2A%28%24v%2A1.0%29' - def test_proxy_url(self): st = self.servicetype ste = self.servicetype_enc @@ -30,7 +29,6 @@ def test_proxy_url(self): args_esc = "_xrd_r=application%2Fxrds%2Bxml%3Bsep%3Dfalse" self.failUnlessEqual(h + '=foo?' + args_esc, pqu('=foo', None)) - def test_proxy_url_qmarks(self): st = self.servicetype ste = self.servicetype_enc diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index 8c222d08..c7ba05c0 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -24,7 +24,10 @@ No such file %s """ -class QuitServer(Exception): pass + +class QuitServer(Exception): + pass + def mkResponse(data): status_mo = status_header_re.match(data) @@ -40,6 +43,7 @@ def mkResponse(data): headers=headers, body=body) + class TestFetcher(object): def __init__(self, base_url): self.base_url = base_url @@ -64,16 +68,18 @@ def fetch(self, url, headers, body): response.final_url = current_url return response + class TestSecondGet(unittest.TestCase): class MockFetcher(object): def __init__(self): self.count = 0 + def fetch(self, uri, headers=None, body=None): self.count += 1 if self.count == 1: headers = { 'X-XRDS-Location'.lower(): 'http://unittest/404', - } + } return fetchers.HTTPResponse(uri, 200, headers, '') else: return fetchers.HTTPResponse(uri, 404) @@ -137,10 +143,8 @@ def runCustomTest(self): self.failUnlessEqual( self.expected.response_text, result.response_text, msg) - expected_keys = dir(self.expected) - expected_keys.sort() - actual_keys = dir(result) - actual_keys.sort() + expected_keys = sorted(dir(self.expected)) + actual_keys = sorted(dir(result)) self.failUnlessEqual(actual_keys, expected_keys) for k in dir(self.expected): @@ -162,6 +166,7 @@ def shortDescription(self): n, self.__class__.__module__) + def pyUnitTests(): s = unittest.TestSuite() for success, input_name, id_name, result_name in discoverdata.testlist: @@ -170,9 +175,11 @@ def pyUnitTests(): return s + def test(): runner = unittest.TextTestRunner() return runner.run(pyUnitTests()) + if __name__ == '__main__': test() diff --git a/openid/test/trustroot.py b/openid/test/trustroot.py index f934ce36..c9a0f726 100644 --- a/openid/test/trustroot.py +++ b/openid/test/trustroot.py @@ -23,6 +23,7 @@ def runTest(self): else: assert tr is None, tr + class _MatchTest(unittest.TestCase): def __init__(self, match, desc, line): unittest.TestCase.__init__(self) @@ -45,6 +46,7 @@ def runTest(self): else: assert not match + def getTests(t, grps, head, dat): tests = [] top = head.strip() @@ -61,6 +63,7 @@ def getTests(t, grps, head, dat): i += 2 return tests + def parseTests(data): parts = map(str.strip, data.split('=' * 40 + '\n')) assert not parts[0] @@ -71,6 +74,7 @@ def parseTests(data): tests.extend(getTests(_MatchTest, [1, 0], mh, mdat)) return tests + def pyUnitTests(): here = os.path.dirname(os.path.abspath(__file__)) test_data_file_name = os.path.join(here, 'data', 'trustroot.txt') @@ -81,6 +85,7 @@ def pyUnitTests(): tests = parseTests(test_data) return unittest.TestSuite(tests) + if __name__ == '__main__': suite = pyUnitTests() runner = unittest.TextTestRunner() diff --git a/openid/urinorm.py b/openid/urinorm.py index 5bdbaeff..21869c83 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -29,11 +29,11 @@ (0xA0, 0xD7FF), (0xF900, 0xFDCF), (0xFDF0, 0xFFEF), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), - ] + ] else: UCSCHAR = [ (0xA0, 0xD7FF), @@ -53,19 +53,22 @@ (0xC0000, 0xCFFFD), (0xD0000, 0xDFFFD), (0xE1000, 0xEFFFD), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), (0xF0000, 0xFFFFD), (0x100000, 0x10FFFD), - ] + ] _unreserved = [False] * 256 -for _ in range(ord('A'), ord('Z') + 1): _unreserved[_] = True -for _ in range(ord('0'), ord('9') + 1): _unreserved[_] = True -for _ in range(ord('a'), ord('z') + 1): _unreserved[_] = True +for _ in range(ord('A'), ord('Z') + 1): + _unreserved[_] = True +for _ in range(ord('0'), ord('9') + 1): + _unreserved[_] = True +for _ in range(ord('a'), ord('z') + 1): + _unreserved[_] = True _unreserved[ord('-')] = True _unreserved[ord('.')] = True _unreserved[ord('_')] = True @@ -73,7 +76,7 @@ _escapeme_re = re.compile('[%s]' % (''.join( - map(lambda (m, n): u'%s-%s' % (unichr(m), unichr(n)), + map(lambda m_n: u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])), UCSCHAR + IPRIVATE)),)) @@ -176,9 +179,7 @@ def urinorm(uri): host = host.lower() if port: - if (port == ':' or - (scheme == 'http' and port == ':80') or - (scheme == 'https' and port == ':443')): + if port == ':' or (scheme == 'http' and port == ':80') or (scheme == 'https' and port == ':443'): port = '' else: port = '' diff --git a/openid/yadis/__init__.py b/openid/yadis/__init__.py index cfa5f1e7..68a0d449 100644 --- a/openid/yadis/__init__.py +++ b/openid/yadis/__init__.py @@ -10,7 +10,7 @@ 'services', 'xri', 'xrires', - ] +] __version__ = '[library version:1.1.0-rc1]'[17:-1] diff --git a/openid/yadis/accept.py b/openid/yadis/accept.py index d7508131..2353bfbf 100644 --- a/openid/yadis/accept.py +++ b/openid/yadis/accept.py @@ -1,6 +1,8 @@ """Functions for generating and parsing HTTP Accept: headers for supporting server-directed content negotiation. """ +from operator import itemgetter + def generateAcceptHeader(*elements): """Generate an accept header value @@ -9,7 +11,7 @@ def generateAcceptHeader(*elements): """ parts = [] for element in elements: - if type(element) is str: + if isinstance(element, str): qs = "1.0" mtype = element else: @@ -32,6 +34,7 @@ def generateAcceptHeader(*elements): return ', '.join(chunks) + def parseAcceptHeader(value): """Parse an accept header, ignoring any accept-extensions @@ -65,11 +68,11 @@ def parseAcceptHeader(value): else: q = 1.0 - accept.append((q, main, sub)) + accept.append((main, sub, q)) + + # Sort in order q, main, sub + return sorted(accept, key=itemgetter(2, 0, 1), reverse=True) - accept.sort() - accept.reverse() - return [(main, sub, q) for (q, main, sub) in accept] def matchTypes(accept_types, have_types): """Given the result of parsing an Accept: header, and the @@ -93,31 +96,32 @@ def matchTypes(accept_types, have_types): match_main = {} match_sub = {} - for (main, sub, q) in accept_types: + for (main, sub, qvalue) in accept_types: if main == '*': - default = max(default, q) + default = max(default, qvalue) continue elif sub == '*': - match_main[main] = max(match_main.get(main, 0), q) + match_main[main] = max(match_main.get(main, 0), qvalue) else: - match_sub[(main, sub)] = max(match_sub.get((main, sub), 0), q) + match_sub[(main, sub)] = max(match_sub.get((main, sub), 0), qvalue) accepted_list = [] order_maintainer = 0 for mtype in have_types: main, sub = mtype.split('/') if (main, sub) in match_sub: - q = match_sub[(main, sub)] + quality = match_sub[(main, sub)] else: - q = match_main.get(main, default) + quality = match_main.get(main, default) - if q: - accepted_list.append((1 - q, order_maintainer, q, mtype)) + if quality: + accepted_list.append((1 - quality, order_maintainer, quality, mtype)) order_maintainer += 1 accepted_list.sort() return [(mtype, q) for (_, _, q, mtype) in accepted_list] + def getAcceptable(accept_header, have_types): """Parse the accept header and return a list of available types in preferred order. If a type is unacceptable, it will not be in the diff --git a/openid/yadis/constants.py b/openid/yadis/constants.py index 75ff96ef..d160c66f 100644 --- a/openid/yadis/constants.py +++ b/openid/yadis/constants.py @@ -10,4 +10,4 @@ ('text/html', 0.3), ('application/xhtml+xml', 0.5), (YADIS_CONTENT_TYPE, 1.0), - ) +) diff --git a/openid/yadis/discover.py b/openid/yadis/discover.py index 27fcd013..83655a90 100644 --- a/openid/yadis/discover.py +++ b/openid/yadis/discover.py @@ -16,6 +16,7 @@ def __init__(self, message, http_response): Exception.__init__(self, message) self.http_response = http_response + class DiscoveryResult(object): """Contains the result of performing Yadis discovery on a URI""" @@ -53,6 +54,7 @@ def isXRDS(self): return (self.usedYadisLocation() or self.content_type == YADIS_CONTENT_TYPE) + def discover(uri): """Discover services for a given URI. @@ -97,7 +99,6 @@ def discover(uri): return result - def whereIsYadis(resp): """Given a HTTPResponse, return the location of the Yadis document. @@ -114,8 +115,7 @@ def whereIsYadis(resp): # According to the spec, the content-type header must be an exact # match, or else we have to look for an indirection. - if (content_type and - content_type.split(';', 1)[0].lower() == YADIS_CONTENT_TYPE): + if content_type and content_type.split(';', 1)[0].lower() == YADIS_CONTENT_TYPE: return resp.final_url else: # Try the header diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index 52a8ab32..563a1f2e 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -16,7 +16,7 @@ 'iterServices', 'expandService', 'expandServices', - ] +] import random import sys @@ -36,9 +36,9 @@ # Make the parser raise an exception so we can sniff out the type # of exceptions ElementTree.XML('> purposely malformed XML <') -except (SystemExit, MemoryError, AssertionError, ImportError): +except (MemoryError, AssertionError, ImportError): raise -except: +except Exception: XMLError = sys.exc_info()[0] @@ -49,14 +49,12 @@ class XRDSError(Exception): reason = None - class XRDSFraud(XRDSError): """Raised when there's an assertion in the XRDS that it does not have the authority to make. """ - def parseXRDS(text): """Parse the given text as an XRDS document. @@ -67,7 +65,7 @@ def parseXRDS(text): """ try: element = ElementTree.XML(text) - except XMLError, why: + except XMLError as why: exc = XRDSError('Error parsing document as XML') exc.reason = why raise exc @@ -78,12 +76,15 @@ def parseXRDS(text): return tree + XRD_NS_2_0 = 'xri://$xrd*($v*2.0)' XRDS_NS = 'xri://$xrds' + def nsTag(ns, t): return '{%s}%s' % (ns, t) + def mkXRDTag(t): """basestring -> basestring @@ -92,6 +93,7 @@ def mkXRDTag(t): """ return nsTag(XRD_NS_2_0, t) + def mkXRDSTag(t): """basestring -> basestring @@ -100,6 +102,7 @@ def mkXRDSTag(t): """ return nsTag(XRDS_NS, t) + # Tags that are used in Yadis documents root_tag = mkXRDSTag('XRDS') service_tag = mkXRDTag('Service') @@ -111,11 +114,13 @@ def mkXRDSTag(t): # Other XRD tags canonicalID_tag = mkXRDTag('CanonicalID') + def isXRDS(xrd_tree): """Is this document an XRDS document?""" root = xrd_tree.getroot() return root.tag == root_tag + def getYadisXRD(xrd_tree): """Return the XRD element that should contain the Yadis services""" xrd = None @@ -132,6 +137,7 @@ def getYadisXRD(xrd_tree): return xrd + def getXRDExpiration(xrd_element, default=None): """Return the expiration date of this XRD element, or None if no expiration was specified. @@ -156,6 +162,7 @@ def getXRDExpiration(xrd_element, default=None): expires_time = strptime(expires_string, "%Y-%m-%dT%H:%M:%SZ") return datetime(*expires_time[0:6]) + def getCanonicalID(iname, xrd_tree): """Return the CanonicalID from this XRDS document. @@ -194,20 +201,22 @@ def getCanonicalID(iname, xrd_tree): return canonicalID - class _Max(object): """Value that compares greater than any other value. Should only be used as a singleton. Implemented for use as a priority value for when a priority is not specified.""" + def __cmp__(self, other): if other is self: return 0 return 1 + Max = _Max() + def getPriorityStrict(element): """Get the priority of this element. @@ -226,6 +235,7 @@ def getPriorityStrict(element): # Any errors in parsing the priority fall through to here return Max + def getPriority(element): """Get the priority of this element @@ -236,17 +246,18 @@ def getPriority(element): except ValueError: return Max + def prioSort(elements): """Sort a list of elements that have priority attributes""" # Randomize the services before sorting so that equal priority # elements are load-balanced. random.shuffle(elements) - prio_elems = [(getPriority(e), e) for e in elements] - prio_elems.sort() + prio_elems = sorted((getPriority(e), e) for e in elements) sorted_elems = [s for (_, s) in prio_elems] return sorted_elems + def iterServices(xrd_tree): """Return an iterable over the Service elements in the Yadis XRD @@ -254,18 +265,21 @@ def iterServices(xrd_tree): xrd = getYadisXRD(xrd_tree) return prioSort(xrd.findall(service_tag)) + def sortedURIs(service_element): """Given a Service element, return a list of the contents of all URI tags in priority order.""" return [uri_element.text for uri_element in prioSort(service_element.findall(uri_tag))] + def getTypeURIs(service_element): """Given a Service element, return a list of the contents of all Type tags""" return [type_element.text for type_element in service_element.findall(type_tag)] + def expandService(service_element): """Take a service element and expand it into an iterator of: ([type_uri], uri, service_element) @@ -281,6 +295,7 @@ def expandService(service_element): return expanded + def expandServices(service_elements): """Take a sorted iterator of service elements and expand it into a sorted iterator of: diff --git a/openid/yadis/filters.py b/openid/yadis/filters.py index 43e4f3f1..1a9d3e74 100644 --- a/openid/yadis/filters.py +++ b/openid/yadis/filters.py @@ -9,7 +9,7 @@ 'IFilter', 'TransformFilterMaker', 'CompoundFilter', - ] +] from openid.yadis.etxrd import expandService @@ -27,6 +27,7 @@ class BasicServiceEndpoint(object): The simplest kind of filter you can write implements fromBasicServiceEndpoint, which takes one of these objects. """ + def __init__(self, yadis_url, type_uris, uri, service_element): self.type_uris = type_uris self.yadis_url = yadis_url @@ -61,6 +62,7 @@ def fromBasicServiceEndpoint(endpoint): fromBasicServiceEndpoint = staticmethod(fromBasicServiceEndpoint) + class IFilter(object): """Interface for Yadis filter objects. Other filter-like things are convertable to this class.""" @@ -69,6 +71,7 @@ def getServiceEndpoints(self, yadis_url, service_element): """Returns an iterator of endpoint objects""" raise NotImplementedError + class TransformFilterMaker(object): """Take a list of basic filters and makes a filter that transforms the basic filter into a top-level filter. This is mostly useful @@ -124,10 +127,12 @@ def applyFilters(self, endpoint): return None + class CompoundFilter(object): """Create a new filter that applies a set of filters to an endpoint and collects their results. """ + def __init__(self, subfilters): self.subfilters = subfilters @@ -140,10 +145,12 @@ def getServiceEndpoints(self, yadis_url, service_element): subfilter.getServiceEndpoints(yadis_url, service_element)) return endpoints + # Exception raised when something is not able to be turned into a filter filter_type_error = TypeError( 'Expected a filter, an endpoint, a callable or a list of any of these.') + def mkFilter(parts): """Convert a filter-convertable thing into a filter @@ -160,6 +167,7 @@ def mkFilter(parts): else: return mkCompoundFilter(parts) + def mkCompoundFilter(parts): """Create a filter out of a list of filter-like things diff --git a/openid/yadis/manager.py b/openid/yadis/manager.py index 709adb7d..afd55eea 100644 --- a/openid/yadis/manager.py +++ b/openid/yadis/manager.py @@ -54,6 +54,7 @@ def store(self, session): """Store this object in the session, by its session key.""" session[self.session_key] = self + class Discovery(object): """State management for discovery. @@ -133,7 +134,7 @@ def cleanup(self, force=False): return service - ### Lower-level methods + # Lower-level methods def getSessionKey(self): """Get the session key for this starting URL and suffix diff --git a/openid/yadis/parsehtml.py b/openid/yadis/parsehtml.py index c2f80294..4ecef3b9 100644 --- a/openid/yadis/parsehtml.py +++ b/openid/yadis/parsehtml.py @@ -8,17 +8,20 @@ # Size of the chunks to search at a time (also the amount that gets # read at a time) -CHUNK_SIZE = 1024 * 16 # 16 KB +CHUNK_SIZE = 1024 * 16 # 16 KB + class ParseDone(Exception): """Exception to hold the URI that was located when the parse is finished. If the parse finishes without finding the URI, set it to None.""" + class MetaNotFound(Exception): """Exception to hold the content of the page if we did not find the appropriate tag""" + re_flags = re.IGNORECASE | re.UNICODE | re.VERBOSE ent_pat = r''' & @@ -32,6 +35,7 @@ class MetaNotFound(Exception): ent_re = re.compile(ent_pat, re_flags) + def substituteMO(mo): if mo.lastgroup == 'hex': codepoint = int(mo.group('hex'), 16) @@ -46,9 +50,11 @@ def substituteMO(mo): else: return unichr(codepoint) + def substituteEntities(s): return ent_re.sub(substituteMO, s) + class YadisHTMLParser(HTMLParser): """Parser that finds a meta http-equiv tag in the head of a html document. @@ -107,7 +113,7 @@ def handle_starttag(self, tag, attrs): # if we ever see a start body tag, bail out right away, since # we want to prevent the meta tag from appearing in the body # [2] - if tag=='body': + if tag == 'body': self._terminate() if self.phase == self.TOP: @@ -155,6 +161,7 @@ def feed(self, chars): return HTMLParser.feed(self, chars) + def findHTMLMeta(stream): """Look for a meta http-equiv tag with the YADIS header name. @@ -171,7 +178,7 @@ def findHTMLMeta(stream): parser = YadisHTMLParser() chunks = [] - while 1: + while True: chunk = stream.read(CHUNK_SIZE) if not chunk: # End of file @@ -180,11 +187,11 @@ def findHTMLMeta(stream): chunks.append(chunk) try: parser.feed(chunk) - except HTMLParseError, why: + except HTMLParseError as why: # HTML parse error, so bail chunks.append(stream.read()) break - except ParseDone, why: + except ParseDone as why: uri = why[0] if uri is None: # Parse finished, but we may need the rest of the file diff --git a/openid/yadis/services.py b/openid/yadis/services.py index 65d88344..740fec0a 100644 --- a/openid/yadis/services.py +++ b/openid/yadis/services.py @@ -27,10 +27,11 @@ def getServiceEndpoints(input_url, flt=None): try: endpoints = applyFilter(result.normalized_uri, result.response_text, flt) - except XRDSError, err: + except XRDSError as err: raise DiscoveryFailure(str(err), None) return (result.normalized_uri, endpoints) + def applyFilter(normalized_uri, xrd_data, flt=None): """Generate an iterable of endpoint objects given this input data, presumably from the result of performing the Yadis protocol. diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index 3a39a6b8..bd3b29ed 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -1,10 +1,12 @@ # -*- test-case-name: openid.test.test_xri -*- """Utility functions for handling XRIs. -@see: XRI Syntax v2.0 at the U{OASIS XRI Technical Committee} +@see: XRI Syntax v2.0 at the + U{OASIS XRI Technical Committee} """ import re +from functools import reduce XRI_AUTHORITIES = ['!', '=', '@', '+', '$', '('] @@ -16,11 +18,11 @@ (0xA0, 0xD7FF), (0xF900, 0xFDCF), (0xFDF0, 0xFFEF), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), - ] + ] else: UCSCHAR = [ (0xA0, 0xD7FF), @@ -40,17 +42,17 @@ (0xC0000, 0xCFFFD), (0xD0000, 0xDFFFD), (0xE1000, 0xEFFFD), - ] + ] IPRIVATE = [ (0xE000, 0xF8FF), (0xF0000, 0xFFFFD), (0x100000, 0x10FFFD), - ] + ] _escapeme_re = re.compile('[%s]' % (''.join( - map(lambda (m, n): u'%s-%s' % (unichr(m), unichr(n)), + map(lambda m_n: u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])), UCSCHAR + IPRIVATE)),)) @@ -59,8 +61,7 @@ def identifierScheme(identifier): @returns: C{"XRI"} or C{"URI"} """ - if identifier.startswith('xri://') or ( - identifier and identifier[0] in XRI_AUTHORITIES): + if identifier.startswith('xri://') or (identifier and identifier[0] in XRI_AUTHORITIES): return "XRI" else: return "URI" @@ -146,8 +147,7 @@ def rootAuthority(xri): else: # IRI reference. XXX: Can IRI authorities have segments? segments = authority.split('!') - segments = reduce(list.__add__, - map(lambda s: s.split('*'), segments)) + segments = reduce(list.__add__, map(lambda s: s.split('*'), segments)) root = segments[0] return XRI(root) diff --git a/openid/yadis/xrires.py b/openid/yadis/xrires.py index e8fd7e4c..4a365950 100644 --- a/openid/yadis/xrires.py +++ b/openid/yadis/xrires.py @@ -11,13 +11,14 @@ DEFAULT_PROXY = 'http://proxy.xri.net/' + class ProxyResolver(object): """Python interface to a remote XRI proxy resolver. """ + def __init__(self, proxy_url=DEFAULT_PROXY): self.proxy_url = proxy_url - def queryURL(self, xri, service_type=None): """Build a URL to query the proxy resolver. @@ -42,7 +43,7 @@ def queryURL(self, xri, service_type=None): # 11:13:42), then we could ask for application/xrd+xml instead, # which would give us a bit less to process. '_xrd_r': 'application/xrds+xml', - } + } if service_type: args['_xrd_t'] = service_type else: @@ -51,7 +52,6 @@ def queryURL(self, xri, service_type=None): query = _appendArgs(hxri, args) return query - def query(self, xri, service_types): """Resolve some services for an XRI. @@ -103,8 +103,7 @@ def _appendArgs(url, args): """ # to be merged with oidutil.appendArgs when we combine the projects. if hasattr(args, 'items'): - args = args.items() - args.sort() + args = sorted(args.items()) if len(args) == 0: return url diff --git a/pylintrc b/pylintrc deleted file mode 100644 index fb36e4c0..00000000 --- a/pylintrc +++ /dev/null @@ -1,40 +0,0 @@ -[REPORTS] - -include-ids=y - -[BASIC] - -# Required attributes for module, separated by a comma -required-attributes=__all__ - -# Regular expression which should only match functions or classes name which do -# not require a docstring -no-docstring-rgx=__.*__ - -# Regular expression which should only match correct module names -module-rgx=[a-z_][a-z0-9_]*$ - -# Regular expression which should only match correct module level names -const-rgx=(([a-z_][a-z0-9_]{3,30})|(__.*__)|([A-Z_][A-Z0-9_]{3,30}))$ - -# Regular expression which should only match correct class names -class-rgx=[A-Z_][a-zA-Z0-9]+$ - -# Regular expression which should only match correct function names -function-rgx=[a-z_][A-Za-z0-9_]{2,30}$ - -# Regular expression which should only match correct method names -method-rgx=[a-z_][A-Za-z0-9_]{2,30}$ - -# Regular expression which should only match correct list comprehension / -# generator expression variable names -inlinevar-rgx=[A-Za-z_][A-Za-z0-9_]*$ - -# Good variable names which should always be accepted, separated by a comma -good-names=i,j,k,ex,Run,_ - -# Bad variable names which should always be refused, separated by a comma -bad-names=foo,bar,baz,toto,tutu,tata - -# List of builtins function names that should not be used, separated by a comma -bad-functions=input diff --git a/setup.py b/setup.py index d68abde6..67ebdf83 100644 --- a/setup.py +++ b/setup.py @@ -35,15 +35,15 @@ author_email='openid@janrain.com', download_url='http://github.com/openid/python-openid/tarball/%s' % (version,), classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Operating System :: POSIX", - "Programming Language :: Python", - "Topic :: Internet :: WWW/HTTP", - "Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: System :: Systems Administration :: Authentication/Directory", + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX", + "Programming Language :: Python", + "Topic :: Internet :: WWW/HTTP", + "Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Topic :: System :: Systems Administration :: Authentication/Directory", ], - ) +) From 48ac8bb9acd4b00a44790770a19da05c4e84f8b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 12 Dec 2017 14:45:06 +0100 Subject: [PATCH 025/151] Enable test discovery --- Makefile | 10 +++-- openid/test/__init__.py | 11 ++++++ openid/test/kvform.py | 67 +++++++++++++------------------- openid/test/test_accept.py | 38 ++++++++++++------ openid/test/test_examples.py | 25 ++++++++++-- openid/test/test_openidyadis.py | 69 +++++++++++++++------------------ openid/test/test_parsehtml.py | 58 +++++++++++---------------- openid/test/test_urinorm.py | 38 +++++++++--------- run_tests.sh | 2 - 9 files changed, 164 insertions(+), 154 deletions(-) delete mode 100755 run_tests.sh diff --git a/Makefile b/Makefile index 2cba2775..4e66971c 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,15 @@ .PHONY: test coverage isort check-all check-isort check-flake8 test: - python admin/runtests + # TODO: Ignore djopenid tests for the time being + python -m unittest discover --start openid/test -t . coverage: - python-coverage erase + python -m coverage erase -rm -r htmlcov - python-coverage run --branch --source="." admin/runtests - python-coverage html --directory=htmlcov + # TODO: Ignore djopenid tests for the time being + python -m coverage run --branch --source="." openid/test/__init__.py discover --start openid/test -t . + python -m coverage html --directory=htmlcov isort: isort --recursive . diff --git a/openid/test/__init__.py b/openid/test/__init__.py index e69de29b..a503e99a 100644 --- a/openid/test/__init__.py +++ b/openid/test/__init__.py @@ -0,0 +1,11 @@ +"""Openid library tests.""" +import unittest + + +# Utility code to allow run unittest under coverage called as module. +def _run_unittest(): + unittest.main() + + +if __name__ == '__main__': + _run_unittest() diff --git a/openid/test/kvform.py b/openid/test/kvform.py index 7bbb5cef..a2ebd7d0 100644 --- a/openid/test/kvform.py +++ b/openid/test/kvform.py @@ -5,8 +5,6 @@ class KVBaseTest(unittest.TestCase, CatchLogs): - def shortDescription(self): - return '%s test for %r' % (self.__class__.__name__, self.kvform) def checkWarnings(self, num_warnings): self.failUnlessEqual(num_warnings, len(self.messages), repr(self.messages)) @@ -19,35 +17,26 @@ def tearDown(self): class KVDictTest(KVBaseTest): - def __init__(self, kv, dct, warnings): - unittest.TestCase.__init__(self) - self.kvform = kv - self.dict = dct - self.expected_warnings = warnings def runTest(self): - # Convert KVForm to dict - d = kvform.kvToDict(self.kvform) + for kv_data, result, expected_warnings in kvdict_cases: + # Convert KVForm to dict + d = kvform.kvToDict(kv_data) - # make sure it parses to expected dict - self.failUnlessEqual(self.dict, d) + # make sure it parses to expected dict + self.failUnlessEqual(d, result) - # Check to make sure we got the expected number of warnings - self.checkWarnings(self.expected_warnings) + # Check to make sure we got the expected number of warnings + self.checkWarnings(expected_warnings) - # Convert back to KVForm and round-trip back to dict to make - # sure that *** dict -> kv -> dict is identity. *** - kv = kvform.dictToKV(d) - d2 = kvform.kvToDict(kv) - self.failUnlessEqual(d, d2) + # Convert back to KVForm and round-trip back to dict to make + # sure that *** dict -> kv -> dict is identity. *** + kv = kvform.dictToKV(d) + d2 = kvform.kvToDict(kv) + self.failUnlessEqual(d, d2) class KVSeqTest(KVBaseTest): - def __init__(self, seq, kv, expected_warnings): - unittest.TestCase.__init__(self) - self.kvform = kv - self.seq = seq - self.expected_warnings = expected_warnings def cleanSeq(self, seq): """Create a new sequence by stripping whitespace from start @@ -62,19 +51,20 @@ def cleanSeq(self, seq): return clean def runTest(self): - # seq serializes to expected kvform - actual = kvform.seqToKV(self.seq) - self.failUnlessEqual(self.kvform, actual) - self.assertIsInstance(actual, str) + for kv_data, result, expected_warnings in kvseq_cases: + # seq serializes to expected kvform + actual = kvform.seqToKV(kv_data) + self.failUnlessEqual(actual, result) + self.assertIsInstance(actual, str) - # Parse back to sequence. Expected to be unchanged, except - # stripping whitespace from start and end of values - # (i. e. ordering, case, and internal whitespace is preserved) - seq = kvform.kvToSeq(actual) - clean_seq = self.cleanSeq(seq) + # Parse back to sequence. Expected to be unchanged, except + # stripping whitespace from start and end of values + # (i. e. ordering, case, and internal whitespace is preserved) + seq = kvform.kvToSeq(actual) + clean_seq = self.cleanSeq(seq) - self.failUnlessEqual(seq, clean_seq) - self.checkWarnings(self.expected_warnings) + self.failUnlessEqual(seq, clean_seq) + self.checkWarnings(expected_warnings) kvdict_cases = [ @@ -145,15 +135,10 @@ def runTest(self): class KVExcTest(unittest.TestCase): - def __init__(self, seq): - unittest.TestCase.__init__(self) - self.seq = seq - - def shortDescription(self): - return 'KVExcTest for %r' % (self.seq,) def runTest(self): - self.failUnlessRaises(ValueError, kvform.seqToKV, self.seq) + for kv_data in kvexc_cases: + self.failUnlessRaises(ValueError, kvform.seqToKV, kv_data) class GeneralTest(KVBaseTest): diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index b8af670e..c180f8c7 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -84,20 +84,34 @@ def parseExpected(expected_text): class MatchAcceptTest(unittest.TestCase): - def __init__(self, descr, accept_header, available, expected): - unittest.TestCase.__init__(self) - self.accept_header = accept_header - self.available = available - self.expected = expected - self.descr = descr - - def shortDescription(self): - return self.descr def runTest(self): - accepted = accept.parseAcceptHeader(self.accept_header) - actual = accept.matchTypes(accepted, self.available) - self.failUnlessEqual(self.expected, actual) + lines = getTestData() + chunks = chunk(lines) + data_sets = map(parseLines, chunks) + for data in data_sets: + lnos = [] + lno, accept_header = data['accept'] + lnos.append(lno) + lno, avail_data = data['available'] + lnos.append(lno) + try: + available = parseAvailable(avail_data) + except Exception: + print 'On line', lno + raise + + lno, exp_data = data['expected'] + lnos.append(lno) + try: + expected = parseExpected(exp_data) + except Exception: + print 'On line', lno + raise + + accepted = accept.parseAcceptHeader(accept_header) + actual = accept.matchTypes(accepted, available) + self.failUnlessEqual(expected, actual) def pyUnitTests(): diff --git a/openid/test/test_examples.py b/openid/test/test_examples.py index e1c5797f..3550a4c0 100644 --- a/openid/test/test_examples.py +++ b/openid/test/test_examples.py @@ -7,14 +7,33 @@ import unittest from cStringIO import StringIO -import twill.commands -import twill.parse -import twill.unit +from mock import Mock from openid.consumer.consumer import AuthRequest from openid.consumer.discover import OPENID_1_1_TYPE, OpenIDServiceEndpoint +class FakeTestInfo(object): + """Twill TestInfo placeholder.""" + + def __init__(self, *args, **kwargs): + pass + + +try: + import twill.commands + import twill.parse + import twill.unit +except ImportError: + twill = Mock() + twill.unit.TestInfo = FakeTestInfo + + +def setUpModule(): + if twill.unit.TestInfo == FakeTestInfo: + unittest.skip("Skipping examples, twill is not available.") + + class TwillTest(twill.unit.TestInfo): """Variant of twill.unit.TestInfo that runs a function as a test script, not twill script from a file. diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 16aebea1..3f730b5e 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -106,59 +106,54 @@ def subsets(l): class OpenIDYadisTest(unittest.TestCase): - def __init__(self, uris, type_uris, local_id): - unittest.TestCase.__init__(self) - self.uris = uris - self.type_uris = type_uris - self.local_id = local_id + + yadis_url = 'http://unit.test/' def shortDescription(self): # XXX: return 'Successful OpenID Yadis parsing case' - def setUp(self): - self.yadis_url = 'http://unit.test/' - + def make_xrds(self, uris, type_uris, local_id): # Create an XRDS document to parse - services = mkService(uris=self.uris, - type_uris=self.type_uris, - local_id=self.local_id) - self.xrds = mkXRDS(services) + services = mkService(uris=uris, + type_uris=type_uris, + local_id=local_id) + return mkXRDS(services) def runTest(self): - # Parse into endpoint objects that we will check - endpoints = applyFilter( - self.yadis_url, self.xrds, OpenIDServiceEndpoint) + for uris, type_uris, local_id in data: + # Parse into endpoint objects that we will check + endpoints = applyFilter(self.yadis_url, self.make_xrds(uris, type_uris, local_id), OpenIDServiceEndpoint) - # make sure there are the same number of endpoints as - # URIs. This assumes that the type_uris contains at least one - # OpenID type. - self.failUnlessEqual(len(self.uris), len(endpoints)) + # make sure there are the same number of endpoints as + # URIs. This assumes that the type_uris contains at least one + # OpenID type. + self.failUnlessEqual(len(uris), len(endpoints)) - # So that we can check equality on the endpoint types - type_uris = sorted(self.type_uris) + # So that we can check equality on the endpoint types + type_uris = sorted(type_uris) - seen_uris = [] - for endpoint in endpoints: - seen_uris.append(endpoint.server_url) + seen_uris = [] + for endpoint in endpoints: + seen_uris.append(endpoint.server_url) - # All endpoints will have same yadis_url - self.failUnlessEqual(self.yadis_url, endpoint.claimed_id) + # All endpoints will have same yadis_url + self.failUnlessEqual(self.yadis_url, endpoint.claimed_id) - # and local_id - self.failUnlessEqual(self.local_id, endpoint.local_id) + # and local_id + self.failUnlessEqual(local_id, endpoint.local_id) - # and types - actual_types = sorted(endpoint.type_uris) - self.failUnlessEqual(actual_types, type_uris) + # and types + actual_types = sorted(endpoint.type_uris) + self.failUnlessEqual(actual_types, type_uris) - # So that they will compare equal, because we don't care what - # order they are in - seen_uris.sort() - uris = sorted(self.uris) + # So that they will compare equal, because we don't care what + # order they are in + seen_uris.sort() + uris = sorted(uris) - # Make sure we saw all URIs, and saw each one once - self.failUnlessEqual(uris, seen_uris) + # Make sure we saw all URIs, and saw each one once + self.failUnlessEqual(uris, seen_uris) def pyUnitTests(): diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index a3c038da..4ee1b616 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -9,40 +9,28 @@ class _TestCase(unittest.TestCase): reserved_values = ['None', 'EOF'] - def __init__(self, filename, testname, expected, case): - self.filename = filename - self.testname = testname - self.expected = expected - self.case = case - unittest.TestCase.__init__(self) - def runTest(self): - p = YadisHTMLParser() - try: - p.feed(self.case) - except ParseDone as why: - found = why[0] - - # make sure we protect outselves against accidental bogus - # test cases - assert found not in self.reserved_values - - # convert to a string - if found is None: - found = 'None' - - msg = "%r != %r for case %s" % (found, self.expected, self.case) - self.failUnlessEqual(found, self.expected, msg) - except HTMLParseError: - self.failUnless(self.expected == 'None', (self.case, self.expected)) - else: - self.failUnless(self.expected == 'EOF', (self.case, self.expected)) - - def shortDescription(self): - return "%s (%s<%s>)" % ( - self.testname, - self.__class__.__module__, - os.path.basename(self.filename)) + for expected, case in getCases(): + p = YadisHTMLParser() + try: + p.feed(case) + except ParseDone as why: + found = why[0] + + # make sure we protect outselves against accidental bogus + # test cases + assert found not in self.reserved_values + + # convert to a string + if found is None: + found = 'None' + + msg = "%r != %r for case %s" % (found, expected, case) + self.failUnlessEqual(found, expected, msg) + except HTMLParseError: + self.failUnless(expected == 'None', (case, expected)) + else: + self.failUnless(expected == 'EOF', (case, expected)) def parseCases(data): @@ -78,11 +66,9 @@ def test(): def getCases(test_files=default_test_files): cases = [] for filename in test_files: - test_num = 0 data = file(filename).read() for expected, case in parseCases(data): - test_num += 1 - cases.append((filename, test_num, expected, case)) + cases.append((expected, case)) return cases diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 98e9f490..49f18fc9 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -4,31 +4,31 @@ import openid.urinorm -class UrinormTest(unittest.TestCase): - def __init__(self, desc, case, expected): - unittest.TestCase.__init__(self) - self.desc = desc - self.case = case - self.expected = expected +with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'urinorm.txt')) as test_data_file: + test_data = test_data_file.read() - def shortDescription(self): - return self.desc + +class UrinormTest(unittest.TestCase): def runTest(self): - try: - actual = openid.urinorm.urinorm(self.case) - except ValueError as why: - self.assertEqual(self.expected, 'fail', why) - else: - self.assertEqual(actual, self.expected) - - def parse(cls, full_case): + for case in test_data.split('\n\n'): + case = case.strip() + if not case: + continue + + desc, raw, expected = self.parse(case) + try: + actual = openid.urinorm.urinorm(raw) + except ValueError as why: + self.assertEqual(expected, 'fail', why) + else: + self.assertEqual(actual, expected, desc) + + def parse(self, full_case): desc, case, expected = full_case.split('\n') case = unicode(case, 'utf-8') - return cls(desc, case, expected) - - parse = classmethod(parse) + return (desc, case, expected) def parseTests(test_data): diff --git a/run_tests.sh b/run_tests.sh deleted file mode 100755 index 9cb637fa..00000000 --- a/run_tests.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/sh -python admin/runtests From 265f022c3c95ce666d60c0c0d385cf9c63b3f8af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 12 Dec 2017 17:22:21 +0100 Subject: [PATCH 026/151] Clean data driven tests --- openid/test/datadriven.py | 48 ------------------ openid/test/test_association.py | 42 ++++----------- openid/test/test_discover.py | 87 ++++++++++++-------------------- openid/test/test_htmldiscover.py | 23 +++------ openid/test/test_nonce.py | 39 ++++---------- 5 files changed, 59 insertions(+), 180 deletions(-) delete mode 100644 openid/test/datadriven.py diff --git a/openid/test/datadriven.py b/openid/test/datadriven.py deleted file mode 100644 index aac6e9db..00000000 --- a/openid/test/datadriven.py +++ /dev/null @@ -1,48 +0,0 @@ -import types -import unittest - - -class DataDrivenTestCase(unittest.TestCase): - cases = [] - - def generateCases(cls): - return cls.cases - - generateCases = classmethod(generateCases) - - def loadTests(cls): - tests = [] - for case in cls.generateCases(): - if isinstance(case, tuple): - test = cls(*case) - elif isinstance(case, dict): - test = cls(**case) - else: - test = cls(case) - tests.append(test) - return tests - - loadTests = classmethod(loadTests) - - def __init__(self, description): - unittest.TestCase.__init__(self, 'runOneTest') - self.description = description - - def shortDescription(self): - return '%s for %s' % (self.__class__.__name__, self.description) - - -def loadTests(module_name): - loader = unittest.defaultTestLoader - this_module = __import__(module_name, {}, {}, [None]) - - tests = [] - for name in dir(this_module): - obj = getattr(this_module, name) - if isinstance(obj, (type, types.ClassType)) and issubclass(obj, unittest.TestCase): - if hasattr(obj, 'loadTests'): - tests.extend(obj.loadTests()) - else: - tests.append(loader.loadTestsFromTestCase(obj)) - - return unittest.TestSuite(tests) diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 929d5b6e..e0598aff 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -6,7 +6,6 @@ from openid.dh import DiffieHellman from openid.message import BARE_NS, OPENID2_NS, OPENID_NS, Message from openid.server.server import DiffieHellmanSHA1ServerSession, PlainTextServerSession -from openid.test import datadriven class AssociationSerializationTest(unittest.TestCase): @@ -29,7 +28,7 @@ def createNonstandardConsumerDH(): return DiffieHellmanSHA1ConsumerSession(nonstandard_dh) -class DiffieHellmanSessionTest(datadriven.DataDrivenTestCase): +class DiffieHellmanSessionTest(unittest.TestCase): secrets = [ '\x00' * 20, '\xff' * 20, @@ -43,26 +42,15 @@ class DiffieHellmanSessionTest(datadriven.DataDrivenTestCase): (PlainTextConsumerSession, PlainTextServerSession), ] - def generateCases(cls): - return [(c, s, sec) - for c, s in cls.session_factories - for sec in cls.secrets] - - generateCases = classmethod(generateCases) - - def __init__(self, csess_fact, ssess_fact, secret): - datadriven.DataDrivenTestCase.__init__(self, csess_fact.__name__) - self.secret = secret - self.csess_fact = csess_fact - self.ssess_fact = ssess_fact - - def runOneTest(self): - csess = self.csess_fact() - msg = Message.fromOpenIDArgs(csess.getRequest()) - ssess = self.ssess_fact.fromMessage(msg) - check_secret = csess.extractSecret( - Message.fromOpenIDArgs(ssess.answer(self.secret))) - self.failUnlessEqual(self.secret, check_secret) + def test(self): + for csess_fact, ssess_fact in self.session_factories: + for secret in self.secrets: + csess = csess_fact() + msg = Message.fromOpenIDArgs(csess.getRequest()) + ssess = ssess_fact.fromMessage(msg) + check_secret = csess.extractSecret( + Message.fromOpenIDArgs(ssess.answer(secret))) + self.failUnlessEqual(secret, check_secret) class TestMakePairs(unittest.TestCase): @@ -155,13 +143,3 @@ def test_aintGotSignedList(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") self.failUnlessRaises(ValueError, assoc.checkMessageSignature, m) - - -def pyUnitTests(): - return datadriven.loadTests(__name__) - - -if __name__ == '__main__': - suite = pyUnitTests() - runner = unittest.TextTestRunner() - runner.run(suite) diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 29a73ffa..92fc6960 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -11,8 +11,6 @@ from openid.yadis.discover import DiscoveryFailure from openid.yadis.xri import XRI -from . import datadriven - # Tests for conditions that trigger DiscoveryFailure @@ -27,7 +25,7 @@ def fetch(self, url, body=None, headers=None): return response -class TestDiscoveryFailure(datadriven.DataDrivenTestCase): +class TestDiscoveryFailure(unittest.TestCase): cases = [ [HTTPResponse('http://network.error/', None)], [HTTPResponse('http://not.found/', 404)], @@ -38,27 +36,24 @@ class TestDiscoveryFailure(datadriven.DataDrivenTestCase): HTTPResponse('http://xrds.missing/', 404)], ] - def __init__(self, responses): - self.url = responses[0].final_url - datadriven.DataDrivenTestCase.__init__(self, self.url) - self.responses = responses - - def setUp(self): - fetcher = SimpleMockFetcher(self.responses) - fetchers.setDefaultFetcher(fetcher) - - def tearDown(self): - fetchers.setDefaultFetcher(None) - - def runOneTest(self): - expected_status = self.responses[-1].status + def runOneTest(self, url, expected_status): try: - discover.discover(self.url) + discover.discover(url) except DiscoveryFailure as why: self.failUnlessEqual(why.http_response.status, expected_status) else: self.fail('Did not raise DiscoveryFailure') + def test(self): + for responses in self.cases: + url = responses[0].final_url + status = responses[-1].status + + fetcher = SimpleMockFetcher(responses) + fetchers.setDefaultFetcher(fetcher) + self.runOneTest(url, status) + fetchers.setDefaultFetcher(None) + # Tests for raising/catching exceptions from the fetcher through the # discover function @@ -77,7 +72,7 @@ class DidFetch(Exception): """Custom exception just to make sure it's not handled differently""" -class TestFetchException(datadriven.DataDrivenTestCase): +class TestFetchException(unittest.TestCase): """Make sure exceptions get passed through discover function from fetcher.""" @@ -88,29 +83,25 @@ class TestFetchException(datadriven.DataDrivenTestCase): RuntimeError(), ] - def __init__(self, exc): - datadriven.DataDrivenTestCase.__init__(self, repr(exc)) - self.exc = exc - - def setUp(self): - fetcher = ErrorRaisingFetcher(self.exc) - fetchers.setDefaultFetcher(fetcher, wrap_exceptions=False) - - def tearDown(self): - fetchers.setDefaultFetcher(None) - - def runOneTest(self): + def runOneTest(self, exc): try: discover.discover('http://doesnt.matter/') except Exception: exc = sys.exc_info()[1] if exc is None: # str exception - self.failUnless(self.exc is sys.exc_info()[0]) + self.failUnless(exc is sys.exc_info()[0]) else: - self.failUnless(self.exc is exc, exc) + self.failUnless(exc is exc, exc) else: - self.fail('Expected %r', self.exc) + self.fail('Expected %r', exc) + + def test(self): + for exc in self.cases: + fetcher = ErrorRaisingFetcher(exc) + fetchers.setDefaultFetcher(fetcher, wrap_exceptions=False) + self.runOneTest(exc) + fetchers.setDefaultFetcher(None) # Tests for openid.consumer.discover.discover @@ -621,18 +612,14 @@ def test_xri(self): "http://www.livejournal.com/openid/server.bml") -class TestPreferredNamespace(datadriven.DataDrivenTestCase): - def __init__(self, expected_ns, type_uris): - datadriven.DataDrivenTestCase.__init__( - self, 'Expecting %s from %s' % (expected_ns, type_uris)) - self.expected_ns = expected_ns - self.type_uris = type_uris +class TestPreferredNamespace(unittest.TestCase): - def runOneTest(self): - endpoint = discover.OpenIDServiceEndpoint() - endpoint.type_uris = self.type_uris - actual_ns = endpoint.preferredNamespace() - self.failUnlessEqual(actual_ns, self.expected_ns) + def test(self): + for expected_ns, type_uris in self.cases: + endpoint = discover.OpenIDServiceEndpoint() + endpoint.type_uris = type_uris + actual_ns = endpoint.preferredNamespace() + self.failUnlessEqual(actual_ns, expected_ns) cases = [ (message.OPENID1_NS, []), @@ -797,13 +784,3 @@ def test_strip_fragment(self): endpoint = discover.OpenIDServiceEndpoint() endpoint.claimed_id = 'http://recycled.invalid/#123' self.failUnlessEqual('http://recycled.invalid/', endpoint.getDisplayIdentifier()) - - -def pyUnitTests(): - return datadriven.loadTests(__name__) - - -if __name__ == '__main__': - suite = pyUnitTests() - runner = unittest.TextTestRunner() - runner.run(suite) diff --git a/openid/test/test_htmldiscover.py b/openid/test/test_htmldiscover.py index 188565b2..9d7344a3 100644 --- a/openid/test/test_htmldiscover.py +++ b/openid/test/test_htmldiscover.py @@ -1,24 +1,17 @@ -from openid.consumer.discover import OpenIDServiceEndpoint +import unittest -from . import datadriven +from openid.consumer.discover import OpenIDServiceEndpoint -class BadLinksTestCase(datadriven.DataDrivenTestCase): +class BadLinksTestCase(unittest.TestCase): cases = [ '', "http://not.in.a.link.tag/", '', ] - def __init__(self, data): - datadriven.DataDrivenTestCase.__init__(self, data) - self.data = data - - def runOneTest(self): - actual = OpenIDServiceEndpoint.fromHTML('http://unused.url/', self.data) - expected = [] - self.failUnlessEqual(expected, actual) - - -def pyUnitTests(): - return datadriven.loadTests(__name__) + def test_from_html(self): + for html in self.cases: + actual = OpenIDServiceEndpoint.fromHTML('http://unused.url/', html) + expected = [] + self.failUnlessEqual(expected, actual) diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index 7b271346..fcc4687f 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -2,7 +2,6 @@ import unittest from openid.store.nonce import checkTimestamp, mkNonce, split as splitNonce -from openid.test import datadriven nonce_re = re.compile(r'\A\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ') @@ -36,7 +35,7 @@ def test_mkSplit(self): self.failUnlessEqual(et, t) -class BadSplitTest(datadriven.DataDrivenTestCase): +class BadSplitTest(unittest.TestCase): cases = [ '', '1970-01-01T00:00:00+1:00', @@ -47,15 +46,12 @@ class BadSplitTest(datadriven.DataDrivenTestCase): 'monkeys', ] - def __init__(self, nonce_str): - datadriven.DataDrivenTestCase.__init__(self, nonce_str) - self.nonce_str = nonce_str + def test(self): + for nonce_str in self.cases: + self.failUnlessRaises(ValueError, splitNonce, nonce_str) - def runOneTest(self): - self.failUnlessRaises(ValueError, splitNonce, self.nonce_str) - -class CheckTimestampTest(datadriven.DataDrivenTestCase): +class CheckTimestampTest(unittest.TestCase): cases = [ # exact, no allowed skew ('1970-01-01T00:00:00Z', 0, 0, True), @@ -82,24 +78,7 @@ class CheckTimestampTest(datadriven.DataDrivenTestCase): ('monkeys', 0, 0, False), ] - def __init__(self, nonce_string, allowed_skew, now, expected): - datadriven.DataDrivenTestCase.__init__( - self, repr((nonce_string, allowed_skew, now))) - self.nonce_string = nonce_string - self.allowed_skew = allowed_skew - self.now = now - self.expected = expected - - def runOneTest(self): - actual = checkTimestamp(self.nonce_string, self.allowed_skew, self.now) - self.failUnlessEqual(bool(self.expected), bool(actual)) - - -def pyUnitTests(): - return datadriven.loadTests(__name__) - - -if __name__ == '__main__': - suite = pyUnitTests() - runner = unittest.TextTestRunner() - runner.run(suite) + def test(self): + for nonce_string, allowed_skew, now, expected in self.cases: + actual = checkTimestamp(nonce_string, allowed_skew, now) + self.failUnlessEqual(bool(expected), bool(actual)) From a3dd223cd78bafcedc2714090cb264ccef9948b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 12 Dec 2017 17:47:30 +0100 Subject: [PATCH 027/151] Rename test modules --- openid/test/{cryptutil.py => test_cryptutil.py} | 0 openid/test/{dh.py => test_dh.py} | 0 openid/test/{kvform.py => test_kvform.py} | 0 openid/test/{linkparse.py => test_linkparse.py} | 0 openid/test/{oidutil.py => test_oidutil.py} | 0 openid/test/{storetest.py => test_storetest.py} | 0 openid/test/{trustroot.py => test_trustroot.py} | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename openid/test/{cryptutil.py => test_cryptutil.py} (100%) rename openid/test/{dh.py => test_dh.py} (100%) rename openid/test/{kvform.py => test_kvform.py} (100%) rename openid/test/{linkparse.py => test_linkparse.py} (100%) rename openid/test/{oidutil.py => test_oidutil.py} (100%) rename openid/test/{storetest.py => test_storetest.py} (100%) rename openid/test/{trustroot.py => test_trustroot.py} (100%) diff --git a/openid/test/cryptutil.py b/openid/test/test_cryptutil.py similarity index 100% rename from openid/test/cryptutil.py rename to openid/test/test_cryptutil.py diff --git a/openid/test/dh.py b/openid/test/test_dh.py similarity index 100% rename from openid/test/dh.py rename to openid/test/test_dh.py diff --git a/openid/test/kvform.py b/openid/test/test_kvform.py similarity index 100% rename from openid/test/kvform.py rename to openid/test/test_kvform.py diff --git a/openid/test/linkparse.py b/openid/test/test_linkparse.py similarity index 100% rename from openid/test/linkparse.py rename to openid/test/test_linkparse.py diff --git a/openid/test/oidutil.py b/openid/test/test_oidutil.py similarity index 100% rename from openid/test/oidutil.py rename to openid/test/test_oidutil.py diff --git a/openid/test/storetest.py b/openid/test/test_storetest.py similarity index 100% rename from openid/test/storetest.py rename to openid/test/test_storetest.py diff --git a/openid/test/trustroot.py b/openid/test/test_trustroot.py similarity index 100% rename from openid/test/trustroot.py rename to openid/test/test_trustroot.py From a14a20576f1d5e21806886a0af1b89203c8ca31a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 12 Dec 2017 16:46:35 +0100 Subject: [PATCH 028/151] Drop custom test runner --- .travis.yml | 2 +- admin/fixperms | 3 +- admin/runtests | 207 ------------------- openid/test/test_accept.py | 37 +--- openid/test/test_cryptutil.py | 171 +++++++--------- openid/test/test_dh.py | 135 ++++++------- openid/test/test_fetchers.py | 36 ++-- openid/test/test_kvform.py | 31 ++- openid/test/test_linkparse.py | 86 +++----- openid/test/test_oidutil.py | 271 +++++++++++-------------- openid/test/test_openidyadis.py | 7 - openid/test/test_parsehtml.py | 22 +-- openid/test/test_storetest.py | 307 ++++++++++++++--------------- openid/test/test_trustroot.py | 87 +++----- openid/test/test_urinorm.py | 25 --- openid/test/test_yadis_discover.py | 117 ++++------- 16 files changed, 537 insertions(+), 1007 deletions(-) delete mode 100755 admin/runtests diff --git a/.travis.yml b/.travis.yml index fe119573..7d67a575 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,7 +3,7 @@ language: python python: - 2.7 -before_install: pip install Django pycrypto lxml isort flake8 +before_install: pip install 'Django<=1.11.99' pycrypto lxml isort flake8 install: python setup.py install script: - make check-isort diff --git a/admin/fixperms b/admin/fixperms index d0303e11..8bcf8eca 100755 --- a/admin/fixperms +++ b/admin/fixperms @@ -4,7 +4,6 @@ admin/builddiscover.py admin/fixperms admin/makechangelog admin/pythonsource -admin/runtests admin/setversion admin/tagrelease -EOF \ No newline at end of file +EOF diff --git a/admin/runtests b/admin/runtests deleted file mode 100755 index db7a647e..00000000 --- a/admin/runtests +++ /dev/null @@ -1,207 +0,0 @@ -#!/usr/bin/env python -import os.path -import sys -import warnings - -test_modules = [ - 'cryptutil', - 'oidutil', - 'dh', -] - - -def fixpath(): - try: - d = os.path.dirname(__file__) - except NameError: - d = os.path.dirname(sys.argv[0]) - parent = os.path.normpath(os.path.join(d, '..')) - if parent not in sys.path: - print "putting %s in sys.path" % (parent,) - sys.path.insert(0, parent) - - -def otherTests(): - failed = [] - for module_name in test_modules: - print 'Testing %s...' % (module_name,), - sys.stdout.flush() - module_name = 'openid.test.' + module_name - try: - test_mod = __import__(module_name, {}, {}, [None]) - except ImportError: - print 'Failed to import test %r' % (module_name,) - failed.append(module_name) - else: - try: - test_mod.test() - except Exception: - sys.excepthook(*sys.exc_info()) - failed.append(module_name) - else: - print 'Succeeded.' - - return failed - - -def pyunitTests(): - import unittest - pyunit_module_names = [ - 'server', - 'consumer', - 'message', - 'symbol', - 'etxrd', - 'xri', - 'xrires', - 'association_response', - 'auth_request', - 'negotiation', - 'verifydisco', - 'sreg', - 'ax', - 'pape', - 'pape_draft2', - 'pape_draft5', - 'rpverify', - 'extension', - ] - - pyunit_modules = [ - __import__('openid.test.test_%s' % (name,), {}, {}, ['unused']) - for name in pyunit_module_names - ] - - try: - from openid.test import test_examples - except ImportError as e: - if 'twill' in str(e): - warnings.warn("Could not import twill; skipping test_examples.") - else: - raise - else: - pyunit_modules.append(test_examples) - - # Some modules have data-driven tests, and they use custom methods - # to build the test suite: - custom_module_names = [ - 'kvform', - 'linkparse', - 'oidutil', - 'storetest', - 'test_accept', - 'test_association', - 'test_discover', - 'test_fetchers', - 'test_htmldiscover', - 'test_nonce', - 'test_openidyadis', - 'test_parsehtml', - 'test_urinorm', - 'test_yadis_discover', - 'trustroot', - ] - - loader = unittest.TestLoader() - s = unittest.TestSuite() - - for m in pyunit_modules: - s.addTest(loader.loadTestsFromModule(m)) - - for name in custom_module_names: - m = __import__('openid.test.%s' % (name,), {}, {}, ['unused']) - try: - s.addTest(m.pyUnitTests()) - except AttributeError as ex: - # because the AttributeError doesn't actually say which - # object it was. - print "Error loading tests from %s:" % (name,) - raise - - runner = unittest.TextTestRunner() # verbosity=2) - - return runner.run(s) - - -def splitDir(d, count): - # in python2.4 and above, it's easier to spell this as - # d.rsplit(os.sep, count) - for i in xrange(count): - d = os.path.dirname(d) - return d - - -def _import_djopenid(): - """Import djopenid from examples/ - - It's not in sys.path, and I don't really want to put it in sys.path. - """ - import types - thisfile = os.path.abspath(sys.modules[__name__].__file__) - topDir = splitDir(thisfile, 2) - djdir = os.path.join(topDir, 'examples', 'djopenid') - - djinit = os.path.join(djdir, '__init__.py') - - djopenid = types.ModuleType('djopenid') - execfile(djinit, djopenid.__dict__) - djopenid.__file__ = djinit - - # __path__ is the magic that makes child modules of the djopenid package - # importable. New feature in python 2.3, see PEP 302. - djopenid.__path__ = [djdir] - sys.modules['djopenid'] = djopenid - - -def django_tests(): - """Runs tests from examples/djopenid. - - @returns: number of failed tests. - """ - import os - # Django uses this to find out where its settings are. - os.environ['DJANGO_SETTINGS_MODULE'] = 'djopenid.settings' - - _import_djopenid() - - try: - import django.test.simple - except ImportError as e: - warnings.warn("django.test.simple not found; " - "django examples not tested.") - return 0 - import djopenid.server.models - import djopenid.consumer.models - print "Testing Django examples:" - - # These tests do get put in to a pyunit test suite, so we could run them - # with the other pyunit tests, but django also establishes a test database - # for them, so we let it do that thing instead. - return django.test.simple.run_tests([djopenid.server.models, - djopenid.consumer.models]) - - -try: - bool -except NameError: - def bool(x): - return not not x - - -def main(): - fixpath() - other_failed = otherTests() - pyunit_result = pyunitTests() - django_failures = django_tests() - - if other_failed: - print 'Failures:', ', '.join(other_failed) - - failed = (bool(other_failed) or - bool(not pyunit_result.wasSuccessful()) or - (django_failures > 0)) - return failed - - -if __name__ == '__main__': - sys.exit(main() and 1 or 0) diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index c180f8c7..55e7eded 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -1,3 +1,4 @@ +"""Test `openid.yadis.accept` module.""" import os.path import unittest @@ -112,39 +113,3 @@ def runTest(self): accepted = accept.parseAcceptHeader(accept_header) actual = accept.matchTypes(accepted, available) self.failUnlessEqual(expected, actual) - - -def pyUnitTests(): - lines = getTestData() - chunks = chunk(lines) - data_sets = map(parseLines, chunks) - cases = [] - for data in data_sets: - lnos = [] - lno, header = data['accept'] - lnos.append(lno) - lno, avail_data = data['available'] - lnos.append(lno) - try: - available = parseAvailable(avail_data) - except Exception: - print 'On line', lno - raise - - lno, exp_data = data['expected'] - lnos.append(lno) - try: - expected = parseExpected(exp_data) - except Exception: - print 'On line', lno - raise - - descr = 'MatchAcceptTest for lines %r' % (lnos,) - case = MatchAcceptTest(descr, header, available, expected) - cases.append(case) - return unittest.TestSuite(cases) - - -if __name__ == '__main__': - runner = unittest.TextTestRunner() - runner.run(pyUnitTests()) diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index cf6074c1..4cf57b00 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -1,6 +1,8 @@ +"""Test `openid.cryptutil` module.""" import os.path import random import sys +import unittest from openid import cryptutil @@ -8,107 +10,84 @@ # find a good source of randomness on this machine. -def test_cryptrand(): - # It's possible, but HIGHLY unlikely that a correct implementation - # will fail by returning the same number twice +class TestRandRange(unittest.TestCase): + """Test `randrange` function.""" - s = cryptutil.getBytes(32) - t = cryptutil.getBytes(32) - assert len(s) == 32 - assert len(t) == 32 - assert s != t + def test_cryptrand(self): + # It's possible, but HIGHLY unlikely that a correct implementation + # will fail by returning the same number twice - a = cryptutil.randrange(2 ** 128) - b = cryptutil.randrange(2 ** 128) - assert isinstance(a, long) - assert isinstance(b, long) - assert b != a + s = cryptutil.getBytes(32) + t = cryptutil.getBytes(32) + assert len(s) == 32 + assert len(t) == 32 + assert s != t - # Make sure that we can generate random numbers that are larger - # than platform int size - cryptutil.randrange(long(sys.maxsize) + 1) + a = cryptutil.randrange(2 ** 128) + b = cryptutil.randrange(2 ** 128) + assert isinstance(a, long) + assert isinstance(b, long) + assert b != a + # Make sure that we can generate random numbers that are larger + # than platform int size + cryptutil.randrange(long(sys.maxsize) + 1) + + +class TestLongBinary(unittest.TestCase): + """Test `longToBinary` and `binaryToLong` functions.""" + + def test_binaryLongConvert(self): + MAX = sys.maxsize + for iteration in xrange(500): + n = 0 + for i in range(10): + n += long(random.randrange(MAX)) + + s = cryptutil.longToBinary(n) + assert isinstance(s, str) + n_prime = cryptutil.binaryToLong(s) + assert n == n_prime, (n, n_prime) -def test_reversed(): - if hasattr(cryptutil, 'reversed'): cases = [ - ('', ''), - ('a', 'a'), - ('ab', 'ba'), - ('abc', 'cba'), - ('abcdefg', 'gfedcba'), - ([], []), - ([1], [1]), - ([1, 2], [2, 1]), - ([1, 2, 3], [3, 2, 1]), - (range(1000), range(999, -1, -1)), + ('\x00', 0), + ('\x01', 1), + ('\x7F', 127), + ('\x00\xFF', 255), + ('\x00\x80', 128), + ('\x00\x81', 129), + ('\x00\x80\x00', 32768), + ('OpenID is cool', 1611215304203901150134421257416556) ] - for case, expected in cases: - expected = list(expected) - actual = list(cryptutil.reversed(case)) - assert actual == expected, (case, expected, actual) - twice = list(cryptutil.reversed(actual)) - assert twice == list(case), (actual, case, twice) - - -def test_binaryLongConvert(): - MAX = sys.maxsize - for iteration in xrange(500): - n = 0 - for i in range(10): - n += long(random.randrange(MAX)) - - s = cryptutil.longToBinary(n) - assert isinstance(s, str) - n_prime = cryptutil.binaryToLong(s) - assert n == n_prime, (n, n_prime) - - cases = [ - ('\x00', 0), - ('\x01', 1), - ('\x7F', 127), - ('\x00\xFF', 255), - ('\x00\x80', 128), - ('\x00\x81', 129), - ('\x00\x80\x00', 32768), - ('OpenID is cool', 1611215304203901150134421257416556) - ] - - for s, n in cases: - n_prime = cryptutil.binaryToLong(s) - s_prime = cryptutil.longToBinary(n) - assert n == n_prime, (s, n, n_prime) - assert s == s_prime, (n, s, s_prime) - - -def test_longToBase64(): - f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) - try: - for line in f: - parts = line.strip().split(' ') - assert parts[0] == cryptutil.longToBase64(long(parts[1])) - finally: - f.close() - - -def test_base64ToLong(): - f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) - try: - for line in f: - parts = line.strip().split(' ') - assert long(parts[1]) == cryptutil.base64ToLong(parts[0]) - finally: - f.close() - - -def test(): - test_reversed() - test_binaryLongConvert() - test_cryptrand() - test_longToBase64() - test_base64ToLong() - - -if __name__ == '__main__': - test() + for s, n in cases: + n_prime = cryptutil.binaryToLong(s) + s_prime = cryptutil.longToBinary(n) + assert n == n_prime, (s, n, n_prime) + assert s == s_prime, (n, s, s_prime) + + +class TestLongToBase64(unittest.TestCase): + """Test `longToBase64` function.""" + + def test_longToBase64(self): + f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) + try: + for line in f: + parts = line.strip().split(' ') + assert parts[0] == cryptutil.longToBase64(long(parts[1])) + finally: + f.close() + + +class TestBase64ToLong(unittest.TestCase): + """Test `Base64ToLong` function.""" + + def test_base64ToLong(self): + f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) + try: + for line in f: + parts = line.strip().split(' ') + assert long(parts[1]) == cryptutil.base64ToLong(parts[0]) + finally: + f.close() diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 01a6ab52..6c78a0b9 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -1,77 +1,72 @@ +"""Test `openid.dh` module.""" import os.path +import unittest from openid.dh import DiffieHellman, strxor -def test_strxor(): - NUL = '\x00' - - cases = [ - (NUL, NUL, NUL), - ('\x01', NUL, '\x01'), - ('a', 'a', NUL), - ('a', NUL, 'a'), - ('abc', NUL * 3, 'abc'), - ('x' * 10, NUL * 10, 'x' * 10), - ('\x01', '\x02', '\x03'), - ('\xf0', '\x0f', '\xff'), - ('\xff', '\x0f', '\xf0'), - ] - - for aa, bb, expected in cases: - actual = strxor(aa, bb) - assert actual == expected, (aa, bb, expected, actual) - - exc_cases = [ - ('', 'a'), - ('foo', 'ba'), - (NUL * 3, NUL * 4), - (''.join(map(chr, xrange(256))), - ''.join(map(chr, xrange(128)))), - ] - - for aa, bb in exc_cases: +class TestStrXor(unittest.TestCase): + """Test `strxor` function.""" + + def test_strxor(self): + NUL = '\x00' + + cases = [ + (NUL, NUL, NUL), + ('\x01', NUL, '\x01'), + ('a', 'a', NUL), + ('a', NUL, 'a'), + ('abc', NUL * 3, 'abc'), + ('x' * 10, NUL * 10, 'x' * 10), + ('\x01', '\x02', '\x03'), + ('\xf0', '\x0f', '\xff'), + ('\xff', '\x0f', '\xf0'), + ] + + for aa, bb, expected in cases: + actual = strxor(aa, bb) + assert actual == expected, (aa, bb, expected, actual) + + exc_cases = [ + ('', 'a'), + ('foo', 'ba'), + (NUL * 3, NUL * 4), + (''.join(map(chr, xrange(256))), + ''.join(map(chr, xrange(128)))), + ] + + for aa, bb in exc_cases: + try: + unexpected = strxor(aa, bb) + except ValueError: + pass + else: + assert False, 'Expected ValueError, got %r' % (unexpected,) + + +class TestDiffieHellman(unittest.TestCase): + + def _test_dh(self): + dh1 = DiffieHellman.fromDefaults() + dh2 = DiffieHellman.fromDefaults() + secret1 = dh1.getSharedSecret(dh2.public) + secret2 = dh2.getSharedSecret(dh1.public) + assert secret1 == secret2 + return secret1 + + def test_exchange(self): + s1 = self._test_dh() + s2 = self._test_dh() + assert s1 != s2 + + def test_public(self): + f = file(os.path.join(os.path.dirname(__file__), 'dhpriv')) + dh = DiffieHellman.fromDefaults() try: - unexpected = strxor(aa, bb) - except ValueError: - pass - else: - assert False, 'Expected ValueError, got %r' % (unexpected,) - - -def test1(): - dh1 = DiffieHellman.fromDefaults() - dh2 = DiffieHellman.fromDefaults() - secret1 = dh1.getSharedSecret(dh2.public) - secret2 = dh2.getSharedSecret(dh1.public) - assert secret1 == secret2 - return secret1 - - -def test_exchange(): - s1 = test1() - s2 = test1() - assert s1 != s2 - - -def test_public(): - f = file(os.path.join(os.path.dirname(__file__), 'dhpriv')) - dh = DiffieHellman.fromDefaults() - try: - for line in f: - parts = line.strip().split(' ') - dh._setPrivate(long(parts[0])) - - assert dh.public == long(parts[1]) - finally: - f.close() - - -def test(): - test_exchange() - test_public() - test_strxor() - + for line in f: + parts = line.strip().split(' ') + dh._setPrivate(long(parts[0])) -if __name__ == '__main__': - test() + assert dh.public == long(parts[1]) + finally: + f.close() diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 4cf5a22b..d7a4e6c7 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -196,23 +196,24 @@ def finish(self): self.rfile.close() -def test(): - import socket - host = socket.getfqdn('127.0.0.1') - # When I use port 0 here, it works for the first fetch and the - # next one gets connection refused. Bummer. So instead, pick a - # port that's *probably* not in use. - import os - port = (os.getpid() % 31000) + 1024 +class TestFetchers(unittest.TestCase): + def test(self): + import socket + host = socket.getfqdn('127.0.0.1') + # When I use port 0 here, it works for the first fetch and the + # next one gets connection refused. Bummer. So instead, pick a + # port that's *probably* not in use. + import os + port = (os.getpid() % 31000) + 1024 - server = HTTPServer((host, port), FetcherTestHandler) + server = HTTPServer((host, port), FetcherTestHandler) - import threading - server_thread = threading.Thread(target=server.serve_forever) - server_thread.setDaemon(True) - server_thread.start() + import threading + server_thread = threading.Thread(target=server.serve_forever) + server_thread.setDaemon(True) + server_thread.start() - run_fetcher_tests(server) + run_fetcher_tests(server) class FakeFetcher(object): @@ -356,10 +357,3 @@ class TestSilencedUrllib2Fetcher(TestUrllib2Fetcher): fetcher = fetchers.ExceptionWrappingFetcher(fetchers.Urllib2Fetcher()) invalid_url_error = fetchers.HTTPFetchingError - - -def pyUnitTests(): - case1 = unittest.FunctionTestCase(test) - loadTests = unittest.defaultTestLoader.loadTestsFromTestCase - case2 = loadTests(DefaultFetcherTest) - return unittest.TestSuite([case1, case2, loadTests(TestUrllib2Fetcher), loadTests(TestSilencedUrllib2Fetcher)]) diff --git a/openid/test/test_kvform.py b/openid/test/test_kvform.py index a2ebd7d0..19929279 100644 --- a/openid/test/test_kvform.py +++ b/openid/test/test_kvform.py @@ -6,8 +6,11 @@ class KVBaseTest(unittest.TestCase, CatchLogs): - def checkWarnings(self, num_warnings): - self.failUnlessEqual(num_warnings, len(self.messages), repr(self.messages)) + def checkWarnings(self, num_warnings, msg=None): + full_msg = 'Invalid number of warnings {} != {}'.format(num_warnings, len(self.messages)) + if msg is not None: + full_msg = full_msg + ' ' + msg + self.failUnlessEqual(num_warnings, len(self.messages), full_msg) def setUp(self): CatchLogs.setUp(self) @@ -20,6 +23,9 @@ class KVDictTest(KVBaseTest): def runTest(self): for kv_data, result, expected_warnings in kvdict_cases: + # Clean captrured messages + del self.messages[:] + # Convert KVForm to dict d = kvform.kvToDict(kv_data) @@ -27,7 +33,7 @@ def runTest(self): self.failUnlessEqual(d, result) # Check to make sure we got the expected number of warnings - self.checkWarnings(expected_warnings) + self.checkWarnings(expected_warnings, msg='kvToDict({!r})'.format(kv_data)) # Convert back to KVForm and round-trip back to dict to make # sure that *** dict -> kv -> dict is identity. *** @@ -42,7 +48,7 @@ def cleanSeq(self, seq): """Create a new sequence by stripping whitespace from start and end of each value of each pair""" clean = [] - for k, v in self.seq: + for k, v in seq: if isinstance(k, str): k = k.decode('utf8') if isinstance(v, str): @@ -52,6 +58,9 @@ def cleanSeq(self, seq): def runTest(self): for kv_data, result, expected_warnings in kvseq_cases: + # Clean captrured messages + del self.messages[:] + # seq serializes to expected kvform actual = kvform.seqToKV(kv_data) self.failUnlessEqual(actual, result) @@ -148,17 +157,3 @@ def test_convert(self): result = kvform.seqToKV([(1, 1)]) self.failUnlessEqual(result, '1:1\n') self.checkWarnings(2) - - -def pyUnitTests(): - tests = [KVDictTest(*case) for case in kvdict_cases] - tests.extend([KVSeqTest(*case) for case in kvseq_cases]) - tests.extend([KVExcTest(case) for case in kvexc_cases]) - tests.append(unittest.defaultTestLoader.loadTestsFromTestCase(GeneralTest)) - return unittest.TestSuite(tests) - - -if __name__ == '__main__': - suite = pyUnitTests() - runner = unittest.TextTestRunner() - runner.run(suite) diff --git a/openid/test/test_linkparse.py b/openid/test/test_linkparse.py index f31f9ef2..230bd051 100644 --- a/openid/test/test_linkparse.py +++ b/openid/test/test_linkparse.py @@ -1,4 +1,4 @@ -import codecs +"""Test `openid.consumer.html_parse` module.""" import os.path import unittest @@ -47,70 +47,42 @@ def parseTests(s): desc, markup, links = parseCase(case) tests.append((desc, markup, links, case)) + assert len(tests) == num_tests, (len(tests), num_tests) return num_tests, tests -class _LinkTest(unittest.TestCase): - def __init__(self, desc, case, expected, raw): - unittest.TestCase.__init__(self) - self.desc = desc - self.case = case - self.expected = expected - self.raw = raw +with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'linkparse.txt')) as link_test_data_file: + link_test_data = link_test_data_file.read().decode('utf-8') - def shortDescription(self): - return self.desc + +class LinkTest(unittest.TestCase): + """Test `parseLinkAttrs` function.""" def runTest(self): - actual = parseLinkAttrs(self.case) - i = 0 - for optional, exp_link in self.expected: - if optional: - if i >= len(actual): - continue - - act_link = actual[i] - for k, (o, v) in exp_link.items(): - if o: - act_v = act_link.get(k) - if act_v is None: + num_tests, test_cases = parseTests(link_test_data) + + for desc, case, expected, raw in test_cases: + actual = parseLinkAttrs(case) + i = 0 + for optional, exp_link in expected: + if optional: + if i >= len(actual): continue - else: - act_v = act_link[k] - - if optional and v != act_v: - break - - self.assertEqual(v, act_v) - else: - i += 1 - - assert i == len(actual) - -def pyUnitTests(): - here = os.path.dirname(os.path.abspath(__file__)) - test_data_file_name = os.path.join(here, 'linkparse.txt') - test_data_file = codecs.open(test_data_file_name, 'r', 'utf-8') - test_data = test_data_file.read() - test_data_file.close() + act_link = actual[i] + for k, (o, v) in exp_link.items(): + if o: + act_v = act_link.get(k) + if act_v is None: + continue + else: + act_v = act_link[k] - num_tests, test_cases = parseTests(test_data) - - tests = [_LinkTest(*case) for case in test_cases] - - def test_parseSucceeded(): - assert len(test_cases) == num_tests, (len(test_cases), num_tests) - - check_desc = 'Check that we parsed the correct number of test cases' - check = unittest.FunctionTestCase( - test_parseSucceeded, description=check_desc) - tests.insert(0, check) - - return unittest.TestSuite(tests) + if optional and v != act_v: + break + self.assertEqual(v, act_v) + else: + i += 1 -if __name__ == '__main__': - suite = pyUnitTests() - runner = unittest.TextTestRunner() - runner.run(suite) + assert i == len(actual) diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index c7a002fa..7eed8865 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +"""Test `openid.oidutil` module.""" import random import string import unittest @@ -6,55 +7,128 @@ from openid import oidutil -def test_base64(): - allowed_s = string.ascii_letters + string.digits + '+/=' - allowed_d = {} - for c in allowed_s: - allowed_d[c] = None - isAllowed = allowed_d.has_key - - def checkEncoded(s): - for c in s: - assert isAllowed(c), s - - cases = [ - '', - 'x', - '\x00', - '\x01', - '\x00' * 100, - ''.join(map(chr, range(256))), - ] - - for s in cases: - b64 = oidutil.toBase64(s) - checkEncoded(b64) - s_prime = oidutil.fromBase64(b64) - assert s_prime == s, (s, b64, s_prime) - - # Randomized test - for _ in xrange(50): - n = random.randrange(2048) - s = ''.join(map(chr, map(lambda _: random.randrange(256), range(n)))) - b64 = oidutil.toBase64(s) - checkEncoded(b64) - s_prime = oidutil.fromBase64(b64) - assert s_prime == s, (s, b64, s_prime) +class TestBase64(unittest.TestCase): + """Test `toBase64` and `fromBase64` functions.""" + + def test_base64(self): + allowed_s = string.ascii_letters + string.digits + '+/=' + allowed_d = {} + for c in allowed_s: + allowed_d[c] = None + isAllowed = allowed_d.has_key + + def checkEncoded(s): + for c in s: + assert isAllowed(c), s + + cases = [ + '', + 'x', + '\x00', + '\x01', + '\x00' * 100, + ''.join(map(chr, range(256))), + ] + + for s in cases: + b64 = oidutil.toBase64(s) + checkEncoded(b64) + s_prime = oidutil.fromBase64(b64) + assert s_prime == s, (s, b64, s_prime) + + # Randomized test + for _ in xrange(50): + n = random.randrange(2048) + s = ''.join(map(chr, map(lambda _: random.randrange(256), range(n)))) + b64 = oidutil.toBase64(s) + checkEncoded(b64) + s_prime = oidutil.fromBase64(b64) + assert s_prime == s, (s, b64, s_prime) + + +simple = 'http://www.example.com/' +append_args_cases = [ + ('empty list', + (simple, []), + simple), + + ('empty dict', + (simple, {}), + simple), + + ('one list', + (simple, [('a', 'b')]), + simple + '?a=b'), + + ('one dict', + (simple, {'a': 'b'}), + simple + '?a=b'), + + ('two list (same)', + (simple, [('a', 'b'), ('a', 'c')]), + simple + '?a=b&a=c'), + + ('two list', + (simple, [('a', 'b'), ('b', 'c')]), + simple + '?a=b&b=c'), + + ('two list (order)', + (simple, [('b', 'c'), ('a', 'b')]), + simple + '?b=c&a=b'), + + ('two dict (order)', + (simple, {'b': 'c', 'a': 'b'}), + simple + '?a=b&b=c'), + + ('escape', + (simple, [('=', '=')]), + simple + '?%3D=%3D'), + + ('escape (URL)', + (simple, [('this_url', simple)]), + simple + '?this_url=https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fwww.example.com%2F'), + + ('use dots', + (simple, [('openid.stuff', 'bother')]), + simple + '?openid.stuff=bother'), + + ('args exist (empty)', + (simple + '?stuff=bother', []), + simple + '?stuff=bother'), + + ('args exist', + (simple + '?stuff=bother', [('ack', 'ack')]), + simple + '?stuff=bother&ack=ack'), + + ('args exist', + (simple + '?stuff=bother', [('ack', 'ack')]), + simple + '?stuff=bother&ack=ack'), + + ('args exist (dict)', + (simple + '?stuff=bother', {'ack': 'ack'}), + simple + '?stuff=bother&ack=ack'), + + ('args exist (dict 2)', + (simple + '?stuff=bother', {'ack': 'ack', 'zebra': 'lion'}), + simple + '?stuff=bother&ack=ack&zebra=lion'), + + ('three args (dict)', + (simple, {'stuff': 'bother', 'ack': 'ack', 'zebra': 'lion'}), + simple + '?ack=ack&stuff=bother&zebra=lion'), + + ('three args (list)', + (simple, [('stuff', 'bother'), ('ack', 'ack'), ('zebra', 'lion')]), + simple + '?stuff=bother&ack=ack&zebra=lion'), +] class AppendArgsTest(unittest.TestCase): - def __init__(self, desc, args, expected): - unittest.TestCase.__init__(self) - self.desc = desc - self.args = args - self.expected = expected + """Test `appendArgs` function.""" def runTest(self): - result = oidutil.appendArgs(*self.args) - self.assertEqual(self.expected, result, self.args) - - def shortDescription(self): - return self.desc + for name, args, expected in append_args_cases: + result = oidutil.appendArgs(*args) + self.assertEqual(expected, result, '{} {}'.format(name, args)) class TestUnicodeConversion(unittest.TestCase): @@ -82,115 +156,6 @@ def testCopyHash(self): self.failIfEqual(hash(s), hash(t)) -def buildAppendTests(): - simple = 'http://www.example.com/' - cases = [ - ('empty list', - (simple, []), - simple), - - ('empty dict', - (simple, {}), - simple), - - ('one list', - (simple, [('a', 'b')]), - simple + '?a=b'), - - ('one dict', - (simple, {'a': 'b'}), - simple + '?a=b'), - - ('two list (same)', - (simple, [('a', 'b'), ('a', 'c')]), - simple + '?a=b&a=c'), - - ('two list', - (simple, [('a', 'b'), ('b', 'c')]), - simple + '?a=b&b=c'), - - ('two list (order)', - (simple, [('b', 'c'), ('a', 'b')]), - simple + '?b=c&a=b'), - - ('two dict (order)', - (simple, {'b': 'c', 'a': 'b'}), - simple + '?a=b&b=c'), - - ('escape', - (simple, [('=', '=')]), - simple + '?%3D=%3D'), - - ('escape (URL)', - (simple, [('this_url', simple)]), - simple + '?this_url=https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fwww.example.com%2F'), - - ('use dots', - (simple, [('openid.stuff', 'bother')]), - simple + '?openid.stuff=bother'), - - ('args exist (empty)', - (simple + '?stuff=bother', []), - simple + '?stuff=bother'), - - ('args exist', - (simple + '?stuff=bother', [('ack', 'ack')]), - simple + '?stuff=bother&ack=ack'), - - ('args exist', - (simple + '?stuff=bother', [('ack', 'ack')]), - simple + '?stuff=bother&ack=ack'), - - ('args exist (dict)', - (simple + '?stuff=bother', {'ack': 'ack'}), - simple + '?stuff=bother&ack=ack'), - - ('args exist (dict 2)', - (simple + '?stuff=bother', {'ack': 'ack', 'zebra': 'lion'}), - simple + '?stuff=bother&ack=ack&zebra=lion'), - - ('three args (dict)', - (simple, {'stuff': 'bother', 'ack': 'ack', 'zebra': 'lion'}), - simple + '?ack=ack&stuff=bother&zebra=lion'), - - ('three args (list)', - (simple, [('stuff', 'bother'), ('ack', 'ack'), ('zebra', 'lion')]), - simple + '?stuff=bother&ack=ack&zebra=lion'), - ] - - tests = [] - - for name, args, expected in cases: - test = AppendArgsTest(name, args, expected) - tests.append(test) - - return unittest.TestSuite(tests) - - -def pyUnitTests(): - some = buildAppendTests() - some.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestSymbol)) - some.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestUnicodeConversion)) - return some - - -def test_appendArgs(): - suite = buildAppendTests() - suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestSymbol)) - runner = unittest.TextTestRunner() - result = runner.run(suite) - assert result.wasSuccessful() - # XXX: there are more functions that could benefit from being better # specified and tested in oidutil.py These include, but are not # limited to appendArgs - - -def test(skipPyUnit=True): - test_base64() - if not skipPyUnit: - test_appendArgs() - - -if __name__ == '__main__': - test(skipPyUnit=False) diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 3f730b5e..4a76749f 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -154,10 +154,3 @@ def runTest(self): # Make sure we saw all URIs, and saw each one once self.failUnlessEqual(uris, seen_uris) - - -def pyUnitTests(): - cases = [] - for args in data: - cases.append(OpenIDYadisTest(*args)) - return unittest.TestSuite(cases) diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index 4ee1b616..bd5a6267 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -1,15 +1,14 @@ import os.path -import sys import unittest from HTMLParser import HTMLParseError from openid.yadis.parsehtml import ParseDone, YadisHTMLParser -class _TestCase(unittest.TestCase): +class TestParseHTML(unittest.TestCase): reserved_values = ['None', 'EOF'] - def runTest(self): + def test(self): for expected, case in getCases(): p = YadisHTMLParser() try: @@ -41,19 +40,6 @@ def parseCases(data): return cases -def pyUnitTests(): - """Make a pyunit TestSuite from a file defining test cases.""" - s = unittest.TestSuite() - for (filename, test_num, expected, case) in getCases(): - s.addTest(_TestCase(filename, str(test_num), expected, case)) - return s - - -def test(): - runner = unittest.TextTestRunner() - return runner.run(pyUnitTests()) - - filenames = ['data/test1-parsehtml.txt'] default_test_files = [] @@ -70,7 +56,3 @@ def getCases(test_files=default_test_files): for expected, case in parseCases(data): cases.append((expected, case)) return cases - - -if __name__ == '__main__': - sys.exit(not test().wasSuccessful()) diff --git a/openid/test/test_storetest.py b/openid/test/test_storetest.py index a3885b5c..6937f041 100644 --- a/openid/test/test_storetest.py +++ b/openid/test/test_storetest.py @@ -1,3 +1,4 @@ +"""Test `openid.store` module.""" import os import random import socket @@ -222,179 +223,169 @@ def checkUseNonce(nonce, expected, server_url, msg=''): nonceModule.SKEW = orig_skew -def test_filestore(): - from openid.store import filestore - import tempfile - import shutil - try: - temp_dir = tempfile.mkdtemp() - except AttributeError: - import os - temp_dir = os.tmpnam() - os.mkdir(temp_dir) +class TestFileOpenIDStore(unittest.TestCase): + """Test `FileOpenIDStore` class.""" - store = filestore.FileOpenIDStore(temp_dir) - try: - testStore(store) - store.cleanup() - except Exception: - raise - else: - shutil.rmtree(temp_dir) + def test_filestore(self): + from openid.store import filestore + import tempfile + import shutil + try: + temp_dir = tempfile.mkdtemp() + except AttributeError: + import os + temp_dir = os.tmpnam() + os.mkdir(temp_dir) + store = filestore.FileOpenIDStore(temp_dir) + try: + testStore(store) + store.cleanup() + except Exception: + raise + else: + shutil.rmtree(temp_dir) -def test_sqlite(): - from openid.store import sqlstore - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError: - pass - else: - conn = sqlite.connect(':memory:') - store = sqlstore.SQLiteStore(conn) - store.createTables() - testStore(store) - - -def test_mysql(): - from openid.store import sqlstore - try: - import MySQLdb - except ImportError: - pass - else: - db_user = 'openid_test' - db_passwd = '' - db_name = getTmpDbName() - - # Change this connect line to use the right user and password + +class TestSQLiteStore(unittest.TestCase): + """Test `SQLiteStore` class.""" + + def test_sqlite(self): + from openid.store import sqlstore try: - conn = MySQLdb.connect(user=db_user, passwd=db_passwd, host=db_host) - except MySQLdb.OperationalError as why: - if why[0] == 2005: - print ('Skipping MySQL store test (cannot connect ' - 'to test server on host %r)' % (db_host,)) - return - else: - raise - - conn.query('CREATE DATABASE %s;' % db_name) + from pysqlite2 import dbapi2 as sqlite + except ImportError: + pass + else: + conn = sqlite.connect(':memory:') + store = sqlstore.SQLiteStore(conn) + store.createTables() + testStore(store) + + +class TestMySQLStore(unittest.TestCase): + """Test `MySQLStore` class.""" + + def test_mysql(self): + from openid.store import sqlstore try: - conn.query('USE %s;' % db_name) + import MySQLdb + except ImportError: + pass + else: + db_user = 'openid_test' + db_passwd = '' + db_name = getTmpDbName() - # OK, we're in the right environment. Create store and - # create the tables. - store = sqlstore.MySQLStore(conn) - store.createTables() + # Change this connect line to use the right user and password + try: + conn = MySQLdb.connect(user=db_user, passwd=db_passwd, host=db_host) + except MySQLdb.OperationalError as why: + if why[0] == 2005: + print ('Skipping MySQL store test (cannot connect ' + 'to test server on host %r)' % (db_host,)) + return + else: + raise - # At last, we get to run the test. - testStore(store) - finally: - # Remove the database. If you want to do post-mortem on a - # failing test, comment out this line. - conn.query('DROP DATABASE %s;' % db_name) + conn.query('CREATE DATABASE %s;' % db_name) + try: + conn.query('USE %s;' % db_name) + # OK, we're in the right environment. Create store and + # create the tables. + store = sqlstore.MySQLStore(conn) + store.createTables() -def test_postgresql(): - """ - Tests the PostgreSQLStore on a locally-hosted PostgreSQL database - cluster, version 7.4 or later. To run this test, you must have: + # At last, we get to run the test. + testStore(store) + finally: + # Remove the database. If you want to do post-mortem on a + # failing test, comment out this line. + conn.query('DROP DATABASE %s;' % db_name) - - The 'psycopg' python module (version 1.1) installed - - PostgreSQL running locally +class TestPostgreSQLStore(unittest.TestCase): + """Test `PostgreSQLStore` class.""" - - An 'openid_test' user account in your database cluster, which - you can create by running 'createuser -Ad openid_test' as the - 'postgres' user + def test_postgresql(self): + """ + Tests the PostgreSQLStore on a locally-hosted PostgreSQL database + cluster, version 7.4 or later. To run this test, you must have: - - Trust auth for the 'openid_test' account, which you can activate - by adding the following line to your pg_hba.conf file: + - The 'psycopg' python module (version 1.1) installed - local all openid_test trust + - PostgreSQL running locally - This test connects to the database cluster three times: + - An 'openid_test' user account in your database cluster, which + you can create by running 'createuser -Ad openid_test' as the + 'postgres' user - - To the 'template1' database, to create the test database + - Trust auth for the 'openid_test' account, which you can activate + by adding the following line to your pg_hba.conf file: - - To the test database, to run the store tests + local all openid_test trust - - To the 'template1' database once more, to drop the test database - """ - from openid.store import sqlstore - try: - import psycopg - except ImportError: - pass - else: - db_name = getTmpDbName() - db_user = 'openid_test' - - # Connect once to create the database; reconnect to access the - # new database. - conn_create = psycopg.connect(database='template1', user=db_user, host=db_host) - conn_create.autocommit() - - # Create the test database. - cursor = conn_create.cursor() - cursor.execute('CREATE DATABASE %s;' % (db_name,)) - conn_create.close() - - # Connect to the test database. - conn_test = psycopg.connect(database=db_name, user=db_user, host=db_host) - - # OK, we're in the right environment. Create the store - # instance and create the tables. - store = sqlstore.PostgreSQLStore(conn_test) - store.createTables() - - # At last, we get to run the test. - testStore(store) - - # Disconnect. - conn_test.close() - - # It takes a little time for the close() call above to take - # effect, so we'll wait for a second before trying to remove - # the database. (Maybe this is because we're using a UNIX - # socket to connect to postgres rather than TCP?) - import time - time.sleep(1) - - # Remove the database now that the test is over. - conn_remove = psycopg.connect(database='template1', user=db_user, host=db_host) - conn_remove.autocommit() - - cursor = conn_remove.cursor() - cursor.execute('DROP DATABASE %s;' % (db_name,)) - conn_remove.close() - - -def test_memstore(): - from openid.store import memstore - testStore(memstore.MemoryStore()) - - -test_functions = [ - test_filestore, - test_sqlite, - test_mysql, - test_postgresql, - test_memstore, -] - - -def pyUnitTests(): - tests = map(unittest.FunctionTestCase, test_functions) - return unittest.TestSuite(tests) - - -if __name__ == '__main__': - import sys - suite = pyUnitTests() - runner = unittest.TextTestRunner() - result = runner.run(suite) - if result.wasSuccessful(): - sys.exit(0) - else: - sys.exit(1) + This test connects to the database cluster three times: + + - To the 'template1' database, to create the test database + + - To the test database, to run the store tests + + - To the 'template1' database once more, to drop the test database + """ + from openid.store import sqlstore + try: + import psycopg + except ImportError: + pass + else: + db_name = getTmpDbName() + db_user = 'openid_test' + + # Connect once to create the database; reconnect to access the + # new database. + conn_create = psycopg.connect(database='template1', user=db_user, host=db_host) + conn_create.autocommit() + + # Create the test database. + cursor = conn_create.cursor() + cursor.execute('CREATE DATABASE %s;' % (db_name,)) + conn_create.close() + + # Connect to the test database. + conn_test = psycopg.connect(database=db_name, user=db_user, host=db_host) + + # OK, we're in the right environment. Create the store + # instance and create the tables. + store = sqlstore.PostgreSQLStore(conn_test) + store.createTables() + + # At last, we get to run the test. + testStore(store) + + # Disconnect. + conn_test.close() + + # It takes a little time for the close() call above to take + # effect, so we'll wait for a second before trying to remove + # the database. (Maybe this is because we're using a UNIX + # socket to connect to postgres rather than TCP?) + import time + time.sleep(1) + + # Remove the database now that the test is over. + conn_remove = psycopg.connect(database='template1', user=db_user, host=db_host) + conn_remove.autocommit() + + cursor = conn_remove.cursor() + cursor.execute('DROP DATABASE %s;' % (db_name,)) + conn_remove.close() + + +class TestMemoryStore(unittest.TestCase): + """Test `MemoryStore` class.""" + + def test_memstore(self): + from openid.store import memstore + testStore(memstore.MemoryStore()) diff --git a/openid/test/test_trustroot.py b/openid/test/test_trustroot.py index c9a0f726..7905e4a8 100644 --- a/openid/test/test_trustroot.py +++ b/openid/test/test_trustroot.py @@ -3,51 +3,43 @@ from openid.server.trustroot import TrustRoot +with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'trustroot.txt')) as test_data_file: + trustroot_test_data = test_data_file.read() -class _ParseTest(unittest.TestCase): - def __init__(self, sanity, desc, case): - unittest.TestCase.__init__(self) - self.desc = desc + ': ' + repr(case) - self.case = case - self.sanity = sanity - def shortDescription(self): - return self.desc +class ParseTest(unittest.TestCase): - def runTest(self): - tr = TrustRoot.parse(self.case) - if self.sanity == 'sane': - assert tr.isSane(), self.case - elif self.sanity == 'insane': - assert not tr.isSane(), self.case - else: - assert tr is None, tr + def test(self): + ph, pdat, mh, mdat = parseTests(trustroot_test_data) + for sanity, desc, case in getTests(['bad', 'insane', 'sane'], ph, pdat): + tr = TrustRoot.parse(case) + if sanity == 'sane': + assert tr.isSane(), case + elif sanity == 'insane': + assert not tr.isSane(), case + else: + assert tr is None, tr -class _MatchTest(unittest.TestCase): - def __init__(self, match, desc, line): - unittest.TestCase.__init__(self) - tr, rt = line.split() - self.desc = desc + ': ' + repr(tr) + ' ' + repr(rt) - self.tr = tr - self.rt = rt - self.match = match - def shortDescription(self): - return self.desc +class MatchTest(unittest.TestCase): - def runTest(self): - tr = TrustRoot.parse(self.tr) - self.failIf(tr is None, self.tr) + def test(self): + ph, pdat, mh, mdat = parseTests(trustroot_test_data) - match = tr.validateURL(self.rt) - if self.match: - assert match - else: - assert not match + for expected_match, desc, line in getTests([1, 0], mh, mdat): + tr, rt = line.split() + tr = TrustRoot.parse(tr) + self.failIf(tr is None, tr) + match = tr.validateURL(rt) + if expected_match: + assert match + else: + assert not match -def getTests(t, grps, head, dat): + +def getTests(grps, head, dat): tests = [] top = head.strip() gdat = map(str.strip, dat.split('-' * 40 + '\n')) @@ -59,7 +51,7 @@ def getTests(t, grps, head, dat): cases = gdat[i + 1].split('\n') assert len(cases) == int(n) for case in cases: - tests.append(t(x, top + ' - ' + desc, case)) + tests.append((x, top + ' - ' + desc, case)) i += 2 return tests @@ -68,25 +60,4 @@ def parseTests(data): parts = map(str.strip, data.split('=' * 40 + '\n')) assert not parts[0] _, ph, pdat, mh, mdat = parts - - tests = [] - tests.extend(getTests(_ParseTest, ['bad', 'insane', 'sane'], ph, pdat)) - tests.extend(getTests(_MatchTest, [1, 0], mh, mdat)) - return tests - - -def pyUnitTests(): - here = os.path.dirname(os.path.abspath(__file__)) - test_data_file_name = os.path.join(here, 'data', 'trustroot.txt') - test_data_file = file(test_data_file_name) - test_data = test_data_file.read() - test_data_file.close() - - tests = parseTests(test_data) - return unittest.TestSuite(tests) - - -if __name__ == '__main__': - suite = pyUnitTests() - runner = unittest.TextTestRunner() - runner.run(suite) + return ph, pdat, mh, mdat diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 49f18fc9..0db74eb0 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -3,7 +3,6 @@ import openid.urinorm - with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'urinorm.txt')) as test_data_file: test_data = test_data_file.read() @@ -29,27 +28,3 @@ def parse(self, full_case): case = unicode(case, 'utf-8') return (desc, case, expected) - - -def parseTests(test_data): - result = [] - - cases = test_data.split('\n\n') - for case in cases: - case = case.strip() - - if case: - result.append(UrinormTest.parse(case)) - - return result - - -def pyUnitTests(): - here = os.path.dirname(os.path.abspath(__file__)) - test_data_file_name = os.path.join(here, 'urinorm.txt') - test_data_file = file(test_data_file_name) - test_data = test_data_file.read() - test_data_file.close() - - tests = parseTests(test_data) - return unittest.TestSuite(tests) diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index c7ba05c0..e00f0e89 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -96,90 +96,51 @@ def test_404(self): self.failUnlessRaises(DiscoveryFailure, discover, uri) -class _TestCase(unittest.TestCase): +class TestDiscover(unittest.TestCase): base_url = 'http://invalid.unittest/' - def __init__(self, input_name, id_name, result_name, success): - self.input_name = input_name - self.id_name = id_name - self.result_name = result_name - self.success = success - # Still not quite sure how to best construct these custom tests. - # Between python2.3 and python2.4, a patch attached to pyunit.sf.net - # bug #469444 got applied which breaks loadTestsFromModule on this - # class if it has test_ or runTest methods. So, kludge to change - # the method name. - unittest.TestCase.__init__(self, methodName='runCustomTest') - def setUp(self): fetchers.setDefaultFetcher(TestFetcher(self.base_url), wrap_exceptions=False) - self.input_url, self.expected = discoverdata.generateResult( - self.base_url, - self.input_name, - self.id_name, - self.result_name, - self.success) - def tearDown(self): fetchers.setDefaultFetcher(None) - def runCustomTest(self): - if self.expected is DiscoveryFailure: - self.failUnlessRaises(DiscoveryFailure, - discover, self.input_url) - else: - result = discover(self.input_url) - self.failUnlessEqual(self.input_url, result.request_uri) - - msg = 'Identity URL mismatch: actual = %r, expected = %r' % ( - result.normalized_uri, self.expected.normalized_uri) - self.failUnlessEqual( - self.expected.normalized_uri, result.normalized_uri, msg) - - msg = 'Content mismatch: actual = %r, expected = %r' % ( - result.response_text, self.expected.response_text) - self.failUnlessEqual( - self.expected.response_text, result.response_text, msg) - - expected_keys = sorted(dir(self.expected)) - actual_keys = sorted(dir(result)) - self.failUnlessEqual(actual_keys, expected_keys) - - for k in dir(self.expected): - if k.startswith('__') and k.endswith('__'): - continue - exp_v = getattr(self.expected, k) - if isinstance(exp_v, types.MethodType): - continue - act_v = getattr(result, k) - assert act_v == exp_v, (k, exp_v, act_v) - - def shortDescription(self): - try: - n = self.input_url - except AttributeError: - # run before setUp, or if setUp did not complete successfully. - n = self.input_name - return "%s (%s)" % ( - n, - self.__class__.__module__) - - -def pyUnitTests(): - s = unittest.TestSuite() - for success, input_name, id_name, result_name in discoverdata.testlist: - test = _TestCase(input_name, id_name, result_name, success) - s.addTest(test) - - return s - - -def test(): - runner = unittest.TextTestRunner() - return runner.run(pyUnitTests()) - - -if __name__ == '__main__': - test() + def test(self): + for success, input_name, id_name, result_name in discoverdata.testlist: + input_url, expected = discoverdata.generateResult( + self.base_url, + input_name, + id_name, + result_name, + success) + + if expected is DiscoveryFailure: + self.failUnlessRaises(DiscoveryFailure, + discover, input_url) + else: + result = discover(input_url) + self.failUnlessEqual(input_url, result.request_uri) + + msg = 'Identity URL mismatch: actual = %r, expected = %r' % ( + result.normalized_uri, expected.normalized_uri) + self.failUnlessEqual( + expected.normalized_uri, result.normalized_uri, msg) + + msg = 'Content mismatch: actual = %r, expected = %r' % ( + result.response_text, expected.response_text) + self.failUnlessEqual( + expected.response_text, result.response_text, msg) + + expected_keys = sorted(dir(expected)) + actual_keys = sorted(dir(result)) + self.failUnlessEqual(actual_keys, expected_keys) + + for k in dir(expected): + if k.startswith('__') and k.endswith('__'): + continue + exp_v = getattr(expected, k) + if isinstance(exp_v, types.MethodType): + continue + act_v = getattr(result, k) + assert act_v == exp_v, (k, exp_v, act_v) From 615df55b186cac51fb4330eea0e946d75adf191f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 12 Dec 2017 16:47:27 +0100 Subject: [PATCH 029/151] Clean djopenid and run its tests --- Makefile | 6 +- examples/djopenid/consumer/urls.py | 16 +++-- examples/djopenid/consumer/views.py | 25 +++---- examples/djopenid/manage.py | 30 +++++--- examples/djopenid/server/tests.py | 15 ++-- examples/djopenid/server/urls.py | 23 ++++--- examples/djopenid/server/views.py | 69 ++++++------------- examples/djopenid/settings.py | 65 +++-------------- examples/djopenid/templates/index.html | 4 +- .../djopenid/templates/server/idPage.html | 4 +- examples/djopenid/templates/server/index.html | 2 +- examples/djopenid/templates/server/trust.html | 2 +- examples/djopenid/urls.py | 15 ++-- examples/djopenid/util.py | 53 +------------- examples/djopenid/views.py | 15 ---- 15 files changed, 113 insertions(+), 231 deletions(-) delete mode 100644 examples/djopenid/views.py diff --git a/Makefile b/Makefile index 4e66971c..2a2b1a4f 100644 --- a/Makefile +++ b/Makefile @@ -1,14 +1,12 @@ .PHONY: test coverage isort check-all check-isort check-flake8 test: - # TODO: Ignore djopenid tests for the time being - python -m unittest discover --start openid/test -t . + PYTHONPATH="examples" DJANGO_SETTINGS_MODULE="djopenid.settings" python -m unittest discover coverage: python -m coverage erase -rm -r htmlcov - # TODO: Ignore djopenid tests for the time being - python -m coverage run --branch --source="." openid/test/__init__.py discover --start openid/test -t . + PYTHONPATH="examples" DJANGO_SETTINGS_MODULE="djopenid.settings" python -m coverage run --branch --source="." openid/test/__init__.py discover python -m coverage html --directory=htmlcov isort: diff --git a/examples/djopenid/consumer/urls.py b/examples/djopenid/consumer/urls.py index 7190093e..9b37b1aa 100644 --- a/examples/djopenid/consumer/urls.py +++ b/examples/djopenid/consumer/urls.py @@ -1,8 +1,10 @@ -from django.conf.urls.defaults import patterns +"""Consumer URLs.""" +from django.conf.urls import url -urlpatterns = patterns( - 'djopenid.consumer.views', - (r'^$', 'startOpenID'), - (r'^finish/$', 'finishOpenID'), - (r'^xrds/$', 'rpXRDS'), -) +from djopenid.consumer.views import finishOpenID, rpXRDS, startOpenID + +urlpatterns = [ + url(r'^$', startOpenID, name='index'), + url(r'^finish/$', finishOpenID, name='return_to'), + url(r'^xrds/$', rpXRDS, name='xrds'), +] diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index bbc0ff87..c92208b6 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -1,5 +1,6 @@ from django.http import HttpResponseRedirect -from django.views.generic.simple import direct_to_template +from django.shortcuts import render +from django.urls import reverse from openid.consumer import consumer from openid.consumer.discover import DiscoveryFailure @@ -36,12 +37,11 @@ def getConsumer(request): def renderIndexPage(request, **template_args): - template_args['consumer_url'] = util.getViewURL(request, startOpenID) + template_args['consumer_url'] = request.build_absolute_uri(reverse('consumer:index')) template_args['pape_policies'] = POLICY_PAIRS - response = direct_to_template( - request, 'consumer/index.html', template_args) - response[YADIS_HEADER_NAME] = util.getViewURL(request, rpXRDS) + response = render(request, 'consumer/index.html', template_args) + response[YADIS_HEADER_NAME] = request.build_absolute_uri(reverse('consumer:xrds')) return response @@ -114,8 +114,8 @@ def startOpenID(request): # Compute the trust root and return URL values to build the # redirect information. - trust_root = util.getViewURL(request, startOpenID) - return_to = util.getViewURL(request, finishOpenID) + trust_root = util.request.build_absolute_uri(reverse('consumer:index')) + return_to = util.request.build_absolute_uri(reverse('consumer:return_to')) # Send the browser to the server either by sending a redirect # URL or by generating a POST form. @@ -130,8 +130,7 @@ def startOpenID(request): form_id = 'openid_message' form_html = auth_request.formMarkup(trust_root, return_to, False, {'id': form_id}) - return direct_to_template( - request, 'consumer/request_form.html', {'html': form_html}) + return render(request, 'consumer/request_form.html', {'html': form_html}) return renderIndexPage(request) @@ -159,7 +158,7 @@ def finishOpenID(request): # Get a response object indicating the result of the OpenID # protocol. - return_to = util.getViewURL(request, finishOpenID) + return_to = request.build_absolute_uri(reverse('consumer:return_to')) response = c.complete(request_args, return_to) # Get a Simple Registration response object if response @@ -218,7 +217,5 @@ def rpXRDS(request): """ Return a relying party verification XRDS document """ - return util.renderXRDS( - request, - [RP_RETURN_TO_URL_TYPE], - [util.getViewURL(request, finishOpenID)]) + return_to = request.build_absolute_uri(reverse('consumer:return_to')) + return util.renderXRDS(request, [RP_RETURN_TO_URL_TYPE], [return_to]) diff --git a/examples/djopenid/manage.py b/examples/djopenid/manage.py index 45a1ee63..fb88042f 100644 --- a/examples/djopenid/manage.py +++ b/examples/djopenid/manage.py @@ -1,14 +1,22 @@ #!/usr/bin/env python -from django.core.management import execute_manager - -try: - import settings # Assumed to be in the same directory. -except ImportError: - import sys - sys.stderr.write("Error: Can't find the file 'settings.py' in the directory containing %r. It appears you've " - "customized things.\nYou'll have to run django-admin.py, passing it your settings module.\n(If " - "the file settings.py does indeed exist, it's causing an ImportError somehow.)\n" % __file__) - sys.exit(1) +import os +import sys if __name__ == "__main__": - execute_manager(settings) + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "djopenid.settings") + try: + from django.core.management import execute_from_command_line + except ImportError: + # The above import may fail for some other reason. Ensure that the + # issue is really that Django is missing to avoid masking other + # exceptions on Python 2. + try: + import django # noqa: F401 + except ImportError: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) + raise + execute_from_command_line(sys.argv) diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index 6cae5471..8d0b8de4 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -1,7 +1,9 @@ +from urlparse import urljoin -from django.contrib.sessions.middleware import SessionWrapper +import django from django.http import HttpRequest from django.test.testcases import TestCase +from django.urls import reverse from openid.message import Message from openid.server.server import CheckIDRequest @@ -11,11 +13,14 @@ from .. import util from ..server import views +# Allow django tests to run through discover +django.setup() + def dummyRequest(): request = HttpRequest() - request.session = SessionWrapper("test") - request.META['HTTP_HOST'] = 'example.invalid' + request.session = {} + request.META['HTTP_HOST'] = 'example.cz' request.META['SERVER_PROTOCOL'] = 'HTTP' return request @@ -24,7 +29,7 @@ class TestProcessTrustResult(TestCase): def setUp(self): self.request = dummyRequest() - id_url = util.getViewURL(self.request, views.idPage) + id_url = urljoin('http://example.cz/', reverse('server:local_id')) # Set up the OpenID request we're responding to. op_endpoint = 'http://127.0.0.1:8080/endpoint' @@ -65,7 +70,7 @@ class TestShowDecidePage(TestCase): def test_unreachableRealm(self): self.request = dummyRequest() - id_url = util.getViewURL(self.request, views.idPage) + id_url = urljoin('http://example.cz/', reverse('server:local_id')) # Set up the OpenID request we're responding to. op_endpoint = 'http://127.0.0.1:8080/endpoint' diff --git a/examples/djopenid/server/urls.py b/examples/djopenid/server/urls.py index 6763d856..2eff514f 100644 --- a/examples/djopenid/server/urls.py +++ b/examples/djopenid/server/urls.py @@ -1,11 +1,14 @@ -from django.conf.urls.defaults import patterns +"""Server URLs.""" +from django.conf.urls import url +from django.views.generic import TemplateView -urlpatterns = patterns( - 'djopenid.server.views', - (r'^$', 'server'), - (r'^xrds/$', 'idpXrds'), - (r'^processTrustResult/$', 'processTrustResult'), - (r'^user/$', 'idPage'), - (r'^endpoint/$', 'endpoint'), - (r'^trust/$', 'trustPage'), -) +from djopenid.server.views import endpoint, idPage, idpXrds, processTrustResult, server + +urlpatterns = [ + url(r'^$', server, name='index'), + url(r'^xrds/$', idpXrds, name='xrds'), + url(r'^user/$', idPage, name='local_id'), + url(r'^endpoint/$', endpoint, name='endpoint'), + url(r'^trust/$', TemplateView.as_view(template_name='server/trust.html'), name='confirmation'), + url(r'^processTrustResult/$', processTrustResult, name='process-confirmation'), +] diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index bbb9468d..0701ab33 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -18,7 +18,8 @@ import cgi from django import http -from django.views.generic.simple import direct_to_template +from django.shortcuts import render +from django.urls import reverse from openid.consumer.discover import OPENID_IDP_2_0_TYPE from openid.extensions import pape, sreg @@ -28,7 +29,6 @@ from openid.yadis.discover import DiscoveryFailure from .. import util -from ..util import getViewURL def getOpenIDStore(): @@ -43,7 +43,8 @@ def getServer(request): """ Get a Server object to perform OpenID authentication. """ - return Server(getOpenIDStore(), getViewURL(request, endpoint)) + endpoint_url = request.build_absolute_uri(reverse('server:endpoint')) + return Server(getOpenIDStore(), endpoint_url) def setRequest(request, openid_request): @@ -67,12 +68,10 @@ def server(request): """ Respond to requests for the server's primary web page. """ - return direct_to_template( - request, - 'server/index.html', - {'user_url': getViewURL(request, idPage), - 'server_xrds_url': getViewURL(request, idpXrds), - }) + local_id = request.build_absolute_uri(reverse('server:local_id')) + server_xrds_url = request.build_absolute_uri(reverse('server:xrds')) + context = {'local_id': local_id, 'server_xrds_url': server_xrds_url} + return render(request, 'server/index.html', context) def idpXrds(request): @@ -80,29 +79,16 @@ def idpXrds(request): Respond to requests for the IDP's XRDS document, which is used in IDP-driven identifier selection. """ - return util.renderXRDS( - request, [OPENID_IDP_2_0_TYPE], [getViewURL(request, endpoint)]) + endpoint_url = request.build_absolute_uri(reverse('server:endpoint')) + return util.renderXRDS(request, [OPENID_IDP_2_0_TYPE], [endpoint_url]) def idPage(request): """ Serve the identity page for OpenID URLs. """ - return direct_to_template( - request, - 'server/idPage.html', - {'server_url': getViewURL(request, endpoint)}) - - -def trustPage(request): - """ - Display the trust page template, which allows the user to decide - whether to approve the OpenID verification. - """ - return direct_to_template( - request, - 'server/trust.html', - {'trust_handler_url': getViewURL(request, processTrustResult)}) + endpoint_url = request.build_absolute_uri(reverse('server:endpoint')) + return render(request, 'server/idPage.html', {'endpoint_url': endpoint_url}) def endpoint(request): @@ -119,18 +105,12 @@ def endpoint(request): openid_request = s.decodeRequest(query) except ProtocolError as why: # This means the incoming request was invalid. - return direct_to_template( - request, - 'server/endpoint.html', - {'error': str(why)}) + return render(request, 'server/endpoint.html', {'error': str(why)}) # If we did not get a request, display text indicating that this # is an endpoint. if openid_request is None: - return direct_to_template( - request, - 'server/endpoint.html', - {}) + return render(request, 'server/endpoint.html') # We got a request; if the mode is checkid_*, we will handle it by # getting feedback from the user or by checking the session. @@ -157,7 +137,7 @@ def handleCheckIDRequest(request, openid_request): # what URL should be sent. if not openid_request.idSelect(): - id_url = getViewURL(request, idPage) + id_url = request.build_absolute_uri(reverse('server:local_id')) # Confirm that this server can actually vouch for that # identifier @@ -204,14 +184,10 @@ def showDecidePage(request, openid_request): pape_request = pape.Request.fromOpenIDRequest(openid_request) - return direct_to_template( - request, - 'server/trust.html', - {'trust_root': trust_root, - 'trust_handler_url': getViewURL(request, processTrustResult), - 'trust_root_valid': trust_root_valid, - 'pape_request': pape_request, - }) + context = {'trust_root': trust_root, + 'trust_root_valid': trust_root_valid, + 'pape_request': pape_request} + return render(request, 'server/trust.html', context) def processTrustResult(request): @@ -224,7 +200,7 @@ def processTrustResult(request): openid_request = getRequest(request) # The identifier that this server can vouch for - response_identity = getViewURL(request, idPage) + response_identity = request.build_absolute_uri(reverse('server:local_id')) # If the decision was to allow the verification, respond # accordingly. @@ -274,10 +250,7 @@ def displayResponse(request, openid_response): except EncodingError as why: # If it couldn't be encoded, display an error. text = why.response.encodeToKVForm() - return direct_to_template( - request, - 'server/endpoint.html', - {'error': cgi.escape(text)}) + return render(request, 'server/endpoint.html', {'error': cgi.escape(text)}) # Construct the appropriate django framework response. r = http.HttpResponse(webresponse.body) diff --git a/examples/djopenid/settings.py b/examples/djopenid/settings.py index 1ba3ff44..fc2a2b1e 100644 --- a/examples/djopenid/settings.py +++ b/examples/djopenid/settings.py @@ -1,5 +1,4 @@ -# Django settings for djopenid project. - +"""Example Django settings for djopenid project.""" import os import sys import warnings @@ -13,60 +12,26 @@ del openid DEBUG = True -TEMPLATE_DEBUG = DEBUG - -ADMINS = ( - # ('Your Name', 'your_email@domain.com'), -) - -MANAGERS = ADMINS +ALLOWED_HOSTS = ['*'] DATABASES = { 'default': { 'ENGINE': 'django.db.backends.sqlite3', # Add 'postgresql_psycopg2', 'mysql', 'sqlite3' or 'oracle'. - 'NAME': '/tmp/test.db', # Or path to database file if using sqlite3. - 'USER': '', # Not used with sqlite3. - 'PASSWORD': '', # Not used with sqlite3. - 'HOST': '', # Set to empty string for localhost. Not used with sqlite3. - 'PORT': '', # Set to empty string for default. Not used with sqlite3. + 'NAME': ':memory:', } } -# Local time zone for this installation. All choices can be found here: -# http://www.postgresql.org/docs/current/static/datetime-keywords.html#DATETIME-TIMEZONE-SET-TABLE -TIME_ZONE = 'America/Chicago' - -# Language code for this installation. All choices can be found here: -# http://www.w3.org/TR/REC-html40/struct/dirlang.html#langcodes -# http://blogs.law.harvard.edu/tech/stories/storyReader$15 -LANGUAGE_CODE = 'en-us' - -SITE_ID = 1 - -# Absolute path to the directory that holds media. -# Example: "/home/media/media.lawrence.com/" -MEDIA_ROOT = '' - -# URL that handles the media served from MEDIA_ROOT. -# Example: "http://media.lawrence.com" -MEDIA_URL = '' - -# URL prefix for admin media -- CSS, JavaScript and images. Make sure to use a -# trailing slash. -# Examples: "http://foo.com/media/", "/media/". -ADMIN_MEDIA_PREFIX = '/media/' - -# Make this unique, and don't share it with anybody. SECRET_KEY = 'u^bw6lmsa6fah0$^lz-ct$)y7x7#ag92-z+y45-8!(jk0lkavy' -# List of callables that know how to import templates from various sources. -TEMPLATE_LOADERS = ( - 'django.template.loaders.filesystem.Loader', - 'django.template.loaders.app_directories.Loader', - # 'django.template.loaders.eggs.load_template_source', -) +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': [os.path.abspath(os.path.join(os.path.dirname(__file__), 'templates'))], + 'APP_DIRS': True, + } +] -MIDDLEWARE_CLASSES = ( +MIDDLEWARE = ( 'django.middleware.common.CommonMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.auth.middleware.AuthenticationMiddleware', @@ -75,16 +40,8 @@ ROOT_URLCONF = 'djopenid.urls' -TEMPLATE_CONTEXT_PROCESSORS = () - -TEMPLATE_DIRS = ( - os.path.abspath(os.path.join(os.path.dirname(__file__), 'templates')), -) - INSTALLED_APPS = ( - 'django.contrib.contenttypes', 'django.contrib.sessions', - 'djopenid.consumer', 'djopenid.server', ) diff --git a/examples/djopenid/templates/index.html b/examples/djopenid/templates/index.html index 62691ecf..2757bfc4 100644 --- a/examples/djopenid/templates/index.html +++ b/examples/djopenid/templates/index.html @@ -15,8 +15,8 @@

diff --git a/examples/djopenid/templates/server/idPage.html b/examples/djopenid/templates/server/idPage.html index 06eb582f..b63ea8f8 100644 --- a/examples/djopenid/templates/server/idPage.html +++ b/examples/djopenid/templates/server/idPage.html @@ -3,8 +3,8 @@ {% block head %} - - + + {% endblock %} {% block body %} diff --git a/examples/djopenid/templates/server/index.html b/examples/djopenid/templates/server/index.html index 01108d0d..8655ba90 100644 --- a/examples/djopenid/templates/server/index.html +++ b/examples/djopenid/templates/server/index.html @@ -41,7 +41,7 @@ application. The OpenID it serves is

-{{ user_url }}
+{{ local_id }}
     

diff --git a/examples/djopenid/templates/server/trust.html b/examples/djopenid/templates/server/trust.html index 815ab85d..ee098e2c 100644 --- a/examples/djopenid/templates/server/trust.html +++ b/examples/djopenid/templates/server/trust.html @@ -39,7 +39,7 @@
+ action="{% url 'server:process-confirmation' %}"> Verify your identity to the relying party?
diff --git a/examples/djopenid/urls.py b/examples/djopenid/urls.py index 37833177..551fc5e7 100644 --- a/examples/djopenid/urls.py +++ b/examples/djopenid/urls.py @@ -1,8 +1,9 @@ -from django.conf.urls.defaults import include, patterns +"""Djopenid URLs.""" +from django.conf.urls import include, url +from django.views.generic import TemplateView -urlpatterns = patterns( - '', - ('^$', 'djopenid.views.index'), - ('^consumer/', include('djopenid.consumer.urls')), - ('^server/', include('djopenid.server.urls')), -) +urlpatterns = [ + url('^$', TemplateView.as_view(template_name='index.html'), name='index'), + url('^consumer/', include(('djopenid.consumer.urls', 'consumer'))), + url('^server/', include(('djopenid.server.urls', 'server'))), +] diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index 2847d8e3..aa41727a 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -1,13 +1,10 @@ """ Utility code for the Django example consumer and server. """ -from urlparse import urljoin - from django.conf import settings from django.core.exceptions import ImproperlyConfigured -from django.core.urlresolvers import reverse as reverseURL from django.db import connection -from django.views.generic.simple import direct_to_template +from django.shortcuts import render from openid.store import sqlstore from openid.store.filestore import FileOpenIDStore @@ -77,47 +74,6 @@ def getOpenIDStore(filestore_path, table_prefix): return s -def getViewURL(req, view_name_or_obj, args=None, kwargs=None): - relative_url = reverseURL(view_name_or_obj, args=args, kwargs=kwargs) - full_path = req.META.get('SCRIPT_NAME', '') + relative_url - return urljoin(getBaseURL(req), full_path) - - -def getBaseURL(req): - """ - Given a Django web request object, returns the OpenID 'trust root' - for that request; namely, the absolute URL to the site root which - is serving the Django request. The trust root will include the - proper scheme and authority. It will lack a port if the port is - standard (80, 443). - """ - name = req.META['HTTP_HOST'] - try: - name = name[:name.index(':')] - except Exception: - pass - - try: - port = int(req.META['SERVER_PORT']) - except Exception: - port = 80 - - proto = req.META['SERVER_PROTOCOL'] - - if 'HTTPS' in proto: - proto = 'https' - else: - proto = 'http' - - if port in [80, 443] or not port: - port = '' - else: - port = ':%s' % (port,) - - url = "%s://%s%s/" % (proto, name, port) - return url - - def normalDict(request_data): """ Converts a django request MutliValueDict (e.g., request.GET, @@ -135,8 +91,5 @@ def renderXRDS(request, type_uris, endpoint_urls): URLs in one service block, and return a response with the appropriate content-type. """ - response = direct_to_template( - request, 'xrds.xml', - {'type_uris': type_uris, 'endpoint_urls': endpoint_urls}) - response['Content-Type'] = YADIS_CONTENT_TYPE - return response + context = {'type_uris': type_uris, 'endpoint_urls': endpoint_urls} + return render(request, 'xrds.xml', context, content_type=YADIS_CONTENT_TYPE) diff --git a/examples/djopenid/views.py b/examples/djopenid/views.py deleted file mode 100644 index 5d7a4e2a..00000000 --- a/examples/djopenid/views.py +++ /dev/null @@ -1,15 +0,0 @@ - -from django.views.generic.simple import direct_to_template - -from . import util - - -def index(request): - consumer_url = util.getViewURL( - request, 'djopenid.consumer.views.startOpenID') - server_url = util.getViewURL(request, 'djopenid.server.views.server') - - return direct_to_template( - request, - 'index.html', - {'consumer_url': consumer_url, 'server_url': server_url}) From 7a94df871477bacf986c3de310b59a204fb62249 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 22 Jan 2018 08:20:31 +0100 Subject: [PATCH 030/151] Drop example tests --- .isort.cfg | 2 +- openid/test/test_examples.py | 202 ----------------------------------- tox.ini | 1 - 3 files changed, 1 insertion(+), 204 deletions(-) delete mode 100644 openid/test/test_examples.py diff --git a/.isort.cfg b/.isort.cfg index 6c8243e1..271c8b6b 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] line_length = 120 combine_as_imports = true -known_third_party = mock,twill +known_third_party = mock known_first_party = openid diff --git a/openid/test/test_examples.py b/openid/test/test_examples.py deleted file mode 100644 index 3550a4c0..00000000 --- a/openid/test/test_examples.py +++ /dev/null @@ -1,202 +0,0 @@ -"Test some examples." - -import os.path -import socket -import sys -import time -import unittest -from cStringIO import StringIO - -from mock import Mock - -from openid.consumer.consumer import AuthRequest -from openid.consumer.discover import OPENID_1_1_TYPE, OpenIDServiceEndpoint - - -class FakeTestInfo(object): - """Twill TestInfo placeholder.""" - - def __init__(self, *args, **kwargs): - pass - - -try: - import twill.commands - import twill.parse - import twill.unit -except ImportError: - twill = Mock() - twill.unit.TestInfo = FakeTestInfo - - -def setUpModule(): - if twill.unit.TestInfo == FakeTestInfo: - unittest.skip("Skipping examples, twill is not available.") - - -class TwillTest(twill.unit.TestInfo): - """Variant of twill.unit.TestInfo that runs a function as a test script, - not twill script from a file. - """ - - # twill.unit is pretty small to start with, we're overriding - # run_script and bypassing twill.parse, so it may make sense to - # rewrite twill.unit altogether. - - # Desirable features: - # * better unittest.TestCase integration. - # - handle logs on setup and teardown. - # - treat TwillAssertionError as failed test assertion, make twill - # assertions more consistant with TestCase.failUnless idioms. - # - better error reporting on failed assertions. - # - The amount of functions passed back and forth between TestInfo - # and TestCase is currently pretty silly. - # * access to child process's logs. - # TestInfo.start_server redirects stdout/stderr to StringIO - # objects which are, afaict, inaccessible to the caller of - # test.unit.run_child_process. - # * notice when the child process dies, i.e. if you muck up and - # your runExampleServer function throws an exception. - - def run_script(self): - time.sleep(self.sleep) - # twill.commands.go(self.get_url()) - self.script(self) - - -def splitDir(d, count): - # in python2.4 and above, it's easier to spell this as - # d.rsplit(os.sep, count) - for i in xrange(count): - d = os.path.dirname(d) - return d - - -def runExampleServer(host, port, data_path): - thisfile = os.path.abspath(sys.modules[__name__].__file__) - topDir = splitDir(thisfile, 3) - exampleDir = os.path.join(topDir, 'examples') - serverExample = os.path.join(exampleDir, 'server.py') - serverModule = {} - execfile(serverExample, serverModule) - serverMain = serverModule['main'] - - serverMain(host, port, data_path) - - -class TestServer(unittest.TestCase): - """Acceptance tests for examples/server.py. - - These are more acceptance tests than unit tests as they actually - start the whole server running and test it on its external HTTP - interface. - """ - - def setUp(self): - self.twillOutput = StringIO() - self.twillErr = StringIO() - twill.set_output(self.twillOutput) - twill.set_errout(self.twillErr) - # FIXME: make sure we pick an available port. - self.server_port = 8080 - - # We need something to feed the server as a realm, but it needn't - # be reachable. (Until we test realm verification.) - self.realm = 'http://127.0.0.1/%s' % (self.id(),) - self.return_to = self.realm + '/return_to' - - twill.commands.reset_browser() - - def runExampleServer(self): - """Zero-arg run-the-server function to be passed to TestInfo.""" - # FIXME - make sure sstore starts clean. - runExampleServer('127.0.0.1', self.server_port, 'sstore') - - def v1endpoint(self, port): - """Return an OpenID 1.1 OpenIDServiceEndpoint for the server.""" - base = "http://%s:%s" % (socket.getfqdn('127.0.0.1'), port) - ep = OpenIDServiceEndpoint() - ep.claimed_id = base + "/id/bob" - ep.server_url = base + "/openidserver" - ep.type_uris = [OPENID_1_1_TYPE] - return ep - - # TODO: test discovery - - def test_checkidv1(self): - """OpenID 1.1 checkid_setup request.""" - ti = TwillTest(self.twill_checkidv1, self.runExampleServer, - self.server_port, sleep=0.2) - twill.unit.run_test(ti) - - if self.twillErr.getvalue(): - self.fail(self.twillErr.getvalue()) - - def test_allowed(self): - """OpenID 1.1 checkid_setup request.""" - ti = TwillTest(self.twill_allowed, self.runExampleServer, - self.server_port, sleep=0.2) - twill.unit.run_test(ti) - - if self.twillErr.getvalue(): - self.fail(self.twillErr.getvalue()) - - def twill_checkidv1(self, twillInfo): - endpoint = self.v1endpoint(self.server_port) - authreq = AuthRequest(endpoint, assoc=None) - url = authreq.redirectURL(self.realm, self.return_to) - - c = twill.commands - - try: - c.go(url) - c.get_browser()._browser.set_handle_redirect(False) - c.submit("yes") - c.code(302) - headers = c.get_browser()._browser.response().info() - finalURL = headers['Location'] - self.failUnless('openid.mode=id_res' in finalURL, finalURL) - self.failUnless('openid.identity=' in finalURL, finalURL) - except twill.commands.TwillAssertionError as e: - msg = '%s\nFinal page:\n%s' % ( - str(e), c.get_browser().get_html()) - self.fail(msg) - - def twill_allowed(self, twillInfo): - endpoint = self.v1endpoint(self.server_port) - authreq = AuthRequest(endpoint, assoc=None) - url = authreq.redirectURL(self.realm, self.return_to) - - c = twill.commands - - try: - c.go(url) - c.code(200) - c.get_browser()._browser.set_handle_redirect(False) - c.formvalue(1, 'remember', 'true') - c.find('name="login_as" value="bob"') - c.submit("yes") - c.code(302) - # Since we set remember=yes, the second time we shouldn't - # see that page. - c.go(url) - c.code(302) - headers = c.get_browser()._browser.response().info() - finalURL = headers['Location'] - self.failUnless(finalURL.startswith(self.return_to)) - except twill.commands.TwillAssertionError: - from traceback import format_exc - msg = '%s\nTwill output:%s\nTwill errors:%s\nFinal page:\n%s' % ( - format_exc(), - self.twillOutput.getvalue(), - self.twillErr.getvalue(), - c.get_browser().get_html()) - self.fail(msg) - - def tearDown(self): - twill.set_output(None) - twill.set_errout(None) - - -if __name__ == '__main__': - unittest.main() diff --git a/tox.ini b/tox.ini index 31d01e22..d34efa52 100644 --- a/tox.ini +++ b/tox.ini @@ -11,5 +11,4 @@ commands = ./run_tests.sh deps = Django nose - twill pycrypto From f0498277ebdcc1d8a2f49f4b7357af864bfb3190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 22 Jan 2018 09:25:13 +0100 Subject: [PATCH 031/151] Update static checks --- Makefile | 8 +++++--- examples/djopenid/consumer/views.py | 1 - examples/djopenid/server/views.py | 1 - examples/djopenid/util.py | 1 - 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 2a2b1a4f..7053456d 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,7 @@ .PHONY: test coverage isort check-all check-isort check-flake8 +SOURCES = openid setup.py admin contrib + test: PYTHONPATH="examples" DJANGO_SETTINGS_MODULE="djopenid.settings" python -m unittest discover @@ -10,12 +12,12 @@ coverage: python -m coverage html --directory=htmlcov isort: - isort --recursive . + isort --recursive ${SOURCES} check-all: check-isort check-flake8 check-isort: - isort --check-only --diff --recursive . + isort --check-only --diff --recursive ${SOURCES} check-flake8: - flake8 --format=pylint . + flake8 --format=pylint ${SOURCES} diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index c92208b6..74d26fd5 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -1,7 +1,6 @@ from django.http import HttpResponseRedirect from django.shortcuts import render from django.urls import reverse - from openid.consumer import consumer from openid.consumer.discover import DiscoveryFailure from openid.extensions import ax, pape, sreg diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index 0701ab33..2db2a415 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -20,7 +20,6 @@ from django import http from django.shortcuts import render from django.urls import reverse - from openid.consumer.discover import OPENID_IDP_2_0_TYPE from openid.extensions import pape, sreg from openid.fetchers import HTTPFetchingError diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index aa41727a..f98f6268 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -5,7 +5,6 @@ from django.core.exceptions import ImproperlyConfigured from django.db import connection from django.shortcuts import render - from openid.store import sqlstore from openid.store.filestore import FileOpenIDStore from openid.yadis.constants import YADIS_CONTENT_TYPE From f8f080fa87043babd82eebd4aa13b166ed2703b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 22 Jan 2018 14:46:32 +0100 Subject: [PATCH 032/151] Update requirements, tox and travis --- .gitignore | 2 ++ .travis.yml | 20 ++++++++++++------ Makefile | 15 +++++++++++--- openid/test/test_fetchers.py | 10 +-------- setup.py | 15 ++++++++++---- tox.ini | 40 ++++++++++++++++++++++++++---------- 6 files changed, 69 insertions(+), 33 deletions(-) diff --git a/.gitignore b/.gitignore index faa1bf6e..58aac550 100644 --- a/.gitignore +++ b/.gitignore @@ -4,5 +4,7 @@ .tox # Created in tests /.coverage +/.eggs /htmlcov +/python_openid.egg-info /sstore diff --git a/.travis.yml b/.travis.yml index 7d67a575..fe0b9baa 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,11 +1,19 @@ language: python +sudo: false + python: - - 2.7 + - "2.7" + - "pypy" + +addons: + apt: + packages: + # Dependencies for pycurl compilation + - libcurl4-openssl-dev + - libssl-dev -before_install: pip install 'Django<=1.11.99' pycrypto lxml isort flake8 -install: python setup.py install +install: + - pip install tox-travis script: - - make check-isort - - make check-flake8 - - make test + - tox diff --git a/Makefile b/Makefile index 7053456d..c1d31b0f 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,18 @@ -.PHONY: test coverage isort check-all check-isort check-flake8 +.PHONY: test test-openid test-djopenid coverage isort check-all check-isort check-flake8 SOURCES = openid setup.py admin contrib -test: - PYTHONPATH="examples" DJANGO_SETTINGS_MODULE="djopenid.settings" python -m unittest discover +# Run tests by default +all: test + +test-openid: + python -m unittest discover --start=openid + +# Run tests for djopenid example +test-djopenid: + DJANGO_SETTINGS_MODULE="djopenid.settings" python -m unittest discover --start=examples + +test: test-openid test-djopenid coverage: python -m coverage erase diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index d7a4e6c7..069e0fa9 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -198,15 +198,7 @@ def finish(self): class TestFetchers(unittest.TestCase): def test(self): - import socket - host = socket.getfqdn('127.0.0.1') - # When I use port 0 here, it works for the first fetch and the - # next one gets connection refused. Bummer. So instead, pick a - # port that's *probably* not in use. - import os - port = (os.getpid() % 31000) + 1024 - - server = HTTPServer((host, port), FetcherTestHandler) + server = HTTPServer(("", 0), FetcherTestHandler) import threading server_thread = threading.Thread(target=server.serve_forever) diff --git a/setup.py b/setup.py index 67ebdf83..a7e3bce7 100644 --- a/setup.py +++ b/setup.py @@ -1,15 +1,21 @@ import os import sys -try: - from setuptools import setup -except ImportError: - from distutils.core import setup +from setuptools import setup if 'sdist' in sys.argv: os.system('./admin/makedoc') version = '[library version:2.2.5]'[17:-1] +EXTRAS_REQUIRE = { + 'quality': ('flake8', 'isort'), + 'tests': ('mock', ), + # Optional dependencies for fetchers + 'httplib2': ('httplib2', ), + 'pycurl': ('pycurl', ), + # Dependencies for Django example + 'djopenid': ('django<1.11.99', ), +} setup( name='python-openid', @@ -29,6 +35,7 @@ 'openid.extensions', 'openid.extensions.draft', ], + extras_require=EXTRAS_REQUIRE, # license specified by classifier. # license=getLicense(), author='JanRain', diff --git a/tox.ini b/tox.ini index d34efa52..5d114bb3 100644 --- a/tox.ini +++ b/tox.ini @@ -1,14 +1,32 @@ -# Tox (http://tox.testrun.org/) is a tool for running tests -# in multiple virtualenvs. This configuration file will run the -# test suite on all supported python versions. To use it, "pip install tox" -# and then run "tox" from this directory. - [tox] -envlist = py25, py26, py27, pypy +envlist = + quality + py27-{openid,djopenid,httplib2,pycurl} + pypy-{openid,djopenid,httplib2,pycurl} + +# tox-travis specials +[travis] +python = + 2.7: py27, quality +# Generic specification for all unspecific environments [testenv] -commands = ./run_tests.sh -deps = - Django - nose - pycrypto +whitelist_externals = make +extras = + tests + djopenid: djopenid + httplib2: httplib2 + pycurl: pycurl +commands = + pip install --editable . + pip list + make test-openid + djopenid: make test-djopenid + +[testenv:quality] +whitelist_externals = make +basepython = python2.7 +commands = + pip install --editable .[quality] + pip list + make check-all From 68fd827919cbfd4f0d2a1518a7454931258e592d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jan 2018 09:07:44 +0100 Subject: [PATCH 033/151] Replace failUnlessEqual --- examples/djopenid/server/tests.py | 12 +- openid/test/support.py | 2 +- openid/test/test_accept.py | 2 +- openid/test/test_association.py | 30 +- openid/test/test_association_response.py | 22 +- openid/test/test_auth_request.py | 22 +- openid/test/test_ax.py | 89 +++--- openid/test/test_consumer.py | 130 ++++----- openid/test/test_discover.py | 59 ++-- openid/test/test_etxrd.py | 8 +- openid/test/test_extension.py | 14 +- openid/test/test_htmldiscover.py | 3 +- openid/test/test_kvform.py | 12 +- openid/test/test_message.py | 262 ++++++++---------- openid/test/test_negotiation.py | 2 +- openid/test/test_nonce.py | 10 +- openid/test/test_openidyadis.py | 10 +- openid/test/test_pape_draft2.py | 111 ++++---- openid/test/test_pape_draft5.py | 189 +++++-------- openid/test/test_parsehtml.py | 2 +- openid/test/test_rpverify.py | 8 +- openid/test/test_server.py | 331 ++++++++++------------- openid/test/test_sreg.py | 148 +++++----- openid/test/test_symbol.py | 4 +- openid/test/test_trustroot.py | 2 +- openid/test/test_verifydisco.py | 22 +- openid/test/test_xri.py | 34 +-- openid/test/test_xrires.py | 18 +- openid/test/test_yadis_discover.py | 10 +- 29 files changed, 686 insertions(+), 882 deletions(-) diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index 8d0b8de4..b6bb5850 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -6,7 +6,7 @@ from django.urls import reverse from openid.message import Message -from openid.server.server import CheckIDRequest +from openid.server.server import CheckIDRequest, HTTP_REDIRECT from openid.yadis.constants import YADIS_CONTENT_TYPE from openid.yadis.services import applyFilter @@ -48,7 +48,7 @@ def test_allow(self): response = views.processTrustResult(self.request) - self.failUnlessEqual(response.status_code, 302) + self.assertEqual(response.status_code, HTTP_REDIRECT) finalURL = response['location'] self.failUnless('openid.mode=id_res' in finalURL, finalURL) self.failUnless('openid.identity=' in finalURL, finalURL) @@ -59,7 +59,7 @@ def test_cancel(self): response = views.processTrustResult(self.request) - self.failUnlessEqual(response.status_code, 302) + self.assertEqual(response.status_code, HTTP_REDIRECT) finalURL = response['location'] self.failUnless('openid.mode=cancel' in finalURL, finalURL) self.failIf('openid.identity=' in finalURL, finalURL) @@ -102,6 +102,6 @@ def test_genericRender(self): requested_url = 'http://requested.invalid/' (endpoint,) = applyFilter(requested_url, response.content) - self.failUnlessEqual(YADIS_CONTENT_TYPE, response['Content-Type']) - self.failUnlessEqual(type_uris, endpoint.type_uris) - self.failUnlessEqual(endpoint_url, endpoint.uri) + self.assertEqual(response['Content-Type'], YADIS_CONTENT_TYPE) + self.assertEqual(endpoint.type_uris, type_uris) + self.assertEqual(endpoint.uri, endpoint_url) diff --git a/openid/test/support.py b/openid/test/support.py index e864c899..04749042 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -24,7 +24,7 @@ def failUnlessOpenIDValueEquals(self, msg, key, expected, ns=None): actual = msg.getArg(ns, key) error_format = 'Wrong value for openid.%s: expected=%s, actual=%s' error_message = error_format % (key, expected, actual) - self.failUnlessEqual(expected, actual, error_message) + self.assertEqual(actual, expected, error_message) def failIfOpenIDKeyExists(self, msg, key, ns=None): if ns is None: diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 55e7eded..3f8d9fff 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -112,4 +112,4 @@ def runTest(self): accepted = accept.parseAcceptHeader(accept_header) actual = accept.matchTypes(accepted, available) - self.failUnlessEqual(expected, actual) + self.assertEqual(actual, expected) diff --git a/openid/test/test_association.py b/openid/test/test_association.py index e0598aff..12fd20cd 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -16,11 +16,11 @@ def test_roundTrip(self): 'handle', 'secret', issued, lifetime, 'HMAC-SHA1') s = assoc.serialize() assoc2 = association.Association.deserialize(s) - self.failUnlessEqual(assoc.handle, assoc2.handle) - self.failUnlessEqual(assoc.issued, assoc2.issued) - self.failUnlessEqual(assoc.secret, assoc2.secret) - self.failUnlessEqual(assoc.lifetime, assoc2.lifetime) - self.failUnlessEqual(assoc.assoc_type, assoc2.assoc_type) + self.assertEqual(assoc.handle, assoc2.handle) + self.assertEqual(assoc.issued, assoc2.issued) + self.assertEqual(assoc.secret, assoc2.secret) + self.assertEqual(assoc.lifetime, assoc2.lifetime) + self.assertEqual(assoc.assoc_type, assoc2.assoc_type) def createNonstandardConsumerDH(): @@ -50,7 +50,7 @@ def test(self): ssess = ssess_fact.fromMessage(msg) check_secret = csess.extractSecret( Message.fromOpenIDArgs(ssess.answer(secret))) - self.failUnlessEqual(secret, check_secret) + self.assertEqual(secret, check_secret) class TestMakePairs(unittest.TestCase): @@ -76,7 +76,7 @@ def testMakePairs(self): ('identifier', '=example'), ('mode', 'id_res'), ] - self.failUnlessEqual(pairs, expected) + self.assertEqual(pairs, expected) class TestMac(unittest.TestCase): @@ -90,7 +90,7 @@ def test_sha1(self): expected = ('\xe0\x1bv\x04\xf1G\xc0\xbb\x7f\x9a\x8b' '\xe9\xbc\xee}\\\xe5\xbb7*') sig = assoc.sign(self.pairs) - self.failUnlessEqual(sig, expected) + self.assertEqual(sig, expected) def test_sha256(self): assoc = association.Association.fromExpiresIn( @@ -98,7 +98,7 @@ def test_sha256(self): expected = ('\xfd\xaa\xfe;\xac\xfc*\x988\xad\x05d6-\xeaVy' '\xd5\xa5Z.<\xa9\xed\x18\x82\\$\x95x\x1c&') sig = assoc.sign(self.pairs) - self.failUnlessEqual(sig, expected) + self.assertEqual(sig, expected) class TestMessageSigning(unittest.TestCase): @@ -116,20 +116,16 @@ def test_signSHA1(self): 3600, '{sha1}', 'very_secret', "HMAC-SHA1") signed = assoc.signMessage(self.message) self.failUnless(signed.getArg(OPENID_NS, "sig")) - self.failUnlessEqual(signed.getArg(OPENID_NS, "signed"), - "assoc_handle,identifier,mode,ns,signed") - self.failUnlessEqual(signed.getArg(BARE_NS, "xey"), "value", - signed) + self.assertEqual(signed.getArg(OPENID_NS, "signed"), "assoc_handle,identifier,mode,ns,signed") + self.assertEqual(signed.getArg(BARE_NS, "xey"), "value") def test_signSHA256(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA256") signed = assoc.signMessage(self.message) self.failUnless(signed.getArg(OPENID_NS, "sig")) - self.failUnlessEqual(signed.getArg(OPENID_NS, "signed"), - "assoc_handle,identifier,mode,ns,signed") - self.failUnlessEqual(signed.getArg(BARE_NS, "xey"), "value", - signed) + self.assertEqual(signed.getArg(OPENID_NS, "signed"), "assoc_handle,identifier,mode,ns,signed") + self.assertEqual(signed.getArg(BARE_NS, "xey"), "value") class TestCheckMessageSignature(unittest.TestCase): diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 79be68c7..1b3e13c8 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -201,8 +201,7 @@ def _doTest(self, expected_session_type, session_type_value): 'to yield session type %r, but yielded %r' % (session_type_value, expected_session_type, actual_session_type)) - self.failUnlessEqual( - expected_session_type, actual_session_type, error_message) + self.assertEqual(expected_session_type, actual_session_type, error_message) test_none = mkTest( session_type_value=None, @@ -281,10 +280,10 @@ def test_worksWithGoodFields(self): assoc = self.consumer._extractAssociation( self.assoc_response, self.assoc_session) self.failUnless(self.assoc_session.extract_secret_called) - self.failUnlessEqual(self.assoc_session.secret, assoc.secret) - self.failUnlessEqual(1000, assoc.lifetime) - self.failUnlessEqual(self.assoc_handle, assoc.handle) - self.failUnlessEqual(self.assoc_type, assoc.assoc_type) + self.assertEqual(assoc.secret, self.assoc_session.secret) + self.assertEqual(assoc.lifetime, 1000) + self.assertEqual(assoc.handle, self.assoc_handle) + self.assertEqual(assoc.assoc_type, self.assoc_type) def test_badAssocType(self): # Make sure that the assoc type in the response is not valid @@ -311,8 +310,7 @@ def _setUpDH(self): self.endpoint, 'HMAC-SHA1', 'DH-SHA1') # XXX: this is testing _createAssociateRequest - self.failUnlessEqual(self.endpoint.compatibilityMode(), - message.isOpenID1()) + self.assertEqual(self.endpoint.compatibilityMode(), message.isOpenID1()) server_sess = DiffieHellmanSHA1ServerSession.fromMessage(message) server_resp = server_sess.answer(self.secret) @@ -326,10 +324,10 @@ def test_success(self): sess, server_resp = self._setUpDH() ret = self.consumer._extractAssociation(server_resp, sess) self.failIf(ret is None) - self.failUnlessEqual(ret.assoc_type, 'HMAC-SHA1') - self.failUnlessEqual(ret.secret, self.secret) - self.failUnlessEqual(ret.handle, 'handle') - self.failUnlessEqual(ret.lifetime, 1000) + self.assertEqual(ret.assoc_type, 'HMAC-SHA1') + self.assertEqual(ret.secret, self.secret) + self.assertEqual(ret.handle, 'handle') + self.assertEqual(ret.lifetime, 1000) def test_openid2success(self): # Use openid 2 type in endpoint so _setUpDH checks diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index c92ccb31..6f173812 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -49,11 +49,8 @@ def failUnlessAnonymous(self, msg): self.failIfOpenIDKeyExists(msg, key) def failUnlessHasRequiredFields(self, msg): - self.failUnlessEqual(self.preferred_namespace, - self.authreq.message.getOpenIDNamespace()) - - self.failUnlessEqual(self.preferred_namespace, - msg.getOpenIDNamespace()) + self.assertEqual(self.authreq.message.getOpenIDNamespace(), self.preferred_namespace) + self.assertEqual(msg.getOpenIDNamespace(), self.preferred_namespace) self.failUnlessOpenIDValueEquals(msg, 'mode', self.expected_mode) @@ -82,10 +79,8 @@ def test_checkWithAssocHandle(self): def test_addExtensionArg(self): self.authreq.addExtensionArg('bag:', 'color', 'brown') self.authreq.addExtensionArg('bag:', 'material', 'paper') - self.failUnless('bag:' in self.authreq.message.namespaces) - self.failUnlessEqual(self.authreq.message.getArgs('bag:'), - {'color': 'brown', - 'material': 'paper'}) + self.assertIn('bag:', self.authreq.message.namespaces) + self.assertEqual(self.authreq.message.getArgs('bag:'), {'color': 'brown', 'material': 'paper'}) msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) @@ -93,8 +88,8 @@ def test_addExtensionArg(self): # namespaces. Really it doesn't care that it has alias "0", # but that is tested anyway post_args = msg.toPostArgs() - self.failUnlessEqual('brown', post_args['openid.ext0.color']) - self.failUnlessEqual('paper', post_args['openid.ext0.material']) + self.assertEqual(post_args['openid.ext0.color'], 'brown') + self.assertEqual(post_args['openid.ext0.material'], 'paper') def test_standard(self): msg = self.authreq.getMessage(self.realm, self.return_to, @@ -117,7 +112,7 @@ def failUnlessIdentifiersPresent(self, msg): identity_present = msg.hasKey(message.OPENID_NS, 'identity') claimed_present = msg.hasKey(message.OPENID_NS, 'claimed_id') - self.failUnlessEqual(claimed_present, identity_present) + self.assertEqual(claimed_present, identity_present) def failUnlessHasIdentifiers(self, msg, op_specific_id, claimed_id): self.failUnlessOpenIDValueEquals(msg, 'identity', op_specific_id) @@ -191,8 +186,7 @@ def test_identifierSelect(self): msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) self.failUnlessHasRequiredFields(msg) - self.failUnlessEqual(message.IDENTIFIER_SELECT, - msg.getArg(message.OPENID1_NS, 'identity')) + self.assertEqual(msg.getArg(message.OPENID1_NS, 'identity'), message.IDENTIFIER_SELECT) class TestAuthRequestOpenID1Immediate(TestAuthRequestOpenID1): diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 80be5c8c..4c107e3f 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -25,8 +25,8 @@ def setUp(self): def test_checkMode(self): check = self.bax._checkMode - self.failUnlessRaises(ax.NotAXMessage, check, {}) - self.failUnlessRaises(ax.AXError, check, {'mode': 'fetch_request'}) + self.assertRaises(ax.NotAXMessage, check, {}) + self.assertRaises(ax.AXError, check, {'mode': 'fetch_request'}) # does not raise an exception when the mode is right check({'mode': self.bax.mode}) @@ -39,14 +39,14 @@ def test_checkMode_newArgs(self): class AttrInfoTest(unittest.TestCase): def test_construct(self): - self.failUnlessRaises(TypeError, ax.AttrInfo) + self.assertRaises(TypeError, ax.AttrInfo) type_uri = 'a uri' ainfo = ax.AttrInfo(type_uri) - self.failUnlessEqual(type_uri, ainfo.type_uri) - self.failUnlessEqual(1, ainfo.count) + self.assertEqual(ainfo.type_uri, type_uri) + self.assertEqual(ainfo.count, 1) self.failIf(ainfo.required) - self.failUnless(ainfo.alias is None) + self.assertIsNone(ainfo.alias) class ToTypeURIsTest(unittest.TestCase): @@ -56,7 +56,7 @@ def setUp(self): def test_empty(self): for empty in [None, '']: uris = ax.toTypeURIs(self.aliases, empty) - self.failUnlessEqual([], uris) + self.assertEqual(uris, []) def test_undefined(self): self.failUnlessRaises( @@ -68,7 +68,7 @@ def test_one(self): alias = 'openid_hackers' self.aliases.addAlias(uri, alias) uris = ax.toTypeURIs(self.aliases, alias) - self.failUnlessEqual([uri], uris) + self.assertEqual(uris, [uri]) def test_two(self): uri1 = 'http://janrain.com/' @@ -80,7 +80,7 @@ def test_two(self): self.aliases.addAlias(uri2, alias2) uris = ax.toTypeURIs(self.aliases, ','.join([alias1, alias2])) - self.failUnlessEqual([uri1, uri2], uris) + self.assertEqual(uris, [uri1, uri2]) class ParseAXValuesTest(unittest.TestCase): @@ -94,7 +94,7 @@ def failUnlessAXValues(self, ax_args, expected_args): """Fail unless parseExtensionArgs(ax_args) == expected_args.""" msg = ax.AXKeyValueMessage() msg.parseExtensionArgs(ax_args) - self.failUnlessEqual(expected_args, msg.data) + self.assertEqual(msg.data, expected_args) def test_emptyIsValid(self): self.failUnlessAXValues({}, {}) @@ -201,15 +201,15 @@ def setUp(self): self.alias_a = 'a' def test_mode(self): - self.failUnlessEqual(self.msg.mode, 'fetch_request') + self.assertEqual(self.msg.mode, 'fetch_request') def test_construct(self): - self.failUnlessEqual({}, self.msg.requested_attributes) - self.failUnlessEqual(None, self.msg.update_url) + self.assertEqual(self.msg.requested_attributes, {}) + self.assertIsNone(self.msg.update_url) msg = ax.FetchRequest('hailstorm') - self.failUnlessEqual({}, msg.requested_attributes) - self.failUnlessEqual('hailstorm', msg.update_url) + self.assertEqual(msg.requested_attributes, {}) + self.assertEqual(msg.update_url, 'hailstorm') def test_add(self): uri = 'mud://puddle' @@ -234,7 +234,7 @@ def test_getExtensionArgs_empty(self): expected_args = { 'mode': 'fetch_request', } - self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) + self.assertEqual(self.msg.getExtensionArgs(), expected_args) def test_getExtensionArgs_noAlias(self): attr = ax.AttrInfo( @@ -284,14 +284,14 @@ def failUnlessExtensionArgs(self, expected_args): """ expected_args = dict(expected_args) expected_args['mode'] = self.msg.mode - self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) + self.assertEqual(self.msg.getExtensionArgs(), expected_args) def test_isIterable(self): - self.failUnlessEqual([], list(self.msg)) - self.failUnlessEqual([], list(self.msg.iterAttrs())) + self.assertEqual(list(self.msg), []) + self.assertEqual(list(self.msg.iterAttrs()), []) def test_getRequiredAttrs_empty(self): - self.failUnlessEqual([], self.msg.getRequiredAttrs()) + self.assertEqual(self.msg.getRequiredAttrs(), []) def test_parseExtensionArgs_extraType(self): extension_args = { @@ -309,13 +309,13 @@ def test_parseExtensionArgs(self): } self.msg.parseExtensionArgs(extension_args) self.failUnless(self.type_a in self.msg) - self.failUnlessEqual([self.type_a], list(self.msg)) + self.assertEqual(list(self.msg), [self.type_a]) attr_info = self.msg.requested_attributes.get(self.type_a) self.failUnless(attr_info) self.failIf(attr_info.required) - self.failUnlessEqual(self.type_a, attr_info.type_uri) - self.failUnlessEqual(self.alias_a, attr_info.alias) - self.failUnlessEqual([attr_info], list(self.msg.iterAttrs())) + self.assertEqual(attr_info.type_uri, self.type_a) + self.assertEqual(attr_info.alias, self.alias_a) + self.assertEqual(list(self.msg.iterAttrs()), [attr_info]) def test_extensionArgs_idempotent(self): extension_args = { @@ -324,7 +324,7 @@ def test_extensionArgs_idempotent(self): 'if_available': self.alias_a } self.msg.parseExtensionArgs(extension_args) - self.failUnlessEqual(extension_args, self.msg.getExtensionArgs()) + self.assertEqual(self.msg.getExtensionArgs(), extension_args) self.failIf(self.msg.requested_attributes[self.type_a].required) def test_extensionArgs_idempotent_count_required(self): @@ -335,7 +335,7 @@ def test_extensionArgs_idempotent_count_required(self): 'required': self.alias_a } self.msg.parseExtensionArgs(extension_args) - self.failUnlessEqual(extension_args, self.msg.getExtensionArgs()) + self.assertEqual(self.msg.getExtensionArgs(), extension_args) self.failUnless(self.msg.requested_attributes[self.type_a].required) def test_extensionArgs_count1(self): @@ -351,7 +351,7 @@ def test_extensionArgs_count1(self): 'if_available': self.alias_a, } self.msg.parseExtensionArgs(extension_args) - self.failUnlessEqual(extension_args_norm, self.msg.getExtensionArgs()) + self.assertEqual(self.msg.getExtensionArgs(), extension_args_norm) def test_openidNoRealm(self): openid_req_msg = Message.fromOpenIDArgs({ @@ -438,13 +438,13 @@ def setUp(self): def test_construct(self): self.failUnless(self.msg.update_url is None) - self.failUnlessEqual({}, self.msg.data) + self.assertEqual(self.msg.data, {}) def test_getExtensionArgs_empty(self): expected_args = { 'mode': 'fetch_response', } - self.failUnlessEqual(expected_args, self.msg.getExtensionArgs()) + self.assertEqual(self.msg.getExtensionArgs(), expected_args) def test_getExtensionArgs_empty_request(self): expected_args = { @@ -452,7 +452,7 @@ def test_getExtensionArgs_empty_request(self): } req = ax.FetchRequest() msg = ax.FetchResponse(request=req) - self.failUnlessEqual(expected_args, msg.getExtensionArgs()) + self.assertEqual(msg.getExtensionArgs(), expected_args) def test_getExtensionArgs_empty_request_some(self): uri = 'http://not.found/' @@ -466,7 +466,7 @@ def test_getExtensionArgs_empty_request_some(self): req = ax.FetchRequest() req.add(ax.AttrInfo(uri)) msg = ax.FetchResponse(request=req) - self.failUnlessEqual(expected_args, msg.getExtensionArgs()) + self.assertEqual(msg.getExtensionArgs(), expected_args) def test_updateUrlInResponse(self): uri = 'http://not.found/' @@ -481,7 +481,7 @@ def test_updateUrlInResponse(self): req = ax.FetchRequest(update_url=self.request_update_url) req.add(ax.AttrInfo(uri)) msg = ax.FetchResponse(request=req) - self.failUnlessEqual(expected_args, msg.getExtensionArgs()) + self.assertEqual(msg.getExtensionArgs(), expected_args) def test_getExtensionArgs_some_request(self): expected_args = { @@ -494,7 +494,7 @@ def test_getExtensionArgs_some_request(self): req.add(ax.AttrInfo(self.type_a, alias=self.alias_a)) msg = ax.FetchResponse(request=req) msg.addValue(self.type_a, self.value_a) - self.failUnlessEqual(expected_args, msg.getExtensionArgs()) + self.assertEqual(msg.getExtensionArgs(), expected_args) def test_getExtensionArgs_some_not_request(self): req = ax.FetchRequest() @@ -504,10 +504,10 @@ def test_getExtensionArgs_some_not_request(self): def test_getSingle_success(self): self.msg.addValue(self.type_a, self.value_a) - self.failUnlessEqual(self.value_a, self.msg.getSingle(self.type_a)) + self.assertEqual(self.msg.getSingle(self.type_a), self.value_a) def test_getSingle_none(self): - self.failUnlessEqual(None, self.msg.getSingle(self.type_a)) + self.assertIsNone(self.msg.getSingle(self.type_a)) def test_getSingle_extra(self): self.msg.setValues(self.type_a, ['x', 'y']) @@ -574,7 +574,7 @@ class Endpoint: resp = SuccessResponse(Endpoint(), msg, signed_fields=sf) ax_resp = ax.FetchResponse.fromSuccessResponse(resp) values = ax_resp.get(uri) - self.failUnlessEqual([value], values) + self.assertEqual(values, [value]) class StoreRequestTest(unittest.TestCase): @@ -584,14 +584,14 @@ def setUp(self): self.alias_a = 'juggling' def test_construct(self): - self.failUnlessEqual({}, self.msg.data) + self.assertEqual(self.msg.data, {}) def test_getExtensionArgs_empty(self): args = self.msg.getExtensionArgs() expected_args = { 'mode': 'store_request', } - self.failUnlessEqual(expected_args, args) + self.assertEqual(args, expected_args) def test_getExtensionArgs_nonempty(self): aliases = NamespaceMap() @@ -606,7 +606,7 @@ def test_getExtensionArgs_nonempty(self): 'value.%s.1' % (self.alias_a,): 'foo', 'value.%s.2' % (self.alias_a,): 'bar', } - self.failUnlessEqual(expected_args, args) + self.assertEqual(args, expected_args) class StoreResponseTest(unittest.TestCase): @@ -614,20 +614,17 @@ def test_success(self): msg = ax.StoreResponse() self.failUnless(msg.succeeded()) self.failIf(msg.error_message) - self.failUnlessEqual({'mode': 'store_response_success'}, - msg.getExtensionArgs()) + self.assertEqual(msg.getExtensionArgs(), {'mode': 'store_response_success'}) def test_fail_nomsg(self): msg = ax.StoreResponse(False) self.failIf(msg.succeeded()) self.failIf(msg.error_message) - self.failUnlessEqual({'mode': 'store_response_failure'}, - msg.getExtensionArgs()) + self.assertEqual(msg.getExtensionArgs(), {'mode': 'store_response_failure'}) def test_fail_msg(self): reason = 'no reason, really' msg = ax.StoreResponse(False, reason) self.failIf(msg.succeeded()) - self.failUnlessEqual(reason, msg.error_message) - self.failUnlessEqual({'mode': 'store_response_failure', - 'error': reason}, msg.getExtensionArgs()) + self.assertEqual(msg.error_message, reason) + self.assertEqual(msg.getExtensionArgs(), {'mode': 'store_response_failure', 'error': reason}) diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 05496638..dce7075b 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -391,20 +391,20 @@ def raiseSetupNeeded(msg): self.consumer._checkSetupNeeded = raiseSetupNeeded response = self.consumer.complete(message, None, None) - self.failUnlessEqual(SETUP_NEEDED, response.status) + self.assertEqual(response.status, SETUP_NEEDED) self.failUnless(setup_url_sentinel is response.setup_url) def test_cancel(self): message = Message.fromPostArgs({'openid.mode': 'cancel'}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) - self.failUnlessEqual(r.status, CANCEL) + self.assertEqual(r.status, CANCEL) self.failUnless(r.identity_url == self.endpoint.claimed_id) def test_cancel_with_return_to(self): message = Message.fromPostArgs({'openid.mode': 'cancel'}) r = self.consumer.complete(message, self.endpoint, self.return_to) - self.failUnlessEqual(r.status, CANCEL) + self.assertEqual(r.status, CANCEL) self.failUnless(r.identity_url == self.endpoint.claimed_id) def test_error(self): @@ -412,9 +412,9 @@ def test_error(self): message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) - self.failUnlessEqual(r.status, FAILURE) + self.assertEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) - self.failUnlessEqual(r.message, msg) + self.assertEqual(r.message, msg) def test_errorWithNoOptionalKeys(self): msg = 'an error message' @@ -422,11 +422,11 @@ def test_errorWithNoOptionalKeys(self): message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, 'openid.contact': contact}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) - self.failUnlessEqual(r.status, FAILURE) + self.assertEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) self.failUnless(r.contact == contact) self.failUnless(r.reference is None) - self.failUnlessEqual(r.message, msg) + self.assertEqual(r.message, msg) def test_errorWithOptionalKeys(self): msg = 'an error message' @@ -435,16 +435,16 @@ def test_errorWithOptionalKeys(self): message = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': msg, 'openid.reference': reference, 'openid.contact': contact, 'openid.ns': OPENID2_NS}) r = self.consumer.complete(message, self.endpoint, None) - self.failUnlessEqual(r.status, FAILURE) + self.assertEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) self.failUnless(r.contact == contact) self.failUnless(r.reference == reference) - self.failUnlessEqual(r.message, msg) + self.assertEqual(r.message, msg) def test_noMode(self): message = Message.fromPostArgs({}) r = self.consumer.complete(message, self.endpoint, None) - self.failUnlessEqual(r.status, FAILURE) + self.assertEqual(r.status, FAILURE) self.failUnless(r.identity_url == self.endpoint.claimed_id) def test_idResMissingField(self): @@ -533,22 +533,22 @@ def test_idResNoIdentity(self): def test_idResMissingIdentitySig(self): self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) - self.failUnlessEqual(r.status, FAILURE) + self.assertEqual(r.status, FAILURE) def test_idResMissingReturnToSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,assoc_handle,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) - self.failUnlessEqual(r.status, FAILURE) + self.assertEqual(r.status, FAILURE) def test_idResMissingAssocHandleSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,claimed_id') r = self.consumer.complete(self.message, self.endpoint, None) - self.failUnlessEqual(r.status, FAILURE) + self.assertEqual(r.status, FAILURE) def test_idResMissingClaimedIDSig(self): self.message.setArg(OPENID_NS, 'signed', 'identity,response_nonce,return_to,assoc_handle') r = self.consumer.complete(self.message, self.endpoint, None) - self.failUnlessEqual(r.status, FAILURE) + self.assertEqual(r.status, FAILURE) def failUnlessSuccess(self, response): if response.status != SUCCESS: @@ -571,7 +571,7 @@ def _createAssoc(self): store = self.consumer.store store.storeAssociation(self.server_url, assoc) assoc2 = store.getAssociation(self.server_url) - self.failUnlessEqual(assoc, assoc2) + self.assertEqual(assoc, assoc2) def test_goodResponse(self): """successful response to check_authentication""" @@ -657,12 +657,9 @@ def test_invalidatePresent(self): class TestSetupNeeded(TestIdRes): def failUnlessSetupNeeded(self, expected_setup_url, message): - try: + with self.assertRaises(SetupNeededError) as catch: self.consumer._checkSetupNeeded(message) - except SetupNeededError as why: - self.failUnlessEqual(expected_setup_url, why.user_setup_url) - else: - self.fail("Expected to find an immediate-mode response") + self.assertEqual(catch.exception.user_setup_url, expected_setup_url) def test_setupNeededOpenID1(self): """The minimum conditions necessary to trigger Setup Needed""" @@ -701,8 +698,8 @@ def test_setupNeededOpenID2(self): }) self.failUnless(message.isOpenID2()) response = self.consumer.complete(message, None, None) - self.failUnlessEqual('setup_needed', response.status) - self.failUnlessEqual(None, response.setup_url) + self.assertEqual(response.status, 'setup_needed') + self.assertIsNone(response.setup_url) def test_setupNeededDoesntWorkForOpenID1(self): message = Message.fromOpenIDArgs({ @@ -713,7 +710,7 @@ def test_setupNeededDoesntWorkForOpenID1(self): self.consumer._checkSetupNeeded(message) response = self.consumer.complete(message, None, None) - self.failUnlessEqual('failure', response.status) + self.assertEqual(response.status, 'failure') self.failUnless(response.message.startswith('Invalid openid.mode')) def test_noSetupNeededOpenID2(self): @@ -1029,8 +1026,8 @@ def test_newerAssoc(self): message = good_assoc.signMessage(message) self.disableReturnToChecking() info = self.consumer._doIdRes(message, self.endpoint, None) - self.failUnlessEqual(info.status, SUCCESS, info.message) - self.failUnlessEqual(self.consumer_id, info.identity_url) + self.assertEqual(info.status, SUCCESS, info.message) + self.assertEqual(info.identity_url, self.consumer_id) class TestReturnToArgs(unittest.TestCase): @@ -1259,7 +1256,7 @@ def test_112(self): 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,' 'ns.pape,pape.nist_auth_level,pape.auth_policies'} - self.failUnlessEqual(OPENID2_NS, args['openid.ns']) + self.assertEqual(args['openid.ns'], OPENID2_NS) incoming = Message.fromPostArgs(args) self.failUnless(incoming.isOpenID2()) car = self.consumer._createCheckAuthRequest(incoming) @@ -1267,8 +1264,8 @@ def test_112(self): expected_args['openid.mode'] = 'check_authentication' expected = Message.fromPostArgs(expected_args) self.failUnless(expected.isOpenID2()) - self.failUnlessEqual(expected, car) - self.failUnlessEqual(expected_args, car.toPostArgs()) + self.assertEqual(car, expected) + self.assertEqual(car.toPostArgs(), expected_args) class TestFetchAssoc(unittest.TestCase, CatchLogs): @@ -1349,9 +1346,9 @@ def test_extensionResponse(self): 'return_to': 'return_to', }) utargs = resp.extensionResponse('urn:unittest', False) - self.failUnlessEqual(utargs, {'one': '1', 'two': '2'}) + self.assertEqual(utargs, {'one': '1', 'two': '2'}) sregargs = resp.extensionResponse('urn:sreg', False) - self.failUnlessEqual(sregargs, {'nickname': 'j3h'}) + self.assertEqual(sregargs, {'nickname': 'j3h'}) def test_extensionResponseSigned(self): args = { @@ -1376,12 +1373,12 @@ def test_extensionResponseSigned(self): # All args in this NS are signed, so expect all. sregargs = resp.extensionResponse('urn:sreg', True) - self.failUnlessEqual(sregargs, {'nickname': 'j3h', 'dob': 'yesterday'}) + self.assertEqual(sregargs, {'nickname': 'j3h', 'dob': 'yesterday'}) # Not all args in this NS are signed, so expect None when # asking for them. utargs = resp.extensionResponse('urn:unittest', True) - self.failUnlessEqual(utargs, None) + self.assertIsNone(utargs) def test_noReturnTo(self): resp = mkSuccess(self.endpoint, {}) @@ -1389,18 +1386,16 @@ def test_noReturnTo(self): def test_returnTo(self): resp = mkSuccess(self.endpoint, {'return_to': 'return_to'}) - self.failUnlessEqual(resp.getReturnTo(), 'return_to') + self.assertEqual(resp.getReturnTo(), 'return_to') def test_displayIdentifierClaimedId(self): resp = mkSuccess(self.endpoint, {}) - self.failUnlessEqual(resp.getDisplayIdentifier(), - resp.endpoint.claimed_id) + self.assertEqual(resp.getDisplayIdentifier(), resp.endpoint.claimed_id) def test_displayIdentifierOverride(self): self.endpoint.display_identifier = "http://input.url/" resp = mkSuccess(self.endpoint, {}) - self.failUnlessEqual(resp.getDisplayIdentifier(), - "http://input.url/") + self.assertEqual(resp.getDisplayIdentifier(), "http://input.url/") class StubConsumer(object): @@ -1440,11 +1435,9 @@ def test_setAssociationPreference(self): self.consumer.setAssociationPreference([]) self.failUnless(isinstance(self.consumer.consumer.negotiator, association.SessionNegotiator)) - self.failUnlessEqual([], - self.consumer.consumer.negotiator.allowed_types) + self.assertEqual(self.consumer.consumer.negotiator.allowed_types, []) self.consumer.setAssociationPreference([('HMAC-SHA1', 'DH-SHA1')]) - self.failUnlessEqual([('HMAC-SHA1', 'DH-SHA1')], - self.consumer.consumer.negotiator.allowed_types) + self.assertEqual(self.consumer.consumer.negotiator.allowed_types, [('HMAC-SHA1', 'DH-SHA1')]) def withDummyDiscovery(self, callable, dummy_getNextService): class DummyDisco(object): @@ -1520,8 +1513,8 @@ def checkEndpoint(message, endpoint, return_to): self.consumer.consumer.complete = checkEndpoint response = self.consumer.complete({}, None) - self.failUnlessEqual(response.status, FAILURE) - self.failUnlessEqual(response.message, text) + self.assertEqual(response.status, FAILURE) + self.assertEqual(response.message, text) self.failUnless(response.identity_url is None) def _doResp(self, auth_req, exp_resp): @@ -1543,7 +1536,7 @@ def _doResp(self, auth_req, exp_resp): self.failIf(self.consumer._token_key in self.session) # Expected status response - self.failUnlessEqual(resp.status, exp_resp.status) + self.assertEqual(resp.status, exp_resp.status) return resp @@ -1674,10 +1667,10 @@ def verifyDiscoveryResults(identifier, endpoint): response = self.consumer._doIdRes(message, self.endpoint, None) self.failUnlessSuccess(response) - self.failUnlessEqual(response.identity_url, "=directed_identifier") + self.assertEqual(response.identity_url, "=directed_identifier") # assert that discovery attempt happens and returns good - self.failUnlessEqual(iverified, [discovered_endpoint]) + self.assertEqual(iverified, [discovered_endpoint]) def test_idpDrivenCompleteFraud(self): # crap with an identifier that doesn't match discovery info @@ -1732,15 +1725,15 @@ def test_theGoodStuff(self): self.services = [endpoint] r = self.consumer._verifyDiscoveryResults(self.message, endpoint) - self.failUnlessEqual(r, endpoint) + self.assertEqual(r, endpoint) def test_otherServer(self): text = "verify failed" def discoverAndVerify(claimed_id, to_match_endpoints): - self.failUnlessEqual(claimed_id, self.identifier) + self.assertEqual(claimed_id, self.identifier) for to_match in to_match_endpoints: - self.failUnlessEqual(claimed_id, to_match.claimed_id) + self.assertEqual(claimed_id, to_match.claimed_id) raise ProtocolError(text) self.consumer._discoverAndVerify = discoverAndVerify @@ -1764,9 +1757,9 @@ def test_foreignDelegate(self): text = "verify failed" def discoverAndVerify(claimed_id, to_match_endpoints): - self.failUnlessEqual(claimed_id, self.identifier) + self.assertEqual(claimed_id, self.identifier) for to_match in to_match_endpoints: - self.failUnlessEqual(claimed_id, to_match.claimed_id) + self.assertEqual(claimed_id, to_match.claimed_id) raise ProtocolError(text) self.consumer._discoverAndVerify = discoverAndVerify @@ -1778,12 +1771,8 @@ def discoverAndVerify(claimed_id, to_match_endpoints): endpoint.server_url = self.server_url endpoint.local_id = "http://unittest/juan-carlos" - try: - r = self.consumer._verifyDiscoveryResults(self.message, endpoint) - except ProtocolError as e: - self.failUnlessEqual(str(e), text) - else: - self.fail("Exepected ProtocolError, %r returned" % (r,)) + with self.assertRaisesRegexp(ProtocolError, text): + self.consumer._verifyDiscoveryResults(self.message, endpoint) def test_nothingDiscovered(self): # a set of no things. @@ -1821,7 +1810,7 @@ def test_noEncryptionSendsType(self): 'assoc_type': self.assoc_type, }) - self.failUnlessEqual(expected, args) + self.assertEqual(args, expected) def test_noEncryptionCompatibility(self): self.endpoint.use_compatibility = True @@ -1830,9 +1819,7 @@ def test_noEncryptionCompatibility(self): self.endpoint, self.assoc_type, session_type) self.failUnless(isinstance(session, PlainTextConsumerSession)) - self.failUnlessEqual( - Message.fromOpenIDArgs({'mode': 'associate', 'assoc_type': self.assoc_type}), - args) + self.assertEqual(args, Message.fromOpenIDArgs({'mode': 'associate', 'assoc_type': self.assoc_type})) def test_dhSHA1Compatibility(self): # Set the consumer's session type to a fast session since we @@ -1860,7 +1847,7 @@ def test_dhSHA1Compatibility(self): 'dh_gen': 'Ag==', }) - self.failUnlessEqual(expected, args) + self.assertEqual(args, expected) # XXX: test the other types @@ -1893,7 +1880,7 @@ def testExtractSecret(self): self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) extracted = self.consumer_session.extractSecret(self.msg) - self.failUnlessEqual(extracted, self.secret) + self.assertEqual(extracted, self.secret) def testAbsentServerPublic(self): self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) @@ -2023,7 +2010,7 @@ def returnTrue(unused1, unused2): # first endpoint that we passed in as a result. result = self.consumer._discoverAndVerify( 'http://claimed.id/', [self.to_match]) - self.failUnlessEqual(matching_endpoint, result) + self.assertEqual(result, matching_endpoint) class SillyExtension(Extension): @@ -2041,7 +2028,7 @@ def test_SillyExtension(self): ar = AuthRequest(OpenIDServiceEndpoint(), None) ar.addExtension(ext) ext_args = ar.message.getArgs(ext.ns_uri) - self.failUnlessEqual(ext.getExtensionArgs(), ext_args) + self.assertEqual(ext_args, ext.getExtensionArgs()) class TestKVPost(unittest.TestCase): @@ -2055,19 +2042,16 @@ def test_200(self): response.body = "foo:bar\nbaz:quux\n" r = _httpResponseToMessage(response, self.server_url) expected_msg = Message.fromOpenIDArgs({'foo': 'bar', 'baz': 'quux'}) - self.failUnlessEqual(expected_msg, r) + self.assertEqual(r, expected_msg) def test_400(self): response = HTTPResponse() response.status = 400 response.body = "error:bonk\nerror_code:7\n" - try: - r = _httpResponseToMessage(response, self.server_url) - except ServerError as e: - self.failUnlessEqual(e.error_text, 'bonk') - self.failUnlessEqual(e.error_code, '7') - else: - self.fail("Expected ServerError, got return %r" % (r,)) + with self.assertRaises(ServerError) as catch: + _httpResponseToMessage(response, self.server_url) + self.assertEqual(catch.exception.error_text, 'bonk') + self.assertEqual(catch.exception.error_code, '7') def test_500(self): # 500 as an example of any non-200, non-400 code. diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 92fc6960..4a214686 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -37,12 +37,9 @@ class TestDiscoveryFailure(unittest.TestCase): ] def runOneTest(self, url, expected_status): - try: + with self.assertRaises(DiscoveryFailure) as catch: discover.discover(url) - except DiscoveryFailure as why: - self.failUnlessEqual(why.http_response.status, expected_status) - else: - self.fail('Did not raise DiscoveryFailure') + self.assertEqual(catch.exception.http_response.status, expected_status) def test(self): for responses in self.cases: @@ -164,7 +161,7 @@ def _checkService(self, s, used_yadis=False, display_identifier=None ): - self.failUnlessEqual(server_url, s.server_url) + self.assertEqual(s.server_url, server_url) if types == ['2.0 OP']: self.failIf(claimed_id) self.failIf(local_id) @@ -173,11 +170,10 @@ def _checkService(self, s, self.failIf(s.getLocalID()) self.failIf(s.compatibilityMode()) self.failUnless(s.isOPIdentifier()) - self.failUnlessEqual(s.preferredNamespace(), - discover.OPENID_2_0_MESSAGE_NS) + self.assertEqual(s.preferredNamespace(), discover.OPENID_2_0_MESSAGE_NS) else: - self.failUnlessEqual(claimed_id, s.claimed_id) - self.failUnlessEqual(local_id, s.getLocalID()) + self.assertEqual(s.claimed_id, claimed_id) + self.assertEqual(s.getLocalID(), local_id) if used_yadis: self.failUnless(s.used_yadis, "Expected to use Yadis") @@ -193,16 +189,16 @@ def _checkService(self, s, } type_uris = [openid_types[t] for t in types] - self.failUnlessEqual(type_uris, s.type_uris) - self.failUnlessEqual(canonical_id, s.canonicalID) + self.assertEqual(s.type_uris, type_uris) + self.assertEqual(s.canonicalID, canonical_id) if s.canonicalID: self.failUnless(s.getDisplayIdentifier() != claimed_id) self.failUnless(s.getDisplayIdentifier() is not None) - self.failUnlessEqual(display_identifier, s.getDisplayIdentifier()) - self.failUnlessEqual(s.claimed_id, s.canonicalID) + self.assertEqual(s.getDisplayIdentifier(), display_identifier) + self.assertEqual(s.canonicalID, s.claimed_id) - self.failUnlessEqual(s.display_identifier or s.claimed_id, s.getDisplayIdentifier()) + self.assertEqual(s.display_identifier or s.claimed_id, s.getDisplayIdentifier()) def setUp(self): self.documents = self.documents.copy() @@ -228,8 +224,8 @@ def _discover(self, content_type, data, self.documents[self.id_url] = (content_type, data) id_url, services = discover.discover(self.id_url) - self.failUnlessEqual(expected_services, len(services)) - self.failUnlessEqual(expected_id, id_url) + self.assertEqual(len(services), expected_services) + self.assertEqual(id_url, expected_id) return services def test_404(self): @@ -312,8 +308,8 @@ def test_html1Fragment(self): expected_id = self.id_url self.id_url = self.id_url + '#fragment' id_url, services = discover.discover(self.id_url) - self.failUnlessEqual(expected_services, len(services)) - self.failUnlessEqual(expected_id, id_url) + self.assertEqual(len(services), expected_services) + self.assertEqual(id_url, expected_id) self._checkService( services[0], @@ -596,7 +592,7 @@ def test_useCanonicalID(self): endpoint = discover.OpenIDServiceEndpoint() endpoint.claimed_id = XRI("=!1000") endpoint.canonicalID = XRI("=!1000") - self.failUnlessEqual(endpoint.getLocalID(), XRI("=!1000")) + self.assertEqual(endpoint.getLocalID(), XRI("=!1000")) class TestXRIDiscoveryIDP(BaseTestDiscovery): @@ -608,8 +604,7 @@ class TestXRIDiscoveryIDP(BaseTestDiscovery): def test_xri(self): user_xri, services = discover.discoverXRI('=smoker') self.failUnless(services, "Expected services, got zero") - self.failUnlessEqual(services[0].server_url, - "http://www.livejournal.com/openid/server.bml") + self.assertEqual(services[0].server_url, "http://www.livejournal.com/openid/server.bml") class TestPreferredNamespace(unittest.TestCase): @@ -619,7 +614,7 @@ def test(self): endpoint = discover.OpenIDServiceEndpoint() endpoint.type_uris = type_uris actual_ns = endpoint.preferredNamespace() - self.failUnlessEqual(actual_ns, expected_ns) + self.assertEqual(actual_ns, expected_ns) cases = [ (message.OPENID1_NS, []), @@ -680,17 +675,17 @@ def test_isOPEndpoint(self): self.failUnless(self.endpoint.isOPIdentifier()) def test_noIdentifiers(self): - self.failUnlessEqual(self.endpoint.getLocalID(), None) - self.failUnlessEqual(self.endpoint.claimed_id, None) + self.assertIsNone(self.endpoint.getLocalID()) + self.assertIsNone(self.endpoint.claimed_id) def test_compatibility(self): self.failIf(self.endpoint.compatibilityMode()) def test_canonicalID(self): - self.failUnlessEqual(self.endpoint.canonicalID, None) + self.assertIsNone(self.endpoint.canonicalID) def test_serverURL(self): - self.failUnlessEqual(self.endpoint.server_url, self.op_endpoint_url) + self.assertEqual(self.endpoint.server_url, self.op_endpoint_url) class TestDiscoverFunction(unittest.TestCase): @@ -712,16 +707,16 @@ def discoverURI(self, identifier): return 'URI' def test_uri(self): - self.failUnlessEqual('URI', discover.discover('http://woo!')) + self.assertEqual(discover.discover('http://woo!'), 'URI') def test_uriForBogus(self): - self.failUnlessEqual('URI', discover.discover('not a URL or XRI')) + self.assertEqual(discover.discover('not a URL or XRI'), 'URI') def test_xri(self): - self.failUnlessEqual('XRI', discover.discover('xri://=something')) + self.assertEqual(discover.discover('xri://=something'), 'XRI') def test_xriChar(self): - self.failUnlessEqual('XRI', discover.discover('=something')) + self.assertEqual(discover.discover('=something'), 'XRI') class TestEndpointSupportsType(unittest.TestCase): @@ -783,4 +778,4 @@ class TestEndpointDisplayIdentifier(unittest.TestCase): def test_strip_fragment(self): endpoint = discover.OpenIDServiceEndpoint() endpoint.claimed_id = 'http://recycled.invalid/#123' - self.failUnlessEqual('http://recycled.invalid/', endpoint.getDisplayIdentifier()) + self.assertEqual(endpoint.getDisplayIdentifier(), 'http://recycled.invalid/') diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index cae2712c..3e349dcc 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -58,8 +58,8 @@ def testParseOpenID(self): it = iter(services) for (server_url, delegate) in expectedServices: for (actual_url, actual_delegate) in it: - self.failUnlessEqual(server_url, actual_url) - self.failUnlessEqual(delegate, actual_delegate) + self.assertEqual(actual_url, server_url) + self.assertEqual(actual_delegate, delegate) break else: self.fail('Not enough services found') @@ -71,7 +71,7 @@ def _checkServices(self, expectedServices): for (type_uri, uri) in expectedServices: for service in it: if type_uri in service.type_uris: - self.failUnlessEqual(service.uri, uri) + self.assertEqual(service.uri, uri) break else: self.fail('Did not find %r service' % (type_uri,)) @@ -184,7 +184,7 @@ def test(self): def _getCanonicalID(self, iname, xrds, expectedID): if isinstance(expectedID, (str, unicode, type(None))): cid = etxrd.getCanonicalID(iname, xrds) - self.failUnlessEqual(cid, expectedID and xri.XRI(expectedID)) + self.assertEqual(cid, expectedID and xri.XRI(expectedID)) elif issubclass(expectedID, etxrd.XRDSError): self.failUnlessRaises(expectedID, etxrd.getCanonicalID, iname, xrds) diff --git a/openid/test/test_extension.py b/openid/test/test_extension.py index 0f714c62..487968fd 100644 --- a/openid/test/test_extension.py +++ b/openid/test/test_extension.py @@ -18,11 +18,8 @@ def test_OpenID1(self): ext.toMessage(oid1_msg) namespaces = oid1_msg.namespaces self.failUnless(namespaces.isImplicit(DummyExtension.ns_uri)) - self.failUnlessEqual( - DummyExtension.ns_uri, - namespaces.getNamespaceURI(DummyExtension.ns_alias)) - self.failUnlessEqual(DummyExtension.ns_alias, - namespaces.getAlias(DummyExtension.ns_uri)) + self.assertEqual(DummyExtension.ns_uri, namespaces.getNamespaceURI(DummyExtension.ns_alias)) + self.assertEqual(DummyExtension.ns_alias, namespaces.getAlias(DummyExtension.ns_uri)) def test_OpenID2(self): oid2_msg = message.Message(message.OPENID2_NS) @@ -30,8 +27,5 @@ def test_OpenID2(self): ext.toMessage(oid2_msg) namespaces = oid2_msg.namespaces self.failIf(namespaces.isImplicit(DummyExtension.ns_uri)) - self.failUnlessEqual( - DummyExtension.ns_uri, - namespaces.getNamespaceURI(DummyExtension.ns_alias)) - self.failUnlessEqual(DummyExtension.ns_alias, - namespaces.getAlias(DummyExtension.ns_uri)) + self.assertEqual(DummyExtension.ns_uri, namespaces.getNamespaceURI(DummyExtension.ns_alias)) + self.assertEqual(DummyExtension.ns_alias, namespaces.getAlias(DummyExtension.ns_uri)) diff --git a/openid/test/test_htmldiscover.py b/openid/test/test_htmldiscover.py index 9d7344a3..65b036f1 100644 --- a/openid/test/test_htmldiscover.py +++ b/openid/test/test_htmldiscover.py @@ -13,5 +13,4 @@ class BadLinksTestCase(unittest.TestCase): def test_from_html(self): for html in self.cases: actual = OpenIDServiceEndpoint.fromHTML('http://unused.url/', html) - expected = [] - self.failUnlessEqual(expected, actual) + self.assertEqual(actual, []) diff --git a/openid/test/test_kvform.py b/openid/test/test_kvform.py index 19929279..b3fb6dd2 100644 --- a/openid/test/test_kvform.py +++ b/openid/test/test_kvform.py @@ -10,7 +10,7 @@ def checkWarnings(self, num_warnings, msg=None): full_msg = 'Invalid number of warnings {} != {}'.format(num_warnings, len(self.messages)) if msg is not None: full_msg = full_msg + ' ' + msg - self.failUnlessEqual(num_warnings, len(self.messages), full_msg) + self.assertEqual(num_warnings, len(self.messages), full_msg) def setUp(self): CatchLogs.setUp(self) @@ -30,7 +30,7 @@ def runTest(self): d = kvform.kvToDict(kv_data) # make sure it parses to expected dict - self.failUnlessEqual(d, result) + self.assertEqual(d, result) # Check to make sure we got the expected number of warnings self.checkWarnings(expected_warnings, msg='kvToDict({!r})'.format(kv_data)) @@ -39,7 +39,7 @@ def runTest(self): # sure that *** dict -> kv -> dict is identity. *** kv = kvform.dictToKV(d) d2 = kvform.kvToDict(kv) - self.failUnlessEqual(d, d2) + self.assertEqual(d, d2) class KVSeqTest(KVBaseTest): @@ -63,7 +63,7 @@ def runTest(self): # seq serializes to expected kvform actual = kvform.seqToKV(kv_data) - self.failUnlessEqual(actual, result) + self.assertEqual(actual, result) self.assertIsInstance(actual, str) # Parse back to sequence. Expected to be unchanged, except @@ -72,7 +72,7 @@ def runTest(self): seq = kvform.kvToSeq(actual) clean_seq = self.cleanSeq(seq) - self.failUnlessEqual(seq, clean_seq) + self.assertEqual(seq, clean_seq) self.checkWarnings(expected_warnings) @@ -155,5 +155,5 @@ class GeneralTest(KVBaseTest): def test_convert(self): result = kvform.seqToKV([(1, 1)]) - self.failUnlessEqual(result, '1:1\n') + self.assertEqual(result, '1:1\n') self.checkWarnings(2) diff --git a/openid/test/test_message.py b/openid/test/test_message.py index be6fc210..f0570437 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -10,17 +10,14 @@ def mkGetArgTest(ns, key, expected=None): def test(self): a_default = object() - self.failUnlessEqual(self.msg.getArg(ns, key), expected) + self.assertEqual(self.msg.getArg(ns, key), expected) if expected is None: - self.failUnlessEqual( - self.msg.getArg(ns, key, a_default), a_default) + self.assertEqual(self.msg.getArg(ns, key, a_default), a_default) self.failUnlessRaises( KeyError, self.msg.getArg, ns, key, message.no_default) else: - self.failUnlessEqual( - self.msg.getArg(ns, key, a_default), expected) - self.failUnlessEqual( - self.msg.getArg(ns, key, message.no_default), expected) + self.assertEqual(self.msg.getArg(ns, key, a_default), expected) + self.assertEqual(self.msg.getArg(ns, key, message.no_default), expected) return test @@ -30,23 +27,23 @@ def setUp(self): self.msg = message.Message() def test_toPostArgs(self): - self.failUnlessEqual(self.msg.toPostArgs(), {}) + self.assertEqual(self.msg.toPostArgs(), {}) def test_toArgs(self): - self.failUnlessEqual(self.msg.toArgs(), {}) + self.assertEqual(self.msg.toArgs(), {}) def test_toKVForm(self): - self.failUnlessEqual(self.msg.toKVForm(), '') + self.assertEqual(self.msg.toKVForm(), '') def test_toURLEncoded(self): - self.failUnlessEqual(self.msg.toURLEncoded(), '') + self.assertEqual(self.msg.toURLEncoded(), '') def test_toURL(self): base_url = 'http://base.url/' - self.failUnlessEqual(self.msg.toURL(base_url), base_url) + self.assertEqual(self.msg.toURL(base_url), base_url) def test_getOpenID(self): - self.failUnlessEqual(self.msg.getOpenIDNamespace(), None) + self.assertIsNone(self.msg.getOpenIDNamespace()) def test_getKeyOpenID(self): # Could reasonably return None instead of raising an @@ -57,17 +54,16 @@ def test_getKeyOpenID(self): self.msg.getKey, message.OPENID_NS, 'foo') def test_getKeyBARE(self): - self.failUnlessEqual(self.msg.getKey(message.BARE_NS, 'foo'), 'foo') + self.assertEqual(self.msg.getKey(message.BARE_NS, 'foo'), 'foo') def test_getKeyNS1(self): - self.failUnlessEqual(self.msg.getKey(message.OPENID1_NS, 'foo'), None) + self.assertIsNone(self.msg.getKey(message.OPENID1_NS, 'foo')) def test_getKeyNS2(self): - self.failUnlessEqual(self.msg.getKey(message.OPENID2_NS, 'foo'), None) + self.assertIsNone(self.msg.getKey(message.OPENID2_NS, 'foo')) def test_getKeyNS3(self): - self.failUnlessEqual(self.msg.getKey('urn:nothing-significant', 'foo'), - None) + self.assertIsNone(self.msg.getKey('urn:nothing-significant', 'foo')) def test_hasKey(self): # Could reasonably return False instead of raising an @@ -78,17 +74,16 @@ def test_hasKey(self): self.msg.hasKey, message.OPENID_NS, 'foo') def test_hasKeyBARE(self): - self.failUnlessEqual(self.msg.hasKey(message.BARE_NS, 'foo'), False) + self.assertFalse(self.msg.hasKey(message.BARE_NS, 'foo')) def test_hasKeyNS1(self): - self.failUnlessEqual(self.msg.hasKey(message.OPENID1_NS, 'foo'), False) + self.assertFalse(self.msg.hasKey(message.OPENID1_NS, 'foo')) def test_hasKeyNS2(self): - self.failUnlessEqual(self.msg.hasKey(message.OPENID2_NS, 'foo'), False) + self.assertFalse(self.msg.hasKey(message.OPENID2_NS, 'foo')) def test_hasKeyNS3(self): - self.failUnlessEqual(self.msg.hasKey('urn:nothing-significant', 'foo'), - False) + self.assertFalse(self.msg.hasKey('urn:nothing-significant', 'foo')) def test_getAliasedArgSuccess(self): msg = message.Message.fromPostArgs({'openid.ns.test': 'urn://foo', @@ -123,16 +118,16 @@ def test_getArgs(self): self.msg.getArgs, message.OPENID_NS) def test_getArgsBARE(self): - self.failUnlessEqual(self.msg.getArgs(message.BARE_NS), {}) + self.assertEqual(self.msg.getArgs(message.BARE_NS), {}) def test_getArgsNS1(self): - self.failUnlessEqual(self.msg.getArgs(message.OPENID1_NS), {}) + self.assertEqual(self.msg.getArgs(message.OPENID1_NS), {}) def test_getArgsNS2(self): - self.failUnlessEqual(self.msg.getArgs(message.OPENID2_NS), {}) + self.assertEqual(self.msg.getArgs(message.OPENID2_NS), {}) def test_getArgsNS3(self): - self.failUnlessEqual(self.msg.getArgs('urn:nothing-significant'), {}) + self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) def test_updateArgs(self): self.failUnlessRaises(message.UndefinedOpenIDNamespace, @@ -145,9 +140,9 @@ def _test_updateArgsNS(self, ns): 'Magnolia Electric Co.': 'Jason Molina', } - self.failUnlessEqual(self.msg.getArgs(ns), {}) + self.assertEqual(self.msg.getArgs(ns), {}) self.msg.updateArgs(ns, update_args) - self.failUnlessEqual(self.msg.getArgs(ns), update_args) + self.assertEqual(self.msg.getArgs(ns), update_args) def test_updateArgsBARE(self): self._test_updateArgsNS(message.BARE_NS) @@ -169,9 +164,9 @@ def test_setArg(self): def _test_setArgNS(self, ns): key = 'Camper van Beethoven' value = 'David Lowery' - self.failUnlessEqual(self.msg.getArg(ns, key), None) + self.assertIsNone(self.msg.getArg(ns, key)) self.msg.setArg(ns, key, value) - self.failUnlessEqual(self.msg.getArg(ns, key), value) + self.assertEqual(self.msg.getArg(ns, key), value) def test_setArgBARE(self): self._test_setArgNS(message.BARE_NS) @@ -227,70 +222,59 @@ def setUp(self): 'openid.error': 'unit test'}) def test_toPostArgs(self): - self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode': 'error', - 'openid.error': 'unit test'}) + self.assertEqual(self.msg.toPostArgs(), {'openid.mode': 'error', 'openid.error': 'unit test'}) def test_toArgs(self): - self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', - 'error': 'unit test'}) + self.assertEqual(self.msg.toArgs(), {'mode': 'error', 'error': 'unit test'}) def test_toKVForm(self): - self.failUnlessEqual(self.msg.toKVForm(), - 'error:unit test\nmode:error\n') + self.assertEqual(self.msg.toKVForm(), 'error:unit test\nmode:error\n') def test_toURLEncoded(self): - self.failUnlessEqual(self.msg.toURLEncoded(), - 'openid.error=unit+test&openid.mode=error') + self.assertEqual(self.msg.toURLEncoded(), 'openid.error=unit+test&openid.mode=error') def test_toURL(self): base_url = 'http://base.url/' actual = self.msg.toURL(base_url) actual_base = actual[:len(base_url)] - self.failUnlessEqual(actual_base, base_url) - self.failUnlessEqual(actual[len(base_url)], '?') + self.assertEqual(actual_base, base_url) + self.assertEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] parsed = cgi.parse_qs(query) - self.failUnlessEqual(parsed, {'openid.mode': ['error'], - 'openid.error': ['unit test']}) + self.assertEqual(parsed, {'openid.mode': ['error'], 'openid.error': ['unit test']}) def test_getOpenID(self): - self.failUnlessEqual(self.msg.getOpenIDNamespace(), message.OPENID1_NS) + self.assertEqual(self.msg.getOpenIDNamespace(), message.OPENID1_NS) def test_getKeyOpenID(self): - self.failUnlessEqual(self.msg.getKey(message.OPENID_NS, 'mode'), - 'openid.mode') + self.assertEqual(self.msg.getKey(message.OPENID_NS, 'mode'), 'openid.mode') def test_getKeyBARE(self): - self.failUnlessEqual(self.msg.getKey(message.BARE_NS, 'mode'), 'mode') + self.assertEqual(self.msg.getKey(message.BARE_NS, 'mode'), 'mode') def test_getKeyNS1(self): - self.failUnlessEqual( - self.msg.getKey(message.OPENID1_NS, 'mode'), 'openid.mode') + self.assertEqual(self.msg.getKey(message.OPENID1_NS, 'mode'), 'openid.mode') def test_getKeyNS2(self): - self.failUnlessEqual(self.msg.getKey(message.OPENID2_NS, 'mode'), None) + self.assertIsNone(self.msg.getKey(message.OPENID2_NS, 'mode')) def test_getKeyNS3(self): - self.failUnlessEqual( - self.msg.getKey('urn:nothing-significant', 'mode'), None) + self.assertIsNone(self.msg.getKey('urn:nothing-significant', 'mode')) def test_hasKey(self): - self.failUnlessEqual(self.msg.hasKey(message.OPENID_NS, 'mode'), True) + self.assertTrue(self.msg.hasKey(message.OPENID_NS, 'mode')) def test_hasKeyBARE(self): - self.failUnlessEqual(self.msg.hasKey(message.BARE_NS, 'mode'), False) + self.assertFalse(self.msg.hasKey(message.BARE_NS, 'mode')) def test_hasKeyNS1(self): - self.failUnlessEqual(self.msg.hasKey(message.OPENID1_NS, 'mode'), True) + self.assertTrue(self.msg.hasKey(message.OPENID1_NS, 'mode')) def test_hasKeyNS2(self): - self.failUnlessEqual( - self.msg.hasKey(message.OPENID2_NS, 'mode'), False) + self.assertFalse(self.msg.hasKey(message.OPENID2_NS, 'mode')) def test_hasKeyNS3(self): - self.failUnlessEqual( - self.msg.hasKey('urn:nothing-significant', 'mode'), False) + self.assertFalse(self.msg.hasKey('urn:nothing-significant', 'mode')) test_getArgBARE = mkGetArgTest(message.BARE_NS, 'mode') test_getArgNS = mkGetArgTest(message.OPENID_NS, 'mode', 'error') @@ -299,21 +283,19 @@ def test_hasKeyNS3(self): test_getArgNS3 = mkGetArgTest('urn:nothing-significant', 'mode') def test_getArgs(self): - self.failUnlessEqual(self.msg.getArgs(message.OPENID_NS), - {'mode': 'error', 'error': 'unit test'}) + self.assertEqual(self.msg.getArgs(message.OPENID_NS), {'mode': 'error', 'error': 'unit test'}) def test_getArgsBARE(self): - self.failUnlessEqual(self.msg.getArgs(message.BARE_NS), {}) + self.assertEqual(self.msg.getArgs(message.BARE_NS), {}) def test_getArgsNS1(self): - self.failUnlessEqual(self.msg.getArgs(message.OPENID1_NS), - {'mode': 'error', 'error': 'unit test'}) + self.assertEqual(self.msg.getArgs(message.OPENID1_NS), {'mode': 'error', 'error': 'unit test'}) def test_getArgsNS2(self): - self.failUnlessEqual(self.msg.getArgs(message.OPENID2_NS), {}) + self.assertEqual(self.msg.getArgs(message.OPENID2_NS), {}) def test_getArgsNS3(self): - self.failUnlessEqual(self.msg.getArgs('urn:nothing-significant'), {}) + self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) def _test_updateArgsNS(self, ns, before=None): if before is None: @@ -323,11 +305,11 @@ def _test_updateArgsNS(self, ns, before=None): 'Magnolia Electric Co.': 'Jason Molina', } - self.failUnlessEqual(self.msg.getArgs(ns), before) + self.assertEqual(self.msg.getArgs(ns), before) self.msg.updateArgs(ns, update_args) after = dict(before) after.update(update_args) - self.failUnlessEqual(self.msg.getArgs(ns), after) + self.assertEqual(self.msg.getArgs(ns), after) def test_updateArgs(self): self._test_updateArgsNS(message.OPENID_NS, @@ -349,9 +331,9 @@ def test_updateArgsNS3(self): def _test_setArgNS(self, ns): key = 'Camper van Beethoven' value = 'David Lowery' - self.failUnlessEqual(self.msg.getArg(ns, key), None) + self.assertIsNone(self.msg.getArg(ns, key)) self.msg.setArg(ns, key, value) - self.failUnlessEqual(self.msg.getArg(ns, key), value) + self.assertEqual(self.msg.getArg(ns, key), value) def test_setArg(self): self._test_setArgNS(message.OPENID_NS) @@ -374,9 +356,9 @@ def _test_delArgNS(self, ns): self.failUnlessRaises(KeyError, self.msg.delArg, ns, key) self.msg.setArg(ns, key, value) - self.failUnlessEqual(self.msg.getArg(ns, key), value) + self.assertEqual(self.msg.getArg(ns, key), value) self.msg.delArg(ns, key) - self.failUnlessEqual(self.msg.getArg(ns, key), None) + self.assertIsNone(self.msg.getArg(ns, key)) def test_delArg(self): self._test_delArgNS(message.OPENID_NS) @@ -408,37 +390,29 @@ def setUp(self): }) def test_toPostArgs(self): - self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID1_NS - }) + self.assertEqual(self.msg.toPostArgs(), + {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': message.OPENID1_NS}) def test_toArgs(self): - self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', - 'error': 'unit test', - 'ns': message.OPENID1_NS}) + self.assertEqual(self.msg.toArgs(), {'mode': 'error', 'error': 'unit test', 'ns': message.OPENID1_NS}) def test_toKVForm(self): - self.failUnlessEqual(self.msg.toKVForm(), - 'error:unit test\nmode:error\nns:%s\n' % message.OPENID1_NS) + self.assertEqual(self.msg.toKVForm(), 'error:unit test\nmode:error\nns:%s\n' % message.OPENID1_NS) def test_toURLEncoded(self): - self.failUnlessEqual( - self.msg.toURLEncoded(), - 'openid.error=unit+test&openid.mode=error&openid.ns=https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fopenid.net%2Fsignon%2F1.0') + self.assertEqual(self.msg.toURLEncoded(), + 'openid.error=unit+test&openid.mode=error&openid.ns=https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fopenid.net%2Fsignon%2F1.0') def test_toURL(self): base_url = 'http://base.url/' actual = self.msg.toURL(base_url) actual_base = actual[:len(base_url)] - self.failUnlessEqual(actual_base, base_url) - self.failUnlessEqual(actual[len(base_url)], '?') + self.assertEqual(actual_base, base_url) + self.assertEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] parsed = cgi.parse_qs(query) - self.failUnlessEqual( - parsed, - {'openid.mode': ['error'], 'openid.error': ['unit test'], 'openid.ns': [message.OPENID1_NS]}) + self.assertEqual(parsed, + {'openid.mode': ['error'], 'openid.error': ['unit test'], 'openid.ns': [message.OPENID1_NS]}) def test_isOpenID1(self): self.failUnless(self.msg.isOpenID1()) @@ -452,12 +426,9 @@ def setUp(self): self.msg.setArg(message.BARE_NS, "xey", "value") def test_toPostArgs(self): - self.failUnlessEqual(self.msg.toPostArgs(), - {'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID2_NS, - 'xey': 'value', - }) + self.assertEqual( + self.msg.toPostArgs(), + {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': message.OPENID2_NS, 'xey': 'value'}) def test_toPostArgs_bug_with_utf8_encoded_values(self): msg = message.Message.fromPostArgs({'openid.mode': 'error', @@ -465,32 +436,24 @@ def test_toPostArgs_bug_with_utf8_encoded_values(self): 'openid.ns': message.OPENID2_NS }) msg.setArg(message.BARE_NS, 'ünicöde_key', 'ünicöde_välüe') - self.failUnlessEqual(msg.toPostArgs(), - {'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID2_NS, - 'ünicöde_key': 'ünicöde_välüe', - }) + post_args = {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': message.OPENID2_NS, + 'ünicöde_key': 'ünicöde_välüe'} + self.assertEqual(msg.toPostArgs(), post_args) def test_toArgs(self): # This method can't tolerate BARE_NS. self.msg.delArg(message.BARE_NS, "xey") - self.failUnlessEqual(self.msg.toArgs(), {'mode': 'error', - 'error': 'unit test', - 'ns': message.OPENID2_NS, - }) + self.assertEqual(self.msg.toArgs(), {'mode': 'error', 'error': 'unit test', 'ns': message.OPENID2_NS}) def test_toKVForm(self): # Can't tolerate BARE_NS in kvform self.msg.delArg(message.BARE_NS, "xey") - self.failUnlessEqual(self.msg.toKVForm(), - 'error:unit test\nmode:error\nns:%s\n' % - (message.OPENID2_NS,)) + self.assertEqual(self.msg.toKVForm(), 'error:unit test\nmode:error\nns:%s\n' % message.OPENID2_NS) def _test_urlencoded(self, s): expected = ('openid.error=unit+test&openid.mode=error&openid.ns=%s&xey=value' % urllib.quote(message.OPENID2_NS, '')) - self.failUnlessEqual(s, expected) + self.assertEqual(s, expected) def test_toURLEncoded(self): self._test_urlencoded(self.msg.toURLEncoded()) @@ -499,50 +462,43 @@ def test_toURL(self): base_url = 'http://base.url/' actual = self.msg.toURL(base_url) actual_base = actual[:len(base_url)] - self.failUnlessEqual(actual_base, base_url) - self.failUnlessEqual(actual[len(base_url)], '?') + self.assertEqual(actual_base, base_url) + self.assertEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] self._test_urlencoded(query) def test_getOpenID(self): - self.failUnlessEqual(self.msg.getOpenIDNamespace(), message.OPENID2_NS) + self.assertEqual(self.msg.getOpenIDNamespace(), message.OPENID2_NS) def test_getKeyOpenID(self): - self.failUnlessEqual(self.msg.getKey(message.OPENID_NS, 'mode'), - 'openid.mode') + self.assertEqual(self.msg.getKey(message.OPENID_NS, 'mode'), 'openid.mode') def test_getKeyBARE(self): - self.failUnlessEqual(self.msg.getKey(message.BARE_NS, 'mode'), 'mode') + self.assertEqual(self.msg.getKey(message.BARE_NS, 'mode'), 'mode') def test_getKeyNS1(self): - self.failUnlessEqual( - self.msg.getKey(message.OPENID1_NS, 'mode'), None) + self.assertIsNone(self.msg.getKey(message.OPENID1_NS, 'mode')) def test_getKeyNS2(self): - self.failUnlessEqual( - self.msg.getKey(message.OPENID2_NS, 'mode'), 'openid.mode') + self.assertEqual(self.msg.getKey(message.OPENID2_NS, 'mode'), 'openid.mode') def test_getKeyNS3(self): - self.failUnlessEqual( - self.msg.getKey('urn:nothing-significant', 'mode'), None) + self.assertIsNone(self.msg.getKey('urn:nothing-significant', 'mode')) def test_hasKeyOpenID(self): - self.failUnlessEqual(self.msg.hasKey(message.OPENID_NS, 'mode'), True) + self.assertTrue(self.msg.hasKey(message.OPENID_NS, 'mode')) def test_hasKeyBARE(self): - self.failUnlessEqual(self.msg.hasKey(message.BARE_NS, 'mode'), False) + self.assertFalse(self.msg.hasKey(message.BARE_NS, 'mode')) def test_hasKeyNS1(self): - self.failUnlessEqual( - self.msg.hasKey(message.OPENID1_NS, 'mode'), False) + self.assertFalse(self.msg.hasKey(message.OPENID1_NS, 'mode')) def test_hasKeyNS2(self): - self.failUnlessEqual( - self.msg.hasKey(message.OPENID2_NS, 'mode'), True) + self.assertTrue(self.msg.hasKey(message.OPENID2_NS, 'mode')) def test_hasKeyNS3(self): - self.failUnlessEqual( - self.msg.hasKey('urn:nothing-significant', 'mode'), False) + self.assertFalse(self.msg.hasKey('urn:nothing-significant', 'mode')) test_getArgBARE = mkGetArgTest(message.BARE_NS, 'mode') test_getArgNS = mkGetArgTest(message.OPENID_NS, 'mode', 'error') @@ -551,22 +507,19 @@ def test_hasKeyNS3(self): test_getArgNS3 = mkGetArgTest('urn:nothing-significant', 'mode') def test_getArgsOpenID(self): - self.failUnlessEqual(self.msg.getArgs(message.OPENID_NS), - {'mode': 'error', 'error': 'unit test'}) + self.assertEqual(self.msg.getArgs(message.OPENID_NS), {'mode': 'error', 'error': 'unit test'}) def test_getArgsBARE(self): - self.failUnlessEqual(self.msg.getArgs(message.BARE_NS), - {'xey': 'value'}) + self.assertEqual(self.msg.getArgs(message.BARE_NS), {'xey': 'value'}) def test_getArgsNS1(self): - self.failUnlessEqual(self.msg.getArgs(message.OPENID1_NS), {}) + self.assertEqual(self.msg.getArgs(message.OPENID1_NS), {}) def test_getArgsNS2(self): - self.failUnlessEqual(self.msg.getArgs(message.OPENID2_NS), - {'mode': 'error', 'error': 'unit test'}) + self.assertEqual(self.msg.getArgs(message.OPENID2_NS), {'mode': 'error', 'error': 'unit test'}) def test_getArgsNS3(self): - self.failUnlessEqual(self.msg.getArgs('urn:nothing-significant'), {}) + self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) def _test_updateArgsNS(self, ns, before=None): if before is None: @@ -576,11 +529,11 @@ def _test_updateArgsNS(self, ns, before=None): 'Magnolia Electric Co.': 'Jason Molina', } - self.failUnlessEqual(self.msg.getArgs(ns), before) + self.assertEqual(self.msg.getArgs(ns), before) self.msg.updateArgs(ns, update_args) after = dict(before) after.update(update_args) - self.failUnlessEqual(self.msg.getArgs(ns), after) + self.assertEqual(self.msg.getArgs(ns), after) def test_updateArgsOpenID(self): self._test_updateArgsNS(message.OPENID_NS, @@ -603,9 +556,9 @@ def test_updateArgsNS3(self): def _test_setArgNS(self, ns): key = 'Camper van Beethoven' value = 'David Lowery' - self.failUnlessEqual(self.msg.getArg(ns, key), None) + self.assertIsNone(self.msg.getArg(ns, key)) self.msg.setArg(ns, key, value) - self.failUnlessEqual(self.msg.getArg(ns, key), value) + self.assertEqual(self.msg.getArg(ns, key), value) def test_setArgOpenID(self): self._test_setArgNS(message.OPENID_NS) @@ -661,8 +614,8 @@ def test_mysterious_missing_namespace_bug(self): for k in openid_args['signed'].split(','): if not ("openid." + k) in m.toPostArgs().keys(): missing.append(k) - self.assertEqual([], missing, missing) - self.assertEqual(openid_args, m.toArgs()) + self.assertEqual(missing, []) + self.assertEqual(m.toArgs(), openid_args) self.failUnless(m.isOpenID1()) def test_112B(self): @@ -687,8 +640,8 @@ def test_112B(self): for k in args['openid.signed'].split(','): if not ("openid." + k) in m.toPostArgs().keys(): missing.append(k) - self.assertEqual([], missing, missing) - self.assertEqual(args, m.toPostArgs()) + self.assertEqual(missing, [], missing) + self.assertEqual(m.toPostArgs(), args) self.failUnless(m.isOpenID2()) def test_repetitive_namespaces(self): @@ -719,8 +672,8 @@ def test_implicit_sreg_ns(self): m = message.Message.fromOpenIDArgs(openid_args) self.failUnless((sreg.ns_uri, 'sreg') in list(m.namespaces.iteritems())) - self.assertEqual('a@b.com', m.getArg(sreg.ns_uri, 'email')) - self.assertEqual(openid_args, m.toArgs()) + self.assertEqual(m.getArg(sreg.ns_uri, 'email'), 'a@b.com') + self.assertEqual(m.toArgs(), openid_args) self.failUnless(m.isOpenID1()) def _test_delArgNS(self, ns): @@ -729,9 +682,9 @@ def _test_delArgNS(self, ns): self.failUnlessRaises(KeyError, self.msg.delArg, ns, key) self.msg.setArg(ns, key, value) - self.failUnlessEqual(self.msg.getArg(ns, key), value) + self.assertEqual(self.msg.getArg(ns, key), value) self.msg.delArg(ns, key) - self.failUnlessEqual(self.msg.getArg(ns, key), None) + self.assertIsNone(self.msg.getArg(ns, key)) def test_delArgOpenID(self): self._test_delArgNS(message.OPENID_NS) @@ -944,7 +897,7 @@ def test_isOpenID1(self): m = message.Message(ns) self.failUnless(m.isOpenID1(), "%r not recognized as OpenID 1" % (ns,)) - self.failUnlessEqual(ns, m.getOpenIDNamespace()) + self.assertEqual(m.getOpenIDNamespace(), ns) self.failUnless(m.namespaces.isImplicit(ns), m.namespaces.getNamespaceURI(message.NULL_NAMESPACE)) @@ -953,7 +906,7 @@ def test_isOpenID2(self): m = message.Message(ns) self.failUnless(m.isOpenID2()) self.failIf(m.namespaces.isImplicit(message.NULL_NAMESPACE)) - self.failUnlessEqual(ns, m.getOpenIDNamespace()) + self.assertEqual(m.getOpenIDNamespace(), ns) def test_setOpenIDNamespace_explicit(self): m = message.Message() @@ -970,8 +923,7 @@ def test_explicitOpenID11NSSerialzation(self): m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, implicit=False) post_args = m.toPostArgs() - self.failUnlessEqual(post_args, - {'openid.ns': message.THE_OTHER_OPENID1_NS}) + self.assertEqual(post_args, {'openid.ns': message.THE_OTHER_OPENID1_NS}) def test_fromPostArgs_ns11(self): # An example of the stuff that some Drupal installations send us, diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index 8936ecd8..6be4528a 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -144,7 +144,7 @@ def testUnsupportedWithRetryAndFail(self): self.consumer.return_messages = [msg, Message(self.endpoint.preferredNamespace())] - self.failUnlessEqual(self.consumer._negotiateAssociation(self.endpoint), None) + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) self.failUnlessLogMatches('Unsupported association type', 'Server %s refused' % (self.endpoint.server_url)) diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index fcc4687f..1817f821 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -23,16 +23,16 @@ def test_splitNonce(self): expected_t = 0 expected_salt = '' actual_t, actual_salt = splitNonce(s) - self.failUnlessEqual(expected_t, actual_t) - self.failUnlessEqual(expected_salt, actual_salt) + self.assertEqual(actual_t, expected_t) + self.assertEqual(actual_salt, expected_salt) def test_mkSplit(self): t = 42 nonce_str = mkNonce(t) self.failUnless(nonce_re.match(nonce_str)) et, salt = splitNonce(nonce_str) - self.failUnlessEqual(len(salt), 6) - self.failUnlessEqual(et, t) + self.assertEqual(len(salt), 6) + self.assertEqual(et, t) class BadSplitTest(unittest.TestCase): @@ -81,4 +81,4 @@ class CheckTimestampTest(unittest.TestCase): def test(self): for nonce_string, allowed_skew, now, expected in self.cases: actual = checkTimestamp(nonce_string, allowed_skew, now) - self.failUnlessEqual(bool(expected), bool(actual)) + self.assertEqual(bool(actual), bool(expected)) diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 4a76749f..91e7e4ac 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -128,7 +128,7 @@ def runTest(self): # make sure there are the same number of endpoints as # URIs. This assumes that the type_uris contains at least one # OpenID type. - self.failUnlessEqual(len(uris), len(endpoints)) + self.assertEqual(len(endpoints), len(uris)) # So that we can check equality on the endpoint types type_uris = sorted(type_uris) @@ -138,14 +138,14 @@ def runTest(self): seen_uris.append(endpoint.server_url) # All endpoints will have same yadis_url - self.failUnlessEqual(self.yadis_url, endpoint.claimed_id) + self.assertEqual(endpoint.claimed_id, self.yadis_url) # and local_id - self.failUnlessEqual(local_id, endpoint.local_id) + self.assertEqual(endpoint.local_id, local_id) # and types actual_types = sorted(endpoint.type_uris) - self.failUnlessEqual(actual_types, type_uris) + self.assertEqual(type_uris, actual_types) # So that they will compare equal, because we don't care what # order they are in @@ -153,4 +153,4 @@ def runTest(self): uris = sorted(uris) # Make sure we saw all URIs, and saw each one once - self.failUnlessEqual(uris, seen_uris) + self.assertEqual(seen_uris, uris) diff --git a/openid/test/test_pape_draft2.py b/openid/test/test_pape_draft2.py index be76550b..9f54f88f 100644 --- a/openid/test/test_pape_draft2.py +++ b/openid/test/test_pape_draft2.py @@ -10,48 +10,46 @@ def setUp(self): self.req = pape.Request() def test_construct(self): - self.failUnlessEqual([], self.req.preferred_auth_policies) - self.failUnlessEqual(None, self.req.max_auth_age) - self.failUnlessEqual('pape', self.req.ns_alias) + self.assertEqual(self.req.preferred_auth_policies, []) + self.assertIsNone(self.req.max_auth_age) + self.assertEqual(self.req.ns_alias, 'pape') req2 = pape.Request([pape.AUTH_MULTI_FACTOR], 1000) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], req2.preferred_auth_policies) - self.failUnlessEqual(1000, req2.max_auth_age) + self.assertEqual(req2.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.assertEqual(req2.max_auth_age, 1000) def test_add_policy_uri(self): - self.failUnlessEqual([], self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, []) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], - self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], - self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) def test_getExtensionArgs(self): - self.failUnlessEqual({'preferred_auth_policies': ''}, self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': ''}) self.req.addPolicyURI('http://uri') - self.failUnlessEqual({'preferred_auth_policies': 'http://uri'}, self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': 'http://uri'}) self.req.addPolicyURI('http://zig') - self.failUnlessEqual({'preferred_auth_policies': 'http://uri http://zig'}, self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': 'http://uri http://zig'}) self.req.max_auth_age = 789 - self.failUnlessEqual({'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}, - self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), + {'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}) def test_parseExtensionArgs(self): args = {'preferred_auth_policies': 'http://foo http://bar', 'max_auth_age': '9'} self.req.parseExtensionArgs(args) - self.failUnlessEqual(9, self.req.max_auth_age) - self.failUnlessEqual(['http://foo', 'http://bar'], self.req.preferred_auth_policies) + self.assertEqual(self.req.max_auth_age, 9) + self.assertEqual(self.req.preferred_auth_policies, ['http://foo', 'http://bar']) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}) - self.failUnlessEqual(None, self.req.max_auth_age) - self.failUnlessEqual([], self.req.preferred_auth_policies) + self.assertIsNone(self.req.max_auth_age) + self.assertEqual(self.req.preferred_auth_policies, []) def test_fromOpenIDRequest(self): openid_req_msg = Message.fromOpenIDArgs({ @@ -64,8 +62,8 @@ def test_fromOpenIDRequest(self): oid_req = server.OpenIDRequest() oid_req.message = openid_req_msg req = pape.Request.fromOpenIDRequest(oid_req) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], req.preferred_auth_policies) - self.failUnlessEqual(5476, req.max_auth_age) + self.assertEqual(req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) + self.assertEqual(req.max_auth_age, 5476) def test_fromOpenIDRequest_no_pape(self): message = Message() @@ -79,7 +77,7 @@ def test_preferred_types(self): self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) pt = self.req.preferredTypes([pape.AUTH_MULTI_FACTOR, pape.AUTH_MULTI_FACTOR_PHYSICAL]) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], pt) + self.assertEqual(pt, [pape.AUTH_MULTI_FACTOR]) class DummySuccessResponse: @@ -96,41 +94,40 @@ def setUp(self): self.req = pape.Response() def test_construct(self): - self.failUnlessEqual([], self.req.auth_policies) - self.failUnlessEqual(None, self.req.auth_time) - self.failUnlessEqual('pape', self.req.ns_alias) - self.failUnlessEqual(None, self.req.nist_auth_level) + self.assertEqual(self.req.auth_policies, []) + self.assertIsNone(self.req.auth_time) + self.assertEqual(self.req.ns_alias, 'pape') + self.assertIsNone(self.req.nist_auth_level) req2 = pape.Response([pape.AUTH_MULTI_FACTOR], "2004-12-11T10:30:44Z", 3) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], req2.auth_policies) - self.failUnlessEqual("2004-12-11T10:30:44Z", req2.auth_time) - self.failUnlessEqual(3, req2.nist_auth_level) + self.assertEqual(req2.auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.assertEqual(req2.auth_time, "2004-12-11T10:30:44Z") + self.assertEqual(req2.nist_auth_level, 3) def test_add_policy_uri(self): - self.failUnlessEqual([], self.req.auth_policies) + self.assertEqual(self.req.auth_policies, []) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], self.req.auth_policies) + self.assertEqual(self.req.auth_policies, [pape.AUTH_MULTI_FACTOR]) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], self.req.auth_policies) + self.assertEqual(self.req.auth_policies, [pape.AUTH_MULTI_FACTOR]) self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], self.req.auth_policies) + self.assertEqual(self.req.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], self.req.auth_policies) + self.assertEqual(self.req.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) def test_getExtensionArgs(self): - self.failUnlessEqual({'auth_policies': 'none'}, self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'auth_policies': 'none'}) self.req.addPolicyURI('http://uri') - self.failUnlessEqual({'auth_policies': 'http://uri'}, self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'auth_policies': 'http://uri'}) self.req.addPolicyURI('http://zig') - self.failUnlessEqual({'auth_policies': 'http://uri http://zig'}, self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'auth_policies': 'http://uri http://zig'}) self.req.auth_time = "1776-07-04T14:43:12Z" - self.failUnlessEqual({'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}, - self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), + {'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}) self.req.nist_auth_level = 3 - self.failUnlessEqual({'auth_policies': 'http://uri http://zig', - 'auth_time': "1776-07-04T14:43:12Z", - 'nist_auth_level': '3'}, - self.req.getExtensionArgs()) + nist_data = {'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z", + 'nist_auth_level': '3'} + self.assertEqual(self.req.getExtensionArgs(), nist_data) def test_getExtensionArgs_error_auth_age(self): self.req.auth_time = "long ago" @@ -148,13 +145,13 @@ def test_parseExtensionArgs(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': '1970-01-01T00:00:00Z'} self.req.parseExtensionArgs(args) - self.failUnlessEqual('1970-01-01T00:00:00Z', self.req.auth_time) - self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) + self.assertEqual(self.req.auth_time, '1970-01-01T00:00:00Z') + self.assertEqual(self.req.auth_policies, ['http://foo', 'http://bar']) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}) - self.failUnlessEqual(None, self.req.auth_time) - self.failUnlessEqual([], self.req.auth_policies) + self.assertIsNone(self.req.auth_time) + self.assertEqual(self.req.auth_policies, []) def test_parseExtensionArgs_strict_bogus1(self): args = {'auth_policies': 'http://foo http://bar', @@ -174,18 +171,18 @@ def test_parseExtensionArgs_strict_good(self): 'auth_time': '1970-01-01T00:00:00Z', 'nist_auth_level': '0'} self.req.parseExtensionArgs(args, True) - self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) - self.failUnlessEqual('1970-01-01T00:00:00Z', self.req.auth_time) - self.failUnlessEqual(0, self.req.nist_auth_level) + self.assertEqual(self.req.auth_policies, ['http://foo', 'http://bar']) + self.assertEqual(self.req.auth_time, '1970-01-01T00:00:00Z') + self.assertEqual(self.req.nist_auth_level, 0) def test_parseExtensionArgs_nostrict_bogus(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': 'when the cows come home', 'nist_auth_level': 'some'} self.req.parseExtensionArgs(args) - self.failUnlessEqual(['http://foo', 'http://bar'], self.req.auth_policies) - self.failUnlessEqual(None, self.req.auth_time) - self.failUnlessEqual(None, self.req.nist_auth_level) + self.assertEqual(self.req.auth_policies, ['http://foo', 'http://bar']) + self.assertIsNone(self.req.auth_time) + self.assertIsNone(self.req.nist_auth_level) def test_fromSuccessResponse(self): openid_req_msg = Message.fromOpenIDArgs({ @@ -201,8 +198,8 @@ def test_fromSuccessResponse(self): } oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) req = pape.Response.fromSuccessResponse(oid_req) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT], req.auth_policies) - self.failUnlessEqual('1970-01-01T00:00:00Z', req.auth_time) + self.assertEqual(req.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) + self.assertEqual(req.auth_time, '1970-01-01T00:00:00Z') def test_fromSuccessResponseNoSignedArgs(self): openid_req_msg = Message.fromOpenIDArgs({ diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index 243eae5d..ee2be1ae 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -14,37 +14,31 @@ def setUp(self): self.req = pape.Request() def test_construct(self): - self.failUnlessEqual([], self.req.preferred_auth_policies) - self.failUnlessEqual(None, self.req.max_auth_age) - self.failUnlessEqual('pape', self.req.ns_alias) + self.assertEqual(self.req.preferred_auth_policies, []) + self.assertIsNone(self.req.max_auth_age) + self.assertEqual(self.req.ns_alias, 'pape') self.failIf(self.req.preferred_auth_level_types) bogus_levels = ['http://janrain.com/our_levels'] req2 = pape.Request( [pape.AUTH_MULTI_FACTOR], 1000, bogus_levels) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], - req2.preferred_auth_policies) - self.failUnlessEqual(1000, req2.max_auth_age) - self.failUnlessEqual(bogus_levels, req2.preferred_auth_level_types) + self.assertEqual(req2.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.assertEqual(req2.max_auth_age, 1000) + self.assertEqual(req2.preferred_auth_level_types, bogus_levels) def test_addAuthLevel(self): self.req.addAuthLevel('http://example.com/', 'example') - self.failUnlessEqual(['http://example.com/'], - self.req.preferred_auth_level_types) - self.failUnlessEqual('http://example.com/', - self.req.auth_level_aliases['example']) + self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/']) + self.assertEqual(self.req.auth_level_aliases['example'], 'http://example.com/') self.req.addAuthLevel('http://example.com/1', 'example1') - self.failUnlessEqual(['http://example.com/', 'http://example.com/1'], - self.req.preferred_auth_level_types) + self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) self.req.addAuthLevel('http://example.com/', 'exmpl') - self.failUnlessEqual(['http://example.com/', 'http://example.com/1'], - self.req.preferred_auth_level_types) + self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) self.req.addAuthLevel('http://example.com/', 'example') - self.failUnlessEqual(['http://example.com/', 'http://example.com/1'], - self.req.preferred_auth_level_types) + self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) self.failUnlessRaises(KeyError, self.req.addAuthLevel, @@ -60,41 +54,28 @@ def test_addAuthLevel(self): before_aliases = self.req.auth_level_aliases.keys() self.req.addAuthLevel(uri) after_aliases = self.req.auth_level_aliases.keys() - self.assertEqual(before_aliases, after_aliases) + self.assertEqual(after_aliases, before_aliases) def test_add_policy_uri(self): - self.failUnlessEqual([], self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, []) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], - self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], - self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, - pape.AUTH_PHISHING_RESISTANT], - self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, - pape.AUTH_PHISHING_RESISTANT], - self.req.preferred_auth_policies) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) def test_getExtensionArgs(self): - self.failUnlessEqual({'preferred_auth_policies': ''}, - self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': ''}) self.req.addPolicyURI('http://uri') - self.failUnlessEqual( - {'preferred_auth_policies': 'http://uri'}, - self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': 'http://uri'}) self.req.addPolicyURI('http://zig') - self.failUnlessEqual( - {'preferred_auth_policies': 'http://uri http://zig'}, - self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': 'http://uri http://zig'}) self.req.max_auth_age = 789 - self.failUnlessEqual( - {'preferred_auth_policies': 'http://uri http://zig', - 'max_auth_age': '789'}, - self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), + {'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}) def test_getExtensionArgsWithAuthLevels(self): uri = 'http://example.com/auth_level' @@ -112,7 +93,7 @@ def test_getExtensionArgsWithAuthLevels(self): 'preferred_auth_policies': '', } - self.failUnlessEqual(expected_args, self.req.getExtensionArgs()) + self.assertEqual(self.req.getExtensionArgs(), expected_args) def test_parseExtensionArgsWithAuthLevels(self): uri = 'http://example.com/auth_level' @@ -133,10 +114,9 @@ def test_parseExtensionArgsWithAuthLevels(self): expected_auth_levels = [uri, uri2] - self.assertEqual(expected_auth_levels, - self.req.preferred_auth_level_types) - self.assertEqual(uri, self.req.auth_level_aliases[alias]) - self.assertEqual(uri2, self.req.auth_level_aliases[alias2]) + self.assertEqual(self.req.preferred_auth_level_types, expected_auth_levels) + self.assertEqual(self.req.auth_level_aliases[alias], uri) + self.assertEqual(self.req.auth_level_aliases[alias2], uri2) def test_parseExtensionArgsWithAuthLevels_openID1(self): request_args = { @@ -144,13 +124,11 @@ def test_parseExtensionArgsWithAuthLevels_openID1(self): } expected_auth_levels = [pape.LEVELS_NIST, pape.LEVELS_JISA] self.req.parseExtensionArgs(request_args, is_openid1=True) - self.assertEqual(expected_auth_levels, - self.req.preferred_auth_level_types) + self.assertEqual(self.req.preferred_auth_level_types, expected_auth_levels) self.req = pape.Request() self.req.parseExtensionArgs(request_args, is_openid1=False) - self.assertEqual([], - self.req.preferred_auth_level_types) + self.assertEqual(self.req.preferred_auth_level_types, []) self.req = pape.Request() self.failUnlessRaises(ValueError, @@ -160,7 +138,7 @@ def test_parseExtensionArgsWithAuthLevels_openID1(self): def test_parseExtensionArgs_ignoreBadAuthLevels(self): request_args = {'preferred_auth_level_types': 'monkeys'} self.req.parseExtensionArgs(request_args, False) - self.assertEqual([], self.req.preferred_auth_level_types) + self.assertEqual(self.req.preferred_auth_level_types, []) def test_parseExtensionArgs_strictBadAuthLevels(self): request_args = {'preferred_auth_level_types': 'monkeys'} @@ -171,10 +149,9 @@ def test_parseExtensionArgs(self): args = {'preferred_auth_policies': 'http://foo http://bar', 'max_auth_age': '9'} self.req.parseExtensionArgs(args, False) - self.failUnlessEqual(9, self.req.max_auth_age) - self.failUnlessEqual(['http://foo', 'http://bar'], - self.req.preferred_auth_policies) - self.failUnlessEqual([], self.req.preferred_auth_level_types) + self.assertEqual(self.req.max_auth_age, 9) + self.assertEqual(self.req.preferred_auth_policies, ['http://foo', 'http://bar']) + self.assertEqual(self.req.preferred_auth_level_types, []) def test_parseExtensionArgs_strict_bad_auth_age(self): args = {'max_auth_age': 'not an int'} @@ -183,9 +160,9 @@ def test_parseExtensionArgs_strict_bad_auth_age(self): def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}, False) - self.failUnlessEqual(None, self.req.max_auth_age) - self.failUnlessEqual([], self.req.preferred_auth_policies) - self.failUnlessEqual([], self.req.preferred_auth_level_types) + self.assertIsNone(self.req.max_auth_age) + self.assertEqual(self.req.preferred_auth_policies, []) + self.assertEqual(self.req.preferred_auth_level_types, []) def test_fromOpenIDRequest(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] @@ -199,8 +176,8 @@ def test_fromOpenIDRequest(self): oid_req = server.OpenIDRequest() oid_req.message = openid_req_msg req = pape.Request.fromOpenIDRequest(oid_req) - self.failUnlessEqual(policy_uris, req.preferred_auth_policies) - self.failUnlessEqual(5476, req.max_auth_age) + self.assertEqual(req.preferred_auth_policies, policy_uris) + self.assertEqual(req.max_auth_age, 5476) def test_fromOpenIDRequest_no_pape(self): message = Message() @@ -214,7 +191,7 @@ def test_preferred_types(self): self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) pt = self.req.preferredTypes([pape.AUTH_MULTI_FACTOR, pape.AUTH_MULTI_FACTOR_PHYSICAL]) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], pt) + self.assertEqual(pt, [pape.AUTH_MULTI_FACTOR]) class DummySuccessResponse: @@ -234,56 +211,44 @@ def setUp(self): self.resp = pape.Response() def test_construct(self): - self.failUnlessEqual([], self.resp.auth_policies) - self.failUnlessEqual(None, self.resp.auth_time) - self.failUnlessEqual('pape', self.resp.ns_alias) - self.failUnlessEqual(None, self.resp.nist_auth_level) + self.assertEqual(self.resp.auth_policies, []) + self.assertIsNone(self.resp.auth_time) + self.assertEqual(self.resp.ns_alias, 'pape') + self.assertIsNone(self.resp.nist_auth_level) req2 = pape.Response([pape.AUTH_MULTI_FACTOR], "2004-12-11T10:30:44Z", {pape.LEVELS_NIST: 3}) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], req2.auth_policies) - self.failUnlessEqual("2004-12-11T10:30:44Z", req2.auth_time) - self.failUnlessEqual(3, req2.nist_auth_level) + self.assertEqual(req2.auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.assertEqual(req2.auth_time, "2004-12-11T10:30:44Z") + self.assertEqual(req2.nist_auth_level, 3) def test_add_policy_uri(self): - self.failUnlessEqual([], self.resp.auth_policies) + self.assertEqual(self.resp.auth_policies, []) self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], self.resp.auth_policies) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR]) self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR], self.resp.auth_policies) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR]) self.resp.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, - pape.AUTH_PHISHING_RESISTANT], - self.resp.auth_policies) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.failUnlessEqual([pape.AUTH_MULTI_FACTOR, - pape.AUTH_PHISHING_RESISTANT], - self.resp.auth_policies) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) self.failUnlessRaises(RuntimeError, self.resp.addPolicyURI, pape.AUTH_NONE) def test_getExtensionArgs(self): - self.failUnlessEqual({'auth_policies': pape.AUTH_NONE}, - self.resp.getExtensionArgs()) + self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': pape.AUTH_NONE}) self.resp.addPolicyURI('http://uri') - self.failUnlessEqual({'auth_policies': 'http://uri'}, - self.resp.getExtensionArgs()) + self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': 'http://uri'}) self.resp.addPolicyURI('http://zig') - self.failUnlessEqual({'auth_policies': 'http://uri http://zig'}, - self.resp.getExtensionArgs()) + self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': 'http://uri http://zig'}) self.resp.auth_time = "1776-07-04T14:43:12Z" - self.failUnlessEqual( - {'auth_policies': 'http://uri http://zig', - 'auth_time': "1776-07-04T14:43:12Z"}, - self.resp.getExtensionArgs()) + self.assertEqual(self.resp.getExtensionArgs(), + {'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}) self.resp.setAuthLevel(pape.LEVELS_NIST, '3') - self.failUnlessEqual( - {'auth_policies': 'http://uri http://zig', - 'auth_time': "1776-07-04T14:43:12Z", - 'auth_level.nist': '3', - 'auth_level.ns.nist': pape.LEVELS_NIST}, - self.resp.getExtensionArgs()) + nist_args = {'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z", + 'auth_level.nist': '3', 'auth_level.ns.nist': pape.LEVELS_NIST} + self.assertEqual(self.resp.getExtensionArgs(), nist_args) def test_getExtensionArgs_error_auth_age(self): self.resp.auth_time = "long ago" @@ -293,19 +258,18 @@ def test_parseExtensionArgs(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': '1970-01-01T00:00:00Z'} self.resp.parseExtensionArgs(args, is_openid1=False) - self.failUnlessEqual('1970-01-01T00:00:00Z', self.resp.auth_time) - self.failUnlessEqual(['http://foo', 'http://bar'], - self.resp.auth_policies) + self.assertEqual(self.resp.auth_time, '1970-01-01T00:00:00Z') + self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) def test_parseExtensionArgs_valid_none(self): args = {'auth_policies': pape.AUTH_NONE} self.resp.parseExtensionArgs(args, is_openid1=False) - self.failUnlessEqual([], self.resp.auth_policies) + self.assertEqual(self.resp.auth_policies, []) def test_parseExtensionArgs_old_none(self): args = {'auth_policies': 'none'} self.resp.parseExtensionArgs(args, is_openid1=False) - self.failUnlessEqual([], self.resp.auth_policies) + self.assertEqual(self.resp.auth_policies, []) def test_parseExtensionArgs_old_none_strict(self): args = {'auth_policies': 'none'} @@ -315,8 +279,8 @@ def test_parseExtensionArgs_old_none_strict(self): def test_parseExtensionArgs_empty(self): self.resp.parseExtensionArgs({}, is_openid1=False) - self.failUnlessEqual(None, self.resp.auth_time) - self.failUnlessEqual([], self.resp.auth_policies) + self.assertIsNone(self.resp.auth_time) + self.assertEqual(self.resp.auth_policies, []) def test_parseExtensionArgs_empty_strict(self): self.failUnlessRaises( @@ -332,8 +296,7 @@ def test_parseExtensionArgs_ignore_superfluous_none(self): self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) - self.assertEqual([pape.AUTH_MULTI_FACTOR_PHYSICAL], - self.resp.auth_policies) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR_PHYSICAL]) def test_parseExtensionArgs_none_strict(self): policies = [pape.AUTH_NONE, pape.AUTH_MULTI_FACTOR_PHYSICAL] @@ -356,8 +319,8 @@ def test_parseExtensionArgs_openid1_strict(self): 'auth_policies': pape.AUTH_NONE, } self.resp.parseExtensionArgs(args, strict=True, is_openid1=True) - self.failUnlessEqual('0', self.resp.getAuthLevel(pape.LEVELS_NIST)) - self.failUnlessEqual([], self.resp.auth_policies) + self.assertEqual(self.resp.getAuthLevel(pape.LEVELS_NIST), '0') + self.assertEqual(self.resp.auth_policies, []) def test_parseExtensionArgs_strict_no_namespace_decl_openid2(self): # Test the case where the namespace is not declared for an @@ -386,20 +349,18 @@ def test_parseExtensionArgs_strict_good(self): 'auth_level.nist': '0', 'auth_level.ns.nist': pape.LEVELS_NIST} self.resp.parseExtensionArgs(args, is_openid1=False, strict=True) - self.failUnlessEqual(['http://foo', 'http://bar'], - self.resp.auth_policies) - self.failUnlessEqual('1970-01-01T00:00:00Z', self.resp.auth_time) - self.failUnlessEqual(0, self.resp.nist_auth_level) + self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) + self.assertEqual(self.resp.auth_time, '1970-01-01T00:00:00Z') + self.assertEqual(self.resp.nist_auth_level, 0) def test_parseExtensionArgs_nostrict_bogus(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': 'when the cows come home', 'nist_auth_level': 'some'} self.resp.parseExtensionArgs(args, is_openid1=False) - self.failUnlessEqual(['http://foo', 'http://bar'], - self.resp.auth_policies) - self.failUnlessEqual(None, self.resp.auth_time) - self.failUnlessEqual(None, self.resp.nist_auth_level) + self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) + self.assertIsNone(self.resp.auth_time) + self.assertIsNone(self.resp.nist_auth_level) def test_fromSuccessResponse(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] @@ -416,8 +377,8 @@ def test_fromSuccessResponse(self): } oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) req = pape.Response.fromSuccessResponse(oid_req) - self.failUnlessEqual(policy_uris, req.auth_policies) - self.failUnlessEqual('1970-01-01T00:00:00Z', req.auth_time) + self.assertEqual(req.auth_policies, policy_uris) + self.assertEqual(req.auth_time, '1970-01-01T00:00:00Z') def test_fromSuccessResponseNoSignedArgs(self): policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index bd5a6267..10a13163 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -25,7 +25,7 @@ def test(self): found = 'None' msg = "%r != %r for case %s" % (found, expected, case) - self.failUnlessEqual(found, expected, msg) + self.assertEqual(found, expected, msg) except HTMLParseError: self.failUnless(expected == 'None', (case, expected)) else: diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index c1069818..82432f32 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -22,7 +22,7 @@ def failUnlessDiscoURL(self, realm, expected_discovery_url): """ realm_obj = trustroot.TrustRoot.parse(realm) actual_discovery_url = realm_obj.buildDiscoveryURL() - self.failUnlessEqual(expected_discovery_url, actual_discovery_url) + self.assertEqual(actual_discovery_url, expected_discovery_url) def test_trivial(self): """There is no wildcard and the realm is the same as the return_to URL @@ -69,7 +69,7 @@ def failUnlessXRDSHasReturnURLs(self, data, expected_return_urls): actual_return_urls = list(trustroot.getAllowedReturnURLs( self.disco_url)) - self.failUnlessEqual(expected_return_urls, actual_return_urls) + self.assertEqual(actual_return_urls, expected_return_urls) def failUnlessDiscoveryFailure(self, text): self.data = text @@ -216,7 +216,7 @@ def test_verifyWithDiscoveryCalled(self): return_to = 'http://www.example.com/foo' def vrfy(disco_url): - self.failUnlessEqual('http://www.example.com/', disco_url) + self.assertEqual(disco_url, 'http://www.example.com/') return [return_to] self.failUnless( @@ -228,7 +228,7 @@ def test_verifyFailWithDiscoveryCalled(self): return_to = 'http://www.example.com/foo' def vrfy(disco_url): - self.failUnlessEqual('http://www.example.com/', disco_url) + self.assertEqual(disco_url, 'http://www.example.com/') return ['http://something-else.invalid/'] self.failIf( diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 171be4c2..6433f0d2 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -43,7 +43,7 @@ def test_browserWithReturnTo(self): rt_base, result_args = e.encodeToURL().split('?', 1) result_args = cgi.parse_qs(result_args) - self.failUnlessEqual(result_args, expected_args) + self.assertEqual(result_args, expected_args) def test_browserWithReturnTo_OpenID2_GET(self): return_to = "http://rp.unittest/consumer" @@ -65,7 +65,7 @@ def test_browserWithReturnTo_OpenID2_GET(self): rt_base, result_args = e.encodeToURL().split('?', 1) result_args = cgi.parse_qs(result_args) - self.failUnlessEqual(result_args, expected_args) + self.assertEqual(result_args, expected_args) def test_browserWithReturnTo_OpenID2_POST(self): return_to = "http://rp.unittest/consumer" + ('x' * OPENID1_URL_LIMIT) @@ -102,7 +102,7 @@ def test_browserWithReturnTo_OpenID1_exceeds_limit(self): rt_base, result_args = e.encodeToURL().split('?', 1) result_args = cgi.parse_qs(result_args) - self.failUnlessEqual(result_args, expected_args) + self.assertEqual(result_args, expected_args) def test_noReturnTo(self): # will be a ProtocolError raised by Decode or CheckIDRequest.answer @@ -115,12 +115,12 @@ def test_noReturnTo(self): expected = """error:waffles mode:error """ - self.failUnlessEqual(e.encodeToKVForm(), expected) + self.assertEqual(e.encodeToKVForm(), expected) def test_noMessage(self): e = server.ProtocolError(None, "no moar pancakes") self.failIf(e.hasReturnTo()) - self.failUnlessEqual(e.whichEncoding(), None) + self.assertIsNone(e.whichEncoding()) class TestDecode(unittest.TestCase): @@ -139,7 +139,7 @@ def setUp(self): def test_none(self): args = {} r = self.decode(args) - self.failUnlessEqual(r, None) + self.assertIsNone(r) def test_irrelevant(self): args = { @@ -182,12 +182,12 @@ def test_checkidImmediate(self): } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) - self.failUnlessEqual(r.mode, "checkid_immediate") - self.failUnlessEqual(r.immediate, True) - self.failUnlessEqual(r.identity, self.id_url) - self.failUnlessEqual(r.trust_root, self.tr_url) - self.failUnlessEqual(r.return_to, self.rt_url) - self.failUnlessEqual(r.assoc_handle, self.assoc_handle) + self.assertEqual(r.mode, "checkid_immediate") + self.assertTrue(r.immediate) + self.assertEqual(r.identity, self.id_url) + self.assertEqual(r.trust_root, self.tr_url) + self.assertEqual(r.return_to, self.rt_url) + self.assertEqual(r.assoc_handle, self.assoc_handle) def test_checkidSetup(self): args = { @@ -199,11 +199,11 @@ def test_checkidSetup(self): } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) - self.failUnlessEqual(r.mode, "checkid_setup") - self.failUnlessEqual(r.immediate, False) - self.failUnlessEqual(r.identity, self.id_url) - self.failUnlessEqual(r.trust_root, self.tr_url) - self.failUnlessEqual(r.return_to, self.rt_url) + self.assertEqual(r.mode, "checkid_setup") + self.assertFalse(r.immediate) + self.assertEqual(r.identity, self.id_url) + self.assertEqual(r.trust_root, self.tr_url) + self.assertEqual(r.return_to, self.rt_url) def test_checkidSetupOpenID2(self): args = { @@ -217,12 +217,12 @@ def test_checkidSetupOpenID2(self): } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) - self.failUnlessEqual(r.mode, "checkid_setup") - self.failUnlessEqual(r.immediate, False) - self.failUnlessEqual(r.identity, self.id_url) - self.failUnlessEqual(r.claimed_id, self.claimed_id) - self.failUnlessEqual(r.trust_root, self.tr_url) - self.failUnlessEqual(r.return_to, self.rt_url) + self.assertEqual(r.mode, "checkid_setup") + self.assertFalse(r.immediate) + self.assertEqual(r.identity, self.id_url) + self.assertEqual(r.claimed_id, self.claimed_id) + self.assertEqual(r.trust_root, self.tr_url) + self.assertEqual(r.return_to, self.rt_url) def test_checkidSetupNoClaimedIDOpenID2(self): args = { @@ -245,11 +245,11 @@ def test_checkidSetupNoIdentityOpenID2(self): } r = self.decode(args) self.failUnless(isinstance(r, server.CheckIDRequest)) - self.failUnlessEqual(r.mode, "checkid_setup") - self.failUnlessEqual(r.immediate, False) - self.failUnlessEqual(r.identity, None) - self.failUnlessEqual(r.trust_root, self.tr_url) - self.failUnlessEqual(r.return_to, self.rt_url) + self.assertEqual(r.mode, "checkid_setup") + self.assertFalse(r.immediate) + self.assertIsNone(r.identity) + self.assertEqual(r.trust_root, self.tr_url) + self.assertEqual(r.return_to, self.rt_url) def test_checkidSetupNoReturnOpenID1(self): """Make sure an OpenID 1 request cannot be decoded if it lacks @@ -340,8 +340,8 @@ def test_checkAuth(self): } r = self.decode(args) self.failUnless(isinstance(r, server.CheckAuthRequest)) - self.failUnlessEqual(r.mode, 'check_authentication') - self.failUnlessEqual(r.sig, 'sigblob') + self.assertEqual(r.mode, 'check_authentication') + self.assertEqual(r.sig, 'sigblob') def test_checkAuthMissingSignature(self): args = { @@ -368,7 +368,7 @@ def test_checkAuthAndInvalidate(self): } r = self.decode(args) self.failUnless(isinstance(r, server.CheckAuthRequest)) - self.failUnlessEqual(r.invalidate_handle, '[[SMART_handle]]') + self.assertEqual(r.invalidate_handle, '[[SMART_handle]]') def test_associateDH(self): args = { @@ -378,9 +378,9 @@ def test_associateDH(self): } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) - self.failUnlessEqual(r.mode, "associate") - self.failUnlessEqual(r.session.session_type, "DH-SHA1") - self.failUnlessEqual(r.assoc_type, "HMAC-SHA1") + self.assertEqual(r.mode, "associate") + self.assertEqual(r.session.session_type, "DH-SHA1") + self.assertEqual(r.assoc_type, "HMAC-SHA1") self.failUnless(r.session.consumer_pubkey) def test_associateDHMissingKey(self): @@ -411,11 +411,11 @@ def test_associateDHModGen(self): } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) - self.failUnlessEqual(r.mode, "associate") - self.failUnlessEqual(r.session.session_type, "DH-SHA1") - self.failUnlessEqual(r.assoc_type, "HMAC-SHA1") - self.failUnlessEqual(r.session.dh.modulus, ALT_MODULUS) - self.failUnlessEqual(r.session.dh.generator, ALT_GEN) + self.assertEqual(r.mode, "associate") + self.assertEqual(r.session.session_type, "DH-SHA1") + self.assertEqual(r.assoc_type, "HMAC-SHA1") + self.assertEqual(r.session.dh.modulus, ALT_MODULUS) + self.assertEqual(r.session.dh.generator, ALT_GEN) self.failUnless(r.session.consumer_pubkey) def test_associateDHCorruptModGen(self): @@ -467,9 +467,9 @@ def test_associatePlain(self): } r = self.decode(args) self.failUnless(isinstance(r, server.AssociateRequest)) - self.failUnlessEqual(r.mode, "associate") - self.failUnlessEqual(r.session.session_type, "no-encryption") - self.failUnlessEqual(r.assoc_type, "HMAC-SHA1") + self.assertEqual(r.mode, "associate") + self.assertEqual(r.session.session_type, "no-encryption") + self.assertEqual(r.assoc_type, "HMAC-SHA1") def test_nomode(self): args = { @@ -630,7 +630,7 @@ def test_id_res_OpenID1_exceeds_limit(self): self.failUnless(len(response.encodeToURL()) > OPENID1_URL_LIMIT) self.failUnless(response.whichEncoding() == server.ENCODE_URL) webresponse = self.encode(response) - self.failUnlessEqual(webresponse.headers['location'], response.encodeToURL()) + self.assertEqual(webresponse.headers['location'], response.encodeToURL()) def test_id_res(self): request = server.CheckIDRequest( @@ -648,7 +648,7 @@ def test_id_res(self): 'return_to': request.return_to, }) webresponse = self.encode(response) - self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) + self.assertEqual(webresponse.code, server.HTTP_REDIRECT) self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] @@ -658,7 +658,7 @@ def test_id_res(self): # argh. q2 = dict(cgi.parse_qsl(urlparse(location)[4])) expected = response.fields.toPostArgs() - self.failUnlessEqual(q2, expected) + self.assertEqual(q2, expected) def test_cancel(self): request = server.CheckIDRequest( @@ -674,7 +674,7 @@ def test_cancel(self): 'mode': 'cancel', }) webresponse = self.encode(response) - self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) + self.assertEqual(webresponse.code, server.HTTP_REDIRECT) self.assertIn('location', webresponse.headers) def test_cancelToForm(self): @@ -704,9 +704,9 @@ def test_assocReply(self): webresponse = self.encode(response) body = """assoc_handle:every-zig """ - self.failUnlessEqual(webresponse.code, server.HTTP_OK) - self.failUnlessEqual(webresponse.headers, {}) - self.failUnlessEqual(webresponse.body, body) + self.assertEqual(webresponse.code, server.HTTP_OK) + self.assertEqual(webresponse.headers, {}) + self.assertEqual(webresponse.body, body) def test_checkauthReply(self): request = server.CheckAuthRequest('a_sock_monkey', @@ -721,9 +721,9 @@ def test_checkauthReply(self): is_valid:true """ webresponse = self.encode(response) - self.failUnlessEqual(webresponse.code, server.HTTP_OK) - self.failUnlessEqual(webresponse.headers, {}) - self.failUnlessEqual(webresponse.body, body) + self.assertEqual(webresponse.code, server.HTTP_OK) + self.assertEqual(webresponse.headers, {}) + self.assertEqual(webresponse.body, body) def test_unencodableError(self): args = Message.fromPostArgs({ @@ -739,9 +739,9 @@ def test_encodableError(self): }) body = "error:snoot\nmode:error\n" webresponse = self.encode(server.ProtocolError(args, "snoot")) - self.failUnlessEqual(webresponse.code, server.HTTP_ERROR) - self.failUnlessEqual(webresponse.headers, {}) - self.failUnlessEqual(webresponse.body, body) + self.assertEqual(webresponse.code, server.HTTP_ERROR) + self.assertEqual(webresponse.headers, {}) + self.assertEqual(webresponse.body, body) class TestSigningEncode(unittest.TestCase): @@ -776,7 +776,7 @@ def test_idres(self): 'sekrit', 'HMAC-SHA1')) self.request.assoc_handle = assoc_handle webresponse = self.encode(self.response) - self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) + self.assertEqual(webresponse.code, server.HTTP_REDIRECT) self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] @@ -787,7 +787,7 @@ def test_idres(self): def test_idresDumb(self): webresponse = self.encode(self.response) - self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) + self.assertEqual(webresponse.code, server.HTTP_REDIRECT) self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] @@ -812,7 +812,7 @@ def test_cancel(self): response = server.OpenIDResponse(request) response.fields.setArg(OPENID_NS, 'mode', 'cancel') webresponse = self.encode(response) - self.failUnlessEqual(webresponse.code, server.HTTP_REDIRECT) + self.assertEqual(webresponse.code, server.HTTP_REDIRECT) self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) @@ -828,9 +828,9 @@ def test_assocReply(self): webresponse = self.encode(response) body = """assoc_handle:every-zig """ - self.failUnlessEqual(webresponse.code, server.HTTP_OK) - self.failUnlessEqual(webresponse.headers, {}) - self.failUnlessEqual(webresponse.body, body) + self.assertEqual(webresponse.code, server.HTTP_OK) + self.assertEqual(webresponse.headers, {}) + self.assertEqual(webresponse.body, body) def test_alreadySigned(self): self.response.fields.setArg(OPENID_NS, 'sig', 'priorSig==') @@ -901,8 +901,8 @@ def withVerifyReturnTo(new_verify, callable): sentinel = Exception() def vrfyExc(trust_root, return_to): - self.failUnlessEqual(self.request.trust_root, trust_root) - self.failUnlessEqual(self.request.return_to, return_to) + self.assertEqual(trust_root, self.request.trust_root) + self.assertEqual(return_to, self.request.return_to) raise sentinel try: @@ -913,16 +913,13 @@ def vrfyExc(trust_root, return_to): # Ensure that True and False are passed through unchanged def constVerify(val): def verify(trust_root, return_to): - self.failUnlessEqual(self.request.trust_root, trust_root) - self.failUnlessEqual(self.request.return_to, return_to) + self.assertEqual(trust_root, self.request.trust_root) + self.assertEqual(return_to, self.request.return_to) return val return verify for val in [True, False]: - self.failUnlessEqual( - val, - withVerifyReturnTo(constVerify(val), - self.request.returnToVerified)) + self.assertEqual(withVerifyReturnTo(constVerify(val), self.request.returnToVerified), val) def _expectAnswer(self, answer, identity=None, claimed_id=None): expected_list = [ @@ -939,15 +936,13 @@ def _expectAnswer(self, answer, identity=None, claimed_id=None): for k, expected in expected_list: actual = answer.fields.getArg(OPENID_NS, k) - self.failUnlessEqual(actual, expected, "%s: expected %s, got %s" % (k, expected, actual)) + self.assertEqual(actual, expected, "%s: expected %s, got %s" % (k, expected, actual)) self.failUnless(answer.fields.hasKey(OPENID_NS, 'response_nonce')) self.failUnless(answer.fields.getOpenIDNamespace() == OPENID2_NS) # One for nonce, one for ns - self.failUnlessEqual(len(answer.fields.toPostArgs()), - len(expected_list) + 2, - answer.fields.toPostArgs()) + self.assertEqual(len(answer.fields.toPostArgs()), len(expected_list) + 2) def test_answerAllow(self): """Check the fields specified by "Positive Assertions" @@ -955,7 +950,7 @@ def test_answerAllow(self): including mode=id_res, identity, claimed_id, op_endpoint, return_to """ answer = self.request.answer(True) - self.failUnlessEqual(answer.request, self.request) + self.assertEqual(answer.request, self.request) self._expectAnswer(answer, self.request.identity) def test_answerAllowDelegatedIdentity(self): @@ -975,7 +970,7 @@ def test_answerAllowDelegatedIdentity2(self): def test_answerAllowWithoutIdentityReally(self): self.request.identity = None answer = self.request.answer(True) - self.failUnlessEqual(answer.request, self.request) + self.assertEqual(answer.request, self.request) self._expectAnswer(answer) def test_answerAllowAnonymousFail(self): @@ -1113,7 +1108,7 @@ def test_trustRootOpenID2(self): def test_answerAllowNoTrustRoot(self): self.request.trust_root = None answer = self.request.answer(True) - self.failUnlessEqual(answer.request, self.request) + self.assertEqual(answer.request, self.request) self._expectAnswer(answer, self.request.identity) def test_fromMessageWithoutTrustRoot(self): @@ -1126,7 +1121,7 @@ def test_fromMessageWithoutTrustRoot(self): result = server.CheckIDRequest.fromMessage(msg, self.server.op_endpoint) - self.failUnlessEqual(result.trust_root, 'http://real_trust_root/foo') + self.assertEqual(result.trust_root, 'http://real_trust_root/foo') def test_fromMessageWithEmptyTrustRoot(self): return_to = u'http://someplace.invalid/?go=thing' @@ -1141,7 +1136,7 @@ def test_fromMessageWithEmptyTrustRoot(self): result = server.CheckIDRequest.fromMessage(msg, self.server.op_endpoint) - self.failUnlessEqual(result.trust_root, return_to) + self.assertEqual(result.trust_root, return_to) def test_fromMessageWithoutTrustRootOrReturnTo(self): msg = Message(OPENID2_NS) @@ -1175,18 +1170,14 @@ def test_answerAllowNoEndpointOpenID1(self): for k, expected in expected_list: actual = answer.fields.getArg(OPENID_NS, k) - self.failUnlessEqual( - expected, actual, - "%s: expected %s, got %s" % (k, expected, actual)) + self.assertEqual(actual, expected, "%s: expected %s, got %s" % (k, expected, actual)) self.failUnless(answer.fields.hasKey(OPENID_NS, 'response_nonce')) - self.failUnlessEqual(answer.fields.getOpenIDNamespace(), OPENID1_NS) + self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID1_NS) self.failUnless(answer.fields.namespaces.isImplicit(OPENID1_NS)) # One for nonce (OpenID v1 namespace is implicit) - self.failUnlessEqual(len(answer.fields.toPostArgs()), - len(expected_list) + 1, - answer.fields.toPostArgs()) + self.assertEqual(len(answer.fields.toPostArgs()), len(expected_list) + 1) def test_answerImmediateDenyOpenID2(self): """Look for mode=setup_needed in checkid_immediate negative @@ -1201,11 +1192,10 @@ def test_answerImmediateDenyOpenID2(self): server_url = "http://setup-url.unittest/" # crappiting setup_url, you dirty my interface with your presence! answer = self.request.answer(False, server_url=server_url) - self.failUnlessEqual(answer.request, self.request) - self.failUnlessEqual(len(answer.fields.toPostArgs()), 3, answer.fields) - self.failUnlessEqual(answer.fields.getOpenIDNamespace(), OPENID2_NS) - self.failUnlessEqual(answer.fields.getArg(OPENID_NS, 'mode'), - 'setup_needed') + self.assertEqual(answer.request, self.request) + self.assertEqual(len(answer.fields.toPostArgs()), 3, answer.fields) + self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID2_NS) + self.assertEqual(answer.fields.getArg(OPENID_NS, 'mode'), 'setup_needed') usu = answer.fields.getArg(OPENID_NS, 'user_setup_url') expected_substr = 'openid.claimed_id=https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fclaimed-id.test%2F' @@ -1220,19 +1210,17 @@ def test_answerImmediateDenyOpenID1(self): server_url = "http://setup-url.unittest/" # crappiting setup_url, you dirty my interface with your presence! answer = self.request.answer(False, server_url=server_url) - self.failUnlessEqual(answer.request, self.request) - self.failUnlessEqual(len(answer.fields.toPostArgs()), 2, answer.fields) - self.failUnlessEqual(answer.fields.getOpenIDNamespace(), OPENID1_NS) + self.assertEqual(answer.request, self.request) + self.assertEqual(len(answer.fields.toPostArgs()), 2, answer.fields) + self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID1_NS) self.failUnless(answer.fields.namespaces.isImplicit(OPENID1_NS)) - self.failUnlessEqual(answer.fields.getArg(OPENID_NS, 'mode'), 'id_res') + self.assertEqual(answer.fields.getArg(OPENID_NS, 'mode'), 'id_res') self.failUnless(answer.fields.getArg( OPENID_NS, 'user_setup_url', '').startswith(server_url)) def test_answerSetupDeny(self): answer = self.request.answer(False) - self.failUnlessEqual(answer.fields.getArgs(OPENID_NS), { - 'mode': 'cancel', - }) + self.assertEqual(answer.fields.getArgs(OPENID_NS), {'mode': 'cancel'}) def test_encodeToURL(self): server_url = 'http://openid-server.unittest/' @@ -1246,15 +1234,14 @@ def test_encodeToURL(self): self.server.op_endpoint) # argh, lousy hack self.request.message = message - self.failUnlessEqual(rebuilt_request.__dict__, self.request.__dict__) + self.assertEqual(rebuilt_request.__dict__, self.request.__dict__) def test_getCancelURL(self): url = self.request.getCancelURL() rt, query_string = url.split('?') - self.failUnlessEqual(self.request.return_to, rt) + self.assertEqual(self.request.return_to, rt) query = dict(cgi.parse_qsl(query_string)) - self.failUnlessEqual(query, {'openid.mode': 'cancel', - 'openid.ns': OPENID2_NS}) + self.assertEqual(query, {'openid.mode': 'cancel', 'openid.ns': OPENID2_NS}) def test_getCancelURLimmed(self): self.request.mode = 'checkid_immediate' @@ -1283,24 +1270,16 @@ def setUp(self): def test_addField(self): namespace = 'something:' self.response.fields.setArg(namespace, 'bright', 'potato') - self.failUnlessEqual(self.response.fields.getArgs(OPENID_NS), - {'blue': 'star', - 'mode': 'id_res', - }) - - self.failUnlessEqual(self.response.fields.getArgs(namespace), - {'bright': 'potato'}) + self.assertEqual(self.response.fields.getArgs(OPENID_NS), {'blue': 'star', 'mode': 'id_res'}) + self.assertEqual(self.response.fields.getArgs(namespace), {'bright': 'potato'}) def test_addFields(self): namespace = 'mi5:' args = {'tangy': 'suspenders', 'bravo': 'inclusion'} self.response.fields.updateArgs(namespace, args) - self.failUnlessEqual(self.response.fields.getArgs(OPENID_NS), - {'blue': 'star', - 'mode': 'id_res', - }) - self.failUnlessEqual(self.response.fields.getArgs(namespace), args) + self.assertEqual(self.response.fields.getArgs(OPENID_NS), {'blue': 'star', 'mode': 'id_res'}) + self.assertEqual(self.response.fields.getArgs(namespace), args) class MockSignatory(object): @@ -1344,14 +1323,13 @@ def setUp(self): def test_valid(self): r = self.request.answer(self.signatory) - self.failUnlessEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'true'}) - self.failUnlessEqual(r.request, self.request) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'true'}) + self.assertEqual(r.request, self.request) def test_invalid(self): self.signatory.isValid = False r = self.request.answer(self.signatory) - self.failUnlessEqual(r.fields.getArgs(OPENID_NS), - {'is_valid': 'false'}) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'false'}) def test_replay(self): """Don't validate the same response twice. @@ -1368,23 +1346,20 @@ def test_replay(self): """ r = self.request.answer(self.signatory) r = self.request.answer(self.signatory) - self.failUnlessEqual(r.fields.getArgs(OPENID_NS), - {'is_valid': 'false'}) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'false'}) def test_invalidatehandle(self): self.request.invalidate_handle = "bogusHandle" r = self.request.answer(self.signatory) - self.failUnlessEqual(r.fields.getArgs(OPENID_NS), - {'is_valid': 'true', - 'invalidate_handle': "bogusHandle"}) - self.failUnlessEqual(r.request, self.request) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'true', 'invalidate_handle': "bogusHandle"}) + self.assertEqual(r.request, self.request) def test_invalidatehandleNo(self): assoc_handle = 'goodhandle' self.signatory.assocs.append((False, 'goodhandle')) self.request.invalidate_handle = assoc_handle r = self.request.answer(self.signatory) - self.failUnlessEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'true'}) + self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'true'}) class TestAssociate(unittest.TestCase): @@ -1409,17 +1384,17 @@ def test_dhSHA1(self): response = self.request.answer(self.assoc) rfg = partial(response.fields.getArg, OPENID_NS) - self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") - self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) + self.assertEqual(rfg("assoc_type"), "HMAC-SHA1") + self.assertEqual(rfg("assoc_handle"), self.assoc.handle) self.failIf(rfg("mac_key")) - self.failUnlessEqual(rfg("session_type"), "DH-SHA1") + self.assertEqual(rfg("session_type"), "DH-SHA1") self.failUnless(rfg("enc_mac_key")) self.failUnless(rfg("dh_server_public")) enc_key = rfg("enc_mac_key").decode('base64') spub = cryptutil.base64ToLong(rfg("dh_server_public")) secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha1) - self.failUnlessEqual(secret, self.assoc.secret) + self.assertEqual(secret, self.assoc.secret) def test_dhSHA256(self): self.assoc = self.signatory.createAssociation( @@ -1434,17 +1409,17 @@ def test_dhSHA256(self): response = self.request.answer(self.assoc) rfg = partial(response.fields.getArg, OPENID_NS) - self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA256") - self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) + self.assertEqual(rfg("assoc_type"), "HMAC-SHA256") + self.assertEqual(rfg("assoc_handle"), self.assoc.handle) self.failIf(rfg("mac_key")) - self.failUnlessEqual(rfg("session_type"), "DH-SHA256") + self.assertEqual(rfg("session_type"), "DH-SHA256") self.failUnless(rfg("enc_mac_key")) self.failUnless(rfg("dh_server_public")) enc_key = rfg("enc_mac_key").decode('base64') spub = cryptutil.base64ToLong(rfg("dh_server_public")) secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha256) - self.failUnlessEqual(secret, self.assoc.secret) + self.assertEqual(secret, self.assoc.secret) def test_protoError256(self): s256_session = DiffieHellmanSHA256ConsumerSession() @@ -1510,16 +1485,16 @@ def test_protoErrorFields(self): contact=contact, reference=reference) reply = p.toMessage() - self.failUnlessEqual(reply.getArg(OPENID_NS, 'reference'), reference) - self.failUnlessEqual(reply.getArg(OPENID_NS, 'contact'), contact) + self.assertEqual(reply.getArg(OPENID_NS, 'reference'), reference) + self.assertEqual(reply.getArg(OPENID_NS, 'contact'), contact) openid2_msg = Message.fromPostArgs(openid2_args) p = server.ProtocolError(openid2_msg, error, contact=contact, reference=reference) reply = p.toMessage() - self.failUnlessEqual(reply.getArg(OPENID_NS, 'reference'), reference) - self.failUnlessEqual(reply.getArg(OPENID_NS, 'contact'), contact) + self.assertEqual(reply.getArg(OPENID_NS, 'reference'), reference) + self.assertEqual(reply.getArg(OPENID_NS, 'contact'), contact) def failUnlessExpiresInMatches(self, msg, expected_expires_in): expires_in_str = msg.getArg(OPENID_NS, 'expires_in', no_default) @@ -1541,14 +1516,13 @@ def test_plaintext(self): rfg = partial(response.fields.getArg, OPENID_NS) - self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") - self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) + self.assertEqual(rfg("assoc_type"), "HMAC-SHA1") + self.assertEqual(rfg("assoc_handle"), self.assoc.handle) self.failUnlessExpiresInMatches( response.fields, self.signatory.SECRET_LIFETIME) - self.failUnlessEqual( - rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) + self.assertEqual(rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) self.failIf(rfg("session_type")) self.failIf(rfg("enc_mac_key")) self.failIf(rfg("dh_server_public")) @@ -1573,16 +1547,15 @@ def test_plaintext_v2(self): rfg = partial(response.fields.getArg, OPENID_NS) - self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") - self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) + self.assertEqual(rfg("assoc_type"), "HMAC-SHA1") + self.assertEqual(rfg("assoc_handle"), self.assoc.handle) self.failUnlessExpiresInMatches( response.fields, self.signatory.SECRET_LIFETIME) - self.failUnlessEqual( - rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) + self.assertEqual(rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) - self.failUnlessEqual(rfg("session_type"), "no-encryption") + self.assertEqual(rfg("session_type"), "no-encryption") self.failIf(rfg("enc_mac_key")) self.failIf(rfg("dh_server_public")) @@ -1592,14 +1565,13 @@ def test_plaintext256(self): rfg = partial(response.fields.getArg, OPENID_NS) - self.failUnlessEqual(rfg("assoc_type"), "HMAC-SHA1") - self.failUnlessEqual(rfg("assoc_handle"), self.assoc.handle) + self.assertEqual(rfg("assoc_type"), "HMAC-SHA1") + self.assertEqual(rfg("assoc_handle"), self.assoc.handle) self.failUnlessExpiresInMatches( response.fields, self.signatory.SECRET_LIFETIME) - self.failUnlessEqual( - rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) + self.assertEqual(rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) self.failIf(rfg("session_type")) self.failIf(rfg("enc_mac_key")) self.failIf(rfg("dh_server_public")) @@ -1620,10 +1592,10 @@ def test_unsupportedPrefer(self): ) rfg = partial(response.fields.getArg, OPENID_NS) - self.failUnlessEqual(rfg('error_code'), 'unsupported-type') - self.failUnlessEqual(rfg('assoc_type'), allowed_assoc) - self.failUnlessEqual(rfg('error'), message) - self.failUnlessEqual(rfg('session_type'), allowed_sess) + self.assertEqual(rfg('error_code'), 'unsupported-type') + self.assertEqual(rfg('assoc_type'), allowed_assoc) + self.assertEqual(rfg('error'), message) + self.assertEqual(rfg('session_type'), allowed_sess) def test_unsupported(self): message = 'This is a unit test' @@ -1635,10 +1607,10 @@ def test_unsupported(self): response = self.request.answerUnsupported(message) rfg = partial(response.fields.getArg, OPENID_NS) - self.failUnlessEqual(rfg('error_code'), 'unsupported-type') - self.failUnlessEqual(rfg('assoc_type'), None) - self.failUnlessEqual(rfg('error'), message) - self.failUnlessEqual(rfg('session_type'), None) + self.assertEqual(rfg('error_code'), 'unsupported-type') + self.assertIsNone(rfg('assoc_type')) + self.assertEqual(rfg('error'), message) + self.assertIsNone(rfg('session_type')) class Counter(object): @@ -1667,7 +1639,7 @@ def monkeyDo(request): request.mode = "monkeymode" request.namespace = OPENID1_NS self.server.handleRequest(request) - self.failUnlessEqual(monkeycalled.count, 1) + self.assertEqual(monkeycalled.count, 1) def test_associate(self): request = server.AssociateRequest.fromMessage(Message.fromPostArgs({})) @@ -1719,10 +1691,8 @@ def test_associate3(self): self.failUnless(response.fields.hasKey(OPENID_NS, "error")) self.failUnless(response.fields.hasKey(OPENID_NS, "error_code")) self.failIf(response.fields.hasKey(OPENID_NS, "assoc_handle")) - self.failUnlessEqual(response.fields.getArg(OPENID_NS, "assoc_type"), - 'HMAC-SHA256') - self.failUnlessEqual(response.fields.getArg(OPENID_NS, "session_type"), - 'DH-SHA256') + self.assertEqual(response.fields.getArg(OPENID_NS, "assoc_type"), 'HMAC-SHA256') + self.assertEqual(response.fields.getArg(OPENID_NS, "session_type"), 'DH-SHA256') def test_associate4(self): """DH-SHA256 association session""" @@ -1790,11 +1760,8 @@ def test_sign(self): 'azu': 'alsosigned', }) sresponse = self.signatory.sign(response) - self.failUnlessEqual( - sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), - assoc_handle) - self.failUnlessEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), - 'assoc_handle,azu,bar,foo,signed') + self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), assoc_handle) + self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,signed') self.failUnless(sresponse.fields.getArg(OPENID_NS, 'sig')) self.failIf(self.messages, self.messages) @@ -1814,8 +1781,7 @@ def test_signDumb(self): self.failUnless(assoc_handle) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.failUnless(assoc) - self.failUnlessEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), - 'assoc_handle,azu,bar,foo,ns,signed') + self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,ns,signed') self.failUnless(sresponse.fields.getArg(OPENID_NS, 'sig')) self.failIf(self.messages, self.messages) @@ -1855,12 +1821,10 @@ def test_signExpired(self): self.failUnless(new_assoc_handle) self.failIfEqual(new_assoc_handle, assoc_handle) - self.failUnlessEqual( - sresponse.fields.getArg(OPENID_NS, 'invalidate_handle'), - assoc_handle) + self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'invalidate_handle'), assoc_handle) - self.failUnlessEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), - 'assoc_handle,azu,bar,foo,invalidate_handle,signed') + self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), + 'assoc_handle,azu,bar,foo,invalidate_handle,signed') self.failUnless(sresponse.fields.getArg(OPENID_NS, 'sig')) # make sure the expired association is gone @@ -1890,12 +1854,10 @@ def test_signInvalidHandle(self): self.failUnless(new_assoc_handle) self.failIfEqual(new_assoc_handle, assoc_handle) - self.failUnlessEqual( - sresponse.fields.getArg(OPENID_NS, 'invalidate_handle'), - assoc_handle) + self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'invalidate_handle'), assoc_handle) - self.failUnlessEqual( - sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,invalidate_handle,signed') + self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), + 'assoc_handle,azu,bar,foo,invalidate_handle,signed') self.failUnless(sresponse.fields.getArg(OPENID_NS, 'sig')) # make sure the new key is a dumb mode association @@ -1975,7 +1937,7 @@ def test_getAssoc(self): assoc_handle = self.makeAssoc(dumb=True) assoc = self.signatory.getAssociation(assoc_handle, True) self.failUnless(assoc) - self.failUnlessEqual(assoc.handle, assoc_handle) + self.assertEqual(assoc.handle, assoc_handle) self.failIf(self.messages, self.messages) def test_getAssocExpired(self): @@ -1986,15 +1948,13 @@ def test_getAssocExpired(self): def test_getAssocInvalid(self): ah = 'no-such-handle' - self.failUnlessEqual( - self.signatory.getAssociation(ah, dumb=False), None) + self.assertIsNone(self.signatory.getAssociation(ah, dumb=False)) self.failIf(self.messages, self.messages) def test_getAssocDumbVsNormal(self): """getAssociation(dumb=False) cannot get a dumb assoc""" assoc_handle = self.makeAssoc(dumb=True) - self.failUnlessEqual( - self.signatory.getAssociation(assoc_handle, dumb=False), None) + self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=False)) self.failIf(self.messages, self.messages) def test_getAssocNormalVsDumb(self): @@ -2006,8 +1966,7 @@ def test_getAssocNormalVsDumb(self): MAC keys. """ assoc_handle = self.makeAssoc(dumb=False) - self.failUnlessEqual( - self.signatory.getAssociation(assoc_handle, dumb=True), None) + self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=True)) self.failIf(self.messages, self.messages) def test_createAssociation(self): diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index ddcf9dc4..bca85040 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -7,7 +7,7 @@ class SRegURITest(unittest.TestCase): def test_is11(self): - self.failUnlessEqual(sreg.ns_uri_1_1, sreg.ns_uri) + self.assertEqual(sreg.ns_uri_1_1, sreg.ns_uri) class CheckFieldNameTest(unittest.TestCase): @@ -38,19 +38,17 @@ class SupportsSRegTest(unittest.TestCase): def test_unsupported(self): endpoint = FakeEndpoint([]) self.failIf(sreg.supportsSReg(endpoint)) - self.failUnlessEqual([sreg.ns_uri_1_1, sreg.ns_uri_1_0], - endpoint.checked_uris) + self.assertEqual(endpoint.checked_uris, [sreg.ns_uri_1_1, sreg.ns_uri_1_0]) def test_supported_1_1(self): endpoint = FakeEndpoint([sreg.ns_uri_1_1]) self.failUnless(sreg.supportsSReg(endpoint)) - self.failUnlessEqual([sreg.ns_uri_1_1], endpoint.checked_uris) + self.assertEqual(endpoint.checked_uris, [sreg.ns_uri_1_1]) def test_supported_1_0(self): endpoint = FakeEndpoint([sreg.ns_uri_1_0]) self.failUnless(sreg.supportsSReg(endpoint)) - self.failUnlessEqual([sreg.ns_uri_1_1, sreg.ns_uri_1_0], - endpoint.checked_uris) + self.assertEqual(endpoint.checked_uris, [sreg.ns_uri_1_1, sreg.ns_uri_1_0]) class FakeMessage(object): @@ -68,20 +66,20 @@ def setUp(self): def test_openID2Empty(self): ns_uri = sreg.getSRegNS(self.msg) - self.failUnlessEqual(self.msg.namespaces.getAlias(ns_uri), 'sreg') - self.failUnlessEqual(sreg.ns_uri, ns_uri) + self.assertEqual(self.msg.namespaces.getAlias(ns_uri), 'sreg') + self.assertEqual(ns_uri, sreg.ns_uri) def test_openID1Empty(self): self.msg.openid1 = True ns_uri = sreg.getSRegNS(self.msg) - self.failUnlessEqual(self.msg.namespaces.getAlias(ns_uri), 'sreg') - self.failUnlessEqual(sreg.ns_uri, ns_uri) + self.assertEqual(self.msg.namespaces.getAlias(ns_uri), 'sreg') + self.assertEqual(ns_uri, sreg.ns_uri) def test_openID1Defined_1_0(self): self.msg.openid1 = True self.msg.namespaces.add(sreg.ns_uri_1_0) ns_uri = sreg.getSRegNS(self.msg) - self.failUnlessEqual(sreg.ns_uri_1_0, ns_uri) + self.assertEqual(ns_uri, sreg.ns_uri_1_0) def test_openID1Defined_1_0_overrideAlias(self): for openid_version in [True, False]: @@ -92,8 +90,8 @@ def test_openID1Defined_1_0_overrideAlias(self): self.msg.openid1 = openid_version self.msg.namespaces.addAlias(sreg_version, alias) ns_uri = sreg.getSRegNS(self.msg) - self.failUnlessEqual(self.msg.namespaces.getAlias(ns_uri), alias) - self.failUnlessEqual(sreg_version, ns_uri) + self.assertEqual(self.msg.namespaces.getAlias(ns_uri), alias) + self.assertEqual(ns_uri, sreg_version) def test_openID1DefinedBadly(self): self.msg.openid1 = True @@ -110,7 +108,7 @@ def test_openID2DefinedBadly(self): def test_openID2Defined_1_0(self): self.msg.namespaces.add(sreg.ns_uri_1_0) ns_uri = sreg.getSRegNS(self.msg) - self.failUnlessEqual(sreg.ns_uri_1_0, ns_uri) + self.assertEqual(ns_uri, sreg.ns_uri_1_0) def test_openID1_sregNSfromArgs(self): args = { @@ -127,10 +125,10 @@ def test_openID1_sregNSfromArgs(self): class SRegRequestTest(unittest.TestCase): def test_constructEmpty(self): req = sreg.SRegRequest() - self.failUnlessEqual([], req.optional) - self.failUnlessEqual([], req.required) - self.failUnlessEqual(None, req.policy_url) - self.failUnlessEqual(sreg.ns_uri, req.ns_uri) + self.assertEqual(req.optional, []) + self.assertEqual(req.required, []) + self.assertIsNone(req.policy_url) + self.assertEqual(req.ns_uri, sreg.ns_uri) def test_constructFields(self): req = sreg.SRegRequest( @@ -138,10 +136,10 @@ def test_constructFields(self): ['gender'], 'http://policy', 'http://sreg.ns_uri') - self.failUnlessEqual(['gender'], req.optional) - self.failUnlessEqual(['nickname'], req.required) - self.failUnlessEqual('http://policy', req.policy_url) - self.failUnlessEqual('http://sreg.ns_uri', req.ns_uri) + self.assertEqual(req.optional, ['gender']) + self.assertEqual(req.required, ['nickname']) + self.assertEqual(req.policy_url, 'http://policy') + self.assertEqual(req.ns_uri, 'http://sreg.ns_uri') def test_constructBadFields(self): self.failUnlessRaises( @@ -159,7 +157,7 @@ def __init__(self): self.message = Message() def getArgs(msg_self, ns_uri): - self.failUnlessEqual(ns_sentinel, ns_uri) + self.assertEqual(ns_uri, ns_sentinel) return args_sentinel def copy(msg_self): @@ -171,7 +169,7 @@ def _getSRegNS(req_self, unused): return ns_sentinel def parseExtensionArgs(req_self, args): - self.failUnlessEqual(args_sentinel, args) + self.assertEqual(args, args_sentinel) openid_req = OpenIDRequest() @@ -185,7 +183,7 @@ def parseExtensionArgs(req_self, args): def test_parseExtensionArgs_empty(self): req = sreg.SRegRequest() results = req.parseExtensionArgs({}) - self.failUnlessEqual(None, results) + self.assertIsNone(results) def test_parseExtensionArgs_extraIgnored(self): req = sreg.SRegRequest() @@ -194,7 +192,7 @@ def test_parseExtensionArgs_extraIgnored(self): def test_parseExtensionArgs_nonStrict(self): req = sreg.SRegRequest() req.parseExtensionArgs({'required': 'beans'}) - self.failUnlessEqual([], req.required) + self.assertEqual(req.required, []) def test_parseExtensionArgs_strict(self): req = sreg.SRegRequest() @@ -205,32 +203,32 @@ def test_parseExtensionArgs_strict(self): def test_parseExtensionArgs_policy(self): req = sreg.SRegRequest() req.parseExtensionArgs({'policy_url': 'http://policy'}, strict=True) - self.failUnlessEqual('http://policy', req.policy_url) + self.assertEqual(req.policy_url, 'http://policy') def test_parseExtensionArgs_requiredEmpty(self): req = sreg.SRegRequest() req.parseExtensionArgs({'required': ''}, strict=True) - self.failUnlessEqual([], req.required) + self.assertEqual(req.required, []) def test_parseExtensionArgs_optionalEmpty(self): req = sreg.SRegRequest() req.parseExtensionArgs({'optional': ''}, strict=True) - self.failUnlessEqual([], req.optional) + self.assertEqual(req.optional, []) def test_parseExtensionArgs_optionalSingle(self): req = sreg.SRegRequest() req.parseExtensionArgs({'optional': 'nickname'}, strict=True) - self.failUnlessEqual(['nickname'], req.optional) + self.assertEqual(req.optional, ['nickname']) def test_parseExtensionArgs_optionalList(self): req = sreg.SRegRequest() req.parseExtensionArgs({'optional': 'nickname,email'}, strict=True) - self.failUnlessEqual(['nickname', 'email'], req.optional) + self.assertEqual(req.optional, ['nickname', 'email']) def test_parseExtensionArgs_optionalListBadNonStrict(self): req = sreg.SRegRequest() req.parseExtensionArgs({'optional': 'nickname,email,beer'}) - self.failUnlessEqual(['nickname', 'email'], req.optional) + self.assertEqual(req.optional, ['nickname', 'email']) def test_parseExtensionArgs_optionalListBadStrict(self): req = sreg.SRegRequest() @@ -243,8 +241,8 @@ def test_parseExtensionArgs_bothNonStrict(self): req = sreg.SRegRequest() req.parseExtensionArgs({'optional': 'nickname', 'required': 'nickname'}) - self.failUnlessEqual([], req.optional) - self.failUnlessEqual(['nickname'], req.required) + self.assertEqual(req.optional, []) + self.assertEqual(req.required, ['nickname']) def test_parseExtensionArgs_bothStrict(self): req = sreg.SRegRequest() @@ -259,17 +257,17 @@ def test_parseExtensionArgs_bothList(self): req = sreg.SRegRequest() req.parseExtensionArgs({'optional': 'nickname,email', 'required': 'country,postcode'}, strict=True) - self.failUnlessEqual(['nickname', 'email'], req.optional) - self.failUnlessEqual(['country', 'postcode'], req.required) + self.assertEqual(req.optional, ['nickname', 'email']) + self.assertEqual(req.required, ['country', 'postcode']) def test_allRequestedFields(self): req = sreg.SRegRequest() - self.failUnlessEqual([], req.allRequestedFields()) + self.assertEqual(req.allRequestedFields(), []) req.requestField('nickname') - self.failUnlessEqual(['nickname'], req.allRequestedFields()) + self.assertEqual(req.allRequestedFields(), ['nickname']) req.requestField('gender', required=True) requested = sorted(req.allRequestedFields()) - self.failUnlessEqual(['gender', 'nickname'], requested) + self.assertEqual(requested, ['gender', 'nickname']) def test_wereFieldsRequested(self): req = sreg.SRegRequest() @@ -308,36 +306,36 @@ def test_requestField(self): for field_name in fields: req.requestField(field_name) - self.failUnlessEqual(fields, req.optional) - self.failUnlessEqual([], req.required) + self.assertEqual(req.optional, fields) + self.assertEqual(req.required, []) # By default, adding the same fields over again has no effect for field_name in fields: req.requestField(field_name) - self.failUnlessEqual(fields, req.optional) - self.failUnlessEqual([], req.required) + self.assertEqual(req.optional, fields) + self.assertEqual(req.required, []) # Requesting a field as required overrides requesting it as optional expected = list(fields) overridden = expected.pop(0) req.requestField(overridden, required=True) - self.failUnlessEqual(expected, req.optional) - self.failUnlessEqual([overridden], req.required) + self.assertEqual(req.optional, expected) + self.assertEqual(req.required, [overridden]) # Requesting a field as required overrides requesting it as optional for field_name in fields: req.requestField(field_name, required=True) - self.failUnlessEqual([], req.optional) - self.failUnlessEqual(fields, req.required) + self.assertEqual(req.optional, []) + self.assertEqual(req.required, fields) # Requesting it as optional does not downgrade it to optional for field_name in fields: req.requestField(field_name) - self.failUnlessEqual([], req.optional) - self.failUnlessEqual(fields, req.required) + self.assertEqual(req.optional, []) + self.assertEqual(req.required, fields) def test_requestFields_type(self): req = sreg.SRegRequest() @@ -350,60 +348,54 @@ def test_requestFields(self): fields = list(sreg.data_fields) req.requestFields(fields) - self.failUnlessEqual(fields, req.optional) - self.failUnlessEqual([], req.required) + self.assertEqual(req.optional, fields) + self.assertEqual(req.required, []) # By default, adding the same fields over again has no effect req.requestFields(fields) - self.failUnlessEqual(fields, req.optional) - self.failUnlessEqual([], req.required) + self.assertEqual(req.optional, fields) + self.assertEqual(req.required, []) # Requesting a field as required overrides requesting it as optional expected = list(fields) overridden = expected.pop(0) req.requestFields([overridden], required=True) - self.failUnlessEqual(expected, req.optional) - self.failUnlessEqual([overridden], req.required) + self.assertEqual(req.optional, expected) + self.assertEqual(req.required, [overridden]) # Requesting a field as required overrides requesting it as optional req.requestFields(fields, required=True) - self.failUnlessEqual([], req.optional) - self.failUnlessEqual(fields, req.required) + self.assertEqual(req.optional, []) + self.assertEqual(req.required, fields) # Requesting it as optional does not downgrade it to optional req.requestFields(fields) - self.failUnlessEqual([], req.optional) - self.failUnlessEqual(fields, req.required) + self.assertEqual(req.optional, []) + self.assertEqual(req.required, fields) def test_getExtensionArgs(self): req = sreg.SRegRequest() - self.failUnlessEqual({}, req.getExtensionArgs()) + self.assertEqual(req.getExtensionArgs(), {}) req.requestField('nickname') - self.failUnlessEqual({'optional': 'nickname'}, req.getExtensionArgs()) + self.assertEqual(req.getExtensionArgs(), {'optional': 'nickname'}) req.requestField('email') - self.failUnlessEqual({'optional': 'nickname,email'}, - req.getExtensionArgs()) + self.assertEqual(req.getExtensionArgs(), {'optional': 'nickname,email'}) req.requestField('gender', required=True) - self.failUnlessEqual({'optional': 'nickname,email', - 'required': 'gender'}, - req.getExtensionArgs()) + self.assertEqual(req.getExtensionArgs(), {'optional': 'nickname,email', 'required': 'gender'}) req.requestField('postcode', required=True) - self.failUnlessEqual({'optional': 'nickname,email', - 'required': 'gender,postcode'}, - req.getExtensionArgs()) + self.assertEqual(req.getExtensionArgs(), {'optional': 'nickname,email', 'required': 'gender,postcode'}) req.policy_url = 'http://policy.invalid/' - self.failUnlessEqual({'optional': 'nickname,email', - 'required': 'gender,postcode', - 'policy_url': 'http://policy.invalid/'}, - req.getExtensionArgs()) + policy_data = {'optional': 'nickname,email', 'required': 'gender,postcode', + 'policy_url': 'http://policy.invalid/'} + self.assertEqual(req.getExtensionArgs(), policy_data) data = { @@ -453,8 +445,7 @@ def test_fromSuccessResponse_unsigned(self): success_resp = DummySuccessResponse(message, {}) sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp, signed_only=False) - self.failUnlessEqual([('nickname', 'The Mad Stork')], - sreg_resp.items()) + self.assertEqual(sreg_resp.items(), [('nickname', 'The Mad Stork')]) class SendFieldsTest(unittest.TestCase): @@ -484,11 +475,8 @@ def test(self): # Extract the fields that were sent sreg_data_resp = resp_msg.getArgs(sreg.ns_uri) - self.failUnlessEqual( - {'nickname': 'linusaur', - 'email': 'president@whitehouse.gov', - 'fullname': 'Leonhard Euler', - }, sreg_data_resp) + sent_data = {'nickname': 'linusaur', 'email': 'president@whitehouse.gov', 'fullname': 'Leonhard Euler'} + self.assertEqual(sreg_data_resp, sent_data) if __name__ == '__main__': diff --git a/openid/test/test_symbol.py b/openid/test/test_symbol.py index 74252226..7eab28b2 100644 --- a/openid/test/test_symbol.py +++ b/openid/test/test_symbol.py @@ -6,12 +6,12 @@ class SymbolTest(unittest.TestCase): def test_selfEquality(self): s = oidutil.Symbol('xxx') - self.failUnlessEqual(s, s) + self.assertEqual(s, s) def test_otherEquality(self): x = oidutil.Symbol('xxx') y = oidutil.Symbol('xxx') - self.failUnlessEqual(x, y) + self.assertEqual(x, y) def test_inequality(self): x = oidutil.Symbol('xxx') diff --git a/openid/test/test_trustroot.py b/openid/test/test_trustroot.py index 7905e4a8..8302141c 100644 --- a/openid/test/test_trustroot.py +++ b/openid/test/test_trustroot.py @@ -30,7 +30,7 @@ def test(self): for expected_match, desc, line in getTests([1, 0], mh, mdat): tr, rt = line.split() tr = TrustRoot.parse(tr) - self.failIf(tr is None, tr) + self.assertIsNotNone(tr) match = tr.validateURL(rt) if expected_match: diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index 57ead86a..04b629bd 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -73,8 +73,8 @@ def test_openID2NoIdentifiers(self): 'op_endpoint': op_endpoint}) result_endpoint = self.consumer._verifyDiscoveryResults(msg) self.failUnless(result_endpoint.isOPIdentifier()) - self.failUnlessEqual(op_endpoint, result_endpoint.server_url) - self.failUnlessEqual(None, result_endpoint.claimed_id) + self.assertEqual(result_endpoint.server_url, op_endpoint) + self.assertIsNone(result_endpoint.claimed_id) self.failUnlessLogEmpty() def test_openID2NoEndpointDoesDisco(self): @@ -88,7 +88,7 @@ def test_openID2NoEndpointDoesDisco(self): 'claimed_id': 'monkeysoft', 'op_endpoint': op_endpoint}) result = self.consumer._verifyDiscoveryResults(msg) - self.failUnlessEqual(sentinel, result) + self.assertEqual(result, sentinel) self.failUnlessLogMatches('No pre-discovered') def test_openID2MismatchedDoesDisco(self): @@ -106,7 +106,7 @@ def test_openID2MismatchedDoesDisco(self): 'claimed_id': 'monkeysoft', 'op_endpoint': op_endpoint}) result = self.consumer._verifyDiscoveryResults(msg, mismatched) - self.failUnlessEqual(sentinel, result) + self.assertEqual(result, sentinel) self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') @@ -136,9 +136,9 @@ def test_openid2UsePreDiscoveredWrongType(self): endpoint.type_uris = [discover.OPENID_1_1_TYPE] def discoverAndVerify(claimed_id, to_match_endpoints): - self.failUnlessEqual(claimed_id, endpoint.claimed_id) + self.assertEqual(claimed_id, endpoint.claimed_id) for to_match in to_match_endpoints: - self.failUnlessEqual(claimed_id, to_match.claimed_id) + self.assertEqual(claimed_id, to_match.claimed_id) raise consumer.ProtocolError(text) self.consumer._discoverAndVerify = discoverAndVerify @@ -216,11 +216,11 @@ def test_openid2Fragment(self): 'op_endpoint': endpoint.server_url}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) - self.failUnlessEqual(result.local_id, endpoint.local_id) - self.failUnlessEqual(result.server_url, endpoint.server_url) - self.failUnlessEqual(result.type_uris, endpoint.type_uris) + self.assertEqual(result.local_id, endpoint.local_id) + self.assertEqual(result.server_url, endpoint.server_url) + self.assertEqual(result.type_uris, endpoint.type_uris) - self.failUnlessEqual(result.claimed_id, claimed_id_frag) + self.assertEqual(result.claimed_id, claimed_id_frag) self.failUnlessLogEmpty() @@ -266,7 +266,7 @@ def test_endpointWithoutLocalID(self): to_match.local_id = "http://localhost:8000/id/id-jo" result = self.consumer._verifyDiscoverySingle(endpoint, to_match) # result should always be None, raises exception on failure. - self.failUnlessEqual(result, None) + self.assertIsNone(result) self.failUnlessLogEmpty() diff --git a/openid/test/test_xri.py b/openid/test/test_xri.py index 6e6ac8e2..a5f0bfaf 100644 --- a/openid/test/test_xri.py +++ b/openid/test/test_xri.py @@ -6,35 +6,31 @@ class XriDiscoveryTestCase(TestCase): def test_isXRI(self): i = xri.identifierScheme - self.failUnlessEqual(i('=john.smith'), 'XRI') - self.failUnlessEqual(i('@smiths/john'), 'XRI') - self.failUnlessEqual(i('smoker.myopenid.com'), 'URI') - self.failUnlessEqual(i('xri://=john'), 'XRI') - self.failUnlessEqual(i(''), 'URI') + self.assertEqual(i('=john.smith'), 'XRI') + self.assertEqual(i('@smiths/john'), 'XRI') + self.assertEqual(i('smoker.myopenid.com'), 'URI') + self.assertEqual(i('xri://=john'), 'XRI') + self.assertEqual(i(''), 'URI') class XriEscapingTestCase(TestCase): def test_escaping_percents(self): - self.failUnlessEqual(xri.escapeForIRI('@example/abc%2Fd/ef'), - '@example/abc%252Fd/ef') + self.assertEqual(xri.escapeForIRI('@example/abc%2Fd/ef'), '@example/abc%252Fd/ef') def test_escaping_xref(self): # no escapes esc = xri.escapeForIRI - self.failUnlessEqual('@example/foo/(@bar)', esc('@example/foo/(@bar)')) + self.assertEqual('@example/foo/(@bar)', esc('@example/foo/(@bar)')) # escape slashes - self.failUnlessEqual('@example/foo/(@bar%2Fbaz)', - esc('@example/foo/(@bar/baz)')) - self.failUnlessEqual('@example/foo/(@bar%2Fbaz)/(+a%2Fb)', - esc('@example/foo/(@bar/baz)/(+a/b)')) + self.assertEqual('@example/foo/(@bar%2Fbaz)', esc('@example/foo/(@bar/baz)')) + self.assertEqual('@example/foo/(@bar%2Fbaz)/(+a%2Fb)', esc('@example/foo/(@bar/baz)/(+a/b)')) # escape query ? and fragment # - self.failUnlessEqual('@example/foo/(@baz%3Fp=q%23r)?i=j#k', - esc('@example/foo/(@baz?p=q#r)?i=j#k')) + self.assertEqual('@example/foo/(@baz%3Fp=q%23r)?i=j#k', esc('@example/foo/(@baz?p=q#r)?i=j#k')) class XriTransformationTestCase(TestCase): def test_to_iri_normal(self): - self.failUnlessEqual(xri.toIRINormal('@example'), 'xri://@example') + self.assertEqual(xri.toIRINormal('@example'), 'xri://@example') try: unichr(0x10000) @@ -43,12 +39,12 @@ def test_to_iri_normal(self): def test_iri_to_url(self): s = u'l\xa1m' expected = 'l%C2%A1m' - self.failUnlessEqual(xri.iriToURI(s), expected) + self.assertEqual(xri.iriToURI(s), expected) else: def test_iri_to_url(self): s = u'l\xa1m\U00101010n' expected = 'l%C2%A1m%F4%81%80%90n' - self.failUnlessEqual(xri.iriToURI(s), expected) + self.assertEqual(xri.iriToURI(s), expected) class CanonicalIDTest(TestCase): @@ -57,7 +53,7 @@ def test(self): result = xri.providerIsAuthoritative(providerID, canonicalID) format = "%s providing %s, expected %s" message = format % (providerID, canonicalID, isAuthoritative) - self.failUnlessEqual(isAuthoritative, result, message) + self.assertEqual(result, isAuthoritative, message) return test @@ -75,7 +71,7 @@ class TestGetRootAuthority(TestCase): def mkTest(the_xri, expected_root): def test(self): actual_root = xri.rootAuthority(the_xri) - self.failUnlessEqual(actual_root, xri.XRI(expected_root)) + self.assertEqual(actual_root, xri.XRI(expected_root)) return test test_at = mkTest("@foo", "@") diff --git a/openid/test/test_xrires.py b/openid/test/test_xrires.py index b06a8b27..9a02bec5 100644 --- a/openid/test/test_xrires.py +++ b/openid/test/test_xrires.py @@ -17,17 +17,14 @@ def test_proxy_url(self): args_esc = "_xrd_r=application%2Fxrds%2Bxml&_xrd_t=" + ste pqu = self.proxy.queryURL h = self.proxy_url - self.failUnlessEqual(h + '=foo?' + args_esc, pqu('=foo', st)) - self.failUnlessEqual(h + '=foo/bar?baz&' + args_esc, - pqu('=foo/bar?baz', st)) - self.failUnlessEqual(h + '=foo/bar?baz=quux&' + args_esc, - pqu('=foo/bar?baz=quux', st)) - self.failUnlessEqual(h + '=foo/bar?mi=fa&so=la&' + args_esc, - pqu('=foo/bar?mi=fa&so=la', st)) + self.assertEqual(pqu('=foo', st), h + '=foo?' + args_esc) + self.assertEqual(pqu('=foo/bar?baz', st), h + '=foo/bar?baz&' + args_esc) + self.assertEqual(pqu('=foo/bar?baz=quux', st), h + '=foo/bar?baz=quux&' + args_esc) + self.assertEqual(pqu('=foo/bar?mi=fa&so=la', st), h + '=foo/bar?mi=fa&so=la&' + args_esc) # With no service endpoint selection. args_esc = "_xrd_r=application%2Fxrds%2Bxml%3Bsep%3Dfalse" - self.failUnlessEqual(h + '=foo?' + args_esc, pqu('=foo', None)) + self.assertEqual(pqu('=foo', None), h + '=foo?' + args_esc) def test_proxy_url_qmarks(self): st = self.servicetype @@ -35,6 +32,5 @@ def test_proxy_url_qmarks(self): args_esc = "_xrd_r=application%2Fxrds%2Bxml&_xrd_t=" + ste pqu = self.proxy.queryURL h = self.proxy_url - self.failUnlessEqual(h + '=foo/bar??' + args_esc, pqu('=foo/bar?', st)) - self.failUnlessEqual(h + '=foo/bar????' + args_esc, - pqu('=foo/bar???', st)) + self.assertEqual(pqu('=foo/bar?', st), h + '=foo/bar??' + args_esc) + self.assertEqual(pqu('=foo/bar???', st), h + '=foo/bar????' + args_esc) diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index e00f0e89..7284cbd1 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -120,21 +120,19 @@ def test(self): discover, input_url) else: result = discover(input_url) - self.failUnlessEqual(input_url, result.request_uri) + self.assertEqual(result.request_uri, input_url) msg = 'Identity URL mismatch: actual = %r, expected = %r' % ( result.normalized_uri, expected.normalized_uri) - self.failUnlessEqual( - expected.normalized_uri, result.normalized_uri, msg) + self.assertEqual(result.normalized_uri, expected.normalized_uri, msg) msg = 'Content mismatch: actual = %r, expected = %r' % ( result.response_text, expected.response_text) - self.failUnlessEqual( - expected.response_text, result.response_text, msg) + self.assertEqual(result.response_text, expected.response_text, msg) expected_keys = sorted(dir(expected)) actual_keys = sorted(dir(result)) - self.failUnlessEqual(actual_keys, expected_keys) + self.assertEqual(actual_keys, expected_keys) for k in dir(expected): if k.startswith('__') and k.endswith('__'): From 2d0035d061b00e33a04a5b4646d246f892a4a084 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jan 2018 09:19:44 +0100 Subject: [PATCH 034/151] Replace failUnlessRaises --- openid/test/test_association.py | 2 +- openid/test/test_association_response.py | 3 +- openid/test/test_auth_request.py | 2 +- openid/test/test_ax.py | 33 +++---- openid/test/test_consumer.py | 108 +++++++---------------- openid/test/test_discover.py | 19 ++-- openid/test/test_etxrd.py | 15 +--- openid/test/test_fetchers.py | 3 +- openid/test/test_kvform.py | 2 +- openid/test/test_message.py | 50 ++++------- openid/test/test_nonce.py | 2 +- openid/test/test_pape_draft2.py | 14 ++- openid/test/test_pape_draft5.py | 39 +++----- openid/test/test_rpverify.py | 3 +- openid/test/test_server.py | 80 +++++++---------- openid/test/test_services.py | 4 +- openid/test/test_sreg.py | 41 +++------ openid/test/test_verifydisco.py | 10 +-- openid/test/test_yadis_discover.py | 5 +- 19 files changed, 147 insertions(+), 288 deletions(-) diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 12fd20cd..5caa4214 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -138,4 +138,4 @@ def test_aintGotSignedList(self): m.updateArgs(BARE_NS, {'xey': 'value'}) assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") - self.failUnlessRaises(ValueError, assoc.checkMessageSignature, m) + self.assertRaises(ValueError, assoc.checkMessageSignature, m) diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 1b3e13c8..6333c13a 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -74,8 +74,7 @@ def mkExtractAssocMissingTest(keys): def test(self): msg = mkAssocResponse(*keys) - self.failUnlessRaises(KeyError, - self.consumer._extractAssociation, msg, None) + self.assertRaises(KeyError, self.consumer._extractAssociation, msg, None) return test diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index 6f173812..06ee02b4 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -173,7 +173,7 @@ def failUnlessHasRealm(self, msg): def test_setAnonymousFailsForOpenID1(self): """OpenID 1 requests MUST NOT be able to set anonymous to True""" self.failUnless(self.authreq.message.isOpenID1()) - self.failUnlessRaises(ValueError, self.authreq.setAnonymous, True) + self.assertRaises(ValueError, self.authreq.setAnonymous, True) self.authreq.setAnonymous(False) def test_identifierSelect(self): diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 4c107e3f..1c94e4ba 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -59,9 +59,7 @@ def test_empty(self): self.assertEqual(uris, []) def test_undefined(self): - self.failUnlessRaises( - KeyError, - ax.toTypeURIs, self.aliases, 'http://janrain.com/') + self.assertRaises(KeyError, ax.toTypeURIs, self.aliases, 'http://janrain.com/') def test_one(self): uri = 'http://janrain.com/' @@ -88,7 +86,7 @@ class ParseAXValuesTest(unittest.TestCase): def failUnlessAXKeyError(self, ax_args): msg = ax.AXKeyValueMessage() - self.failUnlessRaises(KeyError, msg.parseExtensionArgs, ax_args) + self.assertRaises(KeyError, msg.parseExtensionArgs, ax_args) def failUnlessAXValues(self, ax_args, expected_args): """Fail unless parseExtensionArgs(ax_args) == expected_args.""" @@ -108,10 +106,7 @@ def test_countPresentButNotValue(self): def test_invalidCountValue(self): msg = ax.FetchRequest() - self.failUnlessRaises(ax.AXError, - msg.parseExtensionArgs, - {'type.foo': 'urn:foo', - 'count.foo': 'bogus'}) + self.assertRaises(ax.AXError, msg.parseExtensionArgs, {'type.foo': 'urn:foo', 'count.foo': 'bogus'}) def test_requestUnlimitedValues(self): msg = ax.FetchRequest() @@ -156,8 +151,7 @@ def test_invalidAlias(self): for typ in types: for input in inputs: msg = typ() - self.failUnlessRaises(ax.AXError, msg.parseExtensionArgs, - input) + self.assertRaises(ax.AXError, msg.parseExtensionArgs, input) def test_countPresentAndIsZero(self): self.failUnlessAXValues( @@ -228,7 +222,7 @@ def test_addTwice(self): attr = ax.AttrInfo(uri) self.msg.add(attr) - self.failUnlessRaises(KeyError, self.msg.add, attr) + self.assertRaises(KeyError, self.msg.add, attr) def test_getExtensionArgs_empty(self): expected_args = { @@ -298,8 +292,7 @@ def test_parseExtensionArgs_extraType(self): 'mode': 'fetch_request', 'type.' + self.alias_a: self.type_a, } - self.failUnlessRaises(ValueError, - self.msg.parseExtensionArgs, extension_args) + self.assertRaises(ValueError, self.msg.parseExtensionArgs, extension_args) def test_parseExtensionArgs(self): extension_args = { @@ -361,9 +354,7 @@ def test_openidNoRealm(self): 'ax.update_url': 'http://different.site/path', 'ax.mode': 'fetch_request', }) - self.failUnlessRaises(ax.AXError, - ax.FetchRequest.fromOpenIDRequest, - DummyRequest(openid_req_msg)) + self.assertRaises(ax.AXError, ax.FetchRequest.fromOpenIDRequest, DummyRequest(openid_req_msg)) def test_openidUpdateURLVerificationError(self): openid_req_msg = Message.fromOpenIDArgs({ @@ -375,9 +366,7 @@ def test_openidUpdateURLVerificationError(self): 'ax.mode': 'fetch_request', }) - self.failUnlessRaises(ax.AXError, - ax.FetchRequest.fromOpenIDRequest, - DummyRequest(openid_req_msg)) + self.assertRaises(ax.AXError, ax.FetchRequest.fromOpenIDRequest, DummyRequest(openid_req_msg)) def test_openidUpdateURLVerificationSuccess(self): openid_req_msg = Message.fromOpenIDArgs({ @@ -500,7 +489,7 @@ def test_getExtensionArgs_some_not_request(self): req = ax.FetchRequest() msg = ax.FetchResponse(request=req) msg.addValue(self.type_a, self.value_a) - self.failUnlessRaises(KeyError, msg.getExtensionArgs) + self.assertRaises(KeyError, msg.getExtensionArgs) def test_getSingle_success(self): self.msg.addValue(self.type_a, self.value_a) @@ -511,10 +500,10 @@ def test_getSingle_none(self): def test_getSingle_extra(self): self.msg.setValues(self.type_a, ['x', 'y']) - self.failUnlessRaises(ax.AXError, self.msg.getSingle, self.type_a) + self.assertRaises(ax.AXError, self.msg.getSingle, self.type_a) def test_get(self): - self.failUnlessRaises(KeyError, self.msg.get, self.type_a) + self.assertRaises(KeyError, self.msg.get, self.type_a) def test_fromSuccessResponseWithoutExtension(self): """return None for SuccessResponse with no AX paramaters.""" diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index dce7075b..eae07f33 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -259,7 +259,7 @@ def test_construct(self): self.failUnless(oidc.store is self.store_sentinel) def test_nostore(self): - self.failUnlessRaises(TypeError, GenericConsumer) + self.assertRaises(TypeError, GenericConsumer) class TestIdRes(unittest.TestCase, CatchLogs): @@ -318,9 +318,7 @@ def test_sign(self): def test_signFailsWithBadSig(self): self.message.setArg(OPENID_NS, 'sig', 'BAD SIGNATURE') - self.failUnlessRaises( - ProtocolError, self.consumer._idResCheckSignature, - self.message, self.endpoint.server_url) + self.assertRaises(ProtocolError, self.consumer._idResCheckSignature, self.message, self.endpoint.server_url) def test_stateless(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings @@ -335,9 +333,7 @@ def test_statelessRaisesError(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") self.consumer._checkAuth = lambda unused1, unused2: False - self.failUnlessRaises( - ProtocolError, self.consumer._idResCheckSignature, - self.message, self.endpoint.server_url) + self.assertRaises(ProtocolError, self.consumer._idResCheckSignature, self.message, self.endpoint.server_url) def test_stateless_noStore(self): # assoc_handle missing assoc, consumer._checkAuth returns goodthings @@ -354,9 +350,7 @@ def test_statelessRaisesError_noStore(self): self.message.setArg(OPENID_NS, "assoc_handle", "dumbHandle") self.consumer._checkAuth = lambda unused1, unused2: False self.consumer.store = None - self.failUnlessRaises( - ProtocolError, self.consumer._idResCheckSignature, - self.message, self.endpoint.server_url) + self.assertRaises(ProtocolError, self.consumer._idResCheckSignature, self.message, self.endpoint.server_url) class TestQueryFormat(TestIdRes): @@ -452,8 +446,7 @@ def test_idResMissingField(self): # is supposed to test for. status in FAILURE, but it's because # *check_auth* failed, not because it's missing an arg, exactly. message = Message.fromPostArgs({'openid.mode': 'id_res'}) - self.failUnlessRaises(ProtocolError, self.consumer._doIdRes, - message, self.endpoint, None) + self.assertRaises(ProtocolError, self.consumer._doIdRes, message, self.endpoint, None) def test_idResURLMismatch(self): class VerifiedError(Exception): @@ -475,9 +468,7 @@ def discoverAndVerify(claimed_id, _to_match_endpoints): }) self.consumer.store = GoodAssocStore() - self.failUnlessRaises(VerifiedError, - self.consumer.complete, - message, self.endpoint) + self.assertRaises(VerifiedError, self.consumer.complete, message, self.endpoint) self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') @@ -864,8 +855,7 @@ def test_consumerNonceOpenID2(self): self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),) self.response = Message.fromOpenIDArgs( {'return_to': self.return_to, 'ns': OPENID2_NS}) - self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) self.failUnlessLogEmpty() def test_serverNonce(self): @@ -878,8 +868,7 @@ def test_serverNonceOpenID1(self): """OpenID 1 does not use server-generated nonce""" self.response = Message.fromOpenIDArgs( {'ns': OPENID1_NS, 'return_to': 'http://return.to/', 'response_nonce': mkNonce()}) - self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) self.failUnlessLogEmpty() def test_badNonce(self): @@ -897,8 +886,7 @@ def test_badNonce(self): stamp, salt = splitNonce(nonce) self.store.useNonce(self.server_url, stamp, salt) self.response = Message.fromOpenIDArgs({'response_nonce': nonce, 'ns': OPENID2_NS}) - self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) def test_successWithNoStore(self): """When there is no store, checking the nonce succeeds""" @@ -910,15 +898,13 @@ def test_successWithNoStore(self): def test_tamperedNonce(self): """Malformed nonce""" self.response = Message.fromOpenIDArgs({'ns': OPENID2_NS, 'response_nonce': 'malformed'}) - self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) def test_missingNonce(self): """no nonce parameter on the return_to""" self.response = Message.fromOpenIDArgs( {'return_to': self.return_to}) - self.failUnlessRaises(ProtocolError, self.consumer._idResCheckNonce, - self.response, self.endpoint) + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) class CheckAuthDetectingConsumer(GenericConsumer): @@ -998,8 +984,7 @@ def test_expiredAssoc(self): 'openid.signed': 'identity,return_to', }) self.disableReturnToChecking() - self.failUnlessRaises(ProtocolError, self.consumer._doIdRes, - message, self.endpoint, None) + self.assertRaises(ProtocolError, self.consumer._doIdRes, message, self.endpoint, None) def test_newerAssoc(self): lifetime = 1000 @@ -1077,8 +1062,7 @@ def test_returnToArgsUnexpectedArg(self): 'foo': 'bar', } # no return value, success is assumed if there are no exceptions. - self.failUnlessRaises(ProtocolError, - self.consumer._verifyReturnToArgs, query) + self.assertRaises(ProtocolError, self.consumer._verifyReturnToArgs, query) def test_returnToMismatch(self): query = { @@ -1086,18 +1070,15 @@ def test_returnToMismatch(self): 'openid.return_to': 'http://example.com/?foo=bar', } # fail, query has no key 'foo'. - self.failUnlessRaises(ValueError, - self.consumer._verifyReturnToArgs, query) + self.assertRaises(ValueError, self.consumer._verifyReturnToArgs, query) query['foo'] = 'baz' # fail, values for 'foo' do not match. - self.failUnlessRaises(ValueError, - self.consumer._verifyReturnToArgs, query) + self.assertRaises(ValueError, self.consumer._verifyReturnToArgs, query) def test_noReturnTo(self): query = {'openid.mode': 'id_res'} - self.failUnlessRaises(ValueError, - self.consumer._verifyReturnToArgs, query) + self.assertRaises(ValueError, self.consumer._verifyReturnToArgs, query) def test_completeBadReturnTo(self): """Test GenericConsumer.complete()'s handling of bad return_to @@ -1282,11 +1263,8 @@ def test_error_404(self): """404 from a kv post raises HTTPFetchingError""" self.fetcher.response = HTTPResponse( "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') - self.failUnlessRaises( - fetchers.HTTPFetchingError, - self.consumer._makeKVPost, - Message.fromPostArgs({'mode': 'associate'}), - "http://server_url") + self.assertRaises(fetchers.HTTPFetchingError, self.consumer._makeKVPost, + Message.fromPostArgs({'mode': 'associate'}), "http://server_url") def test_error_exception_unwrapped(self): """Ensure that exceptions are bubbled through from fetchers @@ -1294,21 +1272,16 @@ def test_error_exception_unwrapped(self): """ self.fetcher = ExceptionRaisingMockFetcher() fetchers.setDefaultFetcher(self.fetcher, wrap_exceptions=False) - self.failUnlessRaises(self.fetcher.MyException, - self.consumer._makeKVPost, - Message.fromPostArgs({'mode': 'associate'}), - "http://server_url") + self.assertRaises(self.fetcher.MyException, self.consumer._makeKVPost, + Message.fromPostArgs({'mode': 'associate'}), "http://server_url") # exception fetching returns no association e = OpenIDServiceEndpoint() e.server_url = 'some://url' - self.failUnlessRaises(self.fetcher.MyException, - self.consumer._getAssociation, e) + self.assertRaises(self.fetcher.MyException, self.consumer._getAssociation, e) - self.failUnlessRaises(self.fetcher.MyException, - self.consumer._checkAuth, - Message.fromPostArgs({'openid.signed': ''}), - 'some://url') + self.assertRaises(self.fetcher.MyException, self.consumer._checkAuth, + Message.fromPostArgs({'openid.signed': ''}), 'some://url') def test_error_exception_wrapped(self): """Ensure that openid.fetchers.HTTPFetchingError is caught by @@ -1317,10 +1290,8 @@ def test_error_exception_wrapped(self): self.fetcher = ExceptionRaisingMockFetcher() # This will wrap exceptions! fetchers.setDefaultFetcher(self.fetcher) - self.failUnlessRaises(fetchers.HTTPFetchingError, - self.consumer._makeKVPost, - Message.fromOpenIDArgs({'mode': 'associate'}), - "http://server_url") + self.assertRaises(fetchers.HTTPFetchingError, self.consumer._makeKVPost, + Message.fromOpenIDArgs({'mode': 'associate'}), "http://server_url") # exception fetching returns no association e = OpenIDServiceEndpoint() @@ -1686,8 +1657,7 @@ def verifyDiscoveryResults(identifier, endpoint): raise DiscoveryFailure("PHREAK!", None) self.consumer._verifyDiscoveryResults = verifyDiscoveryResults self.consumer._checkReturnTo = lambda unused1, unused2: True - self.failUnlessRaises(DiscoveryFailure, self.consumer._doIdRes, - message, self.endpoint, None) + self.assertRaises(DiscoveryFailure, self.consumer._doIdRes, message, self.endpoint, None) def failUnlessSuccess(self, response): if response.status != SUCCESS: @@ -1777,9 +1747,7 @@ def discoverAndVerify(claimed_id, to_match_endpoints): def test_nothingDiscovered(self): # a set of no things. self.services = [] - self.failUnlessRaises(DiscoveryFailure, - self.consumer._verifyDiscoveryResults, - self.message, self.endpoint) + self.assertRaises(DiscoveryFailure, self.consumer._verifyDiscoveryResults, self.message, self.endpoint) def discoveryFunc(self, identifier): return identifier, self.services @@ -1885,24 +1853,24 @@ def testExtractSecret(self): def testAbsentServerPublic(self): self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) - self.failUnlessRaises(KeyError, self.consumer_session.extractSecret, self.msg) + self.assertRaises(KeyError, self.consumer_session.extractSecret, self.msg) def testAbsentMacKey(self): self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public) - self.failUnlessRaises(KeyError, self.consumer_session.extractSecret, self.msg) + self.assertRaises(KeyError, self.consumer_session.extractSecret, self.msg) def testInvalidBase64Public(self): self.msg.setArg(OPENID_NS, 'dh_server_public', 'n o t b a s e 6 4.') self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) - self.failUnlessRaises(ValueError, self.consumer_session.extractSecret, self.msg) + self.assertRaises(ValueError, self.consumer_session.extractSecret, self.msg) def testInvalidBase64MacKey(self): self.msg.setArg(OPENID_NS, 'dh_server_public', self.dh_server_public) self.msg.setArg(OPENID_NS, 'enc_mac_key', 'n o t base 64') - self.failUnlessRaises(ValueError, self.consumer_session.extractSecret, self.msg) + self.assertRaises(ValueError, self.consumer_session.extractSecret, self.msg) class TestOpenID1SHA1(TestDiffieHellmanResponseParameters, unittest.TestCase): @@ -1955,9 +1923,7 @@ def test_beginWithoutDiscoveryAnonymousFail(self): def bogusBegin(unused): return NonAnonymousAuthRequest() consumer.consumer.begin = bogusBegin - self.failUnlessRaises( - ProtocolError, - consumer.beginWithoutDiscovery, None) + self.assertRaises(ProtocolError, consumer.beginWithoutDiscovery, None) class TestDiscoverAndVerify(unittest.TestCase): @@ -1971,11 +1937,7 @@ def dummyDiscover(unused_identifier): self.to_match = OpenIDServiceEndpoint() def failUnlessDiscoveryFailure(self): - self.failUnlessRaises( - DiscoveryFailure, - self.consumer._discoverAndVerify, - 'http://claimed-id.com/', - [self.to_match]) + self.assertRaises(DiscoveryFailure, self.consumer._discoverAndVerify, 'http://claimed-id.com/', [self.to_match]) def test_noServices(self): """Discovery returning no results results in a @@ -2058,9 +2020,7 @@ def test_500(self): response = HTTPResponse() response.status = 500 response.body = "foo:bar\nbaz:quux\n" - self.failUnlessRaises(fetchers.HTTPFetchingError, - _httpResponseToMessage, response, - self.server_url) + self.assertRaises(fetchers.HTTPFetchingError, _httpResponseToMessage, response, self.server_url) if __name__ == '__main__': diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 4a214686..94a64702 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -229,8 +229,7 @@ def _discover(self, content_type, data, return services def test_404(self): - self.failUnlessRaises(DiscoveryFailure, - discover.discover, self.id_url + '/404') + self.assertRaises(DiscoveryFailure, discover.discover, self.id_url + '/404') def test_unicode(self): """ @@ -246,7 +245,7 @@ def test_unicode_undecodable_html(self): Check page with unicode and HTML entities that can not be decoded """ data = readDataFile('unicode2.html') - self.failUnlessRaises(UnicodeDecodeError, data.decode, 'utf-8') + self.assertRaises(UnicodeDecodeError, data.decode, 'utf-8') self._discover(content_type='text/html;charset=utf-8', data=data, expected_services=0) def test_unicode_undecodable_html2(self): @@ -258,7 +257,7 @@ def test_unicode_undecodable_html2(self): 'application/xrds+xml', readDataFile('yadis_idp.xml')) data = readDataFile('unicode3.html') - self.failUnlessRaises(UnicodeDecodeError, data.decode, 'utf-8') + self.assertRaises(UnicodeDecodeError, data.decode, 'utf-8') self._discover(content_type='text/html;charset=utf-8', data=data, expected_services=1) def test_noOpenID(self): @@ -462,10 +461,8 @@ def test_yadis2OPDelegate(self): ) def test_yadis2BadLocalID(self): - self.failUnlessRaises(DiscoveryFailure, self._discover, - content_type='application/xrds+xml', - data=readDataFile('yadis_2_bad_local_id.xml'), - expected_services=1) + self.assertRaises(DiscoveryFailure, self._discover, content_type='application/xrds+xml', + data=readDataFile('yadis_2_bad_local_id.xml'), expected_services=1) def test_yadis1And2(self): services = self._discover( @@ -485,10 +482,8 @@ def test_yadis1And2(self): ) def test_yadis1And2BadLocalID(self): - self.failUnlessRaises(DiscoveryFailure, self._discover, - content_type='application/xrds+xml', - data=readDataFile('openid_1_and_2_xrds_bad_delegate.xml'), - expected_services=1) + self.assertRaises(DiscoveryFailure, self._discover, content_type='application/xrds+xml', + data=readDataFile('openid_1_and_2_xrds_bad_delegate.xml'), expected_services=1) class MockFetcherForXRIProxy(object): diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index 3e349dcc..59c84115 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -103,25 +103,19 @@ def testNoXRDS(self): """Make sure that we get an exception when an XRDS element is not present""" self.xmldoc = file(NOXRDS_FILE).read() - self.failUnlessRaises( - etxrd.XRDSError, - services.applyFilter, self.yadis_url, self.xmldoc, None) + self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) def testEmpty(self): """Make sure that we get an exception when an XRDS element is not present""" self.xmldoc = '' - self.failUnlessRaises( - etxrd.XRDSError, - services.applyFilter, self.yadis_url, self.xmldoc, None) + self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) def testNoXRD(self): """Make sure that we get an exception when there is no XRD element present.""" self.xmldoc = file(NOXRD_FILE).read() - self.failUnlessRaises( - etxrd.XRDSError, - services.applyFilter, self.yadis_url, self.xmldoc, None) + self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) class TestCanonicalID(unittest.TestCase): @@ -186,8 +180,7 @@ def _getCanonicalID(self, iname, xrds, expectedID): cid = etxrd.getCanonicalID(iname, xrds) self.assertEqual(cid, expectedID and xri.XRI(expectedID)) elif issubclass(expectedID, etxrd.XRDSError): - self.failUnlessRaises(expectedID, etxrd.getCanonicalID, - iname, xrds) + self.assertRaises(expectedID, etxrd.getCanonicalID, iname, xrds) else: self.fail("Don't know how to test for expected value %r" % (expectedID,)) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 069e0fa9..2ca208b2 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -252,8 +252,7 @@ def test_wrappedByDefault(self): fetchers.ExceptionWrappingFetcher), default_fetcher) - self.failUnlessRaises(fetchers.HTTPFetchingError, - fetchers.fetch, 'http://invalid.janrain.com/') + self.assertRaises(fetchers.HTTPFetchingError, fetchers.fetch, 'http://invalid.janrain.com/') def test_notWrapped(self): """Make sure that if we set a non-wrapped fetcher as default, diff --git a/openid/test/test_kvform.py b/openid/test/test_kvform.py index b3fb6dd2..09532db9 100644 --- a/openid/test/test_kvform.py +++ b/openid/test/test_kvform.py @@ -147,7 +147,7 @@ class KVExcTest(unittest.TestCase): def runTest(self): for kv_data in kvexc_cases: - self.failUnlessRaises(ValueError, kvform.seqToKV, kv_data) + self.assertRaises(ValueError, kvform.seqToKV, kv_data) class GeneralTest(KVBaseTest): diff --git a/openid/test/test_message.py b/openid/test/test_message.py index f0570437..98278f24 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -13,8 +13,7 @@ def test(self): self.assertEqual(self.msg.getArg(ns, key), expected) if expected is None: self.assertEqual(self.msg.getArg(ns, key, a_default), a_default) - self.failUnlessRaises( - KeyError, self.msg.getArg, ns, key, message.no_default) + self.assertRaises(KeyError, self.msg.getArg, ns, key, message.no_default) else: self.assertEqual(self.msg.getArg(ns, key, a_default), expected) self.assertEqual(self.msg.getArg(ns, key, message.no_default), expected) @@ -50,8 +49,7 @@ def test_getKeyOpenID(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.failUnlessRaises(message.UndefinedOpenIDNamespace, - self.msg.getKey, message.OPENID_NS, 'foo') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getKey, message.OPENID_NS, 'foo') def test_getKeyBARE(self): self.assertEqual(self.msg.getKey(message.BARE_NS, 'foo'), 'foo') @@ -70,8 +68,7 @@ def test_hasKey(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.failUnlessRaises(message.UndefinedOpenIDNamespace, - self.msg.hasKey, message.OPENID_NS, 'foo') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.hasKey, message.OPENID_NS, 'foo') def test_hasKeyBARE(self): self.assertFalse(self.msg.hasKey(message.BARE_NS, 'foo')) @@ -93,16 +90,14 @@ def test_getAliasedArgSuccess(self): def test_getAliasedArgFailure(self): msg = message.Message.fromPostArgs({'openid.test.flub': 'bogus'}) - self.assertRaises(KeyError, - msg.getAliasedArg, 'ns.test', message.no_default) + self.assertRaises(KeyError, msg.getAliasedArg, 'ns.test', message.no_default) def test_getArg(self): # Could reasonably return None instead of raising an # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.failUnlessRaises(message.UndefinedOpenIDNamespace, - self.msg.getArg, message.OPENID_NS, 'foo') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getArg, message.OPENID_NS, 'foo') test_getArgBARE = mkGetArgTest(message.BARE_NS, 'foo') test_getArgNS1 = mkGetArgTest(message.OPENID1_NS, 'foo') @@ -114,8 +109,7 @@ def test_getArgs(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.failUnlessRaises(message.UndefinedOpenIDNamespace, - self.msg.getArgs, message.OPENID_NS) + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getArgs, message.OPENID_NS) def test_getArgsBARE(self): self.assertEqual(self.msg.getArgs(message.BARE_NS), {}) @@ -130,9 +124,8 @@ def test_getArgsNS3(self): self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) def test_updateArgs(self): - self.failUnlessRaises(message.UndefinedOpenIDNamespace, - self.msg.updateArgs, message.OPENID_NS, - {'does not': 'matter'}) + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.updateArgs, message.OPENID_NS, + {'does not': 'matter'}) def _test_updateArgsNS(self, ns): update_args = { @@ -157,9 +150,7 @@ def test_updateArgsNS3(self): self._test_updateArgsNS('urn:nothing-significant') def test_setArg(self): - self.failUnlessRaises(message.UndefinedOpenIDNamespace, - self.msg.setArg, message.OPENID_NS, - 'does not', 'matter') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.setArg, message.OPENID_NS, 'does not', 'matter') def _test_setArgNS(self, ns): key = 'Camper van Beethoven' @@ -181,8 +172,7 @@ def test_setArgNS3(self): self._test_setArgNS('urn:nothing-significant') def test_setArgToNone(self): - self.failUnlessRaises(AssertionError, self.msg.setArg, - message.OPENID1_NS, 'op_endpoint', None) + self.assertRaises(AssertionError, self.msg.setArg, message.OPENID1_NS, 'op_endpoint', None) def test_delArg(self): # Could reasonably raise KeyError instead of raising @@ -190,12 +180,11 @@ def test_delArg(self): # right, since this case should only happen when you're # building a message from scratch and so have no default # namespace. - self.failUnlessRaises(message.UndefinedOpenIDNamespace, - self.msg.delArg, message.OPENID_NS, 'key') + self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.delArg, message.OPENID_NS, 'key') def _test_delArgNS(self, ns): key = 'Camper van Beethoven' - self.failUnlessRaises(KeyError, self.msg.delArg, ns, key) + self.assertRaises(KeyError, self.msg.delArg, ns, key) def test_delArgBARE(self): self._test_delArgNS(message.BARE_NS) @@ -354,7 +343,7 @@ def _test_delArgNS(self, ns): key = 'Camper van Beethoven' value = 'David Lowery' - self.failUnlessRaises(KeyError, self.msg.delArg, ns, key) + self.assertRaises(KeyError, self.msg.delArg, ns, key) self.msg.setArg(ns, key, value) self.assertEqual(self.msg.getArg(ns, key), value) self.msg.delArg(ns, key) @@ -586,8 +575,7 @@ def test_badAlias(self): # .fromPostArgs covers .fromPostArgs, .fromOpenIDArgs, # ._fromOpenIDArgs, and .fromOpenIDArgs (since it calls # .fromPostArgs). - self.failUnlessRaises(AssertionError, self.msg.fromPostArgs, - args) + self.assertRaises(AssertionError, self.msg.fromPostArgs, args) def test_mysterious_missing_namespace_bug(self): """A failing test for bug #112""" @@ -665,7 +653,7 @@ def test_repetitive_namespaces(self): 'openid.pape.auth_time': '2008-01-28T20:42:36Z', 'openid.pape.nist_auth_level': '0', } - self.failUnlessRaises(message.InvalidNamespace, message.Message.fromPostArgs, args) + self.assertRaises(message.InvalidNamespace, message.Message.fromPostArgs, args) def test_implicit_sreg_ns(self): openid_args = {'sreg.email': 'a@b.com'} @@ -680,7 +668,7 @@ def _test_delArgNS(self, ns): key = 'Camper van Beethoven' value = 'David Lowery' - self.failUnlessRaises(KeyError, self.msg.delArg, ns, key) + self.assertRaises(KeyError, self.msg.delArg, ns, key) self.msg.setArg(ns, key, value) self.assertEqual(self.msg.getArg(ns, key), value) self.msg.delArg(ns, key) @@ -713,8 +701,7 @@ def test_overwriteExtensionArg(self): self.failUnless(self.msg.getArg(ns, key) == value_2) def test_argList(self): - self.failUnlessRaises(TypeError, self.msg.fromPostArgs, - {'arg': [1, 2, 3]}) + self.assertRaises(TypeError, self.msg.fromPostArgs, {'arg': [1, 2, 3]}) def test_isOpenID1(self): self.failIf(self.msg.isOpenID1()) @@ -883,8 +870,7 @@ def test_setOpenIDNamespace_invalid(self): ] for x in invalid_things: - self.failUnlessRaises(message.InvalidOpenIDNamespace, - m.setOpenIDNamespace, x, False) + self.assertRaises(message.InvalidOpenIDNamespace, m.setOpenIDNamespace, x, False) def test_isOpenID1(self): v1_namespaces = [ diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index 1817f821..3e6c5fe5 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -48,7 +48,7 @@ class BadSplitTest(unittest.TestCase): def test(self): for nonce_str in self.cases: - self.failUnlessRaises(ValueError, splitNonce, nonce_str) + self.assertRaises(ValueError, splitNonce, nonce_str) class CheckTimestampTest(unittest.TestCase): diff --git a/openid/test/test_pape_draft2.py b/openid/test/test_pape_draft2.py index 9f54f88f..53e2c553 100644 --- a/openid/test/test_pape_draft2.py +++ b/openid/test/test_pape_draft2.py @@ -131,15 +131,15 @@ def test_getExtensionArgs(self): def test_getExtensionArgs_error_auth_age(self): self.req.auth_time = "long ago" - self.failUnlessRaises(ValueError, self.req.getExtensionArgs) + self.assertRaises(ValueError, self.req.getExtensionArgs) def test_getExtensionArgs_error_nist_auth_level(self): self.req.nist_auth_level = "high as a kite" - self.failUnlessRaises(ValueError, self.req.getExtensionArgs) + self.assertRaises(ValueError, self.req.getExtensionArgs) self.req.nist_auth_level = 5 - self.failUnlessRaises(ValueError, self.req.getExtensionArgs) + self.assertRaises(ValueError, self.req.getExtensionArgs) self.req.nist_auth_level = -1 - self.failUnlessRaises(ValueError, self.req.getExtensionArgs) + self.assertRaises(ValueError, self.req.getExtensionArgs) def test_parseExtensionArgs(self): args = {'auth_policies': 'http://foo http://bar', @@ -156,15 +156,13 @@ def test_parseExtensionArgs_empty(self): def test_parseExtensionArgs_strict_bogus1(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': 'yesterday'} - self.failUnlessRaises(ValueError, self.req.parseExtensionArgs, - args, True) + self.assertRaises(ValueError, self.req.parseExtensionArgs, args, True) def test_parseExtensionArgs_strict_bogus2(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': '1970-01-01T00:00:00Z', 'nist_auth_level': 'some'} - self.failUnlessRaises(ValueError, self.req.parseExtensionArgs, - args, True) + self.assertRaises(ValueError, self.req.parseExtensionArgs, args, True) def test_parseExtensionArgs_strict_good(self): args = {'auth_policies': 'http://foo http://bar', diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index ee2be1ae..0b0f10ea 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -40,9 +40,7 @@ def test_addAuthLevel(self): self.req.addAuthLevel('http://example.com/', 'example') self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) - self.failUnlessRaises(KeyError, - self.req.addAuthLevel, - 'http://example.com/2', 'example') + self.assertRaises(KeyError, self.req.addAuthLevel, 'http://example.com/2', 'example') # alias is None; we expect a new one to be generated. uri = 'http://another.example.com/' @@ -131,9 +129,7 @@ def test_parseExtensionArgsWithAuthLevels_openID1(self): self.assertEqual(self.req.preferred_auth_level_types, []) self.req = pape.Request() - self.failUnlessRaises(ValueError, - self.req.parseExtensionArgs, - request_args, is_openid1=False, strict=True) + self.assertRaises(ValueError, self.req.parseExtensionArgs, request_args, is_openid1=False, strict=True) def test_parseExtensionArgs_ignoreBadAuthLevels(self): request_args = {'preferred_auth_level_types': 'monkeys'} @@ -142,8 +138,7 @@ def test_parseExtensionArgs_ignoreBadAuthLevels(self): def test_parseExtensionArgs_strictBadAuthLevels(self): request_args = {'preferred_auth_level_types': 'monkeys'} - self.failUnlessRaises(ValueError, self.req.parseExtensionArgs, - request_args, is_openid1=False, strict=True) + self.assertRaises(ValueError, self.req.parseExtensionArgs, request_args, is_openid1=False, strict=True) def test_parseExtensionArgs(self): args = {'preferred_auth_policies': 'http://foo http://bar', @@ -155,8 +150,7 @@ def test_parseExtensionArgs(self): def test_parseExtensionArgs_strict_bad_auth_age(self): args = {'max_auth_age': 'not an int'} - self.assertRaises(ValueError, self.req.parseExtensionArgs, args, - is_openid1=False, strict=True) + self.assertRaises(ValueError, self.req.parseExtensionArgs, args, is_openid1=False, strict=True) def test_parseExtensionArgs_empty(self): self.req.parseExtensionArgs({}, False) @@ -233,8 +227,7 @@ def test_add_policy_uri(self): self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) - self.failUnlessRaises(RuntimeError, self.resp.addPolicyURI, - pape.AUTH_NONE) + self.assertRaises(RuntimeError, self.resp.addPolicyURI, pape.AUTH_NONE) def test_getExtensionArgs(self): self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': pape.AUTH_NONE}) @@ -252,7 +245,7 @@ def test_getExtensionArgs(self): def test_getExtensionArgs_error_auth_age(self): self.resp.auth_time = "long ago" - self.failUnlessRaises(ValueError, self.resp.getExtensionArgs) + self.assertRaises(ValueError, self.resp.getExtensionArgs) def test_parseExtensionArgs(self): args = {'auth_policies': 'http://foo http://bar', @@ -273,9 +266,7 @@ def test_parseExtensionArgs_old_none(self): def test_parseExtensionArgs_old_none_strict(self): args = {'auth_policies': 'none'} - self.failUnlessRaises( - ValueError, - self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) + self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) def test_parseExtensionArgs_empty(self): self.resp.parseExtensionArgs({}, is_openid1=False) @@ -283,9 +274,7 @@ def test_parseExtensionArgs_empty(self): self.assertEqual(self.resp.auth_policies, []) def test_parseExtensionArgs_empty_strict(self): - self.failUnlessRaises( - ValueError, - self.resp.parseExtensionArgs, {}, is_openid1=False, strict=True) + self.assertRaises(ValueError, self.resp.parseExtensionArgs, {}, is_openid1=False, strict=True) def test_parseExtensionArgs_ignore_superfluous_none(self): policies = [pape.AUTH_NONE, pape.AUTH_MULTI_FACTOR_PHYSICAL] @@ -305,14 +294,12 @@ def test_parseExtensionArgs_none_strict(self): 'auth_policies': ' '.join(policies), } - self.failUnlessRaises(ValueError, self.resp.parseExtensionArgs, - args, is_openid1=False, strict=True) + self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) def test_parseExtensionArgs_strict_bogus1(self): args = {'auth_policies': 'http://foo http://bar', 'auth_time': 'yesterday'} - self.failUnlessRaises(ValueError, self.resp.parseExtensionArgs, - args, is_openid1=False, strict=True) + self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) def test_parseExtensionArgs_openid1_strict(self): args = {'auth_level.nist': '0', @@ -328,8 +315,7 @@ def test_parseExtensionArgs_strict_no_namespace_decl_openid2(self): args = {'auth_policies': pape.AUTH_NONE, 'auth_level.nist': '0', } - self.failUnlessRaises(ValueError, self.resp.parseExtensionArgs, - args, is_openid1=False, strict=True) + self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) def test_parseExtensionArgs_nostrict_no_namespace_decl_openid2(self): # Test the case where the namespace is not declared for an @@ -340,8 +326,7 @@ def test_parseExtensionArgs_nostrict_no_namespace_decl_openid2(self): self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) # There is no namespace declaration for this auth level. - self.failUnlessRaises(KeyError, self.resp.getAuthLevel, - pape.LEVELS_NIST) + self.assertRaises(KeyError, self.resp.getAuthLevel, pape.LEVELS_NIST) def test_parseExtensionArgs_strict_good(self): args = {'auth_policies': 'http://foo http://bar', diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index 82432f32..b2e2e23f 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -73,8 +73,7 @@ def failUnlessXRDSHasReturnURLs(self, data, expected_return_urls): def failUnlessDiscoveryFailure(self, text): self.data = text - self.failUnlessRaises( - DiscoveryFailure, trustroot.getAllowedReturnURLs, self.disco_url) + self.assertRaises(DiscoveryFailure, trustroot.getAllowedReturnURLs, self.disco_url) def test_empty(self): self.failUnlessDiscoveryFailure('') diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 6433f0d2..e6e9d431 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -146,14 +146,14 @@ def test_irrelevant(self): 'pony': 'spotted', 'sreg.mutant_power': 'decaffinator', } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_bad(self): args = { 'openid.mode': 'twos-compliment', 'openid.pants': 'zippered', } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_dictOfLists(self): args = { @@ -233,7 +233,7 @@ def test_checkidSetupNoClaimedIDOpenID2(self): 'openid.return_to': self.rt_url, 'openid.realm': self.tr_url, } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoIdentityOpenID2(self): args = { @@ -261,7 +261,7 @@ def test_checkidSetupNoReturnOpenID1(self): 'openid.assoc_handle': self.assoc_handle, 'openid.trust_root': self.tr_url, } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupNoReturnOpenID2(self): """Make sure an OpenID 2 request with no return_to can be @@ -294,7 +294,7 @@ def test_checkidSetupRealmRequiredOpenID2(self): 'openid.identity': self.id_url, 'openid.assoc_handle': self.assoc_handle, } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_checkidSetupBadReturn(self): args = { @@ -352,7 +352,7 @@ def test_checkAuthMissingSignature(self): 'openid.bar': 'signedval2', 'openid.baz': 'unsigned', } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_checkAuthAndInvalidate(self): args = { @@ -390,7 +390,7 @@ def test_associateDHMissingKey(self): 'openid.session_type': 'DH-SHA1', } # Using DH-SHA1 without supplying dh_consumer_public is an error. - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_associateDHpubKeyNotB64(self): args = { @@ -398,7 +398,7 @@ def test_associateDHpubKeyNotB64(self): 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "donkeydonkeydonkey", } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_associateDHModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen @@ -427,7 +427,7 @@ def test_associateDHCorruptModGen(self): 'openid.dh_modulus': 'pizza', 'openid.dh_gen': 'gnocchi', } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_associateDHMissingModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen @@ -437,7 +437,7 @@ def test_associateDHMissingModGen(self): 'openid.dh_consumer_public': "Rzup9265tw==", 'openid.dh_modulus': 'pizza', } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) # def test_associateDHInvalidModGen(self): @@ -450,7 +450,7 @@ def test_associateDHMissingModGen(self): # 'openid.dh_modulus': cryptutil.longToBase64(9), # 'openid.dh_gen': cryptutil.longToBase64(27) , # } -# self.failUnlessRaises(server.ProtocolError, self.decode, args) +# self.assertRaises(server.ProtocolError, self.decode, args) # test_associateDHInvalidModGen.todo = "low-priority feature" def test_associateWeirdSession(self): @@ -459,7 +459,7 @@ def test_associateWeirdSession(self): 'openid.session_type': 'FLCL6', 'openid.dh_consumer_public': "YQ==\n", } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_associatePlain(self): args = { @@ -476,7 +476,7 @@ def test_nomode(self): 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "my public keeey", } - self.failUnlessRaises(server.ProtocolError, self.decode, args) + self.assertRaises(server.ProtocolError, self.decode, args) def test_invalidns(self): args = {'openid.ns': 'Tuesday', @@ -730,7 +730,7 @@ def test_unencodableError(self): 'openid.identity': 'http://limu.unittest/', }) e = server.ProtocolError(args, "wet paint") - self.failUnlessRaises(server.EncodingError, self.encode, e) + self.assertRaises(server.EncodingError, self.encode, e) def test_encodableError(self): args = Message.fromPostArgs({ @@ -798,7 +798,7 @@ def test_idresDumb(self): def test_forgotStore(self): self.encoder.signatory = None - self.failUnlessRaises(ValueError, self.encode, self.response) + self.assertRaises(ValueError, self.encode, self.response) def test_cancel(self): request = server.CheckIDRequest( @@ -834,7 +834,7 @@ def test_assocReply(self): def test_alreadySigned(self): self.response.fields.setArg(OPENID_NS, 'sig', 'priorSig==') - self.failUnlessRaises(server.AlreadySigned, self.encode, self.response) + self.assertRaises(server.AlreadySigned, self.encode, self.response) class TestCheckID(unittest.TestCase): @@ -976,8 +976,7 @@ def test_answerAllowWithoutIdentityReally(self): def test_answerAllowAnonymousFail(self): self.request.identity = None # XXX - Check on this, I think this behavior is legal in OpenID 2.0? - self.failUnlessRaises( - ValueError, self.request.answer, True, identity="=V") + self.assertRaises(ValueError, self.request.answer, True, identity="=V") def test_answerAllowWithIdentity(self): self.request.identity = IDENTIFIER_SELECT @@ -1004,15 +1003,11 @@ def test_answerAllowWithDelegatedIdentityOpenID1(self): self.request.identity = IDENTIFIER_SELECT selected_id = 'http://anon.unittest/9861' claimed_id = 'http://monkeyhat.unittest/' - self.failUnlessRaises(server.VersionError, - self.request.answer, True, - identity=selected_id, - claimed_id=claimed_id) + self.assertRaises(server.VersionError, self.request.answer, True, identity=selected_id, claimed_id=claimed_id) def test_answerAllowWithAnotherIdentity(self): # XXX - Check on this, I think this behavior is legal in OpenID 2.0? - self.failUnlessRaises(ValueError, self.request.answer, True, - identity="http://pebbles.unittest/") + self.assertRaises(ValueError, self.request.answer, True, identity="http://pebbles.unittest/") def test_answerAllowWithIdentityNormalization(self): # The RP has sent us a non-normalized value for openid.identity, @@ -1034,12 +1029,11 @@ def test_answerAllowWithIdentityNormalization(self): def test_answerAllowNoIdentityOpenID1(self): self.request.message = Message(OPENID1_NS) self.request.identity = None - self.failUnlessRaises(ValueError, self.request.answer, True, - identity=None) + self.assertRaises(ValueError, self.request.answer, True, identity=None) def test_answerAllowForgotEndpoint(self): self.request.op_endpoint = None - self.failUnlessRaises(RuntimeError, self.request.answer, True) + self.assertRaises(RuntimeError, self.request.answer, True) def test_checkIDWithNoIdentityOpenID1(self): msg = Message(OPENID1_NS) @@ -1048,9 +1042,7 @@ def test_checkIDWithNoIdentityOpenID1(self): msg.setArg(OPENID_NS, 'mode', 'checkid_setup') msg.setArg(OPENID_NS, 'assoc_handle', 'bogus') - self.failUnlessRaises(server.ProtocolError, - server.CheckIDRequest.fromMessage, - msg, self.server) + self.assertRaises(server.ProtocolError, server.CheckIDRequest.fromMessage, msg, self.server) def test_fromMessageClaimedIDWithoutIdentityOpenID2(self): name = 'https://example.myopenid.com' @@ -1060,9 +1052,7 @@ def test_fromMessageClaimedIDWithoutIdentityOpenID2(self): msg.setArg(OPENID_NS, 'return_to', 'http://invalid:8000/rt') msg.setArg(OPENID_NS, 'claimed_id', name) - self.failUnlessRaises(server.ProtocolError, - server.CheckIDRequest.fromMessage, - msg, self.server) + self.assertRaises(server.ProtocolError, server.CheckIDRequest.fromMessage, msg, self.server) def test_fromMessageIdentityWithoutClaimedIDOpenID2(self): name = 'https://example.myopenid.com' @@ -1072,9 +1062,7 @@ def test_fromMessageIdentityWithoutClaimedIDOpenID2(self): msg.setArg(OPENID_NS, 'return_to', 'http://invalid:8000/rt') msg.setArg(OPENID_NS, 'identity', name) - self.failUnlessRaises(server.ProtocolError, - server.CheckIDRequest.fromMessage, - msg, self.server) + self.assertRaises(server.ProtocolError, server.CheckIDRequest.fromMessage, msg, self.server) def test_trustRootOpenID1(self): """Ignore openid.realm in OpenID 1""" @@ -1145,9 +1133,7 @@ def test_fromMessageWithoutTrustRootOrReturnTo(self): msg.setArg(OPENID_NS, 'identity', 'george') msg.setArg(OPENID_NS, 'claimed_id', 'george') - self.failUnlessRaises(server.ProtocolError, - server.CheckIDRequest.fromMessage, - msg, self.server.op_endpoint) + self.assertRaises(server.ProtocolError, server.CheckIDRequest.fromMessage, msg, self.server.op_endpoint) def test_answerAllowNoEndpointOpenID1(self): """Test .allow() with an OpenID 1.x Message on a CheckIDRequest @@ -1246,7 +1232,7 @@ def test_getCancelURL(self): def test_getCancelURLimmed(self): self.request.mode = 'checkid_immediate' self.request.immediate = True - self.failUnlessRaises(ValueError, self.request.getCancelURL) + self.assertRaises(ValueError, self.request.getCancelURL) class TestCheckIDExtension(unittest.TestCase): @@ -1437,9 +1423,7 @@ def test_protoError256(self): for request_args in bad_request_argss: message = Message.fromPostArgs(request_args) - self.failUnlessRaises(server.ProtocolError, - server.AssociateRequest.fromMessage, - message) + self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, message) def test_protoError(self): from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession @@ -1460,9 +1444,7 @@ def test_protoError(self): for request_args in bad_request_argss: message = Message.fromPostArgs(request_args) - self.failUnlessRaises(server.ProtocolError, - server.AssociateRequest.fromMessage, - message) + self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, message) def test_protoErrorFields(self): @@ -1717,8 +1699,7 @@ def test_missingSessionTypeOpenID2(self): 'openid.ns': OPENID2_NS, }) - self.assertRaises(server.ProtocolError, - server.AssociateRequest.fromMessage, msg) + self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, msg) def test_missingAssocTypeOpenID2(self): """Make sure assoc_type is required in OpenID 2""" @@ -1727,8 +1708,7 @@ def test_missingAssocTypeOpenID2(self): 'openid.session_type': 'no-encryption', }) - self.assertRaises(server.ProtocolError, - server.AssociateRequest.fromMessage, msg) + self.assertRaises(server.ProtocolError, server.AssociateRequest.fromMessage, msg) def test_checkAuth(self): request = server.CheckAuthRequest('arrrrrf', '0x3999', []) diff --git a/openid/test/test_services.py b/openid/test/test_services.py index 8708a3c5..94ae817a 100644 --- a/openid/test/test_services.py +++ b/openid/test/test_services.py @@ -18,6 +18,4 @@ def discover(self, input_url): return result def test_catchXRDSError(self): - self.failUnlessRaises(DiscoveryFailure, - services.getServiceEndpoints, - "http://example.invalid/sometest") + self.assertRaises(DiscoveryFailure, services.getServiceEndpoints, "http://example.invalid/sometest") diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index bca85040..acb3c25f 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -16,10 +16,10 @@ def test_goodNamePasses(self): sreg.checkFieldName(field_name) def test_badNameFails(self): - self.failUnlessRaises(ValueError, sreg.checkFieldName, 'INVALID') + self.assertRaises(ValueError, sreg.checkFieldName, 'INVALID') def test_badTypeFails(self): - self.failUnlessRaises(ValueError, sreg.checkFieldName, None) + self.assertRaises(ValueError, sreg.checkFieldName, None) # For supportsSReg test @@ -96,14 +96,12 @@ def test_openID1Defined_1_0_overrideAlias(self): def test_openID1DefinedBadly(self): self.msg.openid1 = True self.msg.namespaces.addAlias('http://invalid/', 'sreg') - self.failUnlessRaises(sreg.SRegNamespaceError, - sreg.getSRegNS, self.msg) + self.assertRaises(sreg.SRegNamespaceError, sreg.getSRegNS, self.msg) def test_openID2DefinedBadly(self): self.msg.openid1 = False self.msg.namespaces.addAlias('http://invalid/', 'sreg') - self.failUnlessRaises(sreg.SRegNamespaceError, - sreg.getSRegNS, self.msg) + self.assertRaises(sreg.SRegNamespaceError, sreg.getSRegNS, self.msg) def test_openID2Defined_1_0(self): self.msg.namespaces.add(sreg.ns_uri_1_0) @@ -142,9 +140,7 @@ def test_constructFields(self): self.assertEqual(req.ns_uri, 'http://sreg.ns_uri') def test_constructBadFields(self): - self.failUnlessRaises( - ValueError, - sreg.SRegRequest, ['elvis']) + self.assertRaises(ValueError, sreg.SRegRequest, ['elvis']) def test_fromOpenIDRequest(self): ns_sentinel = object() @@ -196,9 +192,7 @@ def test_parseExtensionArgs_nonStrict(self): def test_parseExtensionArgs_strict(self): req = sreg.SRegRequest() - self.failUnlessRaises( - ValueError, - req.parseExtensionArgs, {'required': 'beans'}, strict=True) + self.assertRaises(ValueError, req.parseExtensionArgs, {'required': 'beans'}, strict=True) def test_parseExtensionArgs_policy(self): req = sreg.SRegRequest() @@ -232,10 +226,7 @@ def test_parseExtensionArgs_optionalListBadNonStrict(self): def test_parseExtensionArgs_optionalListBadStrict(self): req = sreg.SRegRequest() - self.failUnlessRaises( - ValueError, - req.parseExtensionArgs, {'optional': 'nickname,email,beer'}, - strict=True) + self.assertRaises(ValueError, req.parseExtensionArgs, {'optional': 'nickname,email,beer'}, strict=True) def test_parseExtensionArgs_bothNonStrict(self): req = sreg.SRegRequest() @@ -246,12 +237,8 @@ def test_parseExtensionArgs_bothNonStrict(self): def test_parseExtensionArgs_bothStrict(self): req = sreg.SRegRequest() - self.failUnlessRaises( - ValueError, - req.parseExtensionArgs, - {'optional': 'nickname', - 'required': 'nickname'}, - strict=True) + self.assertRaises(ValueError, req.parseExtensionArgs, {'optional': 'nickname', 'required': 'nickname'}, + strict=True) def test_parseExtensionArgs_bothList(self): req = sreg.SRegRequest() @@ -291,13 +278,9 @@ def test_contains(self): def test_requestField_bogus(self): req = sreg.SRegRequest() - self.failUnlessRaises( - ValueError, - req.requestField, 'something else') + self.assertRaises(ValueError, req.requestField, 'something else') - self.failUnlessRaises( - ValueError, - req.requestField, 'something else', strict=True) + self.assertRaises(ValueError, req.requestField, 'something else', strict=True) def test_requestField(self): # Add all of the fields, one at a time @@ -339,7 +322,7 @@ def test_requestField(self): def test_requestFields_type(self): req = sreg.SRegRequest() - self.failUnlessRaises(TypeError, req.requestFields, 'nickname') + self.assertRaises(TypeError, req.requestFields, 'nickname') def test_requestFields(self): # Add all of the fields diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index 04b629bd..f8086882 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -39,14 +39,12 @@ def test_openID1NoLocalID(self): def test_openID1NoEndpoint(self): msg = message.Message.fromOpenIDArgs({'identity': 'snakes on a plane'}) - self.failUnlessRaises(RuntimeError, - self.consumer._verifyDiscoveryResults, msg) + self.assertRaises(RuntimeError, self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2NoOPEndpointArg(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS}) - self.failUnlessRaises(KeyError, - self.consumer._verifyDiscoveryResults, msg) + self.assertRaises(KeyError, self.consumer._verifyDiscoveryResults, msg) self.failUnlessLogEmpty() def test_openID2LocalIDNoClaimed(self): @@ -193,9 +191,7 @@ def discoverAndVerify(claimed_id, _to_match): {'ns': message.OPENID1_NS, 'identity': endpoint.local_id}) - self.failUnlessRaises( - VerifiedError, - self.consumer._verifyDiscoveryResults, msg, endpoint) + self.assertRaises(VerifiedError, self.consumer._verifyDiscoveryResults, msg, endpoint) self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index 7284cbd1..472eef37 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -93,7 +93,7 @@ def tearDown(self): def test_404(self): uri = "http://something.unittest/" - self.failUnlessRaises(DiscoveryFailure, discover, uri) + self.assertRaises(DiscoveryFailure, discover, uri) class TestDiscover(unittest.TestCase): @@ -116,8 +116,7 @@ def test(self): success) if expected is DiscoveryFailure: - self.failUnlessRaises(DiscoveryFailure, - discover, input_url) + self.assertRaises(DiscoveryFailure, discover, input_url) else: result = discover(input_url) self.assertEqual(result.request_uri, input_url) From be3e62c24bfb28075006eecf616d58346bc743ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jan 2018 10:52:08 +0100 Subject: [PATCH 035/151] Replace failUnless --- examples/djopenid/server/tests.py | 11 +- openid/test/test_association.py | 4 +- openid/test/test_association_response.py | 13 +- openid/test/test_auth_request.py | 6 +- openid/test/test_ax.py | 24 +-- openid/test/test_consumer.py | 167 +++++++---------- openid/test/test_discover.py | 32 ++-- openid/test/test_extension.py | 2 +- openid/test/test_fetchers.py | 23 +-- openid/test/test_message.py | 48 +++-- openid/test/test_negotiation.py | 12 +- openid/test/test_nonce.py | 12 +- openid/test/test_oidutil.py | 6 +- openid/test/test_pape_draft2.py | 2 +- openid/test/test_pape_draft5.py | 2 +- openid/test/test_parsehtml.py | 4 +- openid/test/test_rpverify.py | 16 +- openid/test/test_server.py | 225 ++++++++++------------- openid/test/test_sreg.py | 16 +- openid/test/test_symbol.py | 2 +- openid/test/test_verifydisco.py | 28 +-- 21 files changed, 276 insertions(+), 379 deletions(-) diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index b6bb5850..0beefc21 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -50,9 +50,9 @@ def test_allow(self): self.assertEqual(response.status_code, HTTP_REDIRECT) finalURL = response['location'] - self.failUnless('openid.mode=id_res' in finalURL, finalURL) - self.failUnless('openid.identity=' in finalURL, finalURL) - self.failUnless('openid.sreg.postcode=12345' in finalURL, finalURL) + self.assertIn('openid.mode=id_res', finalURL) + self.assertIn('openid.identity=', finalURL) + self.assertIn('openid.sreg.postcode=12345', finalURL) def test_cancel(self): self.request.POST['cancel'] = 'Yes' @@ -61,7 +61,7 @@ def test_cancel(self): self.assertEqual(response.status_code, HTTP_REDIRECT) finalURL = response['location'] - self.failUnless('openid.mode=cancel' in finalURL, finalURL) + self.assertIn('openid.mode=cancel', finalURL) self.failIf('openid.identity=' in finalURL, finalURL) self.failIf('openid.sreg.postcode=12345' in finalURL, finalURL) @@ -85,8 +85,7 @@ def test_unreachableRealm(self): views.setRequest(self.request, self.openid_request) response = views.showDecidePage(self.request, self.openid_request) - self.failUnless('trust_root_valid is Unreachable' in response.content, - response) + self.assertIn('trust_root_valid is Unreachable', response.content) class TestGenericXRDS(TestCase): diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 5caa4214..bd042e98 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -115,7 +115,7 @@ def test_signSHA1(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA1") signed = assoc.signMessage(self.message) - self.failUnless(signed.getArg(OPENID_NS, "sig")) + self.assertTrue(signed.getArg(OPENID_NS, "sig")) self.assertEqual(signed.getArg(OPENID_NS, "signed"), "assoc_handle,identifier,mode,ns,signed") self.assertEqual(signed.getArg(BARE_NS, "xey"), "value") @@ -123,7 +123,7 @@ def test_signSHA256(self): assoc = association.Association.fromExpiresIn( 3600, '{sha1}', 'very_secret', "HMAC-SHA256") signed = assoc.signMessage(self.message) - self.failUnless(signed.getArg(OPENID_NS, "sig")) + self.assertTrue(signed.getArg(OPENID_NS, "sig")) self.assertEqual(signed.getArg(OPENID_NS, "signed"), "assoc_handle,identifier,mode,ns,signed") self.assertEqual(signed.getArg(BARE_NS, "xey"), "value") diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 6333c13a..d880a407 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -41,13 +41,8 @@ def setUp(self): self.endpoint = OpenIDServiceEndpoint() def failUnlessProtocolError(self, str_prefix, func, *args, **kwargs): - try: - result = func(*args, **kwargs) - except ProtocolError as e: - message = 'Expected prefix %r, got %r' % (str_prefix, e[0]) - self.failUnless(e[0].startswith(str_prefix), message) - else: - self.fail('Expected ProtocolError, got %r' % (result,)) + with self.assertRaisesRegexp(ProtocolError, str_prefix): + func(*args, **kwargs) def mkExtractAssocMissingTest(keys): @@ -193,7 +188,7 @@ def _doTest(self, expected_session_type, session_type_value): if session_type_value is not None: args['session_type'] = session_type_value message = Message.fromOpenIDArgs(args) - self.failUnless(message.isOpenID1()) + self.assertTrue(message.isOpenID1()) actual_session_type = self.consumer._getOpenID1SessionType(message) error_message = ('Returned sesion type parameter %r was expected ' @@ -278,7 +273,7 @@ def test_worksWithGoodFields(self): """Handle a full successful association response""" assoc = self.consumer._extractAssociation( self.assoc_response, self.assoc_session) - self.failUnless(self.assoc_session.extract_secret_called) + self.assertTrue(self.assoc_session.extract_secret_called) self.assertEqual(assoc.secret, self.assoc_session.secret) self.assertEqual(assoc.lifetime, 1000) self.assertEqual(assoc.handle, self.assoc_handle) diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index 06ee02b4..3f287ac4 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -122,7 +122,7 @@ def failUnlessHasIdentifiers(self, msg, op_specific_id, claimed_id): def test_setAnonymousWorksForOpenID2(self): """OpenID AuthRequests should be able to set 'anonymous' to true.""" - self.failUnless(self.authreq.message.isOpenID2()) + self.assertTrue(self.authreq.message.isOpenID2()) self.authreq.setAnonymous(True) self.authreq.setAnonymous(False) @@ -160,7 +160,7 @@ def failUnlessHasIdentifiers(self, msg, op_specific_id, claimed_id): def failUnlessIdentifiersPresent(self, msg): self.failIfOpenIDKeyExists(msg, 'claimed_id') - self.failUnless(msg.hasKey(message.OPENID_NS, 'identity')) + self.assertTrue(msg.hasKey(message.OPENID_NS, 'identity')) def failUnlessHasRealm(self, msg): # check presence of proper realm key and absence of the wrong @@ -172,7 +172,7 @@ def failUnlessHasRealm(self, msg): def test_setAnonymousFailsForOpenID1(self): """OpenID 1 requests MUST NOT be able to set anonymous to True""" - self.failUnless(self.authreq.message.isOpenID1()) + self.assertTrue(self.authreq.message.isOpenID1()) self.assertRaises(ValueError, self.authreq.setAnonymous, True) self.authreq.setAnonymous(False) diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 1c94e4ba..d6bdaa24 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -120,8 +120,8 @@ def test_requestUnlimitedValues(self): attrs = list(msg.iterAttrs()) foo = attrs[0] - self.failUnless(foo.count == ax.UNLIMITED_VALUES) - self.failUnless(foo.wantsUnlimitedValues()) + self.assertEqual(foo.count, ax.UNLIMITED_VALUES) + self.assertTrue(foo.wantsUnlimitedValues()) def test_longAlias(self): # Spec minimum length is 32 characters. This is a silly test @@ -215,7 +215,7 @@ def test_add(self): self.msg.add(attr) # Present after adding - self.failUnless(uri in self.msg) + self.assertIn(uri, self.msg) def test_addTwice(self): uri = 'lightning://storm' @@ -301,10 +301,10 @@ def test_parseExtensionArgs(self): 'if_available': self.alias_a } self.msg.parseExtensionArgs(extension_args) - self.failUnless(self.type_a in self.msg) + self.assertIn(self.type_a, self.msg) self.assertEqual(list(self.msg), [self.type_a]) attr_info = self.msg.requested_attributes.get(self.type_a) - self.failUnless(attr_info) + self.assertIsNotNone(attr_info) self.failIf(attr_info.required) self.assertEqual(attr_info.type_uri, self.type_a) self.assertEqual(attr_info.alias, self.alias_a) @@ -329,7 +329,7 @@ def test_extensionArgs_idempotent_count_required(self): } self.msg.parseExtensionArgs(extension_args) self.assertEqual(self.msg.getExtensionArgs(), extension_args) - self.failUnless(self.msg.requested_attributes[self.type_a].required) + self.assertTrue(self.msg.requested_attributes[self.type_a].required) def test_extensionArgs_count1(self): extension_args = { @@ -400,7 +400,7 @@ def test_fromOpenIDRequestWithoutExtension(self): }) oreq = DummyRequest(openid_req_msg) r = ax.FetchRequest.fromOpenIDRequest(oreq) - self.failUnless(r is None, "%s is not None" % (r,)) + self.assertIsNone(r) def test_fromOpenIDRequestWithoutData(self): """return something for SuccessResponse with AX paramaters, @@ -414,7 +414,7 @@ def test_fromOpenIDRequestWithoutData(self): }) oreq = DummyRequest(openid_req_msg) r = ax.FetchRequest.fromOpenIDRequest(oreq) - self.failUnless(r is not None) + self.assertIsNotNone(r) class FetchResponseTest(unittest.TestCase): @@ -426,7 +426,7 @@ def setUp(self): self.request_update_url = 'http://update.bogus/' def test_construct(self): - self.failUnless(self.msg.update_url is None) + self.assertIsNone(self.msg.update_url) self.assertEqual(self.msg.data, {}) def test_getExtensionArgs_empty(self): @@ -519,7 +519,7 @@ class Endpoint: oreq = SuccessResponse(Endpoint(), msg, signed_fields=sf) r = ax.FetchResponse.fromSuccessResponse(oreq) - self.failUnless(r is None, "%s is not None" % (r,)) + self.assertIsNone(r) def test_fromSuccessResponseWithoutData(self): """return something for SuccessResponse with AX paramaters, @@ -538,7 +538,7 @@ class Endpoint: oreq = SuccessResponse(Endpoint(), msg, signed_fields=sf) r = ax.FetchResponse.fromSuccessResponse(oreq) - self.failUnless(r is not None) + self.assertIsNotNone(r) def test_fromSuccessResponseWithData(self): name = 'ext0' @@ -601,7 +601,7 @@ def test_getExtensionArgs_nonempty(self): class StoreResponseTest(unittest.TestCase): def test_success(self): msg = ax.StoreResponse() - self.failUnless(msg.succeeded()) + self.assertTrue(msg.succeeded()) self.failIf(msg.error_message) self.assertEqual(msg.getExtensionArgs(), {'mode': 'store_response_success'}) diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index eae07f33..68caf0c8 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -256,7 +256,7 @@ def setUp(self): def test_construct(self): oidc = GenericConsumer(self.store_sentinel) - self.failUnless(oidc.store is self.store_sentinel) + self.assertEqual(oidc.store, self.store_sentinel) def test_nostore(self): self.assertRaises(TypeError, GenericConsumer) @@ -360,12 +360,8 @@ def test_notAList(self): # Value should be a single string. If it's a list, it should generate # an exception. query = {'openid.mode': ['cancel']} - try: - r = Message.fromPostArgs(query) - except TypeError as err: - self.failUnless(str(err).find('values') != -1, err) - else: - self.fail("expected TypeError, got this instead: %s" % (r,)) + with self.assertRaisesRegexp(TypeError, 'values'): + Message.fromPostArgs(query) class TestComplete(TestIdRes): @@ -379,27 +375,27 @@ def test_setupNeededIdRes(self): setup_url_sentinel = object() def raiseSetupNeeded(msg): - self.failUnless(msg is message) + self.assertEqual(msg, message) raise SetupNeededError(setup_url_sentinel) self.consumer._checkSetupNeeded = raiseSetupNeeded response = self.consumer.complete(message, None, None) self.assertEqual(response.status, SETUP_NEEDED) - self.failUnless(setup_url_sentinel is response.setup_url) + self.assertEqual(response.setup_url, setup_url_sentinel) def test_cancel(self): message = Message.fromPostArgs({'openid.mode': 'cancel'}) self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.assertEqual(r.status, CANCEL) - self.failUnless(r.identity_url == self.endpoint.claimed_id) + self.assertEqual(r.identity_url, self.endpoint.claimed_id) def test_cancel_with_return_to(self): message = Message.fromPostArgs({'openid.mode': 'cancel'}) r = self.consumer.complete(message, self.endpoint, self.return_to) self.assertEqual(r.status, CANCEL) - self.failUnless(r.identity_url == self.endpoint.claimed_id) + self.assertEqual(r.identity_url, self.endpoint.claimed_id) def test_error(self): msg = 'an error message' @@ -407,7 +403,7 @@ def test_error(self): self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.assertEqual(r.status, FAILURE) - self.failUnless(r.identity_url == self.endpoint.claimed_id) + self.assertEqual(r.identity_url, self.endpoint.claimed_id) self.assertEqual(r.message, msg) def test_errorWithNoOptionalKeys(self): @@ -417,9 +413,9 @@ def test_errorWithNoOptionalKeys(self): self.disableReturnToChecking() r = self.consumer.complete(message, self.endpoint) self.assertEqual(r.status, FAILURE) - self.failUnless(r.identity_url == self.endpoint.claimed_id) - self.failUnless(r.contact == contact) - self.failUnless(r.reference is None) + self.assertEqual(r.identity_url, self.endpoint.claimed_id) + self.assertEqual(r.contact, contact) + self.assertIsNone(r.reference) self.assertEqual(r.message, msg) def test_errorWithOptionalKeys(self): @@ -430,16 +426,16 @@ def test_errorWithOptionalKeys(self): 'openid.contact': contact, 'openid.ns': OPENID2_NS}) r = self.consumer.complete(message, self.endpoint, None) self.assertEqual(r.status, FAILURE) - self.failUnless(r.identity_url == self.endpoint.claimed_id) - self.failUnless(r.contact == contact) - self.failUnless(r.reference == reference) + self.assertEqual(r.identity_url, self.endpoint.claimed_id) + self.assertEqual(r.contact, contact) + self.assertEqual(r.reference, reference) self.assertEqual(r.message, msg) def test_noMode(self): message = Message.fromPostArgs({}) r = self.consumer.complete(message, self.endpoint, None) self.assertEqual(r.status, FAILURE) - self.failUnless(r.identity_url == self.endpoint.claimed_id) + self.assertEqual(r.identity_url, self.endpoint.claimed_id) def test_idResMissingField(self): # XXX - this test is passing, but not necessarily by what it @@ -568,7 +564,7 @@ def test_goodResponse(self): """successful response to check_authentication""" response = Message.fromOpenIDArgs({'is_valid': 'true'}) r = self.consumer._processCheckAuthResponse(response, self.server_url) - self.failUnless(r) + self.assertTrue(r) def test_missingAnswer(self): """check_authentication returns false when the server sends no answer""" @@ -598,8 +594,7 @@ def test_badResponseInvalidate(self): }) r = self.consumer._processCheckAuthResponse(response, self.server_url) self.failIf(r) - self.failUnless( - self.consumer.store.getAssociation(self.server_url) is None) + self.assertIsNone(self.consumer.store.getAssociation(self.server_url)) def test_invalidateMissing(self): """invalidate_handle with a handle that is not present""" @@ -608,7 +603,7 @@ def test_invalidateMissing(self): 'invalidate_handle': 'missing', }) r = self.consumer._processCheckAuthResponse(response, self.server_url) - self.failUnless(r) + self.assertTrue(r) self.failUnlessLogMatches( 'Received "invalidate_handle"' ) @@ -621,7 +616,7 @@ def test_invalidateMissing_noStore(self): }) self.consumer.store = None r = self.consumer._processCheckAuthResponse(response, self.server_url) - self.failUnless(r) + self.assertTrue(r) self.failUnlessLogMatches( 'Received "invalidate_handle"', 'Unexpectedly got invalidate_handle without a store') @@ -641,9 +636,8 @@ def test_invalidatePresent(self): 'invalidate_handle': 'handle', }) r = self.consumer._processCheckAuthResponse(response, self.server_url) - self.failUnless(r) - self.failUnless( - self.consumer.store.getAssociation(self.server_url) is None) + self.assertTrue(r) + self.assertIsNone(self.consumer.store.getAssociation(self.server_url)) class TestSetupNeeded(TestIdRes): @@ -659,7 +653,7 @@ def test_setupNeededOpenID1(self): 'openid.mode': 'id_res', 'openid.user_setup_url': setup_url, }) - self.failUnless(message.isOpenID1()) + self.assertTrue(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) def test_setupNeededOpenID1_extra(self): @@ -670,14 +664,14 @@ def test_setupNeededOpenID1_extra(self): 'openid.user_setup_url': setup_url, 'openid.identity': 'bogus', }) - self.failUnless(message.isOpenID1()) + self.assertTrue(message.isOpenID1()) self.failUnlessSetupNeeded(setup_url, message) def test_noSetupNeededOpenID1(self): """When the user_setup_url is missing on an OpenID 1 message, we assume that it's not a cancel response to checkid_immediate""" message = Message.fromOpenIDArgs({'mode': 'id_res'}) - self.failUnless(message.isOpenID1()) + self.assertTrue(message.isOpenID1()) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) @@ -687,7 +681,7 @@ def test_setupNeededOpenID2(self): 'mode': 'setup_needed', 'ns': OPENID2_NS, }) - self.failUnless(message.isOpenID2()) + self.assertTrue(message.isOpenID2()) response = self.consumer.complete(message, None, None) self.assertEqual(response.status, 'setup_needed') self.assertIsNone(response.setup_url) @@ -702,7 +696,7 @@ def test_setupNeededDoesntWorkForOpenID1(self): response = self.consumer.complete(message, None, None) self.assertEqual(response.status, 'failure') - self.failUnless(response.message.startswith('Invalid openid.mode')) + self.assertTrue(response.message.startswith('Invalid openid.mode')) def test_noSetupNeededOpenID2(self): message = Message.fromOpenIDArgs({ @@ -710,7 +704,7 @@ def test_noSetupNeededOpenID2(self): 'game': 'puerto_rico', 'ns': OPENID2_NS, }) - self.failUnless(message.isOpenID2()) + self.assertTrue(message.isOpenID2()) # No SetupNeededError raised self.consumer._checkSetupNeeded(message) @@ -761,23 +755,17 @@ def test(self): def mkMissingFieldTest(openid_args): def test(self): message = Message.fromOpenIDArgs(openid_args) - try: + with self.assertRaises(ProtocolError) as catch: self.consumer._idResCheckForFields(message) - except ProtocolError as why: - self.failUnless(why[0].startswith('Missing required')) - else: - self.fail('Expected an error, but none occurred') + self.assertTrue(catch.exception[0].startswith('Missing required')) return test def mkMissingSignedTest(openid_args): def test(self): message = Message.fromOpenIDArgs(openid_args) - try: + with self.assertRaises(ProtocolError) as catch: self.consumer._idResCheckForFields(message) - except ProtocolError as why: - self.failUnless(why[0].endswith('not signed')) - else: - self.fail('Expected an error, but none occurred') + self.assertTrue(catch.exception[0].endswith('not signed')) return test test_openid1Missing_returnToSig = mkMissingSignedTest( @@ -847,7 +835,7 @@ def test_openid1Missing(self): """use consumer-generated nonce""" self.response = Message.fromOpenIDArgs({}) n = self.consumer._idResGetNonceOpenID1(self.response, self.endpoint) - self.failUnless(n is None, n) + self.assertIsNone(n) self.failUnlessLogEmpty() def test_consumerNonceOpenID2(self): @@ -973,7 +961,7 @@ def test_expiredAssoc(self): handle = 'handle' assoc = association.Association( handle, 'secret', issued, lifetime, 'HMAC-SHA1') - self.failUnless(assoc.expiresIn <= 0) + self.assertLessEqual(assoc.expiresIn, 0) self.store.storeAssociation(self.server_url, assoc) message = Message.fromPostArgs({ @@ -1137,8 +1125,7 @@ def test_completeGoodReturnTo(self): m.setArg(OPENID_NS, 'return_to', good) result = self.consumer.complete(m, endpoint, return_to) - self.failUnless(isinstance(result, CancelResponse), - "Expected CancelResponse, got %r for %s" % (result, good,)) + self.assertIsInstance(result, CancelResponse, "Expected CancelResponse, got %r for %s" % (result, good)) class MockFetcher(object): @@ -1194,7 +1181,7 @@ def test_error(self): r = self.consumer._checkAuth(Message.fromPostArgs(query), http_server_url) self.failIf(r) - self.failUnless(self.messages) + self.assertTrue(self.messages) def test_bad_args(self): query = { @@ -1216,9 +1203,9 @@ def test_signedList(self): 'foo': 'bar', }) args = self.consumer._createCheckAuthRequest(query) - self.failUnless(args.isOpenID1()) + self.assertTrue(args.isOpenID1()) for signed_arg in query.getArg(OPENID_NS, 'signed').split(','): - self.failUnless(args.getAliasedArg(signed_arg), signed_arg) + self.assertTrue(args.getAliasedArg(signed_arg)) def test_112(self): args = { @@ -1239,12 +1226,12 @@ def test_112(self): 'ns.pape,pape.nist_auth_level,pape.auth_policies'} self.assertEqual(args['openid.ns'], OPENID2_NS) incoming = Message.fromPostArgs(args) - self.failUnless(incoming.isOpenID2()) + self.assertTrue(incoming.isOpenID2()) car = self.consumer._createCheckAuthRequest(incoming) expected_args = args.copy() expected_args['openid.mode'] = 'check_authentication' expected = Message.fromPostArgs(expected_args) - self.failUnless(expected.isOpenID2()) + self.assertTrue(expected.isOpenID2()) self.assertEqual(car, expected) self.assertEqual(car.toPostArgs(), expected_args) @@ -1296,7 +1283,7 @@ def test_error_exception_wrapped(self): # exception fetching returns no association e = OpenIDServiceEndpoint() e.server_url = 'some://url' - self.failUnless(self.consumer._getAssociation(e) is None) + self.assertIsNone(self.consumer._getAssociation(e)) msg = Message.fromPostArgs({'openid.signed': ''}) self.failIf(self.consumer._checkAuth(msg, 'some://url')) @@ -1353,7 +1340,7 @@ def test_extensionResponseSigned(self): def test_noReturnTo(self): resp = mkSuccess(self.endpoint, {}) - self.failUnless(resp.getReturnTo() is None) + self.assertIsNone(resp.getReturnTo()) def test_returnTo(self): resp = mkSuccess(self.endpoint, {'return_to': 'return_to'}) @@ -1404,8 +1391,7 @@ def setUp(self): def test_setAssociationPreference(self): self.consumer.setAssociationPreference([]) - self.failUnless(isinstance(self.consumer.consumer.negotiator, - association.SessionNegotiator)) + self.assertIsInstance(self.consumer.consumer.negotiator, association.SessionNegotiator) self.assertEqual(self.consumer.consumer.negotiator.allowed_types, []) self.consumer.setAssociationPreference([('HMAC-SHA1', 'DH-SHA1')]) self.assertEqual(self.consumer.consumer.negotiator.allowed_types, [('HMAC-SHA1', 'DH-SHA1')]) @@ -1433,13 +1419,9 @@ def getNextService(self, ignored): raise HTTPFetchingError("Unit test") def test(): - try: + text = 'Error fetching XRDS document: Unit test' + with self.assertRaisesRegexp(DiscoveryFailure, text): self.consumer.begin('unused in this test') - except DiscoveryFailure as why: - self.failUnless(why[0].startswith('Error fetching')) - self.failIf(why[0].find('Unit test') == -1) - else: - self.fail('Expected DiscoveryFailure') self.withDummyDiscovery(test, getNextService) @@ -1450,13 +1432,9 @@ def getNextService(self, ignored): url = 'http://a.user.url/' def test(): - try: + text = 'No usable OpenID services found for http://a.user.url/' + with self.assertRaisesRegexp(DiscoveryFailure, text): self.consumer.begin(url) - except DiscoveryFailure as why: - self.failUnless(why[0].startswith('No usable OpenID')) - self.failIf(why[0].find(url) == -1) - else: - self.fail('Expected DiscoveryFailure') self.withDummyDiscovery(test, getNextService) @@ -1465,20 +1443,20 @@ def test_beginWithoutDiscovery(self): result = self.consumer.beginWithoutDiscovery(self.endpoint) # The result is an auth request - self.failUnless(isinstance(result, AuthRequest)) + self.assertIsInstance(result, AuthRequest) # Side-effect of calling beginWithoutDiscovery is setting the # session value to the endpoint attribute of the result - self.failUnless(self.session[self.consumer._token_key] is result.endpoint) + self.assertEqual(self.session[self.consumer._token_key], result.endpoint) # The endpoint that we passed in is the endpoint on the auth_request - self.failUnless(result.endpoint is self.endpoint) + self.assertEqual(result.endpoint, self.endpoint) def test_completeEmptySession(self): text = "failed complete" def checkEndpoint(message, endpoint, return_to): - self.failUnless(endpoint is None) + self.assertIsNone(endpoint) return FailureResponse(endpoint, text) self.consumer.consumer.complete = checkEndpoint @@ -1486,7 +1464,7 @@ def checkEndpoint(message, endpoint, return_to): response = self.consumer.complete({}, None) self.assertEqual(response.status, FAILURE) self.assertEqual(response.message, text) - self.failUnless(response.identity_url is None) + self.assertIsNone(response.identity_url) def _doResp(self, auth_req, exp_resp): """complete a transaction, using the expected response from @@ -1496,13 +1474,13 @@ def _doResp(self, auth_req, exp_resp): self.consumer.consumer.response = exp_resp # endpoint is stored in the session - self.failUnless(self.session) + self.assertTrue(self.session) resp = self.consumer.complete({}, None) # All responses should have the same identity URL, and the # session should be cleaned out if self.endpoint.claimed_id != IDENTIFIER_SELECT: - self.failUnless(resp.identity_url is self.identity_url) + self.assertEqual(resp.identity_url, self.identity_url) self.failIf(self.consumer._token_key in self.session) @@ -1528,13 +1506,13 @@ def test_noDiscoCompleteCancelWithToken(self): def test_noDiscoCompleteFailure(self): msg = 'failed!' resp = self._doRespNoDisco(FailureResponse(self.endpoint, msg)) - self.failUnless(resp.message is msg) + self.assertEqual(resp.message, msg) def test_noDiscoCompleteSetupNeeded(self): setup_url = 'http://setup.url/' resp = self._doRespNoDisco( SetupNeededResponse(self.endpoint, setup_url)) - self.failUnless(resp.setup_url is setup_url) + self.assertEqual(resp.setup_url, setup_url) # To test that discovery is cleaned up, we need to initialize a # Yadis manager, and have it put its values in the session. @@ -1546,7 +1524,7 @@ def _doRespDisco(self, is_clean, exp_resp): manager = self.discovery.getManager() if is_clean: - self.failUnless(self.discovery.getManager() is None, manager) + self.assertIsNone(self.discovery.getManager()) else: self.failIf(self.discovery.getManager() is None, manager) @@ -1563,14 +1541,14 @@ def test_completeCancel(self): def test_completeFailure(self): msg = 'failed!' resp = self._doRespDisco(False, FailureResponse(self.endpoint, msg)) - self.failUnless(resp.message is msg) + self.assertEqual(resp.message, msg) def test_completeSetupNeeded(self): setup_url = 'http://setup.url/' resp = self._doRespDisco( False, SetupNeededResponse(self.endpoint, setup_url)) - self.failUnless(resp.setup_url is setup_url) + self.assertEqual(resp.setup_url, setup_url) def test_successDifferentURL(self): """ @@ -1588,16 +1566,16 @@ def test_successDifferentURL(self): resp_endpoint.claimed_id = "http://user.url/" self._doRespDisco(True, mkSuccess(resp_endpoint, {})) - self.failUnless(self.discovery.getManager(force=True) is None) + self.assertIsNone(self.discovery.getManager(force=True)) def test_begin(self): self.discovery.createManager([self.endpoint], self.identity_url) # Should not raise an exception auth_req = self.consumer.begin(self.identity_url) - self.failUnless(isinstance(auth_req, AuthRequest)) - self.failUnless(auth_req.endpoint is self.endpoint) - self.failUnless(auth_req.endpoint is self.consumer.consumer.endpoint) - self.failUnless(auth_req.assoc is self.consumer.consumer.assoc) + self.assertIsInstance(auth_req, AuthRequest) + self.assertEqual(auth_req.endpoint, self.endpoint) + self.assertEqual(auth_req.endpoint, self.consumer.consumer.endpoint) + self.assertEqual(auth_req.assoc, self.consumer.consumer.assoc) class IDPDrivenTest(unittest.TestCase): @@ -1629,7 +1607,7 @@ def test_idpDrivenComplete(self): iverified = [] def verifyDiscoveryResults(identifier, endpoint): - self.failUnless(endpoint is self.endpoint) + self.assertEqual(endpoint, self.endpoint) iverified.append(discovered_endpoint) return discovered_endpoint self.consumer._verifyDiscoveryResults = verifyDiscoveryResults @@ -1715,13 +1693,8 @@ def discoverAndVerify(claimed_id, to_match_endpoints): endpoint.server_url = "http://the-MOON.unittest/" endpoint.local_id = self.identifier self.services = [endpoint] - try: - r = self.consumer._verifyDiscoveryResults(self.message, endpoint) - except ProtocolError as e: - # Should we make more ProtocolError subclasses? - self.failUnless(str(e), text) - else: - self.fail("expected ProtocolError, %r returned." % (r,)) + with self.assertRaisesRegexp(ProtocolError, text): + self.consumer._verifyDiscoveryResults(self.message, endpoint) def test_foreignDelegate(self): text = "verify failed" @@ -1770,7 +1743,7 @@ def test_noEncryptionSendsType(self): session, args = self.consumer._createAssociateRequest( self.endpoint, self.assoc_type, session_type) - self.failUnless(isinstance(session, PlainTextConsumerSession)) + self.assertIsInstance(session, PlainTextConsumerSession) expected = Message.fromOpenIDArgs( {'ns': OPENID2_NS, 'session_type': session_type, @@ -1786,7 +1759,7 @@ def test_noEncryptionCompatibility(self): session, args = self.consumer._createAssociateRequest( self.endpoint, self.assoc_type, session_type) - self.failUnless(isinstance(session, PlainTextConsumerSession)) + self.assertIsInstance(session, PlainTextConsumerSession) self.assertEqual(args, Message.fromOpenIDArgs({'mode': 'associate', 'assoc_type': self.assoc_type})) def test_dhSHA1Compatibility(self): @@ -1799,11 +1772,11 @@ def test_dhSHA1Compatibility(self): session, args = self.consumer._createAssociateRequest( self.endpoint, self.assoc_type, session_type) - self.failUnless(isinstance(session, DiffieHellmanSHA1ConsumerSession)) + self.assertIsInstance(session, DiffieHellmanSHA1ConsumerSession) # This is a random base-64 value, so just check that it's # present. - self.failUnless(args.getArg(OPENID1_NS, 'dh_consumer_public')) + self.assertTrue(args.getArg(OPENID1_NS, 'dh_consumer_public')) args.delArg(OPENID1_NS, 'dh_consumer_public') # OK, session_type is set here and not for no-encryption diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 94a64702..2a7b4d8e 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -1,6 +1,5 @@ # -*- coding: utf-8 -*- import os.path -import sys import unittest from urlparse import urlsplit @@ -81,17 +80,9 @@ class TestFetchException(unittest.TestCase): ] def runOneTest(self, exc): - try: + with self.assertRaises(Exception) as catch: discover.discover('http://doesnt.matter/') - except Exception: - exc = sys.exc_info()[1] - if exc is None: - # str exception - self.failUnless(exc is sys.exc_info()[0]) - else: - self.failUnless(exc is exc, exc) - else: - self.fail('Expected %r', exc) + self.assertEqual(catch.exception, exc) def test(self): for exc in self.cases: @@ -169,14 +160,14 @@ def _checkService(self, s, self.failIf(s.local_id) self.failIf(s.getLocalID()) self.failIf(s.compatibilityMode()) - self.failUnless(s.isOPIdentifier()) + self.assertTrue(s.isOPIdentifier()) self.assertEqual(s.preferredNamespace(), discover.OPENID_2_0_MESSAGE_NS) else: self.assertEqual(s.claimed_id, claimed_id) self.assertEqual(s.getLocalID(), local_id) if used_yadis: - self.failUnless(s.used_yadis, "Expected to use Yadis") + self.assertTrue(s.used_yadis, "Expected to use Yadis") else: self.failIf(s.used_yadis, "Expected to use old-style discovery") @@ -193,8 +184,8 @@ def _checkService(self, s, self.assertEqual(s.canonicalID, canonical_id) if s.canonicalID: - self.failUnless(s.getDisplayIdentifier() != claimed_id) - self.failUnless(s.getDisplayIdentifier() is not None) + self.assertNotEqual(s.getDisplayIdentifier(), claimed_id) + self.assertIsNotNone(s.getDisplayIdentifier()) self.assertEqual(s.getDisplayIdentifier(), display_identifier) self.assertEqual(s.canonicalID, s.claimed_id) @@ -598,7 +589,7 @@ class TestXRIDiscoveryIDP(BaseTestDiscovery): def test_xri(self): user_xri, services = discover.discoverXRI('=smoker') - self.failUnless(services, "Expected services, got zero") + self.assertTrue(services, "Expected services, got zero") self.assertEqual(services[0].server_url, "http://www.livejournal.com/openid/server.bml") @@ -646,7 +637,7 @@ def test_openid2(self): def test_openid2OP(self): self.endpoint.type_uris = [discover.OPENID_IDP_2_0_TYPE] - self.failUnless(self.endpoint.isOPIdentifier()) + self.assertTrue(self.endpoint.isOPIdentifier()) def test_multipleMissing(self): self.endpoint.type_uris = [discover.OPENID_2_0_TYPE, @@ -657,7 +648,7 @@ def test_multiplePresent(self): self.endpoint.type_uris = [discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE, discover.OPENID_IDP_2_0_TYPE] - self.failUnless(self.endpoint.isOPIdentifier()) + self.assertTrue(self.endpoint.isOPIdentifier()) class TestFromOPEndpointURL(unittest.TestCase): @@ -667,7 +658,7 @@ def setUp(self): self.op_endpoint_url) def test_isOPEndpoint(self): - self.failUnless(self.endpoint.isOPIdentifier()) + self.assertTrue(self.endpoint.isOPIdentifier()) def test_noIdentifiers(self): self.assertIsNone(self.endpoint.getLocalID()) @@ -727,8 +718,7 @@ def failUnlessSupportsOnly(self, *types): discover.OPENID_IDP_2_0_TYPE, ]: if t in types: - self.failUnless(self.endpoint.supportsType(t), - "Must support %r" % (t,)) + self.assertTrue(self.endpoint.supportsType(t), "Must support %r" % t) else: self.failIf(self.endpoint.supportsType(t), "Shouldn't support %r" % (t,)) diff --git a/openid/test/test_extension.py b/openid/test/test_extension.py index 487968fd..4bccd754 100644 --- a/openid/test/test_extension.py +++ b/openid/test/test_extension.py @@ -17,7 +17,7 @@ def test_OpenID1(self): ext = DummyExtension() ext.toMessage(oid1_msg) namespaces = oid1_msg.namespaces - self.failUnless(namespaces.isImplicit(DummyExtension.ns_uri)) + self.assertTrue(namespaces.isImplicit(DummyExtension.ns_uri)) self.assertEqual(DummyExtension.ns_uri, namespaces.getNamespaceURI(DummyExtension.ns_alias)) self.assertEqual(DummyExtension.ns_alias, namespaces.getAlias(DummyExtension.ns_uri)) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 2ca208b2..e6df2d53 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -1,5 +1,4 @@ import socket -import sys import unittest import urllib2 import warnings @@ -226,31 +225,29 @@ def tearDown(self): def test_getDefaultNotNone(self): """Make sure that None is never returned as a default fetcher""" - self.failUnless(fetchers.getDefaultFetcher() is not None) + self.assertIsNotNone(fetchers.getDefaultFetcher()) fetchers.setDefaultFetcher(None) - self.failUnless(fetchers.getDefaultFetcher() is not None) + self.assertIsNotNone(fetchers.getDefaultFetcher()) def test_setDefault(self): """Make sure the getDefaultFetcher returns the object set for setDefaultFetcher""" sentinel = object() fetchers.setDefaultFetcher(sentinel, wrap_exceptions=False) - self.failUnless(fetchers.getDefaultFetcher() is sentinel) + self.assertEqual(fetchers.getDefaultFetcher(), sentinel) def test_callFetch(self): """Make sure that fetchers.fetch() uses the default fetcher instance that was set.""" fetchers.setDefaultFetcher(FakeFetcher()) actual = fetchers.fetch('bad://url') - self.failUnless(actual is FakeFetcher.sentinel) + self.assertEqual(actual, FakeFetcher.sentinel) def test_wrappedByDefault(self): """Make sure that the default fetcher instance wraps exceptions by default""" default_fetcher = fetchers.getDefaultFetcher() - self.failUnless(isinstance(default_fetcher, - fetchers.ExceptionWrappingFetcher), - default_fetcher) + self.assertIsInstance(default_fetcher, fetchers.ExceptionWrappingFetcher) self.assertRaises(fetchers.HTTPFetchingError, fetchers.fetch, 'http://invalid.janrain.com/') @@ -265,16 +262,8 @@ def test_notWrapped(self): self.failIf(isinstance(fetchers.getDefaultFetcher(), fetchers.ExceptionWrappingFetcher)) - try: + with self.assertRaises(urllib2.URLError): fetchers.fetch('http://invalid.janrain.com/') - except fetchers.HTTPFetchingError: - self.fail('Should not be wrapping exception') - except Exception: - exc = sys.exc_info()[1] - self.failUnless(isinstance(exc, urllib2.URLError), exc) - pass - else: - self.fail('Should have raised an exception') class TestHandler(urllib2.BaseHandler): diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 98278f24..c414313a 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -365,7 +365,7 @@ def test_delArgNS3(self): self._test_delArgNS('urn:nothing-significant') def test_isOpenID1(self): - self.failUnless(self.msg.isOpenID1()) + self.assertTrue(self.msg.isOpenID1()) def test_isOpenID2(self): self.failIf(self.msg.isOpenID2()) @@ -404,7 +404,7 @@ def test_toURL(self): {'openid.mode': ['error'], 'openid.error': ['unit test'], 'openid.ns': [message.OPENID1_NS]}) def test_isOpenID1(self): - self.failUnless(self.msg.isOpenID1()) + self.assertTrue(self.msg.isOpenID1()) class OpenID2MessageTest(unittest.TestCase): @@ -596,15 +596,14 @@ def test_mysterious_missing_namespace_bug(self): 'sreg.email': 'a@b.com'} m = message.Message.fromOpenIDArgs(openid_args) - self.failUnless(('http://openid.net/extensions/sreg/1.1', 'sreg') in - list(m.namespaces.iteritems())) + self.assertEqual(m.namespaces.getAlias('http://openid.net/extensions/sreg/1.1'), 'sreg') missing = [] for k in openid_args['signed'].split(','): if not ("openid." + k) in m.toPostArgs().keys(): missing.append(k) self.assertEqual(missing, []) self.assertEqual(m.toArgs(), openid_args) - self.failUnless(m.isOpenID1()) + self.assertTrue(m.isOpenID1()) def test_112B(self): args = { @@ -630,7 +629,7 @@ def test_112B(self): missing.append(k) self.assertEqual(missing, [], missing) self.assertEqual(m.toPostArgs(), args) - self.failUnless(m.isOpenID2()) + self.assertTrue(m.isOpenID2()) def test_repetitive_namespaces(self): """ @@ -658,11 +657,10 @@ def test_repetitive_namespaces(self): def test_implicit_sreg_ns(self): openid_args = {'sreg.email': 'a@b.com'} m = message.Message.fromOpenIDArgs(openid_args) - self.failUnless((sreg.ns_uri, 'sreg') in - list(m.namespaces.iteritems())) + self.assertEqual(m.namespaces.getAlias(sreg.ns_uri), 'sreg') self.assertEqual(m.getArg(sreg.ns_uri, 'email'), 'a@b.com') self.assertEqual(m.toArgs(), openid_args) - self.failUnless(m.isOpenID1()) + self.assertTrue(m.isOpenID1()) def _test_delArgNS(self, ns): key = 'Camper van Beethoven' @@ -696,9 +694,9 @@ def test_overwriteExtensionArg(self): value_2 = 'value_2' self.msg.setArg(ns, key, value_1) - self.failUnless(self.msg.getArg(ns, key) == value_1) + self.assertEqual(self.msg.getArg(ns, key), value_1) self.msg.setArg(ns, key, value_2) - self.failUnless(self.msg.getArg(ns, key) == value_2) + self.assertEqual(self.msg.getArg(ns, key), value_2) def test_argList(self): self.assertRaises(TypeError, self.msg.fromPostArgs, {'arg': [1, 2, 3]}) @@ -707,7 +705,7 @@ def test_isOpenID1(self): self.failIf(self.msg.isOpenID1()) def test_isOpenID2(self): - self.failUnless(self.msg.isOpenID2()) + self.assertTrue(self.msg.isOpenID2()) class MessageTest(unittest.TestCase): @@ -881,16 +879,14 @@ def test_isOpenID1(self): for ns in v1_namespaces: m = message.Message(ns) - self.failUnless(m.isOpenID1(), "%r not recognized as OpenID 1" % - (ns,)) + self.assertTrue(m.isOpenID1(), "%r not recognized as OpenID 1" % ns) self.assertEqual(m.getOpenIDNamespace(), ns) - self.failUnless(m.namespaces.isImplicit(ns), - m.namespaces.getNamespaceURI(message.NULL_NAMESPACE)) + self.assertTrue(m.namespaces.isImplicit(ns)) def test_isOpenID2(self): ns = 'http://specs.openid.net/auth/2.0' m = message.Message(ns) - self.failUnless(m.isOpenID2()) + self.assertTrue(m.isOpenID2()) self.failIf(m.namespaces.isImplicit(message.NULL_NAMESPACE)) self.assertEqual(m.getOpenIDNamespace(), ns) @@ -902,7 +898,7 @@ def test_setOpenIDNamespace_explicit(self): def test_setOpenIDNamespace_implicit(self): m = message.Message() m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, True) - self.failUnless(m.namespaces.isImplicit(message.THE_OTHER_OPENID1_NS)) + self.assertTrue(m.namespaces.isImplicit(message.THE_OTHER_OPENID1_NS)) def test_explicitOpenID11NSSerialzation(self): m = message.Message() @@ -926,7 +922,7 @@ def test_fromPostArgs_ns11(self): u'openid.trust_root': u'http://drupal.invalid', } m = message.Message.fromPostArgs(query) - self.failUnless(m.isOpenID1()) + self.assertTrue(m.isOpenID1()) class NamespaceMapTest(unittest.TestCase): @@ -935,8 +931,8 @@ def test_onealias(self): uri = 'http://example.com/foo' alias = "foo" nsm.addAlias(uri, alias) - self.failUnless(nsm.getNamespaceURI(alias) == uri) - self.failUnless(nsm.getAlias(uri) == alias) + self.assertEqual(nsm.getNamespaceURI(alias), uri) + self.assertEqual(nsm.getAlias(uri), alias) def test_iteration(self): nsm = message.NamespaceMap() @@ -944,12 +940,12 @@ def test_iteration(self): nsm.add(uripat % 0) for n in range(1, 23): - self.failUnless(uripat % (n - 1) in nsm) - self.failUnless(nsm.isDefined(uripat % (n - 1))) + self.assertIn(uripat % (n - 1), nsm) + self.assertTrue(nsm.isDefined(uripat % (n - 1))) nsm.add(uripat % n) for (uri, alias) in nsm.iteritems(): - self.failUnless(uri[22:] == alias[3:]) + self.assertEqual(uri[22:], alias[3:]) i = 0 it = nsm.iterAliases() @@ -958,7 +954,7 @@ def test_iteration(self): it.next() i += 1 except StopIteration: - self.failUnless(i == 23) + self.assertEqual(i, 23) i = 0 it = nsm.iterNamespaceURIs() @@ -967,7 +963,7 @@ def test_iteration(self): it.next() i += 1 except StopIteration: - self.failUnless(i == 23) + self.assertEqual(i, 23) if __name__ == '__main__': diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index 6be4528a..71ff200b 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -126,7 +126,7 @@ def testUnsupportedWithRetry(self): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] - self.failUnless(self.consumer._negotiateAssociation(self.endpoint) is assoc) + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) self.failUnlessLogMatches('Unsupported association type') @@ -158,7 +158,7 @@ def testValid(self): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [assoc] - self.failUnless(self.consumer._negotiateAssociation(self.endpoint) is assoc) + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) self.failUnlessLogEmpty() @@ -238,7 +238,7 @@ def testUnsupportedWithRetry(self): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] - self.failUnless(self.consumer._negotiateAssociation(self.endpoint) is None) + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) self.failUnlessLogMatches('Server error when requesting an association') @@ -247,7 +247,7 @@ def testValid(self): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [assoc] - self.failUnless(self.consumer._negotiateAssociation(self.endpoint) is assoc) + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) self.failUnlessLogEmpty() @@ -268,10 +268,10 @@ def testAddAllowedTypeBadSessionType(self): def testAddAllowedTypeContents(self): assoc_type = 'HMAC-SHA1' - self.failUnless(self.n.addAllowedType(assoc_type) is None) + self.assertIsNone(self.n.addAllowedType(assoc_type)) for typ in association.getSessionTypes(assoc_type): - self.failUnless((assoc_type, typ) in self.n.allowed_types) + self.assertIn((assoc_type, typ), self.n.allowed_types) if __name__ == '__main__': diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index 3e6c5fe5..010178fa 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -9,14 +9,14 @@ class NonceTest(unittest.TestCase): def test_mkNonce(self): nonce = mkNonce() - self.failUnless(nonce_re.match(nonce)) - self.failUnless(len(nonce) == 26) + self.assertIsNotNone(nonce_re.match(nonce)) + self.assertEqual(len(nonce), 26) def test_mkNonce_when(self): nonce = mkNonce(0) - self.failUnless(nonce_re.match(nonce)) - self.failUnless(nonce.startswith('1970-01-01T00:00:00Z')) - self.failUnless(len(nonce) == 26) + self.assertIsNotNone(nonce_re.match(nonce)) + self.assertTrue(nonce.startswith('1970-01-01T00:00:00Z')) + self.assertEqual(len(nonce), 26) def test_splitNonce(self): s = '1970-01-01T00:00:00Z' @@ -29,7 +29,7 @@ def test_splitNonce(self): def test_mkSplit(self): t = 42 nonce_str = mkNonce(t) - self.failUnless(nonce_re.match(nonce_str)) + self.assertIsNotNone(nonce_re.match(nonce_str)) et, salt = splitNonce(nonce_str) self.assertEqual(len(salt), 6) self.assertEqual(et, t) diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index 7eed8865..3a66850e 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -135,10 +135,10 @@ class TestUnicodeConversion(unittest.TestCase): def test_toUnicode(self): # Unicode objects pass through - self.failUnless(isinstance(oidutil.toUnicode(u'fööbär'), unicode)) + self.assertIsInstance(oidutil.toUnicode(u'fööbär'), unicode) self.assertEquals(oidutil.toUnicode(u'fööbär'), u'fööbär') # UTF-8 encoded string are decoded - self.failUnless(isinstance(oidutil.toUnicode('fööbär'), unicode)) + self.assertIsInstance(oidutil.toUnicode('fööbär'), unicode) self.assertEquals(oidutil.toUnicode('fööbär'), u'fööbär') # Other encodings raise exceptions self.assertRaises(UnicodeDecodeError, lambda: oidutil.toUnicode(u'fööbär'.encode('latin-1'))) @@ -150,7 +150,7 @@ def testCopyHash(self): s = oidutil.Symbol("Foo") d = {s: 1} d_prime = copy.deepcopy(d) - self.failUnless(s in d_prime, "%r isn't in %r" % (s, d_prime)) + self.assertIn(s, d_prime, "%r isn't in %r" % (s, d_prime)) t = oidutil.Symbol("Bar") self.failIfEqual(hash(s), hash(t)) diff --git a/openid/test/test_pape_draft2.py b/openid/test/test_pape_draft2.py index 53e2c553..67ebcc73 100644 --- a/openid/test/test_pape_draft2.py +++ b/openid/test/test_pape_draft2.py @@ -216,4 +216,4 @@ def getSignedNS(self, ns_uri): oid_req = NoSigningDummyResponse(openid_req_msg, signed_stuff) resp = pape.Response.fromSuccessResponse(oid_req) - self.failUnless(resp is None) + self.assertIsNone(resp) diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index 0b0f10ea..89c79fda 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -383,7 +383,7 @@ def getSignedNS(self, ns_uri): oid_req = NoSigningDummyResponse(openid_req_msg, signed_stuff) resp = pape.Response.fromSuccessResponse(oid_req) - self.failUnless(resp is None) + self.assertIsNone(resp) if __name__ == '__main__': diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index 10a13163..546f9f8a 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -27,9 +27,9 @@ def test(self): msg = "%r != %r for case %s" % (found, expected, case) self.assertEqual(found, expected, msg) except HTMLParseError: - self.failUnless(expected == 'None', (case, expected)) + self.assertEqual(expected, 'None', (case, expected)) else: - self.failUnless(expected == 'EOF', (case, expected)) + self.assertEqual(expected, 'EOF', (case, expected)) def parseCases(data): diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index b2e2e23f..769037a3 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -171,21 +171,16 @@ def test_noEntries(self): def test_exactMatch(self): r = 'http://example.com/return.to' - self.failUnless(trustroot.returnToMatches([r], r)) + self.assertTrue(trustroot.returnToMatches([r], r)) def test_garbageMatch(self): r = 'http://example.com/return.to' - self.failUnless(trustroot.returnToMatches( - ['This is not a URL at all. In fact, it has characters, ' - 'like "<" that are not allowed in URLs', - r], - r)) + realm = 'This is not a URL at all. In fact, it has characters, like "<" that are not allowed in URLs' + self.assertTrue(trustroot.returnToMatches([realm, r], r)) def test_descendant(self): r = 'http://example.com/return.to' - self.failUnless(trustroot.returnToMatches( - [r], - 'http://example.com/return.to/user:joe')) + self.assertTrue(trustroot.returnToMatches([r], 'http://example.com/return.to/user:joe')) def test_wildcard(self): self.failIf(trustroot.returnToMatches( @@ -218,8 +213,7 @@ def vrfy(disco_url): self.assertEqual(disco_url, 'http://www.example.com/') return [return_to] - self.failUnless( - trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + self.assertTrue(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) self.failUnlessLogEmpty() def test_verifyFailWithDiscoveryCalled(self): diff --git a/openid/test/test_server.py b/openid/test/test_server.py index e6e9d431..c5b1c4ca 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -35,7 +35,7 @@ def test_browserWithReturnTo(self): 'openid.return_to': return_to, }) e = server.ProtocolError(args, "plucky") - self.failUnless(e.hasReturnTo()) + self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], @@ -56,7 +56,7 @@ def test_browserWithReturnTo_OpenID2_GET(self): 'openid.return_to': return_to, }) e = server.ProtocolError(args, "plucky") - self.failUnless(e.hasReturnTo()) + self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.ns': [OPENID2_NS], 'openid.mode': ['error'], @@ -78,10 +78,9 @@ def test_browserWithReturnTo_OpenID2_POST(self): 'openid.return_to': return_to, }) e = server.ProtocolError(args, "plucky") - self.failUnless(e.hasReturnTo()) - self.failUnless(e.whichEncoding() == server.ENCODE_HTML_FORM) - self.failUnless(e.toFormMarkup() == e.toMessage().toFormMarkup( - args.getArg(OPENID_NS, 'return_to'))) + self.assertTrue(e.hasReturnTo()) + self.assertEqual(e.whichEncoding(), server.ENCODE_HTML_FORM) + self.assertEqual(e.toFormMarkup(), e.toMessage().toFormMarkup(args.getArg(OPENID_NS, 'return_to'))) def test_browserWithReturnTo_OpenID1_exceeds_limit(self): return_to = "http://rp.unittest/consumer" + ('x' * OPENID1_URL_LIMIT) @@ -92,13 +91,13 @@ def test_browserWithReturnTo_OpenID1_exceeds_limit(self): 'openid.return_to': return_to, }) e = server.ProtocolError(args, "plucky") - self.failUnless(e.hasReturnTo()) + self.assertTrue(e.hasReturnTo()) expected_args = { 'openid.mode': ['error'], 'openid.error': ['plucky'], } - self.failUnless(e.whichEncoding() == server.ENCODE_URL) + self.assertEqual(e.whichEncoding(), server.ENCODE_URL) rt_base, result_args = e.encodeToURL().split('?', 1) result_args = cgi.parse_qs(result_args) @@ -163,12 +162,8 @@ def test_dictOfLists(self): 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, } - try: - result = self.decode(args) - except TypeError as err: - self.failUnless(str(err).find('values') != -1, err) - else: - self.fail("Expected TypeError, but got result %s" % (result,)) + with self.assertRaisesRegexp(TypeError, 'values'): + self.decode(args) def test_checkidImmediate(self): args = { @@ -181,7 +176,7 @@ def test_checkidImmediate(self): 'openid.some.extension': 'junk', } r = self.decode(args) - self.failUnless(isinstance(r, server.CheckIDRequest)) + self.assertIsInstance(r, server.CheckIDRequest) self.assertEqual(r.mode, "checkid_immediate") self.assertTrue(r.immediate) self.assertEqual(r.identity, self.id_url) @@ -198,7 +193,7 @@ def test_checkidSetup(self): 'openid.trust_root': self.tr_url, } r = self.decode(args) - self.failUnless(isinstance(r, server.CheckIDRequest)) + self.assertIsInstance(r, server.CheckIDRequest) self.assertEqual(r.mode, "checkid_setup") self.assertFalse(r.immediate) self.assertEqual(r.identity, self.id_url) @@ -216,7 +211,7 @@ def test_checkidSetupOpenID2(self): 'openid.realm': self.tr_url, } r = self.decode(args) - self.failUnless(isinstance(r, server.CheckIDRequest)) + self.assertIsInstance(r, server.CheckIDRequest) self.assertEqual(r.mode, "checkid_setup") self.assertFalse(r.immediate) self.assertEqual(r.identity, self.id_url) @@ -244,7 +239,7 @@ def test_checkidSetupNoIdentityOpenID2(self): 'openid.realm': self.tr_url, } r = self.decode(args) - self.failUnless(isinstance(r, server.CheckIDRequest)) + self.assertIsInstance(r, server.CheckIDRequest) self.assertEqual(r.mode, "checkid_setup") self.assertFalse(r.immediate) self.assertIsNone(r.identity) @@ -276,7 +271,7 @@ def test_checkidSetupNoReturnOpenID2(self): 'openid.assoc_handle': self.assoc_handle, 'openid.realm': self.tr_url, } - self.failUnless(isinstance(self.decode(args), server.CheckIDRequest)) + self.assertIsInstance(self.decode(args), server.CheckIDRequest) req = self.decode(args) self.assertRaises(server.NoReturnToError, req.answer, False) @@ -303,13 +298,9 @@ def test_checkidSetupBadReturn(self): 'openid.assoc_handle': self.assoc_handle, 'openid.return_to': 'not a url', } - try: - result = self.decode(args) - except server.ProtocolError as err: - self.failUnless(err.openid_message) - else: - self.fail("Expected ProtocolError, instead returned with %s" % - (result,)) + with self.assertRaises(server.ProtocolError) as catch: + self.decode(args) + self.assertTrue(catch.exception.openid_message) def test_checkidSetupUntrustedReturn(self): args = { @@ -319,13 +310,9 @@ def test_checkidSetupUntrustedReturn(self): 'openid.return_to': self.rt_url, 'openid.trust_root': 'http://not-the-return-place.unittest/', } - try: - result = self.decode(args) - except server.UntrustedReturnURL as err: - self.failUnless(err.openid_message) - else: - self.fail("Expected UntrustedReturnURL, instead returned with %s" % - (result,)) + with self.assertRaises(server.UntrustedReturnURL) as catch: + self.decode(args) + self.assertTrue(catch.exception.openid_message) def test_checkAuth(self): args = { @@ -339,7 +326,7 @@ def test_checkAuth(self): 'openid.baz': 'unsigned', } r = self.decode(args) - self.failUnless(isinstance(r, server.CheckAuthRequest)) + self.assertIsInstance(r, server.CheckAuthRequest) self.assertEqual(r.mode, 'check_authentication') self.assertEqual(r.sig, 'sigblob') @@ -367,7 +354,7 @@ def test_checkAuthAndInvalidate(self): 'openid.baz': 'unsigned', } r = self.decode(args) - self.failUnless(isinstance(r, server.CheckAuthRequest)) + self.assertIsInstance(r, server.CheckAuthRequest) self.assertEqual(r.invalidate_handle, '[[SMART_handle]]') def test_associateDH(self): @@ -377,11 +364,11 @@ def test_associateDH(self): 'openid.dh_consumer_public': "Rzup9265tw==", } r = self.decode(args) - self.failUnless(isinstance(r, server.AssociateRequest)) + self.assertIsInstance(r, server.AssociateRequest) self.assertEqual(r.mode, "associate") self.assertEqual(r.session.session_type, "DH-SHA1") self.assertEqual(r.assoc_type, "HMAC-SHA1") - self.failUnless(r.session.consumer_pubkey) + self.assertTrue(r.session.consumer_pubkey) def test_associateDHMissingKey(self): """Trying DH assoc w/o public key""" @@ -410,13 +397,13 @@ def test_associateDHModGen(self): 'openid.dh_gen': cryptutil.longToBase64(ALT_GEN), } r = self.decode(args) - self.failUnless(isinstance(r, server.AssociateRequest)) + self.assertIsInstance(r, server.AssociateRequest) self.assertEqual(r.mode, "associate") self.assertEqual(r.session.session_type, "DH-SHA1") self.assertEqual(r.assoc_type, "HMAC-SHA1") self.assertEqual(r.session.dh.modulus, ALT_MODULUS) self.assertEqual(r.session.dh.generator, ALT_GEN) - self.failUnless(r.session.consumer_pubkey) + self.assertTrue(r.session.consumer_pubkey) def test_associateDHCorruptModGen(self): # test dh with non-default but valid values for dh_modulus and dh_gen @@ -466,7 +453,7 @@ def test_associatePlain(self): 'openid.mode': 'associate', } r = self.decode(args) - self.failUnless(isinstance(r, server.AssociateRequest)) + self.assertIsInstance(r, server.AssociateRequest) self.assertEqual(r.mode, "associate") self.assertEqual(r.session.session_type, "no-encryption") self.assertEqual(r.assoc_type, "HMAC-SHA1") @@ -482,16 +469,9 @@ def test_invalidns(self): args = {'openid.ns': 'Tuesday', 'openid.mode': 'associate'} - try: - r = self.decode(args) - except server.ProtocolError as err: - # Assert that the ProtocolError does have a Message attached - # to it, even though the request wasn't a well-formed Message. - self.failUnless(err.openid_message) - # The error message contains the bad openid.ns. - self.failUnless('Tuesday' in str(err), str(err)) - else: - self.fail("Expected ProtocolError but returned with %r" % (r,)) + with self.assertRaisesRegexp(server.ProtocolError, 'Tuesday') as catch: + self.decode(args) + self.assertTrue(catch.exception.openid_message) class TestEncode(unittest.TestCase): @@ -526,7 +506,7 @@ def test_id_res_OpenID2_GET(self): }) self.failIf(response.renderAsForm()) - self.failUnless(response.whichEncoding() == server.ENCODE_URL) + self.assertEqual(response.whichEncoding(), server.ENCODE_URL) webresponse = self.encode(response) self.assertIn('location', webresponse.headers) @@ -553,9 +533,9 @@ def test_id_res_OpenID2_POST(self): 'return_to': 'x' * OPENID1_URL_LIMIT, }) - self.failUnless(response.renderAsForm()) - self.failUnless(len(response.encodeToURL()) > OPENID1_URL_LIMIT) - self.failUnless(response.whichEncoding() == server.ENCODE_HTML_FORM) + self.assertTrue(response.renderAsForm()) + self.assertGreater(len(response.encodeToURL()), OPENID1_URL_LIMIT) + self.assertEqual(response.whichEncoding(), server.ENCODE_HTML_FORM) webresponse = self.encode(response) self.assertIn(response.toFormMarkup(), webresponse.body) @@ -578,7 +558,7 @@ def test_toFormMarkup(self): }) form_markup = response.toFormMarkup({'foo': 'bar'}) - self.failUnless(' foo="bar"' in form_markup) + self.assertIn(' foo="bar"', form_markup) def test_toHTML(self): request = server.CheckIDRequest( @@ -598,11 +578,11 @@ def test_toHTML(self): 'return_to': 'x' * OPENID1_URL_LIMIT, }) html = response.toHTML() - self.failUnless('' in html) - self.failUnless('' in html) - self.failUnless('', html) + self.assertIn('', html) + self.assertIn(' OPENID1_URL_LIMIT) - self.failUnless(response.whichEncoding() == server.ENCODE_URL) + self.assertGreater(len(response.encodeToURL()), OPENID1_URL_LIMIT) + self.assertEqual(response.whichEncoding(), server.ENCODE_URL) webresponse = self.encode(response) self.assertEqual(webresponse.headers['location'], response.encodeToURL()) @@ -652,9 +632,8 @@ def test_id_res(self): self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] - self.failUnless(location.startswith(request.return_to), - "%s does not start with %s" % (location, - request.return_to)) + self.assertTrue(location.startswith(request.return_to), + "%s does not start with %s" % (location, request.return_to)) # argh. q2 = dict(cgi.parse_qsl(urlparse(location)[4])) expected = response.fields.toPostArgs() @@ -691,7 +670,7 @@ def test_cancelToForm(self): 'mode': 'cancel', }) form = response.toFormMarkup() - self.failUnless(form) + self.assertTrue(form) def test_assocReply(self): msg = Message(OPENID2_NS) @@ -781,9 +760,9 @@ def test_idres(self): location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) - self.failUnless('openid.sig' in query) - self.failUnless('openid.assoc_handle' in query) - self.failUnless('openid.signed' in query) + self.assertIn('openid.sig', query) + self.assertIn('openid.assoc_handle', query) + self.assertIn('openid.signed', query) def test_idresDumb(self): webresponse = self.encode(self.response) @@ -792,9 +771,9 @@ def test_idresDumb(self): location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) - self.failUnless('openid.sig' in query) - self.failUnless('openid.assoc_handle' in query) - self.failUnless('openid.signed' in query) + self.assertIn('openid.sig', query) + self.assertIn('openid.assoc_handle', query) + self.assertIn('openid.signed', query) def test_forgotStore(self): self.encoder.signatory = None @@ -859,20 +838,16 @@ def test_trustRootInvalid(self): def test_trustRootValid(self): self.request.trust_root = "http://foo.unittest/" self.request.return_to = "http://foo.unittest/39" - self.failUnless(self.request.trustRootValid()) + self.assertTrue(self.request.trustRootValid()) def test_malformedTrustRoot(self): self.request.trust_root = "invalid://trust*root/" self.request.return_to = "http://foo.unittest/39" sentinel = object() self.request.message = sentinel - try: - result = self.request.trustRootValid() - except server.MalformedTrustRoot as why: - self.failUnless(sentinel is why.openid_message) - else: - self.fail('Expected MalformedTrustRoot exception. Got %r' - % (result,)) + with self.assertRaises(server.MalformedTrustRoot) as catch: + self.request.trustRootValid() + self.assertEqual(catch.exception.openid_message, sentinel) def test_trustRootValidNoReturnTo(self): request = server.CheckIDRequest( @@ -883,7 +858,7 @@ def test_trustRootValidNoReturnTo(self): op_endpoint=self.server.op_endpoint, ) - self.failUnless(request.trustRootValid()) + self.assertTrue(request.trustRootValid()) def test_returnToVerified_callsVerify(self): """Make sure that verifyReturnTo is calling the trustroot @@ -905,10 +880,9 @@ def vrfyExc(trust_root, return_to): self.assertEqual(return_to, self.request.return_to) raise sentinel - try: + with self.assertRaises(Exception) as catch: withVerifyReturnTo(vrfyExc, self.request.returnToVerified) - except Exception as e: - self.failUnless(e is sentinel, e) + self.assertEqual(catch.exception, sentinel) # Ensure that True and False are passed through unchanged def constVerify(val): @@ -938,8 +912,8 @@ def _expectAnswer(self, answer, identity=None, claimed_id=None): actual = answer.fields.getArg(OPENID_NS, k) self.assertEqual(actual, expected, "%s: expected %s, got %s" % (k, expected, actual)) - self.failUnless(answer.fields.hasKey(OPENID_NS, 'response_nonce')) - self.failUnless(answer.fields.getOpenIDNamespace() == OPENID2_NS) + self.assertTrue(answer.fields.hasKey(OPENID_NS, 'response_nonce')) + self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID2_NS) # One for nonce, one for ns self.assertEqual(len(answer.fields.toPostArgs()), len(expected_list) + 2) @@ -1076,7 +1050,7 @@ def test_trustRootOpenID1(self): result = server.CheckIDRequest.fromMessage(msg, self.server.op_endpoint) - self.failUnless(result.trust_root == 'http://real_trust_root/') + self.assertEqual(result.trust_root, 'http://real_trust_root/') def test_trustRootOpenID2(self): """Ignore openid.trust_root in OpenID 2""" @@ -1091,7 +1065,7 @@ def test_trustRootOpenID2(self): result = server.CheckIDRequest.fromMessage(msg, self.server.op_endpoint) - self.failUnless(result.trust_root == 'http://real_trust_root/') + self.assertEqual(result.trust_root, 'http://real_trust_root/') def test_answerAllowNoTrustRoot(self): self.request.trust_root = None @@ -1158,9 +1132,9 @@ def test_answerAllowNoEndpointOpenID1(self): actual = answer.fields.getArg(OPENID_NS, k) self.assertEqual(actual, expected, "%s: expected %s, got %s" % (k, expected, actual)) - self.failUnless(answer.fields.hasKey(OPENID_NS, 'response_nonce')) + self.assertTrue(answer.fields.hasKey(OPENID_NS, 'response_nonce')) self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID1_NS) - self.failUnless(answer.fields.namespaces.isImplicit(OPENID1_NS)) + self.assertTrue(answer.fields.namespaces.isImplicit(OPENID1_NS)) # One for nonce (OpenID v1 namespace is implicit) self.assertEqual(len(answer.fields.toPostArgs()), len(expected_list) + 1) @@ -1185,7 +1159,7 @@ def test_answerImmediateDenyOpenID2(self): usu = answer.fields.getArg(OPENID_NS, 'user_setup_url') expected_substr = 'openid.claimed_id=https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fclaimed-id.test%2F' - self.failUnless(expected_substr in usu, usu) + self.assertIn(expected_substr, usu) def test_answerImmediateDenyOpenID1(self): """Look for user_setup_url in checkid_immediate negative @@ -1199,10 +1173,9 @@ def test_answerImmediateDenyOpenID1(self): self.assertEqual(answer.request, self.request) self.assertEqual(len(answer.fields.toPostArgs()), 2, answer.fields) self.assertEqual(answer.fields.getOpenIDNamespace(), OPENID1_NS) - self.failUnless(answer.fields.namespaces.isImplicit(OPENID1_NS)) + self.assertTrue(answer.fields.namespaces.isImplicit(OPENID1_NS)) self.assertEqual(answer.fields.getArg(OPENID_NS, 'mode'), 'id_res') - self.failUnless(answer.fields.getArg( - OPENID_NS, 'user_setup_url', '').startswith(server_url)) + self.assertTrue(answer.fields.getArg(OPENID_NS, 'user_setup_url', '').startswith(server_url)) def test_answerSetupDeny(self): answer = self.request.answer(False) @@ -1374,8 +1347,8 @@ def test_dhSHA1(self): self.assertEqual(rfg("assoc_handle"), self.assoc.handle) self.failIf(rfg("mac_key")) self.assertEqual(rfg("session_type"), "DH-SHA1") - self.failUnless(rfg("enc_mac_key")) - self.failUnless(rfg("dh_server_public")) + self.assertTrue(rfg("enc_mac_key")) + self.assertTrue(rfg("dh_server_public")) enc_key = rfg("enc_mac_key").decode('base64') spub = cryptutil.base64ToLong(rfg("dh_server_public")) @@ -1399,8 +1372,8 @@ def test_dhSHA256(self): self.assertEqual(rfg("assoc_handle"), self.assoc.handle) self.failIf(rfg("mac_key")) self.assertEqual(rfg("session_type"), "DH-SHA256") - self.failUnless(rfg("enc_mac_key")) - self.failUnless(rfg("dh_server_public")) + self.assertTrue(rfg("enc_mac_key")) + self.assertTrue(rfg("dh_server_public")) enc_key = rfg("enc_mac_key").decode('base64') spub = cryptutil.base64ToLong(rfg("dh_server_public")) @@ -1490,7 +1463,7 @@ def failUnlessExpiresInMatches(self, msg, expected_expires_in): error_message = ('"expires_in" value not within %s of expected: ' 'expected=%s, actual=%s' % (slop, expected_expires_in, expires_in)) - self.failUnless(0 <= difference <= slop, error_message) + self.assertTrue(0 <= difference <= slop, error_message) def test_plaintext(self): self.assoc = self.signatory.createAssociation(dumb=False, assoc_type='HMAC-SHA1') @@ -1626,7 +1599,7 @@ def monkeyDo(request): def test_associate(self): request = server.AssociateRequest.fromMessage(Message.fromPostArgs({})) response = self.server.openid_associate(request) - self.failUnless(response.fields.hasKey(OPENID_NS, "assoc_handle"), + self.assertTrue(response.fields.hasKey(OPENID_NS, "assoc_handle"), "No assoc_handle here: %s" % (response.fields,)) def test_associate2(self): @@ -1647,8 +1620,8 @@ def test_associate2(self): request = server.AssociateRequest.fromMessage(msg) response = self.server.openid_associate(request) - self.failUnless(response.fields.hasKey(OPENID_NS, "error")) - self.failUnless(response.fields.hasKey(OPENID_NS, "error_code")) + self.assertTrue(response.fields.hasKey(OPENID_NS, "error")) + self.assertTrue(response.fields.hasKey(OPENID_NS, "error_code")) self.failIf(response.fields.hasKey(OPENID_NS, "assoc_handle")) self.failIf(response.fields.hasKey(OPENID_NS, "assoc_type")) self.failIf(response.fields.hasKey(OPENID_NS, "session_type")) @@ -1670,8 +1643,8 @@ def test_associate3(self): request = server.AssociateRequest.fromMessage(msg) response = self.server.openid_associate(request) - self.failUnless(response.fields.hasKey(OPENID_NS, "error")) - self.failUnless(response.fields.hasKey(OPENID_NS, "error_code")) + self.assertTrue(response.fields.hasKey(OPENID_NS, "error")) + self.assertTrue(response.fields.hasKey(OPENID_NS, "error_code")) self.failIf(response.fields.hasKey(OPENID_NS, "assoc_handle")) self.assertEqual(response.fields.getArg(OPENID_NS, "assoc_type"), 'HMAC-SHA256') self.assertEqual(response.fields.getArg(OPENID_NS, "session_type"), 'DH-SHA256') @@ -1691,7 +1664,7 @@ def test_associate4(self): message = Message.fromPostArgs(query) request = server.AssociateRequest.fromMessage(message) response = self.server.openid_associate(request) - self.failUnless(response.fields.hasKey(OPENID_NS, "assoc_handle")) + self.assertTrue(response.fields.hasKey(OPENID_NS, "assoc_handle")) def test_missingSessionTypeOpenID2(self): """Make sure session_type is required in OpenID 2""" @@ -1713,7 +1686,7 @@ def test_missingAssocTypeOpenID2(self): def test_checkAuth(self): request = server.CheckAuthRequest('arrrrrf', '0x3999', []) response = self.server.openid_check_authentication(request) - self.failUnless(response.fields.hasKey(OPENID_NS, "is_valid")) + self.assertTrue(response.fields.hasKey(OPENID_NS, "is_valid")) class TestSignatory(unittest.TestCase, CatchLogs): @@ -1742,7 +1715,7 @@ def test_sign(self): sresponse = self.signatory.sign(response) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,signed') - self.failUnless(sresponse.fields.getArg(OPENID_NS, 'sig')) + self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) self.failIf(self.messages, self.messages) def test_signDumb(self): @@ -1758,11 +1731,11 @@ def test_signDumb(self): }) sresponse = self.signatory.sign(response) assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') - self.failUnless(assoc_handle) + self.assertTrue(assoc_handle) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) - self.failUnless(assoc) + self.assertTrue(assoc) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,ns,signed') - self.failUnless(sresponse.fields.getArg(OPENID_NS, 'sig')) + self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) self.failIf(self.messages, self.messages) def test_signExpired(self): @@ -1786,7 +1759,7 @@ def test_signExpired(self): self._normal_key, association.Association.fromExpiresIn(-10, assoc_handle, 'sekrit', 'HMAC-SHA1')) - self.failUnless(self.store.getAssociation(self._normal_key, assoc_handle)) + self.assertTrue(self.store.getAssociation(self._normal_key, assoc_handle)) request.assoc_handle = assoc_handle response = server.OpenIDResponse(request) @@ -1798,23 +1771,23 @@ def test_signExpired(self): sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') - self.failUnless(new_assoc_handle) + self.assertTrue(new_assoc_handle) self.failIfEqual(new_assoc_handle, assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'invalidate_handle'), assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,invalidate_handle,signed') - self.failUnless(sresponse.fields.getArg(OPENID_NS, 'sig')) + self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) # make sure the expired association is gone self.failIf(self.store.getAssociation(self._normal_key, assoc_handle), "expired association is still retrievable.") # make sure the new key is a dumb mode association - self.failUnless(self.store.getAssociation(self._dumb_key, new_assoc_handle)) + self.assertTrue(self.store.getAssociation(self._dumb_key, new_assoc_handle)) self.failIf(self.store.getAssociation(self._normal_key, new_assoc_handle)) - self.failUnless(self.messages) + self.assertTrue(self.messages) def test_signInvalidHandle(self): request = server.OpenIDRequest() @@ -1831,17 +1804,17 @@ def test_signInvalidHandle(self): sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') - self.failUnless(new_assoc_handle) + self.assertTrue(new_assoc_handle) self.failIfEqual(new_assoc_handle, assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'invalidate_handle'), assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,invalidate_handle,signed') - self.failUnless(sresponse.fields.getArg(OPENID_NS, 'sig')) + self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) # make sure the new key is a dumb mode association - self.failUnless(self.store.getAssociation(self._dumb_key, new_assoc_handle)) + self.assertTrue(self.store.getAssociation(self._dumb_key, new_assoc_handle)) self.failIf(self.store.getAssociation(self._normal_key, new_assoc_handle)) self.failIf(self.messages, self.messages) @@ -1862,7 +1835,7 @@ def test_verify(self): verified = self.signatory.verify(assoc_handle, signed) self.failIf(self.messages, self.messages) - self.failUnless(verified) + self.assertTrue(verified) def test_verifyBadSig(self): assoc_handle = '{vroom}{zoom}' @@ -1893,7 +1866,7 @@ def test_verifyBadHandle(self): verified = self.signatory.verify(assoc_handle, signed) self.failIf(verified) - self.failUnless(self.messages) + self.assertTrue(self.messages) def test_verifyAssocMismatch(self): """Attempt to validate sign-all message with a signed-list assoc.""" @@ -1911,12 +1884,12 @@ def test_verifyAssocMismatch(self): verified = self.signatory.verify(assoc_handle, signed) self.failIf(verified) - self.failUnless(self.messages) + self.assertTrue(self.messages) def test_getAssoc(self): assoc_handle = self.makeAssoc(dumb=True) assoc = self.signatory.getAssociation(assoc_handle, True) - self.failUnless(assoc) + self.assertTrue(assoc) self.assertEqual(assoc.handle, assoc_handle) self.failIf(self.messages, self.messages) @@ -1924,7 +1897,7 @@ def test_getAssocExpired(self): assoc_handle = self.makeAssoc(dumb=True, lifetime=-10) assoc = self.signatory.getAssociation(assoc_handle, True) self.failIf(assoc, assoc) - self.failUnless(self.messages) + self.assertTrue(self.messages) def test_getAssocInvalid(self): ah = 'no-such-handle' @@ -1951,7 +1924,7 @@ def test_getAssocNormalVsDumb(self): def test_createAssociation(self): assoc = self.signatory.createAssociation(dumb=False) - self.failUnless(self.signatory.getAssociation(assoc.handle, dumb=False)) + self.assertTrue(self.signatory.getAssociation(assoc.handle, dumb=False)) self.failIf(self.messages, self.messages) def makeAssoc(self, dumb, lifetime=60): @@ -1969,9 +1942,9 @@ def test_invalidate(self): self.store.storeAssociation(self._dumb_key, assoc) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) - self.failUnless(assoc) + self.assertTrue(assoc) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) - self.failUnless(assoc) + self.assertTrue(assoc) self.signatory.invalidate(assoc_handle, dumb=True) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.failIf(assoc) diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index acb3c25f..572c7d19 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -42,12 +42,12 @@ def test_unsupported(self): def test_supported_1_1(self): endpoint = FakeEndpoint([sreg.ns_uri_1_1]) - self.failUnless(sreg.supportsSReg(endpoint)) + self.assertTrue(sreg.supportsSReg(endpoint)) self.assertEqual(endpoint.checked_uris, [sreg.ns_uri_1_1]) def test_supported_1_0(self): endpoint = FakeEndpoint([sreg.ns_uri_1_0]) - self.failUnless(sreg.supportsSReg(endpoint)) + self.assertTrue(sreg.supportsSReg(endpoint)) self.assertEqual(endpoint.checked_uris, [sreg.ns_uri_1_1, sreg.ns_uri_1_0]) @@ -116,8 +116,8 @@ def test_openID1_sregNSfromArgs(self): m = Message.fromOpenIDArgs(args) - self.failUnless(m.getArg(sreg.ns_uri_1_1, 'optional') == 'nickname') - self.failUnless(m.getArg(sreg.ns_uri_1_1, 'required') == 'dob') + self.assertEqual(m.getArg(sreg.ns_uri_1_1, 'optional'), 'nickname') + self.assertEqual(m.getArg(sreg.ns_uri_1_1, 'required'), 'dob') class SRegRequestTest(unittest.TestCase): @@ -174,7 +174,7 @@ def parseExtensionArgs(req_self, args): req = TestingReq.fromOpenIDRequest(openid_req) self.assertIsInstance(req, TestingReq) - self.failUnless(msg.copied) + self.assertTrue(msg.copied) def test_parseExtensionArgs_empty(self): req = sreg.SRegRequest() @@ -260,7 +260,7 @@ def test_wereFieldsRequested(self): req = sreg.SRegRequest() self.failIf(req.wereFieldsRequested()) req.requestField('gender') - self.failUnless(req.wereFieldsRequested()) + self.assertTrue(req.wereFieldsRequested()) def test_contains(self): req = sreg.SRegRequest() @@ -272,7 +272,7 @@ def test_contains(self): req.requestField('nickname') for field_name in sreg.data_fields: if field_name == 'nickname': - self.failUnless(field_name in req) + self.assertIn(field_name, req) else: self.failIf(field_name in req) @@ -406,7 +406,7 @@ class SRegResponseTest(unittest.TestCase): def test_construct(self): resp = sreg.SRegResponse(data) - self.failUnless(resp) + self.assertTrue(resp) empty_resp = sreg.SRegResponse({}) self.failIf(empty_resp) diff --git a/openid/test/test_symbol.py b/openid/test/test_symbol.py index 7eab28b2..9e52844a 100644 --- a/openid/test/test_symbol.py +++ b/openid/test/test_symbol.py @@ -30,7 +30,7 @@ def test_otherInequality(self): def test_ne_inequality(self): x = oidutil.Symbol('xxx') y = oidutil.Symbol('yyy') - self.failUnless(x != y) + self.assertNotEqual(x, y) if __name__ == '__main__': diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index f8086882..c0055ef9 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -17,15 +17,8 @@ def constResult(*args, **kwargs): class DiscoveryVerificationTest(OpenIDTestMixin, TestIdRes): def failUnlessProtocolError(self, prefix, callable, *args, **kwargs): - try: - result = callable(*args, **kwargs) - except consumer.ProtocolError as e: - self.failUnless( - e[0].startswith(prefix), - 'Expected message prefix %r, got message %r' % (prefix, e[0])) - else: - self.fail('Expected ProtocolError with prefix %r, ' - 'got successful return %r' % (prefix, result)) + with self.assertRaisesRegexp(consumer.ProtocolError, prefix): + callable(*args, **kwargs) def test_openID1NoLocalID(self): endpoint = discover.OpenIDServiceEndpoint() @@ -70,7 +63,7 @@ def test_openID2NoIdentifiers(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, 'op_endpoint': op_endpoint}) result_endpoint = self.consumer._verifyDiscoveryResults(msg) - self.failUnless(result_endpoint.isOPIdentifier()) + self.assertTrue(result_endpoint.isOPIdentifier()) self.assertEqual(result_endpoint.server_url, op_endpoint) self.assertIsNone(result_endpoint.claimed_id) self.failUnlessLogEmpty() @@ -121,7 +114,7 @@ def test_openid2UsePreDiscovered(self): 'claimed_id': endpoint.claimed_id, 'op_endpoint': endpoint.server_url}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) - self.failUnless(result is endpoint) + self.assertEqual(result, endpoint) self.failUnlessLogEmpty() def test_openid2UsePreDiscoveredWrongType(self): @@ -147,13 +140,8 @@ def discoverAndVerify(claimed_id, to_match_endpoints): 'claimed_id': endpoint.claimed_id, 'op_endpoint': endpoint.server_url}) - try: - r = self.consumer._verifyDiscoveryResults(msg, endpoint) - except consumer.ProtocolError as e: - # Should we make more ProtocolError subclasses? - self.failUnless(str(e), text) - else: - self.fail("expected ProtocolError, %r returned." % (r,)) + with self.assertRaisesRegexp(consumer.ProtocolError, text): + self.consumer._verifyDiscoveryResults(msg, endpoint) self.failUnlessLogMatches('Error attempting to use stored', 'Attempting discovery') @@ -169,7 +157,7 @@ def test_openid1UsePreDiscovered(self): {'ns': message.OPENID1_NS, 'identity': endpoint.local_id}) result = self.consumer._verifyDiscoveryResults(msg, endpoint) - self.failUnless(result is endpoint) + self.assertEqual(result, endpoint) self.failUnlessLogEmpty() def test_openid1UsePreDiscoveredWrongType(self): @@ -243,7 +231,7 @@ def test_openid1Fallback1_0(self): actual_endpoint = self.consumer._verifyDiscoveryResults( resp_mesg, endpoint) - self.failUnless(actual_endpoint is expected_endpoint) + self.assertEqual(actual_endpoint, expected_endpoint) # XXX: test the implementation of _discoverAndVerify From 42523afffd293ea6acc947bf97eaaec093e521ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jan 2018 11:09:43 +0100 Subject: [PATCH 036/151] Replace failIf --- examples/djopenid/server/tests.py | 4 +- openid/test/support.py | 2 +- openid/test/test_association_response.py | 2 +- openid/test/test_ax.py | 16 ++--- openid/test/test_consumer.py | 19 +++--- openid/test/test_discover.py | 32 +++++---- openid/test/test_extension.py | 2 +- openid/test/test_fetchers.py | 3 +- openid/test/test_message.py | 20 +++--- openid/test/test_pape_draft5.py | 2 +- openid/test/test_rpverify.py | 18 ++---- openid/test/test_server.py | 82 ++++++++++++------------ openid/test/test_sreg.py | 14 ++-- openid/test/test_symbol.py | 4 +- 14 files changed, 105 insertions(+), 115 deletions(-) diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index 0beefc21..2a3b86b4 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -62,8 +62,8 @@ def test_cancel(self): self.assertEqual(response.status_code, HTTP_REDIRECT) finalURL = response['location'] self.assertIn('openid.mode=cancel', finalURL) - self.failIf('openid.identity=' in finalURL, finalURL) - self.failIf('openid.sreg.postcode=12345' in finalURL, finalURL) + self.assertNotIn('openid.identity=', finalURL) + self.assertNotIn('openid.sreg.postcode=12345', finalURL) class TestShowDecidePage(TestCase): diff --git a/openid/test/support.py b/openid/test/support.py index 04749042..c2d45ea3 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -32,7 +32,7 @@ def failIfOpenIDKeyExists(self, msg, key, ns=None): actual = msg.getArg(ns, key) error_message = 'openid.%s unexpectedly present: %s' % (key, actual) - self.failIf(actual is not None, error_message) + self.assertIsNone(actual, error_message) class CatchLogs(object): diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index d880a407..9bac3e21 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -317,7 +317,7 @@ def _setUpDH(self): def test_success(self): sess, server_resp = self._setUpDH() ret = self.consumer._extractAssociation(server_resp, sess) - self.failIf(ret is None) + self.assertIsNotNone(ret) self.assertEqual(ret.assoc_type, 'HMAC-SHA1') self.assertEqual(ret.secret, self.secret) self.assertEqual(ret.handle, 'handle') diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index d6bdaa24..83fe5cf1 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -45,7 +45,7 @@ def test_construct(self): self.assertEqual(ainfo.type_uri, type_uri) self.assertEqual(ainfo.count, 1) - self.failIf(ainfo.required) + self.assertFalse(ainfo.required) self.assertIsNone(ainfo.alias) @@ -209,7 +209,7 @@ def test_add(self): uri = 'mud://puddle' # Not yet added: - self.failIf(uri in self.msg) + self.assertNotIn(uri, self.msg) attr = ax.AttrInfo(uri) self.msg.add(attr) @@ -305,7 +305,7 @@ def test_parseExtensionArgs(self): self.assertEqual(list(self.msg), [self.type_a]) attr_info = self.msg.requested_attributes.get(self.type_a) self.assertIsNotNone(attr_info) - self.failIf(attr_info.required) + self.assertFalse(attr_info.required) self.assertEqual(attr_info.type_uri, self.type_a) self.assertEqual(attr_info.alias, self.alias_a) self.assertEqual(list(self.msg.iterAttrs()), [attr_info]) @@ -318,7 +318,7 @@ def test_extensionArgs_idempotent(self): } self.msg.parseExtensionArgs(extension_args) self.assertEqual(self.msg.getExtensionArgs(), extension_args) - self.failIf(self.msg.requested_attributes[self.type_a].required) + self.assertFalse(self.msg.requested_attributes[self.type_a].required) def test_extensionArgs_idempotent_count_required(self): extension_args = { @@ -602,18 +602,18 @@ class StoreResponseTest(unittest.TestCase): def test_success(self): msg = ax.StoreResponse() self.assertTrue(msg.succeeded()) - self.failIf(msg.error_message) + self.assertFalse(msg.error_message) self.assertEqual(msg.getExtensionArgs(), {'mode': 'store_response_success'}) def test_fail_nomsg(self): msg = ax.StoreResponse(False) - self.failIf(msg.succeeded()) - self.failIf(msg.error_message) + self.assertFalse(msg.succeeded()) + self.assertFalse(msg.error_message) self.assertEqual(msg.getExtensionArgs(), {'mode': 'store_response_failure'}) def test_fail_msg(self): reason = 'no reason, really' msg = ax.StoreResponse(False, reason) - self.failIf(msg.succeeded()) + self.assertFalse(msg.succeeded()) self.assertEqual(msg.error_message, reason) self.assertEqual(msg.getExtensionArgs(), {'mode': 'store_response_failure', 'error': reason}) diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 68caf0c8..169a4233 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -570,13 +570,13 @@ def test_missingAnswer(self): """check_authentication returns false when the server sends no answer""" response = Message.fromOpenIDArgs({}) r = self.consumer._processCheckAuthResponse(response, self.server_url) - self.failIf(r) + self.assertFalse(r) def test_badResponse(self): """check_authentication returns false when is_valid is false""" response = Message.fromOpenIDArgs({'is_valid': 'false'}) r = self.consumer._processCheckAuthResponse(response, self.server_url) - self.failIf(r) + self.assertFalse(r) def test_badResponseInvalidate(self): """Make sure that the handle is invalidated when is_valid is false @@ -593,7 +593,7 @@ def test_badResponseInvalidate(self): 'invalidate_handle': 'handle', }) r = self.consumer._processCheckAuthResponse(response, self.server_url) - self.failIf(r) + self.assertFalse(r) self.assertIsNone(self.consumer.store.getAssociation(self.server_url)) def test_invalidateMissing(self): @@ -1095,7 +1095,7 @@ def test_completeBadReturnTo(self): for bad in bad_return_tos: m.setArg(OPENID_NS, 'return_to', bad) - self.failIf(self.consumer._checkReturnTo(m, return_to)) + self.assertFalse(self.consumer._checkReturnTo(m, return_to)) def test_completeGoodReturnTo(self): """Test GenericConsumer.complete()'s handling of good @@ -1180,7 +1180,7 @@ def test_error(self): 'openid.stuff': 'a value'} r = self.consumer._checkAuth(Message.fromPostArgs(query), http_server_url) - self.failIf(r) + self.assertFalse(r) self.assertTrue(self.messages) def test_bad_args(self): @@ -1286,7 +1286,7 @@ def test_error_exception_wrapped(self): self.assertIsNone(self.consumer._getAssociation(e)) msg = Message.fromPostArgs({'openid.signed': ''}) - self.failIf(self.consumer._checkAuth(msg, 'some://url')) + self.assertFalse(self.consumer._checkAuth(msg, 'some://url')) class TestSuccessResponse(unittest.TestCase): @@ -1482,7 +1482,7 @@ def _doResp(self, auth_req, exp_resp): if self.endpoint.claimed_id != IDENTIFIER_SELECT: self.assertEqual(resp.identity_url, self.identity_url) - self.failIf(self.consumer._token_key in self.session) + self.assertNotIn(self.consumer._token_key, self.session) # Expected status response self.assertEqual(resp.status, exp_resp.status) @@ -1494,7 +1494,7 @@ def _doRespNoDisco(self, exp_resp): auth_req = self.consumer.beginWithoutDiscovery(self.endpoint) resp = self._doResp(auth_req, exp_resp) # There should be nothing left in the session once we have completed. - self.failIf(self.session) + self.assertFalse(self.session) return resp def test_noDiscoCompleteSuccessWithToken(self): @@ -1522,11 +1522,10 @@ def _doRespDisco(self, is_clean, exp_resp): auth_req = self.consumer.begin(self.identity_url) resp = self._doResp(auth_req, exp_resp) - manager = self.discovery.getManager() if is_clean: self.assertIsNone(self.discovery.getManager()) else: - self.failIf(self.discovery.getManager() is None, manager) + self.assertIsNotNone(self.discovery.getManager()) return resp diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 2a7b4d8e..458ed636 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -154,12 +154,12 @@ def _checkService(self, s, ): self.assertEqual(s.server_url, server_url) if types == ['2.0 OP']: - self.failIf(claimed_id) - self.failIf(local_id) - self.failIf(s.claimed_id) - self.failIf(s.local_id) - self.failIf(s.getLocalID()) - self.failIf(s.compatibilityMode()) + self.assertIsNone(claimed_id) + self.assertIsNone(local_id) + self.assertIsNone(s.claimed_id) + self.assertIsNone(s.local_id) + self.assertIsNone(s.getLocalID()) + self.assertFalse(s.compatibilityMode()) self.assertTrue(s.isOPIdentifier()) self.assertEqual(s.preferredNamespace(), discover.OPENID_2_0_MESSAGE_NS) else: @@ -169,8 +169,7 @@ def _checkService(self, s, if used_yadis: self.assertTrue(s.used_yadis, "Expected to use Yadis") else: - self.failIf(s.used_yadis, - "Expected to use old-style discovery") + self.assertFalse(s.used_yadis, "Expected to use old-style discovery") openid_types = { '1.1': discover.OPENID_1_1_TYPE, @@ -570,7 +569,7 @@ def test_xri_normalize(self): def test_xriNoCanonicalID(self): user_xri, services = discover.discoverXRI('=smoker*bad') - self.failIf(services) + self.assertFalse(services) def test_useCanonicalID(self): """When there is no delegate, the CanonicalID should be used with XRI. @@ -621,19 +620,19 @@ def setUp(self): self.endpoint = discover.OpenIDServiceEndpoint() def test_none(self): - self.failIf(self.endpoint.isOPIdentifier()) + self.assertFalse(self.endpoint.isOPIdentifier()) def test_openid1_0(self): self.endpoint.type_uris = [discover.OPENID_1_0_TYPE] - self.failIf(self.endpoint.isOPIdentifier()) + self.assertFalse(self.endpoint.isOPIdentifier()) def test_openid1_1(self): self.endpoint.type_uris = [discover.OPENID_1_1_TYPE] - self.failIf(self.endpoint.isOPIdentifier()) + self.assertFalse(self.endpoint.isOPIdentifier()) def test_openid2(self): self.endpoint.type_uris = [discover.OPENID_2_0_TYPE] - self.failIf(self.endpoint.isOPIdentifier()) + self.assertFalse(self.endpoint.isOPIdentifier()) def test_openid2OP(self): self.endpoint.type_uris = [discover.OPENID_IDP_2_0_TYPE] @@ -642,7 +641,7 @@ def test_openid2OP(self): def test_multipleMissing(self): self.endpoint.type_uris = [discover.OPENID_2_0_TYPE, discover.OPENID_1_0_TYPE] - self.failIf(self.endpoint.isOPIdentifier()) + self.assertFalse(self.endpoint.isOPIdentifier()) def test_multiplePresent(self): self.endpoint.type_uris = [discover.OPENID_2_0_TYPE, @@ -665,7 +664,7 @@ def test_noIdentifiers(self): self.assertIsNone(self.endpoint.claimed_id) def test_compatibility(self): - self.failIf(self.endpoint.compatibilityMode()) + self.assertFalse(self.endpoint.compatibilityMode()) def test_canonicalID(self): self.assertIsNone(self.endpoint.canonicalID) @@ -720,8 +719,7 @@ def failUnlessSupportsOnly(self, *types): if t in types: self.assertTrue(self.endpoint.supportsType(t), "Must support %r" % t) else: - self.failIf(self.endpoint.supportsType(t), - "Shouldn't support %r" % (t,)) + self.assertFalse(self.endpoint.supportsType(t), "Shouldn't support %r" % (t,)) def test_supportsNothing(self): self.failUnlessSupportsOnly() diff --git a/openid/test/test_extension.py b/openid/test/test_extension.py index 4bccd754..640f11a6 100644 --- a/openid/test/test_extension.py +++ b/openid/test/test_extension.py @@ -26,6 +26,6 @@ def test_OpenID2(self): ext = DummyExtension() ext.toMessage(oid2_msg) namespaces = oid2_msg.namespaces - self.failIf(namespaces.isImplicit(DummyExtension.ns_uri)) + self.assertFalse(namespaces.isImplicit(DummyExtension.ns_uri)) self.assertEqual(DummyExtension.ns_uri, namespaces.getNamespaceURI(DummyExtension.ns_alias)) self.assertEqual(DummyExtension.ns_alias, namespaces.getAlias(DummyExtension.ns_uri)) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index e6df2d53..b1c066d0 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -259,8 +259,7 @@ def test_notWrapped(self): fetcher = fetchers.Urllib2Fetcher() fetchers.setDefaultFetcher(fetcher, wrap_exceptions=False) - self.failIf(isinstance(fetchers.getDefaultFetcher(), - fetchers.ExceptionWrappingFetcher)) + self.assertNotIsInstance(fetchers.getDefaultFetcher(), fetchers.ExceptionWrappingFetcher) with self.assertRaises(urllib2.URLError): fetchers.fetch('http://invalid.janrain.com/') diff --git a/openid/test/test_message.py b/openid/test/test_message.py index c414313a..3ce94eaa 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -199,10 +199,10 @@ def test_delArgNS3(self): self._test_delArgNS('urn:nothing-significant') def test_isOpenID1(self): - self.failIf(self.msg.isOpenID1()) + self.assertFalse(self.msg.isOpenID1()) def test_isOpenID2(self): - self.failIf(self.msg.isOpenID2()) + self.assertFalse(self.msg.isOpenID2()) class OpenID1MessageTest(unittest.TestCase): @@ -368,7 +368,7 @@ def test_isOpenID1(self): self.assertTrue(self.msg.isOpenID1()) def test_isOpenID2(self): - self.failIf(self.msg.isOpenID2()) + self.assertFalse(self.msg.isOpenID2()) class OpenID1ExplicitMessageTest(unittest.TestCase): @@ -702,7 +702,7 @@ def test_argList(self): self.assertRaises(TypeError, self.msg.fromPostArgs, {'arg': [1, 2, 3]}) def test_isOpenID1(self): - self.failIf(self.msg.isOpenID1()) + self.assertFalse(self.msg.isOpenID1()) def test_isOpenID2(self): self.assertTrue(self.msg.isOpenID2()) @@ -821,10 +821,10 @@ def test_toFormMarkup_bug_with_utf8_values(self): # encoded strings to be converted to XML character references, # "ünicöde_key" becomes "ünicöde_key" and # "ünicöde_välüe" becomes "ünicöde_välüe" - self.failIf('ünicöde_key' in html, - 'UTF-8 bytes should not convert to XML character references') - self.failIf('ünicöde_välüe' in html, - 'UTF-8 bytes should not convert to XML character references') + self.assertNotIn('ünicöde_key', html, + 'UTF-8 bytes should not convert to XML character references') + self.assertNotIn('ünicöde_välüe', html, + 'UTF-8 bytes should not convert to XML character references') def test_overrideMethod(self): """Be sure that caller cannot change form method to GET.""" @@ -887,13 +887,13 @@ def test_isOpenID2(self): ns = 'http://specs.openid.net/auth/2.0' m = message.Message(ns) self.assertTrue(m.isOpenID2()) - self.failIf(m.namespaces.isImplicit(message.NULL_NAMESPACE)) + self.assertFalse(m.namespaces.isImplicit(message.NULL_NAMESPACE)) self.assertEqual(m.getOpenIDNamespace(), ns) def test_setOpenIDNamespace_explicit(self): m = message.Message() m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, False) - self.failIf(m.namespaces.isImplicit(message.THE_OTHER_OPENID1_NS)) + self.assertFalse(m.namespaces.isImplicit(message.THE_OTHER_OPENID1_NS)) def test_setOpenIDNamespace_implicit(self): m = message.Message() diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index 89c79fda..104411a3 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -17,7 +17,7 @@ def test_construct(self): self.assertEqual(self.req.preferred_auth_policies, []) self.assertIsNone(self.req.max_auth_age) self.assertEqual(self.req.ns_alias, 'pape') - self.failIf(self.req.preferred_auth_level_types) + self.assertFalse(self.req.preferred_auth_level_types) bogus_levels = ['http://janrain.com/our_levels'] req2 = pape.Request( diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index 769037a3..d12cc5b9 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -167,7 +167,7 @@ def test_twoEntries_withOther(self): class TestReturnToMatches(unittest.TestCase): def test_noEntries(self): - self.failIf(trustroot.returnToMatches([], 'anything')) + self.assertFalse(trustroot.returnToMatches([], 'anything')) def test_exactMatch(self): r = 'http://example.com/return.to' @@ -183,15 +183,11 @@ def test_descendant(self): self.assertTrue(trustroot.returnToMatches([r], 'http://example.com/return.to/user:joe')) def test_wildcard(self): - self.failIf(trustroot.returnToMatches( - ['http://*.example.com/return.to'], - 'http://example.com/return.to')) + self.assertFalse(trustroot.returnToMatches(['http://*.example.com/return.to'], 'http://example.com/return.to')) def test_noMatch(self): r = 'http://example.com/return.to' - self.failIf(trustroot.returnToMatches( - [r], - 'http://example.com/xss_exploit')) + self.assertFalse(trustroot.returnToMatches([r], 'http://example.com/xss_exploit')) class TestVerifyReturnTo(unittest.TestCase, CatchLogs): @@ -203,7 +199,7 @@ def tearDown(self): CatchLogs.tearDown(self) def test_bogusRealm(self): - self.failIf(trustroot.verifyReturnTo('', 'http://example.com/')) + self.assertFalse(trustroot.verifyReturnTo('', 'http://example.com/')) def test_verifyWithDiscoveryCalled(self): realm = 'http://*.example.com/' @@ -224,8 +220,7 @@ def vrfy(disco_url): self.assertEqual(disco_url, 'http://www.example.com/') return ['http://something-else.invalid/'] - self.failIf( - trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) self.failUnlessLogMatches("Failed to validate return_to") def test_verifyFailIfDiscoveryRedirects(self): @@ -236,8 +231,7 @@ def vrfy(disco_url): raise trustroot.RealmVerificationRedirected( disco_url, "http://redirected.invalid") - self.failIf( - trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) self.failUnlessLogMatches("Attempting to verify") diff --git a/openid/test/test_server.py b/openid/test/test_server.py index c5b1c4ca..38aa9c05 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -110,7 +110,7 @@ def test_noReturnTo(self): 'openid.identity': 'http://wagu.unittest/', }) e = server.ProtocolError(args, "waffles") - self.failIf(e.hasReturnTo()) + self.assertFalse(e.hasReturnTo()) expected = """error:waffles mode:error """ @@ -118,7 +118,7 @@ def test_noReturnTo(self): def test_noMessage(self): e = server.ProtocolError(None, "no moar pancakes") - self.failIf(e.hasReturnTo()) + self.assertFalse(e.hasReturnTo()) self.assertIsNone(e.whichEncoding()) @@ -505,7 +505,7 @@ def test_id_res_OpenID2_GET(self): 'return_to': request.return_to, }) - self.failIf(response.renderAsForm()) + self.assertFalse(response.renderAsForm()) self.assertEqual(response.whichEncoding(), server.ENCODE_URL) webresponse = self.encode(response) self.assertIn('location', webresponse.headers) @@ -606,7 +606,7 @@ def test_id_res_OpenID1_exceeds_limit(self): 'return_to': 'x' * OPENID1_URL_LIMIT, }) - self.failIf(response.renderAsForm()) + self.assertFalse(response.renderAsForm()) self.assertGreater(len(response.encodeToURL()), OPENID1_URL_LIMIT) self.assertEqual(response.whichEncoding(), server.ENCODE_URL) webresponse = self.encode(response) @@ -795,7 +795,7 @@ def test_cancel(self): self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] query = cgi.parse_qs(urlparse(location)[4]) - self.failIf('openid.sig' in query, response.fields.toPostArgs()) + self.assertNotIn('openid.sig', query, response.fields.toPostArgs()) def test_assocReply(self): msg = Message(OPENID2_NS) @@ -833,7 +833,7 @@ def setUp(self): def test_trustRootInvalid(self): self.request.trust_root = "http://foo.unittest/17" self.request.return_to = "http://foo.unittest/39" - self.failIf(self.request.trustRootValid()) + self.assertFalse(self.request.trustRootValid()) def test_trustRootValid(self): self.request.trust_root = "http://foo.unittest/" @@ -1345,7 +1345,7 @@ def test_dhSHA1(self): rfg = partial(response.fields.getArg, OPENID_NS) self.assertEqual(rfg("assoc_type"), "HMAC-SHA1") self.assertEqual(rfg("assoc_handle"), self.assoc.handle) - self.failIf(rfg("mac_key")) + self.assertFalse(rfg("mac_key")) self.assertEqual(rfg("session_type"), "DH-SHA1") self.assertTrue(rfg("enc_mac_key")) self.assertTrue(rfg("dh_server_public")) @@ -1370,7 +1370,7 @@ def test_dhSHA256(self): rfg = partial(response.fields.getArg, OPENID_NS) self.assertEqual(rfg("assoc_type"), "HMAC-SHA256") self.assertEqual(rfg("assoc_handle"), self.assoc.handle) - self.failIf(rfg("mac_key")) + self.assertFalse(rfg("mac_key")) self.assertEqual(rfg("session_type"), "DH-SHA256") self.assertTrue(rfg("enc_mac_key")) self.assertTrue(rfg("dh_server_public")) @@ -1478,9 +1478,9 @@ def test_plaintext(self): response.fields, self.signatory.SECRET_LIFETIME) self.assertEqual(rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) - self.failIf(rfg("session_type")) - self.failIf(rfg("enc_mac_key")) - self.failIf(rfg("dh_server_public")) + self.assertFalse(rfg("session_type")) + self.assertFalse(rfg("enc_mac_key")) + self.assertFalse(rfg("dh_server_public")) def test_plaintext_v2(self): # The main difference between this and the v1 test is that @@ -1494,7 +1494,7 @@ def test_plaintext_v2(self): self.request = server.AssociateRequest.fromMessage( Message.fromPostArgs(args)) - self.failIf(self.request.message.isOpenID1()) + self.assertFalse(self.request.message.isOpenID1()) self.assoc = self.signatory.createAssociation( dumb=False, assoc_type='HMAC-SHA1') @@ -1511,8 +1511,8 @@ def test_plaintext_v2(self): self.assertEqual(rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) self.assertEqual(rfg("session_type"), "no-encryption") - self.failIf(rfg("enc_mac_key")) - self.failIf(rfg("dh_server_public")) + self.assertFalse(rfg("enc_mac_key")) + self.assertFalse(rfg("dh_server_public")) def test_plaintext256(self): self.assoc = self.signatory.createAssociation(dumb=False, assoc_type='HMAC-SHA256') @@ -1527,9 +1527,9 @@ def test_plaintext256(self): response.fields, self.signatory.SECRET_LIFETIME) self.assertEqual(rfg("mac_key"), oidutil.toBase64(self.assoc.secret)) - self.failIf(rfg("session_type")) - self.failIf(rfg("enc_mac_key")) - self.failIf(rfg("dh_server_public")) + self.assertFalse(rfg("session_type")) + self.assertFalse(rfg("enc_mac_key")) + self.assertFalse(rfg("dh_server_public")) def test_unsupportedPrefer(self): allowed_assoc = 'COLD-PET-RAT' @@ -1622,9 +1622,9 @@ def test_associate2(self): response = self.server.openid_associate(request) self.assertTrue(response.fields.hasKey(OPENID_NS, "error")) self.assertTrue(response.fields.hasKey(OPENID_NS, "error_code")) - self.failIf(response.fields.hasKey(OPENID_NS, "assoc_handle")) - self.failIf(response.fields.hasKey(OPENID_NS, "assoc_type")) - self.failIf(response.fields.hasKey(OPENID_NS, "session_type")) + self.assertFalse(response.fields.hasKey(OPENID_NS, "assoc_handle")) + self.assertFalse(response.fields.hasKey(OPENID_NS, "assoc_type")) + self.assertFalse(response.fields.hasKey(OPENID_NS, "session_type")) def test_associate3(self): """Request an assoc type that is not supported when there are @@ -1645,7 +1645,7 @@ def test_associate3(self): self.assertTrue(response.fields.hasKey(OPENID_NS, "error")) self.assertTrue(response.fields.hasKey(OPENID_NS, "error_code")) - self.failIf(response.fields.hasKey(OPENID_NS, "assoc_handle")) + self.assertFalse(response.fields.hasKey(OPENID_NS, "assoc_handle")) self.assertEqual(response.fields.getArg(OPENID_NS, "assoc_type"), 'HMAC-SHA256') self.assertEqual(response.fields.getArg(OPENID_NS, "session_type"), 'DH-SHA256') @@ -1716,7 +1716,7 @@ def test_sign(self): self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,signed') self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) - self.failIf(self.messages, self.messages) + self.assertFalse(self.messages) def test_signDumb(self): request = server.OpenIDRequest() @@ -1736,7 +1736,7 @@ def test_signDumb(self): self.assertTrue(assoc) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,ns,signed') self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) - self.failIf(self.messages, self.messages) + self.assertFalse(self.messages) def test_signExpired(self): """Sign a response to a message with an expired handle (using invalidate_handle). @@ -1781,12 +1781,12 @@ def test_signExpired(self): self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) # make sure the expired association is gone - self.failIf(self.store.getAssociation(self._normal_key, assoc_handle), - "expired association is still retrievable.") + self.assertFalse(self.store.getAssociation(self._normal_key, assoc_handle), + "expired association is still retrievable.") # make sure the new key is a dumb mode association self.assertTrue(self.store.getAssociation(self._dumb_key, new_assoc_handle)) - self.failIf(self.store.getAssociation(self._normal_key, new_assoc_handle)) + self.assertFalse(self.store.getAssociation(self._normal_key, new_assoc_handle)) self.assertTrue(self.messages) def test_signInvalidHandle(self): @@ -1815,8 +1815,8 @@ def test_signInvalidHandle(self): # make sure the new key is a dumb mode association self.assertTrue(self.store.getAssociation(self._dumb_key, new_assoc_handle)) - self.failIf(self.store.getAssociation(self._normal_key, new_assoc_handle)) - self.failIf(self.messages, self.messages) + self.assertFalse(self.store.getAssociation(self._normal_key, new_assoc_handle)) + self.assertFalse(self.messages) def test_verify(self): assoc_handle = '{vroom}{zoom}' @@ -1834,7 +1834,7 @@ def test_verify(self): }) verified = self.signatory.verify(assoc_handle, signed) - self.failIf(self.messages, self.messages) + self.assertFalse(self.messages) self.assertTrue(verified) def test_verifyBadSig(self): @@ -1853,8 +1853,8 @@ def test_verifyBadSig(self): }) verified = self.signatory.verify(assoc_handle, signed) - self.failIf(self.messages, self.messages) - self.failIf(verified) + self.assertFalse(self.messages) + self.assertFalse(verified) def test_verifyBadHandle(self): assoc_handle = '{vroom}{zoom}' @@ -1865,7 +1865,7 @@ def test_verifyBadHandle(self): }) verified = self.signatory.verify(assoc_handle, signed) - self.failIf(verified) + self.assertFalse(verified) self.assertTrue(self.messages) def test_verifyAssocMismatch(self): @@ -1883,7 +1883,7 @@ def test_verifyAssocMismatch(self): }) verified = self.signatory.verify(assoc_handle, signed) - self.failIf(verified) + self.assertFalse(verified) self.assertTrue(self.messages) def test_getAssoc(self): @@ -1891,24 +1891,24 @@ def test_getAssoc(self): assoc = self.signatory.getAssociation(assoc_handle, True) self.assertTrue(assoc) self.assertEqual(assoc.handle, assoc_handle) - self.failIf(self.messages, self.messages) + self.assertFalse(self.messages) def test_getAssocExpired(self): assoc_handle = self.makeAssoc(dumb=True, lifetime=-10) assoc = self.signatory.getAssociation(assoc_handle, True) - self.failIf(assoc, assoc) + self.assertFalse(assoc) self.assertTrue(self.messages) def test_getAssocInvalid(self): ah = 'no-such-handle' self.assertIsNone(self.signatory.getAssociation(ah, dumb=False)) - self.failIf(self.messages, self.messages) + self.assertFalse(self.messages) def test_getAssocDumbVsNormal(self): """getAssociation(dumb=False) cannot get a dumb assoc""" assoc_handle = self.makeAssoc(dumb=True) self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=False)) - self.failIf(self.messages, self.messages) + self.assertFalse(self.messages) def test_getAssocNormalVsDumb(self): """getAssociation(dumb=True) cannot get a shared assoc @@ -1920,12 +1920,12 @@ def test_getAssocNormalVsDumb(self): """ assoc_handle = self.makeAssoc(dumb=False) self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=True)) - self.failIf(self.messages, self.messages) + self.assertFalse(self.messages) def test_createAssociation(self): assoc = self.signatory.createAssociation(dumb=False) self.assertTrue(self.signatory.getAssociation(assoc.handle, dumb=False)) - self.failIf(self.messages, self.messages) + self.assertFalse(self.messages) def makeAssoc(self, dumb, lifetime=60): assoc_handle = '{bling}' @@ -1947,8 +1947,8 @@ def test_invalidate(self): self.assertTrue(assoc) self.signatory.invalidate(assoc_handle, dumb=True) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) - self.failIf(assoc) - self.failIf(self.messages, self.messages) + self.assertFalse(assoc) + self.assertFalse(self.messages) if __name__ == '__main__': diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index 572c7d19..f358976a 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -37,7 +37,7 @@ def usesExtension(self, namespace_uri): class SupportsSRegTest(unittest.TestCase): def test_unsupported(self): endpoint = FakeEndpoint([]) - self.failIf(sreg.supportsSReg(endpoint)) + self.assertFalse(sreg.supportsSReg(endpoint)) self.assertEqual(endpoint.checked_uris, [sreg.ns_uri_1_1, sreg.ns_uri_1_0]) def test_supported_1_1(self): @@ -258,23 +258,23 @@ def test_allRequestedFields(self): def test_wereFieldsRequested(self): req = sreg.SRegRequest() - self.failIf(req.wereFieldsRequested()) + self.assertFalse(req.wereFieldsRequested()) req.requestField('gender') self.assertTrue(req.wereFieldsRequested()) def test_contains(self): req = sreg.SRegRequest() for field_name in sreg.data_fields: - self.failIf(field_name in req) + self.assertNotIn(field_name, req) - self.failIf('something else' in req) + self.assertNotIn('something else', req) req.requestField('nickname') for field_name in sreg.data_fields: if field_name == 'nickname': self.assertIn(field_name, req) else: - self.failIf(field_name in req) + self.assertNotIn(field_name, req) def test_requestField_bogus(self): req = sreg.SRegRequest() @@ -409,7 +409,7 @@ def test_construct(self): self.assertTrue(resp) empty_resp = sreg.SRegResponse({}) - self.failIf(empty_resp) + self.assertFalse(empty_resp) # XXX: finish this test @@ -419,7 +419,7 @@ def test_fromSuccessResponse_signed(self): }) success_resp = DummySuccessResponse(message, {}) sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp) - self.failIf(sreg_resp) + self.assertFalse(sreg_resp) def test_fromSuccessResponse_unsigned(self): message = Message.fromOpenIDArgs({ diff --git a/openid/test/test_symbol.py b/openid/test/test_symbol.py index 9e52844a..c5db74c0 100644 --- a/openid/test/test_symbol.py +++ b/openid/test/test_symbol.py @@ -20,12 +20,12 @@ def test_inequality(self): def test_selfInequality(self): x = oidutil.Symbol('xxx') - self.failIf(x != x) + self.assertFalse(x != x) def test_otherInequality(self): x = oidutil.Symbol('xxx') y = oidutil.Symbol('xxx') - self.failIf(x != y) + self.assertFalse(x != y) def test_ne_inequality(self): x = oidutil.Symbol('xxx') From bf03ab4f440d76b59330ad423a8213a3b25fbceb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jan 2018 11:13:14 +0100 Subject: [PATCH 037/151] Replace failIfEqual --- openid/test/test_oidutil.py | 2 +- openid/test/test_server.py | 4 ++-- openid/test/test_symbol.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index 3a66850e..b64898a1 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -153,7 +153,7 @@ def testCopyHash(self): self.assertIn(s, d_prime, "%r isn't in %r" % (s, d_prime)) t = oidutil.Symbol("Bar") - self.failIfEqual(hash(s), hash(t)) + self.assertNotEqual(hash(s), hash(t)) # XXX: there are more functions that could benefit from being better diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 38aa9c05..8fd8ac88 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -1772,7 +1772,7 @@ def test_signExpired(self): new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(new_assoc_handle) - self.failIfEqual(new_assoc_handle, assoc_handle) + self.assertNotEqual(new_assoc_handle, assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'invalidate_handle'), assoc_handle) @@ -1805,7 +1805,7 @@ def test_signInvalidHandle(self): new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(new_assoc_handle) - self.failIfEqual(new_assoc_handle, assoc_handle) + self.assertNotEqual(new_assoc_handle, assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'invalidate_handle'), assoc_handle) diff --git a/openid/test/test_symbol.py b/openid/test/test_symbol.py index c5db74c0..a115937f 100644 --- a/openid/test/test_symbol.py +++ b/openid/test/test_symbol.py @@ -16,7 +16,7 @@ def test_otherEquality(self): def test_inequality(self): x = oidutil.Symbol('xxx') y = oidutil.Symbol('yyy') - self.failIfEqual(x, y) + self.assertNotEqual(x, y) def test_selfInequality(self): x = oidutil.Symbol('xxx') From 7e2c019bb4d4e3593398314f84c267787a8a4d77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jan 2018 13:02:02 +0100 Subject: [PATCH 038/151] Use testfixtures to replace CatchLogs --- openid/test/support.py | 49 ---------- openid/test/test_association_response.py | 22 +++-- openid/test/test_consumer.py | 103 +++++++++------------ openid/test/test_kvform.py | 52 ++++------- openid/test/test_negotiation.py | 111 ++++++++++++----------- openid/test/test_rpverify.py | 26 +++--- openid/test/test_server.py | 84 ++++++++++------- openid/test/test_verifydisco.py | 91 +++++++++++-------- setup.py | 2 +- 9 files changed, 247 insertions(+), 293 deletions(-) diff --git a/openid/test/support.py b/openid/test/support.py index c2d45ea3..16f54c77 100644 --- a/openid/test/support.py +++ b/openid/test/support.py @@ -1,21 +1,6 @@ -import logging -from logging.handlers import BufferingHandler - from openid import message -class TestHandler(BufferingHandler): - def __init__(self, messages): - BufferingHandler.__init__(self, 0) - self.messages = messages - - def shouldFlush(self): - return False - - def emit(self, record): - self.messages.append(record) - - class OpenIDTestMixin(object): def failUnlessOpenIDValueEquals(self, msg, key, expected, ns=None): if ns is None: @@ -33,37 +18,3 @@ def failIfOpenIDKeyExists(self, msg, key, ns=None): actual = msg.getArg(ns, key) error_message = 'openid.%s unexpectedly present: %s' % (key, actual) self.assertIsNone(actual, error_message) - - -class CatchLogs(object): - def setUp(self): - self.messages = [] - root_logger = logging.getLogger() - self.old_log_level = root_logger.getEffectiveLevel() - root_logger.setLevel(logging.DEBUG) - - self.handler = TestHandler(self.messages) - formatter = logging.Formatter("%(message)s [%(asctime)s - %(name)s - %(levelname)s]") - self.handler.setFormatter(formatter) - root_logger.addHandler(self.handler) - - def tearDown(self): - root_logger = logging.getLogger() - root_logger.removeHandler(self.handler) - root_logger.setLevel(self.old_log_level) - - def failUnlessLogMatches(self, *prefixes): - """ - Check that the log messages contained in self.messages have - prefixes in *prefixes. Raise AssertionError if not, or if the - number of prefixes is different than the number of log - messages. - """ - messages = [r.getMessage() for r in self.messages] - assert len(prefixes) == len(messages), "Expected log prefixes %r, got %r" % (prefixes, messages) - - for prefix, msg in zip(prefixes, messages): - assert msg.startswith(prefix), "Expected log prefixes %r, got %r" % (prefixes, messages) - - def failUnlessLogEmpty(self): - self.failUnlessLogMatches() diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 9bac3e21..62b31750 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -5,12 +5,13 @@ """ import unittest +from testfixtures import LogCapture + from openid.consumer.consumer import GenericConsumer, ProtocolError from openid.consumer.discover import OPENID_1_1_TYPE, OPENID_2_0_TYPE, OpenIDServiceEndpoint from openid.message import OPENID2_NS, OPENID_NS, Message from openid.server.server import DiffieHellmanSHA1ServerSession from openid.store import memstore -from openid.test.test_consumer import CatchLogs # Some values we can use for convenience (see mkAssocResponse) association_response_values = { @@ -33,9 +34,8 @@ def mkAssocResponse(*keys): return Message.fromOpenIDArgs(args) -class BaseAssocTest(CatchLogs, unittest.TestCase): +class BaseAssocTest(unittest.TestCase): def setUp(self): - CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.consumer = GenericConsumer(self.store) self.endpoint = OpenIDServiceEndpoint() @@ -175,8 +175,9 @@ def mkTest(expected_session_type, session_type_value): """ def test(self): - self._doTest(expected_session_type, session_type_value) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self._doTest(expected_session_type, session_type_value) + self.assertEqual(logbook.records, []) return test @@ -209,11 +210,12 @@ def _doTest(self, expected_session_type, session_type_value): # This one's different because it expects log messages def test_explicitNoEncryption(self): - self._doTest( - session_type_value='no-encryption', - expected_session_type='no-encryption', - ) - self.failUnlessLogMatches('OpenID server sent "no-encryption"') + with LogCapture() as logbook: + self._doTest( + session_type_value='no-encryption', + expected_session_type='no-encryption', + ) + logbook.check(('openid.consumer.consumer', 'WARNING', 'OpenID server sent "no-encryption" for OpenID 1.X')) test_dhSHA1 = mkTest( session_type_value='DH-SHA1', diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 169a4233..a427dbff 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -3,6 +3,8 @@ import unittest import urlparse +from testfixtures import LogCapture, StringComparison + from openid import association, cryptutil, fetchers, kvform, oidutil from openid.consumer.consumer import (CANCEL, FAILURE, SETUP_NEEDED, SUCCESS, AuthRequest, CancelResponse, Consumer, DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, @@ -20,8 +22,6 @@ from openid.yadis.discover import DiscoveryFailure from openid.yadis.manager import Discovery -from .support import CatchLogs - assocs = [ ('another 20-byte key.', 'Snarky'), ('\x00' * 20, 'Zeros'), @@ -212,13 +212,12 @@ def run(): https_server_url = 'https://server.example.com/' -class TestSuccess(unittest.TestCase, CatchLogs): +class TestSuccess(unittest.TestCase): server_url = http_server_url user_url = 'http://www.example.com/user.html' delegate_url = 'http://consumer.example.com/user' def setUp(self): - CatchLogs.setUp(self) self.links = '' % ( self.server_url,) @@ -226,9 +225,6 @@ def setUp(self): '') % ( self.server_url, self.delegate_url) - def tearDown(self): - CatchLogs.tearDown(self) - def test_nodelegate(self): _test_success(self.server_url, self.user_url, self.user_url, self.links) @@ -262,12 +258,10 @@ def test_nostore(self): self.assertRaises(TypeError, GenericConsumer) -class TestIdRes(unittest.TestCase, CatchLogs): +class TestIdRes(unittest.TestCase): consumer_class = GenericConsumer def setUp(self): - CatchLogs.setUp(self) - self.store = memstore.MemoryStore() self.consumer = self.consumer_class(self.store) self.return_to = "nonny" @@ -464,19 +458,18 @@ def discoverAndVerify(claimed_id, _to_match_endpoints): }) self.consumer.store = GoodAssocStore() - self.assertRaises(VerifiedError, self.consumer.complete, message, self.endpoint) - - self.failUnlessLogMatches('Error attempting to use stored', - 'Attempting discovery') + with LogCapture() as logbook: + self.assertRaises(VerifiedError, self.consumer.complete, message, self.endpoint) + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) -class TestCompleteMissingSig(unittest.TestCase, CatchLogs): +class TestCompleteMissingSig(unittest.TestCase): def setUp(self): self.store = GoodAssocStore() self.consumer = GenericConsumer(self.store) self.server_url = "http://idp.unittest/" - CatchLogs.setUp(self) claimed_id = 'bogus.claimed' @@ -498,9 +491,6 @@ def setUp(self): self.endpoint.claimed_id = claimed_id self.consumer._checkReturnTo = lambda unused1, unused2: True - def tearDown(self): - CatchLogs.tearDown(self) - def test_idResMissingNoSigs(self): def _vrfy(resp_msg, endpoint=None): return endpoint @@ -542,14 +532,10 @@ def failUnlessSuccess(self, response): self.fail("Non-successful response: %s" % (response,)) -class TestCheckAuthResponse(TestIdRes, CatchLogs): +class TestCheckAuthResponse(TestIdRes): def setUp(self): - CatchLogs.setUp(self) TestIdRes.setUp(self) - def tearDown(self): - CatchLogs.tearDown(self) - def _createAssoc(self): issued = time.time() lifetime = 1000 @@ -602,11 +588,10 @@ def test_invalidateMissing(self): 'is_valid': 'true', 'invalidate_handle': 'missing', }) - r = self.consumer._processCheckAuthResponse(response, self.server_url) + with LogCapture() as logbook: + r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertTrue(r) - self.failUnlessLogMatches( - 'Received "invalidate_handle"' - ) + logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Received "invalidate_handle" from .*'))) def test_invalidateMissing_noStore(self): """invalidate_handle with a handle that is not present""" @@ -615,11 +600,11 @@ def test_invalidateMissing_noStore(self): 'invalidate_handle': 'missing', }) self.consumer.store = None - r = self.consumer._processCheckAuthResponse(response, self.server_url) + with LogCapture() as logbook: + r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertTrue(r) - self.failUnlessLogMatches( - 'Received "invalidate_handle"', - 'Unexpectedly got invalidate_handle without a store') + logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Received "invalidate_handle" from .*')), + ('openid.consumer.consumer', 'ERROR', 'Unexpectedly got invalidate_handle without a store!')) def test_invalidatePresent(self): """invalidate_handle with a handle that exists @@ -813,51 +798,52 @@ class CheckAuthHappened(Exception): pass -class CheckNonceVerifyTest(TestIdRes, CatchLogs): +class CheckNonceVerifyTest(TestIdRes): def setUp(self): - CatchLogs.setUp(self) TestIdRes.setUp(self) self.consumer.openid1_nonce_query_arg_name = 'nonce' - def tearDown(self): - CatchLogs.tearDown(self) - def test_openid1Success(self): """use consumer-generated nonce""" nonce_value = mkNonce() self.return_to = 'http://rt.unittest/?nonce=%s' % (nonce_value,) self.response = Message.fromOpenIDArgs({'return_to': self.return_to}) self.response.setArg(BARE_NS, 'nonce', nonce_value) - self.consumer._idResCheckNonce(self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.consumer._idResCheckNonce(self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_openid1Missing(self): """use consumer-generated nonce""" self.response = Message.fromOpenIDArgs({}) - n = self.consumer._idResGetNonceOpenID1(self.response, self.endpoint) + with LogCapture() as logbook: + n = self.consumer._idResGetNonceOpenID1(self.response, self.endpoint) self.assertIsNone(n) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_consumerNonceOpenID2(self): """OpenID 2 does not use consumer-generated nonce""" self.return_to = 'http://rt.unittest/?nonce=%s' % (mkNonce(),) self.response = Message.fromOpenIDArgs( {'return_to': self.return_to, 'ns': OPENID2_NS}) - self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_serverNonce(self): """use server-generated nonce""" self.response = Message.fromOpenIDArgs({'ns': OPENID2_NS, 'response_nonce': mkNonce()}) - self.consumer._idResCheckNonce(self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.consumer._idResCheckNonce(self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_serverNonceOpenID1(self): """OpenID 1 does not use server-generated nonce""" self.response = Message.fromOpenIDArgs( {'ns': OPENID1_NS, 'return_to': 'http://return.to/', 'response_nonce': mkNonce()}) - self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertRaises(ProtocolError, self.consumer._idResCheckNonce, self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_badNonce(self): """remove the nonce from the store @@ -880,8 +866,9 @@ def test_successWithNoStore(self): """When there is no store, checking the nonce succeeds""" self.consumer.store = None self.response = Message.fromOpenIDArgs({'response_nonce': mkNonce(), 'ns': OPENID2_NS}) - self.consumer._idResCheckNonce(self.response, self.endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.consumer._idResCheckNonce(self.response, self.endpoint) + self.assertEqual(logbook.records, []) def test_tamperedNonce(self): """Malformed nonce""" @@ -905,12 +892,11 @@ def _idResCheckNonce(self, *args): return True -class TestCheckAuthTriggered(TestIdRes, CatchLogs): +class TestCheckAuthTriggered(TestIdRes): consumer_class = CheckAuthDetectingConsumer def setUp(self): TestIdRes.setUp(self) - CatchLogs.setUp(self) self.disableDiscoveryVerification() def test_checkAuthTriggered(self): @@ -1156,11 +1142,10 @@ def _makeKVPost(self, args, _): return None -class TestCheckAuth(unittest.TestCase, CatchLogs): +class TestCheckAuth(unittest.TestCase): consumer_class = GenericConsumer def setUp(self): - CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.consumer = self.consumer_class(self.store) @@ -1170,7 +1155,6 @@ def setUp(self): fetchers.setDefaultFetcher(self.fetcher) def tearDown(self): - CatchLogs.tearDown(self) fetchers.setDefaultFetcher(self._orig_fetcher, wrap_exceptions=False) def test_error(self): @@ -1178,10 +1162,12 @@ def test_error(self): "http://some_url", 404, {'Hea': 'der'}, 'blah:blah\n') query = {'openid.signed': 'stuff', 'openid.stuff': 'a value'} - r = self.consumer._checkAuth(Message.fromPostArgs(query), - http_server_url) + with LogCapture() as logbook: + r = self.consumer._checkAuth(Message.fromPostArgs(query), http_server_url) self.assertFalse(r) - self.assertTrue(self.messages) + logbook.check(('openid.consumer.consumer', 'INFO', 'Using OpenID check_authentication'), + ('openid.consumer.consumer', 'INFO', 'stuff'), + ('openid.consumer.consumer', 'ERROR', StringComparison('check_authentication failed: .*: 404'))) def test_bad_args(self): query = { @@ -1236,11 +1222,10 @@ def test_112(self): self.assertEqual(car.toPostArgs(), expected_args) -class TestFetchAssoc(unittest.TestCase, CatchLogs): +class TestFetchAssoc(unittest.TestCase): consumer_class = GenericConsumer def setUp(self): - CatchLogs.setUp(self) self.store = memstore.MemoryStore() self.fetcher = MockFetcher() fetchers.setDefaultFetcher(self.fetcher) diff --git a/openid/test/test_kvform.py b/openid/test/test_kvform.py index 09532db9..187629a1 100644 --- a/openid/test/test_kvform.py +++ b/openid/test/test_kvform.py @@ -1,39 +1,23 @@ import unittest -from openid import kvform -from openid.test.support import CatchLogs - - -class KVBaseTest(unittest.TestCase, CatchLogs): - - def checkWarnings(self, num_warnings, msg=None): - full_msg = 'Invalid number of warnings {} != {}'.format(num_warnings, len(self.messages)) - if msg is not None: - full_msg = full_msg + ' ' + msg - self.assertEqual(num_warnings, len(self.messages), full_msg) +from testfixtures import LogCapture - def setUp(self): - CatchLogs.setUp(self) - - def tearDown(self): - CatchLogs.tearDown(self) +from openid import kvform -class KVDictTest(KVBaseTest): +class KVDictTest(unittest.TestCase): def runTest(self): for kv_data, result, expected_warnings in kvdict_cases: - # Clean captrured messages - del self.messages[:] - # Convert KVForm to dict - d = kvform.kvToDict(kv_data) + with LogCapture() as logbook: + d = kvform.kvToDict(kv_data) # make sure it parses to expected dict self.assertEqual(d, result) # Check to make sure we got the expected number of warnings - self.checkWarnings(expected_warnings, msg='kvToDict({!r})'.format(kv_data)) + self.assertEqual(len(logbook.records), expected_warnings) # Convert back to KVForm and round-trip back to dict to make # sure that *** dict -> kv -> dict is identity. *** @@ -42,7 +26,7 @@ def runTest(self): self.assertEqual(d, d2) -class KVSeqTest(KVBaseTest): +class KVSeqTest(unittest.TestCase): def cleanSeq(self, seq): """Create a new sequence by stripping whitespace from start @@ -58,11 +42,9 @@ def cleanSeq(self, seq): def runTest(self): for kv_data, result, expected_warnings in kvseq_cases: - # Clean captrured messages - del self.messages[:] - # seq serializes to expected kvform - actual = kvform.seqToKV(kv_data) + with LogCapture() as logbook: + actual = kvform.seqToKV(kv_data) self.assertEqual(actual, result) self.assertIsInstance(actual, str) @@ -73,7 +55,8 @@ def runTest(self): clean_seq = self.cleanSeq(seq) self.assertEqual(seq, clean_seq) - self.checkWarnings(expected_warnings) + self.assertEqual(len(logbook.records), expected_warnings, + "Invalid warnings for {}: {}".format(kv_data, [r.getMessage() for r in logbook.records])) kvdict_cases = [ @@ -119,16 +102,16 @@ def runTest(self): ([('openid', 'useful'), ('a', 'b')], 'openid:useful\na:b\n', 0), # Warnings about leading whitespace - ([(' openid', 'useful'), ('a', 'b')], ' openid:useful\na:b\n', 2), + ([(' openid', 'useful'), ('a', 'b')], ' openid:useful\na:b\n', 1), # Warnings about leading and trailing whitespace ([(' openid ', ' useful '), - (' a ', ' b ')], ' openid : useful \n a : b \n', 8), + (' a ', ' b ')], ' openid : useful \n a : b \n', 4), # warnings about leading and trailing whitespace, but not about # internal whitespace. ([(' open id ', ' use ful '), - (' a ', ' b ')], ' open id : use ful \n a : b \n', 8), + (' a ', ' b ')], ' open id : use ful \n a : b \n', 4), ([(u'foo', 'bar')], 'foo:bar\n', 0), ] @@ -150,10 +133,11 @@ def runTest(self): self.assertRaises(ValueError, kvform.seqToKV, kv_data) -class GeneralTest(KVBaseTest): +class GeneralTest(unittest.TestCase): kvform = '' def test_convert(self): - result = kvform.seqToKV([(1, 1)]) + with LogCapture() as logbook: + result = kvform.seqToKV([(1, 1)]) self.assertEqual(result, '1:1\n') - self.checkWarnings(2) + self.assertEqual(len(logbook.records), 2) diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index 71ff200b..6c4cf1fe 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -1,12 +1,12 @@ import unittest +from testfixtures import LogCapture, StringComparison + from openid import association from openid.consumer.consumer import GenericConsumer, ServerError from openid.consumer.discover import OPENID_2_0_TYPE, OpenIDServiceEndpoint from openid.message import OPENID1_NS, OPENID_NS, Message -from .support import CatchLogs - class ErrorRaisingConsumer(GenericConsumer): """ @@ -29,14 +29,13 @@ def _requestAssociation(self, endpoint, assoc_type, session_type): return m -class TestOpenID2SessionNegotiation(unittest.TestCase, CatchLogs): +class TestOpenID2SessionNegotiation(unittest.TestCase): """ Test the session type negotiation behavior of an OpenID 2 consumer. """ def setUp(self): - CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) self.endpoint = OpenIDServiceEndpoint() @@ -49,8 +48,10 @@ def testBadResponse(self): server error or is otherwise undecipherable. """ self.consumer.return_messages = [Message(self.endpoint.preferredNamespace())] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testEmptyAssocType(self): """ @@ -64,11 +65,11 @@ def testEmptyAssocType(self): msg.setArg(OPENID_NS, 'session_type', 'new-session-type') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Unsupported association type', - 'Server responded with unsupported association ' + - 'session but did not supply a fallback.') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + no_fallback_msg = 'Server responded with unsupported association session but did not supply a fallback.' + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'ERROR', no_fallback_msg)) def testEmptySessionType(self): """ @@ -82,11 +83,11 @@ def testEmptySessionType(self): # not set: msg.setArg(OPENID_NS, 'session_type', None) self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Unsupported association type', - 'Server responded with unsupported association ' + - 'session but did not supply a fallback.') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + no_fallback_msg = 'Server responded with unsupported association session but did not supply a fallback.' + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'ERROR', no_fallback_msg)) def testNotAllowed(self): """ @@ -106,10 +107,11 @@ def testNotAllowed(self): msg.setArg(OPENID_NS, 'session_type', 'not-allowed') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Unsupported association type', - 'Server sent unsupported session/association type:') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + unsupported_msg = StringComparison('Server sent unsupported session/association type: .*') + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'ERROR', unsupported_msg)) def testUnsupportedWithRetry(self): """ @@ -126,9 +128,9 @@ def testUnsupportedWithRetry(self): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) - - self.failUnlessLogMatches('Unsupported association type') + with LogCapture() as logbook: + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*'))) def testUnsupportedWithRetryAndFail(self): """ @@ -144,10 +146,11 @@ def testUnsupportedWithRetryAndFail(self): self.consumer.return_messages = [msg, Message(self.endpoint.preferredNamespace())] - self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) - - self.failUnlessLogMatches('Unsupported association type', - 'Server %s refused' % (self.endpoint.server_url)) + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + refused_msg = StringComparison('Server %s refused its .*' % self.endpoint.server_url) + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'ERROR', refused_msg)) def testValid(self): """ @@ -158,23 +161,22 @@ def testValid(self): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [assoc] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) + self.assertEqual(logbook.records, []) -class TestOpenID1SessionNegotiation(unittest.TestCase, CatchLogs): +class TestOpenID1SessionNegotiation(unittest.TestCase): """ Tests for the OpenID 1 consumer association session behavior. See the docs for TestOpenID2SessionNegotiation. Notice that this class is not a subclass of the OpenID 2 tests. Instead, it uses - many of the same inputs but inspects the log messages. - See the calls to self.failUnlessLogMatches. Some of - these tests pass openid2-style messages to the openid 1 + many of the same inputs but inspects the log messages, see the LogCapture. + Some of these tests pass openid2-style messages to the openid 1 association processing logic to be sure it ignores the extra data. """ def setUp(self): - CatchLogs.setUp(self) self.consumer = ErrorRaisingConsumer(store=None) self.endpoint = OpenIDServiceEndpoint() @@ -183,8 +185,10 @@ def setUp(self): def testBadResponse(self): self.consumer.return_messages = [Message(self.endpoint.preferredNamespace())] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testEmptyAssocType(self): msg = Message(self.endpoint.preferredNamespace()) @@ -194,9 +198,10 @@ def testEmptyAssocType(self): msg.setArg(OPENID_NS, 'session_type', 'new-session-type') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testEmptySessionType(self): msg = Message(self.endpoint.preferredNamespace()) @@ -206,9 +211,10 @@ def testEmptySessionType(self): # not set: msg.setArg(OPENID_NS, 'session_type', None) self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testNotAllowed(self): allowed_types = [] @@ -223,9 +229,10 @@ def testNotAllowed(self): msg.setArg(OPENID_NS, 'session_type', 'not-allowed') self.consumer.return_messages = [msg] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), None) - - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testUnsupportedWithRetry(self): msg = Message(self.endpoint.preferredNamespace()) @@ -238,20 +245,22 @@ def testUnsupportedWithRetry(self): 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] - self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) - - self.failUnlessLogMatches('Server error when requesting an association') + with LogCapture() as logbook: + self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) + logbook.check( + ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testValid(self): assoc = association.Association( 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [assoc] - self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) + self.assertEqual(logbook.records, []) -class TestNegotiatorBehaviors(unittest.TestCase, CatchLogs): +class TestNegotiatorBehaviors(unittest.TestCase): def setUp(self): self.allowed_types = [ ('HMAC-SHA1', 'no-encryption'), diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index d12cc5b9..04b693ee 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -5,8 +5,9 @@ import unittest +from testfixtures import LogCapture, StringComparison + from openid.server import trustroot -from openid.test.support import CatchLogs from openid.yadis import services from openid.yadis.discover import DiscoveryFailure, DiscoveryResult @@ -190,13 +191,7 @@ def test_noMatch(self): self.assertFalse(trustroot.returnToMatches([r], 'http://example.com/xss_exploit')) -class TestVerifyReturnTo(unittest.TestCase, CatchLogs): - - def setUp(self): - CatchLogs.setUp(self) - - def tearDown(self): - CatchLogs.tearDown(self) +class TestVerifyReturnTo(unittest.TestCase): def test_bogusRealm(self): self.assertFalse(trustroot.verifyReturnTo('', 'http://example.com/')) @@ -209,8 +204,9 @@ def vrfy(disco_url): self.assertEqual(disco_url, 'http://www.example.com/') return [return_to] - self.assertTrue(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertTrue(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + self.assertEqual(logbook.records, []) def test_verifyFailWithDiscoveryCalled(self): realm = 'http://*.example.com/' @@ -220,8 +216,9 @@ def vrfy(disco_url): self.assertEqual(disco_url, 'http://www.example.com/') return ['http://something-else.invalid/'] - self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) - self.failUnlessLogMatches("Failed to validate return_to") + with LogCapture() as logbook: + self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + logbook.check(('openid.server.trustroot', 'ERROR', StringComparison('Failed to validate return_to .*'))) def test_verifyFailIfDiscoveryRedirects(self): realm = 'http://*.example.com/' @@ -231,8 +228,9 @@ def vrfy(disco_url): raise trustroot.RealmVerificationRedirected( disco_url, "http://redirected.invalid") - self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) - self.failUnlessLogMatches("Attempting to verify") + with LogCapture() as logbook: + self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) + logbook.check(('openid.server.trustroot', 'ERROR', StringComparison('Attempting to verify .*'))) if __name__ == '__main__': diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 8fd8ac88..c61878b9 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -5,12 +5,13 @@ from functools import partial from urlparse import urlparse +from testfixtures import LogCapture, StringComparison + from openid import association, cryptutil, oidutil from openid.consumer.consumer import DiffieHellmanSHA256ConsumerSession from openid.message import IDENTIFIER_SELECT, OPENID1_NS, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, Message, no_default from openid.server import server from openid.store import memstore -from openid.test.support import CatchLogs # In general, if you edit or add tests here, try to move in the direction # of testing smaller units. For testing the external interfaces, we'll be @@ -1576,11 +1577,10 @@ def inc(self): self.count += 1 -class TestServer(unittest.TestCase, CatchLogs): +class TestServer(unittest.TestCase): def setUp(self): self.store = memstore.MemoryStore() self.server = server.Server(self.store, "http://server.unittest/endpt") - CatchLogs.setUp(self) def test_dispatch(self): monkeycalled = Counter() @@ -1689,13 +1689,12 @@ def test_checkAuth(self): self.assertTrue(response.fields.hasKey(OPENID_NS, "is_valid")) -class TestSignatory(unittest.TestCase, CatchLogs): +class TestSignatory(unittest.TestCase): def setUp(self): self.store = memstore.MemoryStore() self.signatory = server.Signatory(self.store) self._dumb_key = self.signatory._dumb_key self._normal_key = self.signatory._normal_key - CatchLogs.setUp(self) def test_sign(self): request = server.OpenIDRequest() @@ -1712,11 +1711,12 @@ def test_sign(self): 'bar': 'notsigned', 'azu': 'alsosigned', }) - sresponse = self.signatory.sign(response) + with LogCapture() as logbook: + sresponse = self.signatory.sign(response) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'assoc_handle'), assoc_handle) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,signed') self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def test_signDumb(self): request = server.OpenIDRequest() @@ -1729,14 +1729,15 @@ def test_signDumb(self): 'azu': 'alsosigned', 'ns': OPENID2_NS, }) - sresponse = self.signatory.sign(response) + with LogCapture() as logbook: + sresponse = self.signatory.sign(response) assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(assoc_handle) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.assertTrue(assoc) self.assertEqual(sresponse.fields.getArg(OPENID_NS, 'signed'), 'assoc_handle,azu,bar,foo,ns,signed') self.assertTrue(sresponse.fields.getArg(OPENID_NS, 'sig')) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def test_signExpired(self): """Sign a response to a message with an expired handle (using invalidate_handle). @@ -1768,7 +1769,8 @@ def test_signExpired(self): 'bar': 'notsigned', 'azu': 'alsosigned', }) - sresponse = self.signatory.sign(response) + with LogCapture() as logbook: + sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(new_assoc_handle) @@ -1787,7 +1789,7 @@ def test_signExpired(self): # make sure the new key is a dumb mode association self.assertTrue(self.store.getAssociation(self._dumb_key, new_assoc_handle)) self.assertFalse(self.store.getAssociation(self._normal_key, new_assoc_handle)) - self.assertTrue(self.messages) + logbook.check(('openid.server.server', 'INFO', StringComparison('requested .* key .* is expired .*'))) def test_signInvalidHandle(self): request = server.OpenIDRequest() @@ -1801,7 +1803,8 @@ def test_signInvalidHandle(self): 'bar': 'notsigned', 'azu': 'alsosigned', }) - sresponse = self.signatory.sign(response) + with LogCapture() as logbook: + sresponse = self.signatory.sign(response) new_assoc_handle = sresponse.fields.getArg(OPENID_NS, 'assoc_handle') self.assertTrue(new_assoc_handle) @@ -1816,7 +1819,7 @@ def test_signInvalidHandle(self): # make sure the new key is a dumb mode association self.assertTrue(self.store.getAssociation(self._dumb_key, new_assoc_handle)) self.assertFalse(self.store.getAssociation(self._normal_key, new_assoc_handle)) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def test_verify(self): assoc_handle = '{vroom}{zoom}' @@ -1833,8 +1836,9 @@ def test_verify(self): 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco=', }) - verified = self.signatory.verify(assoc_handle, signed) - self.assertFalse(self.messages) + with LogCapture() as logbook: + verified = self.signatory.verify(assoc_handle, signed) + self.assertEqual(logbook.records, []) self.assertTrue(verified) def test_verifyBadSig(self): @@ -1852,8 +1856,9 @@ def test_verifyBadSig(self): 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco='.encode('rot13'), }) - verified = self.signatory.verify(assoc_handle, signed) - self.assertFalse(self.messages) + with LogCapture() as logbook: + verified = self.signatory.verify(assoc_handle, signed) + self.assertEqual(logbook.records, []) self.assertFalse(verified) def test_verifyBadHandle(self): @@ -1864,9 +1869,10 @@ def test_verifyBadHandle(self): 'openid.sig': "Ylu0KcIR7PvNegB/K41KpnRgJl0=", }) - verified = self.signatory.verify(assoc_handle, signed) + with LogCapture() as logbook: + verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(verified) - self.assertTrue(self.messages) + logbook.check(('openid.server.server', 'ERROR', StringComparison('failed to get assoc with handle .*'))) def test_verifyAssocMismatch(self): """Attempt to validate sign-all message with a signed-list assoc.""" @@ -1882,33 +1888,38 @@ def test_verifyAssocMismatch(self): 'openid.sig': "d71xlHtqnq98DonoSgoK/nD+QRM=", }) - verified = self.signatory.verify(assoc_handle, signed) + with LogCapture() as logbook: + verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(verified) - self.assertTrue(self.messages) + logbook.check(('openid.server.server', 'ERROR', StringComparison('Error in verifying .*'))) def test_getAssoc(self): assoc_handle = self.makeAssoc(dumb=True) - assoc = self.signatory.getAssociation(assoc_handle, True) + with LogCapture() as logbook: + assoc = self.signatory.getAssociation(assoc_handle, True) self.assertTrue(assoc) self.assertEqual(assoc.handle, assoc_handle) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def test_getAssocExpired(self): assoc_handle = self.makeAssoc(dumb=True, lifetime=-10) - assoc = self.signatory.getAssociation(assoc_handle, True) + with LogCapture() as logbook: + assoc = self.signatory.getAssociation(assoc_handle, True) self.assertFalse(assoc) - self.assertTrue(self.messages) + logbook.check(('openid.server.server', 'INFO', StringComparison('requested .* key .* is expired .*'))) def test_getAssocInvalid(self): ah = 'no-such-handle' - self.assertIsNone(self.signatory.getAssociation(ah, dumb=False)) - self.assertFalse(self.messages) + with LogCapture() as logbook: + self.assertIsNone(self.signatory.getAssociation(ah, dumb=False)) + self.assertEqual(logbook.records, []) def test_getAssocDumbVsNormal(self): """getAssociation(dumb=False) cannot get a dumb assoc""" assoc_handle = self.makeAssoc(dumb=True) - self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=False)) - self.assertFalse(self.messages) + with LogCapture() as logbook: + self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=False)) + self.assertEqual(logbook.records, []) def test_getAssocNormalVsDumb(self): """getAssociation(dumb=True) cannot get a shared assoc @@ -1919,13 +1930,15 @@ def test_getAssocNormalVsDumb(self): MAC keys. """ assoc_handle = self.makeAssoc(dumb=False) - self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=True)) - self.assertFalse(self.messages) + with LogCapture() as logbook: + self.assertIsNone(self.signatory.getAssociation(assoc_handle, dumb=True)) + self.assertEqual(logbook.records, []) def test_createAssociation(self): - assoc = self.signatory.createAssociation(dumb=False) + with LogCapture() as logbook: + assoc = self.signatory.createAssociation(dumb=False) self.assertTrue(self.signatory.getAssociation(assoc.handle, dumb=False)) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) def makeAssoc(self, dumb, lifetime=60): assoc_handle = '{bling}' @@ -1945,10 +1958,11 @@ def test_invalidate(self): self.assertTrue(assoc) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.assertTrue(assoc) - self.signatory.invalidate(assoc_handle, dumb=True) + with LogCapture() as logbook: + self.signatory.invalidate(assoc_handle, dumb=True) assoc = self.signatory.getAssociation(assoc_handle, dumb=True) self.assertFalse(assoc) - self.assertFalse(self.messages) + self.assertEqual(logbook.records, []) if __name__ == '__main__': diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index c0055ef9..ec69a62b 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -1,5 +1,7 @@ import unittest +from testfixtures import LogCapture, StringComparison + from openid import message from openid.consumer import consumer, discover from openid.test.support import OpenIDTestMixin @@ -25,48 +27,51 @@ def test_openID1NoLocalID(self): endpoint.claimed_id = 'bogus' msg = message.Message.fromOpenIDArgs({}) - self.failUnlessProtocolError( - 'Missing required field openid.identity', - self.consumer._verifyDiscoveryResults, msg, endpoint) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.failUnlessProtocolError('Missing required field openid.identity', + self.consumer._verifyDiscoveryResults, msg, endpoint) + self.assertEqual(logbook.records, []) def test_openID1NoEndpoint(self): msg = message.Message.fromOpenIDArgs({'identity': 'snakes on a plane'}) - self.assertRaises(RuntimeError, self.consumer._verifyDiscoveryResults, msg) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertRaises(RuntimeError, self.consumer._verifyDiscoveryResults, msg) + self.assertEqual(logbook.records, []) def test_openID2NoOPEndpointArg(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS}) - self.assertRaises(KeyError, self.consumer._verifyDiscoveryResults, msg) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.assertRaises(KeyError, self.consumer._verifyDiscoveryResults, msg) + self.assertEqual(logbook.records, []) def test_openID2LocalIDNoClaimed(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, 'op_endpoint': 'Phone Home', 'identity': 'Jose Lius Borges'}) - self.failUnlessProtocolError( - 'openid.identity is present without', - self.consumer._verifyDiscoveryResults, msg) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.failUnlessProtocolError('openid.identity is present without', + self.consumer._verifyDiscoveryResults, msg) + self.assertEqual(logbook.records, []) def test_openID2NoLocalIDClaimed(self): msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, 'op_endpoint': 'Phone Home', 'claimed_id': 'Manuel Noriega'}) - self.failUnlessProtocolError( - 'openid.claimed_id is present without', - self.consumer._verifyDiscoveryResults, msg) - self.failUnlessLogEmpty() + with LogCapture() as logbook: + self.failUnlessProtocolError('openid.claimed_id is present without', + self.consumer._verifyDiscoveryResults, msg) + self.assertEqual(logbook.records, []) def test_openID2NoIdentifiers(self): op_endpoint = 'Phone Home' msg = message.Message.fromOpenIDArgs({'ns': message.OPENID2_NS, 'op_endpoint': op_endpoint}) - result_endpoint = self.consumer._verifyDiscoveryResults(msg) + with LogCapture() as logbook: + result_endpoint = self.consumer._verifyDiscoveryResults(msg) self.assertTrue(result_endpoint.isOPIdentifier()) self.assertEqual(result_endpoint.server_url, op_endpoint) self.assertIsNone(result_endpoint.claimed_id) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_openID2NoEndpointDoesDisco(self): op_endpoint = 'Phone Home' @@ -78,9 +83,10 @@ def test_openID2NoEndpointDoesDisco(self): 'identity': 'sour grapes', 'claimed_id': 'monkeysoft', 'op_endpoint': op_endpoint}) - result = self.consumer._verifyDiscoveryResults(msg) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg) self.assertEqual(result, sentinel) - self.failUnlessLogMatches('No pre-discovered') + logbook.check(('openid.consumer.consumer', 'INFO', 'No pre-discovered information supplied.')) def test_openID2MismatchedDoesDisco(self): mismatched = discover.OpenIDServiceEndpoint() @@ -96,10 +102,11 @@ def test_openID2MismatchedDoesDisco(self): 'identity': 'sour grapes', 'claimed_id': 'monkeysoft', 'op_endpoint': op_endpoint}) - result = self.consumer._verifyDiscoveryResults(msg, mismatched) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg, mismatched) self.assertEqual(result, sentinel) - self.failUnlessLogMatches('Error attempting to use stored', - 'Attempting discovery') + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid2UsePreDiscovered(self): endpoint = discover.OpenIDServiceEndpoint() @@ -113,9 +120,10 @@ def test_openid2UsePreDiscovered(self): 'identity': endpoint.local_id, 'claimed_id': endpoint.claimed_id, 'op_endpoint': endpoint.server_url}) - result = self.consumer._verifyDiscoveryResults(msg, endpoint) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertEqual(result, endpoint) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_openid2UsePreDiscoveredWrongType(self): text = "verify failed" @@ -140,11 +148,12 @@ def discoverAndVerify(claimed_id, to_match_endpoints): 'claimed_id': endpoint.claimed_id, 'op_endpoint': endpoint.server_url}) - with self.assertRaisesRegexp(consumer.ProtocolError, text): - self.consumer._verifyDiscoveryResults(msg, endpoint) + with LogCapture() as logbook: + with self.assertRaisesRegexp(consumer.ProtocolError, text): + self.consumer._verifyDiscoveryResults(msg, endpoint) - self.failUnlessLogMatches('Error attempting to use stored', - 'Attempting discovery') + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid1UsePreDiscovered(self): endpoint = discover.OpenIDServiceEndpoint() @@ -156,9 +165,10 @@ def test_openid1UsePreDiscovered(self): msg = message.Message.fromOpenIDArgs( {'ns': message.OPENID1_NS, 'identity': endpoint.local_id}) - result = self.consumer._verifyDiscoveryResults(msg, endpoint) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertEqual(result, endpoint) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_openid1UsePreDiscoveredWrongType(self): class VerifiedError(Exception): @@ -179,10 +189,10 @@ def discoverAndVerify(claimed_id, _to_match): {'ns': message.OPENID1_NS, 'identity': endpoint.local_id}) - self.assertRaises(VerifiedError, self.consumer._verifyDiscoveryResults, msg, endpoint) - - self.failUnlessLogMatches('Error attempting to use stored', - 'Attempting discovery') + with LogCapture() as logbook: + self.assertRaises(VerifiedError, self.consumer._verifyDiscoveryResults, msg, endpoint) + logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid2Fragment(self): claimed_id = "http://unittest.invalid/" @@ -198,15 +208,15 @@ def test_openid2Fragment(self): 'identity': endpoint.local_id, 'claimed_id': claimed_id_frag, 'op_endpoint': endpoint.server_url}) - result = self.consumer._verifyDiscoveryResults(msg, endpoint) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertEqual(result.local_id, endpoint.local_id) self.assertEqual(result.server_url, endpoint.server_url) self.assertEqual(result.type_uris, endpoint.type_uris) - self.assertEqual(result.claimed_id, claimed_id_frag) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) def test_openid1Fallback1_0(self): claimed_id = 'http://claimed.id/' @@ -248,10 +258,11 @@ def test_endpointWithoutLocalID(self): to_match.server_url = "http://localhost:8000/openidserver" to_match.claimed_id = "http://localhost:8000/id/id-jo" to_match.local_id = "http://localhost:8000/id/id-jo" - result = self.consumer._verifyDiscoverySingle(endpoint, to_match) + with LogCapture() as logbook: + result = self.consumer._verifyDiscoverySingle(endpoint, to_match) # result should always be None, raises exception on failure. self.assertIsNone(result) - self.failUnlessLogEmpty() + self.assertEqual(logbook.records, []) if __name__ == '__main__': diff --git a/setup.py b/setup.py index a7e3bce7..4b7e934c 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ version = '[library version:2.2.5]'[17:-1] EXTRAS_REQUIRE = { 'quality': ('flake8', 'isort'), - 'tests': ('mock', ), + 'tests': ('mock', 'testfixtures'), # Optional dependencies for fetchers 'httplib2': ('httplib2', ), 'pycurl': ('pycurl', ), From 2ead1e2a1189b79a19f1ca2c3ef5ef2a019e634b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jan 2018 13:48:23 +0100 Subject: [PATCH 039/151] Cleanup test utilities --- openid/test/test_association_response.py | 18 +++---- openid/test/test_auth_request.py | 69 ++++++++++++------------ openid/test/test_ax.py | 64 +++++++--------------- openid/test/test_consumer.py | 28 ++++------ openid/test/test_discover.py | 21 +++----- openid/test/test_fetchers.py | 12 ++--- openid/test/test_rpverify.py | 36 +++++-------- openid/test/test_server.py | 11 ++-- openid/test/test_verifydisco.py | 18 +++---- openid/test/{support.py => utils.py} | 13 +++-- 10 files changed, 116 insertions(+), 174 deletions(-) rename openid/test/{support.py => utils.py} (51%) diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 62b31750..3e5dfd07 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -40,10 +40,6 @@ def setUp(self): self.consumer = GenericConsumer(self.store) self.endpoint = OpenIDServiceEndpoint() - def failUnlessProtocolError(self, str_prefix, func, *args, **kwargs): - with self.assertRaisesRegexp(ProtocolError, str_prefix): - func(*args, **kwargs) - def mkExtractAssocMissingTest(keys): """Factory function for creating test methods for generating @@ -124,7 +120,8 @@ def test(self): keys.remove('ns') msg = mkAssocResponse(*keys) msg.setArg(OPENID_NS, 'session_type', response_session_type) - self.failUnlessProtocolError('Session type mismatch', self.consumer._extractAssociation, msg, assoc_session) + with self.assertRaisesRegexp(ProtocolError, 'Session type mismatch'): + self.consumer._extractAssociation(msg, assoc_session) return test @@ -285,14 +282,14 @@ def test_badAssocType(self): # Make sure that the assoc type in the response is not valid # for the given session. self.assoc_session.allowed_assoc_types = [] - self.failUnlessProtocolError('Unsupported assoc_type for session', - self.consumer._extractAssociation, self.assoc_response, self.assoc_session) + with self.assertRaisesRegexp(ProtocolError, 'Unsupported assoc_type for session'): + self.consumer._extractAssociation(self.assoc_response, self.assoc_session) def test_badExpiresIn(self): # Invalid value for expires_in should cause failure self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever') - self.failUnlessProtocolError('Invalid expires_in', - self.consumer._extractAssociation, self.assoc_response, self.assoc_session) + with self.assertRaisesRegexp(ProtocolError, 'Invalid expires_in'): + self.consumer._extractAssociation(self.assoc_response, self.assoc_session) # XXX: This is what causes most of the imports in this file. It is @@ -334,4 +331,5 @@ def test_openid2success(self): def test_badDHValues(self): sess, server_resp = self._setUpDH() server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00') - self.failUnlessProtocolError('Malformed response for', self.consumer._extractAssociation, server_resp, sess) + with self.assertRaisesRegexp(ProtocolError, 'Malformed response for'): + self.consumer._extractAssociation(server_resp, sess) diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index 3f287ac4..cc969878 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -2,7 +2,8 @@ from openid import message from openid.consumer import consumer -from openid.test import support + +from .utils import OpenIDTestMixin class DummyEndpoint(object): @@ -25,7 +26,7 @@ class DummyAssoc(object): handle = "assoc-handle" -class AuthRequestTestMixin(support.OpenIDTestMixin): +class AuthRequestTestMixin(OpenIDTestMixin): """Mixin for AuthRequest tests for OpenID 1 and 2; DON'T add unittest.TestCase as a base class here.""" @@ -44,21 +45,21 @@ def setUp(self): self.assoc = DummyAssoc() self.authreq = consumer.AuthRequest(self.endpoint, self.assoc) - def failUnlessAnonymous(self, msg): + def assertAnonymous(self, msg): for key in ['claimed_id', 'identity']: - self.failIfOpenIDKeyExists(msg, key) + self.assertOpenIDKeyMissing(msg, key) - def failUnlessHasRequiredFields(self, msg): + def assertHasRequiredFields(self, msg): self.assertEqual(self.authreq.message.getOpenIDNamespace(), self.preferred_namespace) self.assertEqual(msg.getOpenIDNamespace(), self.preferred_namespace) - self.failUnlessOpenIDValueEquals(msg, 'mode', + self.assertOpenIDValueEqual(msg, 'mode', self.expected_mode) # Implement these in subclasses because they depend on # protocol differences! - self.failUnlessHasRealm(msg) - self.failUnlessIdentifiersPresent(msg) + self.assertHasRealm(msg) + self.assertIdentifiersPresent(msg) # TESTS @@ -67,13 +68,13 @@ def test_checkNoAssocHandle(self): msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) - self.failIfOpenIDKeyExists(msg, 'assoc_handle') + self.assertOpenIDKeyMissing(msg, 'assoc_handle') def test_checkWithAssocHandle(self): msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) - self.failUnlessOpenIDValueEquals(msg, 'assoc_handle', + self.assertOpenIDValueEqual(msg, 'assoc_handle', self.assoc.handle) def test_addExtensionArg(self): @@ -95,28 +96,27 @@ def test_standard(self): msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) - self.failUnlessHasIdentifiers( - msg, self.endpoint.local_id, self.endpoint.claimed_id) + self.assertIdentifiers(msg, self.endpoint.local_id, self.endpoint.claimed_id) class TestAuthRequestOpenID2(AuthRequestTestMixin, unittest.TestCase): preferred_namespace = message.OPENID2_NS - def failUnlessHasRealm(self, msg): + def assertHasRealm(self, msg): # check presence of proper realm key and absence of the wrong # one. - self.failUnlessOpenIDValueEquals(msg, 'realm', self.realm) - self.failIfOpenIDKeyExists(msg, 'trust_root') + self.assertOpenIDValueEqual(msg, 'realm', self.realm) + self.assertOpenIDKeyMissing(msg, 'trust_root') - def failUnlessIdentifiersPresent(self, msg): + def assertIdentifiersPresent(self, msg): identity_present = msg.hasKey(message.OPENID_NS, 'identity') claimed_present = msg.hasKey(message.OPENID_NS, 'claimed_id') self.assertEqual(claimed_present, identity_present) - def failUnlessHasIdentifiers(self, msg, op_specific_id, claimed_id): - self.failUnlessOpenIDValueEquals(msg, 'identity', op_specific_id) - self.failUnlessOpenIDValueEquals(msg, 'claimed_id', claimed_id) + def assertIdentifiers(self, msg, op_specific_id, claimed_id): + self.assertOpenIDValueEqual(msg, 'identity', op_specific_id) + self.assertOpenIDValueEqual(msg, 'claimed_id', claimed_id) # TESTS @@ -130,43 +130,42 @@ def test_userAnonymousIgnoresIdentfier(self): self.authreq.setAnonymous(True) msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) - self.failUnlessHasRequiredFields(msg) - self.failUnlessAnonymous(msg) + self.assertHasRequiredFields(msg) + self.assertAnonymous(msg) def test_opAnonymousIgnoresIdentifier(self): self.endpoint.is_op_identifier = True self.authreq.setAnonymous(True) msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) - self.failUnlessHasRequiredFields(msg) - self.failUnlessAnonymous(msg) + self.assertHasRequiredFields(msg) + self.assertAnonymous(msg) def test_opIdentifierSendsIdentifierSelect(self): self.endpoint.is_op_identifier = True msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) - self.failUnlessHasRequiredFields(msg) - self.failUnlessHasIdentifiers( - msg, message.IDENTIFIER_SELECT, message.IDENTIFIER_SELECT) + self.assertHasRequiredFields(msg) + self.assertIdentifiers(msg, message.IDENTIFIER_SELECT, message.IDENTIFIER_SELECT) class TestAuthRequestOpenID1(AuthRequestTestMixin, unittest.TestCase): preferred_namespace = message.OPENID1_NS - def failUnlessHasIdentifiers(self, msg, op_specific_id, claimed_id): + def assertIdentifiers(self, msg, op_specific_id, claimed_id): """Make sure claimed_is is *absent* in request.""" - self.failUnlessOpenIDValueEquals(msg, 'identity', op_specific_id) - self.failIfOpenIDKeyExists(msg, 'claimed_id') + self.assertOpenIDValueEqual(msg, 'identity', op_specific_id) + self.assertOpenIDKeyMissing(msg, 'claimed_id') - def failUnlessIdentifiersPresent(self, msg): - self.failIfOpenIDKeyExists(msg, 'claimed_id') + def assertIdentifiersPresent(self, msg): + self.assertOpenIDKeyMissing(msg, 'claimed_id') self.assertTrue(msg.hasKey(message.OPENID_NS, 'identity')) - def failUnlessHasRealm(self, msg): + def assertHasRealm(self, msg): # check presence of proper realm key and absence of the wrong # one. - self.failUnlessOpenIDValueEquals(msg, 'trust_root', self.realm) - self.failIfOpenIDKeyExists(msg, 'realm') + self.assertOpenIDValueEqual(msg, 'trust_root', self.realm) + self.assertOpenIDKeyMissing(msg, 'realm') # TESTS @@ -185,7 +184,7 @@ def test_identifierSelect(self): self.endpoint.is_op_identifier = True msg = self.authreq.getMessage(self.realm, self.return_to, self.immediate) - self.failUnlessHasRequiredFields(msg) + self.assertHasRequiredFields(msg) self.assertEqual(msg.getArg(message.OPENID1_NS, 'identity'), message.IDENTIFIER_SELECT) diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 83fe5cf1..3221169c 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -84,25 +84,22 @@ def test_two(self): class ParseAXValuesTest(unittest.TestCase): """Testing AXKeyValueMessage.parseExtensionArgs.""" - def failUnlessAXKeyError(self, ax_args): - msg = ax.AXKeyValueMessage() - self.assertRaises(KeyError, msg.parseExtensionArgs, ax_args) - - def failUnlessAXValues(self, ax_args, expected_args): + def assertAXValues(self, ax_args, expected_args): """Fail unless parseExtensionArgs(ax_args) == expected_args.""" msg = ax.AXKeyValueMessage() msg.parseExtensionArgs(ax_args) self.assertEqual(msg.data, expected_args) def test_emptyIsValid(self): - self.failUnlessAXValues({}, {}) + self.assertAXValues({}, {}) def test_missingValueForAliasExplodes(self): - self.failUnlessAXKeyError({'type.foo': 'urn:foo'}) + msg = ax.AXKeyValueMessage() + self.assertRaises(KeyError, msg.parseExtensionArgs, {'type.foo': 'urn:foo'}) def test_countPresentButNotValue(self): - self.failUnlessAXKeyError({'type.foo': 'urn:foo', - 'count.foo': '1'}) + msg = ax.AXKeyValueMessage() + self.assertRaises(KeyError, msg.parseExtensionArgs, {'type.foo': 'urn:foo', 'count.foo': '1'}) def test_invalidCountValue(self): msg = ax.FetchRequest() @@ -154,38 +151,22 @@ def test_invalidAlias(self): self.assertRaises(ax.AXError, msg.parseExtensionArgs, input) def test_countPresentAndIsZero(self): - self.failUnlessAXValues( - {'type.foo': 'urn:foo', - 'count.foo': '0', - }, {'urn:foo': []}) + self.assertAXValues({'type.foo': 'urn:foo', 'count.foo': '0'}, {'urn:foo': []}) def test_singletonEmpty(self): - self.failUnlessAXValues( - {'type.foo': 'urn:foo', - 'value.foo': '', - }, {'urn:foo': []}) + self.assertAXValues({'type.foo': 'urn:foo', 'value.foo': ''}, {'urn:foo': []}) def test_doubleAlias(self): - self.failUnlessAXKeyError( - {'type.foo': 'urn:foo', - 'value.foo': '', - 'type.bar': 'urn:foo', - 'value.bar': '', - }) + msg = ax.AXKeyValueMessage() + self.assertRaises(KeyError, msg.parseExtensionArgs, + {'type.foo': 'urn:foo', 'value.foo': '', 'type.bar': 'urn:foo', 'value.bar': ''}) def test_doubleSingleton(self): - self.failUnlessAXValues( - {'type.foo': 'urn:foo', - 'value.foo': '', - 'type.bar': 'urn:bar', - 'value.bar': '', - }, {'urn:foo': [], 'urn:bar': []}) + self.assertAXValues({'type.foo': 'urn:foo', 'value.foo': '', 'type.bar': 'urn:bar', 'value.bar': ''}, + {'urn:foo': [], 'urn:bar': []}) def test_singletonValue(self): - self.failUnlessAXValues( - {'type.foo': 'urn:foo', - 'value.foo': 'Westfall', - }, {'urn:foo': ['Westfall']}) + self.assertAXValues({'type.foo': 'urn:foo', 'value.foo': 'Westfall'}, {'urn:foo': ['Westfall']}) class FetchRequestTest(unittest.TestCase): @@ -243,10 +224,7 @@ def test_getExtensionArgs_noAlias(self): else: self.fail("Didn't find the type definition") - self.failUnlessExtensionArgs({ - 'type.' + alias: attr.type_uri, - 'if_available': alias, - }) + self.assertExtensionArgs({'type.' + alias: attr.type_uri, 'if_available': alias}) def test_getExtensionArgs_alias_if_available(self): attr = ax.AttrInfo( @@ -254,10 +232,7 @@ def test_getExtensionArgs_alias_if_available(self): alias='transport', ) self.msg.add(attr) - self.failUnlessExtensionArgs({ - 'type.' + attr.alias: attr.type_uri, - 'if_available': attr.alias, - }) + self.assertExtensionArgs({'type.' + attr.alias: attr.type_uri, 'if_available': attr.alias}) def test_getExtensionArgs_alias_req(self): attr = ax.AttrInfo( @@ -266,12 +241,9 @@ def test_getExtensionArgs_alias_req(self): required=True, ) self.msg.add(attr) - self.failUnlessExtensionArgs({ - 'type.' + attr.alias: attr.type_uri, - 'required': attr.alias, - }) + self.assertExtensionArgs({'type.' + attr.alias: attr.type_uri, 'required': attr.alias}) - def failUnlessExtensionArgs(self, expected_args): + def assertExtensionArgs(self, expected_args): """Make sure that getExtensionArgs has the expected result This method will fill in the mode. diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index a427dbff..c3df47b3 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -497,7 +497,7 @@ def _vrfy(resp_msg, endpoint=None): self.consumer._verifyDiscoveryResults = _vrfy r = self.consumer.complete(self.message, self.endpoint, None) - self.failUnlessSuccess(r) + self.assertEqual(r.status, SUCCESS) def test_idResNoIdentity(self): self.message.delArg(OPENID_NS, 'identity') @@ -505,7 +505,7 @@ def test_idResNoIdentity(self): self.endpoint.claimed_id = None self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,op_endpoint') r = self.consumer.complete(self.message, self.endpoint, None) - self.failUnlessSuccess(r) + self.assertEqual(r.status, SUCCESS) def test_idResMissingIdentitySig(self): self.message.setArg(OPENID_NS, 'signed', 'return_to,response_nonce,assoc_handle,claimed_id') @@ -527,10 +527,6 @@ def test_idResMissingClaimedIDSig(self): r = self.consumer.complete(self.message, self.endpoint, None) self.assertEqual(r.status, FAILURE) - def failUnlessSuccess(self, response): - if response.status != SUCCESS: - self.fail("Non-successful response: %s" % (response,)) - class TestCheckAuthResponse(TestIdRes): def setUp(self): @@ -626,7 +622,8 @@ def test_invalidatePresent(self): class TestSetupNeeded(TestIdRes): - def failUnlessSetupNeeded(self, expected_setup_url, message): + + def assertSetupNeeded(self, expected_setup_url, message): with self.assertRaises(SetupNeededError) as catch: self.consumer._checkSetupNeeded(message) self.assertEqual(catch.exception.user_setup_url, expected_setup_url) @@ -639,7 +636,7 @@ def test_setupNeededOpenID1(self): 'openid.user_setup_url': setup_url, }) self.assertTrue(message.isOpenID1()) - self.failUnlessSetupNeeded(setup_url, message) + self.assertSetupNeeded(setup_url, message) def test_setupNeededOpenID1_extra(self): """Extra stuff along with setup_url still trigger Setup Needed""" @@ -650,7 +647,7 @@ def test_setupNeededOpenID1_extra(self): 'openid.identity': 'bogus', }) self.assertTrue(message.isOpenID1()) - self.failUnlessSetupNeeded(setup_url, message) + self.assertSetupNeeded(setup_url, message) def test_noSetupNeededOpenID1(self): """When the user_setup_url is missing on an OpenID 1 message, @@ -1599,7 +1596,7 @@ def verifyDiscoveryResults(identifier, endpoint): self.consumer._checkReturnTo = lambda unused1, unused2: True response = self.consumer._doIdRes(message, self.endpoint, None) - self.failUnlessSuccess(response) + self.assertEqual(response.status, SUCCESS) self.assertEqual(response.identity_url, "=directed_identifier") # assert that discovery attempt happens and returns good @@ -1621,10 +1618,6 @@ def verifyDiscoveryResults(identifier, endpoint): self.consumer._checkReturnTo = lambda unused1, unused2: True self.assertRaises(DiscoveryFailure, self.consumer._doIdRes, message, self.endpoint, None) - def failUnlessSuccess(self, response): - if response.status != SUCCESS: - self.fail("Non-successful response: %s" % (response,)) - class TestDiscoveryVerification(unittest.TestCase): services = [] @@ -1893,14 +1886,11 @@ def dummyDiscover(unused_identifier): self.consumer._discover = dummyDiscover self.to_match = OpenIDServiceEndpoint() - def failUnlessDiscoveryFailure(self): - self.assertRaises(DiscoveryFailure, self.consumer._discoverAndVerify, 'http://claimed-id.com/', [self.to_match]) - def test_noServices(self): """Discovery returning no results results in a DiscoveryFailure exception""" self.discovery_result = (None, []) - self.failUnlessDiscoveryFailure() + self.assertRaises(DiscoveryFailure, self.consumer._discoverAndVerify, 'http://claimed-id.com/', [self.to_match]) def test_noMatches(self): """If no discovered endpoint matches the values from the @@ -1911,7 +1901,7 @@ def test_noMatches(self): def raiseProtocolError(unused1, unused2): raise ProtocolError('unit test') self.consumer._verifyDiscoverySingle = raiseProtocolError - self.failUnlessDiscoveryFailure() + self.assertRaises(DiscoveryFailure, self.consumer._discoverAndVerify, 'http://claimed-id.com/', [self.to_match]) def test_matches(self): """If an endpoint matches, we return it diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 458ed636..9708cee9 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -708,7 +708,7 @@ class TestEndpointSupportsType(unittest.TestCase): def setUp(self): self.endpoint = discover.OpenIDServiceEndpoint() - def failUnlessSupportsOnly(self, *types): + def assertSupportsOnly(self, *types): for t in [ 'foo', discover.OPENID_1_1_TYPE, @@ -722,39 +722,34 @@ def failUnlessSupportsOnly(self, *types): self.assertFalse(self.endpoint.supportsType(t), "Shouldn't support %r" % (t,)) def test_supportsNothing(self): - self.failUnlessSupportsOnly() + self.assertSupportsOnly() def test_openid2(self): self.endpoint.type_uris = [discover.OPENID_2_0_TYPE] - self.failUnlessSupportsOnly(discover.OPENID_2_0_TYPE) + self.assertSupportsOnly(discover.OPENID_2_0_TYPE) def test_openid2provider(self): self.endpoint.type_uris = [discover.OPENID_IDP_2_0_TYPE] - self.failUnlessSupportsOnly(discover.OPENID_IDP_2_0_TYPE, - discover.OPENID_2_0_TYPE) + self.assertSupportsOnly(discover.OPENID_IDP_2_0_TYPE, discover.OPENID_2_0_TYPE) def test_openid1_0(self): self.endpoint.type_uris = [discover.OPENID_1_0_TYPE] - self.failUnlessSupportsOnly(discover.OPENID_1_0_TYPE) + self.assertSupportsOnly(discover.OPENID_1_0_TYPE) def test_openid1_1(self): self.endpoint.type_uris = [discover.OPENID_1_1_TYPE] - self.failUnlessSupportsOnly(discover.OPENID_1_1_TYPE) + self.assertSupportsOnly(discover.OPENID_1_1_TYPE) def test_multiple(self): self.endpoint.type_uris = [discover.OPENID_1_1_TYPE, discover.OPENID_2_0_TYPE] - self.failUnlessSupportsOnly(discover.OPENID_1_1_TYPE, - discover.OPENID_2_0_TYPE) + self.assertSupportsOnly(discover.OPENID_1_1_TYPE, discover.OPENID_2_0_TYPE) def test_multipleWithProvider(self): self.endpoint.type_uris = [discover.OPENID_1_1_TYPE, discover.OPENID_2_0_TYPE, discover.OPENID_IDP_2_0_TYPE] - self.failUnlessSupportsOnly(discover.OPENID_1_1_TYPE, - discover.OPENID_2_0_TYPE, - discover.OPENID_IDP_2_0_TYPE, - ) + self.assertSupportsOnly(discover.OPENID_1_1_TYPE, discover.OPENID_2_0_TYPE, discover.OPENID_IDP_2_0_TYPE) class TestEndpointDisplayIdentifier(unittest.TestCase): diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index b1c066d0..16f615a2 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -13,7 +13,7 @@ # XXX: make these separate test cases -def failUnlessResponseExpected(expected, actual): +def assertResponse(expected, actual): assert expected.final_url == actual.final_url, ( "%r != %r" % (expected.final_url, actual.final_url)) assert expected.status == actual.status @@ -63,7 +63,7 @@ def plain(path, code): print fetcher, fetch_url raise else: - failUnlessResponseExpected(expected, actual) + assertResponse(expected, actual) for err_url in [geturl('/closed'), 'http://invalid.janrain.com/', @@ -302,21 +302,21 @@ def test_success(self): self.add_response('http://example.cz/success/', 200, {'Content-Type': 'text/plain'}, 'BODY') response = self.fetcher.fetch('http://example.cz/success/') expected = fetchers.HTTPResponse('http://example.cz/success/', 200, {'Content-Type': 'text/plain'}, 'BODY') - failUnlessResponseExpected(expected, response) + assertResponse(expected, response) def test_redirect(self): # Test redirect response - a final response comes from another URL. self.add_response('http://example.cz/success/', 200, {'Content-Type': 'text/plain'}, 'BODY') response = self.fetcher.fetch('http://example.cz/redirect/') expected = fetchers.HTTPResponse('http://example.cz/success/', 200, {'Content-Type': 'text/plain'}, 'BODY') - failUnlessResponseExpected(expected, response) + assertResponse(expected, response) def test_error(self): # Test error responses - returned as obtained self.add_response('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') response = self.fetcher.fetch('http://example.cz/error/') expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') - failUnlessResponseExpected(expected, response) + assertResponse(expected, response) def test_invalid_url(self): with self.assertRaisesRegexp(self.invalid_url_error, 'Bad URL scheme:'): @@ -328,7 +328,7 @@ def test_connection_error(self): {'Content-Type': 'text/plain'}, StringIO('BODY')) response = self.fetcher.fetch('http://example.cz/error/') expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') - failUnlessResponseExpected(expected, response) + assertResponse(expected, response) class TestSilencedUrllib2Fetcher(TestUrllib2Fetcher): diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index 04b693ee..cbcb6dfd 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -17,7 +17,7 @@ class TestBuildDiscoveryURL(unittest.TestCase): return_to URL """ - def failUnlessDiscoURL(self, realm, expected_discovery_url): + def assertDiscoveryURL(self, realm, expected_discovery_url): """Build a discovery URL out of the realm and a return_to and make sure that it matches the expected discovery URL """ @@ -28,20 +28,17 @@ def failUnlessDiscoURL(self, realm, expected_discovery_url): def test_trivial(self): """There is no wildcard and the realm is the same as the return_to URL """ - self.failUnlessDiscoURL('http://example.com/foo', - 'http://example.com/foo') + self.assertDiscoveryURL('http://example.com/foo', 'http://example.com/foo') def test_wildcard(self): """There is a wildcard """ - self.failUnlessDiscoURL('http://*.example.com/foo', - 'http://www.example.com/foo') + self.assertDiscoveryURL('http://*.example.com/foo', 'http://www.example.com/foo') def test_wildcard_port(self): """There is a wildcard """ - self.failUnlessDiscoURL('http://*.example.com:8001/foo', - 'http://www.example.com:8001/foo') + self.assertDiscoveryURL('http://*.example.com:8001/foo', 'http://www.example.com:8001/foo') class TestExtractReturnToURLs(unittest.TestCase): @@ -61,29 +58,24 @@ def mockDiscover(self, uri): result.normalized_uri = uri return result - def failUnlessFileHasReturnURLs(self, filename, expected_return_urls): - self.failUnlessXRDSHasReturnURLs(file(filename).read(), - expected_return_urls) - - def failUnlessXRDSHasReturnURLs(self, data, expected_return_urls): + def assertReturnURLs(self, data, expected_return_urls): self.data = data - actual_return_urls = list(trustroot.getAllowedReturnURLs( - self.disco_url)) + actual_return_urls = trustroot.getAllowedReturnURLs(self.disco_url) self.assertEqual(actual_return_urls, expected_return_urls) - def failUnlessDiscoveryFailure(self, text): + def assertDiscoveryFailure(self, text): self.data = text self.assertRaises(DiscoveryFailure, trustroot.getAllowedReturnURLs, self.disco_url) def test_empty(self): - self.failUnlessDiscoveryFailure('') + self.assertDiscoveryFailure('') def test_badXML(self): - self.failUnlessDiscoveryFailure('>') + self.assertDiscoveryFailure('>') def test_noEntries(self): - self.failUnlessXRDSHasReturnURLs('''\ + self.assertReturnURLs('''\ Date: Fri, 26 Jan 2018 16:52:40 +0100 Subject: [PATCH 040/151] Replace 'cgi.parse_qs*' functions --- contrib/openid-parse | 9 ++++----- examples/consumer.py | 2 +- examples/server.py | 6 +++--- openid/consumer/consumer.py | 5 ++--- openid/test/test_consumer.py | 3 +-- openid/test/test_message.py | 6 +++--- openid/test/test_server.py | 21 ++++++++++----------- 7 files changed, 24 insertions(+), 28 deletions(-) diff --git a/contrib/openid-parse b/contrib/openid-parse index ac2c5dff..96bba315 100644 --- a/contrib/openid-parse +++ b/contrib/openid-parse @@ -8,8 +8,7 @@ Requires the 'xsel' program to get the contents of the clipboard. """ from pprint import pformat -from urlparse import urlsplit, urlunsplit -import cgi +from urlparse import urlsplit, urlunsplit, parse_qs import re import subprocess import sys @@ -63,7 +62,7 @@ def main(): def queryFromURL(url): split_url = urlsplit(url) - query = cgi.parse_qs(split_url[3]) + query = parse_qs(split_url[3]) if not query: raise NoQuery(url) @@ -124,7 +123,7 @@ def unlistify(d): def queriesFromLogs(s): qre = re.compile(r'GET (/.*)?\?(.+) HTTP') - return [(match.group(1), cgi.parse_qs(match.group(2))) + return [(match.group(1), parse_qs(match.group(2))) for match in qre.finditer(s)] @@ -135,7 +134,7 @@ def queriesFromPostdata(s): qre = re.compile(r'(?:^Host=(?P.+?)$.*?)?^POSTDATA=(?P.*)$', re.DOTALL | re.MULTILINE) return [(match.group('host') or 'POSTDATA', - cgi.parse_qs(match.group('query'))) for match in qre.finditer(s)] + parse_qs(match.group('query'))) for match in qre.finditer(s)] def find_urls(s): diff --git a/examples/consumer.py b/examples/consumer.py index 908130af..fa6b3f01 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -133,7 +133,7 @@ def do_GET(self): try: self.parsed_uri = urlparse.urlparse(self.path) self.query = {} - for k, v in cgi.parse_qsl(self.parsed_uri[4]): + for k, v in urlparse.parse_qsl(self.parsed_uri[4]): self.query[k] = v.decode('utf-8') path = self.parsed_uri[2] diff --git a/examples/server.py b/examples/server.py index 2da8835c..5ef52a69 100644 --- a/examples/server.py +++ b/examples/server.py @@ -9,7 +9,7 @@ import sys import time from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer -from urlparse import urlparse +from urlparse import parse_qsl, urlparse def quoteattr(s): @@ -69,7 +69,7 @@ def do_GET(self): try: self.parsed_uri = urlparse(self.path) self.query = {} - for k, v in cgi.parse_qsl(self.parsed_uri[4]): + for k, v in parse_qsl(self.parsed_uri[4]): self.query[k] = v self.setUser() @@ -110,7 +110,7 @@ def do_POST(self): post_data = self.rfile.read(content_length) self.query = {} - for k, v in cgi.parse_qsl(post_data): + for k, v in parse_qsl(post_data): self.query[k] = v path = self.parsed_uri[2] diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index c811ce05..a2289381 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -187,10 +187,9 @@ objects. """ -import cgi import copy import logging -from urlparse import urldefrag, urlparse +from urlparse import parse_qsl, urldefrag, urlparse from openid import cryptutil, fetchers, oidutil, urinorm from openid.association import Association, SessionNegotiator, default_negotiator @@ -845,7 +844,7 @@ def _verifyReturnToArgs(query): parsed_url = urlparse(return_to) rt_query = parsed_url[4] - parsed_args = cgi.parse_qsl(rt_query, keep_blank_values=True) + parsed_args = parse_qsl(rt_query, keep_blank_values=True) for rt_key, rt_value in parsed_args: try: diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index c3df47b3..2b0a2af8 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -1,4 +1,3 @@ -import cgi import time import unittest import urlparse @@ -37,7 +36,7 @@ def mkSuccess(endpoint, q): def parseQuery(qs): q = {} - for (k, v) in cgi.parse_qsl(qs): + for (k, v) in urlparse.parse_qsl(qs): assert k not in q q[k] = v return q diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 3ce94eaa..d9b91c86 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -import cgi import unittest import urllib +from urlparse import parse_qs from openid import message, oidutil from openid.extensions import sreg @@ -229,7 +229,7 @@ def test_toURL(self): self.assertEqual(actual_base, base_url) self.assertEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] - parsed = cgi.parse_qs(query) + parsed = parse_qs(query) self.assertEqual(parsed, {'openid.mode': ['error'], 'openid.error': ['unit test']}) def test_getOpenID(self): @@ -399,7 +399,7 @@ def test_toURL(self): self.assertEqual(actual_base, base_url) self.assertEqual(actual[len(base_url)], '?') query = actual[len(base_url) + 1:] - parsed = cgi.parse_qs(query) + parsed = parse_qs(query) self.assertEqual(parsed, {'openid.mode': ['error'], 'openid.error': ['unit test'], 'openid.ns': [message.OPENID1_NS]}) diff --git a/openid/test/test_server.py b/openid/test/test_server.py index ba192068..a3296fc9 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -1,9 +1,8 @@ """Tests for openid.server. """ -import cgi import unittest from functools import partial -from urlparse import urlparse +from urlparse import parse_qs, parse_qsl, urlparse from testfixtures import LogCapture, StringComparison @@ -43,7 +42,7 @@ def test_browserWithReturnTo(self): } rt_base, result_args = e.encodeToURL().split('?', 1) - result_args = cgi.parse_qs(result_args) + result_args = parse_qs(result_args) self.assertEqual(result_args, expected_args) def test_browserWithReturnTo_OpenID2_GET(self): @@ -65,7 +64,7 @@ def test_browserWithReturnTo_OpenID2_GET(self): } rt_base, result_args = e.encodeToURL().split('?', 1) - result_args = cgi.parse_qs(result_args) + result_args = parse_qs(result_args) self.assertEqual(result_args, expected_args) def test_browserWithReturnTo_OpenID2_POST(self): @@ -101,7 +100,7 @@ def test_browserWithReturnTo_OpenID1_exceeds_limit(self): self.assertEqual(e.whichEncoding(), server.ENCODE_URL) rt_base, result_args = e.encodeToURL().split('?', 1) - result_args = cgi.parse_qs(result_args) + result_args = parse_qs(result_args) self.assertEqual(result_args, expected_args) def test_noReturnTo(self): @@ -636,7 +635,7 @@ def test_id_res(self): self.assertTrue(location.startswith(request.return_to), "%s does not start with %s" % (location, request.return_to)) # argh. - q2 = dict(cgi.parse_qsl(urlparse(location)[4])) + q2 = dict(parse_qsl(urlparse(location)[4])) expected = response.fields.toPostArgs() self.assertEqual(q2, expected) @@ -760,7 +759,7 @@ def test_idres(self): self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] - query = cgi.parse_qs(urlparse(location)[4]) + query = parse_qs(urlparse(location)[4]) self.assertIn('openid.sig', query) self.assertIn('openid.assoc_handle', query) self.assertIn('openid.signed', query) @@ -771,7 +770,7 @@ def test_idresDumb(self): self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] - query = cgi.parse_qs(urlparse(location)[4]) + query = parse_qs(urlparse(location)[4]) self.assertIn('openid.sig', query) self.assertIn('openid.assoc_handle', query) self.assertIn('openid.signed', query) @@ -795,7 +794,7 @@ def test_cancel(self): self.assertEqual(webresponse.code, server.HTTP_REDIRECT) self.assertIn('location', webresponse.headers) location = webresponse.headers['location'] - query = cgi.parse_qs(urlparse(location)[4]) + query = parse_qs(urlparse(location)[4]) self.assertNotIn('openid.sig', query, response.fields.toPostArgs()) def test_assocReply(self): @@ -1188,7 +1187,7 @@ def test_encodeToURL(self): # How to check? How about a round-trip test. base, result_args = result.split('?', 1) - result_args = dict(cgi.parse_qsl(result_args)) + result_args = dict(parse_qsl(result_args)) message = Message.fromPostArgs(result_args) rebuilt_request = server.CheckIDRequest.fromMessage(message, self.server.op_endpoint) @@ -1200,7 +1199,7 @@ def test_getCancelURL(self): url = self.request.getCancelURL() rt, query_string = url.split('?') self.assertEqual(self.request.return_to, rt) - query = dict(cgi.parse_qsl(query_string)) + query = dict(parse_qsl(query_string)) self.assertEqual(query, {'openid.mode': 'cancel', 'openid.ns': OPENID2_NS}) def test_getCancelURLimmed(self): From c904ce78e6ebd9c47ef7b83facacd407d27265a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 1 Feb 2018 12:55:30 +0100 Subject: [PATCH 041/151] Update imports for new isort version --- admin/builddiscover.py | 1 - contrib/associate | 5 ++--- contrib/openid-parse | 4 ++-- contrib/upgrade-store-1.1-to-2.0 | 2 +- openid/test/test_openidyadis.py | 1 - 5 files changed, 5 insertions(+), 8 deletions(-) diff --git a/admin/builddiscover.py b/admin/builddiscover.py index ef4ede92..82681280 100755 --- a/admin/builddiscover.py +++ b/admin/builddiscover.py @@ -4,7 +4,6 @@ from openid.test import discoverdata - manifest_header = """\ # This file contains test cases for doing YADIS identity URL and # service discovery. For each case, there are three URLs. The first diff --git a/contrib/associate b/contrib/associate index 76fe5b0e..17eae99f 100755 --- a/contrib/associate +++ b/contrib/associate @@ -3,12 +3,11 @@ and print the results.""" import sys +from datetime import datetime -from openid.store.memstore import MemoryStore from openid.consumer import consumer from openid.consumer.discover import OpenIDServiceEndpoint - -from datetime import datetime +from openid.store.memstore import MemoryStore def verboseAssociation(assoc): diff --git a/contrib/openid-parse b/contrib/openid-parse index 96bba315..5915ad36 100644 --- a/contrib/openid-parse +++ b/contrib/openid-parse @@ -7,11 +7,11 @@ with a pattern like 'GET /foo?bar=baz HTTP'. Requires the 'xsel' program to get the contents of the clipboard. """ -from pprint import pformat -from urlparse import urlsplit, urlunsplit, parse_qs import re import subprocess import sys +from pprint import pformat +from urlparse import parse_qs, urlsplit, urlunsplit from openid import message diff --git a/contrib/upgrade-store-1.1-to-2.0 b/contrib/upgrade-store-1.1-to-2.0 index 1907ce37..2e73c4a2 100644 --- a/contrib/upgrade-store-1.1-to-2.0 +++ b/contrib/upgrade-store-1.1-to-2.0 @@ -14,8 +14,8 @@ # * test data for mysql and postgresql. # * automated tests. -import os import getpass +import os import sys from optparse import OptionParser diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 91e7e4ac..0b19ce2d 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -3,7 +3,6 @@ from openid.consumer.discover import OPENID_1_0_TYPE, OPENID_1_1_TYPE, OpenIDServiceEndpoint from openid.yadis.services import applyFilter - XRDS_BOILERPLATE = '''\ Date: Fri, 2 Feb 2018 18:53:45 +0100 Subject: [PATCH 042/151] Use common decorators as such --- openid/association.py | 6 ++---- openid/consumer/consumer.py | 6 ++---- openid/consumer/discover.py | 15 +++++---------- openid/dh.py | 3 +-- openid/extensions/ax.py | 12 ++++-------- openid/extensions/draft/pape2.py | 6 ++---- openid/extensions/draft/pape5.py | 14 +++++--------- openid/extensions/sreg.py | 9 +++------ openid/message.py | 9 +++------ openid/server/server.py | 20 +++++++------------- openid/server/trustroot.py | 9 +++------ openid/yadis/filters.py | 3 +-- 12 files changed, 38 insertions(+), 74 deletions(-) diff --git a/openid/association.py b/openid/association.py index 8a52b78f..920bd634 100644 --- a/openid/association.py +++ b/openid/association.py @@ -259,6 +259,7 @@ class Association(object): 'HMAC-SHA256': cryptutil.hmacSha256, } + @classmethod def fromExpiresIn(cls, expires_in, handle, secret, assoc_type): """ This is an alternate constructor used by the OpenID consumer @@ -297,8 +298,6 @@ def fromExpiresIn(cls, expires_in, handle, secret, assoc_type): lifetime = expires_in return cls(handle, secret, issued, lifetime, assoc_type) - fromExpiresIn = classmethod(fromExpiresIn) - def __init__(self, handle, secret, issued, lifetime, assoc_type): """ This is the standard constructor for creating an association. @@ -421,6 +420,7 @@ def serialize(self): return kvform.seqToKV(pairs, strict=True) + @classmethod def deserialize(cls, assoc_s): """ Parse an association as stored by serialize(). @@ -453,8 +453,6 @@ def deserialize(cls, assoc_s): secret = oidutil.fromBase64(secret) return cls(handle, secret, issued, lifetime, assoc_type) - deserialize = classmethod(deserialize) - def sign(self, pairs): """ Generate a signature for a sequence of (key, value) pairs diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index a2289381..a836c900 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -539,6 +539,7 @@ def __init__(self, error_text, error_code, message): self.error_code = error_code self.message = message + @classmethod def fromMessage(cls, message): """Generate a ServerError instance, extracting the error text and the error code from the message.""" @@ -547,8 +548,6 @@ def fromMessage(cls, message): error_code = message.getArg(OPENID_NS, 'error_code') return cls(error_text, error_code, message) - fromMessage = classmethod(fromMessage) - class GenericConsumer(object): """This is the implementation of the common logic for OpenID @@ -832,6 +831,7 @@ def _idResCheckForFields(self, message): if message.hasKey(OPENID_NS, field) and field not in signed_list: raise ProtocolError('"%s" not signed' % (field,)) + @staticmethod def _verifyReturnToArgs(query): """Verify that the arguments in the return_to URL are present in this response. @@ -864,8 +864,6 @@ def _verifyReturnToArgs(query): if pair not in parsed_args: raise ProtocolError("Parameter %s not in return_to URL" % (pair[0],)) - _verifyReturnToArgs = staticmethod(_verifyReturnToArgs) - def _verifyDiscoveryResults(self, resp_msg, endpoint=None): """ Extract the information from an OpenID assertion message and diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index f847c63a..b9bc30e5 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -119,6 +119,7 @@ def getLocalID(self): else: return self.local_id or self.canonicalID + @classmethod def fromBasicServiceEndpoint(cls, endpoint): """Create a new instance of this class from the endpoint object passed in. @@ -140,8 +141,7 @@ def fromBasicServiceEndpoint(cls, endpoint): return openid_endpoint - fromBasicServiceEndpoint = classmethod(fromBasicServiceEndpoint) - + @classmethod def fromHTML(cls, uri, html): """Parse the given document as HTML looking for an OpenID @@ -172,8 +172,7 @@ def fromHTML(cls, uri, html): return services - fromHTML = classmethod(fromHTML) - + @classmethod def fromXRDS(cls, uri, xrds): """Parse the given document as XRDS looking for OpenID services. @@ -185,8 +184,7 @@ def fromXRDS(cls, uri, xrds): """ return extractServices(uri, xrds, cls) - fromXRDS = classmethod(fromXRDS) - + @classmethod def fromDiscoveryResult(cls, discoveryResult): """Create endpoints from a DiscoveryResult. @@ -205,8 +203,7 @@ def fromDiscoveryResult(cls, discoveryResult): return method(discoveryResult.normalized_uri, discoveryResult.response_text) - fromDiscoveryResult = classmethod(fromDiscoveryResult) - + @classmethod def fromOPEndpointURL(cls, op_endpoint_url): """Construct an OP-Identifier OpenIDServiceEndpoint object for a given OP Endpoint URL @@ -219,8 +216,6 @@ def fromOPEndpointURL(cls, op_endpoint_url): service.type_uris = [OPENID_IDP_2_0_TYPE] return service - fromOPEndpointURL = classmethod(fromOPEndpointURL) - def __str__(self): return ("<%s.%s " "server_url=%r " diff --git a/openid/dh.py b/openid/dh.py index b0400b9e..5b7a4400 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -21,11 +21,10 @@ class DiffieHellman(object): DEFAULT_GEN = 2 + @classmethod def fromDefaults(cls): return cls(cls.DEFAULT_MOD, cls.DEFAULT_GEN) - fromDefaults = classmethod(fromDefaults) - def __init__(self, modulus, generator): self.modulus = long(modulus) self.generator = long(generator) diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index c8fac3f4..b48d19ce 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -273,6 +273,7 @@ def getRequiredAttrs(self): return required + @classmethod def fromOpenIDRequest(cls, openid_request): """Extract a FetchRequest from an OpenID message @@ -316,8 +317,6 @@ def fromOpenIDRequest(cls, openid_request): return self - fromOpenIDRequest = classmethod(fromOpenIDRequest) - def parseExtensionArgs(self, ax_args): """Given attribute exchange arguments, populate this FetchRequest. @@ -671,6 +670,7 @@ def parseExtensionArgs(self, ax_args): super(FetchResponse, self).parseExtensionArgs(ax_args) self.update_url = ax_args.get('update_url') + @classmethod def fromSuccessResponse(cls, success_response, signed=True): """Construct a FetchResponse object from an OpenID library SuccessResponse object. @@ -699,8 +699,6 @@ def fromSuccessResponse(cls, success_response, signed=True): else: return self - fromSuccessResponse = classmethod(fromSuccessResponse) - class StoreRequest(AXKeyValueMessage): """A store request attribute exchange message representation @@ -724,6 +722,7 @@ def getExtensionArgs(self): ax_args.update(kv_args) return ax_args + @classmethod def fromOpenIDRequest(cls, openid_request): """Extract a StoreRequest from an OpenID message @@ -752,8 +751,6 @@ def fromOpenIDRequest(cls, openid_request): return self - fromOpenIDRequest = classmethod(fromOpenIDRequest) - class StoreResponse(AXMessage): """An indication that the store request was processed along with @@ -787,6 +784,7 @@ def getExtensionArgs(self): return ax_args + @classmethod def fromSuccessResponse(cls, success_response, signed=True): """Construct a StoreResponse object from an OpenID library SuccessResponse object. @@ -814,5 +812,3 @@ def fromSuccessResponse(cls, success_response, signed=True): return None else: return self - - fromSuccessResponse = classmethod(fromSuccessResponse) diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index f9b84c84..954c5c00 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -84,6 +84,7 @@ def getExtensionArgs(self): return ns_args + @classmethod def fromOpenIDRequest(cls, request): """Instantiate a Request object from the arguments in a C{checkid_*} OpenID message @@ -97,8 +98,6 @@ def fromOpenIDRequest(cls, request): self.parseExtensionArgs(args) return self - fromOpenIDRequest = classmethod(fromOpenIDRequest) - def parseExtensionArgs(self, args): """Set the state of this request to be that expressed in these PAPE arguments @@ -184,6 +183,7 @@ def addPolicyURI(self, policy_uri): if policy_uri not in self.auth_policies: self.auth_policies.append(policy_uri) + @classmethod def fromSuccessResponse(cls, success_response): """Create a C{L{Response}} object from a successful OpenID library response @@ -249,8 +249,6 @@ def parseExtensionArgs(self, args, strict=False): elif strict: raise ValueError("auth_time must be in RFC3339 format") - fromSuccessResponse = classmethod(fromSuccessResponse) - def getExtensionArgs(self): """@see: C{L{Extension.getExtensionArgs}} """ diff --git a/openid/extensions/draft/pape5.py b/openid/extensions/draft/pape5.py index 6d0b1ddf..e7568dd1 100644 --- a/openid/extensions/draft/pape5.py +++ b/openid/extensions/draft/pape5.py @@ -172,6 +172,7 @@ def getExtensionArgs(self): return ns_args + @classmethod def fromOpenIDRequest(cls, request): """Instantiate a Request object from the arguments in a C{checkid_*} OpenID message @@ -186,8 +187,6 @@ def fromOpenIDRequest(cls, request): self.parseExtensionArgs(args, is_openid1) return self - fromOpenIDRequest = classmethod(fromOpenIDRequest) - def parseExtensionArgs(self, args, is_openid1, strict=False): """Set the state of this request to be that expressed in these PAPE arguments @@ -325,16 +324,14 @@ def getAuthLevel(self, level_uri): """ return self.auth_levels[level_uri] - def _getNISTAuthLevel(self): + @property + def nist_auth_level(self): + """Backward-compatibility accessor for the NIST auth level.""" try: return int(self.getAuthLevel(LEVELS_NIST)) except KeyError: return None - nist_auth_level = property( - _getNISTAuthLevel, - doc="Backward-compatibility accessor for the NIST auth level") - def addPolicyURI(self, policy_uri): """Add a authentication policy to this response @@ -352,6 +349,7 @@ def addPolicyURI(self, policy_uri): if policy_uri not in self.auth_policies: self.auth_policies.append(policy_uri) + @classmethod def fromSuccessResponse(cls, success_response): """Create a C{L{Response}} object from a successful OpenID library response @@ -447,8 +445,6 @@ def parseExtensionArgs(self, args, is_openid1, strict=False): elif strict: raise ValueError("auth_time must be in RFC3339 format") - fromSuccessResponse = classmethod(fromSuccessResponse) - def getExtensionArgs(self): """@see: C{L{Extension.getExtensionArgs}} """ diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 786aeeaf..4bdb262e 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -203,6 +203,7 @@ def __init__(self, required=None, optional=None, policy_url=None, # overridden for testing. _getSRegNS = staticmethod(getSRegNS) + @classmethod def fromOpenIDRequest(cls, request): """Create a simple registration request that contains the fields that were requested in the OpenID request with the @@ -226,8 +227,6 @@ def fromOpenIDRequest(cls, request): return self - fromOpenIDRequest = classmethod(fromOpenIDRequest) - def parseExtensionArgs(self, args, strict=False): """Parse the unqualified simple registration request parameters and add them to this object. @@ -404,6 +403,7 @@ def __init__(self, data=None, sreg_ns_uri=ns_uri): self.ns_uri = sreg_ns_uri + @classmethod def extractResponse(cls, request, data): """Take a C{L{SRegRequest}} and a dictionary of simple registration values and create a C{L{SRegResponse}} @@ -430,12 +430,11 @@ def extractResponse(cls, request, data): self.data[field] = value return self - extractResponse = classmethod(extractResponse) - # Assign getSRegArgs to a static method so that it can be # overridden for testing _getSRegNS = staticmethod(getSRegNS) + @classmethod def fromSuccessResponse(cls, success_response, signed_only=True): """Create a C{L{SRegResponse}} object from a successful OpenID library response @@ -469,8 +468,6 @@ def fromSuccessResponse(cls, success_response, signed_only=True): return self - fromSuccessResponse = classmethod(fromSuccessResponse) - def getExtensionArgs(self): """Get the fields to put in the simple registration namespace when adding them to an id_res message. diff --git a/openid/message.py b/openid/message.py index 9c487d60..c7d53230 100644 --- a/openid/message.py +++ b/openid/message.py @@ -150,6 +150,7 @@ def __init__(self, openid_namespace=None): implicit = openid_namespace in OPENID1_NAMESPACES self.setOpenIDNamespace(openid_namespace, implicit) + @classmethod def fromPostArgs(cls, args): """Construct a Message containing a set of POST arguments. @@ -177,8 +178,7 @@ def fromPostArgs(cls, args): return self - fromPostArgs = classmethod(fromPostArgs) - + @classmethod def fromOpenIDArgs(cls, openid_args): """Construct a Message from a parsed KVForm message. @@ -189,8 +189,6 @@ def fromOpenIDArgs(cls, openid_args): self._fromOpenIDArgs(openid_args) return self - fromOpenIDArgs = classmethod(fromOpenIDArgs) - def _fromOpenIDArgs(self, openid_args): ns_args = [] @@ -260,12 +258,11 @@ def isOpenID1(self): def isOpenID2(self): return self.getOpenIDNamespace() == OPENID2_NS + @classmethod def fromKVForm(cls, kvform_string): """Create a Message from a KVForm string""" return cls.fromOpenIDArgs(kvform.kvToDict(kvform_string)) - fromKVForm = classmethod(fromKVForm) - def copy(self): return copy.deepcopy(self) diff --git a/openid/server/server.py b/openid/server/server.py index 436b8add..8d45bc8d 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -191,6 +191,7 @@ def __init__(self, assoc_handle, signed, invalidate_handle=None): self.invalidate_handle = invalidate_handle self.namespace = OPENID2_NS + @classmethod def fromMessage(klass, message, op_endpoint=UNUSED): """Construct me from an OpenID Message. @@ -223,8 +224,6 @@ def fromMessage(klass, message, op_endpoint=UNUSED): return self - fromMessage = classmethod(fromMessage) - def answer(self, signatory): """Respond to this request. @@ -280,11 +279,10 @@ class PlainTextServerSession(object): session_type = 'no-encryption' allowed_assoc_types = ['HMAC-SHA1', 'HMAC-SHA256'] + @classmethod def fromMessage(cls, unused_request): return cls() - fromMessage = classmethod(fromMessage) - def answer(self, secret): return {'mac_key': oidutil.toBase64(secret)} @@ -316,6 +314,7 @@ def __init__(self, dh, consumer_pubkey): self.dh = dh self.consumer_pubkey = consumer_pubkey + @classmethod def fromMessage(cls, message): """ @param message: The associate request message @@ -357,8 +356,6 @@ def fromMessage(cls, message): return cls(dh, consumer_pubkey) - fromMessage = classmethod(fromMessage) - def answer(self, secret): mac_key = self.dh.xorSecret(self.consumer_pubkey, secret, @@ -411,6 +408,7 @@ def __init__(self, session, assoc_type): self.assoc_type = assoc_type self.namespace = OPENID2_NS + @classmethod def fromMessage(klass, message, op_endpoint=UNUSED): """Construct me from an OpenID Message. @@ -462,8 +460,6 @@ def fromMessage(klass, message, op_endpoint=UNUSED): self.namespace = message.getOpenIDNamespace() return self - fromMessage = classmethod(fromMessage) - def answer(self, assoc): """Respond to this request with an X{association}. @@ -578,14 +574,14 @@ def __init__(self, identity, return_to, trust_root=None, immediate=False, raise UntrustedReturnURL(None, self.return_to, self.trust_root) self.message = None - def _getNamespace(self): + @property + def namespace(self): warnings.warn('The "namespace" attribute of CheckIDRequest objects ' 'is deprecated. Use "message.getOpenIDNamespace()" ' 'instead', DeprecationWarning, stacklevel=2) return self.message.getOpenIDNamespace() - namespace = property(_getNamespace) - + @classmethod def fromMessage(klass, message, op_endpoint): """Construct me from an OpenID message. @@ -678,8 +674,6 @@ def fromMessage(klass, message, op_endpoint): return self - fromMessage = classmethod(fromMessage) - def idSelect(self): """Is the identifier to be selected by the IDP? diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index ec771b9b..a71ed718 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -246,6 +246,7 @@ def validateURL(self, url): return True + @classmethod def parse(cls, trust_root): """ This method creates a C{L{TrustRoot}} instance from the given @@ -298,8 +299,7 @@ def parse(cls, trust_root): return tr - parse = classmethod(parse) - + @classmethod def checkSanity(cls, trust_root_string): """str -> bool @@ -311,16 +311,13 @@ def checkSanity(cls, trust_root_string): else: return trust_root.isSane() - checkSanity = classmethod(checkSanity) - + @classmethod def checkURL(cls, trust_root, url): """quick func for validating a url against a trust root. See the TrustRoot class if you need more control.""" tr = cls.parse(trust_root) return tr is not None and tr.validateURL(url) - checkURL = classmethod(checkURL) - def buildDiscoveryURL(self): """Return a discovery URL for this realm. diff --git a/openid/yadis/filters.py b/openid/yadis/filters.py index 1a9d3e74..0d87ad0e 100644 --- a/openid/yadis/filters.py +++ b/openid/yadis/filters.py @@ -48,6 +48,7 @@ def matchTypes(self, type_uris): """ return [uri for uri in type_uris if uri in self.type_uris] + @staticmethod def fromBasicServiceEndpoint(endpoint): """Trivial transform from a basic endpoint to itself. This method exists to allow BasicServiceEndpoint to be used as a @@ -60,8 +61,6 @@ def fromBasicServiceEndpoint(endpoint): """ return endpoint - fromBasicServiceEndpoint = staticmethod(fromBasicServiceEndpoint) - class IFilter(object): """Interface for Yadis filter objects. Other filter-like things From 9b848d41ded88e1e02f502ef45d7e2b936a81a6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 31 Jan 2018 18:19:30 +0100 Subject: [PATCH 043/151] Split imports in test_message --- openid/test/test_message.py | 299 +++++++++++++++++------------------- 1 file changed, 143 insertions(+), 156 deletions(-) diff --git a/openid/test/test_message.py b/openid/test/test_message.py index d9b91c86..1ccc090f 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -3,8 +3,11 @@ import urllib from urlparse import parse_qs -from openid import message, oidutil +from openid import oidutil from openid.extensions import sreg +from openid.message import (BARE_NS, NULL_NAMESPACE, OPENID1_NS, OPENID2_NS, OPENID_NS, OPENID_PROTOCOL_FIELDS, + THE_OTHER_OPENID1_NS, InvalidNamespace, InvalidOpenIDNamespace, Message, NamespaceMap, + UndefinedOpenIDNamespace, no_default) def mkGetArgTest(ns, key, expected=None): @@ -13,17 +16,17 @@ def test(self): self.assertEqual(self.msg.getArg(ns, key), expected) if expected is None: self.assertEqual(self.msg.getArg(ns, key, a_default), a_default) - self.assertRaises(KeyError, self.msg.getArg, ns, key, message.no_default) + self.assertRaises(KeyError, self.msg.getArg, ns, key, no_default) else: self.assertEqual(self.msg.getArg(ns, key, a_default), expected) - self.assertEqual(self.msg.getArg(ns, key, message.no_default), expected) + self.assertEqual(self.msg.getArg(ns, key, no_default), expected) return test class EmptyMessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message() + self.msg = Message() def test_toPostArgs(self): self.assertEqual(self.msg.toPostArgs(), {}) @@ -49,16 +52,16 @@ def test_getKeyOpenID(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getKey, message.OPENID_NS, 'foo') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.getKey, OPENID_NS, 'foo') def test_getKeyBARE(self): - self.assertEqual(self.msg.getKey(message.BARE_NS, 'foo'), 'foo') + self.assertEqual(self.msg.getKey(BARE_NS, 'foo'), 'foo') def test_getKeyNS1(self): - self.assertIsNone(self.msg.getKey(message.OPENID1_NS, 'foo')) + self.assertIsNone(self.msg.getKey(OPENID1_NS, 'foo')) def test_getKeyNS2(self): - self.assertIsNone(self.msg.getKey(message.OPENID2_NS, 'foo')) + self.assertIsNone(self.msg.getKey(OPENID2_NS, 'foo')) def test_getKeyNS3(self): self.assertIsNone(self.msg.getKey('urn:nothing-significant', 'foo')) @@ -68,40 +71,39 @@ def test_hasKey(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.hasKey, message.OPENID_NS, 'foo') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.hasKey, OPENID_NS, 'foo') def test_hasKeyBARE(self): - self.assertFalse(self.msg.hasKey(message.BARE_NS, 'foo')) + self.assertFalse(self.msg.hasKey(BARE_NS, 'foo')) def test_hasKeyNS1(self): - self.assertFalse(self.msg.hasKey(message.OPENID1_NS, 'foo')) + self.assertFalse(self.msg.hasKey(OPENID1_NS, 'foo')) def test_hasKeyNS2(self): - self.assertFalse(self.msg.hasKey(message.OPENID2_NS, 'foo')) + self.assertFalse(self.msg.hasKey(OPENID2_NS, 'foo')) def test_hasKeyNS3(self): self.assertFalse(self.msg.hasKey('urn:nothing-significant', 'foo')) def test_getAliasedArgSuccess(self): - msg = message.Message.fromPostArgs({'openid.ns.test': 'urn://foo', - 'openid.test.flub': 'bogus'}) - actual_uri = msg.getAliasedArg('ns.test', message.no_default) + msg = Message.fromPostArgs({'openid.ns.test': 'urn://foo', 'openid.test.flub': 'bogus'}) + actual_uri = msg.getAliasedArg('ns.test', no_default) self.assertEquals("urn://foo", actual_uri) def test_getAliasedArgFailure(self): - msg = message.Message.fromPostArgs({'openid.test.flub': 'bogus'}) - self.assertRaises(KeyError, msg.getAliasedArg, 'ns.test', message.no_default) + msg = Message.fromPostArgs({'openid.test.flub': 'bogus'}) + self.assertRaises(KeyError, msg.getAliasedArg, 'ns.test', no_default) def test_getArg(self): # Could reasonably return None instead of raising an # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getArg, message.OPENID_NS, 'foo') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.getArg, OPENID_NS, 'foo') - test_getArgBARE = mkGetArgTest(message.BARE_NS, 'foo') - test_getArgNS1 = mkGetArgTest(message.OPENID1_NS, 'foo') - test_getArgNS2 = mkGetArgTest(message.OPENID2_NS, 'foo') + test_getArgBARE = mkGetArgTest(BARE_NS, 'foo') + test_getArgNS1 = mkGetArgTest(OPENID1_NS, 'foo') + test_getArgNS2 = mkGetArgTest(OPENID2_NS, 'foo') test_getArgNS3 = mkGetArgTest('urn:nothing-significant', 'foo') def test_getArgs(self): @@ -109,23 +111,22 @@ def test_getArgs(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.getArgs, message.OPENID_NS) + self.assertRaises(UndefinedOpenIDNamespace, self.msg.getArgs, OPENID_NS) def test_getArgsBARE(self): - self.assertEqual(self.msg.getArgs(message.BARE_NS), {}) + self.assertEqual(self.msg.getArgs(BARE_NS), {}) def test_getArgsNS1(self): - self.assertEqual(self.msg.getArgs(message.OPENID1_NS), {}) + self.assertEqual(self.msg.getArgs(OPENID1_NS), {}) def test_getArgsNS2(self): - self.assertEqual(self.msg.getArgs(message.OPENID2_NS), {}) + self.assertEqual(self.msg.getArgs(OPENID2_NS), {}) def test_getArgsNS3(self): self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) def test_updateArgs(self): - self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.updateArgs, message.OPENID_NS, - {'does not': 'matter'}) + self.assertRaises(UndefinedOpenIDNamespace, self.msg.updateArgs, OPENID_NS, {'does not': 'matter'}) def _test_updateArgsNS(self, ns): update_args = { @@ -138,19 +139,19 @@ def _test_updateArgsNS(self, ns): self.assertEqual(self.msg.getArgs(ns), update_args) def test_updateArgsBARE(self): - self._test_updateArgsNS(message.BARE_NS) + self._test_updateArgsNS(BARE_NS) def test_updateArgsNS1(self): - self._test_updateArgsNS(message.OPENID1_NS) + self._test_updateArgsNS(OPENID1_NS) def test_updateArgsNS2(self): - self._test_updateArgsNS(message.OPENID2_NS) + self._test_updateArgsNS(OPENID2_NS) def test_updateArgsNS3(self): self._test_updateArgsNS('urn:nothing-significant') def test_setArg(self): - self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.setArg, message.OPENID_NS, 'does not', 'matter') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.setArg, OPENID_NS, 'does not', 'matter') def _test_setArgNS(self, ns): key = 'Camper van Beethoven' @@ -160,19 +161,19 @@ def _test_setArgNS(self, ns): self.assertEqual(self.msg.getArg(ns, key), value) def test_setArgBARE(self): - self._test_setArgNS(message.BARE_NS) + self._test_setArgNS(BARE_NS) def test_setArgNS1(self): - self._test_setArgNS(message.OPENID1_NS) + self._test_setArgNS(OPENID1_NS) def test_setArgNS2(self): - self._test_setArgNS(message.OPENID2_NS) + self._test_setArgNS(OPENID2_NS) def test_setArgNS3(self): self._test_setArgNS('urn:nothing-significant') def test_setArgToNone(self): - self.assertRaises(AssertionError, self.msg.setArg, message.OPENID1_NS, 'op_endpoint', None) + self.assertRaises(AssertionError, self.msg.setArg, OPENID1_NS, 'op_endpoint', None) def test_delArg(self): # Could reasonably raise KeyError instead of raising @@ -180,20 +181,20 @@ def test_delArg(self): # right, since this case should only happen when you're # building a message from scratch and so have no default # namespace. - self.assertRaises(message.UndefinedOpenIDNamespace, self.msg.delArg, message.OPENID_NS, 'key') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.delArg, OPENID_NS, 'key') def _test_delArgNS(self, ns): key = 'Camper van Beethoven' self.assertRaises(KeyError, self.msg.delArg, ns, key) def test_delArgBARE(self): - self._test_delArgNS(message.BARE_NS) + self._test_delArgNS(BARE_NS) def test_delArgNS1(self): - self._test_delArgNS(message.OPENID1_NS) + self._test_delArgNS(OPENID1_NS) def test_delArgNS2(self): - self._test_delArgNS(message.OPENID2_NS) + self._test_delArgNS(OPENID2_NS) def test_delArgNS3(self): self._test_delArgNS('urn:nothing-significant') @@ -207,8 +208,7 @@ def test_isOpenID2(self): class OpenID1MessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': 'unit test'}) + self.msg = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': 'unit test'}) def test_toPostArgs(self): self.assertEqual(self.msg.toPostArgs(), {'openid.mode': 'error', 'openid.error': 'unit test'}) @@ -233,55 +233,55 @@ def test_toURL(self): self.assertEqual(parsed, {'openid.mode': ['error'], 'openid.error': ['unit test']}) def test_getOpenID(self): - self.assertEqual(self.msg.getOpenIDNamespace(), message.OPENID1_NS) + self.assertEqual(self.msg.getOpenIDNamespace(), OPENID1_NS) def test_getKeyOpenID(self): - self.assertEqual(self.msg.getKey(message.OPENID_NS, 'mode'), 'openid.mode') + self.assertEqual(self.msg.getKey(OPENID_NS, 'mode'), 'openid.mode') def test_getKeyBARE(self): - self.assertEqual(self.msg.getKey(message.BARE_NS, 'mode'), 'mode') + self.assertEqual(self.msg.getKey(BARE_NS, 'mode'), 'mode') def test_getKeyNS1(self): - self.assertEqual(self.msg.getKey(message.OPENID1_NS, 'mode'), 'openid.mode') + self.assertEqual(self.msg.getKey(OPENID1_NS, 'mode'), 'openid.mode') def test_getKeyNS2(self): - self.assertIsNone(self.msg.getKey(message.OPENID2_NS, 'mode')) + self.assertIsNone(self.msg.getKey(OPENID2_NS, 'mode')) def test_getKeyNS3(self): self.assertIsNone(self.msg.getKey('urn:nothing-significant', 'mode')) def test_hasKey(self): - self.assertTrue(self.msg.hasKey(message.OPENID_NS, 'mode')) + self.assertTrue(self.msg.hasKey(OPENID_NS, 'mode')) def test_hasKeyBARE(self): - self.assertFalse(self.msg.hasKey(message.BARE_NS, 'mode')) + self.assertFalse(self.msg.hasKey(BARE_NS, 'mode')) def test_hasKeyNS1(self): - self.assertTrue(self.msg.hasKey(message.OPENID1_NS, 'mode')) + self.assertTrue(self.msg.hasKey(OPENID1_NS, 'mode')) def test_hasKeyNS2(self): - self.assertFalse(self.msg.hasKey(message.OPENID2_NS, 'mode')) + self.assertFalse(self.msg.hasKey(OPENID2_NS, 'mode')) def test_hasKeyNS3(self): self.assertFalse(self.msg.hasKey('urn:nothing-significant', 'mode')) - test_getArgBARE = mkGetArgTest(message.BARE_NS, 'mode') - test_getArgNS = mkGetArgTest(message.OPENID_NS, 'mode', 'error') - test_getArgNS1 = mkGetArgTest(message.OPENID1_NS, 'mode', 'error') - test_getArgNS2 = mkGetArgTest(message.OPENID2_NS, 'mode') + test_getArgBARE = mkGetArgTest(BARE_NS, 'mode') + test_getArgNS = mkGetArgTest(OPENID_NS, 'mode', 'error') + test_getArgNS1 = mkGetArgTest(OPENID1_NS, 'mode', 'error') + test_getArgNS2 = mkGetArgTest(OPENID2_NS, 'mode') test_getArgNS3 = mkGetArgTest('urn:nothing-significant', 'mode') def test_getArgs(self): - self.assertEqual(self.msg.getArgs(message.OPENID_NS), {'mode': 'error', 'error': 'unit test'}) + self.assertEqual(self.msg.getArgs(OPENID_NS), {'mode': 'error', 'error': 'unit test'}) def test_getArgsBARE(self): - self.assertEqual(self.msg.getArgs(message.BARE_NS), {}) + self.assertEqual(self.msg.getArgs(BARE_NS), {}) def test_getArgsNS1(self): - self.assertEqual(self.msg.getArgs(message.OPENID1_NS), {'mode': 'error', 'error': 'unit test'}) + self.assertEqual(self.msg.getArgs(OPENID1_NS), {'mode': 'error', 'error': 'unit test'}) def test_getArgsNS2(self): - self.assertEqual(self.msg.getArgs(message.OPENID2_NS), {}) + self.assertEqual(self.msg.getArgs(OPENID2_NS), {}) def test_getArgsNS3(self): self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) @@ -301,18 +301,16 @@ def _test_updateArgsNS(self, ns, before=None): self.assertEqual(self.msg.getArgs(ns), after) def test_updateArgs(self): - self._test_updateArgsNS(message.OPENID_NS, - before={'mode': 'error', 'error': 'unit test'}) + self._test_updateArgsNS(OPENID_NS, before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsBARE(self): - self._test_updateArgsNS(message.BARE_NS) + self._test_updateArgsNS(BARE_NS) def test_updateArgsNS1(self): - self._test_updateArgsNS(message.OPENID1_NS, - before={'mode': 'error', 'error': 'unit test'}) + self._test_updateArgsNS(OPENID1_NS, before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsNS2(self): - self._test_updateArgsNS(message.OPENID2_NS) + self._test_updateArgsNS(OPENID2_NS) def test_updateArgsNS3(self): self._test_updateArgsNS('urn:nothing-significant') @@ -325,16 +323,16 @@ def _test_setArgNS(self, ns): self.assertEqual(self.msg.getArg(ns, key), value) def test_setArg(self): - self._test_setArgNS(message.OPENID_NS) + self._test_setArgNS(OPENID_NS) def test_setArgBARE(self): - self._test_setArgNS(message.BARE_NS) + self._test_setArgNS(BARE_NS) def test_setArgNS1(self): - self._test_setArgNS(message.OPENID1_NS) + self._test_setArgNS(OPENID1_NS) def test_setArgNS2(self): - self._test_setArgNS(message.OPENID2_NS) + self._test_setArgNS(OPENID2_NS) def test_setArgNS3(self): self._test_setArgNS('urn:nothing-significant') @@ -350,16 +348,16 @@ def _test_delArgNS(self, ns): self.assertIsNone(self.msg.getArg(ns, key)) def test_delArg(self): - self._test_delArgNS(message.OPENID_NS) + self._test_delArgNS(OPENID_NS) def test_delArgBARE(self): - self._test_delArgNS(message.BARE_NS) + self._test_delArgNS(BARE_NS) def test_delArgNS1(self): - self._test_delArgNS(message.OPENID1_NS) + self._test_delArgNS(OPENID1_NS) def test_delArgNS2(self): - self._test_delArgNS(message.OPENID2_NS) + self._test_delArgNS(OPENID2_NS) def test_delArgNS3(self): self._test_delArgNS('urn:nothing-significant') @@ -373,20 +371,17 @@ def test_isOpenID2(self): class OpenID1ExplicitMessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID1_NS - }) + self.msg = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': OPENID1_NS}) def test_toPostArgs(self): self.assertEqual(self.msg.toPostArgs(), - {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': message.OPENID1_NS}) + {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': OPENID1_NS}) def test_toArgs(self): - self.assertEqual(self.msg.toArgs(), {'mode': 'error', 'error': 'unit test', 'ns': message.OPENID1_NS}) + self.assertEqual(self.msg.toArgs(), {'mode': 'error', 'error': 'unit test', 'ns': OPENID1_NS}) def test_toKVForm(self): - self.assertEqual(self.msg.toKVForm(), 'error:unit test\nmode:error\nns:%s\n' % message.OPENID1_NS) + self.assertEqual(self.msg.toKVForm(), 'error:unit test\nmode:error\nns:%s\n' % OPENID1_NS) def test_toURLEncoded(self): self.assertEqual(self.msg.toURLEncoded(), @@ -401,7 +396,7 @@ def test_toURL(self): query = actual[len(base_url) + 1:] parsed = parse_qs(query) self.assertEqual(parsed, - {'openid.mode': ['error'], 'openid.error': ['unit test'], 'openid.ns': [message.OPENID1_NS]}) + {'openid.mode': ['error'], 'openid.error': ['unit test'], 'openid.ns': [OPENID1_NS]}) def test_isOpenID1(self): self.assertTrue(self.msg.isOpenID1()) @@ -409,39 +404,34 @@ def test_isOpenID1(self): class OpenID2MessageTest(unittest.TestCase): def setUp(self): - self.msg = message.Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID2_NS}) - self.msg.setArg(message.BARE_NS, "xey", "value") + self.msg = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': OPENID2_NS}) + self.msg.setArg(BARE_NS, "xey", "value") def test_toPostArgs(self): self.assertEqual( self.msg.toPostArgs(), - {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': message.OPENID2_NS, 'xey': 'value'}) + {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': OPENID2_NS, 'xey': 'value'}) def test_toPostArgs_bug_with_utf8_encoded_values(self): - msg = message.Message.fromPostArgs({'openid.mode': 'error', - 'openid.error': 'unit test', - 'openid.ns': message.OPENID2_NS - }) - msg.setArg(message.BARE_NS, 'ünicöde_key', 'ünicöde_välüe') - post_args = {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': message.OPENID2_NS, + msg = Message.fromPostArgs({'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': OPENID2_NS}) + msg.setArg(BARE_NS, 'ünicöde_key', 'ünicöde_välüe') + post_args = {'openid.mode': 'error', 'openid.error': 'unit test', 'openid.ns': OPENID2_NS, 'ünicöde_key': 'ünicöde_välüe'} self.assertEqual(msg.toPostArgs(), post_args) def test_toArgs(self): # This method can't tolerate BARE_NS. - self.msg.delArg(message.BARE_NS, "xey") - self.assertEqual(self.msg.toArgs(), {'mode': 'error', 'error': 'unit test', 'ns': message.OPENID2_NS}) + self.msg.delArg(BARE_NS, "xey") + self.assertEqual(self.msg.toArgs(), {'mode': 'error', 'error': 'unit test', 'ns': OPENID2_NS}) def test_toKVForm(self): # Can't tolerate BARE_NS in kvform - self.msg.delArg(message.BARE_NS, "xey") - self.assertEqual(self.msg.toKVForm(), 'error:unit test\nmode:error\nns:%s\n' % message.OPENID2_NS) + self.msg.delArg(BARE_NS, "xey") + self.assertEqual(self.msg.toKVForm(), 'error:unit test\nmode:error\nns:%s\n' % OPENID2_NS) def _test_urlencoded(self, s): expected = ('openid.error=unit+test&openid.mode=error&openid.ns=%s&xey=value' % - urllib.quote(message.OPENID2_NS, '')) + urllib.quote(OPENID2_NS, '')) self.assertEqual(s, expected) def test_toURLEncoded(self): @@ -457,55 +447,55 @@ def test_toURL(self): self._test_urlencoded(query) def test_getOpenID(self): - self.assertEqual(self.msg.getOpenIDNamespace(), message.OPENID2_NS) + self.assertEqual(self.msg.getOpenIDNamespace(), OPENID2_NS) def test_getKeyOpenID(self): - self.assertEqual(self.msg.getKey(message.OPENID_NS, 'mode'), 'openid.mode') + self.assertEqual(self.msg.getKey(OPENID_NS, 'mode'), 'openid.mode') def test_getKeyBARE(self): - self.assertEqual(self.msg.getKey(message.BARE_NS, 'mode'), 'mode') + self.assertEqual(self.msg.getKey(BARE_NS, 'mode'), 'mode') def test_getKeyNS1(self): - self.assertIsNone(self.msg.getKey(message.OPENID1_NS, 'mode')) + self.assertIsNone(self.msg.getKey(OPENID1_NS, 'mode')) def test_getKeyNS2(self): - self.assertEqual(self.msg.getKey(message.OPENID2_NS, 'mode'), 'openid.mode') + self.assertEqual(self.msg.getKey(OPENID2_NS, 'mode'), 'openid.mode') def test_getKeyNS3(self): self.assertIsNone(self.msg.getKey('urn:nothing-significant', 'mode')) def test_hasKeyOpenID(self): - self.assertTrue(self.msg.hasKey(message.OPENID_NS, 'mode')) + self.assertTrue(self.msg.hasKey(OPENID_NS, 'mode')) def test_hasKeyBARE(self): - self.assertFalse(self.msg.hasKey(message.BARE_NS, 'mode')) + self.assertFalse(self.msg.hasKey(BARE_NS, 'mode')) def test_hasKeyNS1(self): - self.assertFalse(self.msg.hasKey(message.OPENID1_NS, 'mode')) + self.assertFalse(self.msg.hasKey(OPENID1_NS, 'mode')) def test_hasKeyNS2(self): - self.assertTrue(self.msg.hasKey(message.OPENID2_NS, 'mode')) + self.assertTrue(self.msg.hasKey(OPENID2_NS, 'mode')) def test_hasKeyNS3(self): self.assertFalse(self.msg.hasKey('urn:nothing-significant', 'mode')) - test_getArgBARE = mkGetArgTest(message.BARE_NS, 'mode') - test_getArgNS = mkGetArgTest(message.OPENID_NS, 'mode', 'error') - test_getArgNS1 = mkGetArgTest(message.OPENID1_NS, 'mode') - test_getArgNS2 = mkGetArgTest(message.OPENID2_NS, 'mode', 'error') + test_getArgBARE = mkGetArgTest(BARE_NS, 'mode') + test_getArgNS = mkGetArgTest(OPENID_NS, 'mode', 'error') + test_getArgNS1 = mkGetArgTest(OPENID1_NS, 'mode') + test_getArgNS2 = mkGetArgTest(OPENID2_NS, 'mode', 'error') test_getArgNS3 = mkGetArgTest('urn:nothing-significant', 'mode') def test_getArgsOpenID(self): - self.assertEqual(self.msg.getArgs(message.OPENID_NS), {'mode': 'error', 'error': 'unit test'}) + self.assertEqual(self.msg.getArgs(OPENID_NS), {'mode': 'error', 'error': 'unit test'}) def test_getArgsBARE(self): - self.assertEqual(self.msg.getArgs(message.BARE_NS), {'xey': 'value'}) + self.assertEqual(self.msg.getArgs(BARE_NS), {'xey': 'value'}) def test_getArgsNS1(self): - self.assertEqual(self.msg.getArgs(message.OPENID1_NS), {}) + self.assertEqual(self.msg.getArgs(OPENID1_NS), {}) def test_getArgsNS2(self): - self.assertEqual(self.msg.getArgs(message.OPENID2_NS), {'mode': 'error', 'error': 'unit test'}) + self.assertEqual(self.msg.getArgs(OPENID2_NS), {'mode': 'error', 'error': 'unit test'}) def test_getArgsNS3(self): self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) @@ -525,19 +515,16 @@ def _test_updateArgsNS(self, ns, before=None): self.assertEqual(self.msg.getArgs(ns), after) def test_updateArgsOpenID(self): - self._test_updateArgsNS(message.OPENID_NS, - before={'mode': 'error', 'error': 'unit test'}) + self._test_updateArgsNS(OPENID_NS, before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsBARE(self): - self._test_updateArgsNS(message.BARE_NS, - before={'xey': 'value'}) + self._test_updateArgsNS(BARE_NS, before={'xey': 'value'}) def test_updateArgsNS1(self): - self._test_updateArgsNS(message.OPENID1_NS) + self._test_updateArgsNS(OPENID1_NS) def test_updateArgsNS2(self): - self._test_updateArgsNS(message.OPENID2_NS, - before={'mode': 'error', 'error': 'unit test'}) + self._test_updateArgsNS(OPENID2_NS, before={'mode': 'error', 'error': 'unit test'}) def test_updateArgsNS3(self): self._test_updateArgsNS('urn:nothing-significant') @@ -550,16 +537,16 @@ def _test_setArgNS(self, ns): self.assertEqual(self.msg.getArg(ns, key), value) def test_setArgOpenID(self): - self._test_setArgNS(message.OPENID_NS) + self._test_setArgNS(OPENID_NS) def test_setArgBARE(self): - self._test_setArgNS(message.BARE_NS) + self._test_setArgNS(BARE_NS) def test_setArgNS1(self): - self._test_setArgNS(message.OPENID1_NS) + self._test_setArgNS(OPENID1_NS) def test_setArgNS2(self): - self._test_setArgNS(message.OPENID2_NS) + self._test_setArgNS(OPENID2_NS) def test_setArgNS3(self): self._test_setArgNS('urn:nothing-significant') @@ -568,7 +555,7 @@ def test_badAlias(self): """Make sure dotted aliases and OpenID protocol fields are not allowed as namespace aliases.""" - for f in message.OPENID_PROTOCOL_FIELDS + ['dotted.alias']: + for f in OPENID_PROTOCOL_FIELDS + ['dotted.alias']: args = {'openid.ns.%s' % f: 'blah', 'openid.%s.foo' % f: 'test'} @@ -594,7 +581,7 @@ def test_mysterious_missing_namespace_bug(self): 'invalidate_handle': '{{HMAC-SHA1}{1211477241.92242}{H0akXw==}', 'identity': 'http://nerdbank.org/OPAffirmative/AffirmativeIdentityWithSregNoAssoc.aspx', 'sreg.email': 'a@b.com'} - m = message.Message.fromOpenIDArgs(openid_args) + m = Message.fromOpenIDArgs(openid_args) self.assertEqual(m.namespaces.getAlias('http://openid.net/extensions/sreg/1.1'), 'sreg') missing = [] @@ -622,7 +609,7 @@ def test_112B(self): 'openid.sig': 'YJlWH4U6SroB1HoPkmEKx9AyGGg=', 'openid.signed': 'assoc_handle,identity,response_nonce,return_to,claimed_id,op_endpoint,pape.auth_time,' 'ns.pape,pape.nist_auth_level,pape.auth_policies'} - m = message.Message.fromPostArgs(args) + m = Message.fromPostArgs(args) missing = [] for k in args['openid.signed'].split(','): if not ("openid." + k) in m.toPostArgs().keys(): @@ -652,11 +639,11 @@ def test_repetitive_namespaces(self): 'openid.pape.auth_time': '2008-01-28T20:42:36Z', 'openid.pape.nist_auth_level': '0', } - self.assertRaises(message.InvalidNamespace, message.Message.fromPostArgs, args) + self.assertRaises(InvalidNamespace, Message.fromPostArgs, args) def test_implicit_sreg_ns(self): openid_args = {'sreg.email': 'a@b.com'} - m = message.Message.fromOpenIDArgs(openid_args) + m = Message.fromOpenIDArgs(openid_args) self.assertEqual(m.namespaces.getAlias(sreg.ns_uri), 'sreg') self.assertEqual(m.getArg(sreg.ns_uri, 'email'), 'a@b.com') self.assertEqual(m.toArgs(), openid_args) @@ -673,16 +660,16 @@ def _test_delArgNS(self, ns): self.assertIsNone(self.msg.getArg(ns, key)) def test_delArgOpenID(self): - self._test_delArgNS(message.OPENID_NS) + self._test_delArgNS(OPENID_NS) def test_delArgBARE(self): - self._test_delArgNS(message.BARE_NS) + self._test_delArgNS(BARE_NS) def test_delArgNS1(self): - self._test_delArgNS(message.OPENID1_NS) + self._test_delArgNS(OPENID1_NS) def test_delArgNS2(self): - self._test_delArgNS(message.OPENID2_NS) + self._test_delArgNS(OPENID2_NS) def test_delArgNS3(self): self._test_delArgNS('urn:nothing-significant') @@ -711,7 +698,7 @@ def test_isOpenID2(self): class MessageTest(unittest.TestCase): def setUp(self): self.postargs = { - 'openid.ns': message.OPENID2_NS, + 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.identity': 'http://bogus.example.invalid:port/', 'openid.assoc_handle': 'FLUB', @@ -797,7 +784,7 @@ def _checkForm(self, html, message_, action_url, "Expected submit value to be '%s', got '%s'" % (submit_text, submits[0].attrib['value']) def test_toFormMarkup(self): - m = message.Message.fromPostArgs(self.postargs) + m = Message.fromPostArgs(self.postargs) html = m.toFormMarkup(self.action_url, self.form_tag_attrs, self.submit_text) self._checkForm(html, m, self.action_url, @@ -805,14 +792,14 @@ def test_toFormMarkup(self): def test_toFormMarkup_bug_with_utf8_values(self): postargs = { - 'openid.ns': message.OPENID2_NS, + 'openid.ns': OPENID2_NS, 'openid.mode': 'checkid_setup', 'openid.identity': 'http://bogus.example.invalid:port/', 'openid.assoc_handle': 'FLUB', 'openid.return_to': 'Neverland', 'ünicöde_key': 'ünicöde_välüe', } - m = message.Message.fromPostArgs(postargs) + m = Message.fromPostArgs(postargs) # Calling m.toFormMarkup with lxml used for ElementTree will throw # a ValueError. html = m.toFormMarkup(self.action_url, self.form_tag_attrs, @@ -828,7 +815,7 @@ def test_toFormMarkup_bug_with_utf8_values(self): def test_overrideMethod(self): """Be sure that caller cannot change form method to GET.""" - m = message.Message.fromPostArgs(self.postargs) + m = Message.fromPostArgs(self.postargs) tag_attrs = dict(self.form_tag_attrs) tag_attrs['method'] = 'GET' @@ -841,7 +828,7 @@ def test_overrideMethod(self): def test_overrideRequired(self): """Be sure that caller CANNOT change the form charset for encoding type.""" - m = message.Message.fromPostArgs(self.postargs) + m = Message.fromPostArgs(self.postargs) tag_attrs = dict(self.form_tag_attrs) tag_attrs['accept-charset'] = 'UCS4' @@ -853,7 +840,7 @@ def test_overrideRequired(self): tag_attrs, self.submit_text) def test_setOpenIDNamespace_invalid(self): - m = message.Message() + m = Message() invalid_things = [ # Empty string is not okay here. '', @@ -868,7 +855,7 @@ def test_setOpenIDNamespace_invalid(self): ] for x in invalid_things: - self.assertRaises(message.InvalidOpenIDNamespace, m.setOpenIDNamespace, x, False) + self.assertRaises(InvalidOpenIDNamespace, m.setOpenIDNamespace, x, False) def test_isOpenID1(self): v1_namespaces = [ @@ -878,34 +865,34 @@ def test_isOpenID1(self): ] for ns in v1_namespaces: - m = message.Message(ns) + m = Message(ns) self.assertTrue(m.isOpenID1(), "%r not recognized as OpenID 1" % ns) self.assertEqual(m.getOpenIDNamespace(), ns) self.assertTrue(m.namespaces.isImplicit(ns)) def test_isOpenID2(self): ns = 'http://specs.openid.net/auth/2.0' - m = message.Message(ns) + m = Message(ns) self.assertTrue(m.isOpenID2()) - self.assertFalse(m.namespaces.isImplicit(message.NULL_NAMESPACE)) + self.assertFalse(m.namespaces.isImplicit(NULL_NAMESPACE)) self.assertEqual(m.getOpenIDNamespace(), ns) def test_setOpenIDNamespace_explicit(self): - m = message.Message() - m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, False) - self.assertFalse(m.namespaces.isImplicit(message.THE_OTHER_OPENID1_NS)) + m = Message() + m.setOpenIDNamespace(THE_OTHER_OPENID1_NS, False) + self.assertFalse(m.namespaces.isImplicit(THE_OTHER_OPENID1_NS)) def test_setOpenIDNamespace_implicit(self): - m = message.Message() - m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, True) - self.assertTrue(m.namespaces.isImplicit(message.THE_OTHER_OPENID1_NS)) + m = Message() + m.setOpenIDNamespace(THE_OTHER_OPENID1_NS, True) + self.assertTrue(m.namespaces.isImplicit(THE_OTHER_OPENID1_NS)) def test_explicitOpenID11NSSerialzation(self): - m = message.Message() - m.setOpenIDNamespace(message.THE_OTHER_OPENID1_NS, implicit=False) + m = Message() + m.setOpenIDNamespace(THE_OTHER_OPENID1_NS, implicit=False) post_args = m.toPostArgs() - self.assertEqual(post_args, {'openid.ns': message.THE_OTHER_OPENID1_NS}) + self.assertEqual(post_args, {'openid.ns': THE_OTHER_OPENID1_NS}) def test_fromPostArgs_ns11(self): # An example of the stuff that some Drupal installations send us, @@ -921,13 +908,13 @@ def test_fromPostArgs_ns11(self): u'openid.sreg.required': u'nickname,email', u'openid.trust_root': u'http://drupal.invalid', } - m = message.Message.fromPostArgs(query) + m = Message.fromPostArgs(query) self.assertTrue(m.isOpenID1()) class NamespaceMapTest(unittest.TestCase): def test_onealias(self): - nsm = message.NamespaceMap() + nsm = NamespaceMap() uri = 'http://example.com/foo' alias = "foo" nsm.addAlias(uri, alias) @@ -935,7 +922,7 @@ def test_onealias(self): self.assertEqual(nsm.getAlias(uri), alias) def test_iteration(self): - nsm = message.NamespaceMap() + nsm = NamespaceMap() uripat = 'http://example.com/foo%r' nsm.add(uripat % 0) From 65978bc1d99328a25506bd3334fe2bcc6aacabfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 31 Jan 2018 19:12:23 +0100 Subject: [PATCH 044/151] Refactor Message namespaces * Deprecate 'setOpenIDNamespace' method * Deprecate 'UndefinedOpenIDNamespace' exception * Drop private '_openid_ns_uri' attribute --- .isort.cfg | 2 +- openid/message.py | 33 +++++++++++++++++++++------------ openid/test/test_message.py | 35 ++++++++++++++++++++++++----------- 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/.isort.cfg b/.isort.cfg index 271c8b6b..3bf03262 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,5 +1,5 @@ [settings] line_length = 120 combine_as_imports = true -known_third_party = mock +default_section = THIRDPARTY known_first_party = openid diff --git a/openid/message.py b/openid/message.py index c7d53230..1a843c92 100644 --- a/openid/message.py +++ b/openid/message.py @@ -61,6 +61,9 @@ class UndefinedOpenIDNamespace(ValueError): """Raised if the generic OpenID namespace is accessed when there is no OpenID namespace set for this message.""" + def __init__(self, *args, **kwargs): + warnings.warn("UndefinedOpenIDNamespace exception is deprecated.", DeprecationWarning) + super(UndefinedOpenIDNamespace, self).__init__(*args, **kwargs) class InvalidOpenIDNamespace(ValueError): @@ -144,11 +147,9 @@ def __init__(self, openid_namespace=None): """ self.args = {} self.namespaces = NamespaceMap() - if openid_namespace is None: - self._openid_ns_uri = None - else: + if openid_namespace is not None: implicit = openid_namespace in OPENID1_NAMESPACES - self.setOpenIDNamespace(openid_namespace, implicit) + self._setOpenIDNamespace(openid_namespace, implicit) @classmethod def fromPostArgs(cls, args): @@ -204,13 +205,13 @@ def _fromOpenIDArgs(self, openid_args): self.namespaces.addAlias(value, ns_key) elif ns_alias == NULL_NAMESPACE and ns_key == 'ns': # null namespace - self.setOpenIDNamespace(value, False) + self._setOpenIDNamespace(value, False) else: ns_args.append((ns_alias, ns_key, value)) # Implicitly set an OpenID namespace definition (OpenID 1) if not self.getOpenIDNamespace(): - self.setOpenIDNamespace(OPENID1_NS, True) + self._setOpenIDNamespace(OPENID1_NS, True) # Actually put the pairs into the appropriate namespaces for (ns_alias, ns_key, value) in ns_args: @@ -237,7 +238,7 @@ def _getDefaultNamespace(self, mystery_alias): else: return None - def setOpenIDNamespace(self, openid_ns_uri, implicit): + def _setOpenIDNamespace(self, openid_ns_uri, implicit): """Set the OpenID namespace URI used in this message. @raises InvalidOpenIDNamespace: if the namespace is not in @@ -247,10 +248,19 @@ def setOpenIDNamespace(self, openid_ns_uri, implicit): raise InvalidOpenIDNamespace(openid_ns_uri) self.namespaces.addAlias(openid_ns_uri, NULL_NAMESPACE, implicit) - self._openid_ns_uri = openid_ns_uri + + def setOpenIDNamespace(self, openid_ns_uri, implicit): + """Set the OpenID namespace URI used in this message. + + @raises InvalidOpenIDNamespace: if the namespace is not in + L{Message.allowed_openid_namespaces} + """ + warnings.warn("Method 'setOpenIDNamespace' is deprecated. Pass namespace to Message constructor instead.", + DeprecationWarning) + self._setOpenIDNamespace(openid_ns_uri, implicit) def getOpenIDNamespace(self): - return self._openid_ns_uri + return self.namespaces.getNamespaceURI(NULL_NAMESPACE) def isOpenID1(self): return self.getOpenIDNamespace() in OPENID1_NAMESPACES @@ -379,10 +389,9 @@ def _fixNS(self, namespace): @type namespace: str or unicode or BARE_NS or OPENID_NS """ if namespace == OPENID_NS: - if self._openid_ns_uri is None: + namespace = self.getOpenIDNamespace() + if namespace is None: raise UndefinedOpenIDNamespace('OpenID namespace not set') - else: - namespace = self._openid_ns_uri if namespace != BARE_NS and type(namespace) not in [str, unicode]: raise TypeError( diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 1ccc090f..32d8ee63 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -1,8 +1,11 @@ # -*- coding: utf-8 -*- import unittest import urllib +import warnings from urlparse import parse_qs +from testfixtures import ShouldWarn + from openid import oidutil from openid.extensions import sreg from openid.message import (BARE_NS, NULL_NAMESPACE, OPENID1_NS, OPENID2_NS, OPENID_NS, OPENID_PROTOCOL_FIELDS, @@ -839,8 +842,15 @@ def test_overrideRequired(self): self._checkForm(html, m, self.action_url, tag_attrs, self.submit_text) - def test_setOpenIDNamespace_invalid(self): - m = Message() + def test_setOpenIDNamespace_deprecated(self): + message = Message() + warning_msg = "Method 'setOpenIDNamespace' is deprecated. Pass namespace to Message constructor instead." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + message.setOpenIDNamespace(OPENID2_NS, False) + self.assertEqual(message.getOpenIDNamespace(), OPENID2_NS) + + def test_openid_namespace_invalid(self): invalid_things = [ # Empty string is not okay here. '', @@ -853,9 +863,15 @@ def test_setOpenIDNamespace_invalid(self): # This is a Type URI, not a openid.ns value. 'http://specs.openid.net/auth/2.0/signon', ] + warning_msg = "Method 'setOpenIDNamespace' is deprecated. Pass namespace to Message constructor instead." for x in invalid_things: - self.assertRaises(InvalidOpenIDNamespace, m.setOpenIDNamespace, x, False) + self.assertRaises(InvalidOpenIDNamespace, Message, x, False) + # Test also deprecated setOpenIDNamespace + message = Message() + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertRaises(InvalidOpenIDNamespace, message.setOpenIDNamespace, x, False) def test_isOpenID1(self): v1_namespaces = [ @@ -877,19 +893,16 @@ def test_isOpenID2(self): self.assertFalse(m.namespaces.isImplicit(NULL_NAMESPACE)) self.assertEqual(m.getOpenIDNamespace(), ns) - def test_setOpenIDNamespace_explicit(self): - m = Message() - m.setOpenIDNamespace(THE_OTHER_OPENID1_NS, False) + def test_openid1_namespace_explicit(self): + m = Message(THE_OTHER_OPENID1_NS, False) self.assertFalse(m.namespaces.isImplicit(THE_OTHER_OPENID1_NS)) - def test_setOpenIDNamespace_implicit(self): - m = Message() - m.setOpenIDNamespace(THE_OTHER_OPENID1_NS, True) + def test_openid1_namespace_implicit(self): + m = Message(THE_OTHER_OPENID1_NS, True) self.assertTrue(m.namespaces.isImplicit(THE_OTHER_OPENID1_NS)) def test_explicitOpenID11NSSerialzation(self): - m = Message() - m.setOpenIDNamespace(THE_OTHER_OPENID1_NS, implicit=False) + m = Message(THE_OTHER_OPENID1_NS, False) post_args = m.toPostArgs() self.assertEqual(post_args, {'openid.ns': THE_OTHER_OPENID1_NS}) From fc74480637534e844033d95be39620af2f15446a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 31 Jan 2018 19:35:11 +0100 Subject: [PATCH 045/151] Refactor message creation * Add 'implicit_namespace' argument to Message constructor. * Turn '_fromOpenIDArgs' into a classmethod. --- openid/message.py | 63 ++++++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/openid/message.py b/openid/message.py index 1a843c92..378ffb56 100644 --- a/openid/message.py +++ b/openid/message.py @@ -139,27 +139,30 @@ class Message(object): allowed_openid_namespaces = [OPENID1_NS, THE_OTHER_OPENID1_NS, OPENID2_NS] - def __init__(self, openid_namespace=None): + def __init__(self, openid_namespace=None, implicit_namespace=None): """Create an empty Message. + @param openid_namespace: The message's OpenID namespace. + @param implicit_namespace: Whether the OpenID namespace is only implicit. + @raises InvalidOpenIDNamespace: if openid_namespace is not in L{Message.allowed_openid_namespaces} """ self.args = {} self.namespaces = NamespaceMap() if openid_namespace is not None: - implicit = openid_namespace in OPENID1_NAMESPACES - self._setOpenIDNamespace(openid_namespace, implicit) + if implicit_namespace is None: + implicit_namespace = openid_namespace in OPENID1_NAMESPACES + self._setOpenIDNamespace(openid_namespace, implicit_namespace) @classmethod def fromPostArgs(cls, args): """Construct a Message containing a set of POST arguments. """ - self = cls() - # Partition into "openid." args and bare args openid_args = {} + bare_args = {} for key, value in args.items(): if isinstance(value, list): raise TypeError("query dict must have one value for each key, " @@ -171,12 +174,14 @@ def fromPostArgs(cls, args): prefix = None if prefix != 'openid': - self.args[(BARE_NS, key)] = value + bare_args[key] = value else: openid_args[rest] = value - self._fromOpenIDArgs(openid_args) + self = cls._fromOpenIDArgs(openid_args) + for key, value in bare_args.items(): + self.args[(BARE_NS, key)] = value return self @classmethod @@ -186,32 +191,39 @@ def fromOpenIDArgs(cls, openid_args): @raises InvalidOpenIDNamespace: if openid.ns is not in L{Message.allowed_openid_namespaces} """ - self = cls() - self._fromOpenIDArgs(openid_args) - return self + return cls._fromOpenIDArgs(openid_args) - def _fromOpenIDArgs(self, openid_args): + @classmethod + def _fromOpenIDArgs(cls, openid_args): + # Resolve OpenID namespaces + openid_namespace = None + openid_implicit = False + # Other arguments + namespaces = {} ns_args = [] - - # Resolve namespaces - for rest, value in openid_args.iteritems(): - try: - ns_alias, ns_key = rest.split('.', 1) - except ValueError: + for key, value in openid_args.iteritems(): + if '.' not in key: ns_alias = NULL_NAMESPACE - ns_key = rest + ns_key = key + else: + ns_alias, ns_key = key.split('.', 1) - if ns_alias == 'ns': - self.namespaces.addAlias(value, ns_key) - elif ns_alias == NULL_NAMESPACE and ns_key == 'ns': - # null namespace - self._setOpenIDNamespace(value, False) + if ns_alias == NULL_NAMESPACE and ns_key == 'ns': + openid_namespace = value + elif ns_alias == 'ns': + namespaces[ns_key] = value else: ns_args.append((ns_alias, ns_key, value)) # Implicitly set an OpenID namespace definition (OpenID 1) - if not self.getOpenIDNamespace(): - self._setOpenIDNamespace(OPENID1_NS, True) + if openid_namespace is None: + openid_namespace = OPENID1_NS + openid_implicit = True + + self = cls(openid_namespace, openid_implicit) + + for alias, uri in namespaces.items(): + self.namespaces.addAlias(uri, alias) # Actually put the pairs into the appropriate namespaces for (ns_alias, ns_key, value) in ns_args: @@ -226,6 +238,7 @@ def _fromOpenIDArgs(self, openid_args): self.namespaces.addAlias(ns_uri, ns_alias, implicit=True) self.setArg(ns_uri, ns_key, value) + return self def _getDefaultNamespace(self, mystery_alias): """OpenID 1 compatibility: look for a default namespace URI to From 9f9050b012a71a9eb5ed4e2996672ff15d07cda1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 2 Feb 2018 18:38:00 +0100 Subject: [PATCH 046/151] Refactor server request message * Move `message` instance variable into base `OpenIDRequest`. * Deprecate `namespace` property for all requests. * Fix up request constructors. --- openid/server/server.py | 177 ++++++++++++++++++------------------- openid/test/test_server.py | 171 +++++++++++++---------------------- openid/test/test_sreg.py | 4 +- 3 files changed, 147 insertions(+), 205 deletions(-) diff --git a/openid/server/server.py b/openid/server/server.py index 8d45bc8d..dfe44444 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -150,9 +150,26 @@ class OpenIDRequest(object): @cvar mode: the C{X{openid.mode}} of this request. @type mode: str + + @ivar message: Original request message. + @type message: Message """ mode = None + def __init__(self, message=None): + if message is not None: + self.message = message + else: + # If no message is defined, create an empty one. + self.message = Message(OPENID2_NS) + + @property + def namespace(self): + """Return request namespace.""" + msg = 'The "namespace" attribute of {} objects is deprecated. Use "message.getOpenIDNamespace()" instead' + warnings.warn(msg.format(type(self).__name__), DeprecationWarning, stacklevel=2) + return self.message.getOpenIDNamespace() + class CheckAuthRequest(OpenIDRequest): """A request to verify the validity of a previous response. @@ -176,7 +193,7 @@ class CheckAuthRequest(OpenIDRequest): required_fields = ["identity", "return_to", "response_nonce"] - def __init__(self, assoc_handle, signed, invalidate_handle=None): + def __init__(self, assoc_handle, signed, invalidate_handle=None, message=None): """Construct me. These parameters are assigned directly as class attributes, see @@ -186,10 +203,10 @@ def __init__(self, assoc_handle, signed, invalidate_handle=None): @type signed: L{Message} @type invalidate_handle: str """ + super(CheckAuthRequest, self).__init__(message=message) self.assoc_handle = assoc_handle self.signed = signed self.invalidate_handle = invalidate_handle - self.namespace = OPENID2_NS @classmethod def fromMessage(klass, message, op_endpoint=UNUSED): @@ -200,28 +217,22 @@ def fromMessage(klass, message, op_endpoint=UNUSED): @returntype: L{CheckAuthRequest} """ - self = klass.__new__(klass) - self.message = message - self.namespace = message.getOpenIDNamespace() - self.assoc_handle = message.getArg(OPENID_NS, 'assoc_handle') - self.sig = message.getArg(OPENID_NS, 'sig') - - if (self.assoc_handle is None or - self.sig is None): + assoc_handle = message.getArg(OPENID_NS, 'assoc_handle') + sig = message.getArg(OPENID_NS, 'sig') + invalidate_handle = message.getArg(OPENID_NS, 'invalidate_handle') + if (assoc_handle is None or sig is None): fmt = "%s request missing required parameter from message %s" - raise ProtocolError( - message, text=fmt % (self.mode, message)) - - self.invalidate_handle = message.getArg(OPENID_NS, 'invalidate_handle') + raise ProtocolError(message, text=fmt % (klass.mode, message)) - self.signed = message.copy() + signed = message.copy() # openid.mode is currently check_authentication because # that's the mode of this request. But the signature # was made on something with a different openid.mode. # http://article.gmane.org/gmane.comp.web.openid.general/537 - if self.signed.hasKey(OPENID_NS, "mode"): - self.signed.setArg(OPENID_NS, "mode", "id_res") + if signed.hasKey(OPENID_NS, "mode"): + signed.setArg(OPENID_NS, "mode", "id_res") + self = klass(assoc_handle, signed, invalidate_handle, message) return self def answer(self, signatory): @@ -257,9 +268,8 @@ def __str__(self): ih = " invalidate? %r" % (self.invalidate_handle,) else: ih = "" - s = "<%s handle: %r sig: %r: signed: %r%s>" % ( - self.__class__.__name__, self.assoc_handle, - self.sig, self.signed, ih) + sig = self.message.getArg(OPENID_NS, 'sig') + s = "<%s handle: %r sig: %r: signed: %r%s>" % (self.__class__.__name__, self.assoc_handle, sig, self.signed, ih) return s @@ -397,16 +407,15 @@ class AssociateRequest(OpenIDRequest): 'DH-SHA256': DiffieHellmanSHA256ServerSession, } - def __init__(self, session, assoc_type): + def __init__(self, session, assoc_type, message=None): """Construct me. The session is assigned directly as a class attribute. See my L{class documentation} for its description. """ - super(AssociateRequest, self).__init__() + super(AssociateRequest, self).__init__(message=message) self.session = session self.assoc_type = assoc_type - self.namespace = OPENID2_NS @classmethod def fromMessage(klass, message, op_endpoint=UNUSED): @@ -455,9 +464,7 @@ def fromMessage(klass, message, op_endpoint=UNUSED): fmt = 'Session type %s does not support association type %s' raise ProtocolError(message, fmt % (session_type, assoc_type)) - self = klass(session, assoc_type) - self.message = message - self.namespace = message.getOpenIDNamespace() + self = klass(session, assoc_type, message=message) return self def answer(self, assoc): @@ -527,7 +534,7 @@ class CheckIDRequest(OpenIDRequest): @ivar claimed_id: The claimed identifier. Not present in OpenID 1.x messages. - @type claimed_id: str + @type claimed_id: str or None @ivar trust_root: "Are you Frank?" asks the checkid request. "Who wants to know?" C{trust_root}, that's who. This URL identifies the party @@ -546,7 +553,7 @@ class CheckIDRequest(OpenIDRequest): """ def __init__(self, identity, return_to, trust_root=None, immediate=False, - assoc_handle=None, op_endpoint=None, claimed_id=None): + assoc_handle=None, op_endpoint=None, claimed_id=None, message=None): """Construct me. These parameters are assigned directly as class attributes, see @@ -554,13 +561,33 @@ def __init__(self, identity, return_to, trust_root=None, immediate=False, @raises MalformedReturnURL: When the C{return_to} URL is not a URL. """ + super(CheckIDRequest, self).__init__(message=message) self.assoc_handle = assoc_handle + + # Check the identifier validity. In case of error, create protocol error from the message in the argument. + if self.message.isOpenID1(): + if identity is None: + s = "OpenID 1 message did not contain openid.identity" + raise ProtocolError(message, text=s) + else: + if identity and not claimed_id: + s = ("OpenID 2.0 message contained openid.identity but not " + "claimed_id") + raise ProtocolError(message, text=s) + elif claimed_id and not identity: + s = ("OpenID 2.0 message contained openid.claimed_id but not " + "identity") + raise ProtocolError(message, text=s) + self.identity = identity - self.claimed_id = claimed_id or identity + self.claimed_id = claimed_id self.return_to = return_to self.trust_root = trust_root or return_to + + if self.message.isOpenID2() and op_endpoint is None: + raise ValueError("CheckIDRequest requires op_endpoint argument for OpenID 2.0 requests.") self.op_endpoint = op_endpoint - assert self.op_endpoint is not None + if immediate: self.immediate = True self.mode = "checkid_immediate" @@ -568,18 +595,22 @@ def __init__(self, identity, return_to, trust_root=None, immediate=False, self.immediate = False self.mode = "checkid_setup" + # Using TrustRoot.parse here is a bit misleading, as we're not + # parsing return_to as a trust root at all. However, valid URLs + # are valid trust roots, so we can use this to get an idea if it + # is a valid URL. Not all trust roots are valid return_to URLs, + # however (particularly ones with wildcards), so this is still a + # little sketchy. if self.return_to is not None and not TrustRoot.parse(self.return_to): - raise MalformedReturnURL(None, self.return_to) - if not self.trustRootValid(): - raise UntrustedReturnURL(None, self.return_to, self.trust_root) - self.message = None + raise MalformedReturnURL(message, self.return_to) - @property - def namespace(self): - warnings.warn('The "namespace" attribute of CheckIDRequest objects ' - 'is deprecated. Use "message.getOpenIDNamespace()" ' - 'instead', DeprecationWarning, stacklevel=2) - return self.message.getOpenIDNamespace() + # I first thought that checking to see if the return_to is within + # the trust_root is premature here, a logic-not-decoding thing. But + # it was argued that this is really part of data validation. A + # request with an invalid trust_root/return_to is broken regardless of + # application, right? + if not self.trustRootValid(): + raise UntrustedReturnURL(message, self.return_to, self.trust_root) @classmethod def fromMessage(klass, message, op_endpoint): @@ -602,38 +633,17 @@ def fromMessage(klass, message, op_endpoint): @returntype: L{CheckIDRequest} """ - self = klass.__new__(klass) - self.message = message - self.op_endpoint = op_endpoint mode = message.getArg(OPENID_NS, 'mode') - if mode == "checkid_immediate": - self.immediate = True - self.mode = "checkid_immediate" - else: - self.immediate = False - self.mode = "checkid_setup" + assert mode in ('checkid_immediate', 'checkid_setup') + immediate = bool(mode == 'checkid_immediate') - self.return_to = message.getArg(OPENID_NS, 'return_to') - if message.isOpenID1() and not self.return_to: + return_to = message.getArg(OPENID_NS, 'return_to') + if message.isOpenID1() and not return_to: fmt = "Missing required field 'return_to' from %r" raise ProtocolError(message, text=fmt % (message,)) - self.identity = message.getArg(OPENID_NS, 'identity') - self.claimed_id = message.getArg(OPENID_NS, 'claimed_id') - if message.isOpenID1(): - if self.identity is None: - s = "OpenID 1 message did not contain openid.identity" - raise ProtocolError(message, text=s) - else: - if self.identity and not self.claimed_id: - s = ("OpenID 2.0 message contained openid.identity but not " - "claimed_id") - raise ProtocolError(message, text=s) - elif self.claimed_id and not self.identity: - s = ("OpenID 2.0 message contained openid.claimed_id but not " - "identity") - raise ProtocolError(message, text=s) - + identity = message.getArg(OPENID_NS, 'identity') + claimed_id = message.getArg(OPENID_NS, 'claimed_id') # There's a case for making self.trust_root be a TrustRoot # here. But if TrustRoot isn't currently part of the "public" API, # I'm not sure it's worth doing. @@ -646,32 +656,15 @@ def fromMessage(klass, message, op_endpoint): # Using 'or' here is slightly different than sending a default # argument to getArg, as it will treat no value and an empty # string as equivalent. - self.trust_root = (message.getArg(OPENID_NS, trust_root_param) or self.return_to) + trust_root = (message.getArg(OPENID_NS, trust_root_param) or return_to) - if not message.isOpenID1(): - if self.return_to is self.trust_root is None: - raise ProtocolError(message, "openid.realm required when " + - "openid.return_to absent") + if not message.isOpenID1() and (return_to is trust_root is None): + raise ProtocolError(message, "openid.realm required when openid.return_to absent") - self.assoc_handle = message.getArg(OPENID_NS, 'assoc_handle') - - # Using TrustRoot.parse here is a bit misleading, as we're not - # parsing return_to as a trust root at all. However, valid URLs - # are valid trust roots, so we can use this to get an idea if it - # is a valid URL. Not all trust roots are valid return_to URLs, - # however (particularly ones with wildcards), so this is still a - # little sketchy. - if self.return_to is not None and not TrustRoot.parse(self.return_to): - raise MalformedReturnURL(message, self.return_to) - - # I first thought that checking to see if the return_to is within - # the trust_root is premature here, a logic-not-decoding thing. But - # it was argued that this is really part of data validation. A - # request with an invalid trust_root/return_to is broken regardless of - # application, right? - if not self.trustRootValid(): - raise UntrustedReturnURL(message, self.return_to, self.trust_root) + assoc_handle = message.getArg(OPENID_NS, 'assoc_handle') + self = klass(identity, return_to, trust_root=trust_root, immediate=immediate, assoc_handle=assoc_handle, + op_endpoint=op_endpoint, claimed_id=claimed_id, message=message) return self def idSelect(self): @@ -773,8 +766,6 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): @raises NoReturnError: when I do not have a return_to. """ - assert self.message is not None - if not self.return_to: raise NoReturnToError @@ -974,7 +965,7 @@ def __init__(self, request): @type request: L{OpenIDRequest} """ self.request = request - self.fields = Message(request.namespace) + self.fields = Message(request.message.getOpenIDNamespace()) def __str__(self): return "%s for %s: %s" % ( diff --git a/openid/test/test_server.py b/openid/test/test_server.py index a3296fc9..d4a1f146 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -1,10 +1,12 @@ """Tests for openid.server. """ import unittest +import warnings from functools import partial from urlparse import parse_qs, parse_qsl, urlparse -from testfixtures import LogCapture, StringComparison +from mock import sentinel +from testfixtures import LogCapture, ShouldWarn, StringComparison from openid import association, cryptutil, oidutil from openid.consumer.consumer import DiffieHellmanSHA256ConsumerSession @@ -25,6 +27,39 @@ ALT_GEN = 5 +# Example values to be used in tests +EXAMPLE_IDENTITY = 'http://id.example.cz/' +EXAMPLE_CLAIMED_ID = 'http://claimed.example.cz/' + + +def make_checkid_request(identity=EXAMPLE_IDENTITY, claimed_id=EXAMPLE_CLAIMED_ID, + trust_root='http://realm.example.cz/', return_to='http://realm.example.cz/return_to/', + op_endpoint='http://op.example.cz/', immediate=False, message=None): + """Create a simple CheckIDRequest.""" + message = message or Message(OPENID2_NS) + return server.CheckIDRequest(identity=identity, claimed_id=claimed_id, trust_root=trust_root, return_to=return_to, + op_endpoint=op_endpoint, immediate=immediate, message=message) + + +class TestOpenIDRequest(unittest.TestCase): + """Test OpenID request base class.""" + + def test_init_default_message(self): + # Test empty OpenID 2.0 message is create if not provided. + request = server.OpenIDRequest() + self.assertTrue(request.message) + self.assertEqual(request.message.getOpenIDNamespace(), OPENID2_NS) + + def test_namespace(self): + # Test deprecated namespace property + request = server.OpenIDRequest() + warning_msg = ('The "namespace" attribute of OpenIDRequest objects is deprecated. Use ' + '"message.getOpenIDNamespace()" instead') + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertEqual(request.namespace, OPENID2_NS) + + class TestProtocolError(unittest.TestCase): def test_browserWithReturnTo(self): return_to = "http://rp.unittest/consumer" @@ -328,7 +363,7 @@ def test_checkAuth(self): r = self.decode(args) self.assertIsInstance(r, server.CheckAuthRequest) self.assertEqual(r.mode, 'check_authentication') - self.assertEqual(r.sig, 'sigblob') + self.assertEqual(r.message.getArg(OPENID_NS, 'sig'), 'sigblob') def test_checkAuthMissingSignature(self): args = { @@ -488,14 +523,7 @@ def test_id_res_OpenID2_GET(self): OpenID 1 message size, a GET response (i.e., redirect) is issued. """ - request = server.CheckIDRequest( - identity='http://bombom.unittest/', - trust_root='http://burr.unittest/', - return_to='http://burr.unittest/999', - immediate=False, - op_endpoint=self.server.op_endpoint, - ) - request.message = Message(OPENID2_NS) + request = make_checkid_request() response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'ns': OPENID2_NS, @@ -516,14 +544,7 @@ def test_id_res_OpenID2_POST(self): message size, a POST response (i.e., an HTML form) is returned. """ - request = server.CheckIDRequest( - identity='http://bombom.unittest/', - trust_root='http://burr.unittest/', - return_to='http://burr.unittest/999', - immediate=False, - op_endpoint=self.server.op_endpoint, - ) - request.message = Message(OPENID2_NS) + request = make_checkid_request() response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'ns': OPENID2_NS, @@ -540,14 +561,7 @@ def test_id_res_OpenID2_POST(self): self.assertIn(response.toFormMarkup(), webresponse.body) def test_toFormMarkup(self): - request = server.CheckIDRequest( - identity='http://bombom.unittest/', - trust_root='http://burr.unittest/', - return_to='http://burr.unittest/999', - immediate=False, - op_endpoint=self.server.op_endpoint, - ) - request.message = Message(OPENID2_NS) + request = make_checkid_request() response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'ns': OPENID2_NS, @@ -561,14 +575,7 @@ def test_toFormMarkup(self): self.assertIn(' foo="bar"', form_markup) def test_toHTML(self): - request = server.CheckIDRequest( - identity='http://bombom.unittest/', - trust_root='http://burr.unittest/', - return_to='http://burr.unittest/999', - immediate=False, - op_endpoint=self.server.op_endpoint, - ) - request.message = Message(OPENID2_NS) + request = make_checkid_request() response = server.OpenIDResponse(request) response.fields = Message.fromOpenIDArgs({ 'ns': OPENID2_NS, @@ -582,7 +589,7 @@ def test_toHTML(self): self.assertIn('', html) self.assertIn(' send checkid_* request From 154ced931e746749ba87eb7305c52cf0aa224fdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 7 Mar 2018 14:21:05 +0100 Subject: [PATCH 047/151] Remove disliked functions - map, filter and reduce --- openid/__init__.py | 2 +- openid/consumer/html_parse.py | 2 +- openid/dh.py | 2 +- openid/extensions/draft/pape2.py | 3 +-- openid/extensions/draft/pape5.py | 3 +-- openid/store/filestore.py | 4 +--- openid/store/sqlstore.py | 2 +- openid/test/test_accept.py | 2 +- openid/test/test_dh.py | 4 ++-- openid/test/test_linkparse.py | 2 +- openid/test/test_oidutil.py | 4 ++-- openid/test/test_trustroot.py | 4 ++-- openid/urinorm.py | 4 +--- openid/yadis/__init__.py | 2 +- openid/yadis/xri.py | 7 ++----- 15 files changed, 19 insertions(+), 28 deletions(-) diff --git a/openid/__init__.py b/openid/__init__.py index b172b30c..1ef3e0e6 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -45,7 +45,7 @@ # Parse the version info try: - version_info = map(int, __version__.split('.')) + version_info = tuple(int(i) for i in __version__.split('.')) except ValueError: version_info = (None, None, None) else: diff --git a/openid/consumer/html_parse.py b/openid/consumer/html_parse.py index 14ff8cc2..3c2a0252 100644 --- a/openid/consumer/html_parse.py +++ b/openid/consumer/html_parse.py @@ -249,7 +249,7 @@ def findLinksRel(link_attrs_list, target_rel): as a relationship.""" # XXX: TESTME matchesTarget = partial(linkHasRel, target_rel=target_rel) - return filter(matchesTarget, link_attrs_list) + return [i for i in link_attrs_list if matchesTarget(i)] def findFirstHref(link_attrs_list, target_rel): diff --git a/openid/dh.py b/openid/dh.py index 5b7a4400..74065fd2 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -10,7 +10,7 @@ def strxor(x, y): if len(x) != len(y): raise ValueError('Inputs to strxor must have the same length') - return "".join(map(_xor, zip(x, y))) + return "".join(_xor((a, b)) for a, b in zip(x, y)) class DiffieHellman(object): diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index 954c5c00..d1790a8e 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -145,8 +145,7 @@ def preferredTypes(self, supported_types): @returntype: [str] """ - return filter(self.preferred_auth_policies.__contains__, - supported_types) + return [i for i in supported_types if i in self.preferred_auth_policies] Request.ns_uri = ns_uri diff --git a/openid/extensions/draft/pape5.py b/openid/extensions/draft/pape5.py index e7568dd1..70655006 100644 --- a/openid/extensions/draft/pape5.py +++ b/openid/extensions/draft/pape5.py @@ -264,8 +264,7 @@ def preferredTypes(self, supported_types): @returntype: [str] """ - return filter(self.preferred_auth_policies.__contains__, - supported_types) + return [i for i in supported_types if i in self.preferred_auth_policies] Request.ns_uri = ns_uri diff --git a/openid/store/filestore.py b/openid/store/filestore.py index 0c5c044d..aadf20dc 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -333,9 +333,7 @@ def useNonce(self, server_url, timestamp, salt): def _allAssocs(self): all_associations = [] - association_filenames = map( - lambda filename: os.path.join(self.association_dir, filename), - os.listdir(self.association_dir)) + association_filenames = [os.path.join(self.association_dir, f) for f in os.listdir(self.association_dir)] for association_filename in association_filenames: try: association_file = file(association_filename, 'rb') diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index c9e7b23a..339931e2 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -146,7 +146,7 @@ def unicode_to_str(arg): return str(arg) else: return arg - str_args = map(unicode_to_str, args) + str_args = [unicode_to_str(i) for i in args] self.cur.execute(sql, str_args) def __getattr__(self, attr): diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 3f8d9fff..0b2fbb91 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -89,7 +89,7 @@ class MatchAcceptTest(unittest.TestCase): def runTest(self): lines = getTestData() chunks = chunk(lines) - data_sets = map(parseLines, chunks) + data_sets = [parseLines(l) for l in chunks] for data in data_sets: lnos = [] lno, accept_header = data['accept'] diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 6c78a0b9..03ef20bc 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -31,8 +31,8 @@ def test_strxor(self): ('', 'a'), ('foo', 'ba'), (NUL * 3, NUL * 4), - (''.join(map(chr, xrange(256))), - ''.join(map(chr, xrange(128)))), + (''.join(chr(i) for i in range(256)), + ''.join(chr(i) for i in range(128))), ] for aa, bb in exc_cases: diff --git a/openid/test/test_linkparse.py b/openid/test/test_linkparse.py index 230bd051..077caaf4 100644 --- a/openid/test/test_linkparse.py +++ b/openid/test/test_linkparse.py @@ -30,7 +30,7 @@ def parseCase(s): name = lines.pop(0) assert name.startswith('Name: ') desc = name[6:] - return desc, markup, map(parseLink, lines) + return desc, markup, [parseLink(l) for l in lines] def parseTests(s): diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index b64898a1..a8415e8f 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -27,7 +27,7 @@ def checkEncoded(s): '\x00', '\x01', '\x00' * 100, - ''.join(map(chr, range(256))), + ''.join(chr(i) for i in range(256)), ] for s in cases: @@ -39,7 +39,7 @@ def checkEncoded(s): # Randomized test for _ in xrange(50): n = random.randrange(2048) - s = ''.join(map(chr, map(lambda _: random.randrange(256), range(n)))) + s = ''.join(chr(random.randrange(256)) for i in range(n)) b64 = oidutil.toBase64(s) checkEncoded(b64) s_prime = oidutil.fromBase64(b64) diff --git a/openid/test/test_trustroot.py b/openid/test/test_trustroot.py index 8302141c..a470207c 100644 --- a/openid/test/test_trustroot.py +++ b/openid/test/test_trustroot.py @@ -42,7 +42,7 @@ def test(self): def getTests(grps, head, dat): tests = [] top = head.strip() - gdat = map(str.strip, dat.split('-' * 40 + '\n')) + gdat = [i.strip() for i in dat.split('-' * 40 + '\n')] assert not gdat[0] assert len(gdat) == (len(grps) * 2 + 1), (gdat, grps) i = 1 @@ -57,7 +57,7 @@ def getTests(grps, head, dat): def parseTests(data): - parts = map(str.strip, data.split('=' * 40 + '\n')) + parts = [i.strip() for i in data.split('=' * 40 + '\n')] assert not parts[0] _, ph, pdat, mh, mdat = parts return ph, pdat, mh, mdat diff --git a/openid/urinorm.py b/openid/urinorm.py index 21869c83..e7127d34 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -75,9 +75,7 @@ _unreserved[ord('~')] = True -_escapeme_re = re.compile('[%s]' % (''.join( - map(lambda m_n: u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])), - UCSCHAR + IPRIVATE)),)) +_escapeme_re = re.compile('[%s]' % ''.join(u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])) for m_n in UCSCHAR + IPRIVATE)) def _pct_escape_unicode(char_match): diff --git a/openid/yadis/__init__.py b/openid/yadis/__init__.py index 68a0d449..a163f803 100644 --- a/openid/yadis/__init__.py +++ b/openid/yadis/__init__.py @@ -16,7 +16,7 @@ # Parse the version info try: - version_info = map(int, __version__.split('.')) + version_info = tuple(int(i) for i in __version__.split('.')) except ValueError: version_info = (None, None, None) else: diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index bd3b29ed..60e0675b 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -6,7 +6,6 @@ """ import re -from functools import reduce XRI_AUTHORITIES = ['!', '=', '@', '+', '$', '('] @@ -51,9 +50,7 @@ ] -_escapeme_re = re.compile('[%s]' % (''.join( - map(lambda m_n: u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])), - UCSCHAR + IPRIVATE)),)) +_escapeme_re = re.compile('[%s]' % ''.join(u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])) for m_n in UCSCHAR + IPRIVATE)) def identifierScheme(identifier): @@ -147,7 +144,7 @@ def rootAuthority(xri): else: # IRI reference. XXX: Can IRI authorities have segments? segments = authority.split('!') - segments = reduce(list.__add__, map(lambda s: s.split('*'), segments)) + segments = [c for s in segments for c in s.split('*')] root = segments[0] return XRI(root) From ff6b54f3b66bcad8e3e9beaab9a7e6ab492fbf2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jan 2018 16:17:41 +0100 Subject: [PATCH 048/151] Use lxml for XML --- README.md | 1 + openid/message.py | 12 ++--------- openid/oidutil.py | 43 ------------------------------------- openid/test/test_message.py | 14 ++++-------- openid/yadis/etxrd.py | 21 +++--------------- setup.py | 5 +++++ 6 files changed, 15 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index b54b3ed9..af1d6e1c 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ REQUIREMENTS ============ - Python 2.7. + - lxml INSTALLATION diff --git a/openid/message.py b/openid/message.py index 378ffb56..ff8bcb5c 100644 --- a/openid/message.py +++ b/openid/message.py @@ -8,14 +8,9 @@ import urllib import warnings -from openid import kvform, oidutil +from lxml import etree as ElementTree -try: - ElementTree = oidutil.importElementTree() -except ImportError: - # No elementtree found, so give up, but don't fail to import, - # since we have fallbacks. - ElementTree = None +from openid import kvform, oidutil # This doesn't REALLY belong here, but where is better? IDENTIFIER_SELECT = 'http://specs.openid.net/auth/2.0/identifier_select' @@ -349,9 +344,6 @@ def toFormMarkup(self, action_url, form_tag_attrs=None, encodes the values in this Message object. @rtype: str or unicode """ - if ElementTree is None: - raise RuntimeError('This function requires ElementTree.') - assert action_url is not None form = ElementTree.Element(u'form') diff --git a/openid/oidutil.py b/openid/oidutil.py index 13954b76..70384d32 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -13,14 +13,6 @@ _LOGGER = logging.getLogger(__name__) -elementtree_modules = [ - 'lxml.etree', - 'xml.etree.cElementTree', - 'xml.etree.ElementTree', - 'cElementTree', - 'elementtree.ElementTree', -] - def toUnicode(value): """Returns the given argument as a unicode object. @@ -54,41 +46,6 @@ def autoSubmitHTML(form, title='OpenID transaction in progress'): """ % (title, form) -def importElementTree(module_names=None): - """Find a working ElementTree implementation, trying the standard - places that such a thing might show up. - - >>> ElementTree = importElementTree() - - @param module_names: The names of modules to try to use as - ElementTree. Defaults to C{L{elementtree_modules}} - - @returns: An ElementTree module - """ - if module_names is None: - module_names = elementtree_modules - - for mod_name in module_names: - try: - ElementTree = __import__(mod_name, None, None, ['unused']) - except ImportError: - pass - else: - # Make sure it can actually parse XML - try: - ElementTree.XML('') - except Exception: - logging.exception('Not using ElementTree library %r because it failed to parse a trivial document: %s', - mod_name) - else: - return ElementTree - else: - raise ImportError('No ElementTree library found. ' - 'You may need to install one. ' - 'Tried importing %r' % (module_names,) - ) - - def log(message, level=0): """Handle a log message from the OpenID library. diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 32d8ee63..f00182ab 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -4,9 +4,9 @@ import warnings from urlparse import parse_qs +from lxml import etree as ElementTree from testfixtures import ShouldWarn -from openid import oidutil from openid.extensions import sreg from openid.message import (BARE_NS, NULL_NAMESPACE, OPENID1_NS, OPENID2_NS, OPENID_NS, OPENID_PROTOCOL_FIELDS, THE_OTHER_OPENID1_NS, InvalidNamespace, InvalidOpenIDNamespace, Message, NamespaceMap, @@ -727,10 +727,8 @@ def setUp(self): def _checkForm(self, html, message_, action_url, form_tag_attrs, submit_text): - E = oidutil.importElementTree() - # Build element tree from HTML source - input_tree = E.ElementTree(E.fromstring(html)) + input_tree = ElementTree.ElementTree(ElementTree.fromstring(html)) # Get root element form = input_tree.getroot() @@ -803,14 +801,10 @@ def test_toFormMarkup_bug_with_utf8_values(self): 'ünicöde_key': 'ünicöde_välüe', } m = Message.fromPostArgs(postargs) - # Calling m.toFormMarkup with lxml used for ElementTree will throw - # a ValueError. html = m.toFormMarkup(self.action_url, self.form_tag_attrs, self.submit_text) - # Using the (c)ElementTree from stdlib will result in the UTF-8 - # encoded strings to be converted to XML character references, - # "ünicöde_key" becomes "ünicöde_key" and - # "ünicöde_välüe" becomes "ünicöde_välüe" + self.assertIn('ünicöde_key', html) + self.assertIn('ünicöde_välüe', html) self.assertNotIn('ünicöde_key', html, 'UTF-8 bytes should not convert to XML character references') self.assertNotIn('ünicöde_välüe', html, diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index 563a1f2e..a5366172 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -19,27 +19,12 @@ ] import random -import sys from datetime import datetime from time import strptime -from openid.oidutil import importElementTree -from openid.yadis import xri - -ElementTree = importElementTree() +from lxml import etree as ElementTree -# the different elementtree modules don't have a common exception -# model. We just want to be able to catch the exceptions that signify -# malformed XML data and wrap them, so that the other library code -# doesn't have to know which XML library we're using. -try: - # Make the parser raise an exception so we can sniff out the type - # of exceptions - ElementTree.XML('> purposely malformed XML <') -except (MemoryError, AssertionError, ImportError): - raise -except Exception: - XMLError = sys.exc_info()[0] +from openid.yadis import xri class XRDSError(Exception): @@ -65,7 +50,7 @@ def parseXRDS(text): """ try: element = ElementTree.XML(text) - except XMLError as why: + except ElementTree.Error as why: exc = XRDSError('Error parsing document as XML') exc.reason = why raise exc diff --git a/setup.py b/setup.py index 4b7e934c..876da4f4 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,10 @@ os.system('./admin/makedoc') version = '[library version:2.2.5]'[17:-1] +INSTALL_REQUIRES = [ + 'lxml;platform_python_implementation=="CPython"', + 'lxml <4.0;platform_python_implementation=="PyPy"', +] EXTRAS_REQUIRE = { 'quality': ('flake8', 'isort'), 'tests': ('mock', 'testfixtures'), @@ -35,6 +39,7 @@ 'openid.extensions', 'openid.extensions.draft', ], + install_requires=INSTALL_REQUIRES, extras_require=EXTRAS_REQUIRE, # license specified by classifier. # license=getLicense(), From 66b9f3a05a115d63aa268279cce574699bd603c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 8 Mar 2018 13:57:53 +0100 Subject: [PATCH 049/151] Fix XXE in XRDS parsing --- openid/test/test_etxrd.py | 50 +++++++++++++++++++++++++++++++++++++++ openid/yadis/etxrd.py | 9 +++---- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index 59c84115..2b842d7d 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -1,6 +1,9 @@ import os.path +import tempfile import unittest +from lxml import etree + from openid.yadis import etxrd, services, xri @@ -32,6 +35,53 @@ def simpleOpenIDTransformer(endpoint): return (endpoint.uri, delegate) +class TestParseXRDS(unittest.TestCase): + """Test `parseXRDS` function.""" + + def assertXmlEqual(self, result, expected): + self.assertEqual(result.tag, expected.tag) + self.assertEqual(result.text, expected.text) + self.assertEqual(result.tail, expected.tail) + self.assertEqual(result.attrib, expected.attrib) + self.assertEqual(len(result), len(expected)) + for child_r, child_e in zip(result, expected): + self.assertXmlEqual(child_r, child_e) + + def test_minimal_xrds(self): + xml = '' + tree = etxrd.parseXRDS(xml) + self.assertIsInstance(tree, type(etree.ElementTree())) + self.assertXmlEqual(tree.getroot(), etree.XML(xml)) + + def test_not_xrds(self): + xml = '' + with self.assertRaisesRegexp(etxrd.XRDSError, 'Not an XRDS document'): + etxrd.parseXRDS(xml) + + def test_invalid_xml(self): + xml = '<' + with self.assertRaisesRegexp(etxrd.XRDSError, 'Error parsing document as XML'): + etxrd.parseXRDS(xml) + + def test_xxe(self): + xxe_content = 'XXE CONTENT' + _, tmp_file = tempfile.mkstemp() + try: + with open(tmp_file, 'w') as xxe_file: + xxe_file.write(xxe_content) + # XXE example from Testing for XML Injection (OTG-INPVAL-008) + # https://www.owasp.org/index.php/Testing_for_XML_Injection_(OTG-INPVAL-008) + xml = ('' + '' + ']>' + '&xxe;') + tree = etxrd.parseXRDS(xml % tmp_file) + self.assertNotIn(xxe_content, etree.tostring(tree)) + finally: + os.remove(tmp_file) + + class TestServiceParser(unittest.TestCase): def setUp(self): self.xmldoc = file(XRD_FILE).read() diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index a5366172..a96a107d 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -22,7 +22,7 @@ from datetime import datetime from time import strptime -from lxml import etree as ElementTree +from lxml import etree from openid.yadis import xri @@ -48,14 +48,15 @@ def parseXRDS(text): @raises XRDSError: When there is a parse error or the document does not contain an XRDS. """ + parser = etree.XMLParser(resolve_entities=False) try: - element = ElementTree.XML(text) - except ElementTree.Error as why: + element = etree.XML(text, parser) + except etree.Error as why: exc = XRDSError('Error parsing document as XML') exc.reason = why raise exc else: - tree = ElementTree.ElementTree(element) + tree = etree.ElementTree(element) if not isXRDS(tree): raise XRDSError('Not an XRDS document') From 5652d9e3312f50b1009f5f30d3e8a31929e47b85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 8 Mar 2018 14:57:18 +0100 Subject: [PATCH 050/151] Add codecov --- setup.py | 2 +- tox.ini | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/setup.py b/setup.py index 876da4f4..d6de0ea3 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ ] EXTRAS_REQUIRE = { 'quality': ('flake8', 'isort'), - 'tests': ('mock', 'testfixtures'), + 'tests': ('mock', 'testfixtures', 'coverage'), # Optional dependencies for fetchers 'httplib2': ('httplib2', ), 'pycurl': ('pycurl', ), diff --git a/tox.ini b/tox.ini index 5d114bb3..ab9ef238 100644 --- a/tox.ini +++ b/tox.ini @@ -11,22 +11,26 @@ python = # Generic specification for all unspecific environments [testenv] -whitelist_externals = make +deps = + codecov extras = tests djopenid: djopenid httplib2: httplib2 pycurl: pycurl +passenv = CI TRAVIS TRAVIS_* +setenv = + DJANGO_SETTINGS_MODULE = djopenid.settings + PYTHONPATH = {toxinidir}/examples:{env:PYTHONPATH:} commands = - pip install --editable . - pip list - make test-openid - djopenid: make test-djopenid + coverage run --branch --source=openid,examples --module unittest discover --start=openid + djopenid: coverage run --branch --source=openid,examples --append --module unittest discover --start={toxinidir}/examples + codecov [testenv:quality] whitelist_externals = make basepython = python2.7 +extras = + quality commands = - pip install --editable .[quality] - pip list make check-all From 9138fb8b4af26f6d985d6c0d1d0b668fc39a49a6 Mon Sep 17 00:00:00 2001 From: "Lena (zansorgova)" Date: Sun, 25 Feb 2018 00:01:16 +0100 Subject: [PATCH 051/151] Add ResponseFetcher implementation --- openid/fetchers.py | 24 +++++++++++++ openid/test/test_fetchers.py | 65 ++++++++++++++++++++++++++++++++++++ setup.py | 3 +- tox.ini | 5 +-- 4 files changed, 94 insertions(+), 3 deletions(-) diff --git a/openid/fetchers.py b/openid/fetchers.py index 750b5f55..fa251849 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -29,6 +29,12 @@ except ImportError: pycurl = None +# try to import requests +try: + import requests +except ImportError: + requests = None + USER_AGENT = "python-openid/%s (%s)" % (openid.__version__, sys.platform) MAX_RESPONSE_KB = 1024 @@ -438,3 +444,21 @@ def fetch(self, url, body=None, headers=None): headers=dict(httplib2_response.items()), status=httplib2_response.status, ) + + +class RequestsFetcher(HTTPFetcher): + """A fetcher that uses C{requests} for performing HTTP requests.""" + + def fetch(self, url, body=None, headers=None): + """Perform an HTTP request + + @raises Exception: Any exception that can be raised by 'requests' + + @see: C{L{HTTPFetcher.fetch}} + """ + if body: + method = 'POST' + else: + method = 'GET' + response = requests.request(method, url, data=body, headers=headers) + return HTTPResponse(response.url, response.status_code, response.headers, response.content) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 16f615a2..36698ba0 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -6,10 +6,18 @@ from cStringIO import StringIO from urllib import addinfourl +import responses from mock import Mock from openid import fetchers +try: + import requests +except ImportError: + requests = None +else: + from requests.exceptions import ConnectionError, InvalidSchema + # XXX: make these separate test cases @@ -336,3 +344,60 @@ class TestSilencedUrllib2Fetcher(TestUrllib2Fetcher): fetcher = fetchers.ExceptionWrappingFetcher(fetchers.Urllib2Fetcher()) invalid_url_error = fetchers.HTTPFetchingError + + +@unittest.skipUnless(requests, "Requests are not installed") +class TestRequestsFetcher(unittest.TestCase): + """Test `RequestsFetcher` class.""" + + fetcher = fetchers.RequestsFetcher() + + def test_get(self): + # Test GET response + with responses.RequestsMock() as rsps: + rsps.add(responses.GET, 'http://example.cz/', status=200, body='BODY', + headers={'Content-Type': 'text/plain'}) + response = self.fetcher.fetch('http://example.cz/') + expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, 'BODY') + assertResponse(expected, response) + + def test_post(self): + # Test POST response + with responses.RequestsMock() as rsps: + rsps.add(responses.POST, 'http://example.cz/', status=200, body='BODY', + headers={'Content-Type': 'text/plain'}) + response = self.fetcher.fetch('http://example.cz/', body='key=value') + expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, 'BODY') + assertResponse(expected, response) + + def test_redirect(self): + # Test redirect response - a final response comes from another URL. + with responses.RequestsMock() as rsps: + rsps.add(responses.GET, 'http://example.cz/redirect/', status=302, + headers={'Location': 'http://example.cz/target/'}) + rsps.add(responses.GET, 'http://example.cz/target/', status=200, body='BODY', + headers={'Content-Type': 'text/plain'}) + response = self.fetcher.fetch('http://example.cz/redirect/') + expected = fetchers.HTTPResponse('http://example.cz/target/', 200, {'Content-Type': 'text/plain'}, 'BODY') + assertResponse(expected, response) + + def test_error(self): + # Test error responses - returned as obtained + with responses.RequestsMock() as rsps: + rsps.add(responses.GET, 'http://example.cz/error/', status=500, body='BODY', + headers={'Content-Type': 'text/plain'}) + response = self.fetcher.fetch('http://example.cz/error/') + expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') + assertResponse(expected, response) + + def test_invalid_url(self): + invalid_url = 'invalid://example.cz/' + with self.assertRaisesRegexp(InvalidSchema, "No connection adapters were found for '" + invalid_url + "'"): + self.fetcher.fetch(invalid_url) + + def test_connection_error(self): + # Test connection error + with responses.RequestsMock() as rsps: + rsps.add(responses.GET, 'http://example.cz/', body=ConnectionError('Name or service not known')) + with self.assertRaisesRegexp(ConnectionError, 'Name or service not known'): + self.fetcher.fetch('http://example.cz/') diff --git a/setup.py b/setup.py index d6de0ea3..6b2e6726 100644 --- a/setup.py +++ b/setup.py @@ -13,10 +13,11 @@ ] EXTRAS_REQUIRE = { 'quality': ('flake8', 'isort'), - 'tests': ('mock', 'testfixtures', 'coverage'), + 'tests': ('mock', 'testfixtures', 'responses', 'coverage'), # Optional dependencies for fetchers 'httplib2': ('httplib2', ), 'pycurl': ('pycurl', ), + 'requests': ('requests', ), # Dependencies for Django example 'djopenid': ('django<1.11.99', ), } diff --git a/tox.ini b/tox.ini index ab9ef238..d75be233 100644 --- a/tox.ini +++ b/tox.ini @@ -1,8 +1,8 @@ [tox] envlist = quality - py27-{openid,djopenid,httplib2,pycurl} - pypy-{openid,djopenid,httplib2,pycurl} + py27-{openid,djopenid,httplib2,pycurl,requests} + pypy-{openid,djopenid,httplib2,pycurl,requests} # tox-travis specials [travis] @@ -18,6 +18,7 @@ extras = djopenid: djopenid httplib2: httplib2 pycurl: pycurl + requests: requests passenv = CI TRAVIS TRAVIS_* setenv = DJANGO_SETTINGS_MODULE = djopenid.settings From c99f3da49c6075a273aa6d3141ee6af5e6c6e655 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 10 Apr 2018 15:29:26 +0200 Subject: [PATCH 052/151] Deprecate pape drafts --- openid/extensions/__init__.py | 2 - openid/extensions/draft/pape2.py | 4 + openid/extensions/draft/pape5.py | 460 +----------------------------- openid/extensions/pape.py | 473 +++++++++++++++++++++++++++++++ openid/test/test_pape.py | 390 ++++++++++++++++++++++++- openid/test/test_pape_draft5.py | 392 +------------------------ 6 files changed, 874 insertions(+), 847 deletions(-) create mode 100644 openid/extensions/pape.py diff --git a/openid/extensions/__init__.py b/openid/extensions/__init__.py index 710b2002..5394e7a4 100644 --- a/openid/extensions/__init__.py +++ b/openid/extensions/__init__.py @@ -1,5 +1,3 @@ """OpenID Extension modules.""" __all__ = ['ax', 'pape', 'sreg'] - -from openid.extensions.draft import pape5 as pape diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index d1790a8e..a26ddfcb 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -16,9 +16,13 @@ ] import re +import warnings from openid.extension import Extension +warnings.warn("Module 'openid.extensions.draft.pape2' is deprecated. Use 'openid.extensions.pape' instead.", + DeprecationWarning) + ns_uri = "http://specs.openid.net/extensions/pape/1.0" AUTH_MULTI_FACTOR_PHYSICAL = \ diff --git a/openid/extensions/draft/pape5.py b/openid/extensions/draft/pape5.py index 70655006..47cf9b20 100644 --- a/openid/extensions/draft/pape5.py +++ b/openid/extensions/draft/pape5.py @@ -5,6 +5,10 @@ @since: 2.1.0 """ +import warnings + +from openid.extensions.pape import (AUTH_MULTI_FACTOR, AUTH_MULTI_FACTOR_PHYSICAL, AUTH_PHISHING_RESISTANT, LEVELS_JISA, + LEVELS_NIST, Request, Response, ns_uri) __all__ = [ 'Request', @@ -17,457 +21,5 @@ 'LEVELS_JISA', ] -import re -import warnings - -from openid.extension import Extension - -ns_uri = "http://specs.openid.net/extensions/pape/1.0" - -AUTH_MULTI_FACTOR_PHYSICAL = \ - 'http://schemas.openid.net/pape/policies/2007/06/multi-factor-physical' -AUTH_MULTI_FACTOR = \ - 'http://schemas.openid.net/pape/policies/2007/06/multi-factor' -AUTH_PHISHING_RESISTANT = \ - 'http://schemas.openid.net/pape/policies/2007/06/phishing-resistant' -AUTH_NONE = \ - 'http://schemas.openid.net/pape/policies/2007/06/none' - -TIME_VALIDATOR = re.compile('^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') - -LEVELS_NIST = 'http://csrc.nist.gov/publications/nistpubs/800-63/SP800-63V1_0_2.pdf' -LEVELS_JISA = 'http://www.jisa.or.jp/spec/auth_level.html' - - -class PAPEExtension(Extension): - _default_auth_level_aliases = { - 'nist': LEVELS_NIST, - 'jisa': LEVELS_JISA, - } - - def __init__(self): - self.auth_level_aliases = self._default_auth_level_aliases.copy() - - def _addAuthLevelAlias(self, auth_level_uri, alias=None): - """Add an auth level URI alias to this request. - - @param auth_level_uri: The auth level URI to send in the - request. - - @param alias: The namespace alias to use for this auth level - in this message. May be None if the alias is not - important. - """ - if alias is None: - try: - alias = self._getAlias(auth_level_uri) - except KeyError: - alias = self._generateAlias() - else: - existing_uri = self.auth_level_aliases.get(alias) - if existing_uri is not None and existing_uri != auth_level_uri: - raise KeyError('Attempting to redefine alias %r from %r to %r', - alias, existing_uri, auth_level_uri) - - self.auth_level_aliases[alias] = auth_level_uri - - def _generateAlias(self): - """Return an unused auth level alias""" - for i in xrange(1000): - alias = 'cust%d' % (i,) - if alias not in self.auth_level_aliases: - return alias - - raise RuntimeError('Could not find an unused alias (tried 1000!)') - - def _getAlias(self, auth_level_uri): - """Return the alias for the specified auth level URI. - - @raises KeyError: if no alias is defined - """ - for (alias, existing_uri) in self.auth_level_aliases.iteritems(): - if auth_level_uri == existing_uri: - return alias - - raise KeyError(auth_level_uri) - - -class Request(PAPEExtension): - """A Provider Authentication Policy request, sent from a relying - party to a provider - - @ivar preferred_auth_policies: The authentication policies that - the relying party prefers - @type preferred_auth_policies: [str] - - @ivar max_auth_age: The maximum time, in seconds, that the relying - party wants to allow to have elapsed before the user must - re-authenticate - @type max_auth_age: int or NoneType - - @ivar preferred_auth_level_types: Ordered list of authentication - level namespace URIs - - @type preferred_auth_level_types: [str] - """ - - ns_alias = 'pape' - - def __init__(self, preferred_auth_policies=None, max_auth_age=None, - preferred_auth_level_types=None): - super(Request, self).__init__() - if preferred_auth_policies is None: - preferred_auth_policies = [] - - self.preferred_auth_policies = preferred_auth_policies - self.max_auth_age = max_auth_age - self.preferred_auth_level_types = [] - - if preferred_auth_level_types is not None: - for auth_level in preferred_auth_level_types: - self.addAuthLevel(auth_level) - - def __nonzero__(self): - return bool(self.preferred_auth_policies or - self.max_auth_age is not None or - self.preferred_auth_level_types) - - def addPolicyURI(self, policy_uri): - """Add an acceptable authentication policy URI to this request - - This method is intended to be used by the relying party to add - acceptable authentication types to the request. - - @param policy_uri: The identifier for the preferred type of - authentication. - @see: http://openid.net/specs/openid-provider-authentication-policy-extension-1_0-05.html#auth_policies - """ - if policy_uri not in self.preferred_auth_policies: - self.preferred_auth_policies.append(policy_uri) - - def addAuthLevel(self, auth_level_uri, alias=None): - self._addAuthLevelAlias(auth_level_uri, alias) - if auth_level_uri not in self.preferred_auth_level_types: - self.preferred_auth_level_types.append(auth_level_uri) - - def getExtensionArgs(self): - """@see: C{L{Extension.getExtensionArgs}} - """ - ns_args = { - 'preferred_auth_policies': ' '.join(self.preferred_auth_policies), - } - - if self.max_auth_age is not None: - ns_args['max_auth_age'] = str(self.max_auth_age) - - if self.preferred_auth_level_types: - preferred_types = [] - - for auth_level_uri in self.preferred_auth_level_types: - alias = self._getAlias(auth_level_uri) - ns_args['auth_level.ns.%s' % (alias,)] = auth_level_uri - preferred_types.append(alias) - - ns_args['preferred_auth_level_types'] = ' '.join(preferred_types) - - return ns_args - - @classmethod - def fromOpenIDRequest(cls, request): - """Instantiate a Request object from the arguments in a - C{checkid_*} OpenID message - """ - self = cls() - args = request.message.getArgs(self.ns_uri) - is_openid1 = request.message.isOpenID1() - - if args == {}: - return None - - self.parseExtensionArgs(args, is_openid1) - return self - - def parseExtensionArgs(self, args, is_openid1, strict=False): - """Set the state of this request to be that expressed in these - PAPE arguments - - @param args: The PAPE arguments without a namespace - - @param strict: Whether to raise an exception if the input is - out of spec or otherwise malformed. If strict is false, - malformed input will be ignored. - - @param is_openid1: Whether the input should be treated as part - of an OpenID1 request - - @rtype: None - - @raises ValueError: When the max_auth_age is not parseable as - an integer - """ - - # preferred_auth_policies is a space-separated list of policy URIs - self.preferred_auth_policies = [] - - policies_str = args.get('preferred_auth_policies') - if policies_str: - for uri in policies_str.split(' '): - if uri not in self.preferred_auth_policies: - self.preferred_auth_policies.append(uri) - - # max_auth_age is base-10 integer number of seconds - max_auth_age_str = args.get('max_auth_age') - self.max_auth_age = None - - if max_auth_age_str: - try: - self.max_auth_age = int(max_auth_age_str) - except ValueError: - if strict: - raise - - # Parse auth level information - preferred_auth_level_types = args.get('preferred_auth_level_types') - if preferred_auth_level_types: - aliases = preferred_auth_level_types.strip().split() - - for alias in aliases: - key = 'auth_level.ns.%s' % (alias,) - try: - uri = args[key] - except KeyError: - if is_openid1: - uri = self._default_auth_level_aliases.get(alias) - else: - uri = None - - if uri is None: - if strict: - raise ValueError('preferred auth level %r is not ' - 'defined in this message' % (alias,)) - else: - self.addAuthLevel(uri, alias) - - def preferredTypes(self, supported_types): - """Given a list of authentication policy URIs that a provider - supports, this method returns the subsequence of those types - that are preferred by the relying party. - - @param supported_types: A sequence of authentication policy - type URIs that are supported by a provider - - @returns: The sub-sequence of the supported types that are - preferred by the relying party. This list will be ordered - in the order that the types appear in the supported_types - sequence, and may be empty if the provider does not prefer - any of the supported authentication types. - - @returntype: [str] - """ - return [i for i in supported_types if i in self.preferred_auth_policies] - - -Request.ns_uri = ns_uri - - -class Response(PAPEExtension): - """A Provider Authentication Policy response, sent from a provider - to a relying party - - @ivar auth_policies: List of authentication policies conformed to - by this OpenID assertion, represented as policy URIs - """ - - ns_alias = 'pape' - - def __init__(self, auth_policies=None, auth_time=None, - auth_levels=None): - super(Response, self).__init__() - if auth_policies: - self.auth_policies = auth_policies - else: - self.auth_policies = [] - - self.auth_time = auth_time - self.auth_levels = {} - - if auth_levels is None: - auth_levels = {} - - for uri, level in auth_levels.iteritems(): - self.setAuthLevel(uri, level) - - def setAuthLevel(self, level_uri, level, alias=None): - """Set the value for the given auth level type. - - @param level: string representation of an authentication level - valid for level_uri - - @param alias: An optional namespace alias for the given auth - level URI. May be omitted if the alias is not - significant. The library will use a reasonable default for - widely-used auth level types. - """ - self._addAuthLevelAlias(level_uri, alias) - self.auth_levels[level_uri] = level - - def getAuthLevel(self, level_uri): - """Return the auth level for the specified auth level - identifier - - @returns: A string that should map to the auth levels defined - for the auth level type - - @raises KeyError: If the auth level type is not present in - this message - """ - return self.auth_levels[level_uri] - - @property - def nist_auth_level(self): - """Backward-compatibility accessor for the NIST auth level.""" - try: - return int(self.getAuthLevel(LEVELS_NIST)) - except KeyError: - return None - - def addPolicyURI(self, policy_uri): - """Add a authentication policy to this response - - This method is intended to be used by the provider to add a - policy that the provider conformed to when authenticating the user. - - @param policy_uri: The identifier for the preferred type of - authentication. - @see: http://openid.net/specs/openid-provider-authentication-policy-extension-1_0-01.html#auth_policies - """ - if policy_uri == AUTH_NONE: - raise RuntimeError( - 'To send no policies, do not set any on the response.') - - if policy_uri not in self.auth_policies: - self.auth_policies.append(policy_uri) - - @classmethod - def fromSuccessResponse(cls, success_response): - """Create a C{L{Response}} object from a successful OpenID - library response - (C{L{openid.consumer.consumer.SuccessResponse}}) response - message - - @param success_response: A SuccessResponse from consumer.complete() - @type success_response: C{L{openid.consumer.consumer.SuccessResponse}} - - @rtype: Response or None - @returns: A provider authentication policy response from the - data that was supplied with the C{id_res} response or None - if the provider sent no signed PAPE response arguments. - """ - self = cls() - - # PAPE requires that the args be signed. - args = success_response.getSignedNS(self.ns_uri) - is_openid1 = success_response.isOpenID1() - - # Only try to construct a PAPE response if the arguments were - # signed in the OpenID response. If not, return None. - if args is not None: - self.parseExtensionArgs(args, is_openid1) - return self - else: - return None - - def parseExtensionArgs(self, args, is_openid1, strict=False): - """Parse the provider authentication policy arguments into the - internal state of this object - - @param args: unqualified provider authentication policy - arguments - - @param strict: Whether to raise an exception when bad data is - encountered - - @returns: None. The data is parsed into the internal fields of - this object. - """ - policies_str = args.get('auth_policies') - if policies_str: - auth_policies = policies_str.split(' ') - elif strict: - raise ValueError('Missing auth_policies') - else: - auth_policies = [] - - if (len(auth_policies) > 1 and strict and AUTH_NONE in auth_policies): - raise ValueError('Got some auth policies, as well as the special ' - '"none" URI: %r' % (auth_policies,)) - - if 'none' in auth_policies: - msg = '"none" used as a policy URI (see PAPE draft < 5)' - if strict: - raise ValueError(msg) - else: - warnings.warn(msg, stacklevel=2) - - auth_policies = [u for u in auth_policies - if u not in ['none', AUTH_NONE]] - - self.auth_policies = auth_policies - - for (key, val) in args.iteritems(): - if key.startswith('auth_level.'): - alias = key[11:] - - # skip the already-processed namespace declarations - if alias.startswith('ns.'): - continue - - try: - uri = args['auth_level.ns.%s' % (alias,)] - except KeyError: - if is_openid1: - uri = self._default_auth_level_aliases.get(alias) - else: - uri = None - - if uri is None: - if strict: - raise ValueError( - 'Undefined auth level alias: %r' % (alias,)) - else: - self.setAuthLevel(uri, val, alias) - - auth_time = args.get('auth_time') - if auth_time: - if TIME_VALIDATOR.match(auth_time): - self.auth_time = auth_time - elif strict: - raise ValueError("auth_time must be in RFC3339 format") - - def getExtensionArgs(self): - """@see: C{L{Extension.getExtensionArgs}} - """ - if len(self.auth_policies) == 0: - ns_args = { - 'auth_policies': AUTH_NONE, - } - else: - ns_args = { - 'auth_policies': ' '.join(self.auth_policies), - } - - for level_type, level in self.auth_levels.iteritems(): - alias = self._getAlias(level_type) - ns_args['auth_level.ns.%s' % (alias,)] = level_type - ns_args['auth_level.%s' % (alias,)] = str(level) - - if self.auth_time is not None: - if not TIME_VALIDATOR.match(self.auth_time): - raise ValueError('auth_time must be in RFC3339 format') - - ns_args['auth_time'] = self.auth_time - - return ns_args - - -Response.ns_uri = ns_uri +warnings.warn("Module 'openid.extensions.draft.pape5' is deprecated in favor of 'openid.extensions.pape'.", + DeprecationWarning) diff --git a/openid/extensions/pape.py b/openid/extensions/pape.py new file mode 100644 index 00000000..70655006 --- /dev/null +++ b/openid/extensions/pape.py @@ -0,0 +1,473 @@ +"""An implementation of the OpenID Provider Authentication Policy +Extension 1.0, Draft 5 + +@see: http://openid.net/developers/specs/ + +@since: 2.1.0 +""" + +__all__ = [ + 'Request', + 'Response', + 'ns_uri', + 'AUTH_PHISHING_RESISTANT', + 'AUTH_MULTI_FACTOR', + 'AUTH_MULTI_FACTOR_PHYSICAL', + 'LEVELS_NIST', + 'LEVELS_JISA', +] + +import re +import warnings + +from openid.extension import Extension + +ns_uri = "http://specs.openid.net/extensions/pape/1.0" + +AUTH_MULTI_FACTOR_PHYSICAL = \ + 'http://schemas.openid.net/pape/policies/2007/06/multi-factor-physical' +AUTH_MULTI_FACTOR = \ + 'http://schemas.openid.net/pape/policies/2007/06/multi-factor' +AUTH_PHISHING_RESISTANT = \ + 'http://schemas.openid.net/pape/policies/2007/06/phishing-resistant' +AUTH_NONE = \ + 'http://schemas.openid.net/pape/policies/2007/06/none' + +TIME_VALIDATOR = re.compile('^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') + +LEVELS_NIST = 'http://csrc.nist.gov/publications/nistpubs/800-63/SP800-63V1_0_2.pdf' +LEVELS_JISA = 'http://www.jisa.or.jp/spec/auth_level.html' + + +class PAPEExtension(Extension): + _default_auth_level_aliases = { + 'nist': LEVELS_NIST, + 'jisa': LEVELS_JISA, + } + + def __init__(self): + self.auth_level_aliases = self._default_auth_level_aliases.copy() + + def _addAuthLevelAlias(self, auth_level_uri, alias=None): + """Add an auth level URI alias to this request. + + @param auth_level_uri: The auth level URI to send in the + request. + + @param alias: The namespace alias to use for this auth level + in this message. May be None if the alias is not + important. + """ + if alias is None: + try: + alias = self._getAlias(auth_level_uri) + except KeyError: + alias = self._generateAlias() + else: + existing_uri = self.auth_level_aliases.get(alias) + if existing_uri is not None and existing_uri != auth_level_uri: + raise KeyError('Attempting to redefine alias %r from %r to %r', + alias, existing_uri, auth_level_uri) + + self.auth_level_aliases[alias] = auth_level_uri + + def _generateAlias(self): + """Return an unused auth level alias""" + for i in xrange(1000): + alias = 'cust%d' % (i,) + if alias not in self.auth_level_aliases: + return alias + + raise RuntimeError('Could not find an unused alias (tried 1000!)') + + def _getAlias(self, auth_level_uri): + """Return the alias for the specified auth level URI. + + @raises KeyError: if no alias is defined + """ + for (alias, existing_uri) in self.auth_level_aliases.iteritems(): + if auth_level_uri == existing_uri: + return alias + + raise KeyError(auth_level_uri) + + +class Request(PAPEExtension): + """A Provider Authentication Policy request, sent from a relying + party to a provider + + @ivar preferred_auth_policies: The authentication policies that + the relying party prefers + @type preferred_auth_policies: [str] + + @ivar max_auth_age: The maximum time, in seconds, that the relying + party wants to allow to have elapsed before the user must + re-authenticate + @type max_auth_age: int or NoneType + + @ivar preferred_auth_level_types: Ordered list of authentication + level namespace URIs + + @type preferred_auth_level_types: [str] + """ + + ns_alias = 'pape' + + def __init__(self, preferred_auth_policies=None, max_auth_age=None, + preferred_auth_level_types=None): + super(Request, self).__init__() + if preferred_auth_policies is None: + preferred_auth_policies = [] + + self.preferred_auth_policies = preferred_auth_policies + self.max_auth_age = max_auth_age + self.preferred_auth_level_types = [] + + if preferred_auth_level_types is not None: + for auth_level in preferred_auth_level_types: + self.addAuthLevel(auth_level) + + def __nonzero__(self): + return bool(self.preferred_auth_policies or + self.max_auth_age is not None or + self.preferred_auth_level_types) + + def addPolicyURI(self, policy_uri): + """Add an acceptable authentication policy URI to this request + + This method is intended to be used by the relying party to add + acceptable authentication types to the request. + + @param policy_uri: The identifier for the preferred type of + authentication. + @see: http://openid.net/specs/openid-provider-authentication-policy-extension-1_0-05.html#auth_policies + """ + if policy_uri not in self.preferred_auth_policies: + self.preferred_auth_policies.append(policy_uri) + + def addAuthLevel(self, auth_level_uri, alias=None): + self._addAuthLevelAlias(auth_level_uri, alias) + if auth_level_uri not in self.preferred_auth_level_types: + self.preferred_auth_level_types.append(auth_level_uri) + + def getExtensionArgs(self): + """@see: C{L{Extension.getExtensionArgs}} + """ + ns_args = { + 'preferred_auth_policies': ' '.join(self.preferred_auth_policies), + } + + if self.max_auth_age is not None: + ns_args['max_auth_age'] = str(self.max_auth_age) + + if self.preferred_auth_level_types: + preferred_types = [] + + for auth_level_uri in self.preferred_auth_level_types: + alias = self._getAlias(auth_level_uri) + ns_args['auth_level.ns.%s' % (alias,)] = auth_level_uri + preferred_types.append(alias) + + ns_args['preferred_auth_level_types'] = ' '.join(preferred_types) + + return ns_args + + @classmethod + def fromOpenIDRequest(cls, request): + """Instantiate a Request object from the arguments in a + C{checkid_*} OpenID message + """ + self = cls() + args = request.message.getArgs(self.ns_uri) + is_openid1 = request.message.isOpenID1() + + if args == {}: + return None + + self.parseExtensionArgs(args, is_openid1) + return self + + def parseExtensionArgs(self, args, is_openid1, strict=False): + """Set the state of this request to be that expressed in these + PAPE arguments + + @param args: The PAPE arguments without a namespace + + @param strict: Whether to raise an exception if the input is + out of spec or otherwise malformed. If strict is false, + malformed input will be ignored. + + @param is_openid1: Whether the input should be treated as part + of an OpenID1 request + + @rtype: None + + @raises ValueError: When the max_auth_age is not parseable as + an integer + """ + + # preferred_auth_policies is a space-separated list of policy URIs + self.preferred_auth_policies = [] + + policies_str = args.get('preferred_auth_policies') + if policies_str: + for uri in policies_str.split(' '): + if uri not in self.preferred_auth_policies: + self.preferred_auth_policies.append(uri) + + # max_auth_age is base-10 integer number of seconds + max_auth_age_str = args.get('max_auth_age') + self.max_auth_age = None + + if max_auth_age_str: + try: + self.max_auth_age = int(max_auth_age_str) + except ValueError: + if strict: + raise + + # Parse auth level information + preferred_auth_level_types = args.get('preferred_auth_level_types') + if preferred_auth_level_types: + aliases = preferred_auth_level_types.strip().split() + + for alias in aliases: + key = 'auth_level.ns.%s' % (alias,) + try: + uri = args[key] + except KeyError: + if is_openid1: + uri = self._default_auth_level_aliases.get(alias) + else: + uri = None + + if uri is None: + if strict: + raise ValueError('preferred auth level %r is not ' + 'defined in this message' % (alias,)) + else: + self.addAuthLevel(uri, alias) + + def preferredTypes(self, supported_types): + """Given a list of authentication policy URIs that a provider + supports, this method returns the subsequence of those types + that are preferred by the relying party. + + @param supported_types: A sequence of authentication policy + type URIs that are supported by a provider + + @returns: The sub-sequence of the supported types that are + preferred by the relying party. This list will be ordered + in the order that the types appear in the supported_types + sequence, and may be empty if the provider does not prefer + any of the supported authentication types. + + @returntype: [str] + """ + return [i for i in supported_types if i in self.preferred_auth_policies] + + +Request.ns_uri = ns_uri + + +class Response(PAPEExtension): + """A Provider Authentication Policy response, sent from a provider + to a relying party + + @ivar auth_policies: List of authentication policies conformed to + by this OpenID assertion, represented as policy URIs + """ + + ns_alias = 'pape' + + def __init__(self, auth_policies=None, auth_time=None, + auth_levels=None): + super(Response, self).__init__() + if auth_policies: + self.auth_policies = auth_policies + else: + self.auth_policies = [] + + self.auth_time = auth_time + self.auth_levels = {} + + if auth_levels is None: + auth_levels = {} + + for uri, level in auth_levels.iteritems(): + self.setAuthLevel(uri, level) + + def setAuthLevel(self, level_uri, level, alias=None): + """Set the value for the given auth level type. + + @param level: string representation of an authentication level + valid for level_uri + + @param alias: An optional namespace alias for the given auth + level URI. May be omitted if the alias is not + significant. The library will use a reasonable default for + widely-used auth level types. + """ + self._addAuthLevelAlias(level_uri, alias) + self.auth_levels[level_uri] = level + + def getAuthLevel(self, level_uri): + """Return the auth level for the specified auth level + identifier + + @returns: A string that should map to the auth levels defined + for the auth level type + + @raises KeyError: If the auth level type is not present in + this message + """ + return self.auth_levels[level_uri] + + @property + def nist_auth_level(self): + """Backward-compatibility accessor for the NIST auth level.""" + try: + return int(self.getAuthLevel(LEVELS_NIST)) + except KeyError: + return None + + def addPolicyURI(self, policy_uri): + """Add a authentication policy to this response + + This method is intended to be used by the provider to add a + policy that the provider conformed to when authenticating the user. + + @param policy_uri: The identifier for the preferred type of + authentication. + @see: http://openid.net/specs/openid-provider-authentication-policy-extension-1_0-01.html#auth_policies + """ + if policy_uri == AUTH_NONE: + raise RuntimeError( + 'To send no policies, do not set any on the response.') + + if policy_uri not in self.auth_policies: + self.auth_policies.append(policy_uri) + + @classmethod + def fromSuccessResponse(cls, success_response): + """Create a C{L{Response}} object from a successful OpenID + library response + (C{L{openid.consumer.consumer.SuccessResponse}}) response + message + + @param success_response: A SuccessResponse from consumer.complete() + @type success_response: C{L{openid.consumer.consumer.SuccessResponse}} + + @rtype: Response or None + @returns: A provider authentication policy response from the + data that was supplied with the C{id_res} response or None + if the provider sent no signed PAPE response arguments. + """ + self = cls() + + # PAPE requires that the args be signed. + args = success_response.getSignedNS(self.ns_uri) + is_openid1 = success_response.isOpenID1() + + # Only try to construct a PAPE response if the arguments were + # signed in the OpenID response. If not, return None. + if args is not None: + self.parseExtensionArgs(args, is_openid1) + return self + else: + return None + + def parseExtensionArgs(self, args, is_openid1, strict=False): + """Parse the provider authentication policy arguments into the + internal state of this object + + @param args: unqualified provider authentication policy + arguments + + @param strict: Whether to raise an exception when bad data is + encountered + + @returns: None. The data is parsed into the internal fields of + this object. + """ + policies_str = args.get('auth_policies') + if policies_str: + auth_policies = policies_str.split(' ') + elif strict: + raise ValueError('Missing auth_policies') + else: + auth_policies = [] + + if (len(auth_policies) > 1 and strict and AUTH_NONE in auth_policies): + raise ValueError('Got some auth policies, as well as the special ' + '"none" URI: %r' % (auth_policies,)) + + if 'none' in auth_policies: + msg = '"none" used as a policy URI (see PAPE draft < 5)' + if strict: + raise ValueError(msg) + else: + warnings.warn(msg, stacklevel=2) + + auth_policies = [u for u in auth_policies + if u not in ['none', AUTH_NONE]] + + self.auth_policies = auth_policies + + for (key, val) in args.iteritems(): + if key.startswith('auth_level.'): + alias = key[11:] + + # skip the already-processed namespace declarations + if alias.startswith('ns.'): + continue + + try: + uri = args['auth_level.ns.%s' % (alias,)] + except KeyError: + if is_openid1: + uri = self._default_auth_level_aliases.get(alias) + else: + uri = None + + if uri is None: + if strict: + raise ValueError( + 'Undefined auth level alias: %r' % (alias,)) + else: + self.setAuthLevel(uri, val, alias) + + auth_time = args.get('auth_time') + if auth_time: + if TIME_VALIDATOR.match(auth_time): + self.auth_time = auth_time + elif strict: + raise ValueError("auth_time must be in RFC3339 format") + + def getExtensionArgs(self): + """@see: C{L{Extension.getExtensionArgs}} + """ + if len(self.auth_policies) == 0: + ns_args = { + 'auth_policies': AUTH_NONE, + } + else: + ns_args = { + 'auth_policies': ' '.join(self.auth_policies), + } + + for level_type, level in self.auth_levels.iteritems(): + alias = self._getAlias(level_type) + ns_args['auth_level.ns.%s' % (alias,)] = level_type + ns_args['auth_level.%s' % (alias,)] = str(level) + + if self.auth_time is not None: + if not TIME_VALIDATOR.match(self.auth_time): + raise ValueError('auth_time must be in RFC3339 format') + + ns_args['auth_time'] = self.auth_time + + return ns_args + + +Response.ns_uri = ns_uri diff --git a/openid/test/test_pape.py b/openid/test/test_pape.py index 0507b2c8..056fb891 100644 --- a/openid/test/test_pape.py +++ b/openid/test/test_pape.py @@ -1,10 +1,390 @@ - import unittest +import warnings from openid.extensions import pape +from openid.message import OPENID2_NS, Message +from openid.server import server + +warnings.filterwarnings('ignore', module=__name__, + message='"none" used as a policy URI') + + +class PapeRequestTestCase(unittest.TestCase): + def setUp(self): + self.req = pape.Request() + + def test_construct(self): + self.assertEqual(self.req.preferred_auth_policies, []) + self.assertIsNone(self.req.max_auth_age) + self.assertEqual(self.req.ns_alias, 'pape') + self.assertFalse(self.req.preferred_auth_level_types) + + bogus_levels = ['http://janrain.com/our_levels'] + req2 = pape.Request( + [pape.AUTH_MULTI_FACTOR], 1000, bogus_levels) + self.assertEqual(req2.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.assertEqual(req2.max_auth_age, 1000) + self.assertEqual(req2.preferred_auth_level_types, bogus_levels) + + def test_addAuthLevel(self): + self.req.addAuthLevel('http://example.com/', 'example') + self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/']) + self.assertEqual(self.req.auth_level_aliases['example'], 'http://example.com/') + + self.req.addAuthLevel('http://example.com/1', 'example1') + self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) + + self.req.addAuthLevel('http://example.com/', 'exmpl') + self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) + + self.req.addAuthLevel('http://example.com/', 'example') + self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) + + self.assertRaises(KeyError, self.req.addAuthLevel, 'http://example.com/2', 'example') + + # alias is None; we expect a new one to be generated. + uri = 'http://another.example.com/' + self.req.addAuthLevel(uri) + self.assert_(uri in self.req.auth_level_aliases.values()) + + # We don't expect a new alias to be generated if one already + # exists. + before_aliases = self.req.auth_level_aliases.keys() + self.req.addAuthLevel(uri) + after_aliases = self.req.auth_level_aliases.keys() + self.assertEqual(after_aliases, before_aliases) + + def test_add_policy_uri(self): + self.assertEqual(self.req.preferred_auth_policies, []) + self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) + self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) + self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) + + def test_getExtensionArgs(self): + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': ''}) + self.req.addPolicyURI('http://uri') + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': 'http://uri'}) + self.req.addPolicyURI('http://zig') + self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': 'http://uri http://zig'}) + self.req.max_auth_age = 789 + self.assertEqual(self.req.getExtensionArgs(), + {'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}) + + def test_getExtensionArgsWithAuthLevels(self): + uri = 'http://example.com/auth_level' + alias = 'my_level' + self.req.addAuthLevel(uri, alias) + + uri2 = 'http://example.com/auth_level_2' + alias2 = 'my_level_2' + self.req.addAuthLevel(uri2, alias2) + + expected_args = { + ('auth_level.ns.%s' % alias): uri, + ('auth_level.ns.%s' % alias2): uri2, + 'preferred_auth_level_types': ' '.join([alias, alias2]), + 'preferred_auth_policies': '', + } + + self.assertEqual(self.req.getExtensionArgs(), expected_args) + + def test_parseExtensionArgsWithAuthLevels(self): + uri = 'http://example.com/auth_level' + alias = 'my_level' + + uri2 = 'http://example.com/auth_level_2' + alias2 = 'my_level_2' + + request_args = { + ('auth_level.ns.%s' % alias): uri, + ('auth_level.ns.%s' % alias2): uri2, + 'preferred_auth_level_types': ' '.join([alias, alias2]), + 'preferred_auth_policies': '', + } + + # Check request object state + self.req.parseExtensionArgs(request_args, is_openid1=False, strict=False) + + expected_auth_levels = [uri, uri2] + + self.assertEqual(self.req.preferred_auth_level_types, expected_auth_levels) + self.assertEqual(self.req.auth_level_aliases[alias], uri) + self.assertEqual(self.req.auth_level_aliases[alias2], uri2) + + def test_parseExtensionArgsWithAuthLevels_openID1(self): + request_args = { + 'preferred_auth_level_types': 'nist jisa', + } + expected_auth_levels = [pape.LEVELS_NIST, pape.LEVELS_JISA] + self.req.parseExtensionArgs(request_args, is_openid1=True) + self.assertEqual(self.req.preferred_auth_level_types, expected_auth_levels) + + self.req = pape.Request() + self.req.parseExtensionArgs(request_args, is_openid1=False) + self.assertEqual(self.req.preferred_auth_level_types, []) + + self.req = pape.Request() + self.assertRaises(ValueError, self.req.parseExtensionArgs, request_args, is_openid1=False, strict=True) + + def test_parseExtensionArgs_ignoreBadAuthLevels(self): + request_args = {'preferred_auth_level_types': 'monkeys'} + self.req.parseExtensionArgs(request_args, False) + self.assertEqual(self.req.preferred_auth_level_types, []) + + def test_parseExtensionArgs_strictBadAuthLevels(self): + request_args = {'preferred_auth_level_types': 'monkeys'} + self.assertRaises(ValueError, self.req.parseExtensionArgs, request_args, is_openid1=False, strict=True) + + def test_parseExtensionArgs(self): + args = {'preferred_auth_policies': 'http://foo http://bar', + 'max_auth_age': '9'} + self.req.parseExtensionArgs(args, False) + self.assertEqual(self.req.max_auth_age, 9) + self.assertEqual(self.req.preferred_auth_policies, ['http://foo', 'http://bar']) + self.assertEqual(self.req.preferred_auth_level_types, []) + + def test_parseExtensionArgs_strict_bad_auth_age(self): + args = {'max_auth_age': 'not an int'} + self.assertRaises(ValueError, self.req.parseExtensionArgs, args, is_openid1=False, strict=True) + + def test_parseExtensionArgs_empty(self): + self.req.parseExtensionArgs({}, False) + self.assertIsNone(self.req.max_auth_age) + self.assertEqual(self.req.preferred_auth_policies, []) + self.assertEqual(self.req.preferred_auth_level_types, []) + + def test_fromOpenIDRequest(self): + policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] + openid_req_msg = Message.fromOpenIDArgs({ + 'mode': 'checkid_setup', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.preferred_auth_policies': ' '.join(policy_uris), + 'pape.max_auth_age': '5476' + }) + oid_req = server.OpenIDRequest() + oid_req.message = openid_req_msg + req = pape.Request.fromOpenIDRequest(oid_req) + self.assertEqual(req.preferred_auth_policies, policy_uris) + self.assertEqual(req.max_auth_age, 5476) + + def test_fromOpenIDRequest_no_pape(self): + message = Message() + openid_req = server.OpenIDRequest() + openid_req.message = message + pape_req = pape.Request.fromOpenIDRequest(openid_req) + assert(pape_req is None) + + def test_preferred_types(self): + self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) + self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) + pt = self.req.preferredTypes([pape.AUTH_MULTI_FACTOR, + pape.AUTH_MULTI_FACTOR_PHYSICAL]) + self.assertEqual(pt, [pape.AUTH_MULTI_FACTOR]) + + +class DummySuccessResponse: + def __init__(self, message, signed_stuff): + self.message = message + self.signed_stuff = signed_stuff + + def isOpenID1(self): + return False + + def getSignedNS(self, ns_uri): + return self.signed_stuff + + +class PapeResponseTestCase(unittest.TestCase): + def setUp(self): + self.resp = pape.Response() + + def test_construct(self): + self.assertEqual(self.resp.auth_policies, []) + self.assertIsNone(self.resp.auth_time) + self.assertEqual(self.resp.ns_alias, 'pape') + self.assertIsNone(self.resp.nist_auth_level) + + req2 = pape.Response([pape.AUTH_MULTI_FACTOR], + "2004-12-11T10:30:44Z", {pape.LEVELS_NIST: 3}) + self.assertEqual(req2.auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.assertEqual(req2.auth_time, "2004-12-11T10:30:44Z") + self.assertEqual(req2.nist_auth_level, 3) + + def test_add_policy_uri(self): + self.assertEqual(self.resp.auth_policies, []) + self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR]) + self.resp.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) + self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) + + self.assertRaises(RuntimeError, self.resp.addPolicyURI, pape.AUTH_NONE) + + def test_getExtensionArgs(self): + self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': pape.AUTH_NONE}) + self.resp.addPolicyURI('http://uri') + self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': 'http://uri'}) + self.resp.addPolicyURI('http://zig') + self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': 'http://uri http://zig'}) + self.resp.auth_time = "1776-07-04T14:43:12Z" + self.assertEqual(self.resp.getExtensionArgs(), + {'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}) + self.resp.setAuthLevel(pape.LEVELS_NIST, '3') + nist_args = {'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z", + 'auth_level.nist': '3', 'auth_level.ns.nist': pape.LEVELS_NIST} + self.assertEqual(self.resp.getExtensionArgs(), nist_args) + + def test_getExtensionArgs_error_auth_age(self): + self.resp.auth_time = "long ago" + self.assertRaises(ValueError, self.resp.getExtensionArgs) + + def test_parseExtensionArgs(self): + args = {'auth_policies': 'http://foo http://bar', + 'auth_time': '1970-01-01T00:00:00Z'} + self.resp.parseExtensionArgs(args, is_openid1=False) + self.assertEqual(self.resp.auth_time, '1970-01-01T00:00:00Z') + self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) + + def test_parseExtensionArgs_valid_none(self): + args = {'auth_policies': pape.AUTH_NONE} + self.resp.parseExtensionArgs(args, is_openid1=False) + self.assertEqual(self.resp.auth_policies, []) + + def test_parseExtensionArgs_old_none(self): + args = {'auth_policies': 'none'} + self.resp.parseExtensionArgs(args, is_openid1=False) + self.assertEqual(self.resp.auth_policies, []) + + def test_parseExtensionArgs_old_none_strict(self): + args = {'auth_policies': 'none'} + self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) + + def test_parseExtensionArgs_empty(self): + self.resp.parseExtensionArgs({}, is_openid1=False) + self.assertIsNone(self.resp.auth_time) + self.assertEqual(self.resp.auth_policies, []) + + def test_parseExtensionArgs_empty_strict(self): + self.assertRaises(ValueError, self.resp.parseExtensionArgs, {}, is_openid1=False, strict=True) + + def test_parseExtensionArgs_ignore_superfluous_none(self): + policies = [pape.AUTH_NONE, pape.AUTH_MULTI_FACTOR_PHYSICAL] + + args = { + 'auth_policies': ' '.join(policies), + } + + self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) + + self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR_PHYSICAL]) + + def test_parseExtensionArgs_none_strict(self): + policies = [pape.AUTH_NONE, pape.AUTH_MULTI_FACTOR_PHYSICAL] + + args = { + 'auth_policies': ' '.join(policies), + } + + self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) + + def test_parseExtensionArgs_strict_bogus1(self): + args = {'auth_policies': 'http://foo http://bar', + 'auth_time': 'yesterday'} + self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) + + def test_parseExtensionArgs_openid1_strict(self): + args = {'auth_level.nist': '0', + 'auth_policies': pape.AUTH_NONE, + } + self.resp.parseExtensionArgs(args, strict=True, is_openid1=True) + self.assertEqual(self.resp.getAuthLevel(pape.LEVELS_NIST), '0') + self.assertEqual(self.resp.auth_policies, []) + + def test_parseExtensionArgs_strict_no_namespace_decl_openid2(self): + # Test the case where the namespace is not declared for an + # auth level. + args = {'auth_policies': pape.AUTH_NONE, + 'auth_level.nist': '0', + } + self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) + + def test_parseExtensionArgs_nostrict_no_namespace_decl_openid2(self): + # Test the case where the namespace is not declared for an + # auth level. + args = {'auth_policies': pape.AUTH_NONE, + 'auth_level.nist': '0', + } + self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) + + # There is no namespace declaration for this auth level. + self.assertRaises(KeyError, self.resp.getAuthLevel, pape.LEVELS_NIST) + + def test_parseExtensionArgs_strict_good(self): + args = {'auth_policies': 'http://foo http://bar', + 'auth_time': '1970-01-01T00:00:00Z', + 'auth_level.nist': '0', + 'auth_level.ns.nist': pape.LEVELS_NIST} + self.resp.parseExtensionArgs(args, is_openid1=False, strict=True) + self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) + self.assertEqual(self.resp.auth_time, '1970-01-01T00:00:00Z') + self.assertEqual(self.resp.nist_auth_level, 0) + + def test_parseExtensionArgs_nostrict_bogus(self): + args = {'auth_policies': 'http://foo http://bar', + 'auth_time': 'when the cows come home', + 'nist_auth_level': 'some'} + self.resp.parseExtensionArgs(args, is_openid1=False) + self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) + self.assertIsNone(self.resp.auth_time) + self.assertIsNone(self.resp.nist_auth_level) + + def test_fromSuccessResponse(self): + policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] + openid_req_msg = Message.fromOpenIDArgs({ + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join(policy_uris), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) + signed_stuff = { + 'auth_policies': ' '.join(policy_uris), + 'auth_time': '1970-01-01T00:00:00Z' + } + oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) + req = pape.Response.fromSuccessResponse(oid_req) + self.assertEqual(req.auth_policies, policy_uris) + self.assertEqual(req.auth_time, '1970-01-01T00:00:00Z') + + def test_fromSuccessResponseNoSignedArgs(self): + policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] + openid_req_msg = Message.fromOpenIDArgs({ + 'mode': 'id_res', + 'ns': OPENID2_NS, + 'ns.pape': pape.ns_uri, + 'pape.auth_policies': ' '.join(policy_uris), + 'pape.auth_time': '1970-01-01T00:00:00Z' + }) + + signed_stuff = {} + + class NoSigningDummyResponse(DummySuccessResponse): + def getSignedNS(self, ns_uri): + return None + + oid_req = NoSigningDummyResponse(openid_req_msg, signed_stuff) + resp = pape.Response.fromSuccessResponse(oid_req) + self.assertIsNone(resp) -class PapeImportTestCase(unittest.TestCase): - def test_version(self): - from openid.extensions.draft import pape5 - self.assert_(pape is pape5) +if __name__ == '__main__': + unittest.main() diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index 104411a3..95852066 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -1,390 +1,10 @@ import unittest -import warnings -from openid.extensions.draft import pape5 as pape -from openid.message import OPENID2_NS, Message -from openid.server import server +from openid.extensions import pape -warnings.filterwarnings('ignore', module=__name__, - message='"none" used as a policy URI') - -class PapeRequestTestCase(unittest.TestCase): - def setUp(self): - self.req = pape.Request() - - def test_construct(self): - self.assertEqual(self.req.preferred_auth_policies, []) - self.assertIsNone(self.req.max_auth_age) - self.assertEqual(self.req.ns_alias, 'pape') - self.assertFalse(self.req.preferred_auth_level_types) - - bogus_levels = ['http://janrain.com/our_levels'] - req2 = pape.Request( - [pape.AUTH_MULTI_FACTOR], 1000, bogus_levels) - self.assertEqual(req2.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) - self.assertEqual(req2.max_auth_age, 1000) - self.assertEqual(req2.preferred_auth_level_types, bogus_levels) - - def test_addAuthLevel(self): - self.req.addAuthLevel('http://example.com/', 'example') - self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/']) - self.assertEqual(self.req.auth_level_aliases['example'], 'http://example.com/') - - self.req.addAuthLevel('http://example.com/1', 'example1') - self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) - - self.req.addAuthLevel('http://example.com/', 'exmpl') - self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) - - self.req.addAuthLevel('http://example.com/', 'example') - self.assertEqual(self.req.preferred_auth_level_types, ['http://example.com/', 'http://example.com/1']) - - self.assertRaises(KeyError, self.req.addAuthLevel, 'http://example.com/2', 'example') - - # alias is None; we expect a new one to be generated. - uri = 'http://another.example.com/' - self.req.addAuthLevel(uri) - self.assert_(uri in self.req.auth_level_aliases.values()) - - # We don't expect a new alias to be generated if one already - # exists. - before_aliases = self.req.auth_level_aliases.keys() - self.req.addAuthLevel(uri) - after_aliases = self.req.auth_level_aliases.keys() - self.assertEqual(after_aliases, before_aliases) - - def test_add_policy_uri(self): - self.assertEqual(self.req.preferred_auth_policies, []) - self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) - self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR]) - self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) - self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual(self.req.preferred_auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) - - def test_getExtensionArgs(self): - self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': ''}) - self.req.addPolicyURI('http://uri') - self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': 'http://uri'}) - self.req.addPolicyURI('http://zig') - self.assertEqual(self.req.getExtensionArgs(), {'preferred_auth_policies': 'http://uri http://zig'}) - self.req.max_auth_age = 789 - self.assertEqual(self.req.getExtensionArgs(), - {'preferred_auth_policies': 'http://uri http://zig', 'max_auth_age': '789'}) - - def test_getExtensionArgsWithAuthLevels(self): - uri = 'http://example.com/auth_level' - alias = 'my_level' - self.req.addAuthLevel(uri, alias) - - uri2 = 'http://example.com/auth_level_2' - alias2 = 'my_level_2' - self.req.addAuthLevel(uri2, alias2) - - expected_args = { - ('auth_level.ns.%s' % alias): uri, - ('auth_level.ns.%s' % alias2): uri2, - 'preferred_auth_level_types': ' '.join([alias, alias2]), - 'preferred_auth_policies': '', - } - - self.assertEqual(self.req.getExtensionArgs(), expected_args) - - def test_parseExtensionArgsWithAuthLevels(self): - uri = 'http://example.com/auth_level' - alias = 'my_level' - - uri2 = 'http://example.com/auth_level_2' - alias2 = 'my_level_2' - - request_args = { - ('auth_level.ns.%s' % alias): uri, - ('auth_level.ns.%s' % alias2): uri2, - 'preferred_auth_level_types': ' '.join([alias, alias2]), - 'preferred_auth_policies': '', - } - - # Check request object state - self.req.parseExtensionArgs(request_args, is_openid1=False, strict=False) - - expected_auth_levels = [uri, uri2] - - self.assertEqual(self.req.preferred_auth_level_types, expected_auth_levels) - self.assertEqual(self.req.auth_level_aliases[alias], uri) - self.assertEqual(self.req.auth_level_aliases[alias2], uri2) - - def test_parseExtensionArgsWithAuthLevels_openID1(self): - request_args = { - 'preferred_auth_level_types': 'nist jisa', - } - expected_auth_levels = [pape.LEVELS_NIST, pape.LEVELS_JISA] - self.req.parseExtensionArgs(request_args, is_openid1=True) - self.assertEqual(self.req.preferred_auth_level_types, expected_auth_levels) - - self.req = pape.Request() - self.req.parseExtensionArgs(request_args, is_openid1=False) - self.assertEqual(self.req.preferred_auth_level_types, []) - - self.req = pape.Request() - self.assertRaises(ValueError, self.req.parseExtensionArgs, request_args, is_openid1=False, strict=True) - - def test_parseExtensionArgs_ignoreBadAuthLevels(self): - request_args = {'preferred_auth_level_types': 'monkeys'} - self.req.parseExtensionArgs(request_args, False) - self.assertEqual(self.req.preferred_auth_level_types, []) - - def test_parseExtensionArgs_strictBadAuthLevels(self): - request_args = {'preferred_auth_level_types': 'monkeys'} - self.assertRaises(ValueError, self.req.parseExtensionArgs, request_args, is_openid1=False, strict=True) - - def test_parseExtensionArgs(self): - args = {'preferred_auth_policies': 'http://foo http://bar', - 'max_auth_age': '9'} - self.req.parseExtensionArgs(args, False) - self.assertEqual(self.req.max_auth_age, 9) - self.assertEqual(self.req.preferred_auth_policies, ['http://foo', 'http://bar']) - self.assertEqual(self.req.preferred_auth_level_types, []) - - def test_parseExtensionArgs_strict_bad_auth_age(self): - args = {'max_auth_age': 'not an int'} - self.assertRaises(ValueError, self.req.parseExtensionArgs, args, is_openid1=False, strict=True) - - def test_parseExtensionArgs_empty(self): - self.req.parseExtensionArgs({}, False) - self.assertIsNone(self.req.max_auth_age) - self.assertEqual(self.req.preferred_auth_policies, []) - self.assertEqual(self.req.preferred_auth_level_types, []) - - def test_fromOpenIDRequest(self): - policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] - openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'checkid_setup', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.preferred_auth_policies': ' '.join(policy_uris), - 'pape.max_auth_age': '5476' - }) - oid_req = server.OpenIDRequest() - oid_req.message = openid_req_msg - req = pape.Request.fromOpenIDRequest(oid_req) - self.assertEqual(req.preferred_auth_policies, policy_uris) - self.assertEqual(req.max_auth_age, 5476) - - def test_fromOpenIDRequest_no_pape(self): - message = Message() - openid_req = server.OpenIDRequest() - openid_req.message = message - pape_req = pape.Request.fromOpenIDRequest(openid_req) - assert(pape_req is None) - - def test_preferred_types(self): - self.req.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.req.addPolicyURI(pape.AUTH_MULTI_FACTOR) - pt = self.req.preferredTypes([pape.AUTH_MULTI_FACTOR, - pape.AUTH_MULTI_FACTOR_PHYSICAL]) - self.assertEqual(pt, [pape.AUTH_MULTI_FACTOR]) - - -class DummySuccessResponse: - def __init__(self, message, signed_stuff): - self.message = message - self.signed_stuff = signed_stuff - - def isOpenID1(self): - return False - - def getSignedNS(self, ns_uri): - return self.signed_stuff - - -class PapeResponseTestCase(unittest.TestCase): - def setUp(self): - self.resp = pape.Response() - - def test_construct(self): - self.assertEqual(self.resp.auth_policies, []) - self.assertIsNone(self.resp.auth_time) - self.assertEqual(self.resp.ns_alias, 'pape') - self.assertIsNone(self.resp.nist_auth_level) - - req2 = pape.Response([pape.AUTH_MULTI_FACTOR], - "2004-12-11T10:30:44Z", {pape.LEVELS_NIST: 3}) - self.assertEqual(req2.auth_policies, [pape.AUTH_MULTI_FACTOR]) - self.assertEqual(req2.auth_time, "2004-12-11T10:30:44Z") - self.assertEqual(req2.nist_auth_level, 3) - - def test_add_policy_uri(self): - self.assertEqual(self.resp.auth_policies, []) - self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR]) - self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR]) - self.resp.addPolicyURI(pape.AUTH_PHISHING_RESISTANT) - self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) - self.resp.addPolicyURI(pape.AUTH_MULTI_FACTOR) - self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT]) - - self.assertRaises(RuntimeError, self.resp.addPolicyURI, pape.AUTH_NONE) - - def test_getExtensionArgs(self): - self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': pape.AUTH_NONE}) - self.resp.addPolicyURI('http://uri') - self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': 'http://uri'}) - self.resp.addPolicyURI('http://zig') - self.assertEqual(self.resp.getExtensionArgs(), {'auth_policies': 'http://uri http://zig'}) - self.resp.auth_time = "1776-07-04T14:43:12Z" - self.assertEqual(self.resp.getExtensionArgs(), - {'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z"}) - self.resp.setAuthLevel(pape.LEVELS_NIST, '3') - nist_args = {'auth_policies': 'http://uri http://zig', 'auth_time': "1776-07-04T14:43:12Z", - 'auth_level.nist': '3', 'auth_level.ns.nist': pape.LEVELS_NIST} - self.assertEqual(self.resp.getExtensionArgs(), nist_args) - - def test_getExtensionArgs_error_auth_age(self): - self.resp.auth_time = "long ago" - self.assertRaises(ValueError, self.resp.getExtensionArgs) - - def test_parseExtensionArgs(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': '1970-01-01T00:00:00Z'} - self.resp.parseExtensionArgs(args, is_openid1=False) - self.assertEqual(self.resp.auth_time, '1970-01-01T00:00:00Z') - self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) - - def test_parseExtensionArgs_valid_none(self): - args = {'auth_policies': pape.AUTH_NONE} - self.resp.parseExtensionArgs(args, is_openid1=False) - self.assertEqual(self.resp.auth_policies, []) - - def test_parseExtensionArgs_old_none(self): - args = {'auth_policies': 'none'} - self.resp.parseExtensionArgs(args, is_openid1=False) - self.assertEqual(self.resp.auth_policies, []) - - def test_parseExtensionArgs_old_none_strict(self): - args = {'auth_policies': 'none'} - self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) - - def test_parseExtensionArgs_empty(self): - self.resp.parseExtensionArgs({}, is_openid1=False) - self.assertIsNone(self.resp.auth_time) - self.assertEqual(self.resp.auth_policies, []) - - def test_parseExtensionArgs_empty_strict(self): - self.assertRaises(ValueError, self.resp.parseExtensionArgs, {}, is_openid1=False, strict=True) - - def test_parseExtensionArgs_ignore_superfluous_none(self): - policies = [pape.AUTH_NONE, pape.AUTH_MULTI_FACTOR_PHYSICAL] - - args = { - 'auth_policies': ' '.join(policies), - } - - self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) - - self.assertEqual(self.resp.auth_policies, [pape.AUTH_MULTI_FACTOR_PHYSICAL]) - - def test_parseExtensionArgs_none_strict(self): - policies = [pape.AUTH_NONE, pape.AUTH_MULTI_FACTOR_PHYSICAL] - - args = { - 'auth_policies': ' '.join(policies), - } - - self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) - - def test_parseExtensionArgs_strict_bogus1(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': 'yesterday'} - self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) - - def test_parseExtensionArgs_openid1_strict(self): - args = {'auth_level.nist': '0', - 'auth_policies': pape.AUTH_NONE, - } - self.resp.parseExtensionArgs(args, strict=True, is_openid1=True) - self.assertEqual(self.resp.getAuthLevel(pape.LEVELS_NIST), '0') - self.assertEqual(self.resp.auth_policies, []) - - def test_parseExtensionArgs_strict_no_namespace_decl_openid2(self): - # Test the case where the namespace is not declared for an - # auth level. - args = {'auth_policies': pape.AUTH_NONE, - 'auth_level.nist': '0', - } - self.assertRaises(ValueError, self.resp.parseExtensionArgs, args, is_openid1=False, strict=True) - - def test_parseExtensionArgs_nostrict_no_namespace_decl_openid2(self): - # Test the case where the namespace is not declared for an - # auth level. - args = {'auth_policies': pape.AUTH_NONE, - 'auth_level.nist': '0', - } - self.resp.parseExtensionArgs(args, is_openid1=False, strict=False) - - # There is no namespace declaration for this auth level. - self.assertRaises(KeyError, self.resp.getAuthLevel, pape.LEVELS_NIST) - - def test_parseExtensionArgs_strict_good(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': '1970-01-01T00:00:00Z', - 'auth_level.nist': '0', - 'auth_level.ns.nist': pape.LEVELS_NIST} - self.resp.parseExtensionArgs(args, is_openid1=False, strict=True) - self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) - self.assertEqual(self.resp.auth_time, '1970-01-01T00:00:00Z') - self.assertEqual(self.resp.nist_auth_level, 0) - - def test_parseExtensionArgs_nostrict_bogus(self): - args = {'auth_policies': 'http://foo http://bar', - 'auth_time': 'when the cows come home', - 'nist_auth_level': 'some'} - self.resp.parseExtensionArgs(args, is_openid1=False) - self.assertEqual(self.resp.auth_policies, ['http://foo', 'http://bar']) - self.assertIsNone(self.resp.auth_time) - self.assertIsNone(self.resp.nist_auth_level) - - def test_fromSuccessResponse(self): - policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] - openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join(policy_uris), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) - signed_stuff = { - 'auth_policies': ' '.join(policy_uris), - 'auth_time': '1970-01-01T00:00:00Z' - } - oid_req = DummySuccessResponse(openid_req_msg, signed_stuff) - req = pape.Response.fromSuccessResponse(oid_req) - self.assertEqual(req.auth_policies, policy_uris) - self.assertEqual(req.auth_time, '1970-01-01T00:00:00Z') - - def test_fromSuccessResponseNoSignedArgs(self): - policy_uris = [pape.AUTH_MULTI_FACTOR, pape.AUTH_PHISHING_RESISTANT] - openid_req_msg = Message.fromOpenIDArgs({ - 'mode': 'id_res', - 'ns': OPENID2_NS, - 'ns.pape': pape.ns_uri, - 'pape.auth_policies': ' '.join(policy_uris), - 'pape.auth_time': '1970-01-01T00:00:00Z' - }) - - signed_stuff = {} - - class NoSigningDummyResponse(DummySuccessResponse): - def getSignedNS(self, ns_uri): - return None - - oid_req = NoSigningDummyResponse(openid_req_msg, signed_stuff) - resp = pape.Response.fromSuccessResponse(oid_req) - self.assertIsNone(resp) - - -if __name__ == '__main__': - unittest.main() +class PapeImportTestCase(unittest.TestCase): + def test_version(self): + from openid.extensions.draft import pape5 + self.assertEqual(pape.Request, pape5.Request) + self.assertEqual(pape.Response, pape5.Response) From a821a555d6f9042460aa69ec6028c84cdf94ae82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 10 Apr 2018 15:03:59 +0200 Subject: [PATCH 053/151] Split as python-openid2 library --- NOTICE | 4 ---- README.md | 39 ++++++++++++--------------------------- examples/consumer.py | 2 +- setup.py | 10 +++++----- 4 files changed, 18 insertions(+), 37 deletions(-) delete mode 100644 NOTICE diff --git a/NOTICE b/NOTICE deleted file mode 100644 index e63503e9..00000000 --- a/NOTICE +++ /dev/null @@ -1,4 +0,0 @@ -Python OpenID may be obtained from -http://github.com/openid/python-openid -and we'd like to hear about how you're using this software. -Write to us at openid@janrain.com. diff --git a/README.md b/README.md index af1d6e1c..a259fb69 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,23 @@ -This is the Python OpenID library. - -[![Build Status][travis-image]][travis-link] +# python-openid2 # -[travis-image]: https://secure.travis-ci.org/openid/python-openid.png?branch=master -[travis-link]: http://travis-ci.org/openid/python-openid +[![Build Status](https://travis-ci.org/ziima/python-openid.svg?branch=master)](https://travis-ci.org/ziima/python-openid) +This is the Python OpenID library. -REQUIREMENTS -============ +## REQUIREMENTS ## - Python 2.7. - lxml -INSTALLATION -============ +## INSTALLATION ## To install the base library, just run the following command: -python setup.py install +pip install python-openid2 -To run setup.py you need the distutils module from the Python standard -library; some distributions package this seperately in a "python-dev" -package. - -GETTING STARTED -=============== +## GETTING STARTED ## The examples directory includes an example server and consumer implementation. See the README file in that directory for more @@ -35,8 +26,7 @@ information on running the examples. Library documentation is available in html form in the doc directory. -LOGGING -======= +## LOGGING ## This library offers a logging hook that will record unexpected conditions that occur in library code. If a condition is recoverable, @@ -46,8 +36,7 @@ documentation for the openid.oidutil module for more on the logging hook. -DOCUMENTATION -============= +## DOCUMENTATION ## The documentation in this library is in Epydoc format, which is detailed at: @@ -55,14 +44,10 @@ detailed at: http://epydoc.sourceforge.net/ -CONTACT -======= +## CONTACT ## Send bug reports, suggestions, comments, and questions to -http://openid.net/developers/dev-mailing-lists/. +https://github.com/ziima/python-openid/issues/new If you have a bugfix or feature you'd like to contribute, don't -hesitate to send it to us. For more detailed information on how to -contribute, see - - http://openidenabled.com/contribute/ +hesitate to send it to us on GitHub. diff --git a/examples/consumer.py b/examples/consumer.py index fa6b3f01..d39a2608 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -424,7 +424,7 @@ def pageHeader(self, title):

%s

This example consumer uses the Python + "https://github.com/ziima/python-openid" >Python OpenID library. It just verifies that the identifier that you enter is your identifier.

diff --git a/setup.py b/setup.py index 6b2e6726..e7bb36d2 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import os import sys @@ -23,7 +24,7 @@ } setup( - name='python-openid', + name='python-openid2', version=version, description='OpenID support for servers and consumers.', long_description='''This is a set of Python packages to support use of @@ -31,7 +32,7 @@ single sign-on for your web site? Use the openid.consumer package. Want to run your own OpenID server? Check out openid.server. Includes example code and support for a variety of storage back-ends.''', - url='http://github.com/openid/python-openid', + url='https://github.com/ziima/python-openid', packages=['openid', 'openid.consumer', 'openid.server', @@ -44,9 +45,8 @@ extras_require=EXTRAS_REQUIRE, # license specified by classifier. # license=getLicense(), - author='JanRain', - author_email='openid@janrain.com', - download_url='http://github.com/openid/python-openid/tarball/%s' % (version,), + author='Vlastimil Zíma', + author_email='vlastimil.zima@gmail.com', classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Web Environment", From 832164e534f7c01ed54296af63c83868b4e812b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 10 Apr 2018 15:08:03 +0200 Subject: [PATCH 054/151] Prepare version 2.3.0rc1 --- admin/setversion | 7 ------- openid/__init__.py | 13 +------------ setup.py | 5 +++-- 3 files changed, 4 insertions(+), 21 deletions(-) delete mode 100755 admin/setversion diff --git a/admin/setversion b/admin/setversion deleted file mode 100755 index ea2b20cb..00000000 --- a/admin/setversion +++ /dev/null @@ -1,7 +0,0 @@ -#!/usr/bin/env bash - -cat < Date: Tue, 10 Apr 2018 16:04:01 +0200 Subject: [PATCH 055/151] Add changelog --- Changelog.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 Changelog.md diff --git a/Changelog.md b/Changelog.md new file mode 100644 index 00000000..dc6f764f --- /dev/null +++ b/Changelog.md @@ -0,0 +1,27 @@ +# Changelog # + +## 2.3.0 ## + + * Prevent timing attacks on signature comparison. Thanks to Carl Howells. + * Prevent XXE attacks. + * Fix unicode errors. Thanks to Kai Lautaportti. + * Drop support for python versions < 2.7. + * Use logging module. Thanks to Attila-Mihaly Balazs. + * Allow signatory, encoder and decoder to be set for Server. Thanks to julio. + * Fix URL limit to server responses. Thanks to Rodrigo Primo. + * Fix several protocol errors. + * Add utility method to AX store extension. + * Fix curl detection. Thanks to Sergey Shepelev. + * Use setuptools. Thanks to Tres Seaver. + * Refactor `Message` class creation. + * Add `RequestsFetcher`. Thanks to Lennonka. + * Updated examples. + * Add tox for testing. Thanks to Marc Abramowitz. + * Refactor tests. + * Clean code and add static checks. + +### Deprecation ### + * `Message.setOpenIDNamespace()` method. + * `UndefinedOpenIDNamespace` exception. + * `OpenIDRequest.namespace` attribute. + * `openid.extensions.draft` packages, namely its `pape2` and `pape5` modules. From c4990b09a239f3c74c7dfafcfa830e473c7be233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 30 Apr 2018 13:23:35 +0200 Subject: [PATCH 056/151] Update badges --- README.md | 5 ++++- setup.py | 29 +++++++++++++++++------------ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index a259fb69..2770f60b 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,15 @@ # python-openid2 # [![Build Status](https://travis-ci.org/ziima/python-openid.svg?branch=master)](https://travis-ci.org/ziima/python-openid) +[![codecov](https://codecov.io/gh/ziima/python-openid/branch/master/graph/badge.svg)](https://codecov.io/gh/ziima/python-openid) +[![PyPI](https://img.shields.io/pypi/v/python-openid2.svg)](https://pypi.org/pypi/python-openid2/) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/python-openid2.svg)](https://pypi.org/pypi/python-openid2/) This is the Python OpenID library. ## REQUIREMENTS ## - - Python 2.7. + - Python 2.7 - lxml diff --git a/setup.py b/setup.py index 7a5d810b..89270f9c 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,21 @@ # Dependencies for Django example 'djopenid': ('django<1.11.99', ), } +CLASSIFIERS = [ + 'Development Status :: 5 - Production/Stable', + 'Environment :: Web Environment', + 'Intended Audience :: Developers', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: POSIX', + 'Programming Language :: Python', + 'Programming Language :: Python :: 2', + 'Programming Language :: Python :: 2.7', + 'Topic :: Internet :: WWW/HTTP', + 'Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + 'Topic :: System :: Systems Administration :: Authentication/Directory', +] + setup( name='python-openid2', @@ -42,22 +57,12 @@ 'openid.extensions', 'openid.extensions.draft', ], + python_requires='~=2.7', install_requires=INSTALL_REQUIRES, extras_require=EXTRAS_REQUIRE, # license specified by classifier. # license=getLicense(), author='Vlastimil Zíma', author_email='vlastimil.zima@gmail.com', - classifiers=[ - "Development Status :: 5 - Production/Stable", - "Environment :: Web Environment", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Operating System :: POSIX", - "Programming Language :: Python", - "Topic :: Internet :: WWW/HTTP", - "Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries", - "Topic :: Software Development :: Libraries :: Python Modules", - "Topic :: System :: Systems Administration :: Authentication/Directory", - ], + classifiers=CLASSIFIERS, ) From a1f08f9fd0cba5c7c97bdc592ca3c69919841012 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 30 Apr 2018 13:54:44 +0200 Subject: [PATCH 057/151] Update descriptions --- MANIFEST.in | 8 +------- README.md | 11 +++++++++-- setup.py | 10 ++++------ 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index a1d314b6..efa752ea 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1 @@ -include LICENSE NOTICE CHANGELOG MANIFEST.in NEWS background-associations.txt -graft admin -graft contrib -recursive-include examples README discover *.py *.html *.xml -recursive-include openid/test *.txt dhpriv n2b64 *.py -recursive-include openid/test/data * -recursive-include doc *.css *.html +include *.md diff --git a/README.md b/README.md index 2770f60b..bce1d2de 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,14 @@ [![PyPI](https://img.shields.io/pypi/v/python-openid2.svg)](https://pypi.org/pypi/python-openid2/) [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/python-openid2.svg)](https://pypi.org/pypi/python-openid2/) -This is the Python OpenID library. +Python OpenID library - OpenID support for servers and consumers. + +This is a set of Python packages to support use of the OpenID decentralized identity system in your application. +Want to enable single sign-on for your web site? +Use the `openid.consumer package`. +Want to run your own OpenID server? +Check out `openid.server`. +Includes example code and support for a variety of storage back-ends. ## REQUIREMENTS ## @@ -17,7 +24,7 @@ This is the Python OpenID library. To install the base library, just run the following command: -pip install python-openid2 + pip install python-openid2 ## GETTING STARTED ## diff --git a/setup.py b/setup.py index 89270f9c..3e8fdaa0 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ # Dependencies for Django example 'djopenid': ('django<1.11.99', ), } +LONG_DESCRIPTION = open('README.md').read() + '\n\n' + open('Changelog.md').read() CLASSIFIERS = [ 'Development Status :: 5 - Production/Stable', 'Environment :: Web Environment', @@ -42,12 +43,9 @@ setup( name='python-openid2', version=VERSION, - description='OpenID support for servers and consumers.', - long_description='''This is a set of Python packages to support use of -the OpenID decentralized identity system in your application. Want to enable -single sign-on for your web site? Use the openid.consumer package. Want to -run your own OpenID server? Check out openid.server. Includes example code -and support for a variety of storage back-ends.''', + description='Python OpenID library - OpenID support for servers and consumers.', + long_description=LONG_DESCRIPTION, + long_description_content_type='text/markdown', url='https://github.com/ziima/python-openid', packages=['openid', 'openid.consumer', From f638838a14956ca608368d2efa50e73b282a8581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 30 Apr 2018 14:11:50 +0200 Subject: [PATCH 058/151] Bump version 2.3.0 --- openid/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openid/__init__.py b/openid/__init__.py index 69844336..e34eebfd 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -23,7 +23,7 @@ and limitations under the License. """ -__version__ = '2.3.0rc1' +__version__ = '2.3.0' __all__ = [ 'association', From a1f864ada10d00edc1e58e5ecb97fab6ed319a62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 9 Mar 2018 13:37:22 +0100 Subject: [PATCH 059/151] Refactor urinorm --- openid/test/data/trustroot.txt | 16 +-- openid/test/test_urinorm.py | 95 ++++++++++---- openid/test/urinorm.txt | 87 ------------- openid/urinorm.py | 220 ++++++++++++--------------------- setup.py | 1 + 5 files changed, 157 insertions(+), 262 deletions(-) delete mode 100644 openid/test/urinorm.txt diff --git a/openid/test/data/trustroot.txt b/openid/test/data/trustroot.txt index 3d948a4f..f46ec088 100644 --- a/openid/test/data/trustroot.txt +++ b/openid/test/data/trustroot.txt @@ -3,32 +3,31 @@ Trust root parsing checking ======================================== ---------------------------------------- -21: Does not parse +20: Does not parse ---------------------------------------- baz.org *.foo.com http://*.schtuff.*/ ftp://foo.com ftp://*.foo.com -http://*.foo.com:80:90/ http:/// http:// foo.*.com http://foo.*.com http://www.* http://*foo.com/ +http://.it/ +http://..it/ http://foo.com\/ http://localhost:1900foo/ http://foo.com/invalid#fragment -http://π.pi.com/ -http://lambda.com/Λ 5 ---------------------------------------- -15: Insane +13: Insane ---------------------------------------- http://*/ https://*/ @@ -43,11 +42,9 @@ http://*.museum/ https://*.museum/ http://www.schtuffcom/ http://it/ -http://..it/ -http://.it/ ---------------------------------------- -18: Sane +21: Sane ---------------------------------------- http://*.schtuff.com./ http://*.schtuff.com/ @@ -67,6 +64,9 @@ https://foo.com/ http://kink.fm/should/be/sane http://beta.lingu.no/ http://goathack.livejournal.org:8020/openid/login.bml +http://*.example.com:80:90/ +http://π.pi.example.com/ +http://lambda.example.com/Λ ======================================== return_to matching diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 0db74eb0..50b53552 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -1,30 +1,77 @@ -import os -import unittest +# -*- coding: utf-8 -*- +"""Tests for `openid.urinorm` module.""" +from __future__ import unicode_literals -import openid.urinorm +import unittest -with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'urinorm.txt')) as test_data_file: - test_data = test_data_file.read() +from openid.urinorm import urinorm class UrinormTest(unittest.TestCase): + """Test `urinorm` function.""" + + def test_normalized(self): + self.assertEqual(urinorm('http://example.com/'), 'http://example.com/') + self.assertEqual(urinorm(b'http://example.com/'), 'http://example.com/') + + def test_lowercase_scheme(self): + self.assertEqual(urinorm('htTP://example.com/'), 'http://example.com/') + + def test_unsupported_scheme(self): + self.assertRaisesRegexp(ValueError, 'Not an absolute HTTP or HTTPS URI', urinorm, 'ftp://example.com/') + + def test_lowercase_hostname(self): + self.assertEqual(urinorm('http://exaMPLE.COm/'), 'http://example.com/') + + def test_idn_hostname(self): + self.assertEqual(urinorm('http://π.example.com/'), 'http://xn--1xa.example.com/') + + def test_empty_hostname(self): + self.assertEqual(urinorm('http://username@/'), 'http://username@/') + + def test_invalid_hostname(self): + self.assertRaisesRegexp(ValueError, 'Invalid hostname', urinorm, 'http://.it/') + self.assertRaisesRegexp(ValueError, 'Invalid hostname', urinorm, 'http://..it/') + self.assertRaisesRegexp(ValueError, 'Not an absolute URI', urinorm, 'http:///path/') + + def test_empty_port_section(self): + self.assertEqual(urinorm('http://example.com:/'), 'http://example.com/') + + def test_default_ports(self): + self.assertEqual(urinorm('http://example.com:80/'), 'http://example.com/') + self.assertEqual(urinorm('https://example.com:443/'), 'https://example.com/') + + def test_empty_path(self): + self.assertEqual(urinorm('http://example.com'), 'http://example.com/') + + def test_path_dots(self): + self.assertEqual(urinorm('http://example.com/./a'), 'http://example.com/a') + self.assertEqual(urinorm('http://example.com/../a'), 'http://example.com/a') + + self.assertEqual(urinorm('http://example.com/a/.'), 'http://example.com/a/') + self.assertEqual(urinorm('http://example.com/a/..'), 'http://example.com/') + self.assertEqual(urinorm('http://example.com/a/./'), 'http://example.com/a/') + self.assertEqual(urinorm('http://example.com/a/../'), 'http://example.com/') + + self.assertEqual(urinorm('http://example.com/a/./b'), 'http://example.com/a/b') + self.assertEqual(urinorm('http://example.com/a/../b'), 'http://example.com/b') + + self.assertEqual(urinorm('http://example.com/a/b/c/./../../g'), 'http://example.com/a/g') + self.assertEqual(urinorm('http://example.com/mid/content=5/../6'), 'http://example.com/mid/6') + + def test_path_percent_encoding(self): + self.assertEqual(urinorm('http://example.com/'), 'http://example.com/%08') + self.assertEqual(urinorm('http://example.com/Λ'), 'http://example.com/%CE%9B') + + def test_path_capitalize_percent_encoding(self): + self.assertEqual(urinorm('http://example.com/foo%2cbar'), 'http://example.com/foo%2Cbar') + + def test_path_percent_decode_unreserved(self): + self.assertEqual(urinorm('http://example.com/foo%2Dbar%2dbaz'), 'http://example.com/foo-bar-baz') + + def test_illegal_characters(self): + self.assertRaisesRegexp(ValueError, 'Illegal characters in URI', urinorm, 'http://.com/') - def runTest(self): - for case in test_data.split('\n\n'): - case = case.strip() - if not case: - continue - - desc, raw, expected = self.parse(case) - try: - actual = openid.urinorm.urinorm(raw) - except ValueError as why: - self.assertEqual(expected, 'fail', why) - else: - self.assertEqual(actual, expected, desc) - - def parse(self, full_case): - desc, case, expected = full_case.split('\n') - case = unicode(case, 'utf-8') - - return (desc, case, expected) + def test_realms(self): + # Urinorm supports OpenID realms with * in them + self.assertEqual(urinorm('http://*.example.com/'), 'http://*.example.com/') diff --git a/openid/test/urinorm.txt b/openid/test/urinorm.txt deleted file mode 100644 index a5db39e9..00000000 --- a/openid/test/urinorm.txt +++ /dev/null @@ -1,87 +0,0 @@ -Already normal form -http://example.com/ -http://example.com/ - -Add a trailing slash -http://example.com -http://example.com/ - -Remove an empty port segment -http://example.com:/ -http://example.com/ - -Remove a default port segment -http://example.com:80/ -http://example.com/ - -Capitalization in host names -http://wWw.exaMPLE.COm/ -http://www.example.com/ - -Capitalization in scheme names -htTP://example.com/ -http://example.com/ - -Capitalization in percent-escaped reserved characters -http://example.com/foo%2cbar -http://example.com/foo%2Cbar - -Unescape percent-encoded unreserved characters -http://example.com/foo%2Dbar%2dbaz -http://example.com/foo-bar-baz - -remove_dot_segments example 1 -http://example.com/a/b/c/./../../g -http://example.com/a/g - -remove_dot_segments example 2 -http://example.com/mid/content=5/../6 -http://example.com/mid/6 - -remove_dot_segments: single-dot -http://example.com/a/./b -http://example.com/a/b - -remove_dot_segments: double-dot -http://example.com/a/../b -http://example.com/b - -remove_dot_segments: leading double-dot -http://example.com/../b -http://example.com/b - -remove_dot_segments: trailing single-dot -http://example.com/a/. -http://example.com/a/ - -remove_dot_segments: trailing double-dot -http://example.com/a/.. -http://example.com/ - -remove_dot_segments: trailing single-dot-slash -http://example.com/a/./ -http://example.com/a/ - -remove_dot_segments: trailing double-dot-slash -http://example.com/a/../ -http://example.com/ - -Test of all kinds of syntax-based normalization -hTTPS://a/./b/../b/%63/%7bfoo%7d -https://a/b/c/%7Bfoo%7D - -Unsupported scheme -ftp://example.com/ -fail - -Non-absolute URI -http:/foo -fail - -Illegal character in URI -http://.com/ -fail - -Non-ascii character in URI -http://foo.com/ -fail diff --git a/openid/urinorm.py b/openid/urinorm.py index e7127d34..0da86ee4 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -1,105 +1,12 @@ -import re - -# from appendix B of rfc 3986 (http://www.ietf.org/rfc/rfc3986.txt) -uri_pattern = r'^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?' -uri_re = re.compile(uri_pattern) - -# gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" -# -# sub-delims = "!" / "$" / "&" / "'" / "(" / ")" -# / "*" / "+" / "," / ";" / "=" -# -# unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" - -uri_illegal_char_re = re.compile( - "[^-A-Za-z0-9:/?#[\]@!$&'()*+,;=._~%]", re.UNICODE) - -authority_pattern = r'^([^@]*@)?([^:]*)(:.*)?' -authority_re = re.compile(authority_pattern) - - -pct_encoded_pattern = r'%([0-9A-Fa-f]{2})' -pct_encoded_re = re.compile(pct_encoded_pattern) - -try: - unichr(0x10000) -except ValueError: - # narrow python build - UCSCHAR = [ - (0xA0, 0xD7FF), - (0xF900, 0xFDCF), - (0xFDF0, 0xFFEF), - ] - - IPRIVATE = [ - (0xE000, 0xF8FF), - ] -else: - UCSCHAR = [ - (0xA0, 0xD7FF), - (0xF900, 0xFDCF), - (0xFDF0, 0xFFEF), - (0x10000, 0x1FFFD), - (0x20000, 0x2FFFD), - (0x30000, 0x3FFFD), - (0x40000, 0x4FFFD), - (0x50000, 0x5FFFD), - (0x60000, 0x6FFFD), - (0x70000, 0x7FFFD), - (0x80000, 0x8FFFD), - (0x90000, 0x9FFFD), - (0xA0000, 0xAFFFD), - (0xB0000, 0xBFFFD), - (0xC0000, 0xCFFFD), - (0xD0000, 0xDFFFD), - (0xE1000, 0xEFFFD), - ] - - IPRIVATE = [ - (0xE000, 0xF8FF), - (0xF0000, 0xFFFFD), - (0x100000, 0x10FFFD), - ] - - -_unreserved = [False] * 256 -for _ in range(ord('A'), ord('Z') + 1): - _unreserved[_] = True -for _ in range(ord('0'), ord('9') + 1): - _unreserved[_] = True -for _ in range(ord('a'), ord('z') + 1): - _unreserved[_] = True -_unreserved[ord('-')] = True -_unreserved[ord('.')] = True -_unreserved[ord('_')] = True -_unreserved[ord('~')] = True - - -_escapeme_re = re.compile('[%s]' % ''.join(u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])) for m_n in UCSCHAR + IPRIVATE)) - - -def _pct_escape_unicode(char_match): - c = char_match.group() - return ''.join(['%%%X' % (ord(octet),) for octet in c.encode('utf-8')]) - - -def _pct_encoded_replace_unreserved(mo): - try: - i = int(mo.group(1), 16) - if _unreserved[i]: - return chr(i) - else: - return mo.group().upper() +"""URI normalization utilities.""" +from __future__ import unicode_literals - except ValueError: - return mo.group() +import string +import warnings +from urllib import quote, unquote, urlencode +from urlparse import parse_qsl, urlsplit, urlunsplit - -def _pct_encoded_replace(mo): - try: - return chr(int(mo.group(1), 16)) - except ValueError: - return mo.group() +import six def remove_dot_segments(path): @@ -137,65 +44,92 @@ def remove_dot_segments(path): return ''.join(result_segments) -def urinorm(uri): - if isinstance(uri, unicode): - uri = _escapeme_re.sub(_pct_escape_unicode, uri).encode('ascii') +GEN_DELIMS = ":" + "/" + "?" + "#" + "[" + "]" + "@" +SUB_DELIMS = "!" + "$" + "&" + "'" + "(" + ")" + "*" + "+" + "," + ";" + "=" +RESERVED = GEN_DELIMS + SUB_DELIMS +UNRESERVED = string.ascii_letters + string.digits + "-" + "." + "_" + "~" +# Allow "%" as percent encoding character +PERCENT_ENCODING_CHARACTER = "%" - illegal_mo = uri_illegal_char_re.search(uri) - if illegal_mo: - raise ValueError('Illegal characters in URI: %r at position %s' % - (illegal_mo.group(), illegal_mo.start())) - uri_mo = uri_re.match(uri) +def _check_disallowed_characters(uri_part, part_name): + # Roughly check the allowed characters. The check in not strict according to URI ABNF, but good enough. + # Also allow "%" for percent encoding. + if set(uri_part).difference(set(UNRESERVED + RESERVED + PERCENT_ENCODING_CHARACTER)): + raise ValueError('Illegal characters in URI {}: {}'.format(part_name, uri_part)) - scheme = uri_mo.group(2) - if scheme is None: - raise ValueError('No scheme specified') - scheme = scheme.lower() - if scheme not in ('http', 'https'): - raise ValueError('Not an absolute HTTP or HTTPS URI: %r' % (uri,)) +def urinorm(uri): + """Return normalized URI. - authority = uri_mo.group(4) - if authority is None: - raise ValueError('Not an absolute URI: %r' % (uri,)) + Normalization if performed according to RFC 3986, section 6 https://tools.ietf.org/html/rfc3986#section-6. + Supported URIs are URLs and OpenID realm URIs. - authority_mo = authority_re.match(authority) - if authority_mo is None: - raise ValueError('URI does not have a valid authority: %r' % (uri,)) + @type uri: six.text_type, six.binary_type deprecated + @rtype: six.text_type + @raise ValueError: If URI is invalid. + """ + # Transform the input to the unicode string + if isinstance(uri, six.binary_type): + warnings.warn("Binary input for urinorm is deprecated. Use text input instead.", DeprecationWarning) + uri = uri.decode('utf-8') - userinfo, host, port = authority_mo.groups() + split_uri = urlsplit(uri) - if userinfo is None: - userinfo = '' + # Normalize scheme + scheme = split_uri.scheme.lower() + if scheme not in ('http', 'https'): + raise ValueError('Not an absolute HTTP or HTTPS URI: {!r}'.format(uri)) - if '%' in host: - host = host.lower() - host = pct_encoded_re.sub(_pct_encoded_replace, host) - host = unicode(host, 'utf-8').encode('idna') - else: - host = host.lower() + # Normalize netloc + if not split_uri.netloc: + raise ValueError('Not an absolute URI: {!r}'.format(uri)) - if port: - if port == ':' or (scheme == 'http' and port == ':80') or (scheme == 'https' and port == ':443'): - port = '' + hostname = split_uri.hostname + if hostname is None: + hostname = '' else: - port = '' + hostname = hostname.lower() + # Unquote percent encoded characters + hostname = unquote(hostname) + # Quote IDN domain names + try: + hostname = hostname.encode('idna') + except ValueError as error: + raise ValueError('Invalid hostname {!r}: {}'.format(hostname, error)) + _check_disallowed_characters(hostname, 'hostname') - authority = userinfo + host + port + port = split_uri.port + if port is None: + port = '' + elif (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443): + port = '' - path = uri_mo.group(5) - path = pct_encoded_re.sub(_pct_encoded_replace_unreserved, path) + netloc = hostname + if port: + netloc = netloc + ':' + str(port) + userinfo_chunks = [i for i in (split_uri.username, split_uri.password) if i is not None] + if userinfo_chunks: + userinfo = ':'.join(userinfo_chunks) + _check_disallowed_characters(userinfo, 'userinfo') + netloc = userinfo + '@' + netloc + + # Normalize path + path = split_uri.path + # Unquote and quote - this normalizes the percent encoding + path = quote(unquote(path.encode('utf-8'))).decode('utf-8') path = remove_dot_segments(path) if not path: path = '/' + _check_disallowed_characters(path, 'path') - query = uri_mo.group(6) - if query is None: - query = '' + # Normalize query + data = parse_qsl(split_uri.query) + query = urlencode(data) + _check_disallowed_characters(query, 'query') - fragment = uri_mo.group(8) - if fragment is None: - fragment = '' + # Normalize fragment + fragment = unquote(split_uri.fragment) + _check_disallowed_characters(fragment, 'fragment') - return scheme + '://' + authority + path + query + fragment + return urlunsplit((scheme, netloc, path, query, fragment)) diff --git a/setup.py b/setup.py index 3e8fdaa0..6ac4e0e1 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ # Import version from openid library itself VERSION = __import__('openid').__version__ INSTALL_REQUIRES = [ + 'six', 'lxml;platform_python_implementation=="CPython"', 'lxml <4.0;platform_python_implementation=="PyPy"', ] From 276bd8b0625c9c316d7cdec6b3d231e7cd224da9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 9 Mar 2018 14:24:21 +0100 Subject: [PATCH 060/151] Clean trustroot module * Remove useless code based on the changes in 'urinorm' function. --- openid/server/trustroot.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index a71ed718..ed258953 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -19,7 +19,7 @@ import logging import re -from urlparse import urlparse, urlunparse +from urlparse import urlsplit, urlunsplit from openid import urinorm from openid.yadis import services @@ -27,7 +27,6 @@ _LOGGER = logging.getLogger(__name__) ############################################ -_protocols = ['http', 'https'] _top_level_domains = [ 'ac', 'ad', 'ae', 'aero', 'af', 'ag', 'ai', 'al', 'am', 'an', 'ao', 'aq', 'ar', 'arpa', 'as', 'asia', 'at', 'au', 'aw', @@ -89,29 +88,12 @@ def _parseURL(url): url = urinorm.urinorm(url) except ValueError: return None - proto, netloc, path, params, query, frag = urlparse(url) - if not path: - path = '/' - path = urlunparse(('', '', path, params, query, frag)) + split_url = urlsplit(url) - if ':' in netloc: - try: - host, port = netloc.split(':') - except ValueError: - return None - - if not re.match(r'\d+$', port): - return None - else: - host = netloc - port = '' + path = urlunsplit(('', '', split_url.path or '/', split_url.query, split_url.fragment)) - host = host.lower() - if not host_segment_re.match(host): - return None - - return proto, host, port, path + return split_url.scheme, split_url.hostname, split_url.port, path class TrustRoot(object): @@ -270,10 +252,6 @@ def parse(cls, trust_root): proto, host, port, path = url_parts - # check for valid prototype - if proto not in _protocols: - return None - # check for URI fragment if path.find('#') != -1: return None From cbc56b0104f29cb5cc344fa3f4d18e18cefb127d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 12 Mar 2018 09:23:07 +0100 Subject: [PATCH 061/151] Refactor IRI to URI --- openid/test/test_xri.py | 17 +++------- openid/yadis/xri.py | 71 ++++++++++++----------------------------- 2 files changed, 24 insertions(+), 64 deletions(-) diff --git a/openid/test/test_xri.py b/openid/test/test_xri.py index a5f0bfaf..341472ed 100644 --- a/openid/test/test_xri.py +++ b/openid/test/test_xri.py @@ -32,19 +32,10 @@ class XriTransformationTestCase(TestCase): def test_to_iri_normal(self): self.assertEqual(xri.toIRINormal('@example'), 'xri://@example') - try: - unichr(0x10000) - except ValueError: - # bleh narrow python build - def test_iri_to_url(self): - s = u'l\xa1m' - expected = 'l%C2%A1m' - self.assertEqual(xri.iriToURI(s), expected) - else: - def test_iri_to_url(self): - s = u'l\xa1m\U00101010n' - expected = 'l%C2%A1m%F4%81%80%90n' - self.assertEqual(xri.iriToURI(s), expected) + def test_iri_to_url(self): + s = u'l\xa1m\U00101010n' + expected = 'l%C2%A1m%F4%81%80%90n' + self.assertEqual(xri.iriToURI(s), expected) class CanonicalIDTest(TestCase): diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index 60e0675b..73ad7980 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -4,53 +4,15 @@ @see: XRI Syntax v2.0 at the U{OASIS XRI Technical Committee} """ - import re +import warnings +from urllib import quote -XRI_AUTHORITIES = ['!', '=', '@', '+', '$', '('] +import six -try: - unichr(0x10000) -except ValueError: - # narrow python build - UCSCHAR = [ - (0xA0, 0xD7FF), - (0xF900, 0xFDCF), - (0xFDF0, 0xFFEF), - ] - - IPRIVATE = [ - (0xE000, 0xF8FF), - ] -else: - UCSCHAR = [ - (0xA0, 0xD7FF), - (0xF900, 0xFDCF), - (0xFDF0, 0xFFEF), - (0x10000, 0x1FFFD), - (0x20000, 0x2FFFD), - (0x30000, 0x3FFFD), - (0x40000, 0x4FFFD), - (0x50000, 0x5FFFD), - (0x60000, 0x6FFFD), - (0x70000, 0x7FFFD), - (0x80000, 0x8FFFD), - (0x90000, 0x9FFFD), - (0xA0000, 0xAFFFD), - (0xB0000, 0xBFFFD), - (0xC0000, 0xCFFFD), - (0xD0000, 0xDFFFD), - (0xE1000, 0xEFFFD), - ] - - IPRIVATE = [ - (0xE000, 0xF8FF), - (0xF0000, 0xFFFFD), - (0x100000, 0x10FFFD), - ] - - -_escapeme_re = re.compile('[%s]' % ''.join(u'%s-%s' % (unichr(m_n[0]), unichr(m_n[1])) for m_n in UCSCHAR + IPRIVATE)) +from openid.urinorm import GEN_DELIMS, PERCENT_ENCODING_CHARACTER, SUB_DELIMS + +XRI_AUTHORITIES = ['!', '=', '@', '+', '$', '('] def identifierScheme(identifier): @@ -96,15 +58,22 @@ def toURINormal(xri): return iriToURI(toIRINormal(xri)) -def _percentEscapeUnicode(char_match): - c = char_match.group() - return ''.join(['%%%X' % (ord(octet),) for octet in c.encode('utf-8')]) +def iriToURI(iri): + """Transform an IRI to a URI by escaping unicode. + According to RFC 3987, section 3.1, "Mapping of IRIs to URIs" -def iriToURI(iri): - """Transform an IRI to a URI by escaping unicode.""" - # According to RFC 3987, section 3.1, "Mapping of IRIs to URIs" - return _escapeme_re.sub(_percentEscapeUnicode, iri) + @type iri: six.text_type, six.binary_type deprecated. + @rtype: six.text_type + """ + # Transform the input to the binary string. `quote` doesn't quote correctly unicode strings. + if isinstance(iri, six.text_type): + iri = iri.encode('utf-8') + else: + assert isinstance(iri, six.binary_type) + warnings.warn("Binary input for iriToURI is deprecated. Use text input instead.", DeprecationWarning) + + return quote(iri, (GEN_DELIMS + SUB_DELIMS + PERCENT_ENCODING_CHARACTER).encode('utf-8')).decode('utf-8') def providerIsAuthoritative(providerID, canonicalID): From 0123c0480ad84fd4f7a6cc7ddf0e40c75af98134 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 18 Apr 2018 14:49:31 +0200 Subject: [PATCH 062/151] Refactor YADIS meta tag parsing --- openid/test/data/test1-parsehtml.txt | 152 ------------------- openid/test/test_discover.py | 5 +- openid/test/test_parsehtml.py | 151 +++++++++++++------ openid/yadis/parsehtml.py | 215 ++++----------------------- 4 files changed, 144 insertions(+), 379 deletions(-) delete mode 100644 openid/test/data/test1-parsehtml.txt diff --git a/openid/test/data/test1-parsehtml.txt b/openid/test/data/test1-parsehtml.txt deleted file mode 100644 index 20791e10..00000000 --- a/openid/test/data/test1-parsehtml.txt +++ /dev/null @@ -1,152 +0,0 @@ -found - - - -found - - - -found - - - -found - - - -found - - - -found - - - -found - - - -found - - - -EOF - -' + '') + self.assertEqual(findHTMLMeta(buff), 'found') + def test_multiple_headers(self): + buff = StringIO('' + '' + '') + self.assertEqual(findHTMLMeta(buff), 'found') -class TestParseHTML(unittest.TestCase): - reserved_values = ['None', 'EOF'] + def test_standard_entity(self): + buff = StringIO('') + self.assertEqual(findHTMLMeta(buff), '&') - def test(self): - for expected, case in getCases(): - p = YadisHTMLParser() - try: - p.feed(case) - except ParseDone as why: - found = why[0] + def test_hex_entity(self): + buff = StringIO('') + self.assertEqual(findHTMLMeta(buff), 'found') - # make sure we protect outselves against accidental bogus - # test cases - assert found not in self.reserved_values + def test_decimal_entity(self): + buff = StringIO('') + self.assertEqual(findHTMLMeta(buff), 'found') - # convert to a string - if found is None: - found = 'None' + def test_empty_string(self): + buff = StringIO('') + self.assertEqual(findHTMLMeta(buff), '') - msg = "%r != %r for case %s" % (found, expected, case) - self.assertEqual(found, expected, msg) - except HTMLParseError: - self.assertEqual(expected, 'None', (case, expected)) - else: - self.assertEqual(expected, 'EOF', (case, expected)) + def test_empty_input(self): + buff = StringIO('') + self.assertRaises(MetaNotFound, findHTMLMeta, buff) + def test_invalid_html(self): + buff = StringIO('') + self.assertRaises(MetaNotFound, findHTMLMeta, buff) -def parseCases(data): - cases = [] - for chunk in data.split('\f\n'): - expected, case = chunk.split('\n', 1) - cases.append((expected, case)) - return cases + def test_meta_in_body(self): + buff = StringIO('') + self.assertRaises(MetaNotFound, findHTMLMeta, buff) + def test_no_content(self): + buff = StringIO('') + self.assertRaises(MetaNotFound, findHTMLMeta, buff) -filenames = ['data/test1-parsehtml.txt'] + def test_commented_header(self): + buff = StringIO('' + '' + '') + self.assertRaises(MetaNotFound, findHTMLMeta, buff) -default_test_files = [] -base = os.path.dirname(__file__) -for filename in filenames: - full_name = os.path.join(base, filename) - default_test_files.append(full_name) + def test_no_yadis_header(self): + buff = StringIO("A boring document" + "

A boring document

There's really nothing interesting about this

" + "") + self.assertRaises(MetaNotFound, findHTMLMeta, buff) + def test_unclosed_tag(self): + # script tag not closed + buff = StringIO(' - -''', flags) - -tag_expr = r''' -# Starts with the tag name at a word boundary, where the tag name is -# not a namespace -<%(tag_name)s\b(?!:) - -# All of the stuff up to a ">", hopefully attributes. -(?P[^>]*?) - -(?: # Match a short tag - /> - -| # Match a full tag - > - - (?P.*?) - - # Closed by - (?: # One of the specified close tags - - - # End of the string - | \Z - - ) - -) -''' - - -def tagMatcher(tag_name, *close_tags): - if close_tags: - options = '|'.join((tag_name,) + close_tags) - closers = '(?:%s)' % (options,) - else: - closers = tag_name - - expr = tag_expr % locals() - return re.compile(expr, flags) - - -# Must contain at least an open html and an open head tag -html_find = tagMatcher('html') -head_find = tagMatcher('head', 'body') -link_find = re.compile(r'\w+)= - -# Then either a quoted or unquoted attribute -(?: - - # Match everything that\'s between matching quote marks - (?P["\'])(?P.*?)(?P=qopen) -| - - # If the value is not quoted, match up to whitespace - (?P(?:[^\s<>/]|/(?!>))+) -) - -| - -(?P[<>]) -''', flags) - -# Entity replacement: -replacements = { - 'amp': '&', - 'lt': '<', - 'gt': '>', - 'quot': '"', -} - -ent_replace = re.compile(r'&(%s);' % '|'.join(replacements.keys())) - - -def replaceEnt(mo): - "Replace the entities that are specified by OpenID" - return replacements.get(mo.group(1), mo.group()) - - -def parseLinkAttrs(html): - """Find all link tags in a string representing a HTML document and - return a list of their attributes. - - @param html: the text to parse - @type html: str or unicode - - @return: A list of dictionaries of attributes, one for each link tag - @rtype: [[(type(html), type(html))]] - """ - stripped = removed_re.sub('', html) - html_mo = html_find.search(stripped) - if html_mo is None or html_mo.start('contents') == -1: - return [] - - start, end = html_mo.span('contents') - head_mo = head_find.search(stripped, start, end) - if head_mo is None or head_mo.start('contents') == -1: - return [] - - start, end = head_mo.span('contents') - link_mos = link_find.finditer(stripped, head_mo.start(), head_mo.end()) - - matches = [] - for link_mo in link_mos: - start = link_mo.start() + 5 - link_attrs = {} - for attr_mo in attr_find.finditer(stripped, start): - if attr_mo.lastgroup == 'end_link': - break - - # Either q_val or unq_val must be present, but not both - # unq_val is a True (non-empty) value if it is present - attr_name, q_val, unq_val = attr_mo.group( - 'attr_name', 'q_val', 'unq_val') - attr_val = ent_replace.sub(replaceEnt, unq_val or q_val) - - link_attrs[attr_name] = attr_val - - matches.append(link_attrs) - - return matches - - -def relMatches(rel_attr, target_rel): - """Does this target_rel appear in the rel_str?""" - # XXX: TESTME - rels = rel_attr.strip().split() - for rel in rels: - rel = rel.lower() - if rel == target_rel: - return 1 - - return 0 - - -def linkHasRel(link_attrs, target_rel): - """Does this link have target_rel as a relationship?""" - # XXX: TESTME - rel_attr = link_attrs.get('rel') - return rel_attr and relMatches(rel_attr, target_rel) - - -def findLinksRel(link_attrs_list, target_rel): - """Filter the list of link attributes on whether it has target_rel - as a relationship.""" - # XXX: TESTME - matchesTarget = partial(linkHasRel, target_rel=target_rel) - return [i for i in link_attrs_list if matchesTarget(i)] - - -def findFirstHref(link_attrs_list, target_rel): - """Return the value of the href attribute for the first link tag - in the list that has target_rel as a relationship.""" - # XXX: TESTME - matches = findLinksRel(link_attrs_list, target_rel) - if not matches: - return None - first = matches[0] - return first.get('href') diff --git a/openid/test/linkparse.txt b/openid/test/linkparse.txt deleted file mode 100644 index 74c63ca7..00000000 --- a/openid/test/linkparse.txt +++ /dev/null @@ -1,584 +0,0 @@ -Num Tests: 72 - -OpenID link parsing test cases -Copyright (C) 2005-2008, JanRain, Inc. -See COPYING for license information. - -File format ------------ - -All text before the first triple-newline (this chunk) should be ignored. - -This file may be interpreted as Latin-1 or UTF-8. - -Test cases separated by three line separators (`\n\n\n'). The test -cases consist of a headers section followed by a data block. These are -separated by a double newline. The headers consist of the header name, -followed by a colon, a space, the value, and a newline. There must be -one, and only one, `Name' header for a test case. There may be zero or -more link headers. The `Link' header consists of whitespace-separated -attribute pairs. A link header with an empty string as a value -indicates an empty but present link tag. The attribute pairs are `=' -separated and not quoted. - -Optional Links and attributes have a trailing `*'. A compilant -implementation may produce this as output or may not. A compliant -implementation will not produce any output that is absent from this -file. - - -Name: No link tag at all - - - - - - - -Name: Link element first - - - - -Name: Link inside HTML, not head - - - - - -Name: Link inside head, not html - - - - - -Name: Link inside html, after head - - - - - - - -Name: Link inside html, before head - - - - - - -Name: Link before html and head - - - - - - -Name: Link after html document with head - - - - - - - - -Name: Link inside html inside head, inside another html - - - - - - - -Name: Link inside html inside head - - - - - - -Name: link inside body inside head inside html - - - - - - - -Name: Link inside head inside head inside html - - - - - - - -Name: Link inside script inside head inside html - - - - - - -Name: Link inside comment inside head inside html - - - - - - -Name: Link inside of head after short head - - - - - - - -Name: Plain vanilla -Link: - - - - - - -Name: Ignore tags in the namespace -Link*: - - - - - - - - -Name: Short link tag -Link: - - - - - - -Name: Spaces in the HTML tag -Link: - - - - - - -Name: Spaces in the head tag -Link: - - - - - - -Name: Spaces in the link tag -Link: - - - - - - -Name: No whitespace -Link: - - - - -Name: Closed head tag -Link: - - - - - - - -Name: One good, one bad (after close head) -Link: - - - - - - - - -Name: One good, one bad (after open body) -Link: - - - - - - - - -Name: ill formed (missing close head) -Link: - - - - - - - -Name: Ill formed (no close head, link after ) -Link: - - - - - - - - -Name: Ignore random tags inside of html -Link: - - - - - -<link> - - -Name: case-folding -Link*: - -<HtMl> -<hEaD> -<LiNk> - - -Name: unexpected tags -Link: - -<butternut> -<html> -<summer> -<head> -<turban> -<link> - - -Name: un-closed script tags -Link*: - -<html> -<head> -<script> -<link> - - -Name: un-closed script tags (no whitespace) -Link*: - -<html><head><script><link> - - -Name: un-closed comment -Link*: - -<html> -<head> -<!-- -<link> - - -Name: un-closed CDATA -Link*: - -<html> -<head> -<![CDATA[ -<link> - - -Name: cdata-like -Link*: - -<html> -<head> -<![ACORN[ -<link> -]]> - - -Name: comment close only -Link: - -<html> -<head> -<link> ---> - - -Name: Vanilla, two links -Link: -Link: - -<html> -<head> -<link> -<link> - - -Name: extra tag, two links -Link: -Link: - -<html> -<gold nugget> -<head> -<link> -<link> - - -Name: case-fold, body ends, two links -Link: -Link*: - -<html> -<head> -<link> -<LiNk> -<body> -<link> - - -Name: simple, non-quoted rel -Link: rel=openid.server - -<html><head><link rel=openid.server> - - -Name: short tag has rel -Link: rel=openid.server - -<html><head><link rel=openid.server/> - - -Name: short tag w/space has rel -Link: rel=openid.server - -<html><head><link rel=openid.server /> - - -Name: extra non-attribute, has rel -Link: rel=openid.server - -<html><head><link hubbard rel=openid.server> - - -Name: non-attr, has rel, short -Link: rel=openid.server - -<html><head><link hubbard rel=openid.server/> - - -Name: non-attr, has rel, short, space -Link: rel=openid.server - -<html><head><link hubbard rel=openid.server /> - - -Name: misplaced slash has rel -Link: rel=openid.server - -<html><head><link / rel=openid.server> - - -Name: quoted rel -Link: rel=openid.server - -<html><head><link rel="openid.server"> - - -Name: single-quoted rel -Link: rel=openid.server - -<html><head><link rel='openid.server'> - - -Name: two links w/ rel -Link: x=y -Link: a=b - -<html><head><link x=y><link a=b> - - -Name: non-entity -Link: x=&y - -<html><head><link x=&y> - - -Name: quoted non-entity -Link: x=&y - -<html><head><link x="&y"> - - -Name: quoted entity -Link: x=& - -<html><head><link x="&"> - - -Name: entity not processed -Link: x= - -<html><head><link x=""> - - -Name: < -Link: x=< - -<html><head><link x="<"> - - -Name: > -Link: x=> - -<html><head><link x=">"> - - -Name: " -Link: x=" - -<html><head><link x="""> - - -Name: &" -Link: x=&" - -<html><head><link x="&""> - - -Name: mixed entity and non-entity -Link: x=&"…> - -<html><head><link x="&"…>"> - - -Name: mixed entity and non-entity (w/normal chars) -Link: x=x&"…>x - -<html><head><link x="x&"…>x"> - - -Name: broken tags -Link*: x=y - -<html><head><link x=y<> - - -Name: missing close pointy -Link*: x=y -Link: z=y - -<html><head><link x=y<link z=y /> - - -Name: missing attribute value -Link: x=y y*= -Link: x=y - -<html><head><link x=y y=><link x=y /> - - -Name: Missing close pointy (no following) -Link*: x=y - -<html><head><link x=y - - -Name: Should be quoted -Link*: x=< - -<html><head><link x="<"> - - -Name: Should be quoted (2) -Link*: x=> - -<html><head><link x=">"> - - -Name: Repeated attribute -Link: x=y - -<html><head><link x=z x=y> - - -Name: Repeated attribute (2) -Link: x=y - -<html><head><link x=y x=y> - - -Name: Two attributes -Link: x=y y=z - -<html><head><link x=y y=z> - - -Name: Well-formed link rel="openid.server" -Link: rel=openid.server href=http://www.myopenid.com/server - -<html> - <head> - <link rel="openid.server" - href="http://www.myopenid.com/server" /> - </head> -</html> - - -Name: Well-formed link rel="openid.server" and "openid.delegate" -Link: rel=openid.server href=http://www.myopenid.com/server -Link: rel=openid.delegate href=http://example.myopenid.com/ - -<html><head><link rel="openid.server" - href="http://www.myopenid.com/server" /> - <link rel="openid.delegate" href="http://example.myopenid.com/" /> -</head></html> - - -Name: from brian's livejournal page -Link: rel=stylesheet href=http://www.livejournal.com/~serotta/res/319998/stylesheet?1130478711 type=text/css -Link: rel=openid.server href=http://www.livejournal.com/openid/server.bml - -<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" - "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> -<html xmlns="http://www.w3.org/1999/xhtml"> - <head> - <link rel="stylesheet" - href="http://www.livejournal.com/~serotta/res/319998/stylesheet?1130478711" - type="text/css" /> - <meta http-equiv="Content-Type" content="text/html; charset=utf-8" /> - <meta name="foaf:maker" - content="foaf:mbox_sha1sum '12f8abdacb5b1a806711e23249da592c0d316260'" /> - <meta name="robots" content="noindex, nofollow, noarchive" /> - <meta name="googlebot" content="nosnippet" /> - <link rel="openid.server" - href="http://www.livejournal.com/openid/server.bml" /> - <title>Brian - - - -Name: non-ascii (Latin-1 or UTF8) -Link: x=® - - - - diff --git a/openid/test/test_htmldiscover.py b/openid/test/test_htmldiscover.py index 65b036f1..b4caeb30 100644 --- a/openid/test/test_htmldiscover.py +++ b/openid/test/test_htmldiscover.py @@ -3,14 +3,15 @@ from openid.consumer.discover import OpenIDServiceEndpoint -class BadLinksTestCase(unittest.TestCase): - cases = [ - '', - "http://not.in.a.link.tag/", - '', - ] +class TestFromHTML(unittest.TestCase): + """Test `OpenIDServiceEndpoint.fromHTML`.""" - def test_from_html(self): - for html in self.cases: - actual = OpenIDServiceEndpoint.fromHTML('http://unused.url/', html) - self.assertEqual(actual, []) + def test_empty(self): + self.assertEqual(OpenIDServiceEndpoint.fromHTML('http://example.url/', ''), []) + + def test_invalid_html(self): + self.assertEqual(OpenIDServiceEndpoint.fromHTML('http://example.url/', "http://not.in.a.link.tag/"), []) + + def test_no_op_url(self): + html = '' + self.assertEqual(OpenIDServiceEndpoint.fromHTML('http://example.url/', html), []) diff --git a/openid/test/test_linkparse.py b/openid/test/test_linkparse.py deleted file mode 100644 index 077caaf4..00000000 --- a/openid/test/test_linkparse.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Test `openid.consumer.html_parse` module.""" -import os.path -import unittest - -from openid.consumer.html_parse import parseLinkAttrs - - -def parseLink(line): - parts = line.split() - optional = parts[0] == 'Link*:' - assert optional or parts[0] == 'Link:' - - attrs = {} - for attr in parts[1:]: - k, v = attr.split('=', 1) - if k[-1] == '*': - attr_optional = 1 - k = k[:-1] - else: - attr_optional = 0 - - attrs[k] = (attr_optional, v) - - return (optional, attrs) - - -def parseCase(s): - header, markup = s.split('\n\n', 1) - lines = header.split('\n') - name = lines.pop(0) - assert name.startswith('Name: ') - desc = name[6:] - return desc, markup, [parseLink(l) for l in lines] - - -def parseTests(s): - tests = [] - - cases = s.split('\n\n\n') - header = cases.pop(0) - tests_line, _ = header.split('\n', 1) - k, v = tests_line.split(': ') - assert k == 'Num Tests' - num_tests = int(v) - - for case in cases[:-1]: - desc, markup, links = parseCase(case) - tests.append((desc, markup, links, case)) - - assert len(tests) == num_tests, (len(tests), num_tests) - return num_tests, tests - - -with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'linkparse.txt')) as link_test_data_file: - link_test_data = link_test_data_file.read().decode('utf-8') - - -class LinkTest(unittest.TestCase): - """Test `parseLinkAttrs` function.""" - - def runTest(self): - num_tests, test_cases = parseTests(link_test_data) - - for desc, case, expected, raw in test_cases: - actual = parseLinkAttrs(case) - i = 0 - for optional, exp_link in expected: - if optional: - if i >= len(actual): - continue - - act_link = actual[i] - for k, (o, v) in exp_link.items(): - if o: - act_v = act_link.get(k) - if act_v is None: - continue - else: - act_v = act_link[k] - - if optional and v != act_v: - break - - self.assertEqual(v, act_v) - else: - i += 1 - - assert i == len(actual) From 3c069799e7b7fd99b41c24f412324330fdc7c16a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 3 May 2018 13:26:38 +0200 Subject: [PATCH 065/151] Add string_to_text utility function --- openid/oidutil.py | 17 +++++++++++++++++ openid/test/test_oidutil.py | 24 ++++++++++++++++++++++++ openid/urinorm.py | 8 ++------ 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/openid/oidutil.py b/openid/oidutil.py index 70384d32..34732ac1 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -9,8 +9,11 @@ import binascii import logging +import warnings from urllib import urlencode +import six + _LOGGER = logging.getLogger(__name__) @@ -148,3 +151,17 @@ def __hash__(self): def __repr__(self): return '' % (self.name,) + + +def string_to_text(value, deprecate_msg): + """ + Return input string coverted to text string. + + If input is text, it is returned as is. + If input is binary, it is decoded using UTF-8 to text. + """ + assert isinstance(value, (six.text_type, six.binary_type)) + if isinstance(value, six.binary_type): + warnings.warn(deprecate_msg, DeprecationWarning) + value = value.decode('utf-8') + return value diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index a8415e8f..0b499c82 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -3,8 +3,14 @@ import random import string import unittest +import warnings + +import six +from mock import sentinel +from testfixtures import ShouldWarn from openid import oidutil +from openid.oidutil import string_to_text class TestBase64(unittest.TestCase): @@ -159,3 +165,21 @@ def testCopyHash(self): # XXX: there are more functions that could benefit from being better # specified and tested in oidutil.py These include, but are not # limited to appendArgs + + +class TestToText(unittest.TestCase): + """Test `string_to_text` utility function.""" + + def test_text_input(self): + result = string_to_text(u'ěščřž', sentinel.msg) + self.assertIsInstance(result, six.text_type) + self.assertEqual(result, u'ěščřž') + + def test_binary_input(self): + warning_msg = 'Conversion warning' + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + result = string_to_text('ěščřž'.encode('utf-8'), warning_msg) + + self.assertIsInstance(result, six.text_type) + self.assertEqual(result, u'ěščřž') diff --git a/openid/urinorm.py b/openid/urinorm.py index 0da86ee4..abce4053 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -2,11 +2,10 @@ from __future__ import unicode_literals import string -import warnings from urllib import quote, unquote, urlencode from urlparse import parse_qsl, urlsplit, urlunsplit -import six +from .oidutil import string_to_text def remove_dot_segments(path): @@ -69,10 +68,7 @@ def urinorm(uri): @rtype: six.text_type @raise ValueError: If URI is invalid. """ - # Transform the input to the unicode string - if isinstance(uri, six.binary_type): - warnings.warn("Binary input for urinorm is deprecated. Use text input instead.", DeprecationWarning) - uri = uri.decode('utf-8') + uri = string_to_text(uri, "Binary input for urinorm is deprecated. Use text input instead.") split_uri = urlsplit(uri) From 455b4239eeeac12d209d723c66d27df12447a946 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 14:47:32 +0200 Subject: [PATCH 066/151] Transform Message API to text strings --- openid/message.py | 78 +++++++++++++++++++++---------------- openid/test/test_message.py | 22 ++++++----- 2 files changed, 57 insertions(+), 43 deletions(-) diff --git a/openid/message.py b/openid/message.py index ff8bcb5c..3e65b185 100644 --- a/openid/message.py +++ b/openid/message.py @@ -1,17 +1,23 @@ """Extension argument processing code """ -__all__ = ['Message', 'NamespaceMap', 'no_default', 'registerNamespaceAlias', - 'OPENID_NS', 'BARE_NS', 'OPENID1_NS', 'OPENID2_NS', 'SREG_URI', - 'IDENTIFIER_SELECT'] +from __future__ import unicode_literals import copy import urllib import warnings +import six from lxml import etree as ElementTree from openid import kvform, oidutil +from .oidutil import string_to_text + +__all__ = ['Message', 'NamespaceMap', 'no_default', 'registerNamespaceAlias', + 'OPENID_NS', 'BARE_NS', 'OPENID1_NS', 'OPENID2_NS', 'SREG_URI', + 'IDENTIFIER_SELECT'] + + # This doesn't REALLY belong here, but where is better? IDENTIFIER_SELECT = 'http://specs.openid.net/auth/2.0/identifier_select' @@ -154,6 +160,7 @@ def __init__(self, openid_namespace=None, implicit_namespace=None): def fromPostArgs(cls, args): """Construct a Message containing a set of POST arguments. + @type args: Dict[six.text_type, six.text_type] """ # Partition into "openid." args and bare args openid_args = {} @@ -183,6 +190,8 @@ def fromPostArgs(cls, args): def fromOpenIDArgs(cls, openid_args): """Construct a Message from a parsed KVForm message. + @type openid_args: Dict[six.text_type, six.text_type] + @raises InvalidOpenIDNamespace: if openid.ns is not in L{Message.allowed_openid_namespaces} """ @@ -197,6 +206,8 @@ def _fromOpenIDArgs(cls, openid_args): namespaces = {} ns_args = [] for key, value in openid_args.iteritems(): + key = string_to_text(key, "Binary keys in message creations are deprecated. Use text input instead.") + value = string_to_text(value, "Binary values in message creations are deprecated. Use text input instead.") if '.' not in key: ns_alias = NULL_NAMESPACE ns_key = key @@ -286,6 +297,8 @@ def copy(self): def toPostArgs(self): """Return all arguments with openid. in front of namespaced arguments. + + @rtype: Dict[six.text_type, six.text_type] """ args = {} @@ -297,12 +310,11 @@ def toPostArgs(self): ns_key = 'openid.ns' else: ns_key = 'openid.ns.' + alias - args[ns_key] = oidutil.toUnicode(ns_uri).encode('UTF-8') + args[ns_key] = ns_uri for (ns_uri, ns_key), value in self.args.iteritems(): key = self.getKey(ns_uri, ns_key) - # Ensure the resulting value is an UTF-8 encoded bytestring. - args[key] = oidutil.toUnicode(value).encode('UTF-8') + args[key] = value return args @@ -323,51 +335,49 @@ def toArgs(self): return kvargs def toFormMarkup(self, action_url, form_tag_attrs=None, - submit_text=u"Continue"): + submit_text="Continue"): """Generate HTML form markup that contains the values in this message, to be HTTP POSTed as x-www-form-urlencoded UTF-8. @param action_url: The URL to which the form will be POSTed - @type action_url: str + @type action_url: six.text_type, six.binary_type is deprecated @param form_tag_attrs: Dictionary of attributes to be added to the form tag. 'accept-charset' and 'enctype' have defaults that can be overridden. If a value is supplied for 'action' or 'method', it will be replaced. - @type form_tag_attrs: {unicode: unicode} + @type form_tag_attrs: Dict[six.text_type, six.text_type] @param submit_text: The text that will appear on the submit button for this form. - @type submit_text: unicode + @type submit_text: six.text_type @returns: A string containing (X)HTML markup for a form that encodes the values in this Message object. - @rtype: str or unicode + @rtype: six.text_type """ assert action_url is not None + action_url = string_to_text(action_url, "Binary values for action_url is deprecated. Use text input instead.") - form = ElementTree.Element(u'form') + form = ElementTree.Element('form') if form_tag_attrs: for name, attr in form_tag_attrs.iteritems(): form.attrib[name] = attr - form.attrib[u'action'] = oidutil.toUnicode(action_url) - form.attrib[u'method'] = u'post' - form.attrib[u'accept-charset'] = u'UTF-8' - form.attrib[u'enctype'] = u'application/x-www-form-urlencoded' + form.attrib['action'] = action_url + form.attrib['method'] = 'post' + form.attrib['accept-charset'] = 'UTF-8' + form.attrib['enctype'] = 'application/x-www-form-urlencoded' for name, value in self.toPostArgs().iteritems(): - attrs = {u'type': u'hidden', - u'name': oidutil.toUnicode(name), - u'value': oidutil.toUnicode(value)} - form.append(ElementTree.Element(u'input', attrs)) + attrs = {'type': 'hidden', 'name': name, 'value': value} + form.append(ElementTree.Element('input', attrs)) - submit = ElementTree.Element(u'input', - {u'type': 'submit', u'value': oidutil.toUnicode(submit_text)}) + submit = ElementTree.Element('input', {'type': 'submit', 'value': submit_text}) form.append(submit) - return ElementTree.tostring(form, encoding='utf-8') + return ElementTree.tostring(form, encoding='unicode') def toURL(self, base_url): """Generate a GET URL with the parameters in this message @@ -391,14 +401,14 @@ def _fixNS(self, namespace): this object @param namespace: The string or constant to convert - @type namespace: str or unicode or BARE_NS or OPENID_NS + @type namespace: six.text_type or BARE_NS or OPENID_NS """ if namespace == OPENID_NS: namespace = self.getOpenIDNamespace() if namespace is None: raise UndefinedOpenIDNamespace('OpenID namespace not set') - if namespace != BARE_NS and type(namespace) not in [str, unicode]: + if namespace != BARE_NS and not isinstance(namespace, six.string_types): raise TypeError( "Namespace must be BARE_NS, OPENID_NS or a string. got %r" % (namespace,)) @@ -441,21 +451,24 @@ def getArg(self, namespace, key, default=None): """Get a value for a namespaced key. @param namespace: The namespace in the message for this key - @type namespace: str + @type namespace: Union[six.text_type, NULL_NAMESPACE, OPENID_NS, BARE_NS], six.binary_type is deprecated @param key: The key to get within this namespace - @type key: str + @type key: six.text_type, six.binary_type is deprecated @param default: The value to use if this key is absent from this message. Using the special value openid.message.no_default will result in this method raising a KeyError instead of returning the default. - @rtype: str or the type of default + @rtype: six.text_type or the type of default @raises KeyError: if default is no_default @raises UndefinedOpenIDNamespace: if the message has not yet had an OpenID namespace set """ + if isinstance(namespace, six.string_types): + namespace = string_to_text(namespace, "Binary values for namespace are deprecated. Use text input instead.") + key = string_to_text(key, "Binary values for key are deprecated. Use text input instead.") namespace = self._fixNS(namespace) args_key = (namespace, key) try: @@ -484,7 +497,7 @@ def updateArgs(self, namespace, updates): """Set multiple key/value pairs in one call @param updates: The values to set - @type updates: {unicode:unicode} + @type updates: Dict[six.text_type, six.text_type] """ namespace = self._fixNS(namespace) for k, v in updates.iteritems(): @@ -582,7 +595,7 @@ def addAlias(self, namespace_uri, desired_alias, implicit=False): # Check that desired_alias does not contain a period as per # the spec. - if type(desired_alias) in [str, unicode]: + if isinstance(desired_alias, six.string_types): assert '.' not in desired_alias, \ "%r must not contain a dot" % (desired_alias,) @@ -609,8 +622,7 @@ def addAlias(self, namespace_uri, desired_alias, implicit=False): 'It is already mapped to alias %r') raise InvalidNamespace(fmt % (namespace_uri, desired_alias, alias)) - assert (desired_alias == NULL_NAMESPACE or - type(desired_alias) in [str, unicode]), repr(desired_alias) + assert (desired_alias == NULL_NAMESPACE or isinstance(desired_alias, six.string_types)), repr(desired_alias) assert namespace_uri not in self.implicit_namespaces self.alias_to_namespace[desired_alias] = namespace_uri self.namespace_to_alias[namespace_uri] = desired_alias @@ -629,7 +641,7 @@ def add(self, namespace_uri): # Fall back to generating a numerical alias i = 0 while True: - alias = 'ext' + str(i) + alias = 'ext' + six.text_type(i) try: self.addAlias(namespace_uri, alias) except KeyError: diff --git a/openid/test/test_message.py b/openid/test/test_message.py index f00182ab..24d33cfb 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from __future__ import unicode_literals + import unittest import urllib import warnings @@ -851,7 +853,7 @@ def test_openid_namespace_invalid(self): # Good guess! But wrong. 'http://openid.net/signon/2.0', # What? - u'http://specs%\\\r2Eopenid.net/auth/2.0', + 'http://specs%\\\r2Eopenid.net/auth/2.0', # Too much escapings! 'https%3A%2F%2Fround-lake.dustinice.workers.dev%3A443%2Fhttp%2Fspecs.openid.net%2Fauth%2F2.0', # This is a Type URI, not a openid.ns value. @@ -905,15 +907,15 @@ def test_fromPostArgs_ns11(self): # An example of the stuff that some Drupal installations send us, # which includes openid.ns but is 1.1. query = { - u'openid.assoc_handle': u'', - u'openid.claimed_id': u'http://foobar.invalid/', - u'openid.identity': u'http://foobar.myopenid.com', - u'openid.mode': u'checkid_setup', - u'openid.ns': u'http://openid.net/signon/1.1', - u'openid.ns.sreg': u'http://openid.net/extensions/sreg/1.1', - u'openid.return_to': u'http://drupal.invalid/return_to', - u'openid.sreg.required': u'nickname,email', - u'openid.trust_root': u'http://drupal.invalid', + 'openid.assoc_handle': '', + 'openid.claimed_id': 'http://foobar.invalid/', + 'openid.identity': 'http://foobar.myopenid.com', + 'openid.mode': 'checkid_setup', + 'openid.ns': 'http://openid.net/signon/1.1', + 'openid.ns.sreg': 'http://openid.net/extensions/sreg/1.1', + 'openid.return_to': 'http://drupal.invalid/return_to', + 'openid.sreg.required': 'nickname,email', + 'openid.trust_root': 'http://drupal.invalid', } m = Message.fromPostArgs(query) self.assertTrue(m.isOpenID1()) From 85574cd2161cf74782f51b04e988d83200bb2bf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 15:26:20 +0200 Subject: [PATCH 067/151] Transform key-value form API to text strings --- openid/association.py | 4 ++-- openid/kvform.py | 46 ++++++++++++++++++++++++-------------- openid/server/server.py | 2 +- openid/test/test_kvform.py | 18 +++++++-------- 4 files changed, 41 insertions(+), 29 deletions(-) diff --git a/openid/association.py b/openid/association.py index 920bd634..a265bd44 100644 --- a/openid/association.py +++ b/openid/association.py @@ -402,7 +402,7 @@ def serialize(self): @return: String in KV form suitable for deserialization by deserialize. - @rtype: str + @rtype: six.text_type """ data = { 'version': '2', @@ -465,7 +465,7 @@ def sign(self, pairs): @return: The binary signature of this sequence of pairs - @rtype: str + @rtype: six.text_type """ kv = kvform.seqToKV(pairs) diff --git a/openid/kvform.py b/openid/kvform.py index e0e91a0d..ca196c53 100644 --- a/openid/kvform.py +++ b/openid/kvform.py @@ -1,7 +1,14 @@ -__all__ = ['seqToKV', 'kvToSeq', 'dictToKV', 'kvToDict'] +"""Utilities for key-value format conversions.""" +from __future__ import unicode_literals import logging -import types + +import six + +from .oidutil import string_to_text + +__all__ = ['seqToKV', 'kvToSeq', 'dictToKV', 'kvToDict'] + _LOGGER = logging.getLogger(__name__) @@ -15,10 +22,10 @@ def seqToKV(seq, strict=False): key:value pairs. The pairs are generated in the order given. @param seq: The pairs - @type seq: [(str, (unicode|str))] + @type seq: List[Tuple[six.text_type, six.text_type]], binary_type values are deprecated. @return: A string representation of the sequence - @rtype: str + @rtype: six.text_type """ def err(msg): formatted = 'seqToKV warning: %s: %r' % (msg, seq) @@ -29,11 +36,15 @@ def err(msg): lines = [] for k, v in seq: - if isinstance(k, types.StringType): - k = k.decode('UTF8') - elif not isinstance(k, types.UnicodeType): - err('Converting key to string: %r' % k) - k = str(k) + if not isinstance(k, (six.text_type, six.binary_type)): + err('Converting key to text: %r' % k) + k = six.text_type(k) + if not isinstance(v, (six.text_type, six.binary_type)): + err('Converting value to text: %r' % v) + v = six.text_type(v) + + k = string_to_text(k, "Binary values for keys are deprecated. Use text input instead.") + v = string_to_text(v, "Binary values for values are deprecated. Use text input instead.") if '\n' in k: raise KVFormError( @@ -46,12 +57,6 @@ def err(msg): if k.strip() != k: err('Key has whitespace at beginning or end: %r' % (k,)) - if isinstance(v, types.StringType): - v = v.decode('UTF8') - elif not isinstance(v, types.UnicodeType): - err('Converting value to string: %r' % (v,)) - v = str(v) - if '\n' in v: raise KVFormError( 'Invalid input for seqToKV: value contains newline: %r' % (v,)) @@ -61,16 +66,21 @@ def err(msg): lines.append(k + ':' + v + '\n') - return ''.join(lines).encode('UTF8') + return ''.join(lines) def kvToSeq(data, strict=False): """ + Parse newline-terminated key:value pair string into a sequence. After one parse, seqToKV and kvToSeq are inverses, with no warnings:: seq = kvToSeq(s) seqToKV(kvToSeq(seq)) == seq + + @type data: six.text_type, six.binary_type is deprecated + + @rtype: List[Tuple[six.text_type, six.text_type]] """ def err(msg): formatted = 'kvToSeq warning: %s: %r' % (msg, data) @@ -79,6 +89,8 @@ def err(msg): else: _LOGGER.warn(formatted) + data = string_to_text(data, "Binary values for data are deprecated. Use text input instead.") + lines = data.split('\n') if lines[-1]: err('Does not end in a newline') @@ -112,7 +124,7 @@ def err(msg): 'whitespace in value %r') err(fmt % (line_num, v)) - pairs.append((k_s.decode('UTF8'), v_s.decode('UTF8'))) + pairs.append((k_s, v_s)) else: err('Line %d does not contain a colon' % line_num) diff --git a/openid/server/server.py b/openid/server/server.py index dfe44444..4576c82d 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -1068,7 +1068,7 @@ def encodeToKVForm(self): @see: OpenID Specs, U{Key-Value Colon/Newline format} - @returntype: str + @returntype: six.text_type """ return self.fields.toKVForm() diff --git a/openid/test/test_kvform.py b/openid/test/test_kvform.py index 187629a1..5ea6822c 100644 --- a/openid/test/test_kvform.py +++ b/openid/test/test_kvform.py @@ -1,5 +1,9 @@ +"""Tests for `openid.kvform` module.""" +from __future__ import unicode_literals + import unittest +import six from testfixtures import LogCapture from openid import kvform @@ -33,10 +37,6 @@ def cleanSeq(self, seq): and end of each value of each pair""" clean = [] for k, v in seq: - if isinstance(k, str): - k = k.decode('utf8') - if isinstance(v, str): - v = v.decode('utf8') clean.append((k.strip(), v.strip())) return clean @@ -46,7 +46,7 @@ def runTest(self): with LogCapture() as logbook: actual = kvform.seqToKV(kv_data) self.assertEqual(actual, result) - self.assertIsInstance(actual, str) + self.assertIsInstance(actual, six.text_type) # Parse back to sequence. Expected to be unchanged, except # stripping whitespace from start and end of values @@ -92,12 +92,12 @@ def runTest(self): kvseq_cases = [ ([], '', 0), - # Make sure that we handle non-ascii characters (also wider than 8 bits) - ([(u'\u03bbx', u'x')], '\xce\xbbx:x\n', 0), + # Make sure that we handle unicode characters + ([('\u03bbx', 'x')], '\u03bbx:x\n', 0), # If it's a UTF-8 str, make sure that it's equivalent to the same # string, decoded. - ([('\xce\xbbx', 'x')], '\xce\xbbx:x\n', 0), + ([(b'\xce\xbbx', b'x')], '\u03bbx:x\n', 0), ([('openid', 'useful'), ('a', 'b')], 'openid:useful\na:b\n', 0), @@ -113,7 +113,7 @@ def runTest(self): ([(' open id ', ' use ful '), (' a ', ' b ')], ' open id : use ful \n a : b \n', 4), - ([(u'foo', 'bar')], 'foo:bar\n', 0), + ([('foo', 'bar')], 'foo:bar\n', 0), ] kvexc_cases = [ From 2a28ff0db2ee6b2842bf09afeb037026ab9c9a9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 15:28:59 +0200 Subject: [PATCH 068/151] Drop unused toUnicode utility function --- openid/oidutil.py | 15 +-------------- openid/test/test_oidutil.py | 13 ------------- 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/openid/oidutil.py b/openid/oidutil.py index 34732ac1..6e688fb1 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -5,7 +5,7 @@ interesting. """ -__all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML', 'toUnicode'] +__all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML'] import binascii import logging @@ -17,19 +17,6 @@ _LOGGER = logging.getLogger(__name__) -def toUnicode(value): - """Returns the given argument as a unicode object. - - @param value: A UTF-8 encoded string or a unicode (coercable) object - @type message: str or unicode - - @returns: Unicode object representing the input value. - """ - if isinstance(value, str): - return value.decode('utf-8') - return unicode(value) - - def autoSubmitHTML(form, title='OpenID transaction in progress'): return """ diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index 0b499c82..9713de81 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -137,19 +137,6 @@ def runTest(self): self.assertEqual(expected, result, '{} {}'.format(name, args)) -class TestUnicodeConversion(unittest.TestCase): - - def test_toUnicode(self): - # Unicode objects pass through - self.assertIsInstance(oidutil.toUnicode(u'fööbär'), unicode) - self.assertEquals(oidutil.toUnicode(u'fööbär'), u'fööbär') - # UTF-8 encoded string are decoded - self.assertIsInstance(oidutil.toUnicode('fööbär'), unicode) - self.assertEquals(oidutil.toUnicode('fööbär'), u'fööbär') - # Other encodings raise exceptions - self.assertRaises(UnicodeDecodeError, lambda: oidutil.toUnicode(u'fööbär'.encode('latin-1'))) - - class TestSymbol(unittest.TestCase): def testCopyHash(self): import copy From eaeadc59c2be733322d97e048e02a79686a6af55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 15:39:27 +0200 Subject: [PATCH 069/151] Replace unicode with six.text_type --- openid/consumer/consumer.py | 4 ++-- openid/consumer/discover.py | 2 +- openid/extensions/ax.py | 14 +++++++------- openid/oidutil.py | 2 +- openid/server/server.py | 4 +++- openid/yadis/etxrd.py | 4 ++-- openid/yadis/xri.py | 6 +++--- openid/yadis/xrires.py | 6 +++--- 8 files changed, 22 insertions(+), 20 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index a836c900..96c6ef73 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -308,7 +308,7 @@ def begin(self, user_url, anonymous=False): normalizing and resolving any redirects the server might issue. - @type user_url: unicode + @type user_url: six.text_type @param anonymous: Whether to make an anonymous request of the OpenID provider. Such a request does not ask for an authorization @@ -1616,7 +1616,7 @@ def formMarkup(self, realm, return_to=None, immediate=False, form_tag_attrs=None the form tag. 'accept-charset' and 'enctype' have defaults that can be overridden. If a value is supplied for 'action' or 'method', it will be replaced. - @type form_tag_attrs: {unicode: unicode} + @type form_tag_attrs: Dict[six.text_type, six.text_type] """ message = self.getMessage(realm, return_to, immediate) return message.toFormMarkup(self.endpoint.server_url, form_tag_attrs) diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index 0824af43..08af0a63 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -269,7 +269,7 @@ def findOPLocalIdentifier(service_element, type_uris): @returns: The OP-Local Identifier for this service element, if one is present, or None otherwise. - @rtype: str or unicode or NoneType + @rtype: six.text_type or NoneType """ # XXX: Test this function on its own! diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index b48d19ce..39b85cdc 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -212,7 +212,7 @@ def getExtensionArgs(self): """Get the serialized form of this attribute fetch request. @returns: The fetch request message parameters - @rtype: {unicode:unicode} + @rtype: Dict[six.text_type, six.text_type] """ aliases = NamespaceMap() @@ -418,7 +418,7 @@ def addValue(self, type_uri, value): @param value: The value to add to the response to the relying party for this attribute - @type value: unicode + @type value: six.text_type @returns: None """ @@ -436,7 +436,7 @@ def setValues(self, type_uri, values): @param type_uri: The URI for the attribute @param values: A list of values to send for this attribute. - @type values: [unicode] + @type values: List[six.text_type] """ self.data[type_uri] = values @@ -471,7 +471,7 @@ def parseExtensionArgs(self, ax_args): @param ax_args: The attribute exchange fetch_response arguments, with namespacing removed. - @type ax_args: {unicode:unicode} + @type ax_args: Dict[six.text_type, six.text_type] @returns: None @@ -525,7 +525,7 @@ def getSingle(self, type_uri, default=None): @returns: The value of the attribute in the fetch_response message, or the default supplied - @rtype: unicode or NoneType + @rtype: six.text_type or NoneType @raises ValueError: If there is more than one value for this parameter in the fetch_response message. @@ -554,7 +554,7 @@ def get(self, type_uri): @returns: The list of values for this attribute in the response. May be an empty list. - @rtype: [unicode] + @rtype: List[six.text_type] @raises KeyError: If the attribute was not sent in the response """ @@ -605,7 +605,7 @@ def getExtensionArgs(self): @returns: The dictionary of unqualified attribute exchange arguments that represent this fetch_response. - @rtype: {unicode;unicode} + @rtype: Dict[six.text_type, six.text_type] """ aliases = NamespaceMap() diff --git a/openid/oidutil.py b/openid/oidutil.py index 6e688fb1..3d3f03c2 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -121,7 +121,7 @@ def fromBase64(s): class Symbol(object): """This class implements an object that compares equal to others of the same type that have the same name. These are distict from - str or unicode objects. + string objects. """ def __init__(self, name): diff --git a/openid/server/server.py b/openid/server/server.py index 4576c82d..6afe8262 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -121,6 +121,8 @@ import warnings from copy import deepcopy +import six + from openid import cryptutil, kvform, oidutil from openid.association import Association, default_negotiator, getSecretSize from openid.dh import DiffieHellman @@ -1624,7 +1626,7 @@ def __init__(self, message, text=None, reference=None, contact=None): self.openid_message = message self.reference = reference self.contact = contact - assert type(message) not in [str, unicode] + assert not isinstance(message, six.string_types) Exception.__init__(self, text) def getReturnTo(self): diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index a96a107d..4039c9ec 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -153,13 +153,13 @@ def getCanonicalID(iname, xrd_tree): """Return the CanonicalID from this XRDS document. @param iname: the XRI being resolved. - @type iname: unicode + @type iname: six.text_type @param xrd_tree: The XRDS output from the resolver. @type xrd_tree: ElementTree @returns: The XRI CanonicalID or None. - @returntype: unicode or None + @returntype: six.text_type or None """ xrd_list = xrd_tree.findall(xrd_tag) xrd_list.reverse() diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index 73ad7980..7728c600 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -94,8 +94,8 @@ def rootAuthority(xri): rootAuthority("xri://@example") == "xri://@" - @type xri: unicode - @returntype: unicode + @type xri: six.text_type + @returntype: six.text_type """ if xri.startswith('xri://'): xri = xri[6:] @@ -127,7 +127,7 @@ def XRI(xri): canonicalization by ensuring the xri scheme is present. @param xri: an xri string - @type xri: unicode + @type xri: six.text_type """ if not xri.startswith('xri://'): xri = 'xri://' + xri diff --git a/openid/yadis/xrires.py b/openid/yadis/xrires.py index 4a365950..fe54c48d 100644 --- a/openid/yadis/xrires.py +++ b/openid/yadis/xrires.py @@ -23,7 +23,7 @@ def queryURL(self, xri, service_type=None): """Build a URL to query the proxy resolver. @param xri: An XRI to resolve. - @type xri: unicode + @type xri: six.text_type @param service_type: The service type to resolve, if you desire service endpoint selection. A service type is a URI. @@ -63,14 +63,14 @@ def query(self, xri, service_types): the fetching or parsing don't go so well. @param xri: An XRI to resolve. - @type xri: unicode + @type xri: six.text_type @param service_types: A list of services types to query for. Service types are URIs. @type service_types: list of str @returns: tuple of (CanonicalID, Service elements) - @returntype: (unicode, list of C{ElementTree.Element}s) + @returntype: (six.text_type, list of C{ElementTree.Element}s) """ # FIXME: No test coverage! services = [] From 656420617dddf554a12d6ef25251a8149b7f8cc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 13:10:58 +0200 Subject: [PATCH 070/151] Update strings in cryptography and utilities API --- openid/cryptutil.py | 91 ++++++++++++++++++------ openid/dh.py | 11 ++- openid/oidutil.py | 50 ++++++++----- openid/store/nonce.py | 4 +- openid/test/test_association_response.py | 2 +- openid/test/test_consumer.py | 5 +- openid/test/test_cryptutil.py | 22 +++--- openid/test/test_dh.py | 33 +++++---- openid/test/test_oidutil.py | 22 +++--- openid/test/test_server.py | 6 +- openid/test/test_storetest.py | 2 +- 11 files changed, 165 insertions(+), 83 deletions(-) diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 27ff7965..fd7b30ca 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -9,6 +9,16 @@ http://www.amk.ca/python/code/crypto """ +from __future__ import unicode_literals + +import hashlib +import hmac +import os +import random + +import six + +from openid.oidutil import fromBase64, string_to_text, toBase64 __all__ = [ 'base64ToLong', @@ -23,13 +33,6 @@ 'sha256', ] -import hashlib -import hmac -import os -import random - -from openid.oidutil import fromBase64, toBase64 - class HashContainer(object): def __init__(self, hash_constructor): @@ -42,18 +45,46 @@ def __init__(self, hash_constructor): def hmacSha1(key, text): - return hmac.new(key, text, sha1_module).digest() + """ + Return a SHA1 HMAC. + + @type key: six.binary_type + @type text: six.text_type, six.binary_type is deprecated + @rtype: six.binary_type + """ + text = string_to_text(text, "Binary values for text are deprecated. Use text input instead.") + return hmac.new(key, text.encode('utf-8'), sha1_module).digest() def sha1(s): + """ + Return a SHA1 hash. + + @type s: six.binary_type + @rtype: six.binary_type + """ return sha1_module.new(s).digest() def hmacSha256(key, text): - return hmac.new(key, text, sha256_module).digest() + """ + Return a SHA256 HMAC. + + @type key: six.binary_type + @type text: six.text_type, six.binary_type is deprecated + @rtype: six.binary_type + """ + text = string_to_text(text, "Binary values for text are deprecated. Use text input instead.") + return hmac.new(key, text.encode('utf-8'), sha256_module).digest() def sha256(s): + """ + Return a SHA256 hash. + + @type s: six.binary_type + @rtype: six.binary_type + """ return sha256_module.new(s).digest() @@ -64,12 +95,12 @@ def sha256(s): def longToBinary(value): if value == 0: - return '\x00' + return b'\x00' - return ''.join(reversed(pickle.encode_long(value))) + return pickle.encode_long(value)[::-1] def binaryToLong(s): - return pickle.decode_long(''.join(reversed(s))) + return pickle.decode_long(s[::-1]) else: # We have pycrypto @@ -77,20 +108,28 @@ def longToBinary(value): if value < 0: raise ValueError('This function only supports positive integers') - bytes = long_to_bytes(value) - if ord(bytes[0]) > 127: - return '\x00' + bytes + output = long_to_bytes(value) + if isinstance(output[0], int): + ord_first = output[0] else: - return bytes + ord_first = ord(output[0]) + if ord_first > 127: + return b'\x00' + output + else: + return output - def binaryToLong(bytes): - if not bytes: + def binaryToLong(s): + if not s: raise ValueError('Empty string passed to strToLong') - if ord(bytes[0]) > 127: + if isinstance(s[0], int): + ord_first = s[0] + else: + ord_first = ord(s[0]) + if ord_first > 127: raise ValueError('This function only supports positive integers') - return bytes_to_long(bytes) + return bytes_to_long(s) # A cryptographically safe source of random bytes try: @@ -179,12 +218,20 @@ def base64ToLong(s): def randomString(length, chrs=None): - """Produce a string of length random bytes, chosen from chrs.""" + """Produce a string of length random bytes, chosen from chrs. + + @type chrs: six.binary_type + @rtype: six.binary_type + """ if chrs is None: return getBytes(length) else: n = len(chrs) - return ''.join([chrs[randrange(n)] for _ in xrange(length)]) + random_chars = [chrs[randrange(n)] for _ in range(length)] + if six.PY2: + return b''.join(random_chars) + else: + return six.binary_type(random_chars) def const_eq(s1, s2): diff --git a/openid/dh.py b/openid/dh.py index 74065fd2..ab9f984c 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -1,7 +1,12 @@ +from __future__ import unicode_literals + +import six + from openid import cryptutil def _xor(a_b): + # Python 2 only a, b = a_b return chr(ord(a) ^ ord(b)) @@ -10,7 +15,11 @@ def strxor(x, y): if len(x) != len(y): raise ValueError('Inputs to strxor must have the same length') - return "".join(_xor((a, b)) for a, b in zip(x, y)) + if six.PY2: + return b"".join(_xor((a, b)) for a, b in zip(x, y)) + else: + assert six.PY3 + return bytes((a ^ b) for a, b in zip(x, y)) class DiffieHellman(object): diff --git a/openid/oidutil.py b/openid/oidutil.py index 3d3f03c2..b164165b 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -4,8 +4,7 @@ For users of this library, the C{L{log}} function is probably the most interesting. """ - -__all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML'] +from __future__ import unicode_literals import binascii import logging @@ -14,6 +13,8 @@ import six +__all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML'] + _LOGGER = logging.getLogger(__name__) @@ -44,7 +45,7 @@ def log(message, level=0): @param message: A string containing a debugging message from the OpenID library - @type message: str + @type message: six.text_type, six.binary_type is deprecated @param level: The severity of the log message. This parameter is currently unused, but in the future, the library may indicate @@ -53,6 +54,7 @@ def log(message, level=0): @returns: Nothing. """ + message = string_to_text(message, "Binary values for log are deprecated. Use text input instead.") logging.error("This is a legacy log message, please use the logging module. Message: %s", message) @@ -64,18 +66,20 @@ def appendArgs(url, args): detected or collapsed (both will appear in the output). @param url: The url to which the arguments will be appended - @type url: str + @type url: six.text_type, six.binary_type is deprecated @param args: The query arguments to add to the URL. If a dictionary is passed, the items will be sorted before appending them to the URL. If a sequence of pairs is passed, the order of the sequence will be preserved. - @type args: A dictionary from string to string, or a sequence of - pairs of strings. + @type args: Union[Dict[six.text_type, six.text_type], List[Tuple[six.text_type, six.text_type]]], + six.binary_type is deprecated @returns: The URL with the parameters added - @rtype: str + @rtype: six.text_type """ + url = string_to_text(url, "Binary values for appendArgs are deprecated. Use text input instead.") + if hasattr(args, 'items'): args = sorted(args.items()) else: @@ -89,28 +93,36 @@ def appendArgs(url, args): else: sep = '?' - # Map unicode to UTF-8 if present. Do not make any assumptions - # about the encodings of plain bytes (str). i = 0 for k, v in args: - if not isinstance(k, str): - k = k.encode('UTF-8') - - if not isinstance(v, str): - v = v.encode('UTF-8') - - args[i] = (k, v) + k = string_to_text(k, "Binary values for appendArgs are deprecated. Use text input instead.") + v = string_to_text(v, "Binary values for appendArgs are deprecated. Use text input instead.") + args[i] = (k.encode('utf-8'), v.encode('utf-8')) i += 1 - return '%s%s%s' % (url, sep, urlencode(args)) + encoded_args = urlencode(args) + # `urlencode` returns `str` in both py27 and py3+. We need to convert it to six.text_type. + if not isinstance(encoded_args, six.text_type): + encoded_args = encoded_args.decode('utf-8') + return '%s%s%s' % (url, sep, encoded_args) def toBase64(s): - """Represent string s as base64, omitting newlines""" - return binascii.b2a_base64(s)[:-1] + """Return string s as base64, omitting newlines. + + @type s: six.binary_type + @rtype six.text_type + """ + return binascii.b2a_base64(s)[:-1].decode('utf-8') def fromBase64(s): + """Return binary data from base64 encoded string. + + @type s: six.text_type, six.binary_type deprecated. + @rtype six.binary_type + """ + s = string_to_text(s, "Binary values for s are deprecated. Use text input instead.") try: return binascii.a2b_base64(s) except binascii.Error as why: diff --git a/openid/store/nonce.py b/openid/store/nonce.py index 800dfecf..60b3a891 100644 --- a/openid/store/nonce.py +++ b/openid/store/nonce.py @@ -10,7 +10,7 @@ from openid import cryptutil -NONCE_CHARS = string.ascii_letters + string.digits +NONCE_CHARS = (string.ascii_letters + string.digits).encode('utf-8') # Keep nonces for five hours (allow five hours for the combination of # request time and clock skew). This is probably way more than is @@ -89,7 +89,7 @@ def mkNonce(when=None): @see: time """ - salt = cryptutil.randomString(6, NONCE_CHARS) + salt = cryptutil.randomString(6, NONCE_CHARS).decode('utf-8') if when is None: t = gmtime() else: diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 3e5dfd07..2df50f4f 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -296,7 +296,7 @@ def test_badExpiresIn(self): # sort of a unit test and sort of a functional test. I'm not terribly # fond of it. class TestExtractAssociationDiffieHellman(BaseAssocTest): - secret = 'x' * 20 + secret = b'x' * 20 def _setUpDH(self): sess, message = self.consumer._createAssociateRequest( diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 2b0a2af8..8e63b684 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -22,8 +22,9 @@ from openid.yadis.manager import Discovery assocs = [ - ('another 20-byte key.', 'Snarky'), - ('\x00' * 20, 'Zeros'), + # (secret, handle) + (b'another 20-byte key.', 'Snarky'), + (b'\x00' * 20, 'Zeros'), ] diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index 4cf57b00..7697ada7 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -1,9 +1,13 @@ """Test `openid.cryptutil` module.""" +from __future__ import unicode_literals + import os.path import random import sys import unittest +import six + from openid import cryptutil # Most of the purpose of this test is to make sure that cryptutil can @@ -45,19 +49,19 @@ def test_binaryLongConvert(self): n += long(random.randrange(MAX)) s = cryptutil.longToBinary(n) - assert isinstance(s, str) + assert isinstance(s, six.binary_type) n_prime = cryptutil.binaryToLong(s) assert n == n_prime, (n, n_prime) cases = [ - ('\x00', 0), - ('\x01', 1), - ('\x7F', 127), - ('\x00\xFF', 255), - ('\x00\x80', 128), - ('\x00\x81', 129), - ('\x00\x80\x00', 32768), - ('OpenID is cool', 1611215304203901150134421257416556) + (b'\x00', 0), + (b'\x01', 1), + (b'\x7F', 127), + (b'\x00\xFF', 255), + (b'\x00\x80', 128), + (b'\x00\x81', 129), + (b'\x00\x80\x00', 32768), + (b'OpenID is cool', 1611215304203901150134421257416556) ] for s, n in cases: diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 03ef20bc..707eb4d3 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -1,7 +1,11 @@ """Test `openid.dh` module.""" +from __future__ import unicode_literals + import os.path import unittest +import six + from openid.dh import DiffieHellman, strxor @@ -9,18 +13,18 @@ class TestStrXor(unittest.TestCase): """Test `strxor` function.""" def test_strxor(self): - NUL = '\x00' + NUL = b'\x00' cases = [ (NUL, NUL, NUL), - ('\x01', NUL, '\x01'), - ('a', 'a', NUL), - ('a', NUL, 'a'), - ('abc', NUL * 3, 'abc'), - ('x' * 10, NUL * 10, 'x' * 10), - ('\x01', '\x02', '\x03'), - ('\xf0', '\x0f', '\xff'), - ('\xff', '\x0f', '\xf0'), + (b'\x01', NUL, b'\x01'), + (b'a', b'a', NUL), + (b'a', NUL, b'a'), + (b'abc', NUL * 3, b'abc'), + (b'x' * 10, NUL * 10, b'x' * 10), + (b'\x01', b'\x02', b'\x03'), + (b'\xf0', b'\x0f', b'\xff'), + (b'\xff', b'\x0f', b'\xf0'), ] for aa, bb, expected in cases: @@ -28,12 +32,15 @@ def test_strxor(self): assert actual == expected, (aa, bb, expected, actual) exc_cases = [ - ('', 'a'), - ('foo', 'ba'), + (b'', b'a'), + (b'foo', b'ba'), (NUL * 3, NUL * 4), - (''.join(chr(i) for i in range(256)), - ''.join(chr(i) for i in range(128))), ] + if six.PY2: + exc_cases.append((b''.join(chr(i) for i in range(256)), b''.join(chr(i) for i in range(128)))) + else: + assert six.PY3 + exc_cases.append((bytes(i for i in range(256)), bytes(i for i in range(128)))) for aa, bb in exc_cases: try: diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index 9713de81..f0fa13d8 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- """Test `openid.oidutil` module.""" +from __future__ import unicode_literals + import random import string import unittest @@ -28,12 +30,12 @@ def checkEncoded(s): assert isAllowed(c), s cases = [ - '', - 'x', - '\x00', - '\x01', - '\x00' * 100, - ''.join(chr(i) for i in range(256)), + b'', + b'x', + b'\x00', + b'\x01', + b'\x00' * 100, + b''.join(chr(i) for i in range(256)), ] for s in cases: @@ -45,7 +47,7 @@ def checkEncoded(s): # Randomized test for _ in xrange(50): n = random.randrange(2048) - s = ''.join(chr(random.randrange(256)) for i in range(n)) + s = b''.join(chr(random.randrange(256)) for i in range(n)) b64 = oidutil.toBase64(s) checkEncoded(b64) s_prime = oidutil.fromBase64(b64) @@ -158,9 +160,9 @@ class TestToText(unittest.TestCase): """Test `string_to_text` utility function.""" def test_text_input(self): - result = string_to_text(u'ěščřž', sentinel.msg) + result = string_to_text('ěščřž', sentinel.msg) self.assertIsInstance(result, six.text_type) - self.assertEqual(result, u'ěščřž') + self.assertEqual(result, 'ěščřž') def test_binary_input(self): warning_msg = 'Conversion warning' @@ -169,4 +171,4 @@ def test_binary_input(self): result = string_to_text('ěščřž'.encode('utf-8'), warning_msg) self.assertIsInstance(result, six.text_type) - self.assertEqual(result, u'ěščřž') + self.assertEqual(result, 'ěščřž') diff --git a/openid/test/test_server.py b/openid/test/test_server.py index d4a1f146..0ce5ea8f 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -1308,7 +1308,7 @@ def test_dhSHA1(self): self.assertTrue(rfg("enc_mac_key")) self.assertTrue(rfg("dh_server_public")) - enc_key = rfg("enc_mac_key").decode('base64') + enc_key = oidutil.fromBase64(rfg("enc_mac_key")) spub = cryptutil.base64ToLong(rfg("dh_server_public")) secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha1) self.assertEqual(secret, self.assoc.secret) @@ -1333,7 +1333,7 @@ def test_dhSHA256(self): self.assertTrue(rfg("enc_mac_key")) self.assertTrue(rfg("dh_server_public")) - enc_key = rfg("enc_mac_key").decode('base64') + enc_key = oidutil.fromBase64(rfg("enc_mac_key")) spub = cryptutil.base64ToLong(rfg("dh_server_public")) secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha256) self.assertEqual(secret, self.assoc.secret) @@ -1802,7 +1802,7 @@ def test_verifyBadSig(self): 'openid.apple': 'orange', 'openid.assoc_handle': assoc_handle, 'openid.signed': 'apple,assoc_handle,foo,signed', - 'openid.sig': 'uXoT1qm62/BB09Xbj98TQ8mlBco='.encode('rot13'), + 'openid.sig': 'invalid/BB09Xbj98TQ8mlBco=', }) with LogCapture() as logbook: diff --git a/openid/test/test_storetest.py b/openid/test/test_storetest.py index 6937f041..898a406b 100644 --- a/openid/test/test_storetest.py +++ b/openid/test/test_storetest.py @@ -20,7 +20,7 @@ def generateHandle(n): - return randomString(n, allowed_handle) + return randomString(n, allowed_handle.encode('utf-8')) generateSecret = randomString From 876e275b4f6c7ce0e3f2354a541e51771f72d89e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 13:11:18 +0200 Subject: [PATCH 071/151] Transform Association API to text strings --- openid/association.py | 79 ++++++++++++++++----------------- openid/test/test_association.py | 37 +++++++-------- 2 files changed, 54 insertions(+), 62 deletions(-) diff --git a/openid/association.py b/openid/association.py index a265bd44..ae566e7a 100644 --- a/openid/association.py +++ b/openid/association.py @@ -24,6 +24,16 @@ does not support C{'no-encryption'} associations. It prefers HMAC-SHA1/DH-SHA1 association types if available. """ +from __future__ import unicode_literals + +import time + +import six + +from openid import cryptutil, kvform, oidutil +from openid.message import OPENID_NS + +from .oidutil import string_to_text __all__ = [ 'default_negotiator', @@ -32,10 +42,6 @@ 'Association', ] -import time - -from openid import cryptutil, kvform, oidutil -from openid.message import OPENID_NS all_association_types = [ 'HMAC-SHA1', @@ -132,7 +138,7 @@ class SessionNegotiator(object): determines preference. If an association/session type comes earlier in the list, the library is more likely to use that type. - @type allowed_types: [(str, str)] + @type allowed_types: List[Tuple[six.text_type, six.text_type]] """ def __init__(self, allowed_types): @@ -144,6 +150,11 @@ def copy(self): def setAllowedTypes(self, allowed_types): """Set the allowed association types, checking to make sure each combination is valid.""" + # Convert strings to text + allowed_types = [ + (string_to_text(a, "Binary values for assoc_type are deprecated. Use text input instead."), + string_to_text(s, "Binary values for session_type are deprecated. Use text input instead.")) + for a, s in allowed_types] for (assoc_type, session_type) in allowed_types: checkSessionType(assoc_type, session_type) @@ -209,14 +220,12 @@ class Association(object): C{L{assoc_type}} instance variables. @ivar handle: This is the handle the server gave this association. - - @type handle: C{str} + @type handle: six.text_type @ivar secret: This is the shared secret the server generated for this association. - - @type secret: C{str} + @type secret: six.binary_type @ivar issued: This is the time this association was issued, in @@ -236,8 +245,7 @@ class Association(object): @ivar assoc_type: This is the type of association this instance represents. The only valid value of this field at this time is C{'HMAC-SHA1'}, but new types may be defined in the future. - - @type assoc_type: C{str} + @type assoc_type: six.text_type @sort: __init__, fromExpiresIn, getExpiresIn, __eq__, __ne__, @@ -277,22 +285,17 @@ def fromExpiresIn(cls, expires_in, handle, secret, assoc_type): @param handle: This is the handle the server gave this association. - - @type handle: C{str} - + @type handle: six.text_type, six.binary_type is deprecated @param secret: This is the shared secret the server generated for this association. - - @type secret: C{str} - + @type secret: six.binary_type @param assoc_type: This is the type of association this instance represents. The only valid value of this field at this time is C{'HMAC-SHA1'}, but new types may be defined in the future. - - @type assoc_type: C{str} + @type assoc_type: six.text_type, six.binary_type is deprecated """ issued = int(time.time()) lifetime = expires_in @@ -305,14 +308,12 @@ def __init__(self, handle, secret, issued, lifetime, assoc_type): @param handle: This is the handle the server gave this association. - - @type handle: C{str} + @type handle: six.text_type, six.binary_type is deprecated @param secret: This is the shared secret the server generated for this association. - - @type secret: C{str} + @type secret: six.binary_type @param issued: This is the time this association was issued, @@ -333,8 +334,7 @@ def __init__(self, handle, secret, issued, lifetime, assoc_type): instance represents. The only valid value of this field at this time is C{'HMAC-SHA1'}, but new types may be defined in the future. - - @type assoc_type: C{str} + @type assoc_type: six.text_type, six.binary_type is deprecated """ if assoc_type not in all_association_types: fmt = '%r is not a supported association type' @@ -345,11 +345,13 @@ def __init__(self, handle, secret, issued, lifetime, assoc_type): # fmt = 'Wrong size secret (%s bytes) for association type %s' # raise ValueError(fmt % (len(secret), assoc_type)) - self.handle = handle + self.handle = string_to_text(handle, "Binary values for handle are deprecated. Use text input instead.") + assert isinstance(secret, six.binary_type) self.secret = secret self.issued = issued self.lifetime = lifetime - self.assoc_type = assoc_type + self.assoc_type = string_to_text(assoc_type, + "Binary values for assoc_type are deprecated. Use text input instead.") def getExpiresIn(self, now=None): """ @@ -408,8 +410,8 @@ def serialize(self): 'version': '2', 'handle': self.handle, 'secret': oidutil.toBase64(self.secret), - 'issued': str(int(self.issued)), - 'lifetime': str(int(self.lifetime)), + 'issued': six.text_type(int(self.issued)), + 'lifetime': six.text_type(int(self.lifetime)), 'assoc_type': self.assoc_type } @@ -429,13 +431,12 @@ def deserialize(cls, assoc_s): @param assoc_s: Association as serialized by serialize() - - @type assoc_s: str - + @type assoc_s: six.text_type, six.binary_type is deprecated @return: instance of this class """ - pairs = kvform.kvToSeq(assoc_s, strict=True) + pairs = kvform.kvToSeq( + string_to_text(assoc_s, "Binary values for assoc_s are deprecated. Use text input instead."), strict=True) keys = [] values = [] for k, v in pairs: @@ -459,14 +460,13 @@ def sign(self, pairs): @param pairs: The pairs to sign, in order - - @type pairs: sequence of (str, str) - + @type pairs: Iterable[six.text_type, six.text_type], six.binary_type is deprecated @return: The binary signature of this sequence of pairs - - @rtype: six.text_type + @rtype: six.binary_type """ + warning_msg = "Binary values for pairs are deprecated. Use text input instead." + pairs = [(string_to_text(a, warning_msg), string_to_text(b, warning_msg)) for a, b in pairs] kv = kvform.seqToKV(pairs) try: @@ -484,8 +484,7 @@ def getMessageSignature(self, message): signed list. @return: the signature, base64 encoded - - @rtype: str + @rtype: six.text_type @raises ValueError: If there is no signed list and I am not a sign-all type of association. diff --git a/openid/test/test_association.py b/openid/test/test_association.py index bd042e98..2dd6266f 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import time import unittest @@ -12,8 +14,7 @@ class AssociationSerializationTest(unittest.TestCase): def test_roundTrip(self): issued = int(time.time()) lifetime = 600 - assoc = association.Association( - 'handle', 'secret', issued, lifetime, 'HMAC-SHA1') + assoc = association.Association('handle', b'secret', issued, lifetime, 'HMAC-SHA1') s = assoc.serialize() assoc2 = association.Association.deserialize(s) self.assertEqual(assoc.handle, assoc2.handle) @@ -30,10 +31,10 @@ def createNonstandardConsumerDH(): class DiffieHellmanSessionTest(unittest.TestCase): secrets = [ - '\x00' * 20, - '\xff' * 20, - ' ' * 20, - 'This is a secret....', + b'\x00' * 20, + b'\xff' * 20, + b' ' * 20, + b'This is a secret....', ] session_factories = [ @@ -66,8 +67,7 @@ def setUp(self): 'sig': 'cephalopod', }) m.updateArgs(BARE_NS, {'xey': 'value'}) - self.assoc = association.Association.fromExpiresIn( - 3600, '{sha1}', 'very_secret', "HMAC-SHA1") + self.assoc = association.Association.fromExpiresIn(3600, '{sha1}', b'very_secret', "HMAC-SHA1") def testMakePairs(self): """Make pairs using the OpenID 1.x type signed list.""" @@ -85,18 +85,14 @@ def setUp(self): ('key2', 'value2')] def test_sha1(self): - assoc = association.Association.fromExpiresIn( - 3600, '{sha1}', 'very_secret', "HMAC-SHA1") - expected = ('\xe0\x1bv\x04\xf1G\xc0\xbb\x7f\x9a\x8b' - '\xe9\xbc\xee}\\\xe5\xbb7*') + assoc = association.Association.fromExpiresIn(3600, '{sha1}', b'very_secret', "HMAC-SHA1") + expected = (b'\xe0\x1bv\x04\xf1G\xc0\xbb\x7f\x9a\x8b\xe9\xbc\xee}\\\xe5\xbb7*') sig = assoc.sign(self.pairs) self.assertEqual(sig, expected) def test_sha256(self): - assoc = association.Association.fromExpiresIn( - 3600, '{sha256SA}', 'very_secret', "HMAC-SHA256") - expected = ('\xfd\xaa\xfe;\xac\xfc*\x988\xad\x05d6-\xeaVy' - '\xd5\xa5Z.<\xa9\xed\x18\x82\\$\x95x\x1c&') + assoc = association.Association.fromExpiresIn(3600, '{sha256SA}', b'very_secret', "HMAC-SHA256") + expected = (b'\xfd\xaa\xfe;\xac\xfc*\x988\xad\x05d6-\xeaVy\xd5\xa5Z.<\xa9\xed\x18\x82\\$\x95x\x1c&') sig = assoc.sign(self.pairs) self.assertEqual(sig, expected) @@ -112,16 +108,14 @@ def setUp(self): 'xey': 'value'} def test_signSHA1(self): - assoc = association.Association.fromExpiresIn( - 3600, '{sha1}', 'very_secret', "HMAC-SHA1") + assoc = association.Association.fromExpiresIn(3600, '{sha1}', b'very_secret', "HMAC-SHA1") signed = assoc.signMessage(self.message) self.assertTrue(signed.getArg(OPENID_NS, "sig")) self.assertEqual(signed.getArg(OPENID_NS, "signed"), "assoc_handle,identifier,mode,ns,signed") self.assertEqual(signed.getArg(BARE_NS, "xey"), "value") def test_signSHA256(self): - assoc = association.Association.fromExpiresIn( - 3600, '{sha1}', 'very_secret', "HMAC-SHA256") + assoc = association.Association.fromExpiresIn(3600, '{sha1}', b'very_secret', "HMAC-SHA256") signed = assoc.signMessage(self.message) self.assertTrue(signed.getArg(OPENID_NS, "sig")) self.assertEqual(signed.getArg(OPENID_NS, "signed"), "assoc_handle,identifier,mode,ns,signed") @@ -136,6 +130,5 @@ def test_aintGotSignedList(self): 'sig': 'coyote', }) m.updateArgs(BARE_NS, {'xey': 'value'}) - assoc = association.Association.fromExpiresIn( - 3600, '{sha1}', 'very_secret', "HMAC-SHA1") + assoc = association.Association.fromExpiresIn(3600, '{sha1}', b'very_secret', "HMAC-SHA1") self.assertRaises(ValueError, assoc.checkMessageSignature, m) From cbc2ed111c1cef8d704d96602d68a6701e5b98d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 3 May 2018 16:47:01 +0200 Subject: [PATCH 072/151] Transform Yadis API to text strings --- openid/test/test_accept.py | 14 +++++++------ openid/test/test_etxrd.py | 15 ++++++++------ openid/test/test_openidyadis.py | 9 ++++++-- openid/test/test_services.py | 2 ++ openid/test/test_xri.py | 2 ++ openid/test/test_xrires.py | 1 + openid/test/test_yadis_discover.py | 3 +-- openid/yadis/accept.py | 33 +++++++++++++++++++++--------- openid/yadis/constants.py | 5 ++++- openid/yadis/discover.py | 7 ++++--- openid/yadis/etxrd.py | 23 ++++++++++++--------- openid/yadis/filters.py | 10 ++++++--- openid/yadis/manager.py | 7 +++++-- openid/yadis/services.py | 12 +++++------ openid/yadis/xri.py | 2 ++ openid/yadis/xrires.py | 14 +++++++------ 16 files changed, 102 insertions(+), 57 deletions(-) diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 0b2fbb91..d499959b 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -1,4 +1,6 @@ """Test `openid.yadis.accept` module.""" +from __future__ import unicode_literals + import os.path import unittest @@ -8,13 +10,13 @@ def getTestData(): """Read the test data off of disk - () -> [(int, str)] + () -> [(int, six.text_type)] """ filename = os.path.join(os.path.dirname(__file__), 'data', 'accept.txt') i = 1 lines = [] for line in file(filename): - lines.append((i, line)) + lines.append((i, line.decode('utf-8'))) i += 1 return lines @@ -22,7 +24,7 @@ def getTestData(): def chunk(lines): """Return groups of lines separated by whitespace or comments - [(int, str)] -> [[(int, str)]] + [(int, six.text_type)] -> [[(int, six.text_type)]] """ chunks = [] chunk = [] @@ -44,7 +46,7 @@ def chunk(lines): def parseLines(chunk): """Take the given chunk of lines and turn it into a test data dictionary - [(int, str)] -> {str:(int, str)} + [(int, six.text_type)] -> {six.text_type:(int, six.text_type)} """ items = {} for (lineno, line) in chunk: @@ -58,7 +60,7 @@ def parseLines(chunk): def parseAvailable(available_text): """Parse an Available: line's data - str -> [str] + six.text_type -> [six.text_type] """ return [s.strip() for s in available_text.split(',')] @@ -66,7 +68,7 @@ def parseAvailable(available_text): def parseExpected(expected_text): """Parse an Expected: line's data - str -> [(str, float)] + six.text_type -> [(six.text_type, float)] """ expected = [] if expected_text: diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index 2b842d7d..6387dbb7 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -1,7 +1,10 @@ +from __future__ import unicode_literals + import os.path import tempfile import unittest +import six from lxml import etree from openid.yadis import etxrd, services, xri @@ -71,11 +74,11 @@ def test_xxe(self): xxe_file.write(xxe_content) # XXE example from Testing for XML Injection (OTG-INPVAL-008) # https://www.owasp.org/index.php/Testing_for_XML_Injection_(OTG-INPVAL-008) - xml = ('' - '' - ']>' - '&xxe;') + xml = (b'' + b'' + b']>' + b'&xxe;') tree = etxrd.parseXRDS(xml % tmp_file) self.assertNotIn(xxe_content, etree.tostring(tree)) finally: @@ -226,7 +229,7 @@ def test(self): # somewhere in the resolution chain. def _getCanonicalID(self, iname, xrds, expectedID): - if isinstance(expectedID, (str, unicode, type(None))): + if isinstance(expectedID, six.string_types + (type(None), )): cid = etxrd.getCanonicalID(iname, xrds) self.assertEqual(cid, expectedID and xri.XRI(expectedID)) elif issubclass(expectedID, etxrd.XRDSError): diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 0b19ce2d..4e77b606 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -1,5 +1,9 @@ +from __future__ import unicode_literals + import unittest +import six + from openid.consumer.discover import OPENID_1_0_TYPE, OPENID_1_1_TYPE, OpenIDServiceEndpoint from openid.yadis.services import applyFilter @@ -16,7 +20,8 @@ def mkXRDS(services): - return XRDS_BOILERPLATE % (services,) + xrds = XRDS_BOILERPLATE % services + return xrds.encode('utf-8') def mkService(uris=None, type_uris=None, local_id=None, dent=' '): @@ -35,7 +40,7 @@ def mkService(uris=None, type_uris=None, local_id=None, dent=' '): chunks.extend([dent2, '', uri, '\n']) if local_id: diff --git a/openid/test/test_services.py b/openid/test/test_services.py index 94ae817a..925e1cf5 100644 --- a/openid/test/test_services.py +++ b/openid/test/test_services.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest from openid.yadis import services diff --git a/openid/test/test_xri.py b/openid/test/test_xri.py index 341472ed..fd37b653 100644 --- a/openid/test/test_xri.py +++ b/openid/test/test_xri.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + from unittest import TestCase from openid.yadis import xri diff --git a/openid/test/test_xrires.py b/openid/test/test_xrires.py index 9a02bec5..48435e3a 100644 --- a/openid/test/test_xrires.py +++ b/openid/test/test_xrires.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals from unittest import TestCase diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index 472eef37..8e7add87 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -1,10 +1,9 @@ -#!/usr/bin/env python - """Tests for yadis.discover. @todo: Now that yadis.discover uses urljr.fetchers, we should be able to do tests with a mock fetcher instead of spawning threads with BaseHTTPServer. """ +from __future__ import unicode_literals import re import types diff --git a/openid/yadis/accept.py b/openid/yadis/accept.py index 2353bfbf..faccffab 100644 --- a/openid/yadis/accept.py +++ b/openid/yadis/accept.py @@ -1,21 +1,28 @@ -"""Functions for generating and parsing HTTP Accept: headers for -supporting server-directed content negotiation. -""" +"""Functions for generating and parsing HTTP Accept: headers for supporting server-directed content negotiation.""" +from __future__ import unicode_literals + from operator import itemgetter +import six + +from openid.oidutil import string_to_text + def generateAcceptHeader(*elements): """Generate an accept header value - [str or (str, float)] -> str + [six.text_type or (six.text_type, float)] -> six.text_type """ parts = [] for element in elements: - if isinstance(element, str): + if isinstance(element, six.string_types): qs = "1.0" - mtype = element + mtype = string_to_text(element, + "Binary values for generateAcceptHeader are deprecated. Use text input instead.") else: mtype, q = element + mtype = string_to_text(mtype, + "Binary values for generateAcceptHeader are deprecated. Use text input instead.") q = float(q) if q > 1 or q <= 0: raise ValueError('Invalid preference factor: %r' % q) @@ -41,8 +48,9 @@ def parseAcceptHeader(value): returns a list of tuples containing main MIME type, MIME subtype, and quality markdown. - str -> [(str, str, float)] + six.text_type -> [(six.text_type, six.text_type, float)] """ + value = string_to_text(value, "Binary values for parseAcceptHeader are deprecated. Use text input instead.") chunks = [chunk.strip() for chunk in value.split(',')] accept = [] for chunk in chunks: @@ -86,7 +94,7 @@ def matchTypes(accept_types, have_types): [('text/html', 1.0), ('text/plain', 0.5)] - Type signature: ([(str, str, float)], [str]) -> [(str, float)] + Type signature: ([(six.text_type, six.text_type, float)], [six.text_type]) -> [(six.text_type, float)] """ if not accept_types: # Accept all of them @@ -97,6 +105,8 @@ def matchTypes(accept_types, have_types): match_main = {} match_sub = {} for (main, sub, qvalue) in accept_types: + main = string_to_text(main, "Binary values for matchTypes accept_types are deprecated. Use text input instead.") + sub = string_to_text(sub, "Binary values for matchTypes accept_types are deprecated. Use text input instead.") if main == '*': default = max(default, qvalue) continue @@ -108,6 +118,7 @@ def matchTypes(accept_types, have_types): accepted_list = [] order_maintainer = 0 for mtype in have_types: + mtype = string_to_text(mtype, "Binary values for matchTypes have_types are deprecated. Use text input instead.") main, sub = mtype.split('/') if (main, sub) in match_sub: quality = match_sub[(main, sub)] @@ -119,7 +130,7 @@ def matchTypes(accept_types, have_types): order_maintainer += 1 accepted_list.sort() - return [(mtype, q) for (_, _, q, mtype) in accepted_list] + return [(match, q) for (_, _, q, match) in accepted_list] def getAcceptable(accept_header, have_types): @@ -130,8 +141,10 @@ def getAcceptable(accept_header, have_types): This is a convenience wrapper around matchTypes and parseAcceptHeader. - (str, [str]) -> [str] + (six.text_type, [six.text_type]) -> [six.text_type] """ + accept_header = string_to_text( + accept_header, "Binary values for getAcceptable accept_header are deprecated. Use text input instead.") accepted = parseAcceptHeader(accept_header) preferred = matchTypes(accepted, have_types) return [mtype for (mtype, _) in preferred] diff --git a/openid/yadis/constants.py b/openid/yadis/constants.py index d160c66f..6edb4640 100644 --- a/openid/yadis/constants.py +++ b/openid/yadis/constants.py @@ -1,6 +1,9 @@ -__all__ = ['YADIS_HEADER_NAME', 'YADIS_CONTENT_TYPE', 'YADIS_ACCEPT_HEADER'] +from __future__ import unicode_literals + from openid.yadis.accept import generateAcceptHeader +__all__ = ['YADIS_HEADER_NAME', 'YADIS_CONTENT_TYPE', 'YADIS_ACCEPT_HEADER'] + YADIS_HEADER_NAME = 'X-XRDS-Location' YADIS_CONTENT_TYPE = 'application/xrds+xml' diff --git a/openid/yadis/discover.py b/openid/yadis/discover.py index 83655a90..e1b494fa 100644 --- a/openid/yadis/discover.py +++ b/openid/yadis/discover.py @@ -1,5 +1,4 @@ -# -*- test-case-name: openid.test.test_yadis_discover -*- -__all__ = ['discover', 'DiscoveryResult', 'DiscoveryFailure'] +from __future__ import unicode_literals from StringIO import StringIO @@ -7,6 +6,8 @@ from openid.yadis.constants import YADIS_ACCEPT_HEADER, YADIS_CONTENT_TYPE, YADIS_HEADER_NAME from openid.yadis.parsehtml import MetaNotFound, findHTMLMeta +__all__ = ['discover', 'DiscoveryResult', 'DiscoveryFailure'] + class DiscoveryFailure(Exception): """Raised when a YADIS protocol error occurs in the discovery process""" @@ -107,7 +108,7 @@ def whereIsYadis(resp): [non-blocking] - @returns: str or None + @returns: six.text_type or None """ # Attempt to find out where to go to discover the document # or if we already have it diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index 4039c9ec..019b82ef 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -2,6 +2,15 @@ """ ElementTree interface to an XRD document. """ +from __future__ import unicode_literals + +import random +from datetime import datetime +from time import strptime + +from lxml import etree + +from openid.yadis import xri __all__ = [ 'nsTag', @@ -18,14 +27,6 @@ 'expandServices', ] -import random -from datetime import datetime -from time import strptime - -from lxml import etree - -from openid.yadis import xri - class XRDSError(Exception): """An error with the XRDS document.""" @@ -43,6 +44,8 @@ class XRDSFraud(XRDSError): def parseXRDS(text): """Parse the given text as an XRDS document. + @type text: six.binary_type + @return: ElementTree containing an XRDS document @raises XRDSError: When there is a parse error or the document does @@ -72,7 +75,7 @@ def nsTag(ns, t): def mkXRDTag(t): - """basestring -> basestring + """six.text_type -> six.text_type Create a tag name in the XRD 2.0 XML namespace suitable for using with ElementTree @@ -81,7 +84,7 @@ def mkXRDTag(t): def mkXRDSTag(t): - """basestring -> basestring + """six.text_type -> six.text_type Create a tag name in the XRDS XML namespace suitable for using with ElementTree diff --git a/openid/yadis/filters.py b/openid/yadis/filters.py index 0d87ad0e..c60e5ecc 100644 --- a/openid/yadis/filters.py +++ b/openid/yadis/filters.py @@ -2,6 +2,10 @@ endpoint information out of a Yadis XRD file using the ElementTree XML parser. """ +from __future__ import unicode_literals + +from openid.oidutil import string_to_text +from openid.yadis.etxrd import expandService __all__ = [ 'BasicServiceEndpoint', @@ -11,8 +15,6 @@ 'CompoundFilter', ] -from openid.yadis.etxrd import expandService - class BasicServiceEndpoint(object): """Generic endpoint object that contains parsed service @@ -41,11 +43,13 @@ def matchTypes(self, type_uris): of a single protocol. @param type_uris: The URIs that you wish to check - @type type_uris: iterable of str + @type type_uris: Iterable[six.text_type], six.binary_type is deprecated @return: all types that are in both in type_uris and self.type_uris """ + type_uris = [string_to_text(u, "Binary values for matchTypes are deprecated. Use text input instead.") + for u in type_uris] return [uri for uri in type_uris if uri in self.type_uris] @staticmethod diff --git a/openid/yadis/manager.py b/openid/yadis/manager.py index afd55eea..4da7c641 100644 --- a/openid/yadis/manager.py +++ b/openid/yadis/manager.py @@ -1,3 +1,6 @@ +from __future__ import unicode_literals + + class YadisServiceManager(object): """Holds the state of a list of selected Yadis services, managing storing it in a session and iterating over the services in order.""" @@ -93,7 +96,7 @@ def getNextService(self, discover): @param discover: a callable that takes a URL and returns a list of services - @type discover: str -> [service] + @type discover: six.text_type -> [service] @return: the next available service @@ -140,7 +143,7 @@ def getSessionKey(self): """Get the session key for this starting URL and suffix @return: The session key - @rtype: str + @rtype: six.text_type """ return self.PREFIX + self.session_key_suffix diff --git a/openid/yadis/services.py b/openid/yadis/services.py index 740fec0a..f45d4e9a 100644 --- a/openid/yadis/services.py +++ b/openid/yadis/services.py @@ -1,4 +1,6 @@ -# -*- test-case-name: openid.test.test_services -*- +from __future__ import unicode_literals + +import six from openid.yadis.discover import DiscoveryFailure, discover from openid.yadis.etxrd import XRDSError, iterServices, parseXRDS @@ -19,7 +21,7 @@ def getServiceEndpoints(input_url, flt=None): @return: The normalized identity URL and an iterable of endpoint objects generated by the filter function. - @rtype: (str, [endpoint]) + @rtype: (six.text_type, [endpoint]) @raises DiscoveryFailure: when Yadis fails to obtain an XRDS document. """ @@ -28,7 +30,7 @@ def getServiceEndpoints(input_url, flt=None): endpoints = applyFilter(result.normalized_uri, result.response_text, flt) except XRDSError as err: - raise DiscoveryFailure(str(err), None) + raise DiscoveryFailure(six.text_type(err), None) return (result.normalized_uri, endpoints) @@ -39,11 +41,9 @@ def applyFilter(normalized_uri, xrd_data, flt=None): @param normalized_uri: The input URL, after following redirects, as in the Yadis protocol. - @param xrd_data: The XML text the XRDS file fetched from the normalized URI. - @type xrd_data: str - + @type xrd_data: six.binary_type """ flt = mkFilter(flt) et = parseXRDS(xrd_data) diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index 7728c600..ea394f15 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -4,6 +4,8 @@ @see: XRI Syntax v2.0 at the U{OASIS XRI Technical Committee} """ +from __future__ import unicode_literals + import re import warnings from urllib import quote diff --git a/openid/yadis/xrires.py b/openid/yadis/xrires.py index fe54c48d..26bffb34 100644 --- a/openid/yadis/xrires.py +++ b/openid/yadis/xrires.py @@ -1,10 +1,10 @@ -# -*- test-case-name: openid.test.test_xrires -*- -"""XRI resolution. -""" +"""XRI resolution.""" +from __future__ import unicode_literals from urllib import urlencode from openid import fetchers +from openid.oidutil import string_to_text from openid.yadis import etxrd from openid.yadis.services import iterServices from openid.yadis.xri import toURINormal @@ -27,10 +27,10 @@ def queryURL(self, xri, service_type=None): @param service_type: The service type to resolve, if you desire service endpoint selection. A service type is a URI. - @type service_type: str + @type service_type: Optional[six.text_type], six.binary_type is deprecated @returns: a URL - @returntype: str + @returntype: six.text_type """ # Trim off the xri:// prefix. The proxy resolver didn't accept it # when this code was written, but that may (or may not) change for @@ -45,6 +45,8 @@ def queryURL(self, xri, service_type=None): '_xrd_r': 'application/xrds+xml', } if service_type: + service_type = string_to_text(service_type, + "Binary values for service_type are deprecated. Use text input instead.") args['_xrd_t'] = service_type else: # Don't perform service endpoint selection. @@ -67,7 +69,7 @@ def query(self, xri, service_types): @param service_types: A list of services types to query for. Service types are URIs. - @type service_types: list of str + @type service_types: List[six.text_type], six.binary_type is deprecated @returns: tuple of (CanonicalID, Service elements) @returntype: (six.text_type, list of C{ElementTree.Element}s) From 61e6d6a135600db480ed14ec8f6fbb8676668dc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 4 May 2018 13:41:04 +0200 Subject: [PATCH 073/151] Transform Extensions API to text strings --- openid/extensions/ax.py | 39 ++++++++++++++++++++------------ openid/extensions/draft/pape2.py | 21 +++++++++-------- openid/extensions/pape.py | 23 +++++++++++-------- openid/extensions/sreg.py | 37 ++++++++++++++++++++---------- openid/sreg.py | 1 + openid/test/test_ax.py | 4 ++-- openid/test/test_pape.py | 2 ++ openid/test/test_pape_draft2.py | 2 ++ openid/test/test_pape_draft5.py | 2 ++ openid/test/test_sreg.py | 2 ++ 10 files changed, 85 insertions(+), 48 deletions(-) diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index 39b85cdc..6bbad7c0 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -1,8 +1,15 @@ -# -*- test-case-name: openid.test.test_ax -*- """Implements the OpenID Attribute Exchange specification, version 1.0. @since: 2.1.0 """ +from __future__ import unicode_literals + +import six + +from openid import extension +from openid.message import OPENID_NS, NamespaceMap +from openid.oidutil import string_to_text +from openid.server.trustroot import TrustRoot __all__ = [ 'AttrInfo', @@ -12,10 +19,6 @@ 'StoreResponse', ] -from openid import extension -from openid.message import OPENID_NS, NamespaceMap -from openid.server.trustroot import TrustRoot - # Use this as the 'count' value for an attribute in a FetchRequest to # ask for as many values as the OP can provide. UNLIMITED_VALUES = "unlimited" @@ -107,7 +110,7 @@ class AttrInfo(object): represents and how it is serialized. For example, one type URI representing dates could represent a Unix timestamp in base 10 and another could represent a human-readable string. - @type type_uri: str + @type type_uri: six.text_type @ivar alias: The name that should be given to this alias in the request. If it is not supplied, a generic name will be @@ -115,7 +118,7 @@ class AttrInfo(object): value 'tstamp', set its alias to that value. If two attributes in the same message request to use the same alias, the request will fail to be generated. - @type alias: str or NoneType + @type alias: six.text_type or NoneType """ def __init__(self, type_uri, count=1, required=False, alias=None): @@ -148,7 +151,7 @@ def toTypeURIs(namespace_map, alias_list_s): @param alias_list_s: The string containing the comma-separated list of aliases. May also be None for convenience. - @type alias_list_s: str or NoneType + @type alias_list_s: Optional[six.text_type], six.binary_type is deprecated @returns: The list of namespace URIs that corresponds to the supplied list of aliases. If the string was zero-length or @@ -160,6 +163,8 @@ def toTypeURIs(namespace_map, alias_list_s): uris = [] if alias_list_s: + alias_list_s = string_to_text(alias_list_s, + "Binary values for alias_list_s are deprecated. Use text input instead.") for alias in alias_list_s.split(','): type_uri = namespace_map.getNamespaceURI(alias) if type_uri is None: @@ -178,7 +183,7 @@ class FetchRequest(AXMessage): @ivar requested_attributes: The attributes that have been requested thus far, indexed by the type URI. - @type requested_attributes: {str:AttrInfo} + @type requested_attributes: Dict[six.text_type, AttrInfo] @ivar update_url: A URL that will accept responses for this attribute exchange request, even in the absence of the user @@ -246,7 +251,7 @@ def getExtensionArgs(self): if_available.append(alias) if attribute.count != 1: - ax_args['count.' + alias] = str(attribute.count) + ax_args['count.' + alias] = six.text_type(attribute.count) ax_args['type.' + alias] = type_uri @@ -264,7 +269,7 @@ def getRequiredAttrs(self): @returns: A list of the type URIs for attributes that have been marked as required. - @rtype: [str] + @rtype: List[six.text_type] """ required = [] for type_uri, attribute in self.requested_attributes.iteritems(): @@ -457,7 +462,7 @@ def _getExtensionKVArgs(self, aliases=None): alias = aliases.add(type_uri) ax_args['type.' + alias] = type_uri - ax_args['count.' + alias] = str(len(values)) + ax_args['count.' + alias] = six.text_type(len(values)) for i, value in enumerate(values): key = 'value.%s.%d' % (alias, i + 1) @@ -498,7 +503,7 @@ def parseExtensionArgs(self, ax_args): except KeyError: value = ax_args['value.' + alias] - if value == u'': + if value == '': values = [] else: values = [value] @@ -517,8 +522,8 @@ def getSingle(self, type_uri, default=None): for this attribute, use the supplied default. If there is more than one value for this attribute, this method will fail. - @type type_uri: str @param type_uri: The URI for the attribute + @type type_uri: six.text_type, six.binary_type is deprecated @param default: The value to return if the attribute was not sent in the fetch_response. @@ -531,6 +536,7 @@ def getSingle(self, type_uri, default=None): parameter in the fetch_response message. @raises KeyError: If the attribute was not sent in this response """ + type_uri = string_to_text(type_uri, "Binary values for type_uri are deprecated. Use text input instead.") values = self.data.get(type_uri) if not values: return default @@ -593,9 +599,12 @@ def __init__(self, request=None, update_url=None): request. But if you do not supply the request, you may set the C{update_url} here. - @type update_url: str + @type update_url: Optional[six.text_type], six.binary_type is deprecated """ AXKeyValueMessage.__init__(self) + if update_url is not None: + update_url = string_to_text(update_url, + "Binary values for update_url are deprecated. Use text input instead.") self.update_url = update_url self.request = request diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index a26ddfcb..529f329d 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -5,6 +5,14 @@ @since: 2.1.0 """ +from __future__ import unicode_literals + +import re +import warnings + +import six + +from openid.extension import Extension __all__ = [ 'Request', @@ -15,11 +23,6 @@ 'AUTH_MULTI_FACTOR_PHYSICAL', ] -import re -import warnings - -from openid.extension import Extension - warnings.warn("Module 'openid.extensions.draft.pape2' is deprecated. Use 'openid.extensions.pape' instead.", DeprecationWarning) @@ -41,7 +44,7 @@ class Request(Extension): @ivar preferred_auth_policies: The authentication policies that the relying party prefers - @type preferred_auth_policies: [str] + @type preferred_auth_policies: List[six.text_type] @ivar max_auth_age: The maximum time, in seconds, that the relying party wants to allow to have elapsed before the user must @@ -84,7 +87,7 @@ def getExtensionArgs(self): } if self.max_auth_age is not None: - ns_args['max_auth_age'] = str(self.max_auth_age) + ns_args['max_auth_age'] = six.text_type(self.max_auth_age) return ns_args @@ -147,7 +150,7 @@ def preferredTypes(self, supported_types): sequence, and may be empty if the provider does not prefer any of the supported authentication types. - @returntype: [str] + @returntype: List[six.text_type] """ return [i for i in supported_types if i in self.preferred_auth_policies] @@ -268,7 +271,7 @@ def getExtensionArgs(self): if self.nist_auth_level not in range(0, 5): raise ValueError('nist_auth_level must be an integer between ' 'zero and four, inclusive') - ns_args['nist_auth_level'] = str(self.nist_auth_level) + ns_args['nist_auth_level'] = six.text_type(self.nist_auth_level) if self.auth_time is not None: if not TIME_VALIDATOR.match(self.auth_time): diff --git a/openid/extensions/pape.py b/openid/extensions/pape.py index 70655006..d69c4dbd 100644 --- a/openid/extensions/pape.py +++ b/openid/extensions/pape.py @@ -5,6 +5,14 @@ @since: 2.1.0 """ +from __future__ import unicode_literals + +import re +import warnings + +import six + +from openid.extension import Extension __all__ = [ 'Request', @@ -17,11 +25,6 @@ 'LEVELS_JISA', ] -import re -import warnings - -from openid.extension import Extension - ns_uri = "http://specs.openid.net/extensions/pape/1.0" AUTH_MULTI_FACTOR_PHYSICAL = \ @@ -98,7 +101,7 @@ class Request(PAPEExtension): @ivar preferred_auth_policies: The authentication policies that the relying party prefers - @type preferred_auth_policies: [str] + @type preferred_auth_policies: List[six.text_type] @ivar max_auth_age: The maximum time, in seconds, that the relying party wants to allow to have elapsed before the user must @@ -108,7 +111,7 @@ class Request(PAPEExtension): @ivar preferred_auth_level_types: Ordered list of authentication level namespace URIs - @type preferred_auth_level_types: [str] + @type preferred_auth_level_types: List[six.text_type] """ ns_alias = 'pape' @@ -158,7 +161,7 @@ def getExtensionArgs(self): } if self.max_auth_age is not None: - ns_args['max_auth_age'] = str(self.max_auth_age) + ns_args['max_auth_age'] = six.text_type(self.max_auth_age) if self.preferred_auth_level_types: preferred_types = [] @@ -262,7 +265,7 @@ def preferredTypes(self, supported_types): sequence, and may be empty if the provider does not prefer any of the supported authentication types. - @returntype: [str] + @returntype: List[six.text_type] """ return [i for i in supported_types if i in self.preferred_auth_policies] @@ -459,7 +462,7 @@ def getExtensionArgs(self): for level_type, level in self.auth_levels.iteritems(): alias = self._getAlias(level_type) ns_args['auth_level.ns.%s' % (alias,)] = level_type - ns_args['auth_level.%s' % (alias,)] = str(level) + ns_args['auth_level.%s' % (alias,)] = six.text_type(level) if self.auth_time is not None: if not TIME_VALIDATOR.match(self.auth_time): diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 4bdb262e..7523e563 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -34,11 +34,15 @@ @var sreg_uri: The preferred URI to use for the simple registration namespace and XRD Type value """ +from __future__ import unicode_literals import logging +import six + from openid.extension import Extension from openid.message import NamespaceAliasRegistrationError, registerNamespaceAlias +from openid.oidutil import string_to_text __all__ = [ 'SRegRequest', @@ -138,7 +142,7 @@ def getSRegNS(message): @returns: the sreg namespace URI for the supplied message. The message may be modified to define a simple registration namespace. - @rtype: C{str} + @rtype: six.text_type @raise ValueError: when using OpenID 1 if the message defines the 'sreg' alias to be something other than a simple @@ -169,14 +173,14 @@ class SRegRequest(Extension): @ivar required: A list of the required fields in this simple registration request - @type required: [str] + @type required: List[six.text_type] @ivar optional: A list of the optional fields in this simple registration request - @type optional: [str] + @type optional: List[six.text_type] @ivar policy_url: The policy URL that was provided with the request - @type policy_url: str or NoneType + @type policy_url: Optional[six.text_type] @group Consumer: requestField, requestFields, getExtensionArgs, addToOpenIDRequest @group Server: fromOpenIDRequest, parseExtensionArgs @@ -246,7 +250,7 @@ def parseExtensionArgs(self, args, strict=False): >>> request.parseExtensionArgs(args) @param args: The unqualified simple registration arguments - @type args: {str:str} + @type args: Dict[six.text_type, six.text_type], six.binary_type is deprecated @param strict: Whether requests with fields that are not defined in the simple registration specification should be @@ -259,6 +263,7 @@ def parseExtensionArgs(self, args, strict=False): required = (list_name == 'required') items = args.get(list_name) if items: + items = string_to_text(items, "Binary values for args are deprecated. Use text input instead.") for field_name in items.split(','): try: self.requestField(field_name, required, strict) @@ -266,13 +271,17 @@ def parseExtensionArgs(self, args, strict=False): if strict: raise - self.policy_url = args.get('policy_url') + policy_url = args.get('policy_url') + if policy_url is not None: + policy_url = string_to_text(args.get('policy_url'), + "Binary values for args are deprecated. Use text input instead.") + self.policy_url = policy_url def allRequestedFields(self): """A list of all of the simple registration fields that were requested, whether they were required or optional. - @rtype: [str] + @rtype: List[six.text_type] """ return self.required + self.optional @@ -292,7 +301,7 @@ def requestField(self, field_name, required=False, strict=False): """Request the specified field from the OpenID user @param field_name: the unqualified simple registration field name - @type field_name: str + @type field_name: six.text_type, six.binary_type is deprecated @param required: whether the given field should be presented to the user as being a required to successfully complete @@ -305,6 +314,7 @@ def requestField(self, field_name, required=False, strict=False): registration field or strict is set and the field was requested more than once """ + field_name = string_to_text(field_name, "Binary values for field_name are deprecated. Use text input instead.") checkFieldName(field_name) if strict: @@ -329,7 +339,7 @@ def requestFields(self, field_names, required=False, strict=False): """Add the given list of fields to the request @param field_names: The simple registration data fields to request - @type field_names: [str] + @type field_names: List[six.text_type], six.binary_type is deprecated @param required: Whether these values should be presented to the user as required @@ -341,11 +351,13 @@ def requestFields(self, field_names, required=False, strict=False): registration field or strict is set and a field was requested more than once """ - if isinstance(field_names, basestring): + if isinstance(field_names, six.string_types): raise TypeError('Fields should be passed as a list of ' 'strings (not %r)' % (type(field_names),)) for field_name in field_names: + field_name = string_to_text(field_name, + "Binary values for field_names are deprecated. Use text input instead.") self.requestField(field_name, required, strict=strict) def getExtensionArgs(self): @@ -356,7 +368,7 @@ def getExtensionArgs(self): C{L{parseExtensionArgs}}. This method serializes the simple registration request fields. - @rtype: {str:str} + @rtype: Dict[six.text_type, six.text_type] """ args = {} @@ -417,7 +429,7 @@ def extractResponse(cls, request, data): registration field name to string (unicode) value. For instance, the nickname should be stored under the key 'nickname'. - @type data: {str:str} + @type data: Dict[six.text_type, six.text_type], six.binary_type is deprecated @returns: a simple registration response object @rtype: SRegResponse @@ -427,6 +439,7 @@ def extractResponse(cls, request, data): for field in request.allRequestedFields(): value = data.get(field) if value is not None: + value = string_to_text(value, "Binary values for data are deprecated. Use text input instead.") self.data[field] = value return self diff --git a/openid/sreg.py b/openid/sreg.py index bceb53fe..20b0d8ac 100644 --- a/openid/sreg.py +++ b/openid/sreg.py @@ -1,4 +1,5 @@ """moved to L{openid.extensions.sreg}""" +from __future__ import unicode_literals import warnings diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 3221169c..adff4d9f 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -1,5 +1,5 @@ -"""Tests for the attribute exchange extension module -""" +"""Tests for the attribute exchange extension module.""" +from __future__ import unicode_literals import unittest diff --git a/openid/test/test_pape.py b/openid/test/test_pape.py index 056fb891..d6f55cd6 100644 --- a/openid/test/test_pape.py +++ b/openid/test/test_pape.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest import warnings diff --git a/openid/test/test_pape_draft2.py b/openid/test/test_pape_draft2.py index 67ebcc73..a1d3d87d 100644 --- a/openid/test/test_pape_draft2.py +++ b/openid/test/test_pape_draft2.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest from openid.extensions.draft import pape2 as pape diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index 95852066..fdb783d7 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest from openid.extensions import pape diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index 80fd4420..b56c5c63 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest from openid.extensions import sreg From 2f57983ab8ea63f4438acd048cce50cce1e142d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 13:24:16 +0200 Subject: [PATCH 074/151] Transform remaining API to text strings --- .flake8 | 2 +- .isort.cfg | 1 + admin/builddiscover.py | 2 + admin/gettlds.py | 2 + contrib/associate | 4 +- contrib/openid-parse | 5 +- contrib/upgrade-store-1.1-to-2.0 | 9 +- examples/consumer.py | 6 +- examples/discover | 4 +- examples/djopenid/consumer/models.py | 1 + examples/djopenid/consumer/urls.py | 2 + examples/djopenid/consumer/views.py | 5 +- examples/djopenid/manage.py | 2 + examples/djopenid/server/models.py | 1 + examples/djopenid/server/tests.py | 7 +- examples/djopenid/server/urls.py | 2 + examples/djopenid/server/views.py | 4 +- examples/djopenid/settings.py | 2 + examples/djopenid/urls.py | 2 + examples/djopenid/util.py | 2 + examples/server.py | 2 + openid/consumer/consumer.py | 51 +++++----- openid/consumer/discover.py | 37 +++---- openid/extension.py | 2 + openid/extensions/draft/pape5.py | 2 + openid/fetchers.py | 21 ++-- openid/server/server.py | 119 ++++++++++++++--------- openid/server/trustroot.py | 34 ++++--- openid/store/filestore.py | 45 ++++++--- openid/store/interface.py | 28 ++---- openid/store/memstore.py | 5 +- openid/store/nonce.py | 25 +++-- openid/store/sqlstore.py | 52 ++++++---- openid/test/discoverdata.py | 2 + openid/test/test_association_response.py | 4 +- openid/test/test_auth_request.py | 2 + openid/test/test_consumer.py | 17 ++-- openid/test/test_discover.py | 2 + openid/test/test_extension.py | 2 + openid/test/test_fetchers.py | 24 ++--- openid/test/test_htmldiscover.py | 2 + openid/test/test_negotiation.py | 14 ++- openid/test/test_nonce.py | 2 + openid/test/test_rpverify.py | 18 ++-- openid/test/test_server.py | 38 +++----- openid/test/test_storetest.py | 2 + openid/test/test_symbol.py | 2 + openid/test/test_trustroot.py | 4 +- openid/test/test_verifydisco.py | 2 + openid/test/utils.py | 2 + openid/urinorm.py | 7 +- setup.py | 2 + 52 files changed, 385 insertions(+), 251 deletions(-) diff --git a/.flake8 b/.flake8 index 75ab4379..843eae45 100644 --- a/.flake8 +++ b/.flake8 @@ -2,4 +2,4 @@ max-line-length = 120 # Ignore E123 - enforce hang-closing instead ignore = E123,W503 -max-complexity = 22 +max-complexity = 24 diff --git a/.isort.cfg b/.isort.cfg index 3bf03262..4d1707b3 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,3 +3,4 @@ line_length = 120 combine_as_imports = true default_section = THIRDPARTY known_first_party = openid +add_imports = from __future__ import unicode_literals diff --git a/admin/builddiscover.py b/admin/builddiscover.py index 82681280..9b5da070 100755 --- a/admin/builddiscover.py +++ b/admin/builddiscover.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +from __future__ import unicode_literals + import os.path import urlparse diff --git a/admin/gettlds.py b/admin/gettlds.py index b2a7c92c..c4892769 100644 --- a/admin/gettlds.py +++ b/admin/gettlds.py @@ -8,6 +8,8 @@ Then cut-n-paste. """ +from __future__ import unicode_literals + import sys import urllib2 diff --git a/contrib/associate b/contrib/associate index 17eae99f..d84cfb31 100755 --- a/contrib/associate +++ b/contrib/associate @@ -1,6 +1,6 @@ #!/usr/bin/env python -"""Make an OpenID Assocition request against an endpoint -and print the results.""" +"""Make an OpenID Assocition request against an endpoint and print the results.""" +from __future__ import unicode_literals import sys from datetime import datetime diff --git a/contrib/openid-parse b/contrib/openid-parse index 5915ad36..b6c59ea9 100644 --- a/contrib/openid-parse +++ b/contrib/openid-parse @@ -6,6 +6,7 @@ with a pattern like 'GET /foo?bar=baz HTTP'. Requires the 'xsel' program to get the contents of the clipboard. """ +from __future__ import unicode_literals import re import subprocess @@ -13,6 +14,8 @@ import sys from pprint import pformat from urlparse import parse_qs, urlsplit, urlunsplit +import six + from openid import message OPENID_SORT_ORDER = ['mode', 'identity', 'claimed_id'] @@ -78,7 +81,7 @@ def openidFromQuery(query): s = formatOpenIDMessage(msg) except Exception as err: # XXX - side effect. - sys.stderr.write(str(err)) + sys.stderr.write(six.text_type(err)) s = pformat(query) return s diff --git a/contrib/upgrade-store-1.1-to-2.0 b/contrib/upgrade-store-1.1-to-2.0 index 2e73c4a2..48a62552 100644 --- a/contrib/upgrade-store-1.1-to-2.0 +++ b/contrib/upgrade-store-1.1-to-2.0 @@ -13,12 +13,15 @@ # TODO: # * test data for mysql and postgresql. # * automated tests. +from __future__ import unicode_literals import getpass import os import sys from optparse import OptionParser +import six + def askForPassword(): return getpass.getpass("DB Password: ") @@ -111,7 +114,7 @@ def main(argv=None): try: db_conn = sqlite.connect(options.sqlite_db_name) except Exception as e: - print "Could not connect to SQLite database:", str(e) + print "Could not connect to SQLite database:", six.text_type(e) return 1 if askForConfirmation(options.sqlite_db_name, options.tablename): @@ -134,7 +137,7 @@ def main(argv=None): host=options.db_host, password=password) except Exception as e: - print "Could not connect to PostgreSQL database:", str(e) + print "Could not connect to PostgreSQL database:", six.text_type(e) return 1 if askForConfirmation(options.postgres_db_name, options.tablename): @@ -155,7 +158,7 @@ def main(argv=None): db_conn = MySQLdb.connect(options.db_host, options.username, password, options.mysql_db_name) except Exception as e: - print "Could not connect to MySQL database:", str(e) + print "Could not connect to MySQL database:", six.text_type(e) return 1 if askForConfirmation(options.mysql_db_name, options.tablename): diff --git a/examples/consumer.py b/examples/consumer.py index d39a2608..dbce4117 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -8,6 +8,8 @@ """ __copyright__ = 'Copyright 2005-2008, Janrain, Inc.' +from __future__ import unicode_literals + import cgi import cgitb import optparse @@ -16,6 +18,8 @@ from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from Cookie import SimpleCookie +import six + def quoteattr(s): qs = cgi.escape(s, 1) @@ -176,7 +180,7 @@ def doVerify(self): request = oidconsumer.begin(openid_url) except consumer.DiscoveryFailure as exc: fetch_error_string = 'Error in discovery: %s' % ( - cgi.escape(str(exc[0]))) + cgi.escape(six.text_type(exc[0]))) self.render(fetch_error_string, css_class='error', form_contents=openid_url) diff --git a/examples/discover b/examples/discover index e2ede67e..b334d94e 100644 --- a/examples/discover +++ b/examples/discover @@ -1,5 +1,7 @@ #!/usr/bin/env python -from openid.consumer.discover import discover, DiscoveryFailure +from __future__ import unicode_literals + +from openid.consumer.discover import DiscoveryFailure, discover from openid.fetchers import HTTPFetchingError names = [["server_url", "Server URL "], diff --git a/examples/djopenid/consumer/models.py b/examples/djopenid/consumer/models.py index b194906e..d9781327 100644 --- a/examples/djopenid/consumer/models.py +++ b/examples/djopenid/consumer/models.py @@ -1 +1,2 @@ """Required module for Django application.""" +from __future__ import unicode_literals diff --git a/examples/djopenid/consumer/urls.py b/examples/djopenid/consumer/urls.py index 9b37b1aa..b13f966e 100644 --- a/examples/djopenid/consumer/urls.py +++ b/examples/djopenid/consumer/urls.py @@ -1,4 +1,6 @@ """Consumer URLs.""" +from __future__ import unicode_literals + from django.conf.urls import url from djopenid.consumer.views import finishOpenID, rpXRDS, startOpenID diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index 74d26fd5..776dd0dc 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -1,3 +1,6 @@ +from __future__ import unicode_literals + +import six from django.http import HttpResponseRedirect from django.shortcuts import render from django.urls import reverse @@ -70,7 +73,7 @@ def startOpenID(request): auth_request = c.begin(openid_url) except DiscoveryFailure as e: # Some other protocol-level failure occurred. - error = "OpenID discovery error: %s" % (str(e),) + error = "OpenID discovery error: %s" % (six.text_type(e),) if error: # Render the page with an error. diff --git a/examples/djopenid/manage.py b/examples/djopenid/manage.py index fb88042f..2e2e83af 100644 --- a/examples/djopenid/manage.py +++ b/examples/djopenid/manage.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +from __future__ import unicode_literals + import os import sys diff --git a/examples/djopenid/server/models.py b/examples/djopenid/server/models.py index b194906e..d9781327 100644 --- a/examples/djopenid/server/models.py +++ b/examples/djopenid/server/models.py @@ -1 +1,2 @@ """Required module for Django application.""" +from __future__ import unicode_literals diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index 2a3b86b4..bdc08849 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -1,12 +1,13 @@ +from __future__ import unicode_literals + from urlparse import urljoin import django from django.http import HttpRequest from django.test.testcases import TestCase from django.urls import reverse - from openid.message import Message -from openid.server.server import CheckIDRequest, HTTP_REDIRECT +from openid.server.server import HTTP_REDIRECT, CheckIDRequest from openid.yadis.constants import YADIS_CONTENT_TYPE from openid.yadis.services import applyFilter @@ -85,7 +86,7 @@ def test_unreachableRealm(self): views.setRequest(self.request, self.openid_request) response = views.showDecidePage(self.request, self.openid_request) - self.assertIn('trust_root_valid is Unreachable', response.content) + self.assertContains(response, 'trust_root_valid is Unreachable') class TestGenericXRDS(TestCase): diff --git a/examples/djopenid/server/urls.py b/examples/djopenid/server/urls.py index 2eff514f..ead27ff8 100644 --- a/examples/djopenid/server/urls.py +++ b/examples/djopenid/server/urls.py @@ -1,4 +1,6 @@ """Server URLs.""" +from __future__ import unicode_literals + from django.conf.urls import url from django.views.generic import TemplateView diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index 2db2a415..9b799ee7 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -14,9 +14,11 @@ * 'openid_response' is an OpenID library response """ +from __future__ import unicode_literals import cgi +import six from django import http from django.shortcuts import render from django.urls import reverse @@ -104,7 +106,7 @@ def endpoint(request): openid_request = s.decodeRequest(query) except ProtocolError as why: # This means the incoming request was invalid. - return render(request, 'server/endpoint.html', {'error': str(why)}) + return render(request, 'server/endpoint.html', {'error': six.text_type(why)}) # If we did not get a request, display text indicating that this # is an endpoint. diff --git a/examples/djopenid/settings.py b/examples/djopenid/settings.py index fc2a2b1e..dad0aa17 100644 --- a/examples/djopenid/settings.py +++ b/examples/djopenid/settings.py @@ -1,4 +1,6 @@ """Example Django settings for djopenid project.""" +from __future__ import unicode_literals + import os import sys import warnings diff --git a/examples/djopenid/urls.py b/examples/djopenid/urls.py index 551fc5e7..5bff67fc 100644 --- a/examples/djopenid/urls.py +++ b/examples/djopenid/urls.py @@ -1,4 +1,6 @@ """Djopenid URLs.""" +from __future__ import unicode_literals + from django.conf.urls import include, url from django.views.generic import TemplateView diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index f98f6268..2e0a97c5 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -1,6 +1,8 @@ """ Utility code for the Django example consumer and server. """ +from __future__ import unicode_literals + from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db import connection diff --git a/examples/server.py b/examples/server.py index 5ef52a69..12197495 100644 --- a/examples/server.py +++ b/examples/server.py @@ -2,6 +2,8 @@ __copyright__ = 'Copyright 2005-2008, Janrain, Inc.' +from __future__ import unicode_literals + import cgi import cgitb import Cookie diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 96c6ef73..a190ffc5 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -186,17 +186,21 @@ L{SetupNeededResponse} objects. """ +from __future__ import unicode_literals import copy import logging from urlparse import parse_qsl, urldefrag, urlparse +import six + from openid import cryptutil, fetchers, oidutil, urinorm from openid.association import Association, SessionNegotiator, default_negotiator from openid.consumer.discover import (OPENID_1_0_TYPE, OPENID_1_1_TYPE, OPENID_2_0_TYPE, DiscoveryFailure, OpenIDServiceEndpoint, discover) from openid.dh import DiffieHellman from openid.message import BARE_NS, IDENTIFIER_SELECT, OPENID1_NS, OPENID2_NS, OPENID_NS, Message, no_default +from openid.oidutil import string_to_text from openid.store.nonce import mkNonce, split as splitNonce from openid.yadis.manager import Discovery @@ -373,7 +377,7 @@ def beginWithoutDiscovery(self, service, anonymous=False): try: auth_req.setAnonymous(anonymous) except ValueError as why: - raise ProtocolError(str(why)) + raise ProtocolError(six.text_type(why)) return auth_req @@ -437,7 +441,7 @@ def setAssociationPreference(self, association_preferences): (association type, association session type) pairs that should be allowed for this consumer to use, in order from most preferred to least preferred. - @type association_preferences: [(str, str)] + @type association_preferences: List[Tuple[six.text_type, six.text_type]], six.binary_type is deprecated @returns: None @@ -1054,7 +1058,7 @@ def _verifyDiscoveredServices(self, claimed_id, services, to_match_endpoints): self._verifyDiscoverySingle( endpoint, to_match_endpoint) except ProtocolError as why: - failure_messages.append(str(why)) + failure_messages.append(six.text_type(why)) else: # It matches, so discover verification has # succeeded. Return this endpoint. @@ -1258,20 +1262,23 @@ def _createAssociateRequest(self, endpoint, assoc_type, session_type): @param assoc_type: The association type that the request should ask for. - @type assoc_type: str + @type assoc_type: six.text_type, six.binary_type is deprecated @param session_type: The session type that should be used in the association request. The session_type is used to create an association session object, and that session object is asked for any additional fields that it needs to add to the request. - @type session_type: str + @type session_type: six.text_type, six.binary_type is deprecated @returns: a pair of the association session object and the request message that will be sent to the server. @rtype: (association session type (depends on session_type), openid.message.Message) """ + assoc_type = string_to_text(assoc_type, "Binary values for assoc_type are deprecated. Use text input instead.") + session_type = string_to_text(session_type, + "Binary values for assoc_type are deprecated. Use text input instead.") session_type_class = self.session_types[session_type] assoc_session = session_type_class() @@ -1303,7 +1310,7 @@ def _getOpenID1SessionType(self, assoc_response): return 'no-encryption' @returns: The association type for this message - @rtype: str + @rtype: six.text_type @raises KeyError: when the session_type field is absent. """ @@ -1472,20 +1479,20 @@ def addExtensionArg(self, namespace, key, value): @param namespace: The namespace for the extension. For example, the simple registration extension uses the namespace C{sreg}. - - @type namespace: str + @type namespace: six.text_type, six.binary_type is deprecated @param key: The key within the extension namespace. For example, the nickname field in the simple registration extension's key is C{nickname}. - - @type key: str + @type key: six.text_type, six.binary_type is deprecated @param value: The value to provide to the server for this argument. - - @type value: str + @type value: six.text_type, six.binary_type is deprecated """ + namespace = string_to_text(namespace, "Binary values for namespace are deprecated. Use text input instead.") + key = string_to_text(key, "Binary values for key are deprecated. Use text input instead.") + value = string_to_text(value, "Binary values for value are deprecated. Use text input instead.") self.message.setArg(namespace, key, value) def getMessage(self, realm, return_to=None, immediate=False): @@ -1493,8 +1500,7 @@ def getMessage(self, realm, return_to=None, immediate=False): @param realm: The URL (or URL pattern) that identifies your web site to the user when she is authorizing it. - - @type realm: str + @type realm: six.text_type, six.binary_type is deprecated @param return_to: The URL that the OpenID provider will send the user back to after attempting to verify her identity. @@ -1502,8 +1508,7 @@ def getMessage(self, realm, return_to=None, immediate=False): Not specifying a return_to URL means that the user will not be returned to the site issuing the request upon its completion. - - @type return_to: str + @type return_to: six.text_type, six.binary_type is deprecated @param immediate: If True, the OpenID provider is to send back a response immediately, useful for behind-the-scenes @@ -1517,7 +1522,9 @@ def getMessage(self, realm, return_to=None, immediate=False): @returntype: L{openid.message.Message} """ + realm = string_to_text(realm, "Binary values for realm are deprecated. Use text input instead.") if return_to: + return_to = string_to_text(return_to, "Binary values for return_to are deprecated. Use text input instead.") return_to = oidutil.appendArgs(return_to, self.return_to_args) elif immediate: raise ValueError( @@ -1580,8 +1587,7 @@ def redirectURL(self, realm, return_to=None, immediate=False): @param realm: The URL (or URL pattern) that identifies your web site to the user when she is authorizing it. - - @type realm: str + @type realm: six.text_type, six.binary_type is deprecated @param return_to: The URL that the OpenID provider will send the user back to after attempting to verify her identity. @@ -1589,8 +1595,7 @@ def redirectURL(self, realm, return_to=None, immediate=False): Not specifying a return_to URL means that the user will not be returned to the site issuing the request upon its completion. - - @type return_to: str + @type return_to: six.text_type, six.binary_type is deprecated @param immediate: If True, the OpenID provider is to send back a response immediately, useful for behind-the-scenes @@ -1604,7 +1609,7 @@ def redirectURL(self, realm, return_to=None, immediate=False): @returns: The URL to redirect the user agent to. - @returntype: str + @returntype: six.text_type """ message = self.getMessage(realm, return_to, immediate) return message.toURL(self.endpoint.server_url) @@ -1627,7 +1632,7 @@ def htmlMarkup(self, realm, return_to=None, immediate=False, form_tag_attrs=None @see: formMarkup - @returns: str + @returns: six.text_type """ return oidutil.autoSubmitHTML(self.formMarkup(realm, return_to, immediate, form_tag_attrs)) @@ -1772,7 +1777,7 @@ def getReturnTo(self): initial request, or C{None} if the response did not contain an C{openid.return_to} argument. - @returntype: str + @returntype: six.text_type """ return self.getSigned(OPENID_NS, 'return_to') diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index 08af0a63..a3ffa597 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -1,17 +1,5 @@ -# -*- test-case-name: openid.test.test_discover -*- -"""Functions to discover OpenID endpoints from identifiers. -""" - -__all__ = [ - 'DiscoveryFailure', - 'OPENID_1_0_NS', - 'OPENID_1_0_TYPE', - 'OPENID_1_1_TYPE', - 'OPENID_2_0_TYPE', - 'OPENID_IDP_2_0_TYPE', - 'OpenIDServiceEndpoint', - 'discover', -] +"""Functions to discover OpenID endpoints from identifiers.""" +from __future__ import unicode_literals import logging import urlparse @@ -21,11 +9,23 @@ from openid import fetchers, urinorm from openid.message import OPENID1_NS as OPENID_1_0_MESSAGE_NS, OPENID2_NS as OPENID_2_0_MESSAGE_NS +from openid.oidutil import string_to_text from openid.yadis import filters, xri, xrires from openid.yadis.discover import DiscoveryFailure, discover as yadisDiscover from openid.yadis.etxrd import XRD_NS_2_0, XRDSError, nsTag from openid.yadis.services import applyFilter as extractServices +__all__ = [ + 'DiscoveryFailure', + 'OPENID_1_0_NS', + 'OPENID_1_0_TYPE', + 'OPENID_1_1_TYPE', + 'OPENID_2_0_TYPE', + 'OPENID_IDP_2_0_TYPE', + 'OpenIDServiceEndpoint', + 'discover', +] + _LOGGER = logging.getLogger(__name__) OPENID_1_0_NS = 'http://openid.net/xmlns/1.0' @@ -263,7 +263,7 @@ def findOPLocalIdentifier(service_element, type_uris): @param type_uris: The xrd:Type values present in this service element. This function could extract them, but higher level code needs to do that anyway. - @type type_uris: [str] + @type type_uris: List[six.text_type], six.binary_type is deprecated @raises DiscoveryFailure: when discovery fails. @@ -272,6 +272,8 @@ def findOPLocalIdentifier(service_element, type_uris): @rtype: six.text_type or NoneType """ # XXX: Test this function on its own! + type_uris = [string_to_text(u, "Binary values for text_uris are deprecated. Use text input instead.") + for u in type_uris] # Build the list of tags that could contain the OP-Local Identifier local_id_tags = [] @@ -367,13 +369,14 @@ def discoverYadis(uri): on old-style discovery if Yadis fails. @param uri: normalized identity URL - @type uri: str + @type uri: six.text_type, six.binary_type is deprecated @return: (claimed_id, services) - @rtype: (str, list(OpenIDServiceEndpoint)) + @rtype: (six.text_type, list(OpenIDServiceEndpoint)) @raises DiscoveryFailure: when discovery fails. """ + uri = string_to_text(uri, "Binary values for discoverYadis are deprecated. Use text input instead.") # Might raise a yadis.discover.DiscoveryFailure if no document # came back for that URI at all. I don't think falling back # to OpenID 1.0 discovery on the same URL will help, so don't diff --git a/openid/extension.py b/openid/extension.py index 55e129b5..d8c6828e 100644 --- a/openid/extension.py +++ b/openid/extension.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import warnings from openid import message as message_module diff --git a/openid/extensions/draft/pape5.py b/openid/extensions/draft/pape5.py index 47cf9b20..3a28dc20 100644 --- a/openid/extensions/draft/pape5.py +++ b/openid/extensions/draft/pape5.py @@ -5,6 +5,8 @@ @since: 2.1.0 """ +from __future__ import unicode_literals + import warnings from openid.extensions.pape import (AUTH_MULTI_FACTOR, AUTH_MULTI_FACTOR_PHYSICAL, AUTH_PHISHING_RESISTANT, LEVELS_JISA, diff --git a/openid/fetchers.py b/openid/fetchers.py index 2328fc9b..3bcb9d7d 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -1,11 +1,5 @@ -# -*- test-case-name: openid.test.test_fetchers -*- -""" -This module contains the HTTP fetcher interface and several implementations. -""" - -__all__ = ['fetch', 'getDefaultFetcher', 'setDefaultFetcher', 'HTTPResponse', - 'HTTPFetcher', 'createHTTPFetcher', 'HTTPFetchingError', - 'HTTPError'] +"""This module contains the HTTP fetcher interface and several implementations.""" +from __future__ import unicode_literals import cStringIO import sys @@ -15,6 +9,10 @@ import openid import openid.urinorm +__all__ = ['fetch', 'getDefaultFetcher', 'setDefaultFetcher', 'HTTPResponse', + 'HTTPFetcher', 'createHTTPFetcher', 'HTTPFetchingError', + 'HTTPError'] + # Try to import httplib2 for caching support # http://bitworking.org/projects/httplib2/ try: @@ -121,7 +119,10 @@ def usingCurl(): class HTTPResponse(object): - """XXX document attributes""" + """XXX document attributes + + @type body: six.binary_type + """ headers = None status = None body = None @@ -154,7 +155,7 @@ def fetch(self, url, body=None, headers=None): @param headers: HTTP headers to include with the request - @type headers: {str:str} + @type headers: Dict[six.text_type, six.text_type] @return: An object representing the server's HTTP response. If there are network or protocol errors, an exception will be diff --git a/openid/server/server.py b/openid/server/server.py index 6afe8262..701fbdf4 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -115,6 +115,7 @@ @group Response Encodings: ENCODE_KVFORM, ENCODE_HTML_FORM, ENCODE_URL """ +from __future__ import unicode_literals import logging import time @@ -128,6 +129,7 @@ from openid.dh import DiffieHellman from openid.message import (IDENTIFIER_SELECT, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, InvalidNamespace, InvalidOpenIDNamespace, Message) +from openid.oidutil import string_to_text from openid.server.trustroot import TrustRoot, verifyReturnTo from openid.store.nonce import mkNonce from openid.urinorm import urinorm @@ -151,7 +153,7 @@ class OpenIDRequest(object): """I represent an incoming OpenID request. @cvar mode: the C{X{openid.mode}} of this request. - @type mode: str + @type mode: six.text_type @ivar message: Original request message. @type message: Message @@ -177,16 +179,16 @@ class CheckAuthRequest(OpenIDRequest): """A request to verify the validity of a previous response. @cvar mode: "X{C{check_authentication}}" - @type mode: str + @type mode: six.text_type @ivar assoc_handle: The X{association handle} the response was signed with. - @type assoc_handle: str + @type assoc_handle: six.text_type @ivar signed: The message with the signature which wants checking. @type signed: L{Message} @ivar invalidate_handle: An X{association handle} the client is asking about the validity of. Optional, may be C{None}. - @type invalidate_handle: str + @type invalidate_handle: six.text_type @see: U{OpenID Specs, Mode: check_authentication } @@ -201,13 +203,17 @@ def __init__(self, assoc_handle, signed, invalidate_handle=None, message=None): These parameters are assigned directly as class attributes, see my L{class documentation} for their descriptions. - @type assoc_handle: str + @type assoc_handle: six.text_type, six.binary_type is deprecated @type signed: L{Message} - @type invalidate_handle: str + @type invalidate_handle: six.text_type, six.binary_type is deprecated """ super(CheckAuthRequest, self).__init__(message=message) - self.assoc_handle = assoc_handle + self.assoc_handle = string_to_text(assoc_handle, + "Binary values for assoc_handle are deprecated. Use text input instead.") self.signed = signed + if invalidate_handle is not None: + invalidate_handle = string_to_text( + invalidate_handle, "Binary values for invalidate_handle are deprecated. Use text input instead.") self.invalidate_handle = invalidate_handle @classmethod @@ -282,7 +288,7 @@ class PlainTextServerSession(object): @cvar session_type: The session_type for this association session. There is no type defined for plain-text in the OpenID specification, so we use 'no-encryption'. - @type session_type: str + @type session_type: six.text_type @see: U{OpenID Specs, Mode: associate } @@ -305,7 +311,7 @@ class DiffieHellmanSHA1ServerSession(object): @cvar session_type: The session_type for this association session. - @type session_type: str + @type session_type: six.text_type @ivar dh: The Diffie-Hellman algorithm values for this request @type dh: DiffieHellman @@ -388,11 +394,11 @@ class AssociateRequest(OpenIDRequest): """A request to establish an X{association}. @cvar mode: "X{C{check_authentication}}" - @type mode: str + @type mode: six.text_type @ivar assoc_type: The type of association. The protocol currently only defines one value for this, "X{C{HMAC-SHA1}}". - @type assoc_type: str + @type assoc_type: six.text_type @ivar session: An object that knows how to handle association requests of a certain type. @@ -526,32 +532,32 @@ class CheckIDRequest(OpenIDRequest): and X{C{checkid_setup}}. @cvar mode: "X{C{checkid_immediate}}" or "X{C{checkid_setup}}" - @type mode: str + @type mode: six.text_type @ivar immediate: Is this an immediate-mode request? @type immediate: bool @ivar identity: The OP-local identifier being checked. - @type identity: str + @type identity: six.text_type @ivar claimed_id: The claimed identifier. Not present in OpenID 1.x messages. - @type claimed_id: str or None + @type claimed_id: Optional[six.text_type] @ivar trust_root: "Are you Frank?" asks the checkid request. "Who wants to know?" C{trust_root}, that's who. This URL identifies the party making the request, and the user will use that to make her decision about what answer she trusts them to have. Referred to as "realm" in OpenID 2.0. - @type trust_root: str + @type trust_root: six.text_type @ivar return_to: The URL to send the user agent back to to reply to this request. - @type return_to: str + @type return_to: six.text_type @ivar assoc_handle: Provided in smart mode requests, a handle for a previously established association. C{None} for dumb mode requests. - @type assoc_handle: str + @type assoc_handle: six.text_type """ def __init__(self, identity, return_to, trust_root=None, immediate=False, @@ -631,7 +637,7 @@ def fromMessage(klass, message, op_endpoint): @param op_endpoint: The endpoint URL of the server that this message was sent to. - @type op_endpoint: str + @type op_endpoint: Optional[six.text_type], six.binary_type is deprecated @returntype: L{CheckIDRequest} """ @@ -665,6 +671,9 @@ def fromMessage(klass, message, op_endpoint): assoc_handle = message.getArg(OPENID_NS, 'assoc_handle') + if op_endpoint is not None: + op_endpoint = string_to_text(op_endpoint, + "Binary values for op_endpoint are deprecated. Use text input instead.") self = klass(identity, return_to, trust_root=trust_root, immediate=immediate, assoc_handle=assoc_handle, op_endpoint=op_endpoint, claimed_id=claimed_id, message=message) return self @@ -742,12 +751,11 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): Optional for requests where C{CheckIDRequest.immediate} is C{False} or C{allow} is C{True}. - - @type server_url: str + @type server_url: Optional[six.text_type], six.binary_type is deprecated @param identity: The OP-local identifier to answer with. Only for use when the relying party requested identifier selection. - @type identity: str or None + @type identity: Optional[six.text_type], six.binary_type is deprecated @param claimed_id: The claimed identifier to answer with, for use with identifier selection in the case where the claimed identifier @@ -760,7 +768,7 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): C{claimed_id} will default to that of the request. This parameter is new in OpenID 2.0. - @type claimed_id: str or None + @type claimed_id: Optional[six.text_type], six.binary_type is deprecated @returntype: L{OpenIDResponse} @@ -768,6 +776,12 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): @raises NoReturnError: when I do not have a return_to. """ + if identity is not None: + identity = string_to_text(identity, "Binary values for identity are deprecated. Use text input instead.") + if claimed_id is not None: + claimed_id = string_to_text(claimed_id, + "Binary values for claimed_id are deprecated. Use text input instead.") + if not self.return_to: raise NoReturnToError @@ -779,6 +793,9 @@ def answer(self, allow, server_url=None, identity=None, claimed_id=None): "to respond to OpenID 2.0 messages." % (self,)) server_url = self.op_endpoint + else: + server_url = string_to_text(server_url, + "Binary values for server_url are deprecated. Use text input instead.") if allow: mode = 'id_res' @@ -876,9 +893,9 @@ def encodeToURL(self, server_url): """Encode this request as a URL to GET. @param server_url: The URL of the OpenID server to make this request of. - @type server_url: str + @type server_url: six.text_type, six.binary_type is deprecated - @returntype: str + @returntype: six.text_type @raises NoReturnError: when I do not have a return_to. """ @@ -903,6 +920,7 @@ def encodeToURL(self, server_url): response = Message(self.message.getOpenIDNamespace()) response.updateArgs(OPENID_NS, q) + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") return response.toURL(server_url) def getCancelURL(self): @@ -915,7 +933,7 @@ def getCancelURL(self): that it knows that the user did make a decision. Or you could simulate this method by doing C{.answer(False).encodeToURL()}) - @returntype: str + @returntype: six.text_type @returns: The return_to URL with openid.mode = cancel. @raises NoReturnError: when I do not have a return_to. @@ -951,7 +969,7 @@ class OpenIDResponse(object): @type fields: L{openid.message.Message} @ivar signed: The names of the fields which should be signed. - @type signed: list of str + @type signed: List[six.text_type] """ # Implementer's note: In a more symmetric client/server @@ -983,7 +1001,7 @@ def toFormMarkup(self, form_tag_attrs=None): that can be overridden. If a value is supplied for 'action' or 'method', it will be replaced. - @returntype: str + @returntype: six.text_type @since: 2.1.0 """ @@ -994,7 +1012,7 @@ def toHTML(self, form_tag_attrs=None): """Returns an HTML document that auto-submits the form markup for this response. - @returntype: str + @returntype: six.text_type @see: toFormMarkup @@ -1044,7 +1062,7 @@ def encodeToURL(self): You will generally use this URL with a HTTP redirect. @returns: A URL to direct the user agent back to. - @returntype: str + @returntype: six.text_type """ return self.fields.toURL(self.request.return_to) @@ -1088,7 +1106,7 @@ class WebResponse(object): @type headers: dict @ivar body: The body of this response. - @type body: str + @type body: six.text_type """ def __init__(self, code=HTTP_OK, headers=None, body=""): @@ -1141,7 +1159,7 @@ def verify(self, assoc_handle, message): @param assoc_handle: The handle of the association used to sign the data. - @type assoc_handle: str + @type assoc_handle: six.text_type, six.binary_type is deprecated @param message: The signed message to verify @type message: openid.message.Message @@ -1149,6 +1167,8 @@ def verify(self, assoc_handle, message): @returns: C{True} if the signature is valid, C{False} if not. @returntype: bool """ + assoc_handle = string_to_text(assoc_handle, + "Binary values for assoc_handle are deprecated. Use text input instead.") assoc = self.getAssociation(assoc_handle, dumb=True) if not assoc: _LOGGER.error("failed to get assoc with handle %r to verify message %r", assoc_handle, message) @@ -1202,7 +1222,7 @@ def sign(self, response): try: signed_response.fields = assoc.signMessage(signed_response.fields) except kvform.KVFormError as err: - raise EncodingError(response, explanation=str(err)) + raise EncodingError(response, explanation=six.text_type(err)) return signed_response def createAssociation(self, dumb=True, assoc_type='HMAC-SHA1'): @@ -1213,11 +1233,13 @@ def createAssociation(self, dumb=True, assoc_type='HMAC-SHA1'): @param assoc_type: The type of association to create. Currently there is only one type defined, C{HMAC-SHA1}. - @type assoc_type: str + @type assoc_type: six.text_type, six.binary_type is deprecated @returns: the new association. @returntype: L{openid.association.Association} """ + assoc_type = string_to_text(assoc_type, "Binary values for assoc_type are deprecated. Use text input instead.") + secret = cryptutil.getBytes(getSecretSize(assoc_type)) uniq = oidutil.toBase64(cryptutil.getBytes(4)) handle = '{%s}{%x}{%s}' % (assoc_type, int(time.time()), uniq) @@ -1235,7 +1257,7 @@ def createAssociation(self, dumb=True, assoc_type='HMAC-SHA1'): def getAssociation(self, assoc_handle, dumb, checkExpiration=True): """Get the association with the specified handle. - @type assoc_handle: str + @type assoc_handle: six.text_type, six.binary_type is deprecated @param dumb: Is this association used with dumb mode? @type dumb: bool @@ -1252,6 +1274,8 @@ def getAssociation(self, assoc_handle, dumb, checkExpiration=True): if assoc_handle is None: raise ValueError("assoc_handle must not be None") + assoc_handle = string_to_text(assoc_handle, + "Binary values for assoc_handle are deprecated. Use text input instead.") if dumb: key = self._dumb_key @@ -1269,7 +1293,7 @@ def getAssociation(self, assoc_handle, dumb, checkExpiration=True): def invalidate(self, assoc_handle, dumb): """Invalidates the association with the given handle. - @type assoc_handle: str + @type assoc_handle: six.text_type, six.binary_type is deprecated @param dumb: Is this association used with dumb mode? @type dumb: bool @@ -1278,6 +1302,8 @@ def invalidate(self, assoc_handle, dumb): key = self._dumb_key else: key = self._normal_key + assoc_handle = string_to_text(assoc_handle, + "Binary values for assoc_handle are deprecated. Use text input instead.") self.store.removeAssociation(key, assoc_handle) @@ -1400,13 +1426,13 @@ def decode(self, query): query = query.copy() query['openid.ns'] = OPENID2_NS message = Message.fromPostArgs(query) - raise ProtocolError(message, str(err)) + raise ProtocolError(message, six.text_type(err)) except InvalidNamespace as err: # If openid.ns is OK, but there is problem with other namespaces # We keep only bare parts of query and we try to make a ProtocolError from it query = [(key, value) for key, value in query.items() if key.count('.') < 2] message = Message.fromPostArgs(dict(query)) - raise ProtocolError(message, str(err)) + raise ProtocolError(message, six.text_type(err)) mode = message.getArg(OPENID_NS, 'mode') if not mode: @@ -1469,7 +1495,7 @@ class Server(object): @type encoder: L{Encoder} @ivar op_endpoint: My URL. - @type op_endpoint: str + @type op_endpoint: six.text_type @ivar negotiator: I use this to determine which kinds of associations I can make and how. @@ -1488,7 +1514,7 @@ def __init__(self, store, op_endpoint=None, signatoryClass=None, encoderClass=No @param op_endpoint: My URL, the fully qualified address of this server's endpoint, i.e. C{http://example.com/server} - @type op_endpoint: str + @type op_endpoint: six.text_type, six.binary_type is deprecated @change: C{op_endpoint} is new in library version 2.0. It currently defaults to C{None} for compatibility with @@ -1521,7 +1547,8 @@ def __init__(self, store, op_endpoint=None, signatoryClass=None, encoderClass=No "for OpenID 2.0 servers" % (self.__class__.__module__, self.__class__.__name__), stacklevel=2) - self.op_endpoint = op_endpoint + self.op_endpoint = string_to_text(op_endpoint, + "Binary values for op_endpoint are deprecated. Use text input instead.") def handleRequest(self, request): """Handle a request. @@ -1621,18 +1648,20 @@ def __init__(self, message, text=None, reference=None, contact=None): @type message: openid.message.Message @param text: A message about the encountered error. Set as C{args[0]}. - @type text: str + @type text: six.text_type, six.binary_type is deprecated """ self.openid_message = message self.reference = reference self.contact = contact assert not isinstance(message, six.string_types) + if text is not None: + text = string_to_text(text, "Binary values for text are deprecated. Use text input instead.") Exception.__init__(self, text) def getReturnTo(self): """Get the return_to argument from the request, if any. - @returntype: str + @returntype: six.text_type """ if self.openid_message is None: return None @@ -1653,13 +1682,13 @@ def toMessage(self): namespace = self.openid_message.getOpenIDNamespace() reply = Message(namespace) reply.setArg(OPENID_NS, 'mode', 'error') - reply.setArg(OPENID_NS, 'error', str(self)) + reply.setArg(OPENID_NS, 'error', six.text_type(self)) if self.contact is not None: - reply.setArg(OPENID_NS, 'contact', str(self.contact)) + reply.setArg(OPENID_NS, 'contact', six.text_type(self.contact)) if self.reference is not None: - reply.setArg(OPENID_NS, 'reference', str(self.reference)) + reply.setArg(OPENID_NS, 'reference', six.text_type(self.reference)) return reply diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index ed258953..7c1b801a 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -1,4 +1,3 @@ -# -*- test-case-name: openid.test.test_rpverify -*- """ This module contains the C{L{TrustRoot}} class, which helps handle trust root checking. This module is used by the @@ -8,6 +7,17 @@ It also implements relying party return_to URL verification, based on the realm. """ +from __future__ import unicode_literals + +import logging +import re +from urlparse import urlsplit, urlunsplit + +import six + +from openid import urinorm +from openid.oidutil import string_to_text +from openid.yadis import services __all__ = [ 'TrustRoot', @@ -17,12 +27,6 @@ 'verifyReturnTo', ] -import logging -import re -from urlparse import urlsplit, urlunsplit - -from openid import urinorm -from openid.yadis import services _LOGGER = logging.getLogger(__name__) @@ -176,14 +180,14 @@ def validateURL(self, url): @param url: The URL to check - - @type url: C{str} + @type url: six.text_type, six.binary_type is deprecated @return: Whether the given URL is within this trust root. @rtype: C{bool} """ + url = string_to_text(url, "Binary values for validateURL are deprecated. Use text input instead.") url_parts = _parseURL(url) if url_parts is None: @@ -237,8 +241,7 @@ def parse(cls, trust_root): @param trust_root: This is the trust root to parse into a C{L{TrustRoot}} object. - - @type trust_root: C{str} + @type trust_root: six.text_type, six.binary_type is deprecated @return: A C{L{TrustRoot}} instance if trust_root parses as a @@ -246,6 +249,7 @@ def parse(cls, trust_root): @rtype: C{NoneType} or C{L{TrustRoot}} """ + trust_root = string_to_text(trust_root, "Binary values for trust_root are deprecated. Use text input instead.") url_parts = _parseURL(trust_root) if url_parts is None: return None @@ -279,7 +283,7 @@ def parse(cls, trust_root): @classmethod def checkSanity(cls, trust_root_string): - """str -> bool + """six.text_type -> bool, six.binary_type is deprecated is this a sane trust root? """ @@ -302,7 +306,7 @@ def buildDiscoveryURL(self): This function does not check to make sure that the realm is valid. Its behaviour on invalid inputs is undefined. - @rtype: str + @rtype: six.text_type @returns: The URL upon which relying party discovery should be run in order to verify the return_to URL @@ -352,7 +356,7 @@ def _extractReturnURL(endpoint): @returns: The endpoint URL or None if the endpoint is not a relying party endpoint. - @rtype: str or NoneType + @rtype: six.text_type or NoneType """ if endpoint.matchTypes([RP_RETURN_TO_URL_TYPE]): return endpoint.uri @@ -429,7 +433,7 @@ def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): try: allowable_urls = _vrfy(realm.buildDiscoveryURL()) except RealmVerificationRedirected as err: - _LOGGER.exception(str(err)) + _LOGGER.exception(six.text_type(err)) return False if returnToMatches(allowable_urls, return_to): diff --git a/openid/store/filestore.py b/openid/store/filestore.py index aadf20dc..a7a3e7f5 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -1,7 +1,5 @@ -""" -This module contains an C{L{OpenIDStore}} implementation backed by -flat files. -""" +"""This module contains an C{L{OpenIDStore}} implementation backed by flat files.""" +from __future__ import unicode_literals import logging import os @@ -13,6 +11,7 @@ from openid import cryptutil, oidutil from openid.association import Association +from openid.oidutil import string_to_text from openid.store import nonce from openid.store.interface import OpenIDStore @@ -23,7 +22,7 @@ def _safe64(s): - h64 = oidutil.toBase64(cryptutil.sha1(s)) + h64 = oidutil.toBase64(cryptutil.sha1(s.encode('utf-8'))) h64 = h64.replace('+', '_') h64 = h64.replace('/', '.') h64 = h64.replace('=', '') @@ -44,7 +43,7 @@ def _removeIfPresent(filename): """Attempt to remove a file, returning whether the file existed at the time of the call. - str -> bool + six.text_type -> bool """ try: os.unlink(filename) @@ -65,7 +64,7 @@ def _ensureDir(dir_name): Can raise OSError - str -> NoneType + six.text_type -> NoneType """ try: os.makedirs(dir_name) @@ -99,8 +98,7 @@ def __init__(self, directory): @param directory: This is the directory to put the store directories in. - - @type directory: C{str} + @type directory: six.text_type, six.binary_type is deprecated """ # Make absolute directory = os.path.normpath(os.path.abspath(directory)) @@ -136,7 +134,7 @@ def _mktemp(self): the store, it is safe to remove all of the files in the temporary directory. - () -> (file, str) + () -> (file, six.text_type) """ fd, name = mkstemp(dir=self.temp_dir) try: @@ -155,8 +153,11 @@ def getAssociationFilename(self, server_url, handle): contain the domain name from the server URL for ease of human inspection of the data directory. - (str, str) -> str + (six.text_type, six.text_type) -> six.text_type, six.binary_type is deprecated """ + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") + handle = string_to_text(handle, "Binary values for handle are deprecated. Use text input instead.") + if server_url.find('://') == -1: raise ValueError('Bad server URL: %r' % server_url) @@ -175,15 +176,17 @@ def getAssociationFilename(self, server_url, handle): def storeAssociation(self, server_url, association): """Store an association in the association directory. - (str, Association) -> NoneType + (six.text_type, Association) -> NoneType, six.binary_type is deprecated """ + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") + association_s = association.serialize() filename = self.getAssociationFilename(server_url, association.handle) tmp_file, tmp = self._mktemp() try: try: - tmp_file.write(association_s) + tmp_file.write(association_s.encode('utf-8')) os.fsync(tmp_file.fileno()) finally: tmp_file.close() @@ -218,8 +221,12 @@ def getAssociation(self, server_url, handle=None): """Retrieve an association. If no handle is specified, return the association with the latest expiration. - (str, str or NoneType) -> Association or NoneType + (six.text_type, Optional[six.text_type]) -> Association or NoneType, six.binary_type is deprecated """ + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") + if handle is not None: + handle = string_to_text(handle, "Binary values for handle are deprecated. Use text input instead.") + if handle is None: handle = '' @@ -287,8 +294,11 @@ def _getAssociation(self, filename): def removeAssociation(self, server_url, handle): """Remove an association if it exists. Do nothing if it does not. - (str, str) -> bool + (six.text_type, six.text_type) -> bool, six.binary_type is deprecated """ + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") + handle = string_to_text(handle, "Binary values for handle are deprecated. Use text input instead.") + assoc = self.getAssociation(server_url, handle) if assoc is None: return 0 @@ -299,8 +309,11 @@ def removeAssociation(self, server_url, handle): def useNonce(self, server_url, timestamp, salt): """Return whether this nonce is valid. - str -> bool + @type server_url: six.text_type, six.binary_type is deprecated + @rtype: bool """ + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") + if abs(timestamp - time.time()) > nonce.SKEW: return False diff --git a/openid/store/interface.py b/openid/store/interface.py index 63776572..88fc0e95 100644 --- a/openid/store/interface.py +++ b/openid/store/interface.py @@ -1,7 +1,5 @@ -""" -This module contains the definition of the C{L{OpenIDStore}} -interface. -""" +"""This module contains the definition of the C{L{OpenIDStore}} interface.""" +from __future__ import unicode_literals class OpenIDStore(object): @@ -33,8 +31,7 @@ def storeAssociation(self, server_url, association): there are any limitations on the character set of the input string. In particular, expect to see unescaped non-url-safe characters in the server_url field. - - @type server_url: C{str} + @type server_url: six.text_type @param association: The C{L{Association @@ -74,16 +71,13 @@ def getAssociation(self, server_url, handle=None): any limitations on the character set of the input string. In particular, expect to see unescaped non-url-safe characters in the server_url field. - - @type server_url: C{str} - + @type server_url: six.text_type @param handle: This optional parameter is the handle of the specific association to get. If no specific handle is provided, any valid association matching the server URL is returned. - - @type handle: C{str} or C{NoneType} + @type handle: Optional[six.text_type] @return: The C{L{Association @@ -107,16 +101,13 @@ def removeAssociation(self, server_url, handle): assume there are any limitations on the character set of the input string. In particular, expect to see unescaped non-url-safe characters in the server_url field. - - @type server_url: C{str} - + @type server_url: six.text_type @param handle: This is the handle of the association to remove. If there isn't an association found that matches both the given URL and handle, then there was no matching handle found. - - @type handle: C{str} + @type handle: six.text_type @return: Returns whether or not the given association existed. @@ -144,8 +135,7 @@ def useNonce(self, server_url, timestamp, salt): @param server_url: The URL of the server from which the nonce originated. - - @type server_url: C{str} + @type server_url: six.text_type @param timestamp: The time that the nonce was created (to the nearest second), in seconds since January 1 1970 UTC. @@ -153,7 +143,7 @@ def useNonce(self, server_url, timestamp, salt): @param salt: A random string that makes two nonces from the same server issued during the same second unique. - @type salt: str + @type salt: six.text_type @return: Whether or not the nonce was valid. diff --git a/openid/store/memstore.py b/openid/store/memstore.py index 366a596e..d2a74f41 100644 --- a/openid/store/memstore.py +++ b/openid/store/memstore.py @@ -1,8 +1,11 @@ """A simple store using only in-process memory.""" +from __future__ import unicode_literals import copy import time +import six + from openid.store import nonce @@ -85,7 +88,7 @@ def useNonce(self, server_url, timestamp, salt): if abs(timestamp - time.time()) > nonce.SKEW: return False - anonce = (str(server_url), int(timestamp), str(salt)) + anonce = (six.text_type(server_url), int(timestamp), six.text_type(salt)) if anonce in self.nonces: return False else: diff --git a/openid/store/nonce.py b/openid/store/nonce.py index 60b3a891..f00f4e1c 100644 --- a/openid/store/nonce.py +++ b/openid/store/nonce.py @@ -1,14 +1,18 @@ -__all__ = [ - 'split', - 'mkNonce', - 'checkTimestamp', -] +from __future__ import unicode_literals import string from calendar import timegm from time import gmtime, strftime, strptime, time from openid import cryptutil +from openid.oidutil import string_to_text + +__all__ = [ + 'split', + 'mkNonce', + 'checkTimestamp', +] + NONCE_CHARS = (string.ascii_letters + string.digits).encode('utf-8') @@ -25,14 +29,17 @@ def split(nonce_string): """Extract a timestamp from the given nonce string @param nonce_string: the nonce from which to extract the timestamp - @type nonce_string: str + @type nonce_string: six.text_type, six.binary_type is deprecated @returns: A pair of a Unix timestamp and the salt characters - @returntype: (int, str) + @returntype: (int, six.text_type) @raises ValueError: if the nonce does not start with a correctly formatted time string """ + nonce_string = string_to_text(nonce_string, + "Binary values for nonce_string are deprecated. Use text input instead.") + timestamp_str = nonce_string[:time_str_len] timestamp = timegm(strptime(timestamp_str, time_fmt)) if timestamp < 0: @@ -45,7 +52,7 @@ def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None): within the allowed clock-skew of the current time? @param nonce_string: The nonce that is being checked - @type nonce_string: str + @type nonce_string: six.text_type, six.binary_type is deprecated @param allowed_skew: How many seconds should be allowed for completing the request, allowing for clock skew. @@ -84,7 +91,7 @@ def mkNonce(when=None): nonce. Defaults to the current time. @type when: int - @returntype: str + @returntype: six.text_type @returns: A string that should be usable as a one-way nonce @see: time diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index 339931e2..17912982 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -7,10 +7,15 @@ python -c 'from openid.store import sqlstore; import pysqlite2.dbapi2; sqlstore.SQLiteStore(pysqlite2.dbapi2.connect("cstore.db")).createTables()' """ +from __future__ import unicode_literals + import re import time +import six + from openid.association import Association +from openid.oidutil import string_to_text from openid.store import nonce from openid.store.interface import OpenIDStore @@ -84,19 +89,20 @@ def __init__(self, conn, associations_table=None, nonces_table=None): specify the name of the table used for storing associations. The default value is specified in C{L{SQLStore.associations_table}}. - - @type associations_table: C{str} - + @type associations_table: six.text_type, six.binary_type is deprecated @param nonces_table: This is an optional parameter to specify the name of the table used for storing nonces. The default value is specified in C{L{SQLStore.nonces_table}}. - - @type nonces_table: C{str} + @type nonces_table: six.text_type, six.binary_type is deprecated """ self.conn = conn self.cur = None self._statement_cache = {} + associations_table = string_to_text( + associations_table, "Binary values for associations_table are deprecated. Use text input instead.") + nonces_table = string_to_text(nonces_table, + "Binary values for nonces_table are deprecated. Use text input instead.") self._table_names = { 'associations': associations_table or self.associations_table, 'nonces': nonces_table or self.nonces_table, @@ -115,13 +121,13 @@ def __init__(self, conn, associations_table=None, nonces_table=None): "(Maybe it can't be imported?)") def blobDecode(self, blob): - """Convert a blob as returned by the SQL engine into a str object. + """Convert a blob as returned by the SQL engine into a binary_type object. - str -> str""" + six.binary_type -> six.binary_type""" return blob def blobEncode(self, s): - """Convert a str object into the necessary object for storing + """Convert a six.binary_type object into the necessary object for storing in the database as a blob.""" return s @@ -142,8 +148,8 @@ def _execSQL(self, sql_name, *args): # so this ought to be safe. def unicode_to_str(arg): - if isinstance(arg, unicode): - return str(arg) + if isinstance(arg, six.text_type): + return arg.encode('utf-8') else: return arg str_args = [unicode_to_str(i) for i in args] @@ -216,8 +222,11 @@ def txn_getAssociation(self, server_url, handle=None): """Get the most recent association that has been set for this server URL and handle. - str -> NoneType or Association + @type server_url: six.text_type, six.binary_type is deprecated + @rtype: Optional[Association] """ + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") + if handle is not None: self.db_get_assoc(server_url, handle) else: @@ -229,8 +238,10 @@ def txn_getAssociation(self, server_url, handle=None): else: associations = [] for values in rows: - assoc = Association(*values) - assoc.secret = self.blobDecode(assoc.secret) + # Decode secret before association is created + handle, secret, issued, lifetime, assoc_type = values + secret = self.blobDecode(secret) + assoc = Association(handle, secret, issued, lifetime, assoc_type) if assoc.getExpiresIn() == 0: self.txn_removeAssociation(server_url, assoc.handle) else: @@ -248,8 +259,11 @@ def txn_removeAssociation(self, server_url, handle): """Remove the association for the given server URL and handle, returning whether the association existed at all. - (str, str) -> bool + (six.text_type, six.text_type) -> bool, six.binary_type is deprecated """ + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") + handle = string_to_text(handle, "Binary values for handle are deprecated. Use text input instead.") + self.db_remove_assoc(server_url, handle) return self.cur.rowcount > 0 # -1 is undefined @@ -259,7 +273,11 @@ def txn_useNonce(self, server_url, timestamp, salt): """Return whether this nonce is present, and if it is, then remove it from the set. - str -> bool""" + @type server_url: six.text_type, six.binary_type is deprecated + @rtype: bool + """ + server_url = string_to_text(server_url, "Binary values for server_url are deprecated. Use text input instead.") + if abs(timestamp - time.time()) > nonce.SKEW: return False @@ -342,7 +360,7 @@ class SQLiteStore(SQLStore): clean_nonce_sql = 'DELETE FROM %(nonces)s WHERE timestamp < ?;' def blobDecode(self, buf): - return str(buf) + return six.binary_type(buf) def blobEncode(self, s): return buffer(s) @@ -421,7 +439,7 @@ class MySQLStore(SQLStore): clean_nonce_sql = 'DELETE FROM %(nonces)s WHERE timestamp < %%s;' def blobDecode(self, blob): - if isinstance(blob, str): + if isinstance(blob, six.binary_type): # Versions of MySQLdb >= 1.2.2 return blob else: diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 32d9619c..04990c5b 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -1,4 +1,6 @@ """Module to make discovery data test cases available""" +from __future__ import unicode_literals + import os.path import urlparse diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 2df50f4f..0e1fc5e7 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -3,6 +3,8 @@ This duplicates some things that are covered by test_consumer, but this works for now. """ +from __future__ import unicode_literals + import unittest from testfixtures import LogCapture @@ -231,7 +233,7 @@ def test_explicitNoEncryption(self): class DummyAssociationSession(object): - secret = "shh! don't tell!" + secret = b"shh! don't tell!" extract_secret_called = False session_type = None diff --git a/openid/test/test_auth_request.py b/openid/test/test_auth_request.py index cc969878..7c21a789 100644 --- a/openid/test/test_auth_request.py +++ b/openid/test/test_auth_request.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest from openid import message diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 8e63b684..0cd6a62a 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import time import unittest import urlparse @@ -535,8 +537,7 @@ def setUp(self): def _createAssoc(self): issued = time.time() lifetime = 1000 - assoc = association.Association( - 'handle', 'secret', issued, lifetime, 'HMAC-SHA1') + assoc = association.Association('handle', b'secret', issued, lifetime, 'HMAC-SHA1') store = self.consumer.store store.storeAssociation(self.server_url, assoc) assoc2 = store.getAssociation(self.server_url) @@ -918,8 +919,7 @@ def test_checkAuthTriggeredWithAssoc(self): # handle that is in the message issued = time.time() lifetime = 1000 - assoc = association.Association( - 'handle', 'secret', issued, lifetime, 'HMAC-SHA1') + assoc = association.Association('handle', b'secret', issued, lifetime, 'HMAC-SHA1') self.store.storeAssociation(self.server_url, assoc) self.disableReturnToChecking() message = Message.fromPostArgs({ @@ -942,8 +942,7 @@ def test_expiredAssoc(self): issued = time.time() - 10 lifetime = 0 handle = 'handle' - assoc = association.Association( - handle, 'secret', issued, lifetime, 'HMAC-SHA1') + assoc = association.Association(handle, b'secret', issued, lifetime, 'HMAC-SHA1') self.assertLessEqual(assoc.expiresIn, 0) self.store.storeAssociation(self.server_url, assoc) @@ -962,14 +961,12 @@ def test_newerAssoc(self): good_issued = time.time() - 10 good_handle = 'handle' - good_assoc = association.Association( - good_handle, 'secret', good_issued, lifetime, 'HMAC-SHA1') + good_assoc = association.Association(good_handle, b'secret', good_issued, lifetime, 'HMAC-SHA1') self.store.storeAssociation(self.server_url, good_assoc) bad_issued = time.time() - 5 bad_handle = 'handle2' - bad_assoc = association.Association( - bad_handle, 'secret', bad_issued, lifetime, 'HMAC-SHA1') + bad_assoc = association.Association(bad_handle, b'secret', bad_issued, lifetime, 'HMAC-SHA1') self.store.storeAssociation(self.server_url, bad_assoc) query = { diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index f634b121..414f7ed7 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from __future__ import unicode_literals + import os.path import unittest from urlparse import urlsplit diff --git a/openid/test/test_extension.py b/openid/test/test_extension.py index 640f11a6..f851a0f9 100644 --- a/openid/test/test_extension.py +++ b/openid/test/test_extension.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest from openid import extension, message diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 57b90896..5594a522 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import socket import unittest import urllib2 @@ -74,11 +76,11 @@ def geturl(path): def plain(path, code): path = '/' + path expected = fetchers.HTTPResponse( - geturl(path), code, expected_headers, path) + geturl(path), code, expected_headers, path.encode('utf-8')) return (path, expected) expect_success = fetchers.HTTPResponse( - geturl('/success'), 200, expected_headers, '/success') + geturl('/success'), 200, expected_headers, b'/success') cases = [ ('/success', expect_success), ('/301redirect', expect_success), @@ -222,7 +224,7 @@ def _respond(self, http_code, extra_headers, body): for k, v in extra_headers: self.send_header(k, v) self.end_headers() - self.wfile.write(body) + self.wfile.write(body.encode('utf-8')) self.wfile.close() def finish(self): @@ -384,19 +386,19 @@ class TestRequestsFetcher(unittest.TestCase): def test_get(self): # Test GET response with responses.RequestsMock() as rsps: - rsps.add(responses.GET, 'http://example.cz/', status=200, body='BODY', + rsps.add(responses.GET, 'http://example.cz/', status=200, body=b'BODY', headers={'Content-Type': 'text/plain'}) response = self.fetcher.fetch('http://example.cz/') - expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, 'BODY') + expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) def test_post(self): # Test POST response with responses.RequestsMock() as rsps: - rsps.add(responses.POST, 'http://example.cz/', status=200, body='BODY', + rsps.add(responses.POST, 'http://example.cz/', status=200, body=b'BODY', headers={'Content-Type': 'text/plain'}) response = self.fetcher.fetch('http://example.cz/', body='key=value') - expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, 'BODY') + expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) def test_redirect(self): @@ -404,19 +406,19 @@ def test_redirect(self): with responses.RequestsMock() as rsps: rsps.add(responses.GET, 'http://example.cz/redirect/', status=302, headers={'Location': 'http://example.cz/target/'}) - rsps.add(responses.GET, 'http://example.cz/target/', status=200, body='BODY', + rsps.add(responses.GET, 'http://example.cz/target/', status=200, body=b'BODY', headers={'Content-Type': 'text/plain'}) response = self.fetcher.fetch('http://example.cz/redirect/') - expected = fetchers.HTTPResponse('http://example.cz/target/', 200, {'Content-Type': 'text/plain'}, 'BODY') + expected = fetchers.HTTPResponse('http://example.cz/target/', 200, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) def test_error(self): # Test error responses - returned as obtained with responses.RequestsMock() as rsps: - rsps.add(responses.GET, 'http://example.cz/error/', status=500, body='BODY', + rsps.add(responses.GET, 'http://example.cz/error/', status=500, body=b'BODY', headers={'Content-Type': 'text/plain'}) response = self.fetcher.fetch('http://example.cz/error/') - expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') + expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) def test_invalid_url(self): diff --git a/openid/test/test_htmldiscover.py b/openid/test/test_htmldiscover.py index b4caeb30..cb354a6c 100644 --- a/openid/test/test_htmldiscover.py +++ b/openid/test/test_htmldiscover.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest from openid.consumer.discover import OpenIDServiceEndpoint diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index 6c4cf1fe..9b23233e 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import unittest from testfixtures import LogCapture, StringComparison @@ -124,8 +126,7 @@ def testUnsupportedWithRetry(self): msg.setArg(OPENID_NS, 'assoc_type', 'HMAC-SHA1') msg.setArg(OPENID_NS, 'session_type', 'DH-SHA1') - assoc = association.Association( - 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') + assoc = association.Association('handle', b'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] with LogCapture() as logbook: @@ -157,8 +158,7 @@ def testValid(self): Test the valid case, wherein an association is returned on the first attempt to get one. """ - assoc = association.Association( - 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') + assoc = association.Association('handle', b'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [assoc] with LogCapture() as logbook: @@ -241,8 +241,7 @@ def testUnsupportedWithRetry(self): msg.setArg(OPENID_NS, 'assoc_type', 'HMAC-SHA1') msg.setArg(OPENID_NS, 'session_type', 'DH-SHA1') - assoc = association.Association( - 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') + assoc = association.Association('handle', b'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [msg, assoc] with LogCapture() as logbook: @@ -251,8 +250,7 @@ def testUnsupportedWithRetry(self): ('openid.consumer.consumer', 'ERROR', StringComparison('Server error when requesting an association .*'))) def testValid(self): - assoc = association.Association( - 'handle', 'secret', 'issued', 10000, 'HMAC-SHA1') + assoc = association.Association('handle', b'secret', 'issued', 10000, 'HMAC-SHA1') self.consumer.return_messages = [assoc] with LogCapture() as logbook: diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index 010178fa..d2746115 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -1,3 +1,5 @@ +from __future__ import unicode_literals + import re import unittest diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index cbcb6dfd..4e707d59 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -1,7 +1,5 @@ -"""Unit tests for verification of return_to URLs for a realm -""" - -__all__ = ['TestBuildDiscoveryURL'] +"""Unit tests for verification of return_to URLs for a realm.""" +from __future__ import unicode_literals import unittest @@ -11,6 +9,8 @@ from openid.yadis import services from openid.yadis.discover import DiscoveryFailure, DiscoveryResult +__all__ = ['TestBuildDiscoveryURL'] + class TestBuildDiscoveryURL(unittest.TestCase): """Tests for building the discovery URL from a realm and a @@ -75,7 +75,7 @@ def test_badXML(self): self.assertDiscoveryFailure('>') def test_noEntries(self): - self.assertReturnURLs('''\ + self.assertReturnURLs(b'''\ bytes --[utf-8]--> str + hostname = hostname.encode('idna').decode('utf-8') except ValueError as error: raise ValueError('Invalid hostname {!r}: {}'.format(hostname, error)) _check_disallowed_characters(hostname, 'hostname') @@ -103,7 +106,7 @@ def urinorm(uri): netloc = hostname if port: - netloc = netloc + ':' + str(port) + netloc = netloc + ':' + six.text_type(port) userinfo_chunks = [i for i in (split_uri.username, split_uri.password) if i is not None] if userinfo_chunks: userinfo = ':'.join(userinfo_chunks) diff --git a/setup.py b/setup.py index 6ac4e0e1..a333c44f 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from __future__ import unicode_literals + import os import sys From 38e747ee7071a9af23edf0378028878c699fe512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 11:52:30 +0200 Subject: [PATCH 075/151] Use six for URL utilities --- admin/builddiscover.py | 9 +++++---- contrib/openid-parse | 2 +- examples/consumer.py | 9 +++++---- examples/djopenid/server/tests.py | 3 +-- examples/server.py | 2 +- openid/consumer/consumer.py | 2 +- openid/consumer/discover.py | 8 ++++---- openid/message.py | 4 ++-- openid/oidutil.py | 2 +- openid/server/trustroot.py | 3 +-- openid/test/discoverdata.py | 9 +++++---- openid/test/test_consumer.py | 8 ++++---- openid/test/test_discover.py | 3 ++- openid/test/test_message.py | 6 ++---- openid/test/test_server.py | 2 +- openid/test/test_yadis_discover.py | 5 +++-- openid/urinorm.py | 3 +-- openid/yadis/xri.py | 2 +- openid/yadis/xrires.py | 2 +- 19 files changed, 42 insertions(+), 42 deletions(-) diff --git a/admin/builddiscover.py b/admin/builddiscover.py index 9b5da070..4b572734 100755 --- a/admin/builddiscover.py +++ b/admin/builddiscover.py @@ -2,7 +2,8 @@ from __future__ import unicode_literals import os.path -import urlparse + +from six.moves.urllib.parse import urljoin from openid.test import discoverdata @@ -52,9 +53,9 @@ def writeTestFile(test_name): continue writeTestFile(input_name) - input_url = urlparse.urljoin(base_url, input_name) - id_url = urlparse.urljoin(base_url, id_name) - result_url = urlparse.urljoin(base_url, result_name) + input_url = urljoin(base_url, input_name) + id_url = urljoin(base_url, id_name) + result_url = urljoin(base_url, result_name) manifest.append('\t'.join((input_url, id_url, result_url))) manifest.append('\n') diff --git a/contrib/openid-parse b/contrib/openid-parse index b6c59ea9..c4e437a6 100644 --- a/contrib/openid-parse +++ b/contrib/openid-parse @@ -12,9 +12,9 @@ import re import subprocess import sys from pprint import pformat -from urlparse import parse_qs, urlsplit, urlunsplit import six +from six.moves.urllib.parse import parse_qs, urlsplit, urlunsplit from openid import message diff --git a/examples/consumer.py b/examples/consumer.py index dbce4117..5898ee46 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -14,11 +14,12 @@ import cgitb import optparse import sys -import urlparse +from six.moves.urllib.parse import urlparse, parse_qsl, urljoin from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from Cookie import SimpleCookie import six +from six.moves.urllib.parse import parse_qsl, urljoin, urlparse def quoteattr(s): @@ -135,9 +136,9 @@ def do_GET(self): written to the requesting browser. """ try: - self.parsed_uri = urlparse.urlparse(self.path) + self.parsed_uri = urlparse(self.path) self.query = {} - for k, v in urlparse.parse_qsl(self.parsed_uri[4]): + for k, v in parse_qsl(self.parsed_uri[4]): self.query[k] = v.decode('utf-8') path = self.parsed_uri[2] @@ -343,7 +344,7 @@ def renderPAPE(self, pape_data): def buildURL(self, action, **query): """Build a URL relative to the server base_url, with the given query parameters added.""" - base = urlparse.urljoin(self.server.base_url, action) + base = urljoin(self.server.base_url, action) return appendArgs(base, query) def notFound(self): diff --git a/examples/djopenid/server/tests.py b/examples/djopenid/server/tests.py index bdc08849..02f97ab9 100644 --- a/examples/djopenid/server/tests.py +++ b/examples/djopenid/server/tests.py @@ -1,7 +1,5 @@ from __future__ import unicode_literals -from urlparse import urljoin - import django from django.http import HttpRequest from django.test.testcases import TestCase @@ -10,6 +8,7 @@ from openid.server.server import HTTP_REDIRECT, CheckIDRequest from openid.yadis.constants import YADIS_CONTENT_TYPE from openid.yadis.services import applyFilter +from six.moves.urllib.parse import urljoin from .. import util from ..server import views diff --git a/examples/server.py b/examples/server.py index 12197495..f9174d01 100644 --- a/examples/server.py +++ b/examples/server.py @@ -11,7 +11,7 @@ import sys import time from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer -from urlparse import parse_qsl, urlparse +from six.moves.urllib.parse import parse_qsl, urlparse def quoteattr(s): diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index a190ffc5..2d0b1c8c 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -190,9 +190,9 @@ import copy import logging -from urlparse import parse_qsl, urldefrag, urlparse import six +from six.moves.urllib.parse import parse_qsl, urldefrag, urlparse from openid import cryptutil, fetchers, oidutil, urinorm from openid.association import Association, SessionNegotiator, default_negotiator diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index a3ffa597..c3ddb66f 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -2,10 +2,10 @@ from __future__ import unicode_literals import logging -import urlparse from lxml.etree import LxmlError from lxml.html import document_fromstring +from six.moves.urllib.parse import urldefrag, urlparse from openid import fetchers, urinorm from openid.message import OPENID1_NS as OPENID_1_0_MESSAGE_NS, OPENID2_NS as OPENID_2_0_MESSAGE_NS @@ -87,7 +87,7 @@ def getDisplayIdentifier(self): if self.claimed_id is None: return None else: - return urlparse.urldefrag(self.claimed_id)[0] + return urldefrag(self.claimed_id)[0] def compatibilityMode(self): return self.preferredNamespace() != OPENID_2_0_MESSAGE_NS @@ -306,7 +306,7 @@ def normalizeURL(url): except ValueError as why: raise DiscoveryFailure('Normalizing identifier: %s' % (why[0],), None) else: - return urlparse.urldefrag(normalized)[0] + return urldefrag(normalized)[0] def normalizeXRI(xri): @@ -448,7 +448,7 @@ def discoverNoYadis(uri): def discoverURI(uri): - parsed = urlparse.urlparse(uri) + parsed = urlparse(uri) if parsed[0] and parsed[1]: if parsed[0] not in ['http', 'https']: raise DiscoveryFailure('URI scheme is not HTTP or HTTPS', None) diff --git a/openid/message.py b/openid/message.py index 3e65b185..a55b2291 100644 --- a/openid/message.py +++ b/openid/message.py @@ -3,11 +3,11 @@ from __future__ import unicode_literals import copy -import urllib import warnings import six from lxml import etree as ElementTree +from six.moves.urllib.parse import urlencode from openid import kvform, oidutil @@ -394,7 +394,7 @@ def toKVForm(self): def toURLEncoded(self): """Generate an x-www-urlencoded string""" args = sorted(self.toPostArgs().items()) - return urllib.urlencode(args) + return urlencode(args) def _fixNS(self, namespace): """Convert an input value into the internally used values of diff --git a/openid/oidutil.py b/openid/oidutil.py index b164165b..b5ddf31d 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -9,9 +9,9 @@ import binascii import logging import warnings -from urllib import urlencode import six +from six.moves.urllib.parse import urlencode __all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML'] diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index 7c1b801a..a555d06d 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -11,9 +11,9 @@ import logging import re -from urlparse import urlsplit, urlunsplit import six +from six.moves.urllib.parse import urlsplit, urlunsplit from openid import urinorm from openid.oidutil import string_to_text @@ -27,7 +27,6 @@ 'verifyReturnTo', ] - _LOGGER = logging.getLogger(__name__) ############################################ diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 04990c5b..379f4500 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -2,7 +2,8 @@ from __future__ import unicode_literals import os.path -import urlparse + +from six.moves.urllib.parse import urljoin from openid.yadis.constants import YADIS_HEADER_NAME from openid.yadis.discover import DiscoveryFailure, DiscoveryResult @@ -105,7 +106,7 @@ def generateSample(test_name, base_url, def generateResult(base_url, input_name, id_name, result_name, success): - input_url = urlparse.urljoin(base_url, input_name) + input_url = urljoin(base_url, input_name) # If the name is None then we expect the protocol to fail, which # we represent by None @@ -124,12 +125,12 @@ def generateResult(base_url, input_name, id_name, result_name, success): else: ctype = None - id_url = urlparse.urljoin(base_url, id_name) + id_url = urljoin(base_url, id_name) result = DiscoveryResult(input_url) result.normalized_uri = id_url if success: - result.xrds_uri = urlparse.urljoin(base_url, result_name) + result.xrds_uri = urljoin(base_url, result_name) result.content_type = ctype result.response_text = content return input_url, result diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 0cd6a62a..268a641f 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -2,8 +2,8 @@ import time import unittest -import urlparse +from six.moves.urllib.parse import parse_qsl, urlparse from testfixtures import LogCapture, StringComparison from openid import association, cryptutil, fetchers, kvform, oidutil @@ -39,7 +39,7 @@ def mkSuccess(endpoint, q): def parseQuery(qs): q = {} - for (k, v) in urlparse.parse_qsl(qs): + for (k, v) in parse_qsl(qs): assert k not in q q[k] = v return q @@ -159,7 +159,7 @@ def run(): redirect_url = request.redirectURL(trust_root, return_to, immediate) - parsed = urlparse.urlparse(redirect_url) + parsed = urlparse(redirect_url) qs = parsed[4] q = parseQuery(qs) new_return_to = q['openid.return_to'] @@ -174,7 +174,7 @@ def run(): assert new_return_to.startswith(return_to) assert redirect_url.startswith(server_url) - parsed = urlparse.urlparse(new_return_to) + parsed = urlparse(new_return_to) query = parseQuery(parsed[4]) query.update({ 'openid.mode': 'id_res', diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 414f7ed7..d294d3f9 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -3,7 +3,8 @@ import os.path import unittest -from urlparse import urlsplit + +from six.moves.urllib.parse import urlsplit from openid import fetchers, message from openid.consumer import discover diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 24d33cfb..fd97dbf8 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -2,11 +2,10 @@ from __future__ import unicode_literals import unittest -import urllib import warnings -from urlparse import parse_qs from lxml import etree as ElementTree +from six.moves.urllib.parse import parse_qs, quote from testfixtures import ShouldWarn from openid.extensions import sreg @@ -435,8 +434,7 @@ def test_toKVForm(self): self.assertEqual(self.msg.toKVForm(), 'error:unit test\nmode:error\nns:%s\n' % OPENID2_NS) def _test_urlencoded(self, s): - expected = ('openid.error=unit+test&openid.mode=error&openid.ns=%s&xey=value' % - urllib.quote(OPENID2_NS, '')) + expected = ('openid.error=unit+test&openid.mode=error&openid.ns=%s&xey=value' % quote(OPENID2_NS, '')) self.assertEqual(s, expected) def test_toURLEncoded(self): diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 149508a3..88c1c021 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -5,9 +5,9 @@ import unittest import warnings from functools import partial -from urlparse import parse_qs, parse_qsl, urlparse from mock import sentinel +from six.moves.urllib.parse import parse_qs, parse_qsl, urlparse from testfixtures import LogCapture, ShouldWarn, StringComparison from openid import association, cryptutil, oidutil diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index 8e7add87..eec94c85 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -8,7 +8,8 @@ import re import types import unittest -import urlparse + +from six.moves.urllib.parse import urlparse from openid import fetchers from openid.yadis.discover import DiscoveryFailure, discover @@ -50,7 +51,7 @@ def __init__(self, base_url): def fetch(self, url, headers, body): current_url = url while True: - parsed = urlparse.urlparse(current_url) + parsed = urlparse(current_url) path = parsed[2][1:] try: data = discoverdata.generateSample(path, self.base_url) diff --git a/openid/urinorm.py b/openid/urinorm.py index 4e1530e0..e8cc60ec 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -2,10 +2,9 @@ from __future__ import unicode_literals import string -from urllib import quote, unquote, urlencode -from urlparse import parse_qsl, urlsplit, urlunsplit import six +from six.moves.urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit, urlunsplit from .oidutil import string_to_text diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index ea394f15..58567922 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -8,9 +8,9 @@ import re import warnings -from urllib import quote import six +from six.moves.urllib.parse import quote from openid.urinorm import GEN_DELIMS, PERCENT_ENCODING_CHARACTER, SUB_DELIMS diff --git a/openid/yadis/xrires.py b/openid/yadis/xrires.py index 26bffb34..2bc671f2 100644 --- a/openid/yadis/xrires.py +++ b/openid/yadis/xrires.py @@ -1,7 +1,7 @@ """XRI resolution.""" from __future__ import unicode_literals -from urllib import urlencode +from six.moves.urllib.parse import urlencode from openid import fetchers from openid.oidutil import string_to_text From 1082699a5be1111b439fa21fe0123d5f80f91cbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 11:58:26 +0200 Subject: [PATCH 076/151] Use StringIO and BytesIO from six --- openid/fetchers.py | 15 ++++++++------- openid/test/test_fetchers.py | 2 +- openid/test/test_parsehtml.py | 2 +- openid/yadis/discover.py | 8 ++++---- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/openid/fetchers.py b/openid/fetchers.py index 3bcb9d7d..bb85c3a0 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -1,11 +1,12 @@ """This module contains the HTTP fetcher interface and several implementations.""" from __future__ import unicode_literals -import cStringIO import sys import time import urllib2 +from six import BytesIO + import openid import openid.urinorm @@ -277,7 +278,7 @@ def _parseHeaders(self, header_file): # Remove the status line from the beginning of the input unused_http_status_line = header_file.readline().lower() - if unused_http_status_line.startswith('http/1.1 100 '): + if unused_http_status_line.startswith(b'http/1.1 100 '): unused_http_status_line = header_file.readline() unused_http_status_line = header_file.readline() @@ -291,15 +292,15 @@ def _parseHeaders(self, header_file): headers = {} for line in lines: try: - name, value = line.split(':', 1) + name, value = line.split(b':', 1) except ValueError: raise HTTPError( "Malformed HTTP header line in response: %r" % (line,)) - value = value.strip() + value = value.strip().decode('utf-8') # HTTP headers are case-insensitive - name = name.lower() + name = name.lower().decode('utf-8') headers[name] = value return headers @@ -340,7 +341,7 @@ def fetch(self, url, body=None, headers=None): if not self._checkURL(url): raise HTTPError("Fetching URL not allowed: %r" % (url,)) - data = cStringIO.StringIO() + data = BytesIO() def write_data(chunk): if data.tell() > 1024 * MAX_RESPONSE_KB: @@ -348,7 +349,7 @@ def write_data(chunk): else: return data.write(chunk) - response_header_data = cStringIO.StringIO() + response_header_data = BytesIO() c.setopt(pycurl.WRITEFUNCTION, write_data) c.setopt(pycurl.HEADERFUNCTION, response_header_data.write) c.setopt(pycurl.TIMEOUT, off) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 5594a522..77dac9c4 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -5,11 +5,11 @@ import urllib2 import warnings from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer -from cStringIO import StringIO from urllib import addinfourl import responses from mock import Mock, patch, sentinel +from six import StringIO from openid import fetchers diff --git a/openid/test/test_parsehtml.py b/openid/test/test_parsehtml.py index 8609322a..214aac00 100644 --- a/openid/test/test_parsehtml.py +++ b/openid/test/test_parsehtml.py @@ -2,9 +2,9 @@ from __future__ import unicode_literals import unittest -from StringIO import StringIO from mock import sentinel +from six import StringIO from openid.yadis.parsehtml import MetaNotFound, findHTMLMeta, xpath_lower_case diff --git a/openid/yadis/discover.py b/openid/yadis/discover.py index e1b494fa..769fb74a 100644 --- a/openid/yadis/discover.py +++ b/openid/yadis/discover.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals -from StringIO import StringIO +from six import BytesIO, StringIO from openid import fetchers from openid.yadis.constants import YADIS_ACCEPT_HEADER, YADIS_CONTENT_TYPE, YADIS_HEADER_NAME @@ -138,14 +138,14 @@ def whereIsYadis(resp): encoding = 'UTF-8' try: - content = resp.body.decode(encoding) + buff = StringIO(resp.body.decode(encoding)) except UnicodeError: # Keep encoded version in case yadis location can be found before encoding shut this up. # Possible errors will be caught lower. - content = resp.body + buff = BytesIO(resp.body) try: - yadis_loc = findHTMLMeta(StringIO(content)) + yadis_loc = findHTMLMeta(buff) except (MetaNotFound, UnicodeError): # UnicodeError: Response body could not be encoded and xrds location # could not be found before troubles occurs. From fbec855b8472c6d6ae40f477b95f5762a919820e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 12:49:24 +0200 Subject: [PATCH 077/151] Use six imports for fetchers --- examples/consumer.py | 3 +-- examples/server.py | 3 ++- openid/fetchers.py | 30 ++++++++++++++---------------- openid/test/test_fetchers.py | 21 +++++++++++---------- 4 files changed, 28 insertions(+), 29 deletions(-) diff --git a/examples/consumer.py b/examples/consumer.py index 5898ee46..19b29bdf 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -14,11 +14,10 @@ import cgitb import optparse import sys -from six.moves.urllib.parse import urlparse, parse_qsl, urljoin -from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from Cookie import SimpleCookie import six +from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from six.moves.urllib.parse import parse_qsl, urljoin, urlparse diff --git a/examples/server.py b/examples/server.py index f9174d01..d8434225 100644 --- a/examples/server.py +++ b/examples/server.py @@ -10,7 +10,8 @@ import optparse import sys import time -from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer + +from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer from six.moves.urllib.parse import parse_qsl, urlparse diff --git a/openid/fetchers.py b/openid/fetchers.py index bb85c3a0..611bc115 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -3,9 +3,10 @@ import sys import time -import urllib2 from six import BytesIO +from six.moves.urllib.error import HTTPError as UrllibHTTPError +from six.moves.urllib.request import Request, urlopen import openid import openid.urinorm @@ -55,7 +56,7 @@ def createHTTPFetcher(): 1. requests 2. curl 3. httplib2 - 4. urllib2 + 4. urllib """ if requests is not None: fetcher = RequestsFetcher() @@ -206,12 +207,11 @@ def fetch(self, *args, **kwargs): class Urllib2Fetcher(HTTPFetcher): - """An C{L{HTTPFetcher}} that uses urllib2. - """ + """An C{L{HTTPFetcher}} that uses urllib.""" # Parameterized for the benefit of testing frameworks, see # http://trac.openidenabled.com/trac/ticket/85 - urlopen = staticmethod(urllib2.urlopen) + urlopen = staticmethod(urlopen) def fetch(self, url, body=None, headers=None): if not _allowedURL(url): @@ -220,31 +220,29 @@ def fetch(self, url, body=None, headers=None): if headers is None: headers = {} - headers.setdefault( - 'User-Agent', - "%s Python-urllib/%s" % (USER_AGENT, urllib2.__version__,)) + headers.setdefault('User-Agent', "%s Python-urllib" % USER_AGENT) - req = urllib2.Request(url, data=body, headers=headers) + req = Request(url, data=body, headers=headers) try: f = self.urlopen(req) try: return self._makeResponse(f) finally: f.close() - except urllib2.HTTPError as why: + except UrllibHTTPError as why: try: return self._makeResponse(why) finally: why.close() - def _makeResponse(self, urllib2_response): + def _makeResponse(self, urllib_response): resp = HTTPResponse() - resp.body = urllib2_response.read(MAX_RESPONSE_KB * 1024) - resp.final_url = urllib2_response.geturl() - resp.headers = dict(urllib2_response.info().items()) + resp.body = urllib_response.read(MAX_RESPONSE_KB * 1024) + resp.final_url = urllib_response.geturl() + resp.headers = dict(urllib_response.info().items()) - if hasattr(urllib2_response, 'code'): - resp.status = urllib2_response.code + if hasattr(urllib_response, 'code'): + resp.status = urllib_response.code else: resp.status = 200 diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 77dac9c4..327092cc 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -2,14 +2,15 @@ import socket import unittest -import urllib2 import warnings -from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer -from urllib import addinfourl import responses from mock import Mock, patch, sentinel from six import StringIO +from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer +from six.moves.urllib.error import HTTPError, URLError +from six.moves.urllib.request import BaseHandler, OpenerDirector, install_opener +from six.moves.urllib.response import addinfourl from openid import fetchers @@ -300,11 +301,11 @@ def test_notWrapped(self): self.assertNotIsInstance(fetchers.getDefaultFetcher(), fetchers.ExceptionWrappingFetcher) - with self.assertRaises(urllib2.URLError): + with self.assertRaises(URLError): fetchers.fetch('http://invalid.janrain.com/') -class TestHandler(urllib2.BaseHandler): +class TestHandler(BaseHandler): """Urllib2 test handler.""" def __init__(self, http_mock): @@ -322,13 +323,13 @@ class TestUrllib2Fetcher(unittest.TestCase): def setUp(self): self.http_mock = Mock(side_effect=[]) - opener = urllib2.OpenerDirector() + opener = OpenerDirector() opener.add_handler(TestHandler(self.http_mock)) - urllib2.install_opener(opener) + install_opener(opener) def tearDown(self): # Uninstall custom opener - urllib2.install_opener(None) + install_opener(None) def add_response(self, url, status_code, headers, body=None): response = addinfourl(StringIO(body or ''), headers, url, status_code) @@ -363,8 +364,8 @@ def test_invalid_url(self): def test_connection_error(self): # Test connection error - self.http_mock.side_effect = urllib2.HTTPError('http://example.cz/error/', 500, 'Error message', - {'Content-Type': 'text/plain'}, StringIO('BODY')) + self.http_mock.side_effect = HTTPError('http://example.cz/error/', 500, 'Error message', + {'Content-Type': 'text/plain'}, StringIO('BODY')) response = self.fetcher.fetch('http://example.cz/error/') expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, 'BODY') assertResponse(expected, response) From afe8559a35e9a018b0ac1b6d077d973b8b93e457 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 13:39:21 +0200 Subject: [PATCH 078/151] Make iterators python3 compatible --- examples/djopenid/consumer/views.py | 2 +- examples/djopenid/server/views.py | 2 +- examples/djopenid/util.py | 3 ++- examples/server.py | 2 +- openid/consumer/consumer.py | 4 ++-- openid/extensions/ax.py | 14 ++++++------ openid/extensions/pape.py | 8 +++---- openid/extensions/sreg.py | 4 ++-- openid/fetchers.py | 2 +- openid/message.py | 22 +++++++++++-------- openid/store/memstore.py | 6 ++--- openid/test/test_association_response.py | 2 +- openid/test/test_ax.py | 2 +- openid/test/test_fetchers.py | 2 +- openid/test/test_message.py | 28 +++++++----------------- openid/test/test_sreg.py | 2 +- 16 files changed, 49 insertions(+), 56 deletions(-) diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index 776dd0dc..6a8e2a45 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -104,7 +104,7 @@ def startOpenID(request): # the response. requested_policies = [] policy_prefix = 'policy_' - for k, v in request.POST.iteritems(): + for k, v in six.iteritems(request.POST): if k.startswith(policy_prefix): policy_attr = k[len(policy_prefix):] if policy_attr in PAPE_POLICIES: diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index 9b799ee7..55809608 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -257,7 +257,7 @@ def displayResponse(request, openid_response): r = http.HttpResponse(webresponse.body) r.status_code = webresponse.code - for header, value in webresponse.headers.iteritems(): + for header, value in webresponse.headers.items(): r[header] = value return r diff --git a/examples/djopenid/util.py b/examples/djopenid/util.py index 2e0a97c5..39203c46 100644 --- a/examples/djopenid/util.py +++ b/examples/djopenid/util.py @@ -3,6 +3,7 @@ """ from __future__ import unicode_literals +import six from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.db import connection @@ -84,7 +85,7 @@ def normalDict(request_data): values are lists, because in OpenID, each key in the query arg set can have at most one value. """ - return dict((k, v) for k, v in request_data.iteritems()) + return dict((k, v) for k, v in six.iteritems(request_data)) def renderXRDS(request, type_uris, endpoint_urls): diff --git a/examples/server.py b/examples/server.py index d8434225..934909eb 100644 --- a/examples/server.py +++ b/examples/server.py @@ -234,7 +234,7 @@ def displayResponse(self, response): return self.send_response(webresponse.code) - for header, value in webresponse.headers.iteritems(): + for header, value in webresponse.headers.items(): self.send_header(header, value) self.writeUserHeader() self.end_headers() diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 2d0b1c8c..c08dc62c 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -864,7 +864,7 @@ def _verifyReturnToArgs(query): # Make sure all non-OpenID arguments in the response are also # in the signed return_to. bare_args = message.getArgs(BARE_NS) - for pair in bare_args.iteritems(): + for pair in six.iteritems(bare_args): if pair not in parsed_args: raise ProtocolError("Parameter %s not in return_to URL" % (pair[0],)) @@ -1743,7 +1743,7 @@ def getSignedNS(self, ns_uri): """ msg_args = self.message.getArgs(ns_uri) - for key in msg_args.iterkeys(): + for key in msg_args: if not self.isSigned(ns_uri, key): _LOGGER.info("SuccessResponse.getSignedNS: (%s, %s) not signed.", ns_uri, key) return None diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index 6bbad7c0..faffdb4f 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -226,7 +226,7 @@ def getExtensionArgs(self): ax_args = self._newArgs() - for type_uri, attribute in self.requested_attributes.iteritems(): + for type_uri, attribute in six.iteritems(self.requested_attributes): if attribute.alias is None: alias = aliases.add(type_uri) else: @@ -272,7 +272,7 @@ def getRequiredAttrs(self): @rtype: List[six.text_type] """ required = [] - for type_uri, attribute in self.requested_attributes.iteritems(): + for type_uri, attribute in six.iteritems(self.requested_attributes): if attribute.required: required.append(type_uri) @@ -345,7 +345,7 @@ def parseExtensionArgs(self, ax_args): aliases = NamespaceMap() - for key, value in ax_args.iteritems(): + for key, value in six.iteritems(ax_args): if key.startswith('type.'): alias = key[5:] type_uri = value @@ -388,7 +388,7 @@ def iterAttrs(self): """Iterate over the AttrInfo objects that are contained in this fetch_request. """ - return self.requested_attributes.itervalues() + return six.itervalues(self.requested_attributes) def __iter__(self): """Iterate over the attribute type URIs in this fetch_request @@ -458,7 +458,7 @@ def _getExtensionKVArgs(self, aliases=None): ax_args = {} - for type_uri, values in self.data.iteritems(): + for type_uri, values in six.iteritems(self.data): alias = aliases.add(type_uri) ax_args['type.' + alias] = type_uri @@ -490,14 +490,14 @@ def parseExtensionArgs(self, ax_args): aliases = NamespaceMap() - for key, value in ax_args.iteritems(): + for key, value in six.iteritems(ax_args): if key.startswith('type.'): type_uri = value alias = key[5:] checkAlias(alias) aliases.addAlias(type_uri, alias) - for type_uri, alias in aliases.iteritems(): + for type_uri, alias in aliases.items(): try: count_s = ax_args['count.' + alias] except KeyError: diff --git a/openid/extensions/pape.py b/openid/extensions/pape.py index d69c4dbd..b9c38129 100644 --- a/openid/extensions/pape.py +++ b/openid/extensions/pape.py @@ -88,7 +88,7 @@ def _getAlias(self, auth_level_uri): @raises KeyError: if no alias is defined """ - for (alias, existing_uri) in self.auth_level_aliases.iteritems(): + for (alias, existing_uri) in self.auth_level_aliases.items(): if auth_level_uri == existing_uri: return alias @@ -297,7 +297,7 @@ def __init__(self, auth_policies=None, auth_time=None, if auth_levels is None: auth_levels = {} - for uri, level in auth_levels.iteritems(): + for uri, level in auth_levels.items(): self.setAuthLevel(uri, level) def setAuthLevel(self, level_uri, level, alias=None): @@ -417,7 +417,7 @@ def parseExtensionArgs(self, args, is_openid1, strict=False): self.auth_policies = auth_policies - for (key, val) in args.iteritems(): + for (key, val) in six.iteritems(args): if key.startswith('auth_level.'): alias = key[11:] @@ -459,7 +459,7 @@ def getExtensionArgs(self): 'auth_policies': ' '.join(self.auth_policies), } - for level_type, level in self.auth_levels.iteritems(): + for level_type, level in self.auth_levels.items(): alias = self._getAlias(level_type) ns_args['auth_level.ns.%s' % (alias,)] = level_type ns_args['auth_level.%s' % (alias,)] = six.text_type(level) diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 7523e563..1c5e9d61 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -502,13 +502,13 @@ def items(self): return self.data.items() def iteritems(self): - return self.data.iteritems() + return six.iteritems(self.data) def keys(self): return self.data.keys() def iterkeys(self): - return self.data.iterkeys() + return six.iterkeys(self.data) def has_key(self, key): return key in self diff --git a/openid/fetchers.py b/openid/fetchers.py index 611bc115..11a6b3a6 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -320,7 +320,7 @@ def fetch(self, url, body=None, headers=None): header_list = [] if headers is not None: - for header_name, header_value in headers.iteritems(): + for header_name, header_value in headers.items(): header_list.append('%s: %s' % (header_name, header_value)) c = pycurl.Curl() diff --git a/openid/message.py b/openid/message.py index a55b2291..d550b299 100644 --- a/openid/message.py +++ b/openid/message.py @@ -205,7 +205,7 @@ def _fromOpenIDArgs(cls, openid_args): # Other arguments namespaces = {} ns_args = [] - for key, value in openid_args.iteritems(): + for key, value in six.iteritems(openid_args): key = string_to_text(key, "Binary keys in message creations are deprecated. Use text input instead.") value = string_to_text(value, "Binary values in message creations are deprecated. Use text input instead.") if '.' not in key: @@ -303,7 +303,7 @@ def toPostArgs(self): args = {} # Add namespace definitions to the output - for ns_uri, alias in self.namespaces.iteritems(): + for ns_uri, alias in self.namespaces.items(): if self.namespaces.isImplicit(ns_uri): continue if alias == NULL_NAMESPACE: @@ -312,7 +312,7 @@ def toPostArgs(self): ns_key = 'openid.ns.' + alias args[ns_key] = ns_uri - for (ns_uri, ns_key), value in self.args.iteritems(): + for (ns_uri, ns_key), value in six.iteritems(self.args): key = self.getKey(ns_uri, ns_key) args[key] = value @@ -324,7 +324,7 @@ def toArgs(self): # FIXME - undocumented exception post_args = self.toPostArgs() kvargs = {} - for k, v in post_args.iteritems(): + for k, v in six.iteritems(post_args): if not k.startswith('openid.'): raise ValueError( 'This message can only be encoded as a POST, because it ' @@ -362,7 +362,7 @@ def toFormMarkup(self, action_url, form_tag_attrs=None, form = ElementTree.Element('form') if form_tag_attrs: - for name, attr in form_tag_attrs.iteritems(): + for name, attr in form_tag_attrs.items(): form.attrib[name] = attr form.attrib['action'] = action_url @@ -370,7 +370,7 @@ def toFormMarkup(self, action_url, form_tag_attrs=None, form.attrib['accept-charset'] = 'UTF-8' form.attrib['enctype'] = 'application/x-www-form-urlencoded' - for name, value in self.toPostArgs().iteritems(): + for name, value in six.iteritems(self.toPostArgs()): attrs = {'type': 'hidden', 'name': name, 'value': value} form.append(ElementTree.Element('input', attrs)) @@ -489,7 +489,7 @@ def getArgs(self, namespace): return dict([ (ns_key, value) for ((pair_ns, ns_key), value) - in self.args.iteritems() + in six.iteritems(self.args) if pair_ns == namespace ]) @@ -500,7 +500,7 @@ def updateArgs(self, namespace, updates): @type updates: Dict[six.text_type, six.text_type] """ namespace = self._fixNS(namespace) - for k, v in updates.iteritems(): + for k, v in six.iteritems(updates): self.setArg(namespace, k, v) def setArg(self, namespace, key, value): @@ -579,12 +579,16 @@ def iterAliases(self): """Return an iterator over the aliases""" return iter(self.alias_to_namespace) + def items(self): + """Iterate over the mapping.""" + return self.namespace_to_alias.items() + def iteritems(self): """Iterate over the mapping @returns: iterator of (namespace_uri, alias) """ - return self.namespace_to_alias.iteritems() + return six.iteritems(self.namespace_to_alias) def addAlias(self, namespace_uri, desired_alias, implicit=False): """Add an alias from this namespace URI to the desired alias diff --git a/openid/store/memstore.py b/openid/store/memstore.py index d2a74f41..8d271d66 100644 --- a/openid/store/memstore.py +++ b/openid/store/memstore.py @@ -44,7 +44,7 @@ def cleanup(self): @return: tuple of (removed associations, remaining associations) """ remove = [] - for handle, assoc in self.assocs.iteritems(): + for handle, assoc in six.iteritems(self.assocs): if assoc.getExpiresIn() == 0: remove.append(handle) for handle in remove: @@ -98,7 +98,7 @@ def useNonce(self, server_url, timestamp, salt): def cleanupNonces(self): now = time.time() expired = [] - for anonce in self.nonces.iterkeys(): + for anonce in self.nonces: if abs(anonce[1] - now) > nonce.SKEW: # removing items while iterating over the set could be bad. expired.append(anonce) @@ -110,7 +110,7 @@ def cleanupNonces(self): def cleanupAssociations(self): remove_urls = [] removed_assocs = 0 - for server_url, assocs in self.server_assocs.iteritems(): + for server_url, assocs in six.iteritems(self.server_assocs): removed, remaining = assocs.cleanup() removed_assocs += removed if not remaining: diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 0e1fc5e7..4ef61262 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -117,7 +117,7 @@ class ExtractAssociationSessionTypeMismatch(BaseAssocTest): def mkTest(requested_session_type, response_session_type, openid1=False): def test(self): assoc_session = DummyAssocationSession(requested_session_type) - keys = association_response_values.keys() + keys = list(association_response_values.keys()) if openid1: keys.remove('ns') msg = mkAssocResponse(*keys) diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index adff4d9f..0b059618 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -217,7 +217,7 @@ def test_getExtensionArgs_noAlias(self): ) self.msg.add(attr) ax_args = self.msg.getExtensionArgs() - for k, v in ax_args.iteritems(): + for k, v in ax_args.items(): if v == attr.type_uri and k.startswith('type.'): alias = k[5:] break diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 327092cc..6501b181 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -62,7 +62,7 @@ def assertResponse(expected, actual): # TODO: Delete these pops got_headers.pop('date', None) got_headers.pop('server', None) - for k, v in expected.headers.iteritems(): + for k, v in expected.headers.items(): assert got_headers[k] == v, (k, v, got_headers[k]) diff --git a/openid/test/test_message.py b/openid/test/test_message.py index fd97dbf8..8a636013 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -734,12 +734,12 @@ def _checkForm(self, html, message_, action_url, form = input_tree.getroot() # Check required form attributes - for k, v in self.required_form_attrs.iteritems(): + for k, v in self.required_form_attrs.items(): assert form.attrib[k] == v, \ "Expected '%s' for required form attribute '%s', got '%s'" % (v, k, form.attrib[k]) # Check extra form attributes - for k, v in form_tag_attrs.iteritems(): + for k, v in form_tag_attrs.items(): # Skip attributes that already passed the required # attribute check, since they should be ignored by the @@ -756,7 +756,7 @@ def _checkForm(self, html, message_, action_url, # For each post arg, make sure there is a hidden with that # value. Make sure there are no other hiddens. - for name, value in message_.toPostArgs().iteritems(): + for name, value in message_.toPostArgs().items(): for e in hiddens: if e.attrib['name'] == name: assert e.attrib['value'] == value, \ @@ -938,26 +938,14 @@ def test_iteration(self): self.assertTrue(nsm.isDefined(uripat % (n - 1))) nsm.add(uripat % n) + for (uri, alias) in nsm.items(): + self.assertEqual(uri[22:], alias[3:]) + for (uri, alias) in nsm.iteritems(): self.assertEqual(uri[22:], alias[3:]) - i = 0 - it = nsm.iterAliases() - try: - while True: - it.next() - i += 1 - except StopIteration: - self.assertEqual(i, 23) - - i = 0 - it = nsm.iterNamespaceURIs() - try: - while True: - it.next() - i += 1 - except StopIteration: - self.assertEqual(i, 23) + self.assertEqual(len(tuple(nsm.iterAliases())), 23) + self.assertEqual(len(tuple(nsm.iterNamespaceURIs())), 23) if __name__ == '__main__': diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index b56c5c63..1224cbd9 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -430,7 +430,7 @@ def test_fromSuccessResponse_unsigned(self): success_resp = DummySuccessResponse(message, {}) sreg_resp = sreg.SRegResponse.fromSuccessResponse(success_resp, signed_only=False) - self.assertEqual(sreg_resp.items(), [('nickname', 'The Mad Stork')]) + self.assertEqual(list(sreg_resp.items()), [('nickname', 'The Mad Stork')]) class SendFieldsTest(unittest.TestCase): From ea790df687c2581fd5526a9598aad5a7eb35bfe4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 13:43:53 +0200 Subject: [PATCH 079/151] Python 3 compatible boolean conversions --- openid/extensions/draft/pape2.py | 5 ++++- openid/extensions/pape.py | 5 ++++- openid/extensions/sreg.py | 5 ++++- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index 529f329d..e8dec915 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -62,10 +62,13 @@ def __init__(self, preferred_auth_policies=None, max_auth_age=None): self.preferred_auth_policies = preferred_auth_policies self.max_auth_age = max_auth_age - def __nonzero__(self): + def __bool__(self): return bool(self.preferred_auth_policies or self.max_auth_age is not None) + def __nonzero__(self): + return self.__bool__() + def addPolicyURI(self, policy_uri): """Add an acceptable authentication policy URI to this request diff --git a/openid/extensions/pape.py b/openid/extensions/pape.py index b9c38129..d834824c 100644 --- a/openid/extensions/pape.py +++ b/openid/extensions/pape.py @@ -130,11 +130,14 @@ def __init__(self, preferred_auth_policies=None, max_auth_age=None, for auth_level in preferred_auth_level_types: self.addAuthLevel(auth_level) - def __nonzero__(self): + def __bool__(self): return bool(self.preferred_auth_policies or self.max_auth_age is not None or self.preferred_auth_level_types) + def __nonzero__(self): + return self.__bool__() + def addPolicyURI(self, policy_uri): """Add an acceptable authentication policy URI to this request diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 1c5e9d61..c557e216 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -524,5 +524,8 @@ def __getitem__(self, field_name): checkFieldName(field_name) return self.data[field_name] - def __nonzero__(self): + def __bool__(self): return bool(self.data) + + def __nonzero__(self): + return self.__bool__() From ac5e7589990f44fa4e6e0da93748c9b4407e1cf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 13:49:56 +0200 Subject: [PATCH 080/151] Update python builtin functions --- admin/builddiscover.py | 4 ++-- openid/cryptutil.py | 2 +- openid/extensions/pape.py | 2 +- openid/store/filestore.py | 4 ++-- openid/store/sqlstore.py | 6 +++++- openid/test/discoverdata.py | 4 ++-- openid/test/test_accept.py | 2 +- openid/test/test_cryptutil.py | 6 +++--- openid/test/test_dh.py | 2 +- openid/test/test_discover.py | 2 +- openid/test/test_etxrd.py | 8 ++++---- openid/test/test_oidutil.py | 2 +- 12 files changed, 24 insertions(+), 20 deletions(-) diff --git a/admin/builddiscover.py b/admin/builddiscover.py index 4b572734..0ac4cc88 100755 --- a/admin/builddiscover.py +++ b/admin/builddiscover.py @@ -44,7 +44,7 @@ def writeTestFile(test_name): test_name, template, base_url, discoverdata.example_xrds) out_file_name = os.path.join(out_dir, test_name) - out_file = file(out_file_name, 'w') + out_file = open(out_file_name, 'w') out_file.write(data) manifest = [manifest_header] @@ -61,7 +61,7 @@ def writeTestFile(test_name): manifest.append('\n') manifest_file_name = os.path.join(out_dir, 'manifest.txt') - manifest_file = file(manifest_file_name, 'w') + manifest_file = open(manifest_file_name, 'w') for chunk in manifest: manifest_file.write(chunk) manifest_file.close() diff --git a/openid/cryptutil.py b/openid/cryptutil.py index fd7b30ca..858de483 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -142,7 +142,7 @@ def binaryToLong(s): # have Windows equivalent here, but for now, require pycrypto # on Windows. try: - _urandom = file('/dev/urandom', 'rb') + _urandom = open('/dev/urandom', 'rb') except IOError: raise ImportError('No adequate source of randomness found!') else: diff --git a/openid/extensions/pape.py b/openid/extensions/pape.py index d834824c..56a0b923 100644 --- a/openid/extensions/pape.py +++ b/openid/extensions/pape.py @@ -76,7 +76,7 @@ def _addAuthLevelAlias(self, auth_level_uri, alias=None): def _generateAlias(self): """Return an unused auth level alias""" - for i in xrange(1000): + for i in range(1000): alias = 'cust%d' % (i,) if alias not in self.auth_level_aliases: return alias diff --git a/openid/store/filestore.py b/openid/store/filestore.py index a7a3e7f5..58ee2a72 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -265,7 +265,7 @@ def getAssociation(self, server_url, handle=None): def _getAssociation(self, filename): try: - assoc_file = file(filename, 'rb') + assoc_file = open(filename, 'rb') except IOError as why: if why.errno == ENOENT: # No association exists for that URL and handle @@ -349,7 +349,7 @@ def _allAssocs(self): association_filenames = [os.path.join(self.association_dir, f) for f in os.listdir(self.association_dir)] for association_filename in association_filenames: try: - association_file = file(association_filename, 'rb') + association_file = open(association_filename, 'rb') except IOError as why: if why.errno == ENOENT: _LOGGER.exception("%s disappeared during %s._allAssocs", diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index 17912982..72c8f5ee 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -314,6 +314,10 @@ class SQLiteStore(SQLStore): All other methods are implementation details. """ + try: + import sqlite3 + except ImportError: + sqlite3 = None create_nonce_sql = """ CREATE TABLE %(nonces)s ( @@ -363,7 +367,7 @@ def blobDecode(self, buf): return six.binary_type(buf) def blobEncode(self, s): - return buffer(s) + return self.sqlite3.Binary(s) def useNonce(self, *args, **kwargs): # Older versions of the sqlite wrapper do not raise diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 379f4500..1540c344 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -49,7 +49,7 @@ def getDataName(*components): def getExampleXRDS(): filename = getDataName('example-xrds.xml') - return file(filename).read() + return open(filename).read() example_xrds = getExampleXRDS() @@ -59,7 +59,7 @@ def getExampleXRDS(): def readTests(filename): - data = file(filename).read() + data = open(filename).read() tests = {} for case in data.split('\f\n'): (name, content) = case.split('\n', 1) diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index d499959b..8ca25fa9 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -15,7 +15,7 @@ def getTestData(): filename = os.path.join(os.path.dirname(__file__), 'data', 'accept.txt') i = 1 lines = [] - for line in file(filename): + for line in open(filename): lines.append((i, line.decode('utf-8'))) i += 1 return lines diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index 7697ada7..5a866110 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -43,7 +43,7 @@ class TestLongBinary(unittest.TestCase): def test_binaryLongConvert(self): MAX = sys.maxsize - for iteration in xrange(500): + for iteration in range(500): n = 0 for i in range(10): n += long(random.randrange(MAX)) @@ -75,7 +75,7 @@ class TestLongToBase64(unittest.TestCase): """Test `longToBase64` function.""" def test_longToBase64(self): - f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) + f = open(os.path.join(os.path.dirname(__file__), 'n2b64')) try: for line in f: parts = line.strip().split(' ') @@ -88,7 +88,7 @@ class TestBase64ToLong(unittest.TestCase): """Test `Base64ToLong` function.""" def test_base64ToLong(self): - f = file(os.path.join(os.path.dirname(__file__), 'n2b64')) + f = open(os.path.join(os.path.dirname(__file__), 'n2b64')) try: for line in f: parts = line.strip().split(' ') diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 707eb4d3..c9a2c561 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -67,7 +67,7 @@ def test_exchange(self): assert s1 != s2 def test_public(self): - f = file(os.path.join(os.path.dirname(__file__), 'dhpriv')) + f = open(os.path.join(os.path.dirname(__file__), 'dhpriv')) dh = DiffieHellman.fromDefaults() try: for line in f: diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index d294d3f9..c99fc493 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -206,7 +206,7 @@ def readDataFile(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) filename = os.path.join( module_directory, 'data', 'test_discover', filename) - return file(filename).read() + return open(filename).read() class TestDiscovery(BaseTestDiscovery): diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index 6387dbb7..69325843 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -87,7 +87,7 @@ def test_xxe(self): class TestServiceParser(unittest.TestCase): def setUp(self): - self.xmldoc = file(XRD_FILE).read() + self.xmldoc = open(XRD_FILE).read() self.yadis_url = 'http://unittest.url/' def _getServices(self, flt=None): @@ -155,7 +155,7 @@ def testGetSeveralForOne(self): def testNoXRDS(self): """Make sure that we get an exception when an XRDS element is not present""" - self.xmldoc = file(NOXRDS_FILE).read() + self.xmldoc = open(NOXRDS_FILE).read() self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) def testEmpty(self): @@ -167,7 +167,7 @@ def testEmpty(self): def testNoXRD(self): """Make sure that we get an exception when there is no XRD element present.""" - self.xmldoc = file(NOXRD_FILE).read() + self.xmldoc = open(NOXRD_FILE).read() self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) @@ -180,7 +180,7 @@ def mkTest(iname, filename, expectedID): filename = datapath(filename) def test(self): - xrds = etxrd.parseXRDS(file(filename).read()) + xrds = etxrd.parseXRDS(open(filename).read()) self._getCanonicalID(iname, xrds, expectedID) return test diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index f0fa13d8..daf830ce 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -45,7 +45,7 @@ def checkEncoded(s): assert s_prime == s, (s, b64, s_prime) # Randomized test - for _ in xrange(50): + for _ in range(50): n = random.randrange(2048) s = b''.join(chr(random.randrange(256)) for i in range(n)) b64 = oidutil.toBase64(s) From 8305d62a5d46e1df688aba742f0cfea0dfc4f778 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 09:45:55 +0200 Subject: [PATCH 081/151] Fix URI quoting for python 3 --- openid/urinorm.py | 10 +++++++++- openid/yadis/xri.py | 19 ++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/openid/urinorm.py b/openid/urinorm.py index e8cc60ec..f212ad89 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -115,7 +115,15 @@ def urinorm(uri): # Normalize path path = split_uri.path # Unquote and quote - this normalizes the percent encoding - path = quote(unquote(path.encode('utf-8'))).decode('utf-8') + + # This is hackish. `unquote` and `quote` requires `str` in both py27 and py3+. + if isinstance(path, str): + # Python 3 branch + path = quote(unquote(path)) + else: + # Python 2 branch + path = quote(unquote(path.encode('utf-8'))).decode('utf-8') + path = remove_dot_segments(path) if not path: path = '/' diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index 58567922..2924f353 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -7,11 +7,10 @@ from __future__ import unicode_literals import re -import warnings -import six from six.moves.urllib.parse import quote +from openid.oidutil import string_to_text from openid.urinorm import GEN_DELIMS, PERCENT_ENCODING_CHARACTER, SUB_DELIMS XRI_AUTHORITIES = ['!', '=', '@', '+', '$', '('] @@ -68,14 +67,16 @@ def iriToURI(iri): @type iri: six.text_type, six.binary_type deprecated. @rtype: six.text_type """ - # Transform the input to the binary string. `quote` doesn't quote correctly unicode strings. - if isinstance(iri, six.text_type): - iri = iri.encode('utf-8') - else: - assert isinstance(iri, six.binary_type) - warnings.warn("Binary input for iriToURI is deprecated. Use text input instead.", DeprecationWarning) + iri = string_to_text(iri, "Binary input for iriToURI is deprecated. Use text input instead.") - return quote(iri, (GEN_DELIMS + SUB_DELIMS + PERCENT_ENCODING_CHARACTER).encode('utf-8')).decode('utf-8') + # This is hackish. `quote` requires `str` in both py27 and py3+. + if isinstance(iri, str): + # Python 3 branch + return quote(iri, GEN_DELIMS + SUB_DELIMS + PERCENT_ENCODING_CHARACTER) + else: + # Python 2 branch + return quote(iri.encode('utf-8'), + (GEN_DELIMS + SUB_DELIMS + PERCENT_ENCODING_CHARACTER).encode('utf-8')).decode('utf-8') def providerIsAuthoritative(providerID, canonicalID): From c987505065718ae400d51eea7143c25646480e6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 10:08:07 +0200 Subject: [PATCH 082/151] Use long depending on python version --- openid/dh.py | 10 ++++++++-- openid/server/server.py | 2 +- openid/test/test_cryptutil.py | 17 +++++++++++------ openid/test/test_dh.py | 6 +++--- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/openid/dh.py b/openid/dh.py index ab9f984c..28ff403b 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -4,6 +4,12 @@ from openid import cryptutil +if six.PY2: + long_int = long +else: + assert six.PY3 + long_int = int + def _xor(a_b): # Python 2 only @@ -35,8 +41,8 @@ def fromDefaults(cls): return cls(cls.DEFAULT_MOD, cls.DEFAULT_GEN) def __init__(self, modulus, generator): - self.modulus = long(modulus) - self.generator = long(generator) + self.modulus = long_int(modulus) + self.generator = long_int(generator) self._setPrivate(cryptutil.randrange(1, modulus - 1)) diff --git a/openid/server/server.py b/openid/server/server.py index 701fbdf4..2d251d4c 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -318,7 +318,7 @@ class DiffieHellmanSHA1ServerSession(object): @ivar consumer_pubkey: The public key sent by the consumer in the associate request - @type consumer_pubkey: long + @type consumer_pubkey: int, long in Python 2 @see: U{OpenID Specs, Mode: associate } diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index 5a866110..248c1005 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -12,6 +12,11 @@ # Most of the purpose of this test is to make sure that cryptutil can # find a good source of randomness on this machine. +if six.PY2: + long_int = long +else: + assert six.PY3 + long_int = int class TestRandRange(unittest.TestCase): @@ -29,13 +34,13 @@ def test_cryptrand(self): a = cryptutil.randrange(2 ** 128) b = cryptutil.randrange(2 ** 128) - assert isinstance(a, long) - assert isinstance(b, long) + assert isinstance(a, long_int) + assert isinstance(b, long_int) assert b != a # Make sure that we can generate random numbers that are larger # than platform int size - cryptutil.randrange(long(sys.maxsize) + 1) + cryptutil.randrange(long_int(sys.maxsize) + 1) class TestLongBinary(unittest.TestCase): @@ -46,7 +51,7 @@ def test_binaryLongConvert(self): for iteration in range(500): n = 0 for i in range(10): - n += long(random.randrange(MAX)) + n += long_int(random.randrange(MAX)) s = cryptutil.longToBinary(n) assert isinstance(s, six.binary_type) @@ -79,7 +84,7 @@ def test_longToBase64(self): try: for line in f: parts = line.strip().split(' ') - assert parts[0] == cryptutil.longToBase64(long(parts[1])) + assert parts[0] == cryptutil.longToBase64(long_int(parts[1])) finally: f.close() @@ -92,6 +97,6 @@ def test_base64ToLong(self): try: for line in f: parts = line.strip().split(' ') - assert long(parts[1]) == cryptutil.base64ToLong(parts[0]) + assert long_int(parts[1]) == cryptutil.base64ToLong(parts[0]) finally: f.close() diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index c9a2c561..ddd84f0c 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -6,7 +6,7 @@ import six -from openid.dh import DiffieHellman, strxor +from openid.dh import DiffieHellman, long_int, strxor class TestStrXor(unittest.TestCase): @@ -72,8 +72,8 @@ def test_public(self): try: for line in f: parts = line.strip().split(' ') - dh._setPrivate(long(parts[0])) + dh._setPrivate(long_int(parts[0])) - assert dh.public == long(parts[1]) + assert dh.public == long_int(parts[1]) finally: f.close() From 8345d651bc80fa48961f2d4282a03ce605c216de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 10:31:09 +0200 Subject: [PATCH 083/151] Fix test data loading --- openid/test/discoverdata.py | 2 +- openid/test/test_accept.py | 2 +- openid/test/test_discover.py | 4 ++-- openid/test/test_etxrd.py | 25 +++++++++++++------------ openid/test/test_trustroot.py | 2 +- openid/test/test_yadis_discover.py | 2 +- 6 files changed, 19 insertions(+), 18 deletions(-) diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 1540c344..1cc14849 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -132,5 +132,5 @@ def generateResult(base_url, input_name, id_name, result_name, success): if success: result.xrds_uri = urljoin(base_url, result_name) result.content_type = ctype - result.response_text = content + result.response_text = content.encode('utf-8') return input_url, result diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 8ca25fa9..7f645d8a 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -15,7 +15,7 @@ def getTestData(): filename = os.path.join(os.path.dirname(__file__), 'data', 'accept.txt') i = 1 lines = [] - for line in open(filename): + for line in open(filename, 'rb'): lines.append((i, line.decode('utf-8'))) i += 1 return lines diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index c99fc493..5e56a776 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -206,7 +206,7 @@ def readDataFile(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) filename = os.path.join( module_directory, 'data', 'test_discover', filename) - return open(filename).read() + return open(filename, 'rb').read() class TestDiscovery(BaseTestDiscovery): @@ -258,7 +258,7 @@ def test_unicode_undecodable_html2(self): def test_noOpenID(self): services = self._discover(content_type='text/plain', - data="junk", + data=b"junk", expected_services=0) services = self._discover( diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index 69325843..f07dbb92 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -67,19 +67,20 @@ def test_invalid_xml(self): etxrd.parseXRDS(xml) def test_xxe(self): - xxe_content = 'XXE CONTENT' + xxe_content = b'XXE CONTENT' _, tmp_file = tempfile.mkstemp() try: - with open(tmp_file, 'w') as xxe_file: + with open(tmp_file, 'wb') as xxe_file: xxe_file.write(xxe_content) # XXE example from Testing for XML Injection (OTG-INPVAL-008) # https://www.owasp.org/index.php/Testing_for_XML_Injection_(OTG-INPVAL-008) - xml = (b'' - b'' - b']>' - b'&xxe;') - tree = etxrd.parseXRDS(xml % tmp_file) + xml = ('' + '' + ']>' + '&xxe;') + xml = xml % tmp_file + tree = etxrd.parseXRDS(xml.encode('utf-8')) self.assertNotIn(xxe_content, etree.tostring(tree)) finally: os.remove(tmp_file) @@ -87,7 +88,7 @@ def test_xxe(self): class TestServiceParser(unittest.TestCase): def setUp(self): - self.xmldoc = open(XRD_FILE).read() + self.xmldoc = open(XRD_FILE, 'rb').read() self.yadis_url = 'http://unittest.url/' def _getServices(self, flt=None): @@ -155,7 +156,7 @@ def testGetSeveralForOne(self): def testNoXRDS(self): """Make sure that we get an exception when an XRDS element is not present""" - self.xmldoc = open(NOXRDS_FILE).read() + self.xmldoc = open(NOXRDS_FILE, 'rb').read() self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) def testEmpty(self): @@ -167,7 +168,7 @@ def testEmpty(self): def testNoXRD(self): """Make sure that we get an exception when there is no XRD element present.""" - self.xmldoc = open(NOXRD_FILE).read() + self.xmldoc = open(NOXRD_FILE, 'rb').read() self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) @@ -180,7 +181,7 @@ def mkTest(iname, filename, expectedID): filename = datapath(filename) def test(self): - xrds = etxrd.parseXRDS(open(filename).read()) + xrds = etxrd.parseXRDS(open(filename, 'rb').read()) self._getCanonicalID(iname, xrds, expectedID) return test diff --git a/openid/test/test_trustroot.py b/openid/test/test_trustroot.py index 9d3e3f3f..f90dd90e 100644 --- a/openid/test/test_trustroot.py +++ b/openid/test/test_trustroot.py @@ -5,7 +5,7 @@ from openid.server.trustroot import TrustRoot -with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'trustroot.txt')) as test_data_file: +with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'trustroot.txt'), 'rb') as test_data_file: trustroot_test_data = test_data_file.read().decode('utf-8') diff --git a/openid/test/test_yadis_discover.py b/openid/test/test_yadis_discover.py index eec94c85..52baa14c 100644 --- a/openid/test/test_yadis_discover.py +++ b/openid/test/test_yadis_discover.py @@ -41,7 +41,7 @@ def mkResponse(data): status = int(status_mo.group(1)) return fetchers.HTTPResponse(status=status, headers=headers, - body=body) + body=body.encode('utf-8')) class TestFetcher(object): From cc2feb85c08738378d770b6c0f5057ba5430bb98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 11:35:29 +0200 Subject: [PATCH 084/151] Update urinorm for python3 --- openid/test/data/trustroot.txt | 3 +-- openid/test/test_trustroot.py | 14 ++++++++++++++ openid/urinorm.py | 5 ++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/openid/test/data/trustroot.txt b/openid/test/data/trustroot.txt index f46ec088..81ebfc90 100644 --- a/openid/test/data/trustroot.txt +++ b/openid/test/data/trustroot.txt @@ -44,7 +44,7 @@ http://www.schtuffcom/ http://it/ ---------------------------------------- -21: Sane +20: Sane ---------------------------------------- http://*.schtuff.com./ http://*.schtuff.com/ @@ -64,7 +64,6 @@ https://foo.com/ http://kink.fm/should/be/sane http://beta.lingu.no/ http://goathack.livejournal.org:8020/openid/login.bml -http://*.example.com:80:90/ http://π.pi.example.com/ http://lambda.example.com/Λ diff --git a/openid/test/test_trustroot.py b/openid/test/test_trustroot.py index f90dd90e..c0511ec5 100644 --- a/openid/test/test_trustroot.py +++ b/openid/test/test_trustroot.py @@ -3,6 +3,8 @@ import os import unittest +import six + from openid.server.trustroot import TrustRoot with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'data', 'trustroot.txt'), 'rb') as test_data_file: @@ -23,6 +25,18 @@ def test(self): else: assert tr is None, tr + @unittest.skipUnless(six.PY2, "Test for python 2 only") + def test_double_port_py2(self): + # Python 2 urlparse silently drops the ':90' port + trust_root = TrustRoot.parse('http://*.example.com:80:90/') + self.assertTrue(trust_root.isSane()) + self.assertEqual(trust_root.buildDiscoveryURL(), 'http://www.example.com/') + + @unittest.skipUnless(six.PY3, "Test for python 3 only") + def test_double_port_py3(self): + # Python 3 urllib.parse complains about invalid port + self.assertIsNone(TrustRoot.parse('http://*.example.com:80:90/')) + class MatchTest(unittest.TestCase): diff --git a/openid/urinorm.py b/openid/urinorm.py index f212ad89..6a5a5883 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -97,7 +97,10 @@ def urinorm(uri): raise ValueError('Invalid hostname {!r}: {}'.format(hostname, error)) _check_disallowed_characters(hostname, 'hostname') - port = split_uri.port + try: + port = split_uri.port + except ValueError as error: + raise ValueError('Invalid port in {!r}: {}'.format(split_uri.netloc, error)) if port is None: port = '' elif (scheme == 'http' and port == 80) or (scheme == 'https' and port == 443): From b34eb847afd905d1f4e3332da9647d034df8df0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 15:05:51 +0200 Subject: [PATCH 085/151] Fix exception handling --- openid/consumer/consumer.py | 11 +++++------ openid/consumer/discover.py | 3 ++- openid/extensions/sreg.py | 2 +- openid/oidutil.py | 2 +- openid/server/server.py | 2 +- openid/store/sqlstore.py | 2 +- openid/test/test_consumer.py | 5 +++-- openid/test/test_fetchers.py | 3 ++- 8 files changed, 16 insertions(+), 14 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index c08dc62c..65c9fe6a 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -341,8 +341,7 @@ def begin(self, user_url, anonymous=False): try: service = disco.getNextService(self._discover) except fetchers.HTTPFetchingError as why: - raise DiscoveryFailure( - 'Error fetching XRDS document: %s' % (why[0],), None) + raise DiscoveryFailure('Error fetching XRDS document: %s' % six.text_type(why), None) if service is None: raise DiscoveryFailure( @@ -648,7 +647,7 @@ def _complete_id_res(self, message, endpoint, return_to): try: return self._doIdRes(message, endpoint, return_to) except (ProtocolError, DiscoveryFailure) as why: - return FailureResponse(endpoint, why[0]) + return FailureResponse(endpoint, six.text_type(why)) def _completeInvalid(self, message, endpoint, _): mode = message.getArg(OPENID_NS, 'mode', '') @@ -770,7 +769,7 @@ def _idResCheckNonce(self, message, endpoint): try: timestamp, salt = splitNonce(nonce) except ValueError as why: - raise ProtocolError('Malformed nonce: %s' % (why[0],)) + raise ProtocolError('Malformed nonce: %s' % six.text_type(why)) if (self.store is not None and not self.store.useNonce(server_url, timestamp, salt)): raise ProtocolError('Nonce already used or out of range') @@ -1370,7 +1369,7 @@ def _extractAssociation(self, assoc_response, assoc_session): try: expires_in = int(expires_in_str) except ValueError as why: - raise ProtocolError('Invalid expires_in field: %s' % (why[0],)) + raise ProtocolError('Invalid expires_in field: %s' % six.text_type(why)) # OpenID 1 has funny association session behaviour. if assoc_response.isOpenID1(): @@ -1408,7 +1407,7 @@ def _extractAssociation(self, assoc_response, assoc_session): secret = assoc_session.extractSecret(assoc_response) except ValueError as why: fmt = 'Malformed response for %s session: %s' - raise ProtocolError(fmt % (assoc_session.session_type, why[0])) + raise ProtocolError(fmt % (assoc_session.session_type, six.text_type(why))) return Association.fromExpiresIn( expires_in, assoc_handle, secret, assoc_type) diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index c3ddb66f..b483bc0b 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -3,6 +3,7 @@ import logging +import six from lxml.etree import LxmlError from lxml.html import document_fromstring from six.moves.urllib.parse import urldefrag, urlparse @@ -304,7 +305,7 @@ def normalizeURL(url): try: normalized = urinorm.urinorm(url) except ValueError as why: - raise DiscoveryFailure('Normalizing identifier: %s' % (why[0],), None) + raise DiscoveryFailure('Normalizing identifier: %s' % six.text_type(why), None) else: return urldefrag(normalized)[0] diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index c557e216..7f9828e9 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -163,7 +163,7 @@ def getSRegNS(message): except KeyError as why: # An alias for the string 'sreg' already exists, but it's # defined for something other than simple registration - raise SRegNamespaceError(why[0]) + raise SRegNamespaceError(six.text_type(why)) return sreg_ns_uri diff --git a/openid/oidutil.py b/openid/oidutil.py index b5ddf31d..3ed32de8 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -127,7 +127,7 @@ def fromBase64(s): return binascii.a2b_base64(s) except binascii.Error as why: # Convert to a common exception type - raise ValueError(why[0]) + raise ValueError(six.text_type(why)) class Symbol(object): diff --git a/openid/server/server.py b/openid/server/server.py index 2d251d4c..33c97c7a 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -466,7 +466,7 @@ def fromMessage(klass, message, op_endpoint=UNUSED): session = session_class.fromMessage(message) except ValueError as why: raise ProtocolError(message, 'Error parsing %s session: %s' % - (session_class.session_type, why[0])) + (session_class.session_type, six.text_type(why))) if assoc_type not in session.allowed_assoc_types: fmt = 'Session type %s does not support association type %s' diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index 72c8f5ee..8865e4f1 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -376,7 +376,7 @@ def useNonce(self, *args, **kwargs): try: return super(SQLiteStore, self).useNonce(*args, **kwargs) except self.exceptions.OperationalError as why: - if re.match('^columns .* are not unique$', why[0]): + if re.match('^columns .* are not unique$', six.text_type(why)): return False else: raise diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 268a641f..950bad9e 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -3,6 +3,7 @@ import time import unittest +import six from six.moves.urllib.parse import parse_qsl, urlparse from testfixtures import LogCapture, StringComparison @@ -740,7 +741,7 @@ def test(self): message = Message.fromOpenIDArgs(openid_args) with self.assertRaises(ProtocolError) as catch: self.consumer._idResCheckForFields(message) - self.assertTrue(catch.exception[0].startswith('Missing required')) + self.assertTrue(six.text_type(catch.exception).startswith('Missing required')) return test def mkMissingSignedTest(openid_args): @@ -748,7 +749,7 @@ def test(self): message = Message.fromOpenIDArgs(openid_args) with self.assertRaises(ProtocolError) as catch: self.consumer._idResCheckForFields(message) - self.assertTrue(catch.exception[0].endswith('not signed')) + self.assertTrue(six.text_type(catch.exception).endswith('not signed')) return test test_openid1Missing_returnToSig = mkMissingSignedTest( diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 6501b181..8699e8f6 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -5,6 +5,7 @@ import warnings import responses +import six from mock import Mock, patch, sentinel from six import StringIO from six.moves.BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer @@ -133,7 +134,7 @@ def run_fetcher_tests(server): try: exc_fetchers.append(klass()) except RuntimeError as why: - if why[0].startswith('Cannot find %s library' % (library_name,)): + if six.text_type(why).startswith('Cannot find %s library' % (library_name,)): try: __import__(library_name) except ImportError: From 13735aa26af343ca96e9a94c1bc7786bd7da35b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 15:38:44 +0200 Subject: [PATCH 086/151] Update prioSort ordering --- openid/yadis/etxrd.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index 019b82ef..4eefc0bc 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -6,6 +6,8 @@ import random from datetime import datetime +from functools import total_ordering +from operator import itemgetter from time import strptime from lxml import etree @@ -190,17 +192,18 @@ def getCanonicalID(iname, xrd_tree): return canonicalID +@total_ordering class _Max(object): """Value that compares greater than any other value. Should only be used as a singleton. Implemented for use as a priority value for when a priority is not specified.""" - def __cmp__(self, other): - if other is self: - return 0 + def __eq__(self, other): + return self is other - return 1 + def __gt__(self, other): + return True Max = _Max() @@ -242,7 +245,7 @@ def prioSort(elements): # elements are load-balanced. random.shuffle(elements) - prio_elems = sorted((getPriority(e), e) for e in elements) + prio_elems = sorted(((getPriority(e), e) for e in elements), key=itemgetter(0)) sorted_elems = [s for (_, s) in prio_elems] return sorted_elems From db406bd9f6f5c01d965e3abdc0e47b77b7166828 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 9 May 2018 16:14:58 +0200 Subject: [PATCH 087/151] Fix print statement --- admin/gettlds.py | 4 ++-- contrib/associate | 8 ++++---- contrib/openid-parse | 4 ++-- contrib/upgrade-store-1.1-to-2.0 | 20 ++++++++++---------- examples/consumer.py | 4 ++-- examples/discover | 30 +++++++++++++++--------------- examples/server.py | 4 ++-- openid/test/test_accept.py | 4 ++-- openid/test/test_fetchers.py | 2 +- openid/test/test_oidutil.py | 15 +++++++++++---- openid/test/test_storetest.py | 6 ++---- openid/yadis/xrires.py | 2 +- 12 files changed, 54 insertions(+), 49 deletions(-) diff --git a/admin/gettlds.py b/admin/gettlds.py index c4892769..4b0c4033 100644 --- a/admin/gettlds.py +++ b/admin/gettlds.py @@ -38,10 +38,10 @@ tld = input_line.strip().lower() new_output_line = output_line + prefix + tld if len(new_output_line) > 60: - print output_line + line_suffix + print(output_line + line_suffix) output_line = line_prefix + tld else: output_line = new_output_line prefix = separator -print output_line + suffix +print(output_line + suffix) diff --git a/contrib/associate b/contrib/associate index d84cfb31..ca7e5884 100755 --- a/contrib/associate +++ b/contrib/associate @@ -27,9 +27,9 @@ def verboseAssociation(assoc): def main(): if not sys.argv[1:]: - print "Usage: %s ENDPOINT_URL..." % (sys.argv[0],) + print("Usage: %s ENDPOINT_URL..." % (sys.argv[0],)) for endpoint_url in sys.argv[1:]: - print "Associating with", endpoint_url + print("Associating with", endpoint_url) # This makes it clear why j3h made AssociationManager when we # did the ruby port. We can't invoke requestAssociation @@ -40,9 +40,9 @@ def main(): c = consumer.GenericConsumer(store) auth_req = c.begin(endpoint) if auth_req.assoc: - print verboseAssociation(auth_req.assoc) + print(verboseAssociation(auth_req.assoc)) else: - print " ...no association." + print(" ...no association.") if __name__ == '__main__': diff --git a/contrib/openid-parse b/contrib/openid-parse index c4e437a6..b227149d 100644 --- a/contrib/openid-parse +++ b/contrib/openid-parse @@ -57,10 +57,10 @@ def main(): output.append('at %s:\n%s' % (where, openidFromQuery(query))) if output: - print '\n\n'.join(output) + print('\n\n'.join(output)) elif errors: for err in errors: - print err + print(err) def queryFromURL(url): diff --git a/contrib/upgrade-store-1.1-to-2.0 b/contrib/upgrade-store-1.1-to-2.0 index 48a62552..2e09e0b1 100644 --- a/contrib/upgrade-store-1.1-to-2.0 +++ b/contrib/upgrade-store-1.1-to-2.0 @@ -28,8 +28,8 @@ def askForPassword(): def askForConfirmation(dbname, tablename): - print """The table %s from the database %s will be dropped, and - an empty table with the new nonce table schema will replace it.""" % (tablename, dbname) + print("""The table %s from the database %s will be dropped, and + an empty table with the new nonce table schema will replace it.""" % (tablename, dbname)) return raw_input("Continue? ").lower().strip().startswith('y') @@ -109,12 +109,12 @@ def main(argv=None): try: from pysqlite2 import dbapi2 as sqlite except ImportError: - print "You must have pysqlite2 installed in your PYTHONPATH." + print("You must have pysqlite2 installed in your PYTHONPATH.") return 1 try: db_conn = sqlite.connect(options.sqlite_db_name) except Exception as e: - print "Could not connect to SQLite database:", six.text_type(e) + print("Could not connect to SQLite database:", six.text_type(e)) return 1 if askForConfirmation(options.sqlite_db_name, options.tablename): @@ -122,13 +122,13 @@ def main(argv=None): if options.postgres_db_name: if not options.username: - print "A username is required to open a PostgreSQL Database." + print("A username is required to open a PostgreSQL Database.") return 1 password = askForPassword() try: import psycopg except ImportError: - print "You need psycopg installed to update a postgres DB." + print("You need psycopg installed to update a postgres DB.") return 1 try: @@ -137,7 +137,7 @@ def main(argv=None): host=options.db_host, password=password) except Exception as e: - print "Could not connect to PostgreSQL database:", six.text_type(e) + print("Could not connect to PostgreSQL database:", six.text_type(e)) return 1 if askForConfirmation(options.postgres_db_name, options.tablename): @@ -145,20 +145,20 @@ def main(argv=None): if options.mysql_db_name: if not options.username: - print "A username is required to open a MySQL Database." + print("A username is required to open a MySQL Database.") return 1 password = askForPassword() try: import MySQLdb except ImportError: - print "You must have MySQLdb installed to update a MySQL DB." + print("You must have MySQLdb installed to update a MySQL DB.") return 1 try: db_conn = MySQLdb.connect(options.db_host, options.username, password, options.mysql_db_name) except Exception as e: - print "Could not connect to MySQL database:", six.text_type(e) + print("Could not connect to MySQL database:", six.text_type(e)) return 1 if askForConfirmation(options.mysql_db_name, options.tablename): diff --git a/examples/consumer.py b/examples/consumer.py index 19b29bdf..9ef9e70a 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -473,8 +473,8 @@ def main(host, port, data_path, weak_ssl=False): addr = (host, port) server = OpenIDHTTPServer(store, addr, OpenIDRequestHandler) - print 'Server running at:' - print server.base_url + print('Server running at:') + print(server.base_url) server.serve_forever() diff --git a/examples/discover b/examples/discover index b334d94e..99ae8abf 100644 --- a/examples/discover +++ b/examples/discover @@ -11,40 +11,40 @@ names = [["server_url", "Server URL "], def show_services(user_input, normalized, services): - print " Claimed identifier:", normalized + print(" Claimed identifier:", normalized) if services: - print " Discovered OpenID services:" + print(" Discovered OpenID services:") for n, service in enumerate(services): - print " %s." % (n,) + print(" %s." % (n,)) for attr, name in names: val = getattr(service, attr, None) if val is not None: - print " %s: %s" % (name, val) + print(" %s: %s" % (name, val)) - print " Type URIs:" + print(" Type URIs:") for type_uri in service.type_uris: - print " *", type_uri + print(" *", type_uri) - print + print() else: - print " No OpenID services found" - print + print(" No OpenID services found") + print() if __name__ == "__main__": import sys for user_input in sys.argv[1:]: - print "=" * 50 - print "Running discovery on", user_input + print("=" * 50) + print("Running discovery on", user_input) try: normalized, services = discover(user_input) except DiscoveryFailure as why: - print "Discovery failed:", why - print + print("Discovery failed:", why) + print() except HTTPFetchingError as why: - print "HTTP request failed:", why - print + print("HTTP request failed:", why) + print() else: show_services(user_input, normalized, services) diff --git a/examples/server.py b/examples/server.py index 934909eb..59b97964 100644 --- a/examples/server.py +++ b/examples/server.py @@ -685,8 +685,8 @@ def main(host, port, data_path): httpserver.setOpenIDServer(oidserver) - print 'Server running at:' - print httpserver.base_url + print('Server running at:') + print(httpserver.base_url) httpserver.serve_forever() diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 7f645d8a..8acea20c 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -101,7 +101,7 @@ def runTest(self): try: available = parseAvailable(avail_data) except Exception: - print 'On line', lno + print('On line', lno) raise lno, exp_data = data['expected'] @@ -109,7 +109,7 @@ def runTest(self): try: expected = parseExpected(exp_data) except Exception: - print 'On line', lno + print('On line', lno) raise accepted = accept.parseAcceptHeader(accept_header) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 8699e8f6..4be8b26f 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -101,7 +101,7 @@ def plain(path, code): try: actual = fetcher.fetch(fetch_url) except Exception: - print fetcher, fetch_url + print(fetcher, fetch_url) raise else: assertResponse(expected, actual) diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index daf830ce..8aacb583 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -23,11 +23,10 @@ def test_base64(self): allowed_d = {} for c in allowed_s: allowed_d[c] = None - isAllowed = allowed_d.has_key def checkEncoded(s): for c in s: - assert isAllowed(c), s + self.assertIn(c, allowed_d, msg=s) cases = [ b'', @@ -35,8 +34,12 @@ def checkEncoded(s): b'\x00', b'\x01', b'\x00' * 100, - b''.join(chr(i) for i in range(256)), ] + if six.PY2: + cases.append(b''.join(chr(i) for i in range(256))) + else: + assert six.PY3 + cases.append(bytes(i for i in range(256))) for s in cases: b64 = oidutil.toBase64(s) @@ -47,7 +50,11 @@ def checkEncoded(s): # Randomized test for _ in range(50): n = random.randrange(2048) - s = b''.join(chr(random.randrange(256)) for i in range(n)) + if six.PY2: + s = b''.join(chr(random.randrange(256)) for i in range(n)) + else: + assert six.PY3 + s = bytes(random.randrange(256) for i in range(n)) b64 = oidutil.toBase64(s) checkEncoded(b64) s_prime = oidutil.fromBase64(b64) diff --git a/openid/test/test_storetest.py b/openid/test/test_storetest.py index ef377d11..f315a3c5 100644 --- a/openid/test/test_storetest.py +++ b/openid/test/test_storetest.py @@ -58,8 +58,7 @@ def checkRetrieve(url, handle=None, expected=None): assert retrieved_assoc == expected, (retrieved_assoc, expected) if expected is not None: if retrieved_assoc is expected: - print ('Unexpected: retrieved a reference to the expected ' - 'value instead of a new object') + print('Unexpected: retrieved a reference to the expected value instead of a new object') assert retrieved_assoc.handle == expected.handle assert retrieved_assoc.secret == expected.secret @@ -284,8 +283,7 @@ def test_mysql(self): conn = MySQLdb.connect(user=db_user, passwd=db_passwd, host=db_host) except MySQLdb.OperationalError as why: if why[0] == 2005: - print ('Skipping MySQL store test (cannot connect ' - 'to test server on host %r)' % (db_host,)) + print('Skipping MySQL store test (cannot connect to test server on host %r)' % db_host) return else: raise diff --git a/openid/yadis/xrires.py b/openid/yadis/xrires.py index 2bc671f2..0eb04d7c 100644 --- a/openid/yadis/xrires.py +++ b/openid/yadis/xrires.py @@ -87,7 +87,7 @@ def query(self, xri, service_types): response = fetchers.fetch(url) if response.status not in (200, 206): # XXX: sucks to fail silently. - # print "response not OK:", response + # print("response not OK:", response) continue et = etxrd.parseXRDS(response.body) canonicalID = etxrd.getCanonicalID(xri, et) From f67b427cdccb8fbd62924709d6b425e86648ba51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 2 May 2018 12:50:41 +0200 Subject: [PATCH 088/151] Set up Python3 support --- .gitignore | 16 +++++++++------- .travis.yml | 3 +++ setup.py | 6 +++++- tox.ini | 4 ++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 58aac550..7affb99f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,10 +1,12 @@ -*~ *.pyc -*.swp -.tox -# Created in tests -/.coverage +__pycache__ +# Distribution +/dist +/*.egg-info +# Tests +/.tox /.eggs -/htmlcov -/python_openid.egg-info /sstore +# Coverage +/.coverage +/htmlcov diff --git a/.travis.yml b/.travis.yml index fe0b9baa..35cfdd7c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,6 +4,9 @@ sudo: false python: - "2.7" + - "3.4" + - "3.5" + - "3.6" - "pypy" addons: diff --git a/setup.py b/setup.py index a333c44f..a230b668 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,10 @@ 'Programming Language :: Python', 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.4', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', 'Topic :: Internet :: WWW/HTTP', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', @@ -58,7 +62,7 @@ 'openid.extensions', 'openid.extensions.draft', ], - python_requires='~=2.7', + python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*', install_requires=INSTALL_REQUIRES, extras_require=EXTRAS_REQUIRE, # license specified by classifier. diff --git a/tox.ini b/tox.ini index d75be233..a2ccdec7 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = quality - py27-{openid,djopenid,httplib2,pycurl,requests} + py{27,34,35,36}-{openid,djopenid,httplib2,pycurl,requests} pypy-{openid,djopenid,httplib2,pycurl,requests} # tox-travis specials @@ -19,7 +19,7 @@ extras = httplib2: httplib2 pycurl: pycurl requests: requests -passenv = CI TRAVIS TRAVIS_* +passenv = CI TRAVIS TRAVIS_* PYTHONWARNINGS setenv = DJANGO_SETTINGS_MODULE = djopenid.settings PYTHONPATH = {toxinidir}/examples:{env:PYTHONPATH:} From 083edc87fc0e1f97907e5169299cb6166a29dc78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 11 May 2018 11:04:32 +0200 Subject: [PATCH 089/151] Drop condition for SHA256 availability --- openid/association.py | 37 +++++++++++-------------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/openid/association.py b/openid/association.py index ae566e7a..5baf06f2 100644 --- a/openid/association.py +++ b/openid/association.py @@ -17,8 +17,7 @@ @var default_negotiator: A C{L{SessionNegotiator}} that allows all association types that are specified by the OpenID specification. It prefers to use HMAC-SHA1/DH-SHA1, if it's - available. If HMAC-SHA256 is not supported by your Python runtime, - HMAC-SHA256 and DH-SHA256 will not be available. + available. @var encrypted_negotiator: A C{L{SessionNegotiator}} that does not support C{'no-encryption'} associations. It prefers @@ -48,31 +47,17 @@ 'HMAC-SHA256', ] -if hasattr(cryptutil, 'hmacSha256'): - supported_association_types = list(all_association_types) - - default_association_order = [ - ('HMAC-SHA1', 'DH-SHA1'), - ('HMAC-SHA1', 'no-encryption'), - ('HMAC-SHA256', 'DH-SHA256'), - ('HMAC-SHA256', 'no-encryption'), - ] - - only_encrypted_association_order = [ - ('HMAC-SHA1', 'DH-SHA1'), - ('HMAC-SHA256', 'DH-SHA256'), - ] -else: - supported_association_types = ['HMAC-SHA1'] - - default_association_order = [ - ('HMAC-SHA1', 'DH-SHA1'), - ('HMAC-SHA1', 'no-encryption'), - ] +default_association_order = [ + ('HMAC-SHA1', 'DH-SHA1'), + ('HMAC-SHA1', 'no-encryption'), + ('HMAC-SHA256', 'DH-SHA256'), + ('HMAC-SHA256', 'no-encryption'), +] - only_encrypted_association_order = [ - ('HMAC-SHA1', 'DH-SHA1'), - ] +only_encrypted_association_order = [ + ('HMAC-SHA1', 'DH-SHA1'), + ('HMAC-SHA256', 'DH-SHA256'), +] def getSessionTypes(assoc_type): From 04d9c38c2c4f3664805bc627724087e1cad3242d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 16 May 2018 08:34:55 +0200 Subject: [PATCH 090/151] Fix future imports in examples --- examples/consumer.py | 2 -- examples/server.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/examples/consumer.py b/examples/consumer.py index 9ef9e70a..662079e7 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -6,8 +6,6 @@ and using the Python OpenID library. You can then move on to more robust examples, and integrating OpenID into your application. """ -__copyright__ = 'Copyright 2005-2008, Janrain, Inc.' - from __future__ import unicode_literals import cgi diff --git a/examples/server.py b/examples/server.py index 59b97964..b2909e7c 100644 --- a/examples/server.py +++ b/examples/server.py @@ -1,7 +1,4 @@ #!/usr/bin/env python - -__copyright__ = 'Copyright 2005-2008, Janrain, Inc.' - from __future__ import unicode_literals import cgi From 721a664e66f9ec0f860ea010b9f772e31ab0e9dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 15 May 2018 14:08:24 +0200 Subject: [PATCH 091/151] Prefer stronger association methods --- openid/association.py | 10 +++++----- openid/consumer/consumer.py | 2 +- openid/test/test_consumer.py | 35 +++++++++++++++++------------------ 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/openid/association.py b/openid/association.py index 5baf06f2..de607f4c 100644 --- a/openid/association.py +++ b/openid/association.py @@ -43,28 +43,28 @@ all_association_types = [ - 'HMAC-SHA1', 'HMAC-SHA256', + 'HMAC-SHA1', ] default_association_order = [ - ('HMAC-SHA1', 'DH-SHA1'), - ('HMAC-SHA1', 'no-encryption'), ('HMAC-SHA256', 'DH-SHA256'), ('HMAC-SHA256', 'no-encryption'), + ('HMAC-SHA1', 'DH-SHA1'), + ('HMAC-SHA1', 'no-encryption'), ] only_encrypted_association_order = [ - ('HMAC-SHA1', 'DH-SHA1'), ('HMAC-SHA256', 'DH-SHA256'), + ('HMAC-SHA1', 'DH-SHA1'), ] def getSessionTypes(assoc_type): """Return the allowed session types for a given association type""" assoc_to_session = { - 'HMAC-SHA1': ['DH-SHA1', 'no-encryption'], 'HMAC-SHA256': ['DH-SHA256', 'no-encryption'], + 'HMAC-SHA1': ['DH-SHA1', 'no-encryption'], } return assoc_to_session.get(assoc_type, []) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 65c9fe6a..5508c45c 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -579,8 +579,8 @@ class GenericConsumer(object): openid1_return_to_identifier_name = 'openid1_claimed_id' session_types = { - 'DH-SHA1': DiffieHellmanSHA1ConsumerSession, 'DH-SHA256': DiffieHellmanSHA256ConsumerSession, + 'DH-SHA1': DiffieHellmanSHA1ConsumerSession, 'no-encryption': PlainTextConsumerSession, } diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 950bad9e..1ac1e16e 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -2,6 +2,7 @@ import time import unittest +from functools import partial import six from six.moves.urllib.parse import parse_qsl, urlparse @@ -18,7 +19,7 @@ from openid.extension import Extension from openid.fetchers import HTTPFetchingError, HTTPResponse from openid.message import BARE_NS, IDENTIFIER_SELECT, OPENID1_NS, OPENID2_NS, OPENID_NS, Message -from openid.server.server import DiffieHellmanSHA1ServerSession, PlainTextServerSession +from openid.server.server import DiffieHellmanSHA256ServerSession from openid.store import memstore from openid.store.nonce import mkNonce, split as splitNonce from openid.yadis.discover import DiscoveryFailure @@ -26,8 +27,8 @@ assocs = [ # (secret, handle) - (b'another 20-byte key.', 'Snarky'), - (b'\x00' * 20, 'Zeros'), + (b'another 32-byte very secret key.', 'Snarky'), + (b'\x00' * 32, 'Zeros'), ] @@ -51,22 +52,18 @@ def associate(qs, assoc_secret, assoc_handle): secret and handle.""" q = parseQuery(qs) assert q['openid.mode'] == 'associate' - assert q['openid.assoc_type'] == 'HMAC-SHA1' + assert q['openid.assoc_type'] == 'HMAC-SHA256' reply_dict = { - 'assoc_type': 'HMAC-SHA1', + 'assoc_type': 'HMAC-SHA256', 'assoc_handle': assoc_handle, 'expires_in': '600', } - if q.get('openid.session_type') == 'DH-SHA1': - assert len(q) == 6 or len(q) == 4 - message = Message.fromPostArgs(q) - session = DiffieHellmanSHA1ServerSession.fromMessage(message) - reply_dict['session_type'] = 'DH-SHA1' - else: - assert len(q) == 2 - session = PlainTextServerSession.fromQuery(q) - + assert q.get('openid.session_type') == 'DH-SHA256' + assert len(q) == 6 or len(q) == 4 + message = Message.fromPostArgs(q) + session = DiffieHellmanSHA256ServerSession.fromMessage(message) + reply_dict['session_type'] = 'DH-SHA256' reply_dict.update(session.answer(assoc_secret)) return kvform.dictToKV(reply_dict) @@ -112,7 +109,7 @@ def fetch(self, url, body=None, headers=None): except ValueError: pass # fall through else: - assert body.find('DH-SHA1') != -1 + assert body.find('DH-SHA256') != -1 response = associate( body, self.assoc_secret, self.assoc_handle) self.num_assocs += 1 @@ -121,16 +118,18 @@ def fetch(self, url, body=None, headers=None): return self.response(url, 404, 'Not found') -def makeFastConsumerSession(): +def makeFastConsumerSession(consumer_session_cls=DiffieHellmanSHA256ConsumerSession): """ Create custom DH object so tests run quickly. """ dh = DiffieHellman(100389557, 2) - return DiffieHellmanSHA1ConsumerSession(dh) + return consumer_session_cls(dh) def setConsumerSession(con): - con.session_types = {'DH-SHA1': makeFastConsumerSession} + con.session_types = { + 'DH-SHA256': makeFastConsumerSession, + 'DH-SHA1': partial(makeFastConsumerSession, consumer_session_cls=DiffieHellmanSHA1ConsumerSession)} def _test_success(server_url, user_url, delegate_url, links, immediate=False): From 2385f7b5f2318e890b34357b42611b6a49d04767 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 11 May 2018 15:02:24 +0200 Subject: [PATCH 092/151] Split function for nonce salt generation --- openid/store/nonce.py | 21 +++++++++++++++++---- openid/test/test_nonce.py | 18 +++++++++++++++++- 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/openid/store/nonce.py b/openid/store/nonce.py index f00f4e1c..8c9353ce 100644 --- a/openid/store/nonce.py +++ b/openid/store/nonce.py @@ -1,10 +1,11 @@ from __future__ import unicode_literals +import itertools +import random import string from calendar import timegm from time import gmtime, strftime, strptime, time -from openid import cryptutil from openid.oidutil import string_to_text __all__ = [ @@ -14,7 +15,7 @@ ] -NONCE_CHARS = (string.ascii_letters + string.digits).encode('utf-8') +NONCE_CHARS = string.ascii_letters + string.digits # Keep nonces for five hours (allow five hours for the combination of # request time and clock skew). This is probably way more than is @@ -84,6 +85,19 @@ def checkTimestamp(nonce_string, allowed_skew=SKEW, now=None): return past <= stamp <= future +def make_nonce_salt(length=6): + """ + Generate and return a nonce salt. + + @param length: Length of the generated string. + @type length: int + @rtype: six.text_type + """ + sys_random = random.SystemRandom() + random_chars = itertools.starmap(sys_random.choice, itertools.repeat((NONCE_CHARS, ), length)) + return ''.join(random_chars) + + def mkNonce(when=None): """Generate a nonce with the current timestamp @@ -96,11 +110,10 @@ def mkNonce(when=None): @see: time """ - salt = cryptutil.randomString(6, NONCE_CHARS).decode('utf-8') if when is None: t = gmtime() else: t = gmtime(when) time_str = strftime(time_fmt, t) - return time_str + salt + return time_str + make_nonce_salt() diff --git a/openid/test/test_nonce.py b/openid/test/test_nonce.py index d2746115..1ba85dd2 100644 --- a/openid/test/test_nonce.py +++ b/openid/test/test_nonce.py @@ -3,11 +3,27 @@ import re import unittest -from openid.store.nonce import checkTimestamp, mkNonce, split as splitNonce +import six + +from openid.store.nonce import checkTimestamp, make_nonce_salt, mkNonce, split as splitNonce nonce_re = re.compile(r'\A\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ') +class TestMakeNonceSalt(unittest.TestCase): + """Test `make_nonce_salt` function.""" + + def test_default(self): + salt = make_nonce_salt() + self.assertIsInstance(salt, six.text_type) + self.assertEqual(len(salt), 6) + + def test_custom_length(self): + salt = make_nonce_salt(32) + self.assertIsInstance(salt, six.text_type) + self.assertEqual(len(salt), 32) + + class NonceTest(unittest.TestCase): def test_mkNonce(self): nonce = mkNonce() From 4555f1d2db53893f30e5247150baf50d0441d9b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 11 May 2018 15:15:40 +0200 Subject: [PATCH 093/151] Drop randomString function --- examples/consumer.py | 5 +++-- openid/cryptutil.py | 20 -------------------- openid/test/test_consumer.py | 3 ++- openid/test/test_storetest.py | 9 +++------ 4 files changed, 8 insertions(+), 29 deletions(-) diff --git a/examples/consumer.py b/examples/consumer.py index 662079e7..970865ab 100644 --- a/examples/consumer.py +++ b/examples/consumer.py @@ -39,11 +39,11 @@ def quoteattr(s): else: del openid from openid.consumer import consumer - from openid.cryptutil import randomString from openid.extensions import pape, sreg from openid.fetchers import Urllib2Fetcher, setDefaultFetcher from openid.oidutil import appendArgs from openid.store import filestore, memstore + from openid.store.nonce import make_nonce_salt # Used with an OpenID provider affiliate program. @@ -100,7 +100,8 @@ def getSession(self): # If a session id was not set, create a new one if sid is None: - sid = randomString(16, '0123456789abcdef') + # Pure pragmatism: Use function for nonce salt to generate session ID. + sid = make_nonce_salt(16) session = None else: session = self.server.sessions.get(sid) diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 858de483..0179169e 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -16,8 +16,6 @@ import os import random -import six - from openid.oidutil import fromBase64, string_to_text, toBase64 __all__ = [ @@ -27,7 +25,6 @@ 'hmacSha256', 'longToBase64', 'longToBinary', - 'randomString', 'randrange', 'sha1', 'sha256', @@ -217,23 +214,6 @@ def base64ToLong(s): return binaryToLong(fromBase64(s)) -def randomString(length, chrs=None): - """Produce a string of length random bytes, chosen from chrs. - - @type chrs: six.binary_type - @rtype: six.binary_type - """ - if chrs is None: - return getBytes(length) - else: - n = len(chrs) - random_chars = [chrs[randrange(n)] for _ in range(length)] - if six.PY2: - return b''.join(random_chars) - else: - return six.binary_type(random_chars) - - def const_eq(s1, s2): if len(s1) != len(s2): return False diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 1ac1e16e..c9cd8f5a 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import os import time import unittest from functools import partial @@ -1779,7 +1780,7 @@ def setUp(self): # base64(btwoc(g ^ xb mod p)) self.dh_server_public = cryptutil.longToBase64(self.server_dh.public) - self.secret = cryptutil.randomString(self.session_cls.secret_size) + self.secret = os.urandom(self.session_cls.secret_size) self.enc_mac_key = oidutil.toBase64( self.server_dh.xorSecret(self.consumer_dh.public, diff --git a/openid/test/test_storetest.py b/openid/test/test_storetest.py index f315a3c5..61a11afe 100644 --- a/openid/test/test_storetest.py +++ b/openid/test/test_storetest.py @@ -1,6 +1,7 @@ """Test `openid.store` module.""" from __future__ import unicode_literals +import itertools import os import random import socket @@ -9,7 +10,6 @@ import unittest from openid.association import Association -from openid.cryptutil import randomString from openid.store.nonce import mkNonce, split db_host = 'dbtest' @@ -22,10 +22,7 @@ def generateHandle(n): - return randomString(n, allowed_handle.encode('utf-8')) - - -generateSecret = randomString + return ''.join(itertools.starmap(random.choice, itertools.repeat((allowed_handle, ), n))) def getTmpDbName(): @@ -49,7 +46,7 @@ def testStore(store): server_url = 'http://www.myopenid.com/openid' def genAssoc(issued, lifetime=600): - sec = generateSecret(20) + sec = os.urandom(20) hdl = generateHandle(128) return Association(hdl, sec, now + issued, lifetime, 'HMAC-SHA1') From 6b15a022316b9653cea6d2f636f977eadc89c318 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 16 May 2018 11:38:09 +0200 Subject: [PATCH 094/151] Drop getBytes function --- openid/cryptutil.py | 32 +------------------------------- openid/server/server.py | 5 +++-- openid/test/test_cryptutil.py | 6 ------ 3 files changed, 4 insertions(+), 39 deletions(-) diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 0179169e..3fddee6a 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -128,36 +128,6 @@ def binaryToLong(s): return bytes_to_long(s) -# A cryptographically safe source of random bytes -try: - getBytes = os.urandom -except AttributeError: - try: - from Crypto.Util.randpool import RandomPool - except ImportError: - # Fall back on /dev/urandom, if present. It would be nice to - # have Windows equivalent here, but for now, require pycrypto - # on Windows. - try: - _urandom = open('/dev/urandom', 'rb') - except IOError: - raise ImportError('No adequate source of randomness found!') - else: - def getBytes(n): - bytes = [] - while n: - chunk = _urandom.read(n) - n -= len(chunk) - bytes.append(chunk) - assert n >= 0 - return ''.join(bytes) - else: - _pool = RandomPool() - - def getBytes(n, pool=_pool): - if pool.entropy < n: - pool.randomize() - return pool.get_bytes(n) # A randrange function that works for longs try: @@ -197,7 +167,7 @@ def randrange(start, stop=None, step=1): _duplicate_cache[r] = (duplicate, nbytes) while True: - bytes = '\x00' + getBytes(nbytes) + bytes = '\x00' + os.urandom(nbytes) n = binaryToLong(bytes) # Keep looping if this value is in the low duplicated range if n >= duplicate: diff --git a/openid/server/server.py b/openid/server/server.py index 33c97c7a..dfc7fbab 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -118,6 +118,7 @@ from __future__ import unicode_literals import logging +import os import time import warnings from copy import deepcopy @@ -1240,8 +1241,8 @@ def createAssociation(self, dumb=True, assoc_type='HMAC-SHA1'): """ assoc_type = string_to_text(assoc_type, "Binary values for assoc_type are deprecated. Use text input instead.") - secret = cryptutil.getBytes(getSecretSize(assoc_type)) - uniq = oidutil.toBase64(cryptutil.getBytes(4)) + secret = os.urandom(getSecretSize(assoc_type)) + uniq = oidutil.toBase64(os.urandom(4)) handle = '{%s}{%x}{%s}' % (assoc_type, int(time.time()), uniq) assoc = Association.fromExpiresIn( diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index 248c1005..5d5cc5fd 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -26,12 +26,6 @@ def test_cryptrand(self): # It's possible, but HIGHLY unlikely that a correct implementation # will fail by returning the same number twice - s = cryptutil.getBytes(32) - t = cryptutil.getBytes(32) - assert len(s) == 32 - assert len(t) == 32 - assert s != t - a = cryptutil.randrange(2 ** 128) b = cryptutil.randrange(2 ** 128) assert isinstance(a, long_int) From 6f6b6972036adc5d1937fbb4f63c0b279d630cb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 11 May 2018 13:54:14 +0200 Subject: [PATCH 095/151] Use cryptography for signature comparison --- openid/association.py | 3 ++- openid/cryptutil.py | 11 ----------- setup.py | 1 + 3 files changed, 3 insertions(+), 12 deletions(-) diff --git a/openid/association.py b/openid/association.py index de607f4c..ca063bda 100644 --- a/openid/association.py +++ b/openid/association.py @@ -28,6 +28,7 @@ import time import six +from cryptography.hazmat.primitives.constant_time import bytes_eq from openid import cryptutil, kvform, oidutil from openid.message import OPENID_NS @@ -513,7 +514,7 @@ def checkMessageSignature(self, message): if not message_sig: raise ValueError("%s has no sig." % (message,)) calculated_sig = self.getMessageSignature(message) - return cryptutil.const_eq(calculated_sig, message_sig) + return bytes_eq(calculated_sig.encode('utf-8'), message_sig.encode('utf-8')) def _makePairs(self, message): signed = message.getArg(OPENID_NS, 'signed') diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 3fddee6a..86c3e869 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -182,14 +182,3 @@ def longToBase64(l): def base64ToLong(s): return binaryToLong(fromBase64(s)) - - -def const_eq(s1, s2): - if len(s1) != len(s2): - return False - - result = True - for i in range(len(s1)): - result = result and (s1[i] == s2[i]) - - return result diff --git a/setup.py b/setup.py index a230b668..52bca806 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ VERSION = __import__('openid').__version__ INSTALL_REQUIRES = [ 'six', + 'cryptography', 'lxml;platform_python_implementation=="CPython"', 'lxml <4.0;platform_python_implementation=="PyPy"', ] From dae19740ac5c8887145ecee9d7c955390b431ea0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 22 May 2018 17:49:58 +0200 Subject: [PATCH 096/151] Refactor bytes <-> int conversions --- openid/cryptutil.py | 83 +++++++++++++++++------------------ openid/dh.py | 2 +- openid/test/test_cryptutil.py | 21 +++++++++ 3 files changed, 62 insertions(+), 44 deletions(-) diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 86c3e869..c7de35f5 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -3,18 +3,18 @@ Other configurations will need a quality source of random bytes and access to a function that will convert binary strings to long -integers. This module will work with the Python Cryptography Toolkit -(pycrypto) if it is present. pycrypto can be found with a search -engine, but is currently found at: - -http://www.amk.ca/python/code/crypto +integers. """ from __future__ import unicode_literals +import codecs import hashlib import hmac import os import random +import warnings + +import six from openid.oidutil import fromBase64, string_to_text, toBase64 @@ -28,6 +28,8 @@ 'randrange', 'sha1', 'sha256', + 'int_to_bytes', + 'bytes_to_int', ] @@ -85,48 +87,43 @@ def sha256(s): return sha256_module.new(s).digest() -try: - from Crypto.Util.number import long_to_bytes, bytes_to_long -except ImportError: - import pickle +def bytes_to_int(value): + """ + Convert byte string to integer. - def longToBinary(value): - if value == 0: - return b'\x00' + @type value: six.binary_type + @rtype: Union[six.integer_types] + """ + return int(codecs.encode(value, 'hex'), 16) - return pickle.encode_long(value)[::-1] - def binaryToLong(s): - return pickle.decode_long(s[::-1]) -else: - # We have pycrypto +def int_to_bytes(value): + """ + Convert integer to byte string. - def longToBinary(value): - if value < 0: - raise ValueError('This function only supports positive integers') + @type value: Union[six.integer_types] + @rtype: six.binary_type + """ + hex_value = '{:x}'.format(value) + if len(hex_value) % 2: + hex_value = '0' + hex_value + array = bytearray.fromhex(hex_value) + # First bit must be zero. If it isn't, the bytes must be prepended by zero byte. + # See http://openid.net/specs/openid-authentication-2_0.html#btwoc for details. + if array[0] > 127: + array = bytearray([0]) + array + return six.binary_type(array) - output = long_to_bytes(value) - if isinstance(output[0], int): - ord_first = output[0] - else: - ord_first = ord(output[0]) - if ord_first > 127: - return b'\x00' + output - else: - return output - def binaryToLong(s): - if not s: - raise ValueError('Empty string passed to strToLong') +# Deprecated versions of bytes <--> int conversions +def longToBinary(value): + warnings.warn("Function longToBinary is deprecated in favor of int_to_bytes.", DeprecationWarning) + return int_to_bytes(value) - if isinstance(s[0], int): - ord_first = s[0] - else: - ord_first = ord(s[0]) - if ord_first > 127: - raise ValueError('This function only supports positive integers') - return bytes_to_long(s) +def binaryToLong(s): + warnings.warn("Function binaryToLong is deprecated in favor of bytes_to_int.", DeprecationWarning) + return bytes_to_int(s) # A randrange function that works for longs @@ -149,7 +146,7 @@ def randrange(start, stop=None, step=1): try: (duplicate, nbytes) = _duplicate_cache[r] except KeyError: - rbytes = longToBinary(r) + rbytes = int_to_bytes(r) if rbytes[0] == '\x00': nbytes = len(rbytes) - 1 else: @@ -168,7 +165,7 @@ def randrange(start, stop=None, step=1): while True: bytes = '\x00' + os.urandom(nbytes) - n = binaryToLong(bytes) + n = bytes_to_int(bytes) # Keep looping if this value is in the low duplicated range if n >= duplicate: break @@ -177,8 +174,8 @@ def randrange(start, stop=None, step=1): def longToBase64(l): - return toBase64(longToBinary(l)) + return toBase64(int_to_bytes(l)) def base64ToLong(s): - return binaryToLong(fromBase64(s)) + return bytes_to_int(fromBase64(s)) diff --git a/openid/dh.py b/openid/dh.py index 28ff403b..46a05edb 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -60,5 +60,5 @@ def getSharedSecret(self, composite): def xorSecret(self, composite, secret, hash_func): dh_shared = self.getSharedSecret(composite) - hashed_dh_shared = hash_func(cryptutil.longToBinary(dh_shared)) + hashed_dh_shared = hash_func(cryptutil.int_to_bytes(dh_shared)) return strxor(secret, hashed_dh_shared) diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index 5d5cc5fd..51595759 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -70,6 +70,27 @@ def test_binaryLongConvert(self): assert s == s_prime, (n, s, s_prime) +class TestBytesIntConversion(unittest.TestCase): + """Test bytes <-> int conversions.""" + + # Examples from http://openid.net/specs/openid-authentication-2_0.html#btwoc + cases = [ + (b'\x00', 0), + (b'\x01', 1), + (b'\x7F', 127), + (b'\x00\xFF', 255), + (b'\x00\x80', 128), + (b'\x00\x81', 129), + (b'\x00\x80\x00', 32768), + (b'OpenID is cool', 1611215304203901150134421257416556) + ] + + def test_conversions(self): + for string, number in self.cases: + self.assertEqual(cryptutil.bytes_to_int(string), number) + self.assertEqual(cryptutil.int_to_bytes(number), string) + + class TestLongToBase64(unittest.TestCase): """Test `longToBase64` function.""" From 6df1e9ab1d15cb0f3dca56c4d107521abf73ed03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 16 May 2018 17:48:20 +0200 Subject: [PATCH 097/151] Move default DH constants --- openid/constants.py | 11 +++++++++++ openid/dh.py | 13 ++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) create mode 100644 openid/constants.py diff --git a/openid/constants.py b/openid/constants.py new file mode 100644 index 00000000..8128a27f --- /dev/null +++ b/openid/constants.py @@ -0,0 +1,11 @@ +"""Basic constants for openid library.""" +from __future__ import unicode_literals + +# Default Diffie-Hellman modulus and generator. +# Defined in OpenID specification http://openid.net/specs/openid-authentication-2_0.html#pvalue +DEFAULT_DH_MODULUS = int( + '155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698188993727816152646' + '631438561595825688188889951272158842675419950341258706556549803580104870537681476726513255747040765857479291291572' + '334510643245094715007229621094194349783925984760375594985848253359305585439638443' +) +DEFAULT_DH_GENERATOR = 2 diff --git a/openid/dh.py b/openid/dh.py index 46a05edb..aeb00b51 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -3,6 +3,7 @@ import six from openid import cryptutil +from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS if six.PY2: long_int = long @@ -29,16 +30,10 @@ def strxor(x, y): class DiffieHellman(object): - DEFAULT_MOD = int('155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698' - '188993727816152646631438561595825688188889951272158842675419950341258706556549803580104870537681' - '476726513255747040765857479291291572334510643245094715007229621094194349783925984760375594985848' - '253359305585439638443') - - DEFAULT_GEN = 2 @classmethod def fromDefaults(cls): - return cls(cls.DEFAULT_MOD, cls.DEFAULT_GEN) + return cls(DEFAULT_DH_MODULUS, DEFAULT_DH_GENERATOR) def __init__(self, modulus, generator): self.modulus = long_int(modulus) @@ -52,8 +47,8 @@ def _setPrivate(self, private): self.public = pow(self.generator, self.private, self.modulus) def usingDefaultValues(self): - return (self.modulus == self.DEFAULT_MOD and - self.generator == self.DEFAULT_GEN) + return (self.modulus == DEFAULT_DH_MODULUS and + self.generator == DEFAULT_DH_GENERATOR) def getSharedSecret(self, composite): return pow(composite, self.private, self.modulus) From ac2b6ed87e20243ba2423b43e01b8d0dd0876ee0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 22 May 2018 11:16:28 +0200 Subject: [PATCH 098/151] Use cryptography for DH parameters --- openid/dh.py | 37 +++++++++++++++++++++++++++---------- openid/test/test_dh.py | 16 +++++++++++++--- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/openid/dh.py b/openid/dh.py index aeb00b51..37bb677d 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -1,16 +1,12 @@ +""""Utilities for Diffie-Hellman key exchange.""" from __future__ import unicode_literals import six +from cryptography.hazmat.primitives.asymmetric.dh import DHParameterNumbers from openid import cryptutil from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS -if six.PY2: - long_int = long -else: - assert six.PY3 - long_int = int - def _xor(a_b): # Python 2 only @@ -30,16 +26,37 @@ def strxor(x, y): class DiffieHellman(object): + """Utility for Diffie-Hellman key exchange.""" + + def __init__(self, modulus, generator): + """Create a new instance. + + @type modulus: Union[six.integer_types] + @type generator: Union[six.integer_types] + """ + self.parameter_numbers = DHParameterNumbers(modulus, generator) + self._setPrivate(cryptutil.randrange(1, modulus - 1)) @classmethod def fromDefaults(cls): + """Create Diffie-Hellman with the default modulus and generator.""" return cls(DEFAULT_DH_MODULUS, DEFAULT_DH_GENERATOR) - def __init__(self, modulus, generator): - self.modulus = long_int(modulus) - self.generator = long_int(generator) + @property + def modulus(self): + """Return the prime modulus value. - self._setPrivate(cryptutil.randrange(1, modulus - 1)) + @rtype: Union[six.integer_types] + """ + return self.parameter_numbers.p + + @property + def generator(self): + """Return the generator value. + + @rtype: Union[six.integer_types] + """ + return self.parameter_numbers.g def _setPrivate(self, private): """This is here to make testing easier""" diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index ddd84f0c..838e7460 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -6,7 +6,8 @@ import six -from openid.dh import DiffieHellman, long_int, strxor +from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS +from openid.dh import DiffieHellman, strxor class TestStrXor(unittest.TestCase): @@ -52,6 +53,15 @@ def test_strxor(self): class TestDiffieHellman(unittest.TestCase): + """Test `DiffieHellman` class.""" + + def test_modulus(self): + dh = DiffieHellman.fromDefaults() + self.assertEqual(dh.modulus, DEFAULT_DH_MODULUS) + + def test_generator(self): + dh = DiffieHellman.fromDefaults() + self.assertEqual(dh.generator, DEFAULT_DH_GENERATOR) def _test_dh(self): dh1 = DiffieHellman.fromDefaults() @@ -72,8 +82,8 @@ def test_public(self): try: for line in f: parts = line.strip().split(' ') - dh._setPrivate(long_int(parts[0])) + dh._setPrivate(int(parts[0])) - assert dh.public == long_int(parts[1]) + assert dh.public == int(parts[1]) finally: f.close() From b7dd61577c7feb26b8c205c3cffaaf48f47f4bfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 23 May 2018 10:04:12 +0200 Subject: [PATCH 099/151] Use cryptography for DH keys --- openid/consumer/consumer.py | 4 +--- openid/dh.py | 29 ++++++++++++++++++++------ openid/server/server.py | 2 +- openid/test/dhpriv | 29 -------------------------- openid/test/test_consumer.py | 4 ++-- openid/test/test_dh.py | 40 +++++++++++++++++++++++++++--------- openid/test/test_server.py | 4 ++-- 7 files changed, 59 insertions(+), 53 deletions(-) delete mode 100644 openid/test/dhpriv diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 5508c45c..43220118 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -462,9 +462,7 @@ def __init__(self, dh=None): self.dh = dh def getRequest(self): - cpub = cryptutil.longToBase64(self.dh.public) - - args = {'dh_consumer_public': cpub} + args = {'dh_consumer_public': self.dh.public_key} if not self.dh.usingDefaultValues(): args.update({ diff --git a/openid/dh.py b/openid/dh.py index 37bb677d..2241b390 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -1,7 +1,10 @@ """"Utilities for Diffie-Hellman key exchange.""" from __future__ import unicode_literals +import warnings + import six +from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric.dh import DHParameterNumbers from openid import cryptutil @@ -35,7 +38,8 @@ def __init__(self, modulus, generator): @type generator: Union[six.integer_types] """ self.parameter_numbers = DHParameterNumbers(modulus, generator) - self._setPrivate(cryptutil.randrange(1, modulus - 1)) + parameters = self.parameter_numbers.parameters(default_backend()) + self.private_key = parameters.generate_private_key() @classmethod def fromDefaults(cls): @@ -58,17 +62,30 @@ def generator(self): """ return self.parameter_numbers.g - def _setPrivate(self, private): - """This is here to make testing easier""" - self.private = private - self.public = pow(self.generator, self.private, self.modulus) + @property + def public(self): + """Return the public key. + + @rtype: Union[six.integer_types] + """ + warnings.warn("Attribute 'public' is deprecated. Use 'public_key' instead.", DeprecationWarning) + return self.private_key.public_key().public_numbers().y + + @property + def public_key(self): + """Return base64 encoded public key. + + @rtype: six.text_type + """ + return cryptutil.longToBase64(self.private_key.public_key().public_numbers().y) def usingDefaultValues(self): return (self.modulus == DEFAULT_DH_MODULUS and self.generator == DEFAULT_DH_GENERATOR) def getSharedSecret(self, composite): - return pow(composite, self.private, self.modulus) + private = self.private_key.private_numbers().x + return pow(composite, private, self.modulus) def xorSecret(self, composite, secret, hash_func): dh_shared = self.getSharedSecret(composite) diff --git a/openid/server/server.py b/openid/server/server.py index dfc7fbab..90315c68 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -380,7 +380,7 @@ def answer(self, secret): secret, self.hash_func) return { - 'dh_server_public': cryptutil.longToBase64(self.dh.public), + 'dh_server_public': self.dh.public_key, 'enc_mac_key': oidutil.toBase64(mac_key), } diff --git a/openid/test/dhpriv b/openid/test/dhpriv deleted file mode 100644 index 0fa52314..00000000 --- a/openid/test/dhpriv +++ /dev/null @@ -1,29 +0,0 @@ -130706940119084053627151828062879423433929180135817317038378606310097533503449582079984816816837125851552273641820339909167103200910805078308128174143174269944095368580519322913514764528012639683546377014716235962867583443566164615728897857285824741767070432119909660645255499710701356135207437699643611094585 139808169914464096465921128085565621767096724855516655439365028496569658038844954238931647642811548254956660405394116677296461848124300258439895306367561416289126854788101396379292925819850897858045772500578222021901631436550118958972312221974009238050517034542286574826081826542722270952769078386418682059418 -91966407878983240112417790733941098492087186469785726449910011271065622315680646030230288265496017310433513856308693810812043160919214636748486185212617634222158204354206411031403206076739932806412551605172319515223573351072757800448643935018534945933808900467686115619932664888581913179496050117713298715475 88086484332488517006277516020842172054013692832175783214603951240851750819999098631851571207693874357651112736088114133607400684776234181681933311972926752846692615822043533641407510569745606256772455614745111122033229877596984718963046218854103292937700694160593653595134512369959987897086639788909618660591 -94633950701209990078055218830969910271587805983595045023718108184189787131629772007048606080263109446462048743696369276578815611098215686598630889831104860221067872883514840819381234786050098278403321905311637820524177879167250981289318356078312300538871435101338967079907049912435983871847334104247675360099 136836393035803488129856151345450008294260680733328546556640578838845312279198933806383329293483852515700876505956362639881210101974254765087350842271260064592406308509078284840473735904755203614987286456952991025347168970462354352741159076541157478949094536405618626397435745496863324654768971213730622037771 -24685127248019769965088146297942173464487677364928435784091685260262292485380918213538979925891771204729738138857126454465630594391449913947358655368215901119137728648638547728497517587701248406019427282237279437409508871300675355166059811431191200555457304463617727969228965042729205402243355816702436970430 103488011917988946858248200111251786178288940265978921633592888293430082248387786443813155999158786903216094876295371112716734481877806417714913656921169196196571699893360825510307056269738593971532017994987406325068886420548597161498019372380511676314312298122272401348856314619382867707981701472607230523868 -116791045850880292989786005885944774698035781824784400772676299590038746153860847252706167458966356897309533614849402276819438194497464696186624618374179812548893947178936305721131565012344462048549467883494038577857638815386798694225798517783768606048713198211730870155881426709644960689953998714045816205549 25767875422998856261320430397505398614439586659207416236135894343577952114994718158163212134503751463610021489053571733974769536157057815413209619147486931502025658987681202196476489081257777148377685478756033509708349637895740799542063593586769082830323796978935454479273531157121440998804334199442003857410 -75582226959658406842894734694860761896800153014775231713388264961517169436476322183886891849966756849783437334069692683523296295601533803799559985845105706728538458624387103621364117548643541824878550074680443708148686601108223917493525070861593238005735446708555769966855130921562955491250908613793521520082 51100990616369611694975829054222013346248289055987940844427061856603230021472379888102172458517294080775792439385531234808129302064303666640376750139242970123503857186428797403843206765926798353022284672682073397573130625177187185114726049347844460311761033584101482859992951420083621362870301150543916815123 -22852401165908224137274273646590366934616265607879280260563022941455466297431255072303172649495519837876946233272420969249841381161312477263365567831938496555136366981954001163034914812189448922853839616662859772087929140818377228980710884492996109434435597500854043325062122184466315338260530734979159890875 35017410720028595029711778101507729481023945551700945988329114663345341120595162378885287946069695772429641825579528116641336456773227542256911497084242947904528367986325800537695079726856460817606404224094336361853766354225558025931211551975334149258299477750615397616908655079967952372222383056221992235704 -37364490883518159794654045194678325635036705086417851509136183713863262621334636905291385255662750747808690129471989906644041585863034419130023070856805511017402434123099100618568335168939301014148587149578150068910141065808373976114927339040964292334109797421173369274978107389084873550233108940239410902552 40916262212189137562350357241447034318002130016858244002788189310078477605649010031339865625243230798681216437501833540185827501244378529230150467789369234869122179247196276164931090039290879808162629109742198951942358028123056268054775108592325500609335947248599688175189333996086475013450537086042387719925 -42030470670714872936404499074069849778147578537708230270030877866700844337372497704027708080369726758812896818567830863540507961487472657570488625639077418109017434494794778542739932765561706796300920251933107517954265066804108669800167526425723377411855061131982689717887180411017924173629124764378241885274 124652439272864857598747946875599560379786580730218192165733924418687522301721706620565030507816884907589477351553268146177293719586287258662025940181301472851649975563004543250656807255226609296537922304346339513054316391667044301386950180277940536542183725690479451746977789001659540839582630251935163344393 -33176766914206542084736303652243484580303865879984981189372762326078776390896986743451688462101732968104375838228070296418541745483112261133079756514082093269959937647525005374035326747696591842313517634077723301677759648869372517403529488493581781546743147639937580084065663597330159470577639629864369972900 67485835091897238609131069363014775606263390149204621594445803179810038685760826651889895397414961195533694176706808504447269558421955735607423135937153901140512527504198912146656610630396284977496295289999655140295415981288181545277299615922576281262872097567020980675200178329219970170480653040350512964539 -131497983897702298481056962402569646971797912524360547236788650961059980711719600424210346263081838703940277066368168874781981151411096949736205282734026497995296147418292226818536168555712128736975034272678008697869326747592750850184857659420541708058277866000692785617873742438060271311159568468507825422571 5400380840349873337222394910303409203226429752629134721503171858543984393161548520471799318518954232197106728096866840965784563043721652790856860155702760027304915133166173298206604451826182024471262142046935060360564569939062438160049193241369468208458085699995573492688298015026628427440418009025072261296 -83265103005695640943261961853521077357830295830250157593141844209296716788437615940096402365505416686459260302419338241462783388722843946886845478224048360927114533590583464979009731440049610985062455108831881153988321298531365779084012803908832525921630534096740755274371500276660832724874701671184539131864 141285570207910287798371174771658911045525474449663877845558585668334618068814605961306961485855329182957174312715910923324965889174835444049526313968571611940626279733302104955951067959291852710640374412577070764165811275030632465290729619533330733368808295932659463215921521905553936914975786500018720073003 -68435028583616495789148116911096163791710022987677894923742899873596891423986951658100606742052014161171185231735413902875605720814417622409817842932759492013585936536452615480700628719795872201528559780249210820284350401473564919576289210869896327937002173624497942136329576506818749730506884927872345019446 134655528287263100540003157571441260698452262106680191153945271167894435782028803135774578949200580551016388918860856991026082917835209212892423567114480975540305860034439015788120390011692862968771136814777768281366591257663821495720134621172848947971117885754539770645621669309650476331439675400544167728223 -97765390064836080322590528352647421920257073063706996347334558390461274981996865736612531330863478931481491964338380362350271734683183807511097331539820133036984271653285063355715726806139083282458695728902452215405696318402583540317419929113959816258829534543044153959951908676300847164682178008704099351835 92552521881196975294401505656851872247567784546370503402756239533783651371688190302773864319828182042605239246779598629409815474038541272600580320815319709309111399294952620375093803971373108792300726524826209329889463854451846561437729676142864421966497641824498079067929811613947148353921163336822026640804 -145767094672933012300753301037546647564595762930138884463767054235112032706630891961371504668013023047595721138624016493638510710257541241706724342585654715468628355455898091951826598092812212209834746162089753649871544789379424903025374228231365026585872808685759231756517703720396301355299998059523896918448 116669462839999965355861187716880953863237226719689755457884414384663576662696981997535568446560375442532084973721539944428004043491468494548231348032618218312515409944970197902589794303562379864012797605284844016184274353252071642511293089390472576498394410829972525726474727579603392265177009323768966538608 -34172517877854802711907683049441723730724885305592620486269966708379625109832852005775048584124451699198484092407720344962116726808090368739361658889584507734617844212547181476646725256303630128954338675520938806905779837227983648887192531356390902975904503218654196581612781227843742951241442641220856414232 126013077261793777773236390821108423367648447987653714614732477073177878509574051196587476846560696305938891953527959347566502332765820074506907037627115954790645652211088723122982633069089920979477728376746424256704724173255656757918995039125823421607024407307091796807227896314403153380323770001854211384322 -9979624731056222925878866378063961280844793874828281622845276060532093809300121084179730782833657205171434732875093693074415298975346410131191865198158876447591891117577190438695367929923494177555818480377241891190442070100052523008290671797937772993634966511431668500154258765510857129203107386972819651767 76559085024395996164590986654274454741199399364851956129137304209855150918182685643729981600389513229011956888957763987167398150792454613751473654448162776379362213885827651020309844507723069713820393068520302223477225569348080362344052033711960892643036147232270133731530049660264526964146237693063093765111 -18162696663677410793062235946366423954875282212790518677684260521370996677183041664345920941714064628111537529793170736292618705900247450994864220481135611781148410617609559050220262121494712903009168783279356915189941268264177631458029177102542745167475619936272581126346266816618866806564180995726437177435 63244550218824945129624987597134280916829928261688093445040235408899092619821698537312158783367974202557699994650667088974727356690181336666077506063310290098995215324552449858513870629176838494348632073938023916155113126203791709810160925798130199717340478393420816876665127594623142175853115698049952126277 -4817943161362708117912118300716778687157593557807116683477307391846133734701449509121209661982298574607233039490570567781316652698287671086985501523197566560479906850423709894582834963398034434055472063156147829131181965140631257939036683622084290629927807369457311894970308590034407761706800045378158588657 61612160237840981966750225147965256022861527286827877531373888434780789812764688703260066154973576040405676432586962624922734102370509771313805122788566405984830112657060375568510809122230960988304085950306616401218206390412815884549481965750553137717475620505076144744211331973240555181377832337912951699135 -36363324947629373144612372870171042343590861026293829791335153646774927623889458346817049419803031378037141773848560341251355283891019532059644644509836766167835557471311319194033709837770615526356168418160386395260066262292757953919140150454538786106958252854181965875293629955562111756775391296856504912587 86831561031659073326747216166881733513938228972332631084118628692228329095617884068498116676787029033973607066377816508795286358748076949738854520048303930186595481606562375516134920902325649683618195251332651685732712539073110524182134321873838204219194459231650917098791250048469346563303077080880339797744 -26406869969418301728540993821409753036653370247174689204659006239823766914991146853283367848649039747728229875444327879875275718711878211919734397349994000106499628652960403076186651083084423734034070082770589453774926850920776427074440483233447839259180467805375782600203654373428926653730090468535611335253 100139935381469543084506312717977196291289016554846164338908226931204624582010530255955411615528804421371905642197394534614355186795223905217732992497673429554618838376065777445760355552020655667172127543653684405493978325270279321013143828897100500212200358450649158287605846102419527584313353072518101626851 -92613116984760565837109105383781193800503303131143575169488835702472221039082994091847595094556327985517286288659598094631489552181233202387028607421487026032402972597880028640156629614572656967808446397456622178472130864873587747608262139844319805074476178618930354824943672367046477408898479503054125369731 30023391082615178562263328892343821010986429338255434046051061316154579824472412477397496718186615690433045030046315908170615910505869972621853946234911296439134838951047107272129711854649412919542407760508235711897489847951451200722151978578883748353566191421685659370090024401368356823252748749449302536931 -31485815361342085113278193504381994806529237123359718043079410511224607873725611862217941085749929342777366642477711445011074784469367917758629403998067347054115844421430072631339788256386509261291675080191633908849638316409182455648806133048549359800886124554879661473112614246869101243501787363247762961784 114503770698890543429251666713050844656853278831559195214556474458830029271801818536133531843456707474500106283648085144619097572354066554819887152106174400667929098257361286338795493838820850475790977445807435511982704395422526800272723708548541616513134676140304653112325071112865020365664833601046215694089 -76882090884790547431641385530818076533805072109483843307806375918023300052767710853172670987385376253156912268523505310624133905633437815297307463917718596711590885553760690350221265675690787249135345226947453988081566088302642706234126002514517416493192624887800567412565527886687096028028124049522890448168 15056463217273240496622619354104573042767532856243223052125822509781815362480522535564283485059790932505429110157271454207173426525345813426696743168079246510944969446574354255284952839036431873039487144279164893710061580467579842173706653409487110282515691099753380094215805485573768509475850463001549608836 -52345178981230648108672997265819959243255047568833938156267924185186047373470984278294897653277996726416846430969793375429223610099546622112048283560483136389901514170116723365811871938630317974150540909650396429631704968748113009366339718498979597226137532343384889080245796447593572468846438769413505393967 32148494517199936472358017244372701214529606506776255341152991328091526865643069587953759877295255050519124541457805199596762210567333445908166076384465183589342153762720515477404466193879418014196727238972417616122646440870364200208488239778452378059236162633837824948613596114768455832408342040970780086 -41095268619128788015767564971105114602454449306041732792746397800275041704886345704294273937217484580365505320134717320083763349380629342859670693445658118959823430378844830923452105707338162448974869312012791385772125813291388247857971218575518319578818336960572244046567099555399203328678654466958536663208 92166550199033418923713824997841892577149715275633481076285269142670107687867024550593869464613175882141630640739938334001211714884975032600306279287443909448541179109981755796752132502127330056736913454039526413284519137059580845856736918773597087836203497066909257930043736166431682872083389105176299181629 -40049143661018504441607875135884755310012910557581028447435354354754245291878800571089144452035026644953322330676651798951447670184106450649737772686119714700743396359069052813433030118630105307022867200053964644574786137276428546712005171080129190959914708907200288299169344380390093918556722227705114244981 108159089972386282154772900619022507336076619354549601813179459338897131937353741544606392560724999980281424266891537298473163753022749859939445293926707568015958367188089915420630082556748668489756475027008449860889202622698060097015044886961901650857610841562477736791450080980702347705778074391774667412741 -69905259478181995876884927656894491893594530150260951315109404530530357998889589977208787140430938039028941393673520799460431992051993157468616168400324834880926190141581037597526917869362292931957289043707855837933490285814769110495657056206391880865972389421774822461752702336812585852278453803972600333734 71821415380277072313878763768684432371552628204186742842154591000123020597011744840460964835414360968627162765288463383113375595799297552681618876474019263288277398833725479226930770694271622605114061622753165584075733358178384410640349907375170170910499615355511313349300918885560131539570707695789106185664 -26945345439378873515011714350080059082081595419023056538696949766471272811362104837806324694947413603019863785876836706911406330379274553386254346050697348395574746891556054334903838949157798006141473389066020212044825140294048709654273698482867946522782450500680195477050110145664069582549935651920545151500 80313315938584480048642653013876614091607852535582224914294013785054094052454758327935781971746329853786568549510067442145637007308960551652864942042189241081946607011847245280773379099020221884296226818685556430275385068764313042226925852500883894269809033380734632866477789520106865758504064806906234130588 diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index c9cd8f5a..1bac4681 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -1778,12 +1778,12 @@ def setUp(self): self.consumer_dh = DiffieHellman(100389557, 2) # base64(btwoc(g ^ xb mod p)) - self.dh_server_public = cryptutil.longToBase64(self.server_dh.public) + self.dh_server_public = self.server_dh.public_key self.secret = os.urandom(self.session_cls.secret_size) self.enc_mac_key = oidutil.toBase64( - self.server_dh.xorSecret(self.consumer_dh.public, + self.server_dh.xorSecret(cryptutil.base64ToLong(self.consumer_dh.public_key), self.secret, self.session_cls.hash_func)) diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 838e7460..84a79d86 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -1,12 +1,16 @@ """Test `openid.dh` module.""" from __future__ import unicode_literals -import os.path import unittest +import warnings import six +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.asymmetric.dh import DHPrivateNumbers, DHPublicNumbers +from testfixtures import ShouldWarn from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS +from openid.cryptutil import longToBase64 from openid.dh import DiffieHellman, strxor @@ -76,14 +80,30 @@ def test_exchange(self): s2 = self._test_dh() assert s1 != s2 + private_key = int( + '76773183260125655927407219021356850612958916567415386199501281181228346359328609688049646172182310748186340503' + '26318343789919595649515190982375134969315580266608309203790369036760020471410949003193451675532879428946682852' + '7087756147962428703119223967577366837042279080006329440425557036807436654929251188437293') + public_key = int( + '14830402392262721982219607342625341531794979311088664077137112813385301968870761946911013412944671626402638538' + '59019114967817783168739766941288204771883652891577627356203670315421489407520844320897873950439171044693921561' + '24149254347661216215110718681656349527564919668545970743829522251387472714136707262965225') + + def setup_keys(self, dh_object, public_key, private_key): + """Set up private and public key into DiffieHellman object.""" + public_numbers = DHPublicNumbers(public_key, dh_object.parameter_numbers) + private_numbers = DHPrivateNumbers(private_key, public_numbers) + dh_object.private_key = private_numbers.private_key(default_backend()) + def test_public(self): - f = open(os.path.join(os.path.dirname(__file__), 'dhpriv')) dh = DiffieHellman.fromDefaults() - try: - for line in f: - parts = line.strip().split(' ') - dh._setPrivate(int(parts[0])) - - assert dh.public == int(parts[1]) - finally: - f.close() + self.setup_keys(dh, self.public_key, self.private_key) + warning_msg = "Attribute 'public' is deprecated. Use 'public_key' instead." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertEqual(dh.public, self.server_public_key) + + def test_public_key(self): + dh = DiffieHellman.fromDefaults() + self.setup_keys(dh, self.public_key, self.private_key) + self.assertEqual(dh.public_key, longToBase64(self.public_key)) diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 88c1c021..eadbc46d 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -1293,7 +1293,7 @@ def test_dhSHA1(self): from openid.dh import DiffieHellman from openid.server.server import DiffieHellmanSHA1ServerSession consumer_dh = DiffieHellman.fromDefaults() - cpub = consumer_dh.public + cpub = cryptutil.base64ToLong(consumer_dh.public_key) server_dh = DiffieHellman.fromDefaults() session = DiffieHellmanSHA1ServerSession(server_dh, cpub) self.request = server.AssociateRequest(session, 'HMAC-SHA1') @@ -1318,7 +1318,7 @@ def test_dhSHA256(self): from openid.dh import DiffieHellman from openid.server.server import DiffieHellmanSHA256ServerSession consumer_dh = DiffieHellman.fromDefaults() - cpub = consumer_dh.public + cpub = cryptutil.base64ToLong(consumer_dh.public_key) server_dh = DiffieHellman.fromDefaults() session = DiffieHellmanSHA256ServerSession(server_dh, cpub) self.request = server.AssociateRequest(session, 'HMAC-SHA256') From d66a8a820c231694cfef4c3a102943c3c42cdb32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 22 May 2018 13:11:47 +0200 Subject: [PATCH 100/151] Use cryptography for DH key exchange --- openid/dh.py | 26 +++++++++--- openid/test/test_dh.py | 93 +++++++++++++++++++++++++++++++++--------- 2 files changed, 95 insertions(+), 24 deletions(-) diff --git a/openid/dh.py b/openid/dh.py index 2241b390..b6c2908f 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -5,7 +5,7 @@ import six from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.asymmetric.dh import DHParameterNumbers +from cryptography.hazmat.primitives.asymmetric.dh import DHParameterNumbers, DHPublicNumbers from openid import cryptutil from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS @@ -84,10 +84,26 @@ def usingDefaultValues(self): self.generator == DEFAULT_DH_GENERATOR) def getSharedSecret(self, composite): - private = self.private_key.private_numbers().x - return pow(composite, private, self.modulus) + """Return a shared secret. + + @param composite: Public key of the other party. + @type composite: Union[six.integer_types] + @rtype: Union[six.integer_types] + """ + warnings.warn("Method 'getSharedSecret' is deprecated in favor of 'get_shared_secret'.", DeprecationWarning) + return cryptutil.bytes_to_int(self.get_shared_secret(composite)) + + def get_shared_secret(self, public_key): + """Return a shared secret. + + @param public_key: Public key of the other party. + @type public_key: Union[six.integer_types] + @rtype: six.binary_type + """ + public_numbers = DHPublicNumbers(public_key, self.parameter_numbers) + return self.private_key.exchange(public_numbers.public_key(default_backend())) def xorSecret(self, composite, secret, hash_func): - dh_shared = self.getSharedSecret(composite) - hashed_dh_shared = hash_func(cryptutil.int_to_bytes(dh_shared)) + dh_shared = self.get_shared_secret(composite) + hashed_dh_shared = hash_func(dh_shared) return strxor(secret, hashed_dh_shared) diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 84a79d86..110202e5 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -1,6 +1,7 @@ """Test `openid.dh` module.""" from __future__ import unicode_literals +import os import unittest import warnings @@ -10,7 +11,7 @@ from testfixtures import ShouldWarn from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS -from openid.cryptutil import longToBase64 +from openid.cryptutil import base64ToLong, bytes_to_int, longToBase64, sha256 from openid.dh import DiffieHellman, strxor @@ -67,27 +68,30 @@ def test_generator(self): dh = DiffieHellman.fromDefaults() self.assertEqual(dh.generator, DEFAULT_DH_GENERATOR) - def _test_dh(self): - dh1 = DiffieHellman.fromDefaults() - dh2 = DiffieHellman.fromDefaults() - secret1 = dh1.getSharedSecret(dh2.public) - secret2 = dh2.getSharedSecret(dh1.public) - assert secret1 == secret2 - return secret1 - - def test_exchange(self): - s1 = self._test_dh() - s2 = self._test_dh() - assert s1 != s2 - - private_key = int( + consumer_private_key = int( '76773183260125655927407219021356850612958916567415386199501281181228346359328609688049646172182310748186340503' '26318343789919595649515190982375134969315580266608309203790369036760020471410949003193451675532879428946682852' '7087756147962428703119223967577366837042279080006329440425557036807436654929251188437293') - public_key = int( + consumer_public_key = int( '14830402392262721982219607342625341531794979311088664077137112813385301968870761946911013412944671626402638538' '59019114967817783168739766941288204771883652891577627356203670315421489407520844320897873950439171044693921561' '24149254347661216215110718681656349527564919668545970743829522251387472714136707262965225') + server_private_key = int( + '15467965641543992347841556205070390914637305348154825847599734515099514013537846015402306363308433241908283446' + '71248072297246966864402013185397179020027880855596392908146308184428215791914057102026401324081917190180806065' + '52997123133752764540011560986670942115061415865499463644558159755273696690932941082271979') + server_public_key = int( + '34503131980021108262326730163610830553875615642061454929962013481368582594793479022634253261703143188115239697' + '31865012494779720501092100433895935952054678007893102647432613158698447525023861310539814658911402112680185359' + '5512256481326572078983201034675082346312609787920346766733771767752145619255920370032919' + ) + shared_secret = ( + b'\x14u\xa1_k\xf6\x83\xfbp#\xc9\x8e\xd4qb#\xdc\xe0D\xfe\xbf\x08\x16\xc9\xd3\xedwr\nC&\xf2\x14\xca\x90\xcdr\xa2' + b'\xc7\x96A\x89\xb66\x8e\'W"_\xea\xa4\xd8\x97\xf7e\xdby`\x90\xe0\x8aUG\xf9x;\xc7\xb5\x9a\x1duq]\x8cn\xe5\x14' + b'\xf0\x12\xe3\xf2\x15H\xce\xebe\xd3\xea\xedu\xa8\x9d\xf9>\xfb\xdeL<0\x02\xcb\xfa\xf8\xeb)+\xc1Qn\xa3\n"\x03n' + b'\x12I\x9a\x145p\xaf\x87J\xca\x16T\xb4\xd8') + secret = b'Rimmer ordered hot gazpacho soup' + mac_key = b'\x84\x06)\x1f6\xcf\xbcA\xec\xd0\x9d\xad\xf0\xa6"\xaa\x8cl-)\x91\xccg\xc2Bl\x0c\x83\xdbZ5\xfd' def setup_keys(self, dh_object, public_key, private_key): """Set up private and public key into DiffieHellman object.""" @@ -97,7 +101,7 @@ def setup_keys(self, dh_object, public_key, private_key): def test_public(self): dh = DiffieHellman.fromDefaults() - self.setup_keys(dh, self.public_key, self.private_key) + self.setup_keys(dh, self.server_public_key, self.server_private_key) warning_msg = "Attribute 'public' is deprecated. Use 'public_key' instead." with ShouldWarn(DeprecationWarning(warning_msg)): warnings.simplefilter('always') @@ -105,5 +109,56 @@ def test_public(self): def test_public_key(self): dh = DiffieHellman.fromDefaults() - self.setup_keys(dh, self.public_key, self.private_key) - self.assertEqual(dh.public_key, longToBase64(self.public_key)) + self.setup_keys(dh, self.server_public_key, self.server_private_key) + self.assertEqual(dh.public_key, longToBase64(self.server_public_key)) + + def test_get_shared_secret_server(self): + server_dh = DiffieHellman.fromDefaults() + self.setup_keys(server_dh, self.server_public_key, self.server_private_key) + self.assertEqual(server_dh.get_shared_secret(self.consumer_public_key), self.shared_secret) + + def test_get_shared_secret_consumer(self): + consumer_dh = DiffieHellman.fromDefaults() + self.setup_keys(consumer_dh, self.consumer_public_key, self.consumer_private_key) + self.assertEqual(consumer_dh.get_shared_secret(self.server_public_key), self.shared_secret) + + def test_getSharedSecret(self): + # Test the deprecated method + consumer_dh = DiffieHellman.fromDefaults() + self.setup_keys(consumer_dh, self.consumer_public_key, self.consumer_private_key) + warning_msg = "Method 'getSharedSecret' is deprecated in favor of 'get_shared_secret'." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertEqual(consumer_dh.getSharedSecret(self.server_public_key), bytes_to_int(self.shared_secret)) + + def test_exchange_server_static(self): + # Test key exchange - server part with static values + server_dh = DiffieHellman.fromDefaults() + self.setup_keys(server_dh, self.server_public_key, self.server_private_key) + + self.assertEqual(server_dh.xorSecret(self.consumer_public_key, self.secret, sha256), self.mac_key) + self.assertEqual(server_dh.public_key, longToBase64(self.server_public_key)) + + def test_exchange_consumer_static(self): + # Test key exchange - consumer part with static values + consumer_dh = DiffieHellman.fromDefaults() + self.setup_keys(consumer_dh, self.consumer_public_key, self.consumer_private_key) + + shared_secret = consumer_dh.xorSecret(self.server_public_key, self.mac_key, sha256) + # Check secret was negotiated correctly + self.assertEqual(shared_secret, self.secret) + + def test_exchange_dynamic(self): + # Test complete key exchange with random values + # Consumer part + consumer_dh = DiffieHellman.fromDefaults() + consumer_public_key = consumer_dh.public_key + # Server part + secret = os.urandom(32) + server_dh = DiffieHellman.fromDefaults() + mac_key = server_dh.xorSecret(base64ToLong(consumer_public_key), secret, sha256) + server_public_key = server_dh.public_key + # Consumer part + shared_secret = consumer_dh.xorSecret(base64ToLong(server_public_key), mac_key, sha256) + # Check secret was negotiated correctly + self.assertEqual(secret, shared_secret) From bf042df902a4f06ffac8b671b20efabf5a3973cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 23 May 2018 10:37:49 +0200 Subject: [PATCH 101/151] Drop randrange function --- openid/cryptutil.py | 50 ----------------------------------- openid/test/test_cryptutil.py | 18 ------------- 2 files changed, 68 deletions(-) diff --git a/openid/cryptutil.py b/openid/cryptutil.py index c7de35f5..87de8374 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -10,8 +10,6 @@ import codecs import hashlib import hmac -import os -import random import warnings import six @@ -25,7 +23,6 @@ 'hmacSha256', 'longToBase64', 'longToBinary', - 'randrange', 'sha1', 'sha256', 'int_to_bytes', @@ -126,53 +123,6 @@ def binaryToLong(s): return bytes_to_int(s) -# A randrange function that works for longs -try: - randrange = random.SystemRandom().randrange -except AttributeError: - # In Python 2.2's random.Random, randrange does not support - # numbers larger than sys.maxint for randrange. For simplicity, - # use this implementation for any Python that does not have - # random.SystemRandom - - _duplicate_cache = {} - - def randrange(start, stop=None, step=1): - if stop is None: - stop = start - start = 0 - - r = (stop - start) // step - try: - (duplicate, nbytes) = _duplicate_cache[r] - except KeyError: - rbytes = int_to_bytes(r) - if rbytes[0] == '\x00': - nbytes = len(rbytes) - 1 - else: - nbytes = len(rbytes) - - mxrand = (256 ** nbytes) - - # If we get a number less than this, then it is in the - # duplicated range. - duplicate = mxrand % r - - if len(_duplicate_cache) > 10: - _duplicate_cache.clear() - - _duplicate_cache[r] = (duplicate, nbytes) - - while True: - bytes = '\x00' + os.urandom(nbytes) - n = bytes_to_int(bytes) - # Keep looping if this value is in the low duplicated range - if n >= duplicate: - break - - return start + (n % r) * step - - def longToBase64(l): return toBase64(int_to_bytes(l)) diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index 51595759..ce268cf6 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -19,24 +19,6 @@ long_int = int -class TestRandRange(unittest.TestCase): - """Test `randrange` function.""" - - def test_cryptrand(self): - # It's possible, but HIGHLY unlikely that a correct implementation - # will fail by returning the same number twice - - a = cryptutil.randrange(2 ** 128) - b = cryptutil.randrange(2 ** 128) - assert isinstance(a, long_int) - assert isinstance(b, long_int) - assert b != a - - # Make sure that we can generate random numbers that are larger - # than platform int size - cryptutil.randrange(long_int(sys.maxsize) + 1) - - class TestLongBinary(unittest.TestCase): """Test `longToBinary` and `binaryToLong` functions.""" From 1ea54374d83972ff0ac195e11a6fd24cbdd901ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 30 May 2018 15:39:01 +0200 Subject: [PATCH 102/151] Fix codecov uploads --- .gitignore | 2 +- .travis.yml | 5 +++++ tox.ini | 7 ++----- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 7affb99f..7b31086f 100644 --- a/.gitignore +++ b/.gitignore @@ -8,5 +8,5 @@ __pycache__ /.eggs /sstore # Coverage -/.coverage +/.coverage* /htmlcov diff --git a/.travis.yml b/.travis.yml index 35cfdd7c..091d075a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -20,3 +20,8 @@ install: - pip install tox-travis script: - tox +after_success: + - coverage combine + - coverage report + - pip install codecov + - codecov diff --git a/tox.ini b/tox.ini index a2ccdec7..0dbed218 100644 --- a/tox.ini +++ b/tox.ini @@ -11,8 +11,6 @@ python = # Generic specification for all unspecific environments [testenv] -deps = - codecov extras = tests djopenid: djopenid @@ -24,9 +22,8 @@ setenv = DJANGO_SETTINGS_MODULE = djopenid.settings PYTHONPATH = {toxinidir}/examples:{env:PYTHONPATH:} commands = - coverage run --branch --source=openid,examples --module unittest discover --start=openid - djopenid: coverage run --branch --source=openid,examples --append --module unittest discover --start={toxinidir}/examples - codecov + coverage run --parallel-mode --branch --source=openid,examples --module unittest discover --start=openid + djopenid: coverage run --parallel-mode --branch --source=openid,examples --module unittest discover --start={toxinidir}/examples [testenv:quality] whitelist_externals = make From 7900c5819c5419b96e3d8346d0654e20f746181d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 18 May 2018 14:13:58 +0200 Subject: [PATCH 103/151] Capture deprecation warnings in tests --- openid/test/test_message.py | 35 ++++++++++++++++++++++++++------- openid/test/test_pape_draft5.py | 8 +++++++- openid/test/test_urinorm.py | 8 +++++++- 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 8a636013..0c04219a 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -56,7 +56,10 @@ def test_getKeyOpenID(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(UndefinedOpenIDNamespace, self.msg.getKey, OPENID_NS, 'foo') + warning_msg = "UndefinedOpenIDNamespace exception is deprecated." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.getKey, OPENID_NS, 'foo') def test_getKeyBARE(self): self.assertEqual(self.msg.getKey(BARE_NS, 'foo'), 'foo') @@ -75,7 +78,10 @@ def test_hasKey(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(UndefinedOpenIDNamespace, self.msg.hasKey, OPENID_NS, 'foo') + warning_msg = "UndefinedOpenIDNamespace exception is deprecated." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.hasKey, OPENID_NS, 'foo') def test_hasKeyBARE(self): self.assertFalse(self.msg.hasKey(BARE_NS, 'foo')) @@ -103,7 +109,10 @@ def test_getArg(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(UndefinedOpenIDNamespace, self.msg.getArg, OPENID_NS, 'foo') + warning_msg = "UndefinedOpenIDNamespace exception is deprecated." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.getArg, OPENID_NS, 'foo') test_getArgBARE = mkGetArgTest(BARE_NS, 'foo') test_getArgNS1 = mkGetArgTest(OPENID1_NS, 'foo') @@ -115,7 +124,10 @@ def test_getArgs(self): # exception. I'm not sure which one is more right, since this # case should only happen when you're building a message from # scratch and so have no default namespace. - self.assertRaises(UndefinedOpenIDNamespace, self.msg.getArgs, OPENID_NS) + warning_msg = "UndefinedOpenIDNamespace exception is deprecated." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.getArgs, OPENID_NS) def test_getArgsBARE(self): self.assertEqual(self.msg.getArgs(BARE_NS), {}) @@ -130,7 +142,10 @@ def test_getArgsNS3(self): self.assertEqual(self.msg.getArgs('urn:nothing-significant'), {}) def test_updateArgs(self): - self.assertRaises(UndefinedOpenIDNamespace, self.msg.updateArgs, OPENID_NS, {'does not': 'matter'}) + warning_msg = "UndefinedOpenIDNamespace exception is deprecated." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.updateArgs, OPENID_NS, {'does not': 'matter'}) def _test_updateArgsNS(self, ns): update_args = { @@ -155,7 +170,10 @@ def test_updateArgsNS3(self): self._test_updateArgsNS('urn:nothing-significant') def test_setArg(self): - self.assertRaises(UndefinedOpenIDNamespace, self.msg.setArg, OPENID_NS, 'does not', 'matter') + warning_msg = "UndefinedOpenIDNamespace exception is deprecated." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.setArg, OPENID_NS, 'does not', 'matter') def _test_setArgNS(self, ns): key = 'Camper van Beethoven' @@ -185,7 +203,10 @@ def test_delArg(self): # right, since this case should only happen when you're # building a message from scratch and so have no default # namespace. - self.assertRaises(UndefinedOpenIDNamespace, self.msg.delArg, OPENID_NS, 'key') + warning_msg = "UndefinedOpenIDNamespace exception is deprecated." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertRaises(UndefinedOpenIDNamespace, self.msg.delArg, OPENID_NS, 'key') def _test_delArgNS(self, ns): key = 'Camper van Beethoven' diff --git a/openid/test/test_pape_draft5.py b/openid/test/test_pape_draft5.py index fdb783d7..0368c39d 100644 --- a/openid/test/test_pape_draft5.py +++ b/openid/test/test_pape_draft5.py @@ -1,12 +1,18 @@ from __future__ import unicode_literals import unittest +import warnings + +from testfixtures import ShouldWarn from openid.extensions import pape class PapeImportTestCase(unittest.TestCase): def test_version(self): - from openid.extensions.draft import pape5 + warning_msg = "Module 'openid.extensions.draft.pape5' is deprecated in favor of 'openid.extensions.pape'." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + from openid.extensions.draft import pape5 self.assertEqual(pape.Request, pape5.Request) self.assertEqual(pape.Response, pape5.Response) diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 50b53552..b7be3d95 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -3,6 +3,9 @@ from __future__ import unicode_literals import unittest +import warnings + +from testfixtures import ShouldWarn from openid.urinorm import urinorm @@ -12,7 +15,10 @@ class UrinormTest(unittest.TestCase): def test_normalized(self): self.assertEqual(urinorm('http://example.com/'), 'http://example.com/') - self.assertEqual(urinorm(b'http://example.com/'), 'http://example.com/') + warning_msg = "Binary input for urinorm is deprecated. Use text input instead." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertEqual(urinorm(b'http://example.com/'), 'http://example.com/') def test_lowercase_scheme(self): self.assertEqual(urinorm('htTP://example.com/'), 'http://example.com/') From 938401234a5293a6d90c2205b4079a51b01110ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 18 May 2018 14:14:24 +0200 Subject: [PATCH 104/151] Define return type of getTypeURIs --- openid/yadis/etxrd.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/openid/yadis/etxrd.py b/openid/yadis/etxrd.py index 4eefc0bc..9e97e4c8 100644 --- a/openid/yadis/etxrd.py +++ b/openid/yadis/etxrd.py @@ -10,6 +10,7 @@ from operator import itemgetter from time import strptime +import six from lxml import etree from openid.yadis import xri @@ -266,10 +267,18 @@ def sortedURIs(service_element): def getTypeURIs(service_element): - """Given a Service element, return a list of the contents of all - Type tags""" - return [type_element.text for type_element - in service_element.findall(type_tag)] + """Given a Service element, return a list of the contents of all Type tags. + + @rtype: List[six.text_type] + """ + output = [] + for type_element in service_element.findall(type_tag): + type_uri = type_element.text + # Attribute `text` returns str in both python 2 and 3, convert to text_type in 2.7 + if not isinstance(type_uri, six.text_type): + type_uri = type_uri.decode('utf-8') + output.append(type_uri) + return output def expandService(service_element): From ce4641110182ea7d6c72bf0c02eb530966fd8a8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 18 May 2018 14:25:31 +0200 Subject: [PATCH 105/151] Clean up URL encoded messages --- openid/consumer/consumer.py | 2 +- openid/fetchers.py | 10 ++++++++++ openid/message.py | 11 +++++++++-- openid/test/test_consumer.py | 1 + openid/test/test_fetchers.py | 2 +- 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 43220118..c380a079 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -222,7 +222,7 @@ def makeKVPost(request_message, server_url): @rtype: L{openid.message.Message} """ # XXX: TESTME - resp = fetchers.fetch(server_url, body=request_message.toURLEncoded()) + resp = fetchers.fetch(server_url, body=request_message.toURLEncoded().encode('utf-8')) # Process response in separate function that can be shared by async code. return _httpResponseToMessage(resp, server_url) diff --git a/openid/fetchers.py b/openid/fetchers.py index 11a6b3a6..f8e6e37d 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -4,6 +4,7 @@ import sys import time +import six from six import BytesIO from six.moves.urllib.error import HTTPError as UrllibHTTPError from six.moves.urllib.request import Request, urlopen @@ -155,6 +156,7 @@ def fetch(self, url, body=None, headers=None): the way. If a body is specified, then the request will be a POST. Otherwise, it will be a GET. + @type body: six.binary_type @param headers: HTTP headers to include with the request @type headers: Dict[six.text_type, six.text_type] @@ -214,6 +216,8 @@ class Urllib2Fetcher(HTTPFetcher): urlopen = staticmethod(urlopen) def fetch(self, url, body=None, headers=None): + assert body is None or isinstance(body, six.binary_type) + if not _allowedURL(url): raise ValueError('Bad URL scheme: %r' % (url,)) @@ -309,6 +313,8 @@ def _checkURL(self, url): return _allowedURL(url) def fetch(self, url, body=None, headers=None): + assert body is None or isinstance(body, six.binary_type) + stop = int(time.time()) + self.ALLOWED_TIME off = self.ALLOWED_TIME @@ -415,6 +421,8 @@ def fetch(self, url, body=None, headers=None): @see: C{L{HTTPFetcher.fetch}} """ + assert body is None or isinstance(body, six.binary_type) + if body: method = 'POST' else: @@ -465,6 +473,8 @@ def fetch(self, url, body=None, headers=None): @see: C{L{HTTPFetcher.fetch}} """ + assert body is None or isinstance(body, six.binary_type) + if body: method = 'POST' else: diff --git a/openid/message.py b/openid/message.py index d550b299..761e4708 100644 --- a/openid/message.py +++ b/openid/message.py @@ -392,9 +392,16 @@ def toKVForm(self): return kvform.dictToKV(self.toArgs()) def toURLEncoded(self): - """Generate an x-www-urlencoded string""" + """Generate an x-www-urlencoded string + + @rtype: six.text_type + """ args = sorted(self.toPostArgs().items()) - return urlencode(args) + result = urlencode(args) + # Function `urlencode` returns str in both python 2 and 3, convert to text_type in 2.7 + if isinstance(result, six.binary_type): + result = result.decode('utf-8') + return result def _fixNS(self, namespace): """Convert an input value into the internally used values of diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 1bac4681..9a15d680 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -105,6 +105,7 @@ def fetch(self, url, body=None, headers=None): if url in self.get_responses: return self.get_responses[url] else: + body = body.decode('utf-8') try: body.index('openid.mode=associate') except ValueError: diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 4be8b26f..7936cb57 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -399,7 +399,7 @@ def test_post(self): with responses.RequestsMock() as rsps: rsps.add(responses.POST, 'http://example.cz/', status=200, body=b'BODY', headers={'Content-Type': 'text/plain'}) - response = self.fetcher.fetch('http://example.cz/', body='key=value') + response = self.fetcher.fetch('http://example.cz/', body=b'key=value') expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) From c7cacf71077a358732f9695a49f78e0231ef03c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 21 May 2018 18:24:56 +0200 Subject: [PATCH 106/151] Clean deprecation for test asserts --- openid/test/test_association_response.py | 9 +++++---- openid/test/test_consumer.py | 10 +++++----- openid/test/test_etxrd.py | 4 ++-- openid/test/test_fetchers.py | 6 +++--- openid/test/test_message.py | 2 +- openid/test/test_pape.py | 2 +- openid/test/test_server.py | 7 ++++--- openid/test/test_urinorm.py | 11 ++++++----- openid/test/test_verifydisco.py | 9 +++++---- 9 files changed, 32 insertions(+), 28 deletions(-) diff --git a/openid/test/test_association_response.py b/openid/test/test_association_response.py index 4ef61262..6b9689f3 100644 --- a/openid/test/test_association_response.py +++ b/openid/test/test_association_response.py @@ -7,6 +7,7 @@ import unittest +import six from testfixtures import LogCapture from openid.consumer.consumer import GenericConsumer, ProtocolError @@ -122,7 +123,7 @@ def test(self): keys.remove('ns') msg = mkAssocResponse(*keys) msg.setArg(OPENID_NS, 'session_type', response_session_type) - with self.assertRaisesRegexp(ProtocolError, 'Session type mismatch'): + with six.assertRaisesRegex(self, ProtocolError, 'Session type mismatch'): self.consumer._extractAssociation(msg, assoc_session) return test @@ -284,13 +285,13 @@ def test_badAssocType(self): # Make sure that the assoc type in the response is not valid # for the given session. self.assoc_session.allowed_assoc_types = [] - with self.assertRaisesRegexp(ProtocolError, 'Unsupported assoc_type for session'): + with six.assertRaisesRegex(self, ProtocolError, 'Unsupported assoc_type for session'): self.consumer._extractAssociation(self.assoc_response, self.assoc_session) def test_badExpiresIn(self): # Invalid value for expires_in should cause failure self.assoc_response.setArg(OPENID_NS, 'expires_in', 'forever') - with self.assertRaisesRegexp(ProtocolError, 'Invalid expires_in'): + with six.assertRaisesRegex(self, ProtocolError, 'Invalid expires_in'): self.consumer._extractAssociation(self.assoc_response, self.assoc_session) @@ -333,5 +334,5 @@ def test_openid2success(self): def test_badDHValues(self): sess, server_resp = self._setUpDH() server_resp.setArg(OPENID_NS, 'enc_mac_key', '\x00\x00\x00') - with self.assertRaisesRegexp(ProtocolError, 'Malformed response for'): + with six.assertRaisesRegex(self, ProtocolError, 'Malformed response for'): self.consumer._extractAssociation(server_resp, sess) diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 9a15d680..3233da88 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -358,7 +358,7 @@ def test_notAList(self): # Value should be a single string. If it's a list, it should generate # an exception. query = {'openid.mode': ['cancel']} - with self.assertRaisesRegexp(TypeError, 'values'): + with six.assertRaisesRegex(self, TypeError, 'values'): Message.fromPostArgs(query) @@ -1401,7 +1401,7 @@ def getNextService(self, ignored): def test(): text = 'Error fetching XRDS document: Unit test' - with self.assertRaisesRegexp(DiscoveryFailure, text): + with six.assertRaisesRegex(self, DiscoveryFailure, text): self.consumer.begin('unused in this test') self.withDummyDiscovery(test, getNextService) @@ -1414,7 +1414,7 @@ def getNextService(self, ignored): def test(): text = 'No usable OpenID services found for http://a.user.url/' - with self.assertRaisesRegexp(DiscoveryFailure, text): + with six.assertRaisesRegex(self, DiscoveryFailure, text): self.consumer.begin(url) self.withDummyDiscovery(test, getNextService) @@ -1669,7 +1669,7 @@ def discoverAndVerify(claimed_id, to_match_endpoints): endpoint.server_url = "http://the-MOON.unittest/" endpoint.local_id = self.identifier self.services = [endpoint] - with self.assertRaisesRegexp(ProtocolError, text): + with six.assertRaisesRegex(self, ProtocolError, text): self.consumer._verifyDiscoveryResults(self.message, endpoint) def test_foreignDelegate(self): @@ -1690,7 +1690,7 @@ def discoverAndVerify(claimed_id, to_match_endpoints): endpoint.server_url = self.server_url endpoint.local_id = "http://unittest/juan-carlos" - with self.assertRaisesRegexp(ProtocolError, text): + with six.assertRaisesRegex(self, ProtocolError, text): self.consumer._verifyDiscoveryResults(self.message, endpoint) def test_nothingDiscovered(self): diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index f07dbb92..bc7d78f3 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -58,12 +58,12 @@ def test_minimal_xrds(self): def test_not_xrds(self): xml = '' - with self.assertRaisesRegexp(etxrd.XRDSError, 'Not an XRDS document'): + with six.assertRaisesRegex(self, etxrd.XRDSError, 'Not an XRDS document'): etxrd.parseXRDS(xml) def test_invalid_xml(self): xml = '<' - with self.assertRaisesRegexp(etxrd.XRDSError, 'Error parsing document as XML'): + with six.assertRaisesRegex(self, etxrd.XRDSError, 'Error parsing document as XML'): etxrd.parseXRDS(xml) def test_xxe(self): diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 7936cb57..44070ff7 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -360,7 +360,7 @@ def test_error(self): assertResponse(expected, response) def test_invalid_url(self): - with self.assertRaisesRegexp(self.invalid_url_error, 'Bad URL scheme:'): + with six.assertRaisesRegex(self, self.invalid_url_error, 'Bad URL scheme:'): self.fetcher.fetch('invalid://example.cz/') def test_connection_error(self): @@ -425,12 +425,12 @@ def test_error(self): def test_invalid_url(self): invalid_url = 'invalid://example.cz/' - with self.assertRaisesRegexp(InvalidSchema, "No connection adapters were found for '" + invalid_url + "'"): + with six.assertRaisesRegex(self, InvalidSchema, "No connection adapters were found for '" + invalid_url + "'"): self.fetcher.fetch(invalid_url) def test_connection_error(self): # Test connection error with responses.RequestsMock() as rsps: rsps.add(responses.GET, 'http://example.cz/', body=ConnectionError('Name or service not known')) - with self.assertRaisesRegexp(ConnectionError, 'Name or service not known'): + with six.assertRaisesRegex(self, ConnectionError, 'Name or service not known'): self.fetcher.fetch('http://example.cz/') diff --git a/openid/test/test_message.py b/openid/test/test_message.py index 0c04219a..0e034a72 100644 --- a/openid/test/test_message.py +++ b/openid/test/test_message.py @@ -98,7 +98,7 @@ def test_hasKeyNS3(self): def test_getAliasedArgSuccess(self): msg = Message.fromPostArgs({'openid.ns.test': 'urn://foo', 'openid.test.flub': 'bogus'}) actual_uri = msg.getAliasedArg('ns.test', no_default) - self.assertEquals("urn://foo", actual_uri) + self.assertEqual("urn://foo", actual_uri) def test_getAliasedArgFailure(self): msg = Message.fromPostArgs({'openid.test.flub': 'bogus'}) diff --git a/openid/test/test_pape.py b/openid/test/test_pape.py index d6f55cd6..fcf871e2 100644 --- a/openid/test/test_pape.py +++ b/openid/test/test_pape.py @@ -47,7 +47,7 @@ def test_addAuthLevel(self): # alias is None; we expect a new one to be generated. uri = 'http://another.example.com/' self.req.addAuthLevel(uri) - self.assert_(uri in self.req.auth_level_aliases.values()) + self.assertIn(uri, self.req.auth_level_aliases.values()) # We don't expect a new alias to be generated if one already # exists. diff --git a/openid/test/test_server.py b/openid/test/test_server.py index eadbc46d..2d942e31 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -6,6 +6,7 @@ import warnings from functools import partial +import six from mock import sentinel from six.moves.urllib.parse import parse_qs, parse_qsl, urlparse from testfixtures import LogCapture, ShouldWarn, StringComparison @@ -199,7 +200,7 @@ def test_dictOfLists(self): 'openid.return_to': self.rt_url, 'openid.trust_root': self.tr_url, } - with self.assertRaisesRegexp(TypeError, 'values'): + with six.assertRaisesRegex(self, TypeError, 'values'): self.decode(args) def test_checkidImmediate(self): @@ -506,7 +507,7 @@ def test_invalidns(self): args = {'openid.ns': 'Tuesday', 'openid.mode': 'associate'} - with self.assertRaisesRegexp(server.ProtocolError, 'Tuesday') as catch: + with six.assertRaisesRegex(self, server.ProtocolError, 'Tuesday') as catch: self.decode(args) self.assertTrue(catch.exception.openid_message) @@ -788,7 +789,7 @@ def setUp(self): self.request = make_checkid_request(op_endpoint=self.op_endpoint) def test_openid2_requires_provider(self): - with self.assertRaisesRegexp(ValueError, 'CheckIDRequest requires op_endpoint'): + with six.assertRaisesRegex(self, ValueError, 'CheckIDRequest requires op_endpoint'): server.CheckIDRequest(sentinel.identity, sentinel.return_to, claimed_id=sentinel.claimed_id, message=Message(OPENID2_NS)) diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index b7be3d95..2c7aa0c5 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -5,6 +5,7 @@ import unittest import warnings +import six from testfixtures import ShouldWarn from openid.urinorm import urinorm @@ -24,7 +25,7 @@ def test_lowercase_scheme(self): self.assertEqual(urinorm('htTP://example.com/'), 'http://example.com/') def test_unsupported_scheme(self): - self.assertRaisesRegexp(ValueError, 'Not an absolute HTTP or HTTPS URI', urinorm, 'ftp://example.com/') + six.assertRaisesRegex(self, ValueError, 'Not an absolute HTTP or HTTPS URI', urinorm, 'ftp://example.com/') def test_lowercase_hostname(self): self.assertEqual(urinorm('http://exaMPLE.COm/'), 'http://example.com/') @@ -36,9 +37,9 @@ def test_empty_hostname(self): self.assertEqual(urinorm('http://username@/'), 'http://username@/') def test_invalid_hostname(self): - self.assertRaisesRegexp(ValueError, 'Invalid hostname', urinorm, 'http://.it/') - self.assertRaisesRegexp(ValueError, 'Invalid hostname', urinorm, 'http://..it/') - self.assertRaisesRegexp(ValueError, 'Not an absolute URI', urinorm, 'http:///path/') + six.assertRaisesRegex(self, ValueError, 'Invalid hostname', urinorm, 'http://.it/') + six.assertRaisesRegex(self, ValueError, 'Invalid hostname', urinorm, 'http://..it/') + six.assertRaisesRegex(self, ValueError, 'Not an absolute URI', urinorm, 'http:///path/') def test_empty_port_section(self): self.assertEqual(urinorm('http://example.com:/'), 'http://example.com/') @@ -76,7 +77,7 @@ def test_path_percent_decode_unreserved(self): self.assertEqual(urinorm('http://example.com/foo%2Dbar%2dbaz'), 'http://example.com/foo-bar-baz') def test_illegal_characters(self): - self.assertRaisesRegexp(ValueError, 'Illegal characters in URI', urinorm, 'http://.com/') + six.assertRaisesRegex(self, ValueError, 'Illegal characters in URI', urinorm, 'http://.com/') def test_realms(self): # Urinorm supports OpenID realms with * in them diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index 1217c650..3e749641 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -2,6 +2,7 @@ import unittest +import six from testfixtures import LogCapture, StringComparison from openid import message @@ -26,7 +27,7 @@ def test_openID1NoLocalID(self): msg = message.Message.fromOpenIDArgs({}) with LogCapture() as logbook: - with self.assertRaisesRegexp(consumer.ProtocolError, 'Missing required field openid.identity'): + with six.assertRaisesRegex(self, consumer.ProtocolError, 'Missing required field openid.identity'): self.consumer._verifyDiscoveryResults(msg, endpoint) self.assertEqual(logbook.records, []) @@ -47,7 +48,7 @@ def test_openID2LocalIDNoClaimed(self): 'op_endpoint': 'Phone Home', 'identity': 'Jose Lius Borges'}) with LogCapture() as logbook: - with self.assertRaisesRegexp(consumer.ProtocolError, 'openid.identity is present without'): + with six.assertRaisesRegex(self, consumer.ProtocolError, 'openid.identity is present without'): self.consumer._verifyDiscoveryResults(msg) self.assertEqual(logbook.records, []) @@ -56,7 +57,7 @@ def test_openID2NoLocalIDClaimed(self): 'op_endpoint': 'Phone Home', 'claimed_id': 'Manuel Noriega'}) with LogCapture() as logbook: - with self.assertRaisesRegexp(consumer.ProtocolError, 'openid.claimed_id is present without'): + with six.assertRaisesRegex(self, consumer.ProtocolError, 'openid.claimed_id is present without'): self.consumer._verifyDiscoveryResults(msg) self.assertEqual(logbook.records, []) @@ -147,7 +148,7 @@ def discoverAndVerify(claimed_id, to_match_endpoints): 'op_endpoint': endpoint.server_url}) with LogCapture() as logbook: - with self.assertRaisesRegexp(consumer.ProtocolError, text): + with six.assertRaisesRegex(self, consumer.ProtocolError, text): self.consumer._verifyDiscoveryResults(msg, endpoint) logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), From 7e8fd5130d61ee581fd677204ee96c5f60612223 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 22 May 2018 09:46:00 +0200 Subject: [PATCH 107/151] Clean unclosed file warnings --- openid/test/discoverdata.py | 3 ++- openid/test/test_accept.py | 8 +++---- openid/test/test_discover.py | 3 ++- openid/test/test_etxrd.py | 12 +++++++---- openid/test/test_fetchers.py | 42 +++++++++++++++++++++++------------- 5 files changed, 42 insertions(+), 26 deletions(-) diff --git a/openid/test/discoverdata.py b/openid/test/discoverdata.py index 1cc14849..4c8430fc 100644 --- a/openid/test/discoverdata.py +++ b/openid/test/discoverdata.py @@ -59,7 +59,8 @@ def getExampleXRDS(): def readTests(filename): - data = open(filename).read() + with open(filename) as data_file: + data = data_file.read() tests = {} for case in data.split('\f\n'): (name, content) = case.split('\n', 1) diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index 8acea20c..b10934ac 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -13,11 +13,9 @@ def getTestData(): () -> [(int, six.text_type)] """ filename = os.path.join(os.path.dirname(__file__), 'data', 'accept.txt') - i = 1 - lines = [] - for line in open(filename, 'rb'): - lines.append((i, line.decode('utf-8'))) - i += 1 + with open(filename, 'rb') as data_file: + content = data_file.read().decode('utf-8') + lines = enumerate(content.splitlines(), start=1) return lines diff --git a/openid/test/test_discover.py b/openid/test/test_discover.py index 5e56a776..75b7ab8f 100644 --- a/openid/test/test_discover.py +++ b/openid/test/test_discover.py @@ -206,7 +206,8 @@ def readDataFile(filename): module_directory = os.path.dirname(os.path.abspath(__file__)) filename = os.path.join( module_directory, 'data', 'test_discover', filename) - return open(filename, 'rb').read() + with open(filename, 'rb') as data_file: + return data_file.read() class TestDiscovery(BaseTestDiscovery): diff --git a/openid/test/test_etxrd.py b/openid/test/test_etxrd.py index bc7d78f3..c9455834 100644 --- a/openid/test/test_etxrd.py +++ b/openid/test/test_etxrd.py @@ -88,7 +88,8 @@ def test_xxe(self): class TestServiceParser(unittest.TestCase): def setUp(self): - self.xmldoc = open(XRD_FILE, 'rb').read() + with open(XRD_FILE, 'rb') as xrd_file: + self.xmldoc = xrd_file.read() self.yadis_url = 'http://unittest.url/' def _getServices(self, flt=None): @@ -156,7 +157,8 @@ def testGetSeveralForOne(self): def testNoXRDS(self): """Make sure that we get an exception when an XRDS element is not present""" - self.xmldoc = open(NOXRDS_FILE, 'rb').read() + with open(NOXRDS_FILE, 'rb') as xml_file: + self.xmldoc = xml_file.read() self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) def testEmpty(self): @@ -168,7 +170,8 @@ def testEmpty(self): def testNoXRD(self): """Make sure that we get an exception when there is no XRD element present.""" - self.xmldoc = open(NOXRD_FILE, 'rb').read() + with open(NOXRD_FILE, 'rb') as xml_file: + self.xmldoc = xml_file.read() self.assertRaises(etxrd.XRDSError, services.applyFilter, self.yadis_url, self.xmldoc, None) @@ -181,7 +184,8 @@ def mkTest(iname, filename, expectedID): filename = datapath(filename) def test(self): - xrds = etxrd.parseXRDS(open(filename, 'rb').read()) + with open(filename, 'rb') as xrds_file: + xrds = etxrd.parseXRDS(xrds_file.read()) self._getCanonicalID(iname, xrds, expectedID) return test diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index 44070ff7..d7b03d1a 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -176,21 +176,18 @@ def log_request(self, *args): pass def do_GET(self): - if self.path == '/closed': - self.wfile.close() + try: + http_code, location = self.cases[self.path] + except KeyError: + self.errorResponse('Bad path') else: - try: - http_code, location = self.cases[self.path] - except KeyError: - self.errorResponse('Bad path') - else: - extra_headers = [('Content-type', 'text/plain')] - if location is not None: - host, port = self.server.server_address - base = ('http://%s:%s' % (socket.getfqdn(host), port,)) - location = base + location - extra_headers.append(('Location', location)) - self._respond(http_code, extra_headers, self.path) + extra_headers = [('Content-type', 'text/plain')] + if location is not None: + host, port = self.server.server_address + base = ('http://%s:%s' % (socket.getfqdn(host), port,)) + location = base + location + extra_headers.append(('Location', location)) + self._respond(http_code, extra_headers, self.path) def do_POST(self): try: @@ -227,7 +224,6 @@ def _respond(self, http_code, extra_headers, body): self.send_header(k, v) self.end_headers() self.wfile.write(body.encode('utf-8')) - self.wfile.close() def finish(self): if not self.wfile.closed: @@ -235,6 +231,22 @@ def finish(self): self.wfile.close() self.rfile.close() + def parse_request(self): + """Contain a hook to simulate closed connection.""" + # Parse the request first + # BaseHTTPRequestHandler is old style class in 2.7 + if type(FetcherTestHandler) == type: + result = super(FetcherTestHandler, self).parse_request() + else: + result = BaseHTTPRequestHandler.parse_request(self) + # If the connection should be closed, do so. + if self.path == '/closed': + self.wfile.close() + return False + else: + # Otherwise continue as usual. + return result + class TestFetchers(unittest.TestCase): def test(self): From 0330c17624e19f20e177f3a8001c0270f6e09114 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 22 May 2018 09:39:06 +0200 Subject: [PATCH 108/151] Clean other deprecations --- examples/djopenid/server/views.py | 9 +++++++++ openid/consumer/consumer.py | 2 +- openid/kvform.py | 4 ++-- openid/server/server.py | 4 ++-- openid/store/filestore.py | 4 ++-- 5 files changed, 16 insertions(+), 7 deletions(-) diff --git a/examples/djopenid/server/views.py b/examples/djopenid/server/views.py index 55809608..77fbfa82 100644 --- a/examples/djopenid/server/views.py +++ b/examples/djopenid/server/views.py @@ -45,6 +45,9 @@ def getServer(request): Get a Server object to perform OpenID authentication. """ endpoint_url = request.build_absolute_uri(reverse('server:endpoint')) + # Method `build_absolute_uri` returns str in both python 2 and 3, convert to text_type in 2.7 + if isinstance(endpoint_url, six.binary_type): + endpoint_url = endpoint_url.decode('utf-8') return Server(getOpenIDStore(), endpoint_url) @@ -139,6 +142,9 @@ def handleCheckIDRequest(request, openid_request): if not openid_request.idSelect(): id_url = request.build_absolute_uri(reverse('server:local_id')) + # Method `build_absolute_uri` returns str in both python 2 and 3, convert to text_type in 2.7 + if isinstance(id_url, six.binary_type): + id_url = id_url.decode('utf-8') # Confirm that this server can actually vouch for that # identifier @@ -202,6 +208,9 @@ def processTrustResult(request): # The identifier that this server can vouch for response_identity = request.build_absolute_uri(reverse('server:local_id')) + # Method `build_absolute_uri` returns str in both python 2 and 3, convert to text_type in 2.7 + if isinstance(response_identity, six.binary_type): + response_identity = response_identity.decode('utf-8') # If the decision was to allow the verification, respond # accordingly. diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index c380a079..4557c13d 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -1322,7 +1322,7 @@ def _getOpenID1SessionType(self, assoc_response): # OpenID 1, but we'll accept it anyway, while issuing a # warning. if session_type == 'no-encryption': - _LOGGER.warn('OpenID server sent "no-encryption" for OpenID 1.X') + _LOGGER.warning('OpenID server sent "no-encryption" for OpenID 1.X') # Missing or empty session type is the way to flag a # 'no-encryption' response. Change the session type to diff --git a/openid/kvform.py b/openid/kvform.py index ca196c53..d26fb725 100644 --- a/openid/kvform.py +++ b/openid/kvform.py @@ -32,7 +32,7 @@ def err(msg): if strict: raise KVFormError(formatted) else: - _LOGGER.warn(formatted) + _LOGGER.warning(formatted) lines = [] for k, v in seq: @@ -87,7 +87,7 @@ def err(msg): if strict: raise KVFormError(formatted) else: - _LOGGER.warn(formatted) + _LOGGER.warning(formatted) data = string_to_text(data, "Binary values for data are deprecated. Use text input instead.") diff --git a/openid/server/server.py b/openid/server/server.py index 90315c68..a124e2b7 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -438,8 +438,8 @@ def fromMessage(klass, message, op_endpoint=UNUSED): if message.isOpenID1(): session_type = message.getArg(OPENID_NS, 'session_type') if session_type == 'no-encryption': - _LOGGER.warn('Received OpenID 1 request with a no-encryption ' - 'assocaition session type. Continuing anyway.') + _LOGGER.warning('Received OpenID 1 request with a no-encryption ' + 'assocaition session type. Continuing anyway.') elif not session_type: session_type = 'no-encryption' diff --git a/openid/store/filestore.py b/openid/store/filestore.py index 58ee2a72..2d79cf5f 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -279,7 +279,7 @@ def _getAssociation(self, filename): assoc_file.close() try: - association = Association.deserialize(assoc_s) + association = Association.deserialize(assoc_s.decode('utf-8')) except ValueError: _removeIfPresent(filename) return None @@ -364,7 +364,7 @@ def _allAssocs(self): # Remove expired or corrupted associations try: - association = Association.deserialize(assoc_s) + association = Association.deserialize(assoc_s.decode('utf-8')) except ValueError: _removeIfPresent(association_filename) else: From 90c1668f1f0bdbf5e59a167570b93fa92bca366e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 8 Jun 2018 09:16:23 +0200 Subject: [PATCH 109/151] Add fix_btwoc utility function --- openid/cryptutil.py | 26 +++++++++++++++++++------- openid/test/test_cryptutil.py | 20 ++++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/openid/cryptutil.py b/openid/cryptutil.py index 87de8374..f47a6c1b 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -12,8 +12,6 @@ import hmac import warnings -import six - from openid.oidutil import fromBase64, string_to_text, toBase64 __all__ = [ @@ -94,6 +92,23 @@ def bytes_to_int(value): return int(codecs.encode(value, 'hex'), 16) +def fix_btwoc(value): + """ + Utility function to ensure the output conforms the `btwoc` function output. + + See http://openid.net/specs/openid-authentication-2_0.html#btwoc for details. + + @type value: bytes or bytearray + @rtype: bytes + """ + # Conversion to bytearray is python 2/3 compatible + array = bytearray(value) + # First bit must be zero. If it isn't, the bytes must be prepended by zero byte. + if array[0] > 127: + array = bytearray([0]) + array + return bytes(array) + + def int_to_bytes(value): """ Convert integer to byte string. @@ -105,11 +120,8 @@ def int_to_bytes(value): if len(hex_value) % 2: hex_value = '0' + hex_value array = bytearray.fromhex(hex_value) - # First bit must be zero. If it isn't, the bytes must be prepended by zero byte. - # See http://openid.net/specs/openid-authentication-2_0.html#btwoc for details. - if array[0] > 127: - array = bytearray([0]) + array - return six.binary_type(array) + # The output must be `btwoc` compatible + return fix_btwoc(array) # Deprecated versions of bytes <--> int conversions diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index ce268cf6..eae2a1c6 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -52,6 +52,26 @@ def test_binaryLongConvert(self): assert s == s_prime, (n, s, s_prime) +class TestFixBtwoc(unittest.TestCase): + """Test `fix_btwoc` function.""" + + cases = ( + (b'\x00', b'\x00'), + (b'\x01', b'\x01'), + (b'\x7F', b'\x7F'), + (b'\x80', b'\x00\x80'), + (b'\xFF', b'\x00\xFF'), + ) + + def test_bytes(self): + for value, output in self.cases: + self.assertEqual(cryptutil.fix_btwoc(value), output) + + def test_bytearray(self): + for value, output in self.cases: + self.assertEqual(cryptutil.fix_btwoc(bytearray(value)), output) + + class TestBytesIntConversion(unittest.TestCase): """Test bytes <-> int conversions.""" From 82c0cba30b3b97af03f69a598d747a050d9131d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 8 Jun 2018 09:17:09 +0200 Subject: [PATCH 110/151] Fix DH MAC secret computations --- openid/dh.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/openid/dh.py b/openid/dh.py index b6c2908f..8f3b3954 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -105,5 +105,10 @@ def get_shared_secret(self, public_key): def xorSecret(self, composite, secret, hash_func): dh_shared = self.get_shared_secret(composite) + + # The DH secret must be `btwoc` compatible. + # See http://openid.net/specs/openid-authentication-2_0.html#rfc.section.8.2.3 for details. + dh_shared = cryptutil.fix_btwoc(dh_shared) + hashed_dh_shared = hash_func(dh_shared) return strxor(secret, hashed_dh_shared) From 0cca466bc0deda8c4e543b638b0aa61a014ead5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 6 Jun 2018 12:31:28 +0200 Subject: [PATCH 111/151] Refactor associate script --- contrib/associate | 49 ----------- contrib/associate.py | 204 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+), 49 deletions(-) delete mode 100755 contrib/associate create mode 100755 contrib/associate.py diff --git a/contrib/associate b/contrib/associate deleted file mode 100755 index ca7e5884..00000000 --- a/contrib/associate +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python -"""Make an OpenID Assocition request against an endpoint and print the results.""" -from __future__ import unicode_literals - -import sys -from datetime import datetime - -from openid.consumer import consumer -from openid.consumer.discover import OpenIDServiceEndpoint -from openid.store.memstore import MemoryStore - - -def verboseAssociation(assoc): - """A more verbose representation of an Association. - """ - d = assoc.__dict__ - issued_date = datetime.fromtimestamp(assoc.issued) - d['issued_iso'] = issued_date.isoformat() - fmt = """ Type: %(assoc_type)s - Handle: %(handle)s - Issued: %(issued)s [%(issued_iso)s] - Lifetime: %(lifetime)s - Secret: %(secret)r -""" - return fmt % d - - -def main(): - if not sys.argv[1:]: - print("Usage: %s ENDPOINT_URL..." % (sys.argv[0],)) - for endpoint_url in sys.argv[1:]: - print("Associating with", endpoint_url) - - # This makes it clear why j3h made AssociationManager when we - # did the ruby port. We can't invoke requestAssociation - # without these other trappings. - store = MemoryStore() - endpoint = OpenIDServiceEndpoint() - endpoint.server_url = endpoint_url - c = consumer.GenericConsumer(store) - auth_req = c.begin(endpoint) - if auth_req.assoc: - print(verboseAssociation(auth_req.assoc)) - else: - print(" ...no association.") - - -if __name__ == '__main__': - main() diff --git a/contrib/associate.py b/contrib/associate.py new file mode 100755 index 00000000..221f1ab9 --- /dev/null +++ b/contrib/associate.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python +""" +Make an OpenID association request against an endpoint and print the results. + +Usage: associate.py [options] + associate.py -h | --help + +Options: + -h, --help show this help message and exit + -a, --assoc-type=ASSOC_TYPE set custom association type [default: HMAC-SHA256] + -s, --session-type=SES_TYPE set custom session type [default: DH-SHA256] + --generate-modulus generate another modulus (may take some time) + --generator=GENERATOR set custom generator value [default: 2] + -d, --debug print debug information +""" +from __future__ import unicode_literals + +import base64 +import binascii +import codecs +import logging +import sys + +import requests +import six +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric.dh import DHParameterNumbers, DHPublicNumbers, generate_parameters +from docopt import docopt + +# This script is intentionaly and completely independent on the openid library. +# That should prevent any unwanted changes in association establishing. + +DEFAULT_DH_MODULUS = int( + '155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698188993727816152646' + '631438561595825688188889951272158842675419950341258706556549803580104870537681476726513255747040765857479291291572' + '334510643245094715007229621094194349783925984760375594985848253359305585439638443' +) +DEFAULT_DH_GENERATOR = 2 + +OPENID2_NS = 'http://specs.openid.net/auth/2.0' + + +######################################################################################################################## +# Utilities copied from the openid library +def int_to_bytes(value): + """Convert integer -> bytes.""" + hex_value = '{:x}'.format(value) + if len(hex_value) % 2: + hex_value = '0' + hex_value + array = bytearray.fromhex(hex_value) + # First bit must be zero. If it isn't, the bytes must be prepended by zero byte. + # See http://openid.net/specs/openid-authentication-2_0.html#btwoc for details. + if array[0] > 127: + array = bytearray([0]) + array + return six.binary_type(array) + + +def int_to_base64(number): + """Convert int -> base64.""" + number_bytes = int_to_bytes(number) + return binascii.b2a_base64(number_bytes)[:-1].decode('utf-8') + + +def base64_to_int(value): + binary_value = binascii.a2b_base64(value) + return int(codecs.encode(binary_value, 'hex'), 16) + + +def strxor(x, y): + if len(x) != len(y): + raise ValueError('Inputs to strxor must have the same length') + + if six.PY2: + return b"".join(chr(ord(a) ^ ord(b)) for a, b in zip(x, y)) + else: + assert six.PY3 + return bytes((a ^ b) for a, b in zip(x, y)) + + +def parse_kv_response(response): + """Parse the key-value response.""" + decoded_data = {} + for line in response.iter_lines(): + line = line.strip() + if not line: + continue + pair = line.split(':', 1) + if not len(pair) == 2: + logging.warn("Not a key-value line: %s", line) + continue + key, value = pair + decoded_data[key.strip()] = value.strip() + return decoded_data + + +######################################################################################################################## +# The association code itself + +def parse_association_response(response): + """Parse the association response.""" + association_data = parse_kv_response(response) + if association_data.get('ns') != OPENID2_NS: + raise ValueError("Response is not an OpenID 2.0 response") + for key in ('assoc_type', 'session_type', 'assoc_handle', 'expires_in', 'dh_server_public', 'enc_mac_key'): + if key not in association_data: + raise ValueError("Required key {} is not in response.".format(key)) + return association_data + + +def establish_association(endpoint, assoc_type, session_type, generator, generate_modulus): + """Actually establish the association.""" + generator = int(generator) + if generate_modulus: + parameters = generate_parameters(generator=generator, key_size=2048, backend=default_backend()) + parameter_numbers = parameters.parameter_numbers() + else: + parameter_numbers = DHParameterNumbers(DEFAULT_DH_MODULUS, generator) + parameters = parameter_numbers.parameters(default_backend()) + private_key = parameters.generate_private_key() + public_key = int_to_base64(private_key.public_key().public_numbers().y) + logging.debug("Private key: %s", private_key.private_numbers().x) + logging.debug("Public key: %s", private_key.public_key().public_numbers().y) + + data = {'openid.ns': OPENID2_NS, + 'openid.mode': 'associate', + 'openid.assoc_type': assoc_type, + 'openid.session_type': session_type, + 'openid.dh_consumer_public': public_key} + if parameter_numbers != DHParameterNumbers(DEFAULT_DH_MODULUS, DEFAULT_DH_GENERATOR): + data['openid.dh_modulus'] = int_to_base64(parameter_numbers.p) + data['openid.dh_gen'] = int_to_base64(parameter_numbers.g) + logging.info("Query arguments: %s", data) + response = requests.post(endpoint, data=data) + + if response.status_code != 200: + if response.status_code == 400: + # Is it an error response? + error_data = parse_kv_response(response) + if error_data.get('mode') == 'error': + # It's an error response + raise ValueError("Server responded with error: {}".format(error_data.get('error'))) + raise ValueError("Response returned incorrect status code: {}".format(response.status_code)) + + association_data = parse_association_response(response) + logging.debug("Association data: %s", association_data) + if association_data['assoc_type'] != assoc_type: + raise ValueError( + "Unexpected assoc_type returned {}, expected {}".format(association_data['assoc_type'], assoc_type)) + if association_data['session_type'] != session_type: + raise ValueError( + "Unexpected session_type returned {}, expected {}".format(association_data['session_type'], session_type)) + + server_public_key = base64_to_int(association_data['dh_server_public']) + shared_secret = private_key.exchange( + DHPublicNumbers(server_public_key, parameter_numbers).public_key(default_backend())) + + # Not an ordinary DH secret is used here. + # According to http://openid.net/specs/openid-authentication-2_0.html#rfc.section.8.2.3, the first bit of + # the DH secret must be zero. If it isn't, the bytes must be prepended by zero byte before they're hashed. + shared_secret = bytearray(shared_secret) + if shared_secret[0] > 127: + shared_secret = bytearray([0]) + shared_secret + shared_secret = bytes(shared_secret) + logging.debug("DH shared secret: %s", base64.b64encode(shared_secret)) + + algorithm = getattr(hashes, assoc_type[5:]) + digest = hashes.Hash(algorithm(), backend=default_backend()) + digest.update(shared_secret) + hashed_dh_shared = digest.finalize() + + mac_key = strxor(base64.b64decode(association_data['enc_mac_key']), hashed_dh_shared) + + return {'assoc_type': association_data['assoc_type'], + 'session_type': association_data['session_type'], + 'assoc_handle': association_data['assoc_handle'], + 'expires_in': association_data['expires_in'], + 'mac_key': base64.b64encode(mac_key)} + + +def main(): + """Main script.""" + options = docopt(__doc__) + + # Set up logging + if options['--debug']: + level = logging.DEBUG + else: + level = logging.WARNING + logging.basicConfig(level=level, format='%(asctime)s %(levelname)s:%(funcName)s: %(message)s') + + try: + association = establish_association(options[''], options['--assoc-type'], options['--session-type'], + options['--generator'], options['--generate-modulus']) + except ValueError as error: + sys.stderr.write("Association failed.\n{}\n".format(error)) + sys.exit(1) + + for key, value in association.items(): + sys.stdout.write('{}: {}\n'.format(key, value)) + + +if __name__ == '__main__': + main() From fc1b20de7604c10e3f4f4a9eb7c891c5af3c9105 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 8 Jun 2018 13:19:03 +0200 Subject: [PATCH 112/151] Check empty identifiers Check if identifiers are defined instead of whether they're empty. See http://openid.net/specs/openid-authentication-2_0.html#rfc.section.9.1 --- openid/server/server.py | 4 ++-- openid/test/test_server.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/openid/server/server.py b/openid/server/server.py index a124e2b7..d4978dcf 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -579,11 +579,11 @@ def __init__(self, identity, return_to, trust_root=None, immediate=False, s = "OpenID 1 message did not contain openid.identity" raise ProtocolError(message, text=s) else: - if identity and not claimed_id: + if identity is not None and claimed_id is None: s = ("OpenID 2.0 message contained openid.identity but not " "claimed_id") raise ProtocolError(message, text=s) - elif claimed_id and not identity: + elif identity is None and claimed_id is not None: s = ("OpenID 2.0 message contained openid.claimed_id but not " "identity") raise ProtocolError(message, text=s) diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 2d942e31..e3264b3b 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -284,6 +284,28 @@ def test_checkidSetupNoIdentityOpenID2(self): self.assertEqual(r.trust_root, self.tr_url) self.assertEqual(r.return_to, self.rt_url) + def test_checkidSetupEmptyIdentityOpenID2(self): + args = { + 'openid.ns': OPENID2_NS, + 'openid.mode': 'checkid_setup', + 'openid.assoc_handle': self.assoc_handle, + 'openid.return_to': self.rt_url, + 'openid.realm': self.tr_url, + 'openid.identity': '', + } + self.assertRaises(server.ProtocolError, self.decode, args) + + def test_checkidSetupEmptyClaimedIDOpenID2(self): + args = { + 'openid.ns': OPENID2_NS, + 'openid.mode': 'checkid_setup', + 'openid.assoc_handle': self.assoc_handle, + 'openid.return_to': self.rt_url, + 'openid.realm': self.tr_url, + 'openid.claimed_id': '', + } + self.assertRaises(server.ProtocolError, self.decode, args) + def test_checkidSetupNoReturnOpenID1(self): """Make sure an OpenID 1 request cannot be decoded if it lacks a return_to. From 0e8aca4126e07318efab1aa6fa38e0e3f2889609 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 31 May 2018 13:36:11 +0200 Subject: [PATCH 113/151] Use cryptography for hash algorithms in DH --- openid/consumer/consumer.py | 20 +++++++++++++++++--- openid/dh.py | 20 ++++++++++++++++++++ openid/server/server.py | 19 ++++++++++++++----- openid/test/test_consumer.py | 36 ++++++++++++++++++++++++++++++++---- openid/test/test_dh.py | 19 +++++++++++++++---- openid/test/test_server.py | 33 +++++++++++++++++++++++++++++---- tox.ini | 4 +++- 7 files changed, 130 insertions(+), 21 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 4557c13d..02dd689a 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -190,8 +190,10 @@ import copy import logging +import warnings import six +from cryptography.hazmat.primitives import hashes from six.moves.urllib.parse import parse_qsl, urldefrag, urlparse from openid import cryptutil, fetchers, oidutil, urinorm @@ -450,8 +452,16 @@ def setAssociationPreference(self, association_preferences): class DiffieHellmanSHA1ConsumerSession(object): + """Handler for Diffie-Hellman session. + + @cvar algorithm: Hash algorithm for MAC key generation. + @type algorithm: hashes.HashAlgorithm + @cvar hash_func: Hash function for MAC key generation. Deprecated attribute. + @type hash_func: function + """ session_type = 'DH-SHA1' - hash_func = staticmethod(cryptutil.sha1) + algorithm = hashes.SHA1() + hash_func = None secret_size = 20 allowed_assoc_types = ['HMAC-SHA1'] @@ -478,12 +488,16 @@ def extractSecret(self, response): enc_mac_key64 = response.getArg(OPENID_NS, 'enc_mac_key', no_default) dh_server_public = cryptutil.base64ToLong(dh_server_public64) enc_mac_key = oidutil.fromBase64(enc_mac_key64) - return self.dh.xorSecret(dh_server_public, enc_mac_key, self.hash_func) + if self.hash_func is not None: + warnings.warn("Attribute hash_func is deprecated, use algorithm instead.", DeprecationWarning) + return self.dh.xorSecret(dh_server_public, enc_mac_key, self.hash_func) + else: + return self.dh.xor_secret(dh_server_public, enc_mac_key, self.algorithm) class DiffieHellmanSHA256ConsumerSession(DiffieHellmanSHA1ConsumerSession): session_type = 'DH-SHA256' - hash_func = staticmethod(cryptutil.sha256) + algorithm = hashes.SHA256() secret_size = 32 allowed_assoc_types = ['HMAC-SHA256'] diff --git a/openid/dh.py b/openid/dh.py index 8f3b3954..8f491319 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -5,6 +5,7 @@ import six from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric.dh import DHParameterNumbers, DHPublicNumbers from openid import cryptutil @@ -104,6 +105,7 @@ def get_shared_secret(self, public_key): return self.private_key.exchange(public_numbers.public_key(default_backend())) def xorSecret(self, composite, secret, hash_func): + warnings.warn("Method 'xorSecret' is deprecated, use 'xor_secret' instead.", DeprecationWarning) dh_shared = self.get_shared_secret(composite) # The DH secret must be `btwoc` compatible. @@ -112,3 +114,21 @@ def xorSecret(self, composite, secret, hash_func): hashed_dh_shared = hash_func(dh_shared) return strxor(secret, hashed_dh_shared) + + def xor_secret(self, public_key, secret, algorithm): + """Return a XOR of a secret key and hash of a DH exchanged secret. + + @type public_key: Union[six.integer_types] + @type secret: bytes + @type algorithm: hashes.HashAlgorithm + """ + dh_shared = self.get_shared_secret(public_key) + + # The DH secret must be `btwoc` compatible. + # See http://openid.net/specs/openid-authentication-2_0.html#rfc.section.8.2.3 for details. + dh_shared = cryptutil.fix_btwoc(dh_shared) + + digest = hashes.Hash(algorithm, backend=default_backend()) + digest.update(dh_shared) + hashed_dh_shared = digest.finalize() + return strxor(secret, hashed_dh_shared) diff --git a/openid/server/server.py b/openid/server/server.py index d4978dcf..badaf95a 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -124,6 +124,7 @@ from copy import deepcopy import six +from cryptography.hazmat.primitives import hashes from openid import cryptutil, kvform, oidutil from openid.association import Association, default_negotiator, getSecretSize @@ -314,6 +315,11 @@ class DiffieHellmanSHA1ServerSession(object): session. @type session_type: six.text_type + @cvar algorithm: Hash algorithm for MAC key generation. + @type algorithm: hashes.HashAlgorithm + @cvar hash_func: Hash function for MAC key generation. Deprecated attribute. + @type hash_func: function + @ivar dh: The Diffie-Hellman algorithm values for this request @type dh: DiffieHellman @@ -326,7 +332,8 @@ class DiffieHellmanSHA1ServerSession(object): @see: AssociateRequest """ session_type = 'DH-SHA1' - hash_func = staticmethod(cryptutil.sha1) + algorithm = hashes.SHA1() + hash_func = None allowed_assoc_types = ['HMAC-SHA1'] def __init__(self, dh, consumer_pubkey): @@ -376,9 +383,11 @@ def fromMessage(cls, message): return cls(dh, consumer_pubkey) def answer(self, secret): - mac_key = self.dh.xorSecret(self.consumer_pubkey, - secret, - self.hash_func) + if self.hash_func is not None: + warnings.warn("Attribute hash_func is deprecated, use algorithm instead.", DeprecationWarning) + mac_key = self.dh.xorSecret(self.consumer_pubkey, secret, self.hash_func) + else: + mac_key = self.dh.xor_secret(self.consumer_pubkey, secret, self.algorithm) return { 'dh_server_public': self.dh.public_key, 'enc_mac_key': oidutil.toBase64(mac_key), @@ -387,7 +396,7 @@ def answer(self, secret): class DiffieHellmanSHA256ServerSession(DiffieHellmanSHA1ServerSession): session_type = 'DH-SHA256' - hash_func = staticmethod(cryptutil.sha256) + algorithm = hashes.SHA256() allowed_assoc_types = ['HMAC-SHA256'] diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 3233da88..0796e1e6 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -3,11 +3,12 @@ import os import time import unittest +import warnings from functools import partial import six from six.moves.urllib.parse import parse_qsl, urlparse -from testfixtures import LogCapture, StringComparison +from testfixtures import LogCapture, ShouldWarn, StringComparison from openid import association, cryptutil, fetchers, kvform, oidutil from openid.consumer.consumer import (CANCEL, FAILURE, SETUP_NEEDED, SUCCESS, AuthRequest, CancelResponse, Consumer, @@ -1784,9 +1785,9 @@ def setUp(self): self.secret = os.urandom(self.session_cls.secret_size) self.enc_mac_key = oidutil.toBase64( - self.server_dh.xorSecret(cryptutil.base64ToLong(self.consumer_dh.public_key), - self.secret, - self.session_cls.hash_func)) + self.server_dh.xor_secret(cryptutil.base64ToLong(self.consumer_dh.public_key), + self.secret, + self.session_cls.algorithm)) self.consumer_session = self.session_cls(self.consumer_dh) @@ -1837,6 +1838,33 @@ class TestOpenID2SHA256(TestDiffieHellmanResponseParameters, unittest.TestCase): message_namespace = OPENID2_NS +class TestDiffieHellmanSHA1ConsumerSession(unittest.TestCase): + """Unittests of `DiffieHellmanSHA1ConsumerSession` class.""" + + def test_custom_hash_func(self): + def zero_hash(value): + return b'\x00' * 20 + + class ZeroHashConsumerSession(DiffieHellmanSHA1ConsumerSession): + hash_func = staticmethod(zero_hash) + + server_dh = DiffieHellman.fromDefaults() + consumer_dh = DiffieHellman.fromDefaults() + + msg = Message(OPENID2_NS) + msg.setArg(OPENID_NS, 'dh_server_public', server_dh.public_key) + msg.setArg(OPENID_NS, 'enc_mac_key', oidutil.toBase64(b'Rimmer is smeg head!')) + + consumer_session = ZeroHashConsumerSession(consumer_dh) + with ShouldWarn() as captured: + warnings.simplefilter('always') + self.assertEqual(consumer_session.extractSecret(msg), b'Rimmer is smeg head!') + # There are 2 warnings, we need to check only one. + self.assertIsInstance(captured[0].message, DeprecationWarning) + self.assertEqual(six.text_type(captured[0].message), + "Attribute hash_func is deprecated, use algorithm instead.") + + class TestNoStore(unittest.TestCase): def setUp(self): self.consumer = GenericConsumer(None) diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 110202e5..84be7a54 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -7,6 +7,7 @@ import six from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric.dh import DHPrivateNumbers, DHPublicNumbers from testfixtures import ShouldWarn @@ -131,12 +132,22 @@ def test_getSharedSecret(self): warnings.simplefilter('always') self.assertEqual(consumer_dh.getSharedSecret(self.server_public_key), bytes_to_int(self.shared_secret)) + def test_xorSecret(self): + # Test key exchange - deprecated method + server_dh = DiffieHellman.fromDefaults() + self.setup_keys(server_dh, self.server_public_key, self.server_private_key) + + warning_msg = "Method 'xorSecret' is deprecated, use 'xor_secret' instead." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertEqual(server_dh.xorSecret(self.consumer_public_key, self.secret, sha256), self.mac_key) + def test_exchange_server_static(self): # Test key exchange - server part with static values server_dh = DiffieHellman.fromDefaults() self.setup_keys(server_dh, self.server_public_key, self.server_private_key) - self.assertEqual(server_dh.xorSecret(self.consumer_public_key, self.secret, sha256), self.mac_key) + self.assertEqual(server_dh.xor_secret(self.consumer_public_key, self.secret, hashes.SHA256()), self.mac_key) self.assertEqual(server_dh.public_key, longToBase64(self.server_public_key)) def test_exchange_consumer_static(self): @@ -144,7 +155,7 @@ def test_exchange_consumer_static(self): consumer_dh = DiffieHellman.fromDefaults() self.setup_keys(consumer_dh, self.consumer_public_key, self.consumer_private_key) - shared_secret = consumer_dh.xorSecret(self.server_public_key, self.mac_key, sha256) + shared_secret = consumer_dh.xor_secret(self.server_public_key, self.mac_key, hashes.SHA256()) # Check secret was negotiated correctly self.assertEqual(shared_secret, self.secret) @@ -156,9 +167,9 @@ def test_exchange_dynamic(self): # Server part secret = os.urandom(32) server_dh = DiffieHellman.fromDefaults() - mac_key = server_dh.xorSecret(base64ToLong(consumer_public_key), secret, sha256) + mac_key = server_dh.xor_secret(base64ToLong(consumer_public_key), secret, hashes.SHA256()) server_public_key = server_dh.public_key # Consumer part - shared_secret = consumer_dh.xorSecret(base64ToLong(server_public_key), mac_key, sha256) + shared_secret = consumer_dh.xor_secret(base64ToLong(server_public_key), mac_key, hashes.SHA256()) # Check secret was negotiated correctly self.assertEqual(secret, shared_secret) diff --git a/openid/test/test_server.py b/openid/test/test_server.py index e3264b3b..c5cd96fe 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -7,14 +7,17 @@ from functools import partial import six +from cryptography.hazmat.primitives import hashes from mock import sentinel from six.moves.urllib.parse import parse_qs, parse_qsl, urlparse from testfixtures import LogCapture, ShouldWarn, StringComparison from openid import association, cryptutil, oidutil from openid.consumer.consumer import DiffieHellmanSHA256ConsumerSession +from openid.dh import DiffieHellman from openid.message import IDENTIFIER_SELECT, OPENID1_NS, OPENID1_URL_LIMIT, OPENID2_NS, OPENID_NS, Message, no_default from openid.server import server +from openid.server.server import DiffieHellmanSHA1ServerSession from openid.store import memstore # In general, if you edit or add tests here, try to move in the direction @@ -1301,6 +1304,30 @@ def test_invalidatehandleNo(self): self.assertEqual(r.fields.getArgs(OPENID_NS), {'is_valid': 'true'}) +class TestDiffieHellmanSHA1ServerSession(unittest.TestCase): + """Unittests of `DiffieHellmanSHA1ServerSession` class.""" + + def test_custom_hash_func(self): + def zero_hash(value): + return b'\x00' * 20 + + class ZeroHashServerSession(DiffieHellmanSHA1ServerSession): + hash_func = staticmethod(zero_hash) + + server_dh = DiffieHellman.fromDefaults() + consumer_dh = DiffieHellman.fromDefaults() + + server_session = ZeroHashServerSession(server_dh, cryptutil.base64ToLong(consumer_dh.public_key)) + result = {'dh_server_public': server_dh.public_key, 'enc_mac_key': oidutil.toBase64(b'Rimmer is smeg head!')} + with ShouldWarn() as captured: + warnings.simplefilter('always') + self.assertEqual(server_session.answer(b'Rimmer is smeg head!'), result) + # There are 2 warnings, we need to check only one. + self.assertIsInstance(captured[0].message, DeprecationWarning) + self.assertEqual(six.text_type(captured[0].message), + "Attribute hash_func is deprecated, use algorithm instead.") + + class TestAssociate(unittest.TestCase): # TODO: test DH with non-default values for modulus and gen. # (important to do because we actually had it broken for a while.) @@ -1313,8 +1340,6 @@ def setUp(self): def test_dhSHA1(self): self.assoc = self.signatory.createAssociation(dumb=False, assoc_type='HMAC-SHA1') - from openid.dh import DiffieHellman - from openid.server.server import DiffieHellmanSHA1ServerSession consumer_dh = DiffieHellman.fromDefaults() cpub = cryptutil.base64ToLong(consumer_dh.public_key) server_dh = DiffieHellman.fromDefaults() @@ -1332,7 +1357,7 @@ def test_dhSHA1(self): enc_key = oidutil.fromBase64(rfg("enc_mac_key")) spub = cryptutil.base64ToLong(rfg("dh_server_public")) - secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha1) + secret = consumer_dh.xor_secret(spub, enc_key, hashes.SHA1()) self.assertEqual(secret, self.assoc.secret) def test_dhSHA256(self): @@ -1357,7 +1382,7 @@ def test_dhSHA256(self): enc_key = oidutil.fromBase64(rfg("enc_mac_key")) spub = cryptutil.base64ToLong(rfg("dh_server_public")) - secret = consumer_dh.xorSecret(spub, enc_key, cryptutil.sha256) + secret = consumer_dh.xor_secret(spub, enc_key, hashes.SHA256()) self.assertEqual(secret, self.assoc.secret) def test_protoError256(self): diff --git a/tox.ini b/tox.ini index 0dbed218..9e6a8947 100644 --- a/tox.ini +++ b/tox.ini @@ -17,10 +17,12 @@ extras = httplib2: httplib2 pycurl: pycurl requests: requests -passenv = CI TRAVIS TRAVIS_* PYTHONWARNINGS +passenv = CI TRAVIS TRAVIS_* setenv = DJANGO_SETTINGS_MODULE = djopenid.settings PYTHONPATH = {toxinidir}/examples:{env:PYTHONPATH:} +# For some reason, python2.7 doesn't always apply `warnings.simplefilter` correctly. Set 'all' as default to avoid. + PYTHONWARNINGS = {env:PYTHONWARNINGS:all} commands = coverage run --parallel-mode --branch --source=openid,examples --module unittest discover --start=openid djopenid: coverage run --parallel-mode --branch --source=openid,examples --module unittest discover --start={toxinidir}/examples From e4a091d235202741132e03c8d0d643d94f8eaea5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 31 May 2018 14:15:46 +0200 Subject: [PATCH 114/151] Drop hash utilities --- openid/cryptutil.py | 22 ---------------------- openid/store/filestore.py | 5 +++-- openid/test/test_dh.py | 7 ++++++- 3 files changed, 9 insertions(+), 25 deletions(-) diff --git a/openid/cryptutil.py b/openid/cryptutil.py index f47a6c1b..e2aade9f 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -21,8 +21,6 @@ 'hmacSha256', 'longToBase64', 'longToBinary', - 'sha1', - 'sha256', 'int_to_bytes', 'bytes_to_int', ] @@ -50,16 +48,6 @@ def hmacSha1(key, text): return hmac.new(key, text.encode('utf-8'), sha1_module).digest() -def sha1(s): - """ - Return a SHA1 hash. - - @type s: six.binary_type - @rtype: six.binary_type - """ - return sha1_module.new(s).digest() - - def hmacSha256(key, text): """ Return a SHA256 HMAC. @@ -72,16 +60,6 @@ def hmacSha256(key, text): return hmac.new(key, text.encode('utf-8'), sha256_module).digest() -def sha256(s): - """ - Return a SHA256 hash. - - @type s: six.binary_type - @rtype: six.binary_type - """ - return sha256_module.new(s).digest() - - def bytes_to_int(value): """ Convert byte string to integer. diff --git a/openid/store/filestore.py b/openid/store/filestore.py index 2d79cf5f..33e35e0c 100644 --- a/openid/store/filestore.py +++ b/openid/store/filestore.py @@ -7,9 +7,10 @@ import string import time from errno import EEXIST, ENOENT +from hashlib import sha1 from tempfile import mkstemp -from openid import cryptutil, oidutil +from openid import oidutil from openid.association import Association from openid.oidutil import string_to_text from openid.store import nonce @@ -22,7 +23,7 @@ def _safe64(s): - h64 = oidutil.toBase64(cryptutil.sha1(s.encode('utf-8'))) + h64 = oidutil.toBase64(sha1(s.encode('utf-8')).digest()) h64 = h64.replace('+', '_') h64 = h64.replace('/', '.') h64 = h64.replace('=', '') diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 84be7a54..84e55547 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -12,7 +12,7 @@ from testfixtures import ShouldWarn from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS -from openid.cryptutil import base64ToLong, bytes_to_int, longToBase64, sha256 +from openid.cryptutil import base64ToLong, bytes_to_int, longToBase64 from openid.dh import DiffieHellman, strxor @@ -137,6 +137,11 @@ def test_xorSecret(self): server_dh = DiffieHellman.fromDefaults() self.setup_keys(server_dh, self.server_public_key, self.server_private_key) + def sha256(value): + digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) + digest.update(value) + return digest.finalize() + warning_msg = "Method 'xorSecret' is deprecated, use 'xor_secret' instead." with ShouldWarn(DeprecationWarning(warning_msg)): warnings.simplefilter('always') From add0101d92f44c4c87b1f0c78ffe19b6fcb9ad9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 5 Jun 2018 11:04:43 +0200 Subject: [PATCH 115/151] Use cryptography for HMAC --- openid/association.py | 19 +++++++++++++------ openid/cryptutil.py | 40 +--------------------------------------- 2 files changed, 14 insertions(+), 45 deletions(-) diff --git a/openid/association.py b/openid/association.py index ca063bda..f29a4c37 100644 --- a/openid/association.py +++ b/openid/association.py @@ -28,9 +28,12 @@ import time import six +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.constant_time import bytes_eq +from cryptography.hazmat.primitives.hmac import HMAC -from openid import cryptutil, kvform, oidutil +from openid import kvform, oidutil from openid.message import OPENID_NS from .oidutil import string_to_text @@ -233,6 +236,8 @@ class Association(object): is C{'HMAC-SHA1'}, but new types may be defined in the future. @type assoc_type: six.text_type + @cvar hmac_algorithms: Mapping of association type to hash algorithm. + @type hmac_algorithms: Dict[six.text_type, hashes.HashAlgorithm] @sort: __init__, fromExpiresIn, getExpiresIn, __eq__, __ne__, handle, secret, issued, lifetime, assoc_type @@ -248,9 +253,9 @@ class Association(object): 'assoc_type', ] - _macs = { - 'HMAC-SHA1': cryptutil.hmacSha1, - 'HMAC-SHA256': cryptutil.hmacSha256, + hmac_algorithms = { + 'HMAC-SHA1': hashes.SHA1(), + 'HMAC-SHA256': hashes.SHA256(), } @classmethod @@ -456,12 +461,14 @@ def sign(self, pairs): kv = kvform.seqToKV(pairs) try: - mac = self._macs[self.assoc_type] + algorithm = self.hmac_algorithms[self.assoc_type] except KeyError: raise ValueError( 'Unknown association type: %r' % (self.assoc_type,)) - return mac(self.secret, kv) + hmac = HMAC(self.secret, algorithm, backend=default_backend()) + hmac.update(kv.encode('utf-8')) + return hmac.finalize() def getMessageSignature(self, message): """Return the signature of a message. diff --git a/openid/cryptutil.py b/openid/cryptutil.py index e2aade9f..ded17d86 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -8,17 +8,13 @@ from __future__ import unicode_literals import codecs -import hashlib -import hmac import warnings -from openid.oidutil import fromBase64, string_to_text, toBase64 +from openid.oidutil import fromBase64, toBase64 __all__ = [ 'base64ToLong', 'binaryToLong', - 'hmacSha1', - 'hmacSha256', 'longToBase64', 'longToBinary', 'int_to_bytes', @@ -26,40 +22,6 @@ ] -class HashContainer(object): - def __init__(self, hash_constructor): - self.new = hash_constructor - self.digest_size = hash_constructor().digest_size - - -sha1_module = HashContainer(hashlib.sha1) -sha256_module = HashContainer(hashlib.sha256) - - -def hmacSha1(key, text): - """ - Return a SHA1 HMAC. - - @type key: six.binary_type - @type text: six.text_type, six.binary_type is deprecated - @rtype: six.binary_type - """ - text = string_to_text(text, "Binary values for text are deprecated. Use text input instead.") - return hmac.new(key, text.encode('utf-8'), sha1_module).digest() - - -def hmacSha256(key, text): - """ - Return a SHA256 HMAC. - - @type key: six.binary_type - @type text: six.text_type, six.binary_type is deprecated - @rtype: six.binary_type - """ - text = string_to_text(text, "Binary values for text are deprecated. Use text input instead.") - return hmac.new(key, text.encode('utf-8'), sha256_module).digest() - - def bytes_to_int(value): """ Convert byte string to integer. From b5fa6b97d988aaad9e2a8fc62fb72bf677672235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 29 Jun 2018 18:52:30 +0200 Subject: [PATCH 116/151] Fix #21 - Update log message levels --- openid/consumer/consumer.py | 26 +++++++++++++------------- openid/consumer/discover.py | 2 +- openid/kvform.py | 4 ++-- openid/oidutil.py | 2 -- openid/server/server.py | 4 ++-- openid/server/trustroot.py | 6 +++--- openid/test/test_consumer.py | 6 +++--- openid/test/test_negotiation.py | 16 ++++++++-------- openid/test/test_rpverify.py | 4 ++-- openid/test/test_server.py | 4 ++-- openid/test/test_verifydisco.py | 6 +++--- 11 files changed, 39 insertions(+), 41 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 02dd689a..6834e93c 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -676,7 +676,7 @@ def _checkReturnTo(self, message, return_to): try: self._verifyReturnToArgs(message.toPostArgs()) except ProtocolError as why: - _LOGGER.exception("Verifying return_to arguments: %s", why) + _LOGGER.warning("Verifying return_to arguments: %s", why) return False # Check the return_to base URL against the one in the message. @@ -935,7 +935,7 @@ def _verifyDiscoveryResultsOpenID2(self, resp_msg, endpoint): try: self._verifyDiscoverySingle(endpoint, to_match) except ProtocolError as e: - _LOGGER.exception("Error attempting to use stored discovery information: %s", e) + _LOGGER.info("Unable to use stored discovery information: %s", e) _LOGGER.info("Attempting discovery to verify endpoint") endpoint = self._discoverAndVerify( to_match.claimed_id, [to_match]) @@ -978,7 +978,7 @@ def _verifyDiscoveryResultsOpenID1(self, resp_msg, endpoint): except TypeURIMismatch: self._verifyDiscoverySingle(endpoint, to_match_1_0) except ProtocolError as e: - _LOGGER.exception("Error attempting to use stored discovery information: %s", e) + _LOGGER.info("Unable to use stored discovery information: %s", e) _LOGGER.info("Attempting discovery to verify endpoint") else: return endpoint @@ -1075,9 +1075,9 @@ def _verifyDiscoveredServices(self, claimed_id, services, to_match_endpoints): # succeeded. Return this endpoint. return endpoint else: - _LOGGER.error('Discovery verification failure for %s', claimed_id) + _LOGGER.warning('Discovery verification failure for %s', claimed_id) for failure_message in failure_messages: - _LOGGER.error(' * Endpoint mismatch: %s', failure_message) + _LOGGER.warning(' * Endpoint mismatch: %s', failure_message) raise DiscoveryFailure( 'No matching endpoint found after discovering %s' @@ -1096,7 +1096,7 @@ def _checkAuth(self, message, server_url): try: response = self._makeKVPost(request, server_url) except (fetchers.HTTPFetchingError, ServerError) as e: - _LOGGER.exception('check_authentication failed: %s', e) + _LOGGER.info('check_authentication failed: %s', e) return False else: return self._processCheckAuthResponse(response, server_url) @@ -1130,14 +1130,14 @@ def _processCheckAuthResponse(self, response, server_url): if invalidate_handle is not None: _LOGGER.info('Received "invalidate_handle" from server %s', server_url) if self.store is None: - _LOGGER.error('Unexpectedly got invalidate_handle without a store!') + _LOGGER.warning('Unexpectedly got invalidate_handle without a store!') else: self.store.removeAssociation(server_url, invalidate_handle) if is_valid == 'true': return True else: - _LOGGER.error('Server responds that checkAuth call is not valid') + _LOGGER.info('Server responds that checkAuth call is not valid') return False def _getAssociation(self, endpoint): @@ -1217,7 +1217,7 @@ def _extractSupportedAssociationType(self, server_error, endpoint, # The server didn't like the association/session type # that we sent, and it sent us back a message that # might tell us how to handle it. - _LOGGER.error('Unsupported association type %s: %s', assoc_type, server_error.error_text) + _LOGGER.warning('Unsupported association type %s: %s', assoc_type, server_error.error_text) # Extract the session_type and assoc_type from the # error message @@ -1225,11 +1225,11 @@ def _extractSupportedAssociationType(self, server_error, endpoint, session_type = server_error.message.getArg(OPENID_NS, 'session_type') if assoc_type is None or session_type is None: - _LOGGER.error('Server responded with unsupported association session but did not supply a fallback.') + _LOGGER.warning('Server responded with unsupported association session but did not supply a fallback.') return None elif not self.negotiator.isAllowed(assoc_type, session_type): - _LOGGER.error('Server sent unsupported session/association type: session_type=%s, assoc_type=%s', - session_type, assoc_type) + _LOGGER.warning('Server sent unsupported session/association type: session_type=%s, assoc_type=%s', + session_type, assoc_type) return None else: return assoc_type, session_type @@ -1249,7 +1249,7 @@ def _requestAssociation(self, endpoint, assoc_type, session_type): try: response = self._makeKVPost(args, endpoint.server_url) except fetchers.HTTPFetchingError as why: - _LOGGER.exception('openid.associate request failed: %s', why) + _LOGGER.warning('openid.associate request failed: %s', why) return None try: diff --git a/openid/consumer/discover.py b/openid/consumer/discover.py index b483bc0b..8689bf77 100644 --- a/openid/consumer/discover.py +++ b/openid/consumer/discover.py @@ -422,7 +422,7 @@ def discoverXRI(iname): for service_element in services: endpoints.extend(flt.getServiceEndpoints(iname, service_element)) except XRDSError: - _LOGGER.exception('xrds error on %s', iname) + _LOGGER.info('xrds error on %s', iname) for endpoint in endpoints: # Is there a way to pass this through the filter to the endpoint diff --git a/openid/kvform.py b/openid/kvform.py index d26fb725..1d27d722 100644 --- a/openid/kvform.py +++ b/openid/kvform.py @@ -32,7 +32,7 @@ def err(msg): if strict: raise KVFormError(formatted) else: - _LOGGER.warning(formatted) + _LOGGER.debug(formatted) lines = [] for k, v in seq: @@ -87,7 +87,7 @@ def err(msg): if strict: raise KVFormError(formatted) else: - _LOGGER.warning(formatted) + _LOGGER.debug(formatted) data = string_to_text(data, "Binary values for data are deprecated. Use text input instead.") diff --git a/openid/oidutil.py b/openid/oidutil.py index 3ed32de8..0f9ff99e 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -15,8 +15,6 @@ __all__ = ['log', 'appendArgs', 'toBase64', 'fromBase64', 'autoSubmitHTML'] -_LOGGER = logging.getLogger(__name__) - def autoSubmitHTML(form, title='OpenID transaction in progress'): return """ diff --git a/openid/server/server.py b/openid/server/server.py index badaf95a..42efc8ee 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -1181,13 +1181,13 @@ def verify(self, assoc_handle, message): "Binary values for assoc_handle are deprecated. Use text input instead.") assoc = self.getAssociation(assoc_handle, dumb=True) if not assoc: - _LOGGER.error("failed to get assoc with handle %r to verify message %r", assoc_handle, message) + _LOGGER.info("failed to get assoc with handle %r to verify message %r", assoc_handle, message) return False try: valid = assoc.checkMessageSignature(message) except ValueError as ex: - _LOGGER.exception("Error in verifying %s with %s: %s", message, assoc, ex) + _LOGGER.info("Error in verifying %s with %s: %s", message, assoc, ex) return False return valid diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index a555d06d..db48ed87 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -432,12 +432,12 @@ def verifyReturnTo(realm_str, return_to, _vrfy=getAllowedReturnURLs): try: allowable_urls = _vrfy(realm.buildDiscoveryURL()) except RealmVerificationRedirected as err: - _LOGGER.exception(six.text_type(err)) + _LOGGER.info(six.text_type(err)) return False if returnToMatches(allowable_urls, return_to): return True else: - _LOGGER.error("Failed to validate return_to %r for realm %r, was not in %s", - return_to, realm_str, allowable_urls) + _LOGGER.info("Failed to validate return_to %r for realm %r, was not in %s", + return_to, realm_str, allowable_urls) return False diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 0796e1e6..8232d05e 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -465,7 +465,7 @@ def discoverAndVerify(claimed_id, _to_match_endpoints): with LogCapture() as logbook: self.assertRaises(VerifiedError, self.consumer.complete, message, self.endpoint) - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Unable to use stored discovery .*')), ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) @@ -604,7 +604,7 @@ def test_invalidateMissing_noStore(self): r = self.consumer._processCheckAuthResponse(response, self.server_url) self.assertTrue(r) logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Received "invalidate_handle" from .*')), - ('openid.consumer.consumer', 'ERROR', 'Unexpectedly got invalidate_handle without a store!')) + ('openid.consumer.consumer', 'WARNING', 'Unexpectedly got invalidate_handle without a store!')) def test_invalidatePresent(self): """invalidate_handle with a handle that exists @@ -1164,7 +1164,7 @@ def test_error(self): self.assertFalse(r) logbook.check(('openid.consumer.consumer', 'INFO', 'Using OpenID check_authentication'), ('openid.consumer.consumer', 'INFO', 'stuff'), - ('openid.consumer.consumer', 'ERROR', StringComparison('check_authentication failed: .*: 404'))) + ('openid.consumer.consumer', 'INFO', StringComparison('check_authentication failed: .*: 404'))) def test_bad_args(self): query = { diff --git a/openid/test/test_negotiation.py b/openid/test/test_negotiation.py index 9b23233e..0f59a2d9 100644 --- a/openid/test/test_negotiation.py +++ b/openid/test/test_negotiation.py @@ -70,8 +70,8 @@ def testEmptyAssocType(self): with LogCapture() as logbook: self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) no_fallback_msg = 'Server responded with unsupported association session but did not supply a fallback.' - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), - ('openid.consumer.consumer', 'ERROR', no_fallback_msg)) + logbook.check(('openid.consumer.consumer', 'WARNING', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'WARNING', no_fallback_msg)) def testEmptySessionType(self): """ @@ -88,8 +88,8 @@ def testEmptySessionType(self): with LogCapture() as logbook: self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) no_fallback_msg = 'Server responded with unsupported association session but did not supply a fallback.' - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), - ('openid.consumer.consumer', 'ERROR', no_fallback_msg)) + logbook.check(('openid.consumer.consumer', 'WARNING', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'WARNING', no_fallback_msg)) def testNotAllowed(self): """ @@ -112,8 +112,8 @@ def testNotAllowed(self): with LogCapture() as logbook: self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) unsupported_msg = StringComparison('Server sent unsupported session/association type: .*') - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), - ('openid.consumer.consumer', 'ERROR', unsupported_msg)) + logbook.check(('openid.consumer.consumer', 'WARNING', StringComparison('Unsupported association type .*')), + ('openid.consumer.consumer', 'WARNING', unsupported_msg)) def testUnsupportedWithRetry(self): """ @@ -131,7 +131,7 @@ def testUnsupportedWithRetry(self): self.consumer.return_messages = [msg, assoc] with LogCapture() as logbook: self.assertEqual(self.consumer._negotiateAssociation(self.endpoint), assoc) - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*'))) + logbook.check(('openid.consumer.consumer', 'WARNING', StringComparison('Unsupported association type .*'))) def testUnsupportedWithRetryAndFail(self): """ @@ -150,7 +150,7 @@ def testUnsupportedWithRetryAndFail(self): with LogCapture() as logbook: self.assertIsNone(self.consumer._negotiateAssociation(self.endpoint)) refused_msg = StringComparison('Server %s refused its .*' % self.endpoint.server_url) - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Unsupported association type .*')), + logbook.check(('openid.consumer.consumer', 'WARNING', StringComparison('Unsupported association type .*')), ('openid.consumer.consumer', 'ERROR', refused_msg)) def testValid(self): diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index 4e707d59..82af2cf5 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -210,7 +210,7 @@ def vrfy(disco_url): with LogCapture() as logbook: self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) - logbook.check(('openid.server.trustroot', 'ERROR', StringComparison('Failed to validate return_to .*'))) + logbook.check(('openid.server.trustroot', 'INFO', StringComparison('Failed to validate return_to .*'))) def test_verifyFailIfDiscoveryRedirects(self): realm = 'http://*.example.com/' @@ -222,7 +222,7 @@ def vrfy(disco_url): with LogCapture() as logbook: self.assertFalse(trustroot.verifyReturnTo(realm, return_to, _vrfy=vrfy)) - logbook.check(('openid.server.trustroot', 'ERROR', StringComparison('Attempting to verify .*'))) + logbook.check(('openid.server.trustroot', 'INFO', StringComparison('Attempting to verify .*'))) if __name__ == '__main__': diff --git a/openid/test/test_server.py b/openid/test/test_server.py index c5cd96fe..09732d09 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -1864,7 +1864,7 @@ def test_verifyBadHandle(self): with LogCapture() as logbook: verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(verified) - logbook.check(('openid.server.server', 'ERROR', StringComparison('failed to get assoc with handle .*'))) + logbook.check(('openid.server.server', 'INFO', StringComparison('failed to get assoc with handle .*'))) def test_verifyAssocMismatch(self): """Attempt to validate sign-all message with a signed-list assoc.""" @@ -1882,7 +1882,7 @@ def test_verifyAssocMismatch(self): with LogCapture() as logbook: verified = self.signatory.verify(assoc_handle, signed) self.assertFalse(verified) - logbook.check(('openid.server.server', 'ERROR', StringComparison('Error in verifying .*'))) + logbook.check(('openid.server.server', 'INFO', StringComparison('Error in verifying .*'))) def test_getAssoc(self): assoc_handle = self.makeAssoc(dumb=True) diff --git a/openid/test/test_verifydisco.py b/openid/test/test_verifydisco.py index 3e749641..f327ecd1 100644 --- a/openid/test/test_verifydisco.py +++ b/openid/test/test_verifydisco.py @@ -104,7 +104,7 @@ def test_openID2MismatchedDoesDisco(self): with LogCapture() as logbook: result = self.consumer._verifyDiscoveryResults(msg, mismatched) self.assertEqual(result, sentinel) - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Unable to use stored discovery .*')), ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid2UsePreDiscovered(self): @@ -151,7 +151,7 @@ def discoverAndVerify(claimed_id, to_match_endpoints): with six.assertRaisesRegex(self, consumer.ProtocolError, text): self.consumer._verifyDiscoveryResults(msg, endpoint) - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Unable to use stored discovery .*')), ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid1UsePreDiscovered(self): @@ -190,7 +190,7 @@ def discoverAndVerify(claimed_id, _to_match): with LogCapture() as logbook: self.assertRaises(VerifiedError, self.consumer._verifyDiscoveryResults, msg, endpoint) - logbook.check(('openid.consumer.consumer', 'ERROR', StringComparison('Error attempting to use .*')), + logbook.check(('openid.consumer.consumer', 'INFO', StringComparison('Unable to use stored discovery .*')), ('openid.consumer.consumer', 'INFO', 'Attempting discovery to verify endpoint')) def test_openid2Fragment(self): From b93c0753e81a2749bfc631b4b068a69671125fa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 11 Jul 2018 11:16:16 +0200 Subject: [PATCH 117/151] Use base64 encoded Diffie-Hellman modulus and generator --- openid/constants.py | 9 +++------ openid/consumer/consumer.py | 6 ++---- openid/dh.py | 29 +++++++++++++++++++++++++---- openid/server/server.py | 2 -- openid/test/test_association.py | 3 ++- openid/test/test_consumer.py | 7 ++++--- openid/test/test_dh.py | 26 ++++++++++++++++++++++++-- openid/test/test_server.py | 15 ++++++--------- 8 files changed, 66 insertions(+), 31 deletions(-) diff --git a/openid/constants.py b/openid/constants.py index 8128a27f..f07107bc 100644 --- a/openid/constants.py +++ b/openid/constants.py @@ -3,9 +3,6 @@ # Default Diffie-Hellman modulus and generator. # Defined in OpenID specification http://openid.net/specs/openid-authentication-2_0.html#pvalue -DEFAULT_DH_MODULUS = int( - '155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698188993727816152646' - '631438561595825688188889951272158842675419950341258706556549803580104870537681476726513255747040765857479291291572' - '334510643245094715007229621094194349783925984760375594985848253359305585439638443' -) -DEFAULT_DH_GENERATOR = 2 +DEFAULT_DH_MODULUS = ('ANz5OguIOXLsDhmYmsWizjEOHTdxfo2Vcbt2I3MYZuYe91ouJ4mLBX+YkcLiemOcPym2CBRYHNOyyjmG0mg3BVd9RcLn5S3I' + 'HHoXGHblzqdLFEi/368Ygo79JRnxTkXjgmY0rxlJ5bU1zIKaSDuKdiI+XUkKJX8Fvf8W8vsixYOr') +DEFAULT_DH_GENERATOR = 'Ag==' diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 6834e93c..717eb2be 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -475,10 +475,8 @@ def getRequest(self): args = {'dh_consumer_public': self.dh.public_key} if not self.dh.usingDefaultValues(): - args.update({ - 'dh_modulus': cryptutil.longToBase64(self.dh.modulus), - 'dh_gen': cryptutil.longToBase64(self.dh.generator), - }) + modulus, generator = self.dh.parameters + args.update({'dh_modulus': modulus, 'dh_gen': generator}) return args diff --git a/openid/dh.py b/openid/dh.py index 8f491319..1993cd19 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -35,9 +35,18 @@ class DiffieHellman(object): def __init__(self, modulus, generator): """Create a new instance. - @type modulus: Union[six.integer_types] - @type generator: Union[six.integer_types] + @type modulus: six.text_type, Union[six.integer_types] are deprecated + @type generator: six.text_type, Union[six.integer_types] are deprecated """ + if isinstance(modulus, six.integer_types): + warnings.warn("Modulus should be passed as base64 encoded string.") + else: + modulus = cryptutil.base64ToLong(modulus) + if isinstance(generator, six.integer_types): + warnings.warn("Generator should be passed as base64 encoded string.") + else: + generator = cryptutil.base64ToLong(generator) + self.parameter_numbers = DHParameterNumbers(modulus, generator) parameters = self.parameter_numbers.parameters(default_backend()) self.private_key = parameters.generate_private_key() @@ -53,6 +62,7 @@ def modulus(self): @rtype: Union[six.integer_types] """ + warnings.warn("Modulus property will return base64 encoded string.", DeprecationWarning) return self.parameter_numbers.p @property @@ -61,8 +71,20 @@ def generator(self): @rtype: Union[six.integer_types] """ + warnings.warn("Generator property will return base64 encoded string.", DeprecationWarning) return self.parameter_numbers.g + @property + def parameters(self): + """Return base64 encoded modulus and generator. + + @return: Tuple with modulus and generator + @rtype: Tuple[six.text_type, six.text_type] + """ + modulus = self.parameter_numbers.p + generator = self.parameter_numbers.g + return cryptutil.longToBase64(modulus), cryptutil.longToBase64(generator) + @property def public(self): """Return the public key. @@ -81,8 +103,7 @@ def public_key(self): return cryptutil.longToBase64(self.private_key.public_key().public_numbers().y) def usingDefaultValues(self): - return (self.modulus == DEFAULT_DH_MODULUS and - self.generator == DEFAULT_DH_GENERATOR) + return self.parameters == (DEFAULT_DH_MODULUS, DEFAULT_DH_GENERATOR) def getSharedSecret(self, composite): """Return a shared secret. diff --git a/openid/server/server.py b/openid/server/server.py index 42efc8ee..aa2749a4 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -367,8 +367,6 @@ def fromMessage(cls, message): % (missing,)) if dh_modulus or dh_gen: - dh_modulus = cryptutil.base64ToLong(dh_modulus) - dh_gen = cryptutil.base64ToLong(dh_gen) dh = DiffieHellman(dh_modulus, dh_gen) else: dh = DiffieHellman.fromDefaults() diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 2dd6266f..0763b12d 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -4,6 +4,7 @@ import unittest from openid import association +from openid.constants import DEFAULT_DH_GENERATOR from openid.consumer.consumer import DiffieHellmanSHA1ConsumerSession, PlainTextConsumerSession from openid.dh import DiffieHellman from openid.message import BARE_NS, OPENID2_NS, OPENID_NS, Message @@ -25,7 +26,7 @@ def test_roundTrip(self): def createNonstandardConsumerDH(): - nonstandard_dh = DiffieHellman(1315291, 2) + nonstandard_dh = DiffieHellman('FBHb', DEFAULT_DH_GENERATOR) return DiffieHellmanSHA1ConsumerSession(nonstandard_dh) diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 8232d05e..8ab374f2 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -11,6 +11,7 @@ from testfixtures import LogCapture, ShouldWarn, StringComparison from openid import association, cryptutil, fetchers, kvform, oidutil +from openid.constants import DEFAULT_DH_GENERATOR from openid.consumer.consumer import (CANCEL, FAILURE, SETUP_NEEDED, SUCCESS, AuthRequest, CancelResponse, Consumer, DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, FailureResponse, GenericConsumer, PlainTextConsumerSession, ProtocolError, @@ -125,7 +126,7 @@ def makeFastConsumerSession(consumer_session_cls=DiffieHellmanSHA256ConsumerSess """ Create custom DH object so tests run quickly. """ - dh = DiffieHellman(100389557, 2) + dh = DiffieHellman('BfvStQ==', DEFAULT_DH_GENERATOR) return consumer_session_cls(dh) @@ -1776,8 +1777,8 @@ class TestDiffieHellmanResponseParameters(object): def setUp(self): # Pre-compute DH with small prime so tests run quickly. - self.server_dh = DiffieHellman(100389557, 2) - self.consumer_dh = DiffieHellman(100389557, 2) + self.server_dh = DiffieHellman('BfvStQ==', DEFAULT_DH_GENERATOR) + self.consumer_dh = DiffieHellman('BfvStQ==', DEFAULT_DH_GENERATOR) # base64(btwoc(g ^ xb mod p)) self.dh_server_public = self.server_dh.public_key diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 84e55547..8b8e6d6d 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -61,13 +61,35 @@ def test_strxor(self): class TestDiffieHellman(unittest.TestCase): """Test `DiffieHellman` class.""" + def test_init(self): + dh = DiffieHellman(DEFAULT_DH_MODULUS, DEFAULT_DH_GENERATOR) + self.assertTrue(dh.usingDefaultValues()) + + def test_init_int(self): + dh = DiffieHellman(base64ToLong(DEFAULT_DH_MODULUS), base64ToLong(DEFAULT_DH_GENERATOR)) + self.assertTrue(dh.usingDefaultValues()) + def test_modulus(self): dh = DiffieHellman.fromDefaults() - self.assertEqual(dh.modulus, DEFAULT_DH_MODULUS) + modulus = int('155172898181473697471232257763715539915724801966915404479707795314057629378541917580651227423698' + '188993727816152646631438561595825688188889951272158842675419950341258706556549803580104870537681' + '476726513255747040765857479291291572334510643245094715007229621094194349783925984760375594985848' + '253359305585439638443') + warning_msg = "Modulus property will return base64 encoded string." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertEqual(dh.modulus, modulus) def test_generator(self): dh = DiffieHellman.fromDefaults() - self.assertEqual(dh.generator, DEFAULT_DH_GENERATOR) + warning_msg = "Generator property will return base64 encoded string." + with ShouldWarn(DeprecationWarning(warning_msg)): + warnings.simplefilter('always') + self.assertEqual(dh.generator, 2) + + def test_parameters(self): + dh = DiffieHellman.fromDefaults() + self.assertEqual(dh.parameters, (DEFAULT_DH_MODULUS, DEFAULT_DH_GENERATOR)) consumer_private_key = int( '76773183260125655927407219021356850612958916567415386199501281181228346359328609688049646172182310748186340503' diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 09732d09..207d56f3 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -26,11 +26,9 @@ # for more, see /etc/ssh/moduli -ALT_MODULUS = int('1423261515703355186607439952816216983770573549498844689430217675736088990483613604225135575535147900' - '4551229946895343158530081254885941985717109436635815890343316791551733211386105974742540867014420109' - '9811846875730766487278261498262568348338476437200556998366087779709990807518291581860338635288400119' - '293970087') -ALT_GEN = 5 +ALT_MODULUS = ('AMqt3ewWZ/xotfoV1TxOFTLdJFYaGi1HoSwBq+oeAHMfaSGqxAdCMR/fnmNLtxMb7hryQCYVVDiakQQl4ETojINZsBD1rSuA4pyxpbA' + 'nsZ2eAab2Om9F5dftL/aioAhQUKfQzzB8PbUdJJA1WQe0QnwjqY3x64q+8rogm7ev/oan') +ALT_GEN = 'BQ==' # Example values to be used in tests @@ -456,16 +454,15 @@ def test_associateDHModGen(self): 'openid.mode': 'associate', 'openid.session_type': 'DH-SHA1', 'openid.dh_consumer_public': "Rzup9265tw==", - 'openid.dh_modulus': cryptutil.longToBase64(ALT_MODULUS), - 'openid.dh_gen': cryptutil.longToBase64(ALT_GEN), + 'openid.dh_modulus': ALT_MODULUS, + 'openid.dh_gen': ALT_GEN, } r = self.decode(args) self.assertIsInstance(r, server.AssociateRequest) self.assertEqual(r.mode, "associate") self.assertEqual(r.session.session_type, "DH-SHA1") self.assertEqual(r.assoc_type, "HMAC-SHA1") - self.assertEqual(r.session.dh.modulus, ALT_MODULUS) - self.assertEqual(r.session.dh.generator, ALT_GEN) + self.assertEqual(r.session.dh.parameters, (ALT_MODULUS, ALT_GEN)) self.assertTrue(r.session.consumer_pubkey) def test_associateDHCorruptModGen(self): From fbef83dbaeedbed3ed1471ecaf4644337369c924 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 11 Jul 2018 13:51:59 +0200 Subject: [PATCH 118/151] Fix #13 - Base64 based API for Diffie-Hellman --- openid/consumer/consumer.py | 3 +- openid/dh.py | 21 +++++++----- openid/server/server.py | 39 ++++++++++++++-------- openid/test/test_consumer.py | 13 ++++---- openid/test/test_dh.py | 65 ++++++++++++++++-------------------- openid/test/test_server.py | 14 ++++---- 6 files changed, 82 insertions(+), 73 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 717eb2be..e1384774 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -188,6 +188,7 @@ """ from __future__ import unicode_literals +import base64 import copy import logging import warnings @@ -490,7 +491,7 @@ def extractSecret(self, response): warnings.warn("Attribute hash_func is deprecated, use algorithm instead.", DeprecationWarning) return self.dh.xorSecret(dh_server_public, enc_mac_key, self.hash_func) else: - return self.dh.xor_secret(dh_server_public, enc_mac_key, self.algorithm) + return base64.b64decode(self.dh.xor_secret(dh_server_public64, enc_mac_key64, self.algorithm)) class DiffieHellmanSHA256ConsumerSession(DiffieHellmanSHA1ConsumerSession): diff --git a/openid/dh.py b/openid/dh.py index 1993cd19..1af2d475 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -1,6 +1,7 @@ """"Utilities for Diffie-Hellman key exchange.""" from __future__ import unicode_literals +import base64 import warnings import six @@ -10,6 +11,7 @@ from openid import cryptutil from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS +from openid.oidutil import toBase64 def _xor(a_b): @@ -118,16 +120,16 @@ def getSharedSecret(self, composite): def get_shared_secret(self, public_key): """Return a shared secret. - @param public_key: Public key of the other party. - @type public_key: Union[six.integer_types] + @param public_key: Base64 encoded public key of the other party. + @type public_key: six.text_type @rtype: six.binary_type """ - public_numbers = DHPublicNumbers(public_key, self.parameter_numbers) + public_numbers = DHPublicNumbers(cryptutil.base64ToLong(public_key), self.parameter_numbers) return self.private_key.exchange(public_numbers.public_key(default_backend())) def xorSecret(self, composite, secret, hash_func): warnings.warn("Method 'xorSecret' is deprecated, use 'xor_secret' instead.", DeprecationWarning) - dh_shared = self.get_shared_secret(composite) + dh_shared = self.get_shared_secret(cryptutil.longToBase64(composite)) # The DH secret must be `btwoc` compatible. # See http://openid.net/specs/openid-authentication-2_0.html#rfc.section.8.2.3 for details. @@ -137,11 +139,14 @@ def xorSecret(self, composite, secret, hash_func): return strxor(secret, hashed_dh_shared) def xor_secret(self, public_key, secret, algorithm): - """Return a XOR of a secret key and hash of a DH exchanged secret. + """Return a base64 encoded XOR of a secret key and hash of a DH exchanged secret. - @type public_key: Union[six.integer_types] - @type secret: bytes + @param public_key: Base64 encoded public key of the other party. + @type public_key: six.text_type + @param secret: Base64 encoded secret + @type secret: six.text_type @type algorithm: hashes.HashAlgorithm + @rtype: six.text_type """ dh_shared = self.get_shared_secret(public_key) @@ -152,4 +157,4 @@ def xor_secret(self, public_key, secret, algorithm): digest = hashes.Hash(algorithm, backend=default_backend()) digest.update(dh_shared) hashed_dh_shared = digest.finalize() - return strxor(secret, hashed_dh_shared) + return toBase64(strxor(base64.b64decode(secret), hashed_dh_shared)) diff --git a/openid/server/server.py b/openid/server/server.py index aa2749a4..cf23c824 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -117,6 +117,7 @@ """ from __future__ import unicode_literals +import base64 import logging import os import time @@ -323,9 +324,8 @@ class DiffieHellmanSHA1ServerSession(object): @ivar dh: The Diffie-Hellman algorithm values for this request @type dh: DiffieHellman - @ivar consumer_pubkey: The public key sent by the consumer in the - associate request - @type consumer_pubkey: int, long in Python 2 + @ivar consumer_public_key: The public key sent by the consumer in the associate request + @type consumer_public_key: six.text_type @see: U{OpenID Specs, Mode: associate } @@ -336,9 +336,23 @@ class DiffieHellmanSHA1ServerSession(object): hash_func = None allowed_assoc_types = ['HMAC-SHA1'] - def __init__(self, dh, consumer_pubkey): + def __init__(self, dh, consumer_public_key): self.dh = dh - self.consumer_pubkey = consumer_pubkey + if isinstance(consumer_public_key, six.integer_types): + warnings.warn("Public key should be base64 encoded.", DeprecationWarning) + consumer_public_key = cryptutil.longToBase64(consumer_public_key) + # Check if the key can be decoded + try: + base64.b64decode(consumer_public_key) + except (ValueError, TypeError) as error: + raise ValueError("{!r} is not a valid base64 string: {}".format(consumer_public_key, error)) + self.consumer_public_key = consumer_public_key + + @property + def consumer_pubkey(self): + """Return consumer public key as integer.""" + warnings.warn("Attribute consumer_pubkey si deprecated, use consumer_public_key instead.", DeprecationWarning) + return cryptutil.base64ToLong(self.consumer_public_key) @classmethod def fromMessage(cls, message): @@ -371,24 +385,23 @@ def fromMessage(cls, message): else: dh = DiffieHellman.fromDefaults() - consumer_pubkey = message.getArg(OPENID_NS, 'dh_consumer_public') - if consumer_pubkey is None: + consumer_public_key = message.getArg(OPENID_NS, 'dh_consumer_public') + if consumer_public_key is None: raise ProtocolError(message, "Public key for DH-SHA1 session " "not found in message %s" % (message,)) - consumer_pubkey = cryptutil.base64ToLong(consumer_pubkey) - - return cls(dh, consumer_pubkey) + return cls(dh, consumer_public_key) def answer(self, secret): if self.hash_func is not None: warnings.warn("Attribute hash_func is deprecated, use algorithm instead.", DeprecationWarning) - mac_key = self.dh.xorSecret(self.consumer_pubkey, secret, self.hash_func) + mac_key = self.dh.xorSecret(cryptutil.base64ToLong(self.consumer_public_key), secret, self.hash_func) + mac_key = oidutil.toBase64(mac_key) else: - mac_key = self.dh.xor_secret(self.consumer_pubkey, secret, self.algorithm) + mac_key = self.dh.xor_secret(self.consumer_public_key, base64.b64encode(secret), self.algorithm) return { 'dh_server_public': self.dh.public_key, - 'enc_mac_key': oidutil.toBase64(mac_key), + 'enc_mac_key': mac_key, } diff --git a/openid/test/test_consumer.py b/openid/test/test_consumer.py index 8ab374f2..bcf9d2f5 100644 --- a/openid/test/test_consumer.py +++ b/openid/test/test_consumer.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import base64 import os import time import unittest @@ -10,7 +11,7 @@ from six.moves.urllib.parse import parse_qsl, urlparse from testfixtures import LogCapture, ShouldWarn, StringComparison -from openid import association, cryptutil, fetchers, kvform, oidutil +from openid import association, fetchers, kvform, oidutil from openid.constants import DEFAULT_DH_GENERATOR from openid.consumer.consumer import (CANCEL, FAILURE, SETUP_NEEDED, SUCCESS, AuthRequest, CancelResponse, Consumer, DiffieHellmanSHA1ConsumerSession, DiffieHellmanSHA256ConsumerSession, @@ -1783,12 +1784,10 @@ def setUp(self): # base64(btwoc(g ^ xb mod p)) self.dh_server_public = self.server_dh.public_key - self.secret = os.urandom(self.session_cls.secret_size) + self.secret = base64.b64encode(os.urandom(self.session_cls.secret_size)) - self.enc_mac_key = oidutil.toBase64( - self.server_dh.xor_secret(cryptutil.base64ToLong(self.consumer_dh.public_key), - self.secret, - self.session_cls.algorithm)) + self.enc_mac_key = self.server_dh.xor_secret(self.consumer_dh.public_key, self.secret, + self.session_cls.algorithm) self.consumer_session = self.session_cls(self.consumer_dh) @@ -1799,7 +1798,7 @@ def testExtractSecret(self): self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) extracted = self.consumer_session.extractSecret(self.msg) - self.assertEqual(extracted, self.secret) + self.assertEqual(extracted, base64.b64decode(self.secret)) def testAbsentServerPublic(self): self.msg.setArg(OPENID_NS, 'enc_mac_key', self.enc_mac_key) diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index 8b8e6d6d..a21491a8 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -1,6 +1,7 @@ """Test `openid.dh` module.""" from __future__ import unicode_literals +import base64 import os import unittest import warnings @@ -12,8 +13,9 @@ from testfixtures import ShouldWarn from openid.constants import DEFAULT_DH_GENERATOR, DEFAULT_DH_MODULUS -from openid.cryptutil import base64ToLong, bytes_to_int, longToBase64 +from openid.cryptutil import base64ToLong from openid.dh import DiffieHellman, strxor +from openid.oidutil import toBase64 class TestStrXor(unittest.TestCase): @@ -91,35 +93,23 @@ def test_parameters(self): dh = DiffieHellman.fromDefaults() self.assertEqual(dh.parameters, (DEFAULT_DH_MODULUS, DEFAULT_DH_GENERATOR)) - consumer_private_key = int( - '76773183260125655927407219021356850612958916567415386199501281181228346359328609688049646172182310748186340503' - '26318343789919595649515190982375134969315580266608309203790369036760020471410949003193451675532879428946682852' - '7087756147962428703119223967577366837042279080006329440425557036807436654929251188437293') - consumer_public_key = int( - '14830402392262721982219607342625341531794979311088664077137112813385301968870761946911013412944671626402638538' - '59019114967817783168739766941288204771883652891577627356203670315421489407520844320897873950439171044693921561' - '24149254347661216215110718681656349527564919668545970743829522251387472714136707262965225') - server_private_key = int( - '15467965641543992347841556205070390914637305348154825847599734515099514013537846015402306363308433241908283446' - '71248072297246966864402013185397179020027880855596392908146308184428215791914057102026401324081917190180806065' - '52997123133752764540011560986670942115061415865499463644558159755273696690932941082271979') - server_public_key = int( - '34503131980021108262326730163610830553875615642061454929962013481368582594793479022634253261703143188115239697' - '31865012494779720501092100433895935952054678007893102647432613158698447525023861310539814658911402112680185359' - '5512256481326572078983201034675082346312609787920346766733771767752145619255920370032919' - ) - shared_secret = ( - b'\x14u\xa1_k\xf6\x83\xfbp#\xc9\x8e\xd4qb#\xdc\xe0D\xfe\xbf\x08\x16\xc9\xd3\xedwr\nC&\xf2\x14\xca\x90\xcdr\xa2' - b'\xc7\x96A\x89\xb66\x8e\'W"_\xea\xa4\xd8\x97\xf7e\xdby`\x90\xe0\x8aUG\xf9x;\xc7\xb5\x9a\x1duq]\x8cn\xe5\x14' - b'\xf0\x12\xe3\xf2\x15H\xce\xebe\xd3\xea\xedu\xa8\x9d\xf9>\xfb\xdeL<0\x02\xcb\xfa\xf8\xeb)+\xc1Qn\xa3\n"\x03n' - b'\x12I\x9a\x145p\xaf\x87J\xca\x16T\xb4\xd8') - secret = b'Rimmer ordered hot gazpacho soup' - mac_key = b'\x84\x06)\x1f6\xcf\xbcA\xec\xd0\x9d\xad\xf0\xa6"\xaa\x8cl-)\x91\xccg\xc2Bl\x0c\x83\xdbZ5\xfd' + consumer_private_key = ('bVQh4Z81F5e57JCT1pmxADRktpYwIwhNjWkiIjg450sfYZOJ9Ntf4YHBhcBpkPyehdq/XL+yEWbZFig4wh2MdqES0X' + 'aOPRVl7ZzsjTNgztKUYE2mhiYQd4KMmB9uLExM72ntwcdZ3/vlb0Fq8DlIx3FhqeaYsKKTsdUW/KbJcS0=') + consumer_public_key = ('ANMxIwAeRWw5mZD3+DkoX3G6n/tuBGsjfk6R+vBW2zwve0BSlh1F0EsXlQEUuXJ+s1DQ8nFQLPYOLO0mLexXH0bSscv' + 'zhBldH+L+fxJfoL9xoTAxk7qqT659QqErhEMtQpBy7hK5L7Qb8R2NAUZ++MPxUNB71IBd6vMG6M6MueXp') + server_private_key = ('ANxFaZXkCVNESkYKFclilsm7tVIO1CNYy621Y44w19OPk7xE7zEZdttX/KfRSImecPpn+AATLhRZMuXzaq3KDFFTu9Nu' + 'hSINYml2f7xZd1+lYg6YhWiojfP3YPqLIV9sj/26O1A7pTcq6jajj/8E5P+qkr6+bSQhZ0UlZiBQUyDr') + server_public_key = ('MSJTx7cMqUBAcpLCan75t+8OSf3SZUSwivlEUYxMaHbbueKp1u4/7Fdw9sTCN3gA0iFE2dTOJpRUT4TmFomHnyIfBExdc' + 'wbkXiQIhsSnBJkGmPuAPkKFFHtB0pKET6bWZolwP5fp4lZOgM+7FIRte5OZd5XEJIN9vBYxo6NaoRc=') + shared_secret = ('FHWhX2v2g/twI8mO1HFiI9zgRP6/CBbJ0+13cgpDJvIUypDNcqLHlkGJtjaOJ1ciX+qk2Jf3Zdt5YJDgilVH+Xg7x7WaHXVxX' + 'Yxu5RTwEuPyFUjO62XT6u11qJ35PvveTDwwAsv6+OspK8FRbqMKIgNuEkmaFDVwr4dKyhZUtNg=') + secret = toBase64(b'Rimmer ordered hot gazpacho soup') + mac_key = 'hAYpHzbPvEHs0J2t8KYiqoxsLSmRzGfCQmwMg9taNf0=' def setup_keys(self, dh_object, public_key, private_key): """Set up private and public key into DiffieHellman object.""" - public_numbers = DHPublicNumbers(public_key, dh_object.parameter_numbers) - private_numbers = DHPrivateNumbers(private_key, public_numbers) + public_numbers = DHPublicNumbers(base64ToLong(public_key), dh_object.parameter_numbers) + private_numbers = DHPrivateNumbers(base64ToLong(private_key), public_numbers) dh_object.private_key = private_numbers.private_key(default_backend()) def test_public(self): @@ -128,22 +118,22 @@ def test_public(self): warning_msg = "Attribute 'public' is deprecated. Use 'public_key' instead." with ShouldWarn(DeprecationWarning(warning_msg)): warnings.simplefilter('always') - self.assertEqual(dh.public, self.server_public_key) + self.assertEqual(dh.public, base64ToLong(self.server_public_key)) def test_public_key(self): dh = DiffieHellman.fromDefaults() self.setup_keys(dh, self.server_public_key, self.server_private_key) - self.assertEqual(dh.public_key, longToBase64(self.server_public_key)) + self.assertEqual(dh.public_key, self.server_public_key) def test_get_shared_secret_server(self): server_dh = DiffieHellman.fromDefaults() self.setup_keys(server_dh, self.server_public_key, self.server_private_key) - self.assertEqual(server_dh.get_shared_secret(self.consumer_public_key), self.shared_secret) + self.assertEqual(server_dh.get_shared_secret(self.consumer_public_key), base64.b64decode(self.shared_secret)) def test_get_shared_secret_consumer(self): consumer_dh = DiffieHellman.fromDefaults() self.setup_keys(consumer_dh, self.consumer_public_key, self.consumer_private_key) - self.assertEqual(consumer_dh.get_shared_secret(self.server_public_key), self.shared_secret) + self.assertEqual(consumer_dh.get_shared_secret(self.server_public_key), base64.b64decode(self.shared_secret)) def test_getSharedSecret(self): # Test the deprecated method @@ -152,7 +142,7 @@ def test_getSharedSecret(self): warning_msg = "Method 'getSharedSecret' is deprecated in favor of 'get_shared_secret'." with ShouldWarn(DeprecationWarning(warning_msg)): warnings.simplefilter('always') - self.assertEqual(consumer_dh.getSharedSecret(self.server_public_key), bytes_to_int(self.shared_secret)) + self.assertEqual(consumer_dh.getSharedSecret(self.server_public_key), base64ToLong(self.shared_secret)) def test_xorSecret(self): # Test key exchange - deprecated method @@ -167,7 +157,8 @@ def sha256(value): warning_msg = "Method 'xorSecret' is deprecated, use 'xor_secret' instead." with ShouldWarn(DeprecationWarning(warning_msg)): warnings.simplefilter('always') - self.assertEqual(server_dh.xorSecret(self.consumer_public_key, self.secret, sha256), self.mac_key) + secret = server_dh.xorSecret(base64ToLong(self.consumer_public_key), base64.b64decode(self.secret), sha256) + self.assertEqual(secret, base64.b64decode(self.mac_key)) def test_exchange_server_static(self): # Test key exchange - server part with static values @@ -175,7 +166,7 @@ def test_exchange_server_static(self): self.setup_keys(server_dh, self.server_public_key, self.server_private_key) self.assertEqual(server_dh.xor_secret(self.consumer_public_key, self.secret, hashes.SHA256()), self.mac_key) - self.assertEqual(server_dh.public_key, longToBase64(self.server_public_key)) + self.assertEqual(server_dh.public_key, self.server_public_key) def test_exchange_consumer_static(self): # Test key exchange - consumer part with static values @@ -192,11 +183,11 @@ def test_exchange_dynamic(self): consumer_dh = DiffieHellman.fromDefaults() consumer_public_key = consumer_dh.public_key # Server part - secret = os.urandom(32) + secret = toBase64(os.urandom(32)) server_dh = DiffieHellman.fromDefaults() - mac_key = server_dh.xor_secret(base64ToLong(consumer_public_key), secret, hashes.SHA256()) + mac_key = server_dh.xor_secret(consumer_public_key, secret, hashes.SHA256()) server_public_key = server_dh.public_key # Consumer part - shared_secret = consumer_dh.xor_secret(base64ToLong(server_public_key), mac_key, hashes.SHA256()) + shared_secret = consumer_dh.xor_secret(server_public_key, mac_key, hashes.SHA256()) # Check secret was negotiated correctly self.assertEqual(secret, shared_secret) diff --git a/openid/test/test_server.py b/openid/test/test_server.py index 207d56f3..c522e08a 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -1314,7 +1314,7 @@ class ZeroHashServerSession(DiffieHellmanSHA1ServerSession): server_dh = DiffieHellman.fromDefaults() consumer_dh = DiffieHellman.fromDefaults() - server_session = ZeroHashServerSession(server_dh, cryptutil.base64ToLong(consumer_dh.public_key)) + server_session = ZeroHashServerSession(server_dh, consumer_dh.public_key) result = {'dh_server_public': server_dh.public_key, 'enc_mac_key': oidutil.toBase64(b'Rimmer is smeg head!')} with ShouldWarn() as captured: warnings.simplefilter('always') @@ -1352,10 +1352,10 @@ def test_dhSHA1(self): self.assertTrue(rfg("enc_mac_key")) self.assertTrue(rfg("dh_server_public")) - enc_key = oidutil.fromBase64(rfg("enc_mac_key")) - spub = cryptutil.base64ToLong(rfg("dh_server_public")) + enc_key = rfg("enc_mac_key") + spub = rfg("dh_server_public") secret = consumer_dh.xor_secret(spub, enc_key, hashes.SHA1()) - self.assertEqual(secret, self.assoc.secret) + self.assertEqual(secret, oidutil.toBase64(self.assoc.secret)) def test_dhSHA256(self): self.assoc = self.signatory.createAssociation( @@ -1377,10 +1377,10 @@ def test_dhSHA256(self): self.assertTrue(rfg("enc_mac_key")) self.assertTrue(rfg("dh_server_public")) - enc_key = oidutil.fromBase64(rfg("enc_mac_key")) - spub = cryptutil.base64ToLong(rfg("dh_server_public")) + enc_key = rfg("enc_mac_key") + spub = rfg("dh_server_public") secret = consumer_dh.xor_secret(spub, enc_key, hashes.SHA256()) - self.assertEqual(secret, self.assoc.secret) + self.assertEqual(secret, oidutil.toBase64(self.assoc.secret)) def test_protoError256(self): s256_session = DiffieHellmanSHA256ConsumerSession() From 37d2928e1983481b38d6dec5c82724ed1d87a8c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 11 Jul 2018 13:54:05 +0200 Subject: [PATCH 119/151] Make DiffieHellman.get_shared_secret private --- openid/dh.py | 10 +++++----- openid/test/test_dh.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/openid/dh.py b/openid/dh.py index 1af2d475..1b6f34c0 100644 --- a/openid/dh.py +++ b/openid/dh.py @@ -114,10 +114,10 @@ def getSharedSecret(self, composite): @type composite: Union[six.integer_types] @rtype: Union[six.integer_types] """ - warnings.warn("Method 'getSharedSecret' is deprecated in favor of 'get_shared_secret'.", DeprecationWarning) - return cryptutil.bytes_to_int(self.get_shared_secret(composite)) + warnings.warn("Method 'getSharedSecret' is deprecated in favor of '_get_shared_secret'.", DeprecationWarning) + return cryptutil.bytes_to_int(self._get_shared_secret(composite)) - def get_shared_secret(self, public_key): + def _get_shared_secret(self, public_key): """Return a shared secret. @param public_key: Base64 encoded public key of the other party. @@ -129,7 +129,7 @@ def get_shared_secret(self, public_key): def xorSecret(self, composite, secret, hash_func): warnings.warn("Method 'xorSecret' is deprecated, use 'xor_secret' instead.", DeprecationWarning) - dh_shared = self.get_shared_secret(cryptutil.longToBase64(composite)) + dh_shared = self._get_shared_secret(cryptutil.longToBase64(composite)) # The DH secret must be `btwoc` compatible. # See http://openid.net/specs/openid-authentication-2_0.html#rfc.section.8.2.3 for details. @@ -148,7 +148,7 @@ def xor_secret(self, public_key, secret, algorithm): @type algorithm: hashes.HashAlgorithm @rtype: six.text_type """ - dh_shared = self.get_shared_secret(public_key) + dh_shared = self._get_shared_secret(public_key) # The DH secret must be `btwoc` compatible. # See http://openid.net/specs/openid-authentication-2_0.html#rfc.section.8.2.3 for details. diff --git a/openid/test/test_dh.py b/openid/test/test_dh.py index a21491a8..b24353de 100644 --- a/openid/test/test_dh.py +++ b/openid/test/test_dh.py @@ -128,18 +128,18 @@ def test_public_key(self): def test_get_shared_secret_server(self): server_dh = DiffieHellman.fromDefaults() self.setup_keys(server_dh, self.server_public_key, self.server_private_key) - self.assertEqual(server_dh.get_shared_secret(self.consumer_public_key), base64.b64decode(self.shared_secret)) + self.assertEqual(server_dh._get_shared_secret(self.consumer_public_key), base64.b64decode(self.shared_secret)) def test_get_shared_secret_consumer(self): consumer_dh = DiffieHellman.fromDefaults() self.setup_keys(consumer_dh, self.consumer_public_key, self.consumer_private_key) - self.assertEqual(consumer_dh.get_shared_secret(self.server_public_key), base64.b64decode(self.shared_secret)) + self.assertEqual(consumer_dh._get_shared_secret(self.server_public_key), base64.b64decode(self.shared_secret)) def test_getSharedSecret(self): # Test the deprecated method consumer_dh = DiffieHellman.fromDefaults() self.setup_keys(consumer_dh, self.consumer_public_key, self.consumer_private_key) - warning_msg = "Method 'getSharedSecret' is deprecated in favor of 'get_shared_secret'." + warning_msg = "Method 'getSharedSecret' is deprecated in favor of '_get_shared_secret'." with ShouldWarn(DeprecationWarning(warning_msg)): warnings.simplefilter('always') self.assertEqual(consumer_dh.getSharedSecret(self.server_public_key), base64ToLong(self.shared_secret)) From 86519c055ec2bc77ff1da15f690e38ed75936e0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 12 Jul 2018 11:02:32 +0200 Subject: [PATCH 120/151] Changelog and version 3.0rc1 bump --- Changelog.md | 21 +++++++++++++++++++++ openid/__init__.py | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/Changelog.md b/Changelog.md index dc6f764f..ac649836 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,26 @@ # Changelog # +## 3.0 ## + + * Support Python3. + * Change most of the API to the text strings. UTF-8 encoded byte string should be compatible. + * Authentication methods based on SHA-256 are now preferred over SHA-1. + * Use `cryptography` library for cryptography tasks. + * Add new base64-based API for `DiffieHellman` class. + * Refactor script to negotiate association with an OpenID server. + * Decrease log levels on repetitive logs. + * Default fetcher is picked from more options. + * Remove `openid.consumer.html_parse` module. + * Remove `hmacSha*`, `randomString`, `randrange` and `sha*` functions from `openid.cryptutil`. + * A lot of refactoring and clean up. + +### Deprecation ### + * Binary strings are deprecated, unless explicitely allowed. + * `hash_func` is deprecated in favor of `algorithm` in `DiffieHellmanSHA*ServerSession` and `DiffieHellmanSHA*ConsumerSession`. + * `DiffieHellmanSHA*ServerSession.consumer_pubkey` is deprecated in favor of `consumer_public_key`. + * Functions `longToBinary` and `binaryToLong` deprecated in favor of `int_to_bytes` and `bytes_to_int`, respectively. + * Old `DiffieHellman` API is deprecated. + ## 2.3.0 ## * Prevent timing attacks on signature comparison. Thanks to Carl Howells. diff --git a/openid/__init__.py b/openid/__init__.py index e34eebfd..d74c93d0 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -23,7 +23,7 @@ and limitations under the License. """ -__version__ = '2.3.0' +__version__ = '3.0rc1' __all__ = [ 'association', From 0358d9eda92ab82c0b6f6d0e15d5268cc307b6ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 12 Jul 2018 11:10:56 +0200 Subject: [PATCH 121/151] Fixup readme --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bce1d2de..9836f498 100644 --- a/README.md +++ b/README.md @@ -16,8 +16,10 @@ Includes example code and support for a variety of storage back-ends. ## REQUIREMENTS ## - - Python 2.7 + - Python 2.7, >3.4 - lxml + - six + - cryptography ## INSTALLATION ## From 814feb3b7d4fa6b95dfd136806649df25f53dd49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 19 Jul 2018 09:45:26 +0200 Subject: [PATCH 122/151] Update quality checks --- .flake8 | 5 ----- .isort.cfg | 6 ------ Makefile | 15 ++++----------- setup.cfg | 13 +++++++++++++ tox.ini | 4 ++-- 5 files changed, 19 insertions(+), 24 deletions(-) delete mode 100644 .flake8 delete mode 100644 .isort.cfg diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 843eae45..00000000 --- a/.flake8 +++ /dev/null @@ -1,5 +0,0 @@ -[flake8] -max-line-length = 120 -# Ignore E123 - enforce hang-closing instead -ignore = E123,W503 -max-complexity = 24 diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 4d1707b3..00000000 --- a/.isort.cfg +++ /dev/null @@ -1,6 +0,0 @@ -[settings] -line_length = 120 -combine_as_imports = true -default_section = THIRDPARTY -known_first_party = openid -add_imports = from __future__ import unicode_literals diff --git a/Makefile b/Makefile index c1d31b0f..c97b35a5 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,10 @@ -.PHONY: test test-openid test-djopenid coverage isort check-all check-isort check-flake8 +.PHONY: all test test-openid test-djopenid coverage isort SOURCES = openid setup.py admin contrib -# Run tests by default -all: test +# Run tox by default +all: + tox test-openid: python -m unittest discover --start=openid @@ -22,11 +23,3 @@ coverage: isort: isort --recursive ${SOURCES} - -check-all: check-isort check-flake8 - -check-isort: - isort --check-only --diff --recursive ${SOURCES} - -check-flake8: - flake8 --format=pylint ${SOURCES} diff --git a/setup.cfg b/setup.cfg index 57067081..466c48c5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,16 @@ +[isort] +line_length = 120 +combine_as_imports = true +default_section = THIRDPARTY +known_first_party = openid +add_imports = from __future__ import unicode_literals + +[flake8] +max-line-length = 120 +# Ignore E123 - enforce hang-closing instead +ignore = E123,W503 +max-complexity = 24 + [sdist] force_manifest=1 formats=gztar,zip diff --git a/tox.ini b/tox.ini index 9e6a8947..2be8cffc 100644 --- a/tox.ini +++ b/tox.ini @@ -28,9 +28,9 @@ commands = djopenid: coverage run --parallel-mode --branch --source=openid,examples --module unittest discover --start={toxinidir}/examples [testenv:quality] -whitelist_externals = make basepython = python2.7 extras = quality commands = - make check-all + isort --check-only --diff --recursive openid setup.py admin contrib + flake8 --format=pylint openid setup.py admin contrib From 907c3410f082ea2b048cf84c22708195fa1d13e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 26 Jul 2018 14:38:56 +0200 Subject: [PATCH 123/151] Fix setuptools problem with unicode --- setup.py | 2 -- tox.ini | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 52bca806..621c4c8f 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -from __future__ import unicode_literals - import os import sys diff --git a/tox.ini b/tox.ini index 2be8cffc..c548f3f5 100644 --- a/tox.ini +++ b/tox.ini @@ -32,5 +32,6 @@ basepython = python2.7 extras = quality commands = - isort --check-only --diff --recursive openid setup.py admin contrib +# setup.py is excluded from isort because distutils have problems with unicode_literals. + isort --check-only --diff --recursive openid admin contrib flake8 --format=pylint openid setup.py admin contrib From ba28662ab27f9d9f5523b68f2aaa78eef93cc5ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 1 Aug 2018 11:15:48 +0200 Subject: [PATCH 124/151] Version 3.0rc2 bump --- openid/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openid/__init__.py b/openid/__init__.py index d74c93d0..fb309945 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -23,7 +23,7 @@ and limitations under the License. """ -__version__ = '3.0rc1' +__version__ = '3.0rc2' __all__ = [ 'association', From 0693fc2038706a5f32e3393338c2d7d750bef1db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 14 Aug 2018 09:45:50 +0200 Subject: [PATCH 125/151] Version 3.0 bump --- openid/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openid/__init__.py b/openid/__init__.py index fb309945..60a05925 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -23,7 +23,7 @@ and limitations under the License. """ -__version__ = '3.0rc2' +__version__ = '3.0' __all__ = [ 'association', From 8241a199a48260744cf6cbe8cfa5edcf34f8a4c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 28 Aug 2018 14:29:06 +0200 Subject: [PATCH 126/151] Convert data values for extensions to text --- openid/extensions/ax.py | 16 ++++++++-------- openid/extensions/sreg.py | 15 ++++++--------- openid/oidutil.py | 15 +++++++++++++++ openid/test/test_ax.py | 10 ++++++++++ openid/test/test_oidutil.py | 22 +++++++++++++++++++++- openid/test/test_sreg.py | 6 ++++++ 6 files changed, 66 insertions(+), 18 deletions(-) diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index faffdb4f..79856f7c 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -8,7 +8,7 @@ from openid import extension from openid.message import OPENID_NS, NamespaceMap -from openid.oidutil import string_to_text +from openid.oidutil import force_text, string_to_text from openid.server.trustroot import TrustRoot __all__ = [ @@ -421,9 +421,9 @@ def addValue(self, type_uri, value): @param type_uri: The URI for the attribute - @param value: The value to add to the response to the relying - party for this attribute - @type value: six.text_type + @param value: The value to add to the response to the relying party for this attribute. It the value is not + a text, it will be converted. + @type value: Any @returns: None """ @@ -432,7 +432,7 @@ def addValue(self, type_uri, value): except KeyError: values = self.data[type_uri] = [] - values.append(value) + values.append(force_text(value)) def setValues(self, type_uri, values): """Set the values for the given attribute type. This replaces @@ -440,11 +440,11 @@ def setValues(self, type_uri, values): @param type_uri: The URI for the attribute - @param values: A list of values to send for this attribute. - @type values: List[six.text_type] + @param values: A list of values to send for this attribute. Values which are not text, will be converted. + @type values: List[Any] """ - self.data[type_uri] = values + self.data[type_uri] = [force_text(v) for v in values] def _getExtensionKVArgs(self, aliases=None): """Get the extension arguments for the key/value pairs diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index 7f9828e9..a90e4d42 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -42,7 +42,7 @@ from openid.extension import Extension from openid.message import NamespaceAliasRegistrationError, registerNamespaceAlias -from openid.oidutil import string_to_text +from openid.oidutil import force_text, string_to_text __all__ = [ 'SRegRequest', @@ -424,12 +424,10 @@ def extractResponse(cls, request, data): @param request: The simple registration request object @type request: SRegRequest - @param data: The simple registration data for this - response, as a dictionary from unqualified simple - registration field name to string (unicode) value. For - instance, the nickname should be stored under the key - 'nickname'. - @type data: Dict[six.text_type, six.text_type], six.binary_type is deprecated + @param data: The simple registration data for this response, as a mapping of unqualified simple registration + field name to value. For instance, the nickname should be stored under the key 'nickname'. If the value is + missing or None, it will be skipped. If the value is not a text, it will be converted. + @type data: Dict[six.text_type, Any] @returns: a simple registration response object @rtype: SRegResponse @@ -439,8 +437,7 @@ def extractResponse(cls, request, data): for field in request.allRequestedFields(): value = data.get(field) if value is not None: - value = string_to_text(value, "Binary values for data are deprecated. Use text input instead.") - self.data[field] = value + self.data[field] = force_text(value) return self # Assign getSRegArgs to a static method so that it can be diff --git a/openid/oidutil.py b/openid/oidutil.py index 0f9ff99e..884d38fa 100644 --- a/openid/oidutil.py +++ b/openid/oidutil.py @@ -162,3 +162,18 @@ def string_to_text(value, deprecate_msg): warnings.warn(deprecate_msg, DeprecationWarning) value = value.decode('utf-8') return value + + +def force_text(value): + """ + Return a text object representing value in UTF-8 encoding. + """ + if isinstance(value, six.text_type): + # It's already a text, just return it. + return value + elif isinstance(value, bytes): + # It's a byte string, decode it. + return value.decode('utf-8') + else: + # It's not a string, convert it. + return six.text_type(value) diff --git a/openid/test/test_ax.py b/openid/test/test_ax.py index 0b059618..d6d4c380 100644 --- a/openid/test/test_ax.py +++ b/openid/test/test_ax.py @@ -168,6 +168,16 @@ def test_doubleSingleton(self): def test_singletonValue(self): self.assertAXValues({'type.foo': 'urn:foo', 'value.foo': 'Westfall'}, {'urn:foo': ['Westfall']}) + def test_add_value_convert(self): + message = ax.AXKeyValueMessage() + message.addValue('http://example.com/attribute', 1492) + self.assertEqual(message.get('http://example.com/attribute'), ['1492']) + + def test_set_values_convert(self): + message = ax.AXKeyValueMessage() + message.setValues('http://example.com/attribute', [1492, True, None]) + self.assertEqual(message.get('http://example.com/attribute'), ['1492', 'True', 'None']) + class FetchRequestTest(unittest.TestCase): def setUp(self): diff --git a/openid/test/test_oidutil.py b/openid/test/test_oidutil.py index 8aacb583..9cd99439 100644 --- a/openid/test/test_oidutil.py +++ b/openid/test/test_oidutil.py @@ -12,7 +12,7 @@ from testfixtures import ShouldWarn from openid import oidutil -from openid.oidutil import string_to_text +from openid.oidutil import force_text, string_to_text class TestBase64(unittest.TestCase): @@ -179,3 +179,23 @@ def test_binary_input(self): self.assertIsInstance(result, six.text_type) self.assertEqual(result, 'ěščřž') + + +class TestForceText(unittest.TestCase): + """Test `force_text` utility function.""" + + def test_text(self): + self.assertEqual(force_text(''), '') + self.assertEqual(force_text('ascii'), 'ascii') + self.assertEqual(force_text('ůňíčóďé'), 'ůňíčóďé') + + def test_bytes(self): + self.assertEqual(force_text(b''), '') + self.assertEqual(force_text(b'ascii'), 'ascii') + self.assertEqual(force_text('ůňíčóďé'.encode('utf-8')), 'ůňíčóďé') + + def test_objects(self): + self.assertEqual(force_text(None), 'None') + self.assertEqual(force_text(14), '14') + self.assertEqual(force_text(True), 'True') + self.assertEqual(force_text(False), 'False') diff --git a/openid/test/test_sreg.py b/openid/test/test_sreg.py index 1224cbd9..d4989b53 100644 --- a/openid/test/test_sreg.py +++ b/openid/test/test_sreg.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals import unittest +from datetime import date from openid.extensions import sreg from openid.message import Message, NamespaceMap @@ -461,6 +462,11 @@ def test(self): sent_data = {'nickname': 'linusaur', 'email': 'president@whitehouse.gov', 'fullname': 'Leonhard Euler'} self.assertEqual(sreg_data_resp, sent_data) + def test_extract_response_conversion(self): + sreg_request = sreg.SRegRequest(required=['dob']) + sreg_response = sreg.SRegResponse.extractResponse(sreg_request, {'dob': date(1989, 11, 17)}) + self.assertEqual(sreg_response['dob'], '1989-11-17') + if __name__ == '__main__': unittest.main() From 065963b49eaf5e162f72510a84ffdaf565464a80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 5 Sep 2018 10:34:55 +0200 Subject: [PATCH 127/151] Allow str keys in Message under Python 2.7 --- openid/message.py | 1 - 1 file changed, 1 deletion(-) diff --git a/openid/message.py b/openid/message.py index 761e4708..554069f4 100644 --- a/openid/message.py +++ b/openid/message.py @@ -206,7 +206,6 @@ def _fromOpenIDArgs(cls, openid_args): namespaces = {} ns_args = [] for key, value in six.iteritems(openid_args): - key = string_to_text(key, "Binary keys in message creations are deprecated. Use text input instead.") value = string_to_text(value, "Binary values in message creations are deprecated. Use text input instead.") if '.' not in key: ns_alias = NULL_NAMESPACE From fc585b00fb89b2b8400ddf7f04246dd981e73ca7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 14 Nov 2018 09:54:52 +0100 Subject: [PATCH 128/151] Fix new flake8 warnings --- openid/consumer/consumer.py | 13 ++++++------- openid/extensions/ax.py | 3 +-- openid/extensions/draft/pape2.py | 5 ++--- openid/extensions/pape.py | 6 ++---- openid/extensions/sreg.py | 6 ++---- openid/server/server.py | 7 ++----- openid/server/trustroot.py | 9 ++++----- openid/store/memstore.py | 3 +-- openid/store/sqlstore.py | 3 +-- openid/test/test_server.py | 14 -------------- openid/yadis/discover.py | 3 +-- openid/yadis/xri.py | 2 +- 12 files changed, 23 insertions(+), 51 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index e1384774..86acd8f5 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -909,8 +909,7 @@ def _verifyDiscoveryResultsOpenID2(self, resp_msg, endpoint): raise ProtocolError( 'openid.identity is present without openid.claimed_id') - elif (to_match.claimed_id is not None and - to_match.local_id is None): + elif (to_match.claimed_id is not None and to_match.local_id is None): raise ProtocolError( 'openid.claimed_id is present without openid.identity') @@ -1793,11 +1792,11 @@ def getReturnTo(self): def __eq__(self, other): return ( - (self.endpoint == other.endpoint) and - (self.identity_url == other.identity_url) and - (self.message == other.message) and - (self.signed_fields == other.signed_fields) and - (self.status == other.status)) + self.endpoint == other.endpoint + and self.identity_url == other.identity_url + and self.message == other.message + and self.signed_fields == other.signed_fields + and self.status == other.status) def __ne__(self, other): return not (self == other) diff --git a/openid/extensions/ax.py b/openid/extensions/ax.py index 79856f7c..e6ccaf21 100644 --- a/openid/extensions/ax.py +++ b/openid/extensions/ax.py @@ -312,8 +312,7 @@ def fromOpenIDRequest(cls, openid_request): message.getArg(OPENID_NS, 'return_to')) if not realm: - raise AXError(("Cannot validate update_url %r " + - "against absent realm") % (self.update_url,)) + raise AXError("Cannot validate update_url %r against absent realm" % self.update_url) tr = TrustRoot.parse(realm) if not tr.validateURL(self.update_url): diff --git a/openid/extensions/draft/pape2.py b/openid/extensions/draft/pape2.py index e8dec915..6e7e4565 100644 --- a/openid/extensions/draft/pape2.py +++ b/openid/extensions/draft/pape2.py @@ -35,7 +35,7 @@ AUTH_PHISHING_RESISTANT = \ 'http://schemas.openid.net/pape/policies/2007/06/phishing-resistant' -TIME_VALIDATOR = re.compile('^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') +TIME_VALIDATOR = re.compile(r'^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') class Request(Extension): @@ -63,8 +63,7 @@ def __init__(self, preferred_auth_policies=None, max_auth_age=None): self.max_auth_age = max_auth_age def __bool__(self): - return bool(self.preferred_auth_policies or - self.max_auth_age is not None) + return bool(self.preferred_auth_policies or self.max_auth_age is not None) def __nonzero__(self): return self.__bool__() diff --git a/openid/extensions/pape.py b/openid/extensions/pape.py index 56a0b923..b2543fad 100644 --- a/openid/extensions/pape.py +++ b/openid/extensions/pape.py @@ -36,7 +36,7 @@ AUTH_NONE = \ 'http://schemas.openid.net/pape/policies/2007/06/none' -TIME_VALIDATOR = re.compile('^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') +TIME_VALIDATOR = re.compile(r'^\d\d\d\d-\d\d-\d\dT\d\d:\d\d:\d\dZ$') LEVELS_NIST = 'http://csrc.nist.gov/publications/nistpubs/800-63/SP800-63V1_0_2.pdf' LEVELS_JISA = 'http://www.jisa.or.jp/spec/auth_level.html' @@ -131,9 +131,7 @@ def __init__(self, preferred_auth_policies=None, max_auth_age=None, self.addAuthLevel(auth_level) def __bool__(self): - return bool(self.preferred_auth_policies or - self.max_auth_age is not None or - self.preferred_auth_level_types) + return bool(self.preferred_auth_policies or self.max_auth_age is not None or self.preferred_auth_level_types) def __nonzero__(self): return self.__bool__() diff --git a/openid/extensions/sreg.py b/openid/extensions/sreg.py index a90e4d42..543a8b4a 100644 --- a/openid/extensions/sreg.py +++ b/openid/extensions/sreg.py @@ -110,8 +110,7 @@ def supportsSReg(endpoint): @returns: Whether an sreg type was advertised by the endpoint @rtype: bool """ - return (endpoint.usesExtension(ns_uri_1_1) or - endpoint.usesExtension(ns_uri_1_0)) + return (endpoint.usesExtension(ns_uri_1_1) or endpoint.usesExtension(ns_uri_1_0)) class SRegNamespaceError(ValueError): @@ -294,8 +293,7 @@ def wereFieldsRequested(self): def __contains__(self, field_name): """Was this field in the request?""" - return (field_name in self.required or - field_name in self.optional) + return (field_name in self.required or field_name in self.optional) def requestField(self, field_name, required=False, strict=False): """Request the specified field from the OpenID user diff --git a/openid/server/server.py b/openid/server/server.py index cf23c824..95b847c7 100644 --- a/openid/server/server.py +++ b/openid/server/server.py @@ -367,9 +367,7 @@ def fromMessage(cls, message): """ dh_modulus = message.getArg(OPENID_NS, 'dh_modulus') dh_gen = message.getArg(OPENID_NS, 'dh_gen') - if (dh_modulus is None and dh_gen is not None or - dh_gen is None and dh_modulus is not None): - + if (dh_modulus is None and dh_gen is not None or dh_gen is None and dh_modulus is not None): if dh_modulus is None: missing = 'modulus' else: @@ -515,8 +513,7 @@ def answer(self, assoc): response.fields.updateArgs(OPENID_NS, self.session.answer(assoc.secret)) - if not (self.session.session_type == 'no-encryption' and - self.message.isOpenID1()): + if not (self.session.session_type == 'no-encryption' and self.message.isOpenID1()): # The session type "no-encryption" did not have a name # in OpenID v1, it was just omitted. response.fields.setArg( diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index db48ed87..159b60d7 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -206,8 +206,7 @@ def validateURL(self, url): if not self.wildcard: if host != self.host: return False - elif ((not host.endswith(self.host)) and - ('.' + host) != self.host): + elif ((not host.endswith(self.host)) and ('.' + host) != self.host): return False if path != self.path: @@ -379,13 +378,13 @@ def returnToMatches(allowed_return_to_urls, return_to): return_realm = TrustRoot.parse(allowed_return_to) if ( # Parses as a trust root - return_realm is not None and + return_realm is not None # Does not have a wildcard - not return_realm.wildcard and + and not return_realm.wildcard # Matches the return_to that we passed in with it - return_realm.validateURL(return_to) + and return_realm.validateURL(return_to) ): return True diff --git a/openid/store/memstore.py b/openid/store/memstore.py index 8d271d66..55876767 100644 --- a/openid/store/memstore.py +++ b/openid/store/memstore.py @@ -122,8 +122,7 @@ def cleanupAssociations(self): return removed_assocs def __eq__(self, other): - return ((self.server_assocs == other.server_assocs) and - (self.nonces == other.nonces)) + return ((self.server_assocs == other.server_assocs) and (self.nonces == other.nonces)) def __ne__(self, other): return not (self == other) diff --git a/openid/store/sqlstore.py b/openid/store/sqlstore.py index 8865e4f1..2d3b0b8a 100644 --- a/openid/store/sqlstore.py +++ b/openid/store/sqlstore.py @@ -115,8 +115,7 @@ def __init__(self, conn, associations_table=None, nonces_table=None): if hasattr(self.conn, 'IntegrityError') and hasattr(self.conn, 'OperationalError'): self.exceptions = self.conn - if not (hasattr(self.exceptions, 'IntegrityError') and - hasattr(self.exceptions, 'OperationalError')): + if not (hasattr(self.exceptions, 'IntegrityError') and hasattr(self.exceptions, 'OperationalError')): raise RuntimeError("Error using database connection module " "(Maybe it can't be imported?)") diff --git a/openid/test/test_server.py b/openid/test/test_server.py index c522e08a..f35a8bfc 100644 --- a/openid/test/test_server.py +++ b/openid/test/test_server.py @@ -486,20 +486,6 @@ def test_associateDHMissingModGen(self): } self.assertRaises(server.ProtocolError, self.decode, args) - -# def test_associateDHInvalidModGen(self): -# # test dh with properly encoded values that are not a valid -# # modulus/generator combination. -# args = { -# 'openid.mode': 'associate', -# 'openid.session_type': 'DH-SHA1', -# 'openid.dh_consumer_public': "Rzup9265tw==", -# 'openid.dh_modulus': cryptutil.longToBase64(9), -# 'openid.dh_gen': cryptutil.longToBase64(27) , -# } -# self.assertRaises(server.ProtocolError, self.decode, args) -# test_associateDHInvalidModGen.todo = "low-priority feature" - def test_associateWeirdSession(self): args = { 'openid.mode': 'associate', diff --git a/openid/yadis/discover.py b/openid/yadis/discover.py index 769fb74a..c4678a73 100644 --- a/openid/yadis/discover.py +++ b/openid/yadis/discover.py @@ -52,8 +52,7 @@ def usedYadisLocation(self): def isXRDS(self): """Is the response text supposed to be an XRDS document?""" - return (self.usedYadisLocation() or - self.content_type == YADIS_CONTENT_TYPE) + return (self.usedYadisLocation() or self.content_type == YADIS_CONTENT_TYPE) def discover(uri): diff --git a/openid/yadis/xri.py b/openid/yadis/xri.py index 2924f353..4712c6d2 100644 --- a/openid/yadis/xri.py +++ b/openid/yadis/xri.py @@ -34,7 +34,7 @@ def toIRINormal(xri): return escapeForIRI(xri) -_xref_re = re.compile('\((.*?)\)') +_xref_re = re.compile(r'\((.*?)\)') def _escape_xref(xref_match): From 14e765802f7265e84f5db01625e73e7417e99c5d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 14 Nov 2018 09:35:07 +0100 Subject: [PATCH 129/151] Fixup djopenid example --- examples/djopenid/README | 8 ++--- examples/djopenid/consumer/tests.py | 56 +++++++++++++++++++++++++++++ examples/djopenid/consumer/views.py | 4 +-- examples/djopenid/settings.py | 6 ++-- 4 files changed, 65 insertions(+), 9 deletions(-) create mode 100644 examples/djopenid/consumer/tests.py diff --git a/examples/djopenid/README b/examples/djopenid/README index e803648b..1b42d37a 100644 --- a/examples/djopenid/README +++ b/examples/djopenid/README @@ -12,9 +12,9 @@ SETUP 1. Install the OpenID library, version 2.0.0 or later. - 2. Install Django 0.95.1. + 2. Install Django. - If you find that the examples run on even newer versions of + If you find that the examples doesn't run on newer versions of Django, please let us know! 3. Modify djopenid/settings.py appropriately; you may wish to change @@ -23,11 +23,11 @@ SETUP 4. In examples/djopenid/ run: - python manage.py syncdb + python manage.py migrate 5. To run the example consumer or server, run - python manage.py runserver PORT + python manage.py runserver [PORT] where PORT is the port number on which to listen. diff --git a/examples/djopenid/consumer/tests.py b/examples/djopenid/consumer/tests.py new file mode 100644 index 00000000..c6b7b26d --- /dev/null +++ b/examples/djopenid/consumer/tests.py @@ -0,0 +1,56 @@ +"""Test the consumer.""" +from __future__ import unicode_literals + +import django +from django.test import TestCase +from openid.fetchers import setDefaultFetcher, HTTPResponse +from openid.yadis.constants import YADIS_CONTENT_TYPE + +# Allow django tests to run through discover +django.setup() + + +EXAMPLE_XRDS = b''' + + + + http://specs.openid.net/auth/2.0/server + http://example.com/ + + +''' + + +class FakeFetcher(object): + """Fake fetcher for tests.""" + + def __init__(self): + self.response = None + + def fetch(self, *args, **kwargs): + return self.response + + +class TestStartOpenID(TestCase): + """Test 'startOpenID' view.""" + + def setUp(self): + self.fetcher = FakeFetcher() + setDefaultFetcher(self.fetcher) + + def tearDown(self): + setDefaultFetcher(None) + + def test_get(self): + response = self.client.get('/consumer/') + self.assertContains(response, ' example consumer ') + + def test_post(self): + self.fetcher.response = HTTPResponse('http://example.com/', 200, {'content-type': YADIS_CONTENT_TYPE}, + EXAMPLE_XRDS) + + response = self.client.post('/consumer/', {'openid_identifier': 'http://example.com/'}) + + # Renders a POST form + self.assertContains(response, 'http://example.com/') + self.assertContains(response, 'openid.identity') diff --git a/examples/djopenid/consumer/views.py b/examples/djopenid/consumer/views.py index 6a8e2a45..0d2d8219 100644 --- a/examples/djopenid/consumer/views.py +++ b/examples/djopenid/consumer/views.py @@ -116,8 +116,8 @@ def startOpenID(request): # Compute the trust root and return URL values to build the # redirect information. - trust_root = util.request.build_absolute_uri(reverse('consumer:index')) - return_to = util.request.build_absolute_uri(reverse('consumer:return_to')) + trust_root = request.build_absolute_uri(reverse('consumer:index')) + return_to = request.build_absolute_uri(reverse('consumer:return_to')) # Send the browser to the server either by sending a redirect # URL or by generating a POST form. diff --git a/examples/djopenid/settings.py b/examples/djopenid/settings.py index dad0aa17..fd1a2ffd 100644 --- a/examples/djopenid/settings.py +++ b/examples/djopenid/settings.py @@ -24,6 +24,8 @@ } SECRET_KEY = 'u^bw6lmsa6fah0$^lz-ct$)y7x7#ag92-z+y45-8!(jk0lkavy' +SESSION_ENGINE = 'django.contrib.sessions.backends.file' +SESSION_SERIALIZER = 'django.contrib.sessions.serializers.PickleSerializer' TEMPLATES = [ { @@ -34,10 +36,8 @@ ] MIDDLEWARE = ( - 'django.middleware.common.CommonMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.middleware.doc.XViewMiddleware', + 'django.middleware.common.CommonMiddleware', ) ROOT_URLCONF = 'djopenid.urls' From c91131e427ff5a6b31fac888ee32ad70e6b020bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 8 Jan 2019 15:20:09 +0100 Subject: [PATCH 130/151] Fix associate script for P3 --- contrib/associate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/contrib/associate.py b/contrib/associate.py index 221f1ab9..d88b6980 100755 --- a/contrib/associate.py +++ b/contrib/associate.py @@ -81,7 +81,7 @@ def strxor(x, y): def parse_kv_response(response): """Parse the key-value response.""" decoded_data = {} - for line in response.iter_lines(): + for line in response.text.splitlines(): line = line.strip() if not line: continue @@ -175,7 +175,7 @@ def establish_association(endpoint, assoc_type, session_type, generator, generat 'session_type': association_data['session_type'], 'assoc_handle': association_data['assoc_handle'], 'expires_in': association_data['expires_in'], - 'mac_key': base64.b64encode(mac_key)} + 'mac_key': base64.b64encode(mac_key).decode('utf-8')} def main(): From 2b1e133ed17073e41639987053c8ca0e030fedb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 17 Jun 2019 10:50:09 +0200 Subject: [PATCH 131/151] Fix isort --- openid/__init__.py | 1 + openid/consumer/__init__.py | 1 + openid/extensions/__init__.py | 1 + openid/server/__init__.py | 1 + openid/store/__init__.py | 1 + openid/test/__init__.py | 2 ++ openid/yadis/__init__.py | 2 +- 7 files changed, 8 insertions(+), 1 deletion(-) diff --git a/openid/__init__.py b/openid/__init__.py index 60a05925..86f61745 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -22,6 +22,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +from __future__ import unicode_literals __version__ = '3.0' diff --git a/openid/consumer/__init__.py b/openid/consumer/__init__.py index aab51a29..bd7aa244 100644 --- a/openid/consumer/__init__.py +++ b/openid/consumer/__init__.py @@ -2,5 +2,6 @@ This package contains the portions of the library used only when implementing an OpenID consumer. """ +from __future__ import unicode_literals __all__ = ['consumer', 'discover'] diff --git a/openid/extensions/__init__.py b/openid/extensions/__init__.py index 5394e7a4..e1edd9c4 100644 --- a/openid/extensions/__init__.py +++ b/openid/extensions/__init__.py @@ -1,3 +1,4 @@ """OpenID Extension modules.""" +from __future__ import unicode_literals __all__ = ['ax', 'pape', 'sreg'] diff --git a/openid/server/__init__.py b/openid/server/__init__.py index c8fde257..b2e59d18 100644 --- a/openid/server/__init__.py +++ b/openid/server/__init__.py @@ -2,5 +2,6 @@ This package contains the portions of the library used only when implementing an OpenID server. See L{openid.server.server}. """ +from __future__ import unicode_literals __all__ = ['server', 'trustroot'] diff --git a/openid/store/__init__.py b/openid/store/__init__.py index 76509b51..02f20eaa 100644 --- a/openid/store/__init__.py +++ b/openid/store/__init__.py @@ -4,5 +4,6 @@ @sort: interface, filestore, sqlstore, memstore """ +from __future__ import unicode_literals __all__ = ['interface', 'filestore', 'sqlstore', 'memstore', 'nonce'] diff --git a/openid/test/__init__.py b/openid/test/__init__.py index a503e99a..8de16bf7 100644 --- a/openid/test/__init__.py +++ b/openid/test/__init__.py @@ -1,4 +1,6 @@ """Openid library tests.""" +from __future__ import unicode_literals + import unittest diff --git a/openid/yadis/__init__.py b/openid/yadis/__init__.py index a163f803..9fe91c7f 100644 --- a/openid/yadis/__init__.py +++ b/openid/yadis/__init__.py @@ -1,4 +1,4 @@ - +from __future__ import unicode_literals __all__ = [ 'constants', From f6719a7fa179481df547757ec20acc322d6c60c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 17 Jun 2019 10:27:56 +0200 Subject: [PATCH 132/151] Fix #29 - Fix assoc type as bytes in Association --- openid/association.py | 4 ++-- openid/test/test_association.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/openid/association.py b/openid/association.py index f29a4c37..9025ff86 100644 --- a/openid/association.py +++ b/openid/association.py @@ -327,6 +327,7 @@ def __init__(self, handle, secret, issued, lifetime, assoc_type): defined in the future. @type assoc_type: six.text_type, six.binary_type is deprecated """ + assoc_type = string_to_text(assoc_type, "Binary values for assoc_type are deprecated. Use text input instead.") if assoc_type not in all_association_types: fmt = '%r is not a supported association type' raise ValueError(fmt % (assoc_type,)) @@ -341,8 +342,7 @@ def __init__(self, handle, secret, issued, lifetime, assoc_type): self.secret = secret self.issued = issued self.lifetime = lifetime - self.assoc_type = string_to_text(assoc_type, - "Binary values for assoc_type are deprecated. Use text input instead.") + self.assoc_type = assoc_type def getExpiresIn(self, now=None): """ diff --git a/openid/test/test_association.py b/openid/test/test_association.py index 0763b12d..7a6b2eaf 100644 --- a/openid/test/test_association.py +++ b/openid/test/test_association.py @@ -11,6 +11,12 @@ from openid.server.server import DiffieHellmanSHA1ServerSession, PlainTextServerSession +class TestAssociation(unittest.TestCase): + def test_assoc_type_bytes(self): + assoc = association.Association('handle', b'secret', 1000, 1000, b'HMAC-SHA1') + self.assertEqual(assoc.assoc_type, 'HMAC-SHA1') + + class AssociationSerializationTest(unittest.TestCase): def test_roundTrip(self): issued = int(time.time()) From cbd51c1904927e2a3f03cc4c921390d681f0aff8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 17 Jun 2019 10:31:46 +0200 Subject: [PATCH 133/151] Add support for python 3.7 --- .travis.yml | 3 +++ setup.py | 1 + tox.ini | 6 +++--- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 091d075a..e8be1198 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,6 @@ language: python +# Enable python 3.7 +dist: xenial sudo: false @@ -7,6 +9,7 @@ python: - "3.4" - "3.5" - "3.6" + - "3.7" - "pypy" addons: diff --git a/setup.py b/setup.py index 621c4c8f..f368c49a 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Internet :: WWW/HTTP', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', diff --git a/tox.ini b/tox.ini index c548f3f5..bbcdc0b2 100644 --- a/tox.ini +++ b/tox.ini @@ -1,13 +1,13 @@ [tox] envlist = quality - py{27,34,35,36}-{openid,djopenid,httplib2,pycurl,requests} + py{27,34,35,36,37}-{openid,djopenid,httplib2,pycurl,requests} pypy-{openid,djopenid,httplib2,pycurl,requests} # tox-travis specials [travis] python = - 2.7: py27, quality + 3.7: py37, quality # Generic specification for all unspecific environments [testenv] @@ -28,7 +28,7 @@ commands = djopenid: coverage run --parallel-mode --branch --source=openid,examples --module unittest discover --start={toxinidir}/examples [testenv:quality] -basepython = python2.7 +basepython = python3.7 extras = quality commands = From 4bd4181e1b6982fb6078d285c94e983a79bc3eb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 17 Jun 2019 14:09:38 +0200 Subject: [PATCH 134/151] Fix flake8 warnings --- openid/consumer/consumer.py | 2 +- openid/fetchers.py | 2 +- openid/test/test_cryptutil.py | 14 ++++---------- 3 files changed, 6 insertions(+), 12 deletions(-) diff --git a/openid/consumer/consumer.py b/openid/consumer/consumer.py index 86acd8f5..275a08a8 100644 --- a/openid/consumer/consumer.py +++ b/openid/consumer/consumer.py @@ -1185,7 +1185,7 @@ def _negotiateAssociation(self, endpoint): try: assoc = self._requestAssociation( endpoint, assoc_type, session_type) - except ServerError as why: + except ServerError: # Do not keep trying, since it rejected the # association type that it told us to use. _LOGGER.error('Server %s refused its suggested association type: session_type=%s, assoc_type=%s', diff --git a/openid/fetchers.py b/openid/fetchers.py index f8e6e37d..8a109933 100644 --- a/openid/fetchers.py +++ b/openid/fetchers.py @@ -289,7 +289,7 @@ def _parseHeaders(self, header_file): # and the blank line from the end empty_line = lines.pop() if empty_line: - raise HTTPError("No blank line at end of headers: %r" % (line,)) + raise HTTPError("No blank line at end of headers: %r" % empty_line) headers = {} for line in lines: diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index eae2a1c6..106be420 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -10,15 +10,9 @@ from openid import cryptutil + # Most of the purpose of this test is to make sure that cryptutil can # find a good source of randomness on this machine. -if six.PY2: - long_int = long -else: - assert six.PY3 - long_int = int - - class TestLongBinary(unittest.TestCase): """Test `longToBinary` and `binaryToLong` functions.""" @@ -27,7 +21,7 @@ def test_binaryLongConvert(self): for iteration in range(500): n = 0 for i in range(10): - n += long_int(random.randrange(MAX)) + n += random.randrange(MAX) s = cryptutil.longToBinary(n) assert isinstance(s, six.binary_type) @@ -101,7 +95,7 @@ def test_longToBase64(self): try: for line in f: parts = line.strip().split(' ') - assert parts[0] == cryptutil.longToBase64(long_int(parts[1])) + assert parts[0] == cryptutil.longToBase64(int(parts[1])) finally: f.close() @@ -114,6 +108,6 @@ def test_base64ToLong(self): try: for line in f: parts = line.strip().split(' ') - assert long_int(parts[1]) == cryptutil.base64ToLong(parts[0]) + assert int(parts[1]) == cryptutil.base64ToLong(parts[0]) finally: f.close() From 88dfb5c1eec98a5c39a7511a8ce47f0815af5d5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 18 Jun 2019 09:47:23 +0200 Subject: [PATCH 135/151] Set up bumpversion --- .bumpversion.cfg | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 .bumpversion.cfg diff --git a/.bumpversion.cfg b/.bumpversion.cfg new file mode 100644 index 00000000..0d9de3c1 --- /dev/null +++ b/.bumpversion.cfg @@ -0,0 +1,20 @@ +[bumpversion] +current_version = 3.0 +commit = True +tag = True +tag_name = {new_version} +parse = (?P\d+)\.(?P\d+)(?P.*) +serialize = + {major}.{minor}{rc} + {major}.{minor} + +[bumpversion:part:rc] +values = + final + rc1 + rc2 + rc3 + rc4 + rc5 + +[bumpversion:file:openid/__init__.py] From 987f55ce9ce3643929e2210eb86c68345decd401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 20 Jun 2019 10:38:11 +0200 Subject: [PATCH 136/151] Fix bumpversion config --- .bumpversion.cfg | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 0d9de3c1..bfc46b64 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -9,12 +9,13 @@ serialize = {major}.{minor} [bumpversion:part:rc] +optional_value = final values = - final rc1 rc2 rc3 rc4 rc5 + final [bumpversion:file:openid/__init__.py] From 08e1dff25d908f29c4d9dab28876906020312aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 24 Jun 2019 09:52:13 +0200 Subject: [PATCH 137/151] Update changelog --- Changelog.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/Changelog.md b/Changelog.md index ac649836..70fa0ea9 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,13 @@ # Changelog # +## 3.1 ## + * Convert data values for extensions to text. + * Fixes in Python 2/3 support. + * Fix examples. + * Add support for python 3.7 + * Fix static code checks + * Use bumpversion + ## 3.0 ## * Support Python3. From 089d99a95092c95ed723ed7be64f7833828c9637 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 24 Jun 2019 09:52:34 +0200 Subject: [PATCH 138/151] =?UTF-8?q?Bump=20version:=203.0=20=E2=86=92=203.1?= =?UTF-8?q?rc1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 23 ++++++++++++----------- openid/__init__.py | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index bfc46b64..08722df8 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,21 +1,22 @@ [bumpversion] -current_version = 3.0 +current_version = 3.1rc1 commit = True tag = True tag_name = {new_version} parse = (?P\d+)\.(?P\d+)(?P.*) -serialize = - {major}.{minor}{rc} - {major}.{minor} +serialize = + {major}.{minor}{rc} + {major}.{minor} [bumpversion:part:rc] optional_value = final -values = - rc1 - rc2 - rc3 - rc4 - rc5 - final +values = + rc1 + rc2 + rc3 + rc4 + rc5 + final [bumpversion:file:openid/__init__.py] + diff --git a/openid/__init__.py b/openid/__init__.py index 86f61745..a4652676 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -24,7 +24,7 @@ """ from __future__ import unicode_literals -__version__ = '3.0' +__version__ = '3.1rc1' __all__ = [ 'association', From 2d741abce6759ee4d0070d22cd57a9cbe22479b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Fri, 26 Jul 2019 10:17:29 +0200 Subject: [PATCH 139/151] =?UTF-8?q?Bump=20version:=203.1rc1=20=E2=86=92=20?= =?UTF-8?q?3.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- openid/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 08722df8..4f1f3084 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 3.1rc1 +current_version = 3.1 commit = True tag = True tag_name = {new_version} diff --git a/openid/__init__.py b/openid/__init__.py index a4652676..cceb98a3 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -24,7 +24,7 @@ """ from __future__ import unicode_literals -__version__ = '3.1rc1' +__version__ = '3.1' __all__ = [ 'association', From 4d674262d9542408e4ba7854fe21fd5fb895731b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 8 Oct 2019 10:56:02 +0200 Subject: [PATCH 140/151] Drop python 3.4 --- .travis.yml | 1 - README.md | 2 +- setup.py | 3 +-- tox.ini | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.travis.yml b/.travis.yml index e8be1198..f2f2c462 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,7 +6,6 @@ sudo: false python: - "2.7" - - "3.4" - "3.5" - "3.6" - "3.7" diff --git a/README.md b/README.md index 9836f498..4a51958a 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Includes example code and support for a variety of storage back-ends. ## REQUIREMENTS ## - - Python 2.7, >3.4 + - Python 2.7, >=3.5 - lxml - six - cryptography diff --git a/setup.py b/setup.py index f368c49a..d0ad7690 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,6 @@ 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', @@ -62,7 +61,7 @@ 'openid.extensions', 'openid.extensions.draft', ], - python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*', + python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*', install_requires=INSTALL_REQUIRES, extras_require=EXTRAS_REQUIRE, # license specified by classifier. diff --git a/tox.ini b/tox.ini index bbcdc0b2..2f54bf0c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = quality - py{27,34,35,36,37}-{openid,djopenid,httplib2,pycurl,requests} + py{27,35,36,37}-{openid,djopenid,httplib2,pycurl,requests} pypy-{openid,djopenid,httplib2,pycurl,requests} # tox-travis specials From ae769920b2dc641b9c7ddbc500071a66cb1d3667 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 8 Oct 2019 10:14:50 +0200 Subject: [PATCH 141/151] Fix false positive redirect when verifying consumer --- openid/server/trustroot.py | 2 +- openid/test/test_rpverify.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/openid/server/trustroot.py b/openid/server/trustroot.py index 159b60d7..349032f0 100644 --- a/openid/server/trustroot.py +++ b/openid/server/trustroot.py @@ -400,7 +400,7 @@ def getAllowedReturnURLs(relying_party_url): (rp_url_after_redirects, return_to_urls) = services.getServiceEndpoints( relying_party_url, _extractReturnURL) - if rp_url_after_redirects != relying_party_url: + if urinorm.urinorm(rp_url_after_redirects) != urinorm.urinorm(relying_party_url): # Verification caused a redirect raise RealmVerificationRedirected( relying_party_url, rp_url_after_redirects) diff --git a/openid/test/test_rpverify.py b/openid/test/test_rpverify.py index 82af2cf5..5b6780a5 100644 --- a/openid/test/test_rpverify.py +++ b/openid/test/test_rpverify.py @@ -3,9 +3,11 @@ import unittest +from mock import patch, sentinel from testfixtures import LogCapture, StringComparison from openid.server import trustroot +from openid.server.trustroot import getAllowedReturnURLs from openid.yadis import services from openid.yadis.discover import DiscoveryFailure, DiscoveryResult @@ -183,6 +185,24 @@ def test_noMatch(self): self.assertFalse(trustroot.returnToMatches([r], 'http://example.com/xss_exploit')) +class TestGetAllowedReturnURLs(unittest.TestCase): + + def test_equal(self): + with patch('openid.yadis.services.getServiceEndpoints', autospec=True, + return_value=('http://example.com/', sentinel.endpoints)): + endpoints = getAllowedReturnURLs('http://example.com/') + + self.assertEqual(endpoints, sentinel.endpoints) + + def test_normalized(self): + # Test redirect is not reported when the returned URL is normalized. + with patch('openid.yadis.services.getServiceEndpoints', autospec=True, + return_value=('http://example.com/', sentinel.endpoints)): + endpoints = getAllowedReturnURLs('http://example.com:80') + + self.assertEqual(endpoints, sentinel.endpoints) + + class TestVerifyReturnTo(unittest.TestCase): def test_bogusRealm(self): From 557dc2eac99b29feda2c7f207f7ad6cbe90dde09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 8 Oct 2019 14:57:20 +0200 Subject: [PATCH 142/151] =?UTF-8?q?Bump=20version:=203.1=20=E2=86=92=203.2?= =?UTF-8?q?rc1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- Changelog.md | 4 ++++ openid/__init__.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 4f1f3084..93c0a993 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 3.1 +current_version = 3.2rc1 commit = True tag = True tag_name = {new_version} diff --git a/Changelog.md b/Changelog.md index 70fa0ea9..5923c3fe 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,9 @@ # Changelog # +## 3.2 ## + * Drop support for python 3.4. + * Fix false positive redirect error in consumer verification. + ## 3.1 ## * Convert data values for extensions to text. * Fixes in Python 2/3 support. diff --git a/openid/__init__.py b/openid/__init__.py index cceb98a3..5bb0e9f0 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -24,7 +24,7 @@ """ from __future__ import unicode_literals -__version__ = '3.1' +__version__ = '3.2rc1' __all__ = [ 'association', From 2c0d6f05ed118e64097ee9d0e1237a18c066c9e3 Mon Sep 17 00:00:00 2001 From: Colin Watson Date: Mon, 2 Mar 2020 11:47:34 +0100 Subject: [PATCH 143/151] Quieten some noisy deprecation warnings in tests --- openid/test/test_cryptutil.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/openid/test/test_cryptutil.py b/openid/test/test_cryptutil.py index 106be420..f0caed52 100644 --- a/openid/test/test_cryptutil.py +++ b/openid/test/test_cryptutil.py @@ -5,6 +5,7 @@ import random import sys import unittest +import warnings import six @@ -18,15 +19,17 @@ class TestLongBinary(unittest.TestCase): def test_binaryLongConvert(self): MAX = sys.maxsize - for iteration in range(500): - n = 0 - for i in range(10): - n += random.randrange(MAX) - - s = cryptutil.longToBinary(n) - assert isinstance(s, six.binary_type) - n_prime = cryptutil.binaryToLong(s) - assert n == n_prime, (n, n_prime) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=DeprecationWarning) + for iteration in range(500): + n = 0 + for i in range(10): + n += random.randrange(MAX) + + s = cryptutil.longToBinary(n) + assert isinstance(s, six.binary_type) + n_prime = cryptutil.binaryToLong(s) + assert n == n_prime, (n, n_prime) cases = [ (b'\x00', 0), @@ -39,11 +42,13 @@ def test_binaryLongConvert(self): (b'OpenID is cool', 1611215304203901150134421257416556) ] - for s, n in cases: - n_prime = cryptutil.binaryToLong(s) - s_prime = cryptutil.longToBinary(n) - assert n == n_prime, (s, n, n_prime) - assert s == s_prime, (n, s, s_prime) + with warnings.catch_warnings(): + warnings.simplefilter('ignore', category=DeprecationWarning) + for s, n in cases: + n_prime = cryptutil.binaryToLong(s) + s_prime = cryptutil.longToBinary(n) + assert n == n_prime, (s, n, n_prime) + assert s == s_prime, (n, s, s_prime) class TestFixBtwoc(unittest.TestCase): From 3f71c33a4b907a8120468e2f22da1d4a15824d90 Mon Sep 17 00:00:00 2001 From: Colin Watson Date: Mon, 2 Mar 2020 11:47:54 +0100 Subject: [PATCH 144/151] Fix TestRequestsFetcher failures `responses` needs the content type to be set using the `content_type` keyword argument; using `headers` for this results in responses with `Content-Type: text/plain, text/plain`. test_invalid_url needs a slight adjustment to pass on Python 2, due to the different `repr` for text strings. --- openid/test/test_fetchers.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/openid/test/test_fetchers.py b/openid/test/test_fetchers.py index d7b03d1a..19893407 100644 --- a/openid/test/test_fetchers.py +++ b/openid/test/test_fetchers.py @@ -401,7 +401,7 @@ def test_get(self): # Test GET response with responses.RequestsMock() as rsps: rsps.add(responses.GET, 'http://example.cz/', status=200, body=b'BODY', - headers={'Content-Type': 'text/plain'}) + content_type='text/plain') response = self.fetcher.fetch('http://example.cz/') expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) @@ -410,7 +410,7 @@ def test_post(self): # Test POST response with responses.RequestsMock() as rsps: rsps.add(responses.POST, 'http://example.cz/', status=200, body=b'BODY', - headers={'Content-Type': 'text/plain'}) + content_type='text/plain') response = self.fetcher.fetch('http://example.cz/', body=b'key=value') expected = fetchers.HTTPResponse('http://example.cz/', 200, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) @@ -421,7 +421,7 @@ def test_redirect(self): rsps.add(responses.GET, 'http://example.cz/redirect/', status=302, headers={'Location': 'http://example.cz/target/'}) rsps.add(responses.GET, 'http://example.cz/target/', status=200, body=b'BODY', - headers={'Content-Type': 'text/plain'}) + content_type='text/plain') response = self.fetcher.fetch('http://example.cz/redirect/') expected = fetchers.HTTPResponse('http://example.cz/target/', 200, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) @@ -430,14 +430,18 @@ def test_error(self): # Test error responses - returned as obtained with responses.RequestsMock() as rsps: rsps.add(responses.GET, 'http://example.cz/error/', status=500, body=b'BODY', - headers={'Content-Type': 'text/plain'}) + content_type='text/plain') response = self.fetcher.fetch('http://example.cz/error/') expected = fetchers.HTTPResponse('http://example.cz/error/', 500, {'Content-Type': 'text/plain'}, b'BODY') assertResponse(expected, response) def test_invalid_url(self): invalid_url = 'invalid://example.cz/' - with six.assertRaisesRegex(self, InvalidSchema, "No connection adapters were found for '" + invalid_url + "'"): + expected_message = ( + 'No connection adapters were found for ' + + ('u' if six.PY2 else '') + + "'" + invalid_url + "'") + with six.assertRaisesRegex(self, InvalidSchema, expected_message): self.fetcher.fetch(invalid_url) def test_connection_error(self): From e96fe4e81d144f0a46a0f725f6c52ca7466cba26 Mon Sep 17 00:00:00 2001 From: Colin Watson Date: Thu, 4 Jun 2020 02:40:40 +0100 Subject: [PATCH 145/151] Fix flake8 complaints flake8 complained about "[E741] ambiguous variable name 'l'" in a few places. --- openid/cryptutil.py | 4 ++-- openid/test/test_accept.py | 2 +- openid/test/test_openidyadis.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/openid/cryptutil.py b/openid/cryptutil.py index ded17d86..7e8cd0ac 100644 --- a/openid/cryptutil.py +++ b/openid/cryptutil.py @@ -75,8 +75,8 @@ def binaryToLong(s): return bytes_to_int(s) -def longToBase64(l): - return toBase64(int_to_bytes(l)) +def longToBase64(value): + return toBase64(int_to_bytes(value)) def base64ToLong(s): diff --git a/openid/test/test_accept.py b/openid/test/test_accept.py index b10934ac..aa13d875 100644 --- a/openid/test/test_accept.py +++ b/openid/test/test_accept.py @@ -89,7 +89,7 @@ class MatchAcceptTest(unittest.TestCase): def runTest(self): lines = getTestData() chunks = chunk(lines) - data_sets = [parseLines(l) for l in chunks] + data_sets = [parseLines(line) for line in chunks] for data in data_sets: lnos = [] lno, accept_header = data['accept'] diff --git a/openid/test/test_openidyadis.py b/openid/test/test_openidyadis.py index 4e77b606..3d17cb16 100644 --- a/openid/test/test_openidyadis.py +++ b/openid/test/test_openidyadis.py @@ -66,10 +66,10 @@ def mkService(uris=None, type_uris=None, local_id=None, dent=' '): # Used for generating test data -def subsets(l): +def subsets(lst): """Generate all non-empty sublists of a list""" subsets_list = [[]] - for x in l: + for x in lst: subsets_list += [[x] + t for t in subsets_list] return subsets_list From aba0333707d56102096e55f3caaddc56b3390528 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 2 Jul 2020 13:36:19 +0200 Subject: [PATCH 146/151] Fix urinorm - return plain sub delimiters in path, refs #41 --- openid/test/test_urinorm.py | 8 +++++++- openid/urinorm.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 2c7aa0c5..53debfe3 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -71,11 +71,17 @@ def test_path_percent_encoding(self): self.assertEqual(urinorm('http://example.com/Λ'), 'http://example.com/%CE%9B') def test_path_capitalize_percent_encoding(self): - self.assertEqual(urinorm('http://example.com/foo%2cbar'), 'http://example.com/foo%2Cbar') + self.assertEqual(urinorm('http://example.com/foo%3abar'), 'http://example.com/foo%3Abar') def test_path_percent_decode_unreserved(self): self.assertEqual(urinorm('http://example.com/foo%2Dbar%2dbaz'), 'http://example.com/foo-bar-baz') + def test_path_keep_sub_delims(self): + self.assertEqual(urinorm('http://example.com/foo+!bar'), 'http://example.com/foo+!bar') + + def test_path_percent_decode_sub_delims(self): + self.assertEqual(urinorm('http://example.com/foo%2B%21bar'), 'http://example.com/foo+!bar') + def test_illegal_characters(self): six.assertRaisesRegex(self, ValueError, 'Illegal characters in URI', urinorm, 'http://.com/') diff --git a/openid/urinorm.py b/openid/urinorm.py index 6a5a5883..96787411 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -122,10 +122,10 @@ def urinorm(uri): # This is hackish. `unquote` and `quote` requires `str` in both py27 and py3+. if isinstance(path, str): # Python 3 branch - path = quote(unquote(path)) + path = quote(unquote(path), safe='/' + SUB_DELIMS) else: # Python 2 branch - path = quote(unquote(path.encode('utf-8'))).decode('utf-8') + path = quote(unquote(path.encode('utf-8')), safe=('/' + SUB_DELIMS).encode('utf-8')).decode('utf-8') path = remove_dot_segments(path) if not path: From b33d8f6aa4a60294ba8457f8b45de5ed68e68d83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Thu, 9 Jul 2020 09:06:59 +0200 Subject: [PATCH 147/151] Fix isort --- admin/gettlds.py | 1 + openid/test/test_storetest.py | 5 +++-- tox.ini | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/admin/gettlds.py b/admin/gettlds.py index 4b0c4033..bc8d7347 100644 --- a/admin/gettlds.py +++ b/admin/gettlds.py @@ -11,6 +11,7 @@ from __future__ import unicode_literals import sys + import urllib2 langs = { diff --git a/openid/test/test_storetest.py b/openid/test/test_storetest.py index 61a11afe..4faa756c 100644 --- a/openid/test/test_storetest.py +++ b/openid/test/test_storetest.py @@ -225,9 +225,10 @@ class TestFileOpenIDStore(unittest.TestCase): """Test `FileOpenIDStore` class.""" def test_filestore(self): - from openid.store import filestore - import tempfile import shutil + import tempfile + + from openid.store import filestore try: temp_dir = tempfile.mkdtemp() except AttributeError: diff --git a/tox.ini b/tox.ini index 2f54bf0c..28abe1dd 100644 --- a/tox.ini +++ b/tox.ini @@ -33,5 +33,5 @@ extras = quality commands = # setup.py is excluded from isort because distutils have problems with unicode_literals. - isort --check-only --diff --recursive openid admin contrib + isort --check-only --diff openid admin contrib flake8 --format=pylint openid setup.py admin contrib From 80cab9c5de068aaa46f9a6fa366545562a5ea267 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Wed, 8 Jul 2020 11:09:53 +0200 Subject: [PATCH 148/151] Add support for python 3.8 --- .travis.yml | 1 + setup.py | 1 + tox.ini | 6 +++--- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index f2f2c462..3db2d3d2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,6 +9,7 @@ python: - "3.5" - "3.6" - "3.7" + - "3.8" - "pypy" addons: diff --git a/setup.py b/setup.py index d0ad7690..f8f34c42 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Topic :: Internet :: WWW/HTTP', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content :: CGI Tools/Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', diff --git a/tox.ini b/tox.ini index 28abe1dd..bd300ca1 100644 --- a/tox.ini +++ b/tox.ini @@ -1,13 +1,13 @@ [tox] envlist = quality - py{27,35,36,37}-{openid,djopenid,httplib2,pycurl,requests} + py{27,35,36,37,38}-{openid,djopenid,httplib2,pycurl,requests} pypy-{openid,djopenid,httplib2,pycurl,requests} # tox-travis specials [travis] python = - 3.7: py37, quality + 3.8: py38, quality # Generic specification for all unspecific environments [testenv] @@ -28,7 +28,7 @@ commands = djopenid: coverage run --parallel-mode --branch --source=openid,examples --module unittest discover --start={toxinidir}/examples [testenv:quality] -basepython = python3.7 +basepython = python3.8 extras = quality commands = From 40988b5603640e36ee90e134d64f0a33abc79af1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Tue, 14 Jul 2020 12:55:21 +0200 Subject: [PATCH 149/151] =?UTF-8?q?Bump=20version:=203.2rc1=20=E2=86=92=20?= =?UTF-8?q?3.2rc2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 3 +-- Changelog.md | 3 +++ openid/__init__.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 93c0a993..7e331d11 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 3.2rc1 +current_version = 3.2rc2 commit = True tag = True tag_name = {new_version} @@ -19,4 +19,3 @@ values = final [bumpversion:file:openid/__init__.py] - diff --git a/Changelog.md b/Changelog.md index 5923c3fe..8079af83 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,8 +1,11 @@ # Changelog # ## 3.2 ## + * Add support for python 3.8. * Drop support for python 3.4. * Fix false positive redirect error in consumer verification. + * Do not percent escape sub delimiters in path in URI normalization. Thanks Colin Watson for report. + * Fix tests and static code checks. Thanks Colin Watson. ## 3.1 ## * Convert data values for extensions to text. diff --git a/openid/__init__.py b/openid/__init__.py index 5bb0e9f0..66e47a7e 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -24,7 +24,7 @@ """ from __future__ import unicode_literals -__version__ = '3.2rc1' +__version__ = '3.2rc2' __all__ = [ 'association', From d093a0919198eb53826ae5753e517af10ad95d5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vlastimil=20Z=C3=ADma?= Date: Mon, 27 Jul 2020 11:10:14 +0200 Subject: [PATCH 150/151] =?UTF-8?q?Bump=20version:=203.2rc2=20=E2=86=92=20?= =?UTF-8?q?3.2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- openid/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 7e331d11..f02afbd4 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 3.2rc2 +current_version = 3.2 commit = True tag = True tag_name = {new_version} diff --git a/openid/__init__.py b/openid/__init__.py index 66e47a7e..7b33453f 100644 --- a/openid/__init__.py +++ b/openid/__init__.py @@ -24,7 +24,7 @@ """ from __future__ import unicode_literals -__version__ = '3.2rc2' +__version__ = '3.2' __all__ = [ 'association', From a2cb8bc70a12ad89a62a010fbe1569d21eed21d5 Mon Sep 17 00:00:00 2001 From: Colin Watson Date: Mon, 17 Aug 2020 18:38:33 +0100 Subject: [PATCH 151/151] Fix normalization of non-ASCII query strings on Python 2 urinorm currently deals with encoding issues when normalizing the path, but not the query string. However, in some cases it can happen that the query string contains non-ASCII characters, particularly if using https://openid.net/specs/openid-simple-registration-extension-1_0.html in which case the user's full name may very well not be entirely ASCII; on Python 2 this resulted in a UnicodeEncodeError in urlencode. Work around this. --- openid/test/test_urinorm.py | 8 ++++++++ openid/urinorm.py | 10 ++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/openid/test/test_urinorm.py b/openid/test/test_urinorm.py index 53debfe3..e85969b4 100644 --- a/openid/test/test_urinorm.py +++ b/openid/test/test_urinorm.py @@ -82,6 +82,14 @@ def test_path_keep_sub_delims(self): def test_path_percent_decode_sub_delims(self): self.assertEqual(urinorm('http://example.com/foo%2B%21bar'), 'http://example.com/foo+!bar') + def test_query_encoding(self): + self.assertEqual( + urinorm('http://example.com/?openid.sreg.fullname=Unícöde+Person'), + 'http://example.com/?openid.sreg.fullname=Un%C3%ADc%C3%B6de+Person') + self.assertEqual( + urinorm('http://example.com/?openid.sreg.fullname=Un%C3%ADc%C3%B6de+Person'), + 'http://example.com/?openid.sreg.fullname=Un%C3%ADc%C3%B6de+Person') + def test_illegal_characters(self): six.assertRaisesRegex(self, ValueError, 'Illegal characters in URI', urinorm, 'http://.com/') diff --git a/openid/urinorm.py b/openid/urinorm.py index 96787411..22b3dad1 100644 --- a/openid/urinorm.py +++ b/openid/urinorm.py @@ -132,8 +132,14 @@ def urinorm(uri): path = '/' _check_disallowed_characters(path, 'path') - # Normalize query - data = parse_qsl(split_uri.query) + # Normalize query. On Python 2, `urlencode` without `doseq=True` + # requires values to be convertible to native strings using `str()`. + if isinstance(split_uri.query, str): + # Python 3 branch + data = parse_qsl(split_uri.query) + else: + # Python 2 branch + data = parse_qsl(split_uri.query.encode('utf-8')) query = urlencode(data) _check_disallowed_characters(query, 'query')
Identity:%(identity)s