regression.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # -*- coding: utf-8 -*-
  2. """
  3. flask.testsuite.regression
  4. ~~~~~~~~~~~~~~~~~~~~~~~~~~
  5. Tests regressions.
  6. :copyright: (c) 2011 by Armin Ronacher.
  7. :license: BSD, see LICENSE for more details.
  8. """
  9. import os
  10. import gc
  11. import sys
  12. import flask
  13. import threading
  14. import unittest
  15. from werkzeug.exceptions import NotFound
  16. from flask.testsuite import FlaskTestCase
  17. _gc_lock = threading.Lock()
  18. class _NoLeakAsserter(object):
  19. def __init__(self, testcase):
  20. self.testcase = testcase
  21. def __enter__(self):
  22. gc.disable()
  23. _gc_lock.acquire()
  24. loc = flask._request_ctx_stack._local
  25. # Force Python to track this dictionary at all times.
  26. # This is necessary since Python only starts tracking
  27. # dicts if they contain mutable objects. It's a horrible,
  28. # horrible hack but makes this kinda testable.
  29. loc.__storage__['FOOO'] = [1, 2, 3]
  30. gc.collect()
  31. self.old_objects = len(gc.get_objects())
  32. def __exit__(self, exc_type, exc_value, tb):
  33. if not hasattr(sys, 'getrefcount'):
  34. gc.collect()
  35. new_objects = len(gc.get_objects())
  36. if new_objects > self.old_objects:
  37. self.testcase.fail('Example code leaked')
  38. _gc_lock.release()
  39. gc.enable()
  40. class MemoryTestCase(FlaskTestCase):
  41. def assert_no_leak(self):
  42. return _NoLeakAsserter(self)
  43. def test_memory_consumption(self):
  44. app = flask.Flask(__name__)
  45. @app.route('/')
  46. def index():
  47. return flask.render_template('simple_template.html', whiskey=42)
  48. def fire():
  49. with app.test_client() as c:
  50. rv = c.get('/')
  51. self.assert_equal(rv.status_code, 200)
  52. self.assert_equal(rv.data, b'<h1>42</h1>')
  53. # Trigger caches
  54. fire()
  55. # This test only works on CPython 2.7.
  56. if sys.version_info >= (2, 7) and \
  57. not hasattr(sys, 'pypy_translation_info'):
  58. with self.assert_no_leak():
  59. for x in range(10):
  60. fire()
  61. def test_safe_join_toplevel_pardir(self):
  62. from flask.helpers import safe_join
  63. with self.assert_raises(NotFound):
  64. safe_join('/foo', '..')
  65. class ExceptionTestCase(FlaskTestCase):
  66. def test_aborting(self):
  67. class Foo(Exception):
  68. whatever = 42
  69. app = flask.Flask(__name__)
  70. app.testing = True
  71. @app.errorhandler(Foo)
  72. def handle_foo(e):
  73. return str(e.whatever)
  74. @app.route('/')
  75. def index():
  76. raise flask.abort(flask.redirect(flask.url_for('test')))
  77. @app.route('/test')
  78. def test():
  79. raise Foo()
  80. with app.test_client() as c:
  81. rv = c.get('/')
  82. self.assertEqual(rv.headers['Location'], 'http://localhost/test')
  83. rv = c.get('/test')
  84. self.assertEqual(rv.data, b'42')
  85. def suite():
  86. suite = unittest.TestSuite()
  87. if os.environ.get('RUN_FLASK_MEMORY_TESTS') == '1':
  88. suite.addTest(unittest.makeSuite(MemoryTestCase))
  89. suite.addTest(unittest.makeSuite(ExceptionTestCase))
  90. return suite