Todo: 集成多平台 解决因SaiNiu线程抢占资源问题 本地提交测试环境打包 和 正式打包脚本与正式环境打包bat 提交Python32环境包 改进多日志文件生成情况修改打包日志细节
This commit is contained in:
12
Utils/PythonNew32/Lib/test/test_asyncio/__init__.py
Normal file
12
Utils/PythonNew32/Lib/test/test_asyncio/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import os
|
||||
from test import support
|
||||
from test.support import load_package_tests
|
||||
from test.support import import_helper
|
||||
|
||||
support.requires_working_socket(module=True)
|
||||
|
||||
# Skip tests if we don't have concurrent.futures.
|
||||
import_helper.import_module('concurrent.futures')
|
||||
|
||||
def load_tests(*args):
|
||||
return load_package_tests(os.path.dirname(__file__), *args)
|
||||
4
Utils/PythonNew32/Lib/test/test_asyncio/__main__.py
Normal file
4
Utils/PythonNew32/Lib/test/test_asyncio/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from . import load_tests
|
||||
import unittest
|
||||
|
||||
unittest.main()
|
||||
8
Utils/PythonNew32/Lib/test/test_asyncio/echo.py
Normal file
8
Utils/PythonNew32/Lib/test/test_asyncio/echo.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
while True:
|
||||
buf = os.read(0, 1024)
|
||||
if not buf:
|
||||
break
|
||||
os.write(1, buf)
|
||||
6
Utils/PythonNew32/Lib/test/test_asyncio/echo2.py
Normal file
6
Utils/PythonNew32/Lib/test/test_asyncio/echo2.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
buf = os.read(0, 1024)
|
||||
os.write(1, b'OUT:'+buf)
|
||||
os.write(2, b'ERR:'+buf)
|
||||
11
Utils/PythonNew32/Lib/test/test_asyncio/echo3.py
Normal file
11
Utils/PythonNew32/Lib/test/test_asyncio/echo3.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import os
|
||||
|
||||
if __name__ == '__main__':
|
||||
while True:
|
||||
buf = os.read(0, 1024)
|
||||
if not buf:
|
||||
break
|
||||
try:
|
||||
os.write(1, b'OUT:'+buf)
|
||||
except OSError as ex:
|
||||
os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii'))
|
||||
269
Utils/PythonNew32/Lib/test/test_asyncio/functional.py
Normal file
269
Utils/PythonNew32/Lib/test/test_asyncio/functional.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import asyncio
|
||||
import asyncio.events
|
||||
import contextlib
|
||||
import os
|
||||
import pprint
|
||||
import select
|
||||
import socket
|
||||
import tempfile
|
||||
import threading
|
||||
from test import support
|
||||
|
||||
|
||||
class FunctionalTestCaseMixin:
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
def run_loop_briefly(self, *, delay=0.01):
|
||||
self.loop.run_until_complete(asyncio.sleep(delay))
|
||||
|
||||
def loop_exception_handler(self, loop, context):
|
||||
self.__unhandled_exceptions.append(context)
|
||||
self.loop.default_exception_handler(context)
|
||||
|
||||
def setUp(self):
|
||||
self.loop = self.new_loop()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
self.loop.set_exception_handler(self.loop_exception_handler)
|
||||
self.__unhandled_exceptions = []
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
self.loop.close()
|
||||
|
||||
if self.__unhandled_exceptions:
|
||||
print('Unexpected calls to loop.call_exception_handler():')
|
||||
pprint.pprint(self.__unhandled_exceptions)
|
||||
self.fail('unexpected calls to loop.call_exception_handler()')
|
||||
|
||||
finally:
|
||||
asyncio.set_event_loop(None)
|
||||
self.loop = None
|
||||
|
||||
def tcp_server(self, server_prog, *,
|
||||
family=socket.AF_INET,
|
||||
addr=None,
|
||||
timeout=support.LOOPBACK_TIMEOUT,
|
||||
backlog=1,
|
||||
max_clients=10):
|
||||
|
||||
if addr is None:
|
||||
if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
addr = tmp.name
|
||||
else:
|
||||
addr = ('127.0.0.1', 0)
|
||||
|
||||
sock = socket.create_server(addr, family=family, backlog=backlog)
|
||||
if timeout is None:
|
||||
raise RuntimeError('timeout is required')
|
||||
if timeout <= 0:
|
||||
raise RuntimeError('only blocking sockets are supported')
|
||||
sock.settimeout(timeout)
|
||||
|
||||
return TestThreadedServer(
|
||||
self, sock, server_prog, timeout, max_clients)
|
||||
|
||||
def tcp_client(self, client_prog,
|
||||
family=socket.AF_INET,
|
||||
timeout=support.LOOPBACK_TIMEOUT):
|
||||
|
||||
sock = socket.socket(family, socket.SOCK_STREAM)
|
||||
|
||||
if timeout is None:
|
||||
raise RuntimeError('timeout is required')
|
||||
if timeout <= 0:
|
||||
raise RuntimeError('only blocking sockets are supported')
|
||||
sock.settimeout(timeout)
|
||||
|
||||
return TestThreadedClient(
|
||||
self, sock, client_prog, timeout)
|
||||
|
||||
def unix_server(self, *args, **kwargs):
|
||||
if not hasattr(socket, 'AF_UNIX'):
|
||||
raise NotImplementedError
|
||||
return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
|
||||
|
||||
def unix_client(self, *args, **kwargs):
|
||||
if not hasattr(socket, 'AF_UNIX'):
|
||||
raise NotImplementedError
|
||||
return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unix_sock_name(self):
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
fn = os.path.join(td, 'sock')
|
||||
try:
|
||||
yield fn
|
||||
finally:
|
||||
try:
|
||||
os.unlink(fn)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _abort_socket_test(self, ex):
|
||||
try:
|
||||
self.loop.stop()
|
||||
finally:
|
||||
self.fail(ex)
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Socket Testing Utilities
|
||||
##############################################################################
|
||||
|
||||
|
||||
class TestSocketWrapper:
|
||||
|
||||
def __init__(self, sock):
|
||||
self.__sock = sock
|
||||
|
||||
def recv_all(self, n):
|
||||
buf = b''
|
||||
while len(buf) < n:
|
||||
data = self.recv(n - len(buf))
|
||||
if data == b'':
|
||||
raise ConnectionAbortedError
|
||||
buf += data
|
||||
return buf
|
||||
|
||||
def start_tls(self, ssl_context, *,
|
||||
server_side=False,
|
||||
server_hostname=None):
|
||||
|
||||
ssl_sock = ssl_context.wrap_socket(
|
||||
self.__sock, server_side=server_side,
|
||||
server_hostname=server_hostname,
|
||||
do_handshake_on_connect=False)
|
||||
|
||||
try:
|
||||
ssl_sock.do_handshake()
|
||||
except:
|
||||
ssl_sock.close()
|
||||
raise
|
||||
finally:
|
||||
self.__sock.close()
|
||||
|
||||
self.__sock = ssl_sock
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.__sock, name)
|
||||
|
||||
def __repr__(self):
|
||||
return '<{} {!r}>'.format(type(self).__name__, self.__sock)
|
||||
|
||||
|
||||
class SocketThread(threading.Thread):
|
||||
|
||||
def stop(self):
|
||||
self._active = False
|
||||
self.join()
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.stop()
|
||||
|
||||
|
||||
class TestThreadedClient(SocketThread):
|
||||
|
||||
def __init__(self, test, sock, prog, timeout):
|
||||
threading.Thread.__init__(self, None, None, 'test-client')
|
||||
self.daemon = True
|
||||
|
||||
self._timeout = timeout
|
||||
self._sock = sock
|
||||
self._active = True
|
||||
self._prog = prog
|
||||
self._test = test
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self._prog(TestSocketWrapper(self._sock))
|
||||
except Exception as ex:
|
||||
self._test._abort_socket_test(ex)
|
||||
|
||||
|
||||
class TestThreadedServer(SocketThread):
|
||||
|
||||
def __init__(self, test, sock, prog, timeout, max_clients):
|
||||
threading.Thread.__init__(self, None, None, 'test-server')
|
||||
self.daemon = True
|
||||
|
||||
self._clients = 0
|
||||
self._finished_clients = 0
|
||||
self._max_clients = max_clients
|
||||
self._timeout = timeout
|
||||
self._sock = sock
|
||||
self._active = True
|
||||
|
||||
self._prog = prog
|
||||
|
||||
self._s1, self._s2 = socket.socketpair()
|
||||
self._s1.setblocking(False)
|
||||
|
||||
self._test = test
|
||||
|
||||
def stop(self):
|
||||
try:
|
||||
if self._s2 and self._s2.fileno() != -1:
|
||||
try:
|
||||
self._s2.send(b'stop')
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
super().stop()
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
with self._sock:
|
||||
self._sock.setblocking(False)
|
||||
self._run()
|
||||
finally:
|
||||
self._s1.close()
|
||||
self._s2.close()
|
||||
|
||||
def _run(self):
|
||||
while self._active:
|
||||
if self._clients >= self._max_clients:
|
||||
return
|
||||
|
||||
r, w, x = select.select(
|
||||
[self._sock, self._s1], [], [], self._timeout)
|
||||
|
||||
if self._s1 in r:
|
||||
return
|
||||
|
||||
if self._sock in r:
|
||||
try:
|
||||
conn, addr = self._sock.accept()
|
||||
except BlockingIOError:
|
||||
continue
|
||||
except TimeoutError:
|
||||
if not self._active:
|
||||
return
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
self._clients += 1
|
||||
conn.settimeout(self._timeout)
|
||||
try:
|
||||
with conn:
|
||||
self._handle_client(conn)
|
||||
except Exception as ex:
|
||||
self._active = False
|
||||
try:
|
||||
raise
|
||||
finally:
|
||||
self._test._abort_socket_test(ex)
|
||||
|
||||
def _handle_client(self, sock):
|
||||
self._prog(TestSocketWrapper(sock))
|
||||
|
||||
@property
|
||||
def addr(self):
|
||||
return self._sock.getsockname()
|
||||
2236
Utils/PythonNew32/Lib/test/test_asyncio/test_base_events.py
Normal file
2236
Utils/PythonNew32/Lib/test/test_asyncio/test_base_events.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,89 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from test.test_asyncio import functional as func_tests
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class ReceiveStuffProto(asyncio.BufferedProtocol):
|
||||
def __init__(self, cb, con_lost_fut):
|
||||
self.cb = cb
|
||||
self.con_lost_fut = con_lost_fut
|
||||
|
||||
def get_buffer(self, sizehint):
|
||||
self.buffer = bytearray(100)
|
||||
return self.buffer
|
||||
|
||||
def buffer_updated(self, nbytes):
|
||||
self.cb(self.buffer[:nbytes])
|
||||
|
||||
def connection_lost(self, exc):
|
||||
if exc is None:
|
||||
self.con_lost_fut.set_result(None)
|
||||
else:
|
||||
self.con_lost_fut.set_exception(exc)
|
||||
|
||||
|
||||
class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin):
|
||||
|
||||
def new_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_buffered_proto_create_connection(self):
|
||||
|
||||
NOISE = b'12345678+' * 1024
|
||||
|
||||
async def client(addr):
|
||||
data = b''
|
||||
|
||||
def on_buf(buf):
|
||||
nonlocal data
|
||||
data += buf
|
||||
if data == NOISE:
|
||||
tr.write(b'1')
|
||||
|
||||
conn_lost_fut = self.loop.create_future()
|
||||
|
||||
tr, pr = await self.loop.create_connection(
|
||||
lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr)
|
||||
|
||||
await conn_lost_fut
|
||||
|
||||
async def on_server_client(reader, writer):
|
||||
writer.write(NOISE)
|
||||
await reader.readexactly(1)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
srv = self.loop.run_until_complete(
|
||||
asyncio.start_server(
|
||||
on_server_client, '127.0.0.1', 0))
|
||||
|
||||
addr = srv.sockets[0].getsockname()
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(addr), 5))
|
||||
|
||||
srv.close()
|
||||
self.loop.run_until_complete(srv.wait_closed())
|
||||
|
||||
|
||||
class BufferedProtocolSelectorTests(BaseTestBufferedProtocol,
|
||||
unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.SelectorEventLoop()
|
||||
|
||||
|
||||
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
|
||||
class BufferedProtocolProactorTests(BaseTestBufferedProtocol,
|
||||
unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.ProactorEventLoop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
38
Utils/PythonNew32/Lib/test/test_asyncio/test_context.py
Normal file
38
Utils/PythonNew32/Lib/test/test_asyncio/test_context.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import asyncio
|
||||
import decimal
|
||||
import unittest
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
@unittest.skipUnless(decimal.HAVE_CONTEXTVAR, "decimal is built with a thread-local context")
|
||||
class DecimalContextTest(unittest.TestCase):
|
||||
|
||||
def test_asyncio_task_decimal_context(self):
|
||||
async def fractions(t, precision, x, y):
|
||||
with decimal.localcontext() as ctx:
|
||||
ctx.prec = precision
|
||||
a = decimal.Decimal(x) / decimal.Decimal(y)
|
||||
await asyncio.sleep(t)
|
||||
b = decimal.Decimal(x) / decimal.Decimal(y ** 2)
|
||||
return a, b
|
||||
|
||||
async def main():
|
||||
r1, r2 = await asyncio.gather(
|
||||
fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3))
|
||||
|
||||
return r1, r2
|
||||
|
||||
r1, r2 = asyncio.run(main())
|
||||
|
||||
self.assertEqual(str(r1[0]), '0.333')
|
||||
self.assertEqual(str(r1[1]), '0.111')
|
||||
|
||||
self.assertEqual(str(r2[0]), '0.333333')
|
||||
self.assertEqual(str(r2[1]), '0.111111')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -0,0 +1,430 @@
|
||||
"""Tests for base_events.py"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import unittest
|
||||
|
||||
from unittest import mock
|
||||
from asyncio import tasks
|
||||
from test.test_asyncio import utils as test_utils
|
||||
from test.support.script_helper import assert_python_ok
|
||||
|
||||
MOCK_ANY = mock.ANY
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class EagerTaskFactoryLoopTests:
|
||||
|
||||
Task = None
|
||||
|
||||
def run_coro(self, coro):
|
||||
"""
|
||||
Helper method to run the `coro` coroutine in the test event loop.
|
||||
It helps with making sure the event loop is running before starting
|
||||
to execute `coro`. This is important for testing the eager step
|
||||
functionality, since an eager step is taken only if the event loop
|
||||
is already running.
|
||||
"""
|
||||
|
||||
async def coro_runner():
|
||||
self.assertTrue(asyncio.get_event_loop().is_running())
|
||||
return await coro
|
||||
|
||||
return self.loop.run_until_complete(coro)
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.eager_task_factory = asyncio.create_eager_task_factory(self.Task)
|
||||
self.loop.set_task_factory(self.eager_task_factory)
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
def test_eager_task_factory_set(self):
|
||||
self.assertIsNotNone(self.eager_task_factory)
|
||||
self.assertIs(self.loop.get_task_factory(), self.eager_task_factory)
|
||||
|
||||
async def noop(): pass
|
||||
|
||||
async def run():
|
||||
t = self.loop.create_task(noop())
|
||||
self.assertIsInstance(t, self.Task)
|
||||
await t
|
||||
|
||||
self.run_coro(run())
|
||||
|
||||
def test_await_future_during_eager_step(self):
|
||||
|
||||
async def set_result(fut, val):
|
||||
fut.set_result(val)
|
||||
|
||||
async def run():
|
||||
fut = self.loop.create_future()
|
||||
t = self.loop.create_task(set_result(fut, 'my message'))
|
||||
# assert the eager step completed the task
|
||||
self.assertTrue(t.done())
|
||||
return await fut
|
||||
|
||||
self.assertEqual(self.run_coro(run()), 'my message')
|
||||
|
||||
def test_eager_completion(self):
|
||||
|
||||
async def coro():
|
||||
return 'hello'
|
||||
|
||||
async def run():
|
||||
t = self.loop.create_task(coro())
|
||||
# assert the eager step completed the task
|
||||
self.assertTrue(t.done())
|
||||
return await t
|
||||
|
||||
self.assertEqual(self.run_coro(run()), 'hello')
|
||||
|
||||
def test_block_after_eager_step(self):
|
||||
|
||||
async def coro():
|
||||
await asyncio.sleep(0.1)
|
||||
return 'finished after blocking'
|
||||
|
||||
async def run():
|
||||
t = self.loop.create_task(coro())
|
||||
self.assertFalse(t.done())
|
||||
result = await t
|
||||
self.assertTrue(t.done())
|
||||
return result
|
||||
|
||||
self.assertEqual(self.run_coro(run()), 'finished after blocking')
|
||||
|
||||
def test_cancellation_after_eager_completion(self):
|
||||
|
||||
async def coro():
|
||||
return 'finished without blocking'
|
||||
|
||||
async def run():
|
||||
t = self.loop.create_task(coro())
|
||||
t.cancel()
|
||||
result = await t
|
||||
# finished task can't be cancelled
|
||||
self.assertFalse(t.cancelled())
|
||||
return result
|
||||
|
||||
self.assertEqual(self.run_coro(run()), 'finished without blocking')
|
||||
|
||||
def test_cancellation_after_eager_step_blocks(self):
|
||||
|
||||
async def coro():
|
||||
await asyncio.sleep(0.1)
|
||||
return 'finished after blocking'
|
||||
|
||||
async def run():
|
||||
t = self.loop.create_task(coro())
|
||||
t.cancel('cancellation message')
|
||||
self.assertGreater(t.cancelling(), 0)
|
||||
result = await t
|
||||
|
||||
with self.assertRaises(asyncio.CancelledError) as cm:
|
||||
self.run_coro(run())
|
||||
|
||||
self.assertEqual('cancellation message', cm.exception.args[0])
|
||||
|
||||
def test_current_task(self):
|
||||
captured_current_task = None
|
||||
|
||||
async def coro():
|
||||
nonlocal captured_current_task
|
||||
captured_current_task = asyncio.current_task()
|
||||
# verify the task before and after blocking is identical
|
||||
await asyncio.sleep(0.1)
|
||||
self.assertIs(asyncio.current_task(), captured_current_task)
|
||||
|
||||
async def run():
|
||||
t = self.loop.create_task(coro())
|
||||
self.assertIs(captured_current_task, t)
|
||||
await t
|
||||
|
||||
self.run_coro(run())
|
||||
captured_current_task = None
|
||||
|
||||
def test_all_tasks_with_eager_completion(self):
|
||||
captured_all_tasks = None
|
||||
|
||||
async def coro():
|
||||
nonlocal captured_all_tasks
|
||||
captured_all_tasks = asyncio.all_tasks()
|
||||
|
||||
async def run():
|
||||
t = self.loop.create_task(coro())
|
||||
self.assertIn(t, captured_all_tasks)
|
||||
self.assertNotIn(t, asyncio.all_tasks())
|
||||
|
||||
self.run_coro(run())
|
||||
|
||||
def test_all_tasks_with_blocking(self):
|
||||
captured_eager_all_tasks = None
|
||||
|
||||
async def coro(fut1, fut2):
|
||||
nonlocal captured_eager_all_tasks
|
||||
captured_eager_all_tasks = asyncio.all_tasks()
|
||||
await fut1
|
||||
fut2.set_result(None)
|
||||
|
||||
async def run():
|
||||
fut1 = self.loop.create_future()
|
||||
fut2 = self.loop.create_future()
|
||||
t = self.loop.create_task(coro(fut1, fut2))
|
||||
self.assertIn(t, captured_eager_all_tasks)
|
||||
self.assertIn(t, asyncio.all_tasks())
|
||||
fut1.set_result(None)
|
||||
await fut2
|
||||
self.assertNotIn(t, asyncio.all_tasks())
|
||||
|
||||
self.run_coro(run())
|
||||
|
||||
def test_context_vars(self):
|
||||
cv = contextvars.ContextVar('cv', default=0)
|
||||
|
||||
coro_first_step_ran = False
|
||||
coro_second_step_ran = False
|
||||
|
||||
async def coro():
|
||||
nonlocal coro_first_step_ran
|
||||
nonlocal coro_second_step_ran
|
||||
self.assertEqual(cv.get(), 1)
|
||||
cv.set(2)
|
||||
self.assertEqual(cv.get(), 2)
|
||||
coro_first_step_ran = True
|
||||
await asyncio.sleep(0.1)
|
||||
self.assertEqual(cv.get(), 2)
|
||||
cv.set(3)
|
||||
self.assertEqual(cv.get(), 3)
|
||||
coro_second_step_ran = True
|
||||
|
||||
async def run():
|
||||
cv.set(1)
|
||||
t = self.loop.create_task(coro())
|
||||
self.assertTrue(coro_first_step_ran)
|
||||
self.assertFalse(coro_second_step_ran)
|
||||
self.assertEqual(cv.get(), 1)
|
||||
await t
|
||||
self.assertTrue(coro_second_step_ran)
|
||||
self.assertEqual(cv.get(), 1)
|
||||
|
||||
self.run_coro(run())
|
||||
|
||||
def test_staggered_race_with_eager_tasks(self):
|
||||
# See https://github.com/python/cpython/issues/124309
|
||||
|
||||
async def fail():
|
||||
await asyncio.sleep(0)
|
||||
raise ValueError("no good")
|
||||
|
||||
async def blocked():
|
||||
fut = asyncio.Future()
|
||||
await fut
|
||||
|
||||
async def run():
|
||||
winner, index, excs = await asyncio.staggered.staggered_race(
|
||||
[
|
||||
lambda: blocked(),
|
||||
lambda: asyncio.sleep(1, result="sleep1"),
|
||||
lambda: fail()
|
||||
],
|
||||
delay=0.25
|
||||
)
|
||||
self.assertEqual(winner, 'sleep1')
|
||||
self.assertEqual(index, 1)
|
||||
self.assertIsNone(excs[index])
|
||||
self.assertIsInstance(excs[0], asyncio.CancelledError)
|
||||
self.assertIsInstance(excs[2], ValueError)
|
||||
|
||||
self.run_coro(run())
|
||||
|
||||
def test_staggered_race_with_eager_tasks_no_delay(self):
|
||||
# See https://github.com/python/cpython/issues/124309
|
||||
async def fail():
|
||||
raise ValueError("no good")
|
||||
|
||||
async def run():
|
||||
winner, index, excs = await asyncio.staggered.staggered_race(
|
||||
[
|
||||
lambda: fail(),
|
||||
lambda: asyncio.sleep(1, result="sleep1"),
|
||||
lambda: asyncio.sleep(0, result="sleep0"),
|
||||
],
|
||||
delay=None
|
||||
)
|
||||
self.assertEqual(winner, 'sleep1')
|
||||
self.assertEqual(index, 1)
|
||||
self.assertIsNone(excs[index])
|
||||
self.assertIsInstance(excs[0], ValueError)
|
||||
self.assertEqual(len(excs), 2)
|
||||
|
||||
self.run_coro(run())
|
||||
|
||||
|
||||
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
|
||||
Task = tasks._PyTask
|
||||
|
||||
|
||||
@unittest.skipUnless(hasattr(tasks, '_CTask'),
|
||||
'requires the C _asyncio module')
|
||||
class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
|
||||
Task = getattr(tasks, '_CTask', None)
|
||||
|
||||
def test_issue105987(self):
|
||||
code = """if 1:
|
||||
from _asyncio import _swap_current_task
|
||||
|
||||
class DummyTask:
|
||||
pass
|
||||
|
||||
class DummyLoop:
|
||||
pass
|
||||
|
||||
l = DummyLoop()
|
||||
_swap_current_task(l, DummyTask())
|
||||
t = _swap_current_task(l, None)
|
||||
"""
|
||||
|
||||
_, out, err = assert_python_ok("-c", code)
|
||||
self.assertFalse(err)
|
||||
|
||||
def test_issue122332(self):
|
||||
async def coro():
|
||||
pass
|
||||
|
||||
async def run():
|
||||
task = self.loop.create_task(coro())
|
||||
await task
|
||||
self.assertIsNone(task.get_coro())
|
||||
|
||||
self.run_coro(run())
|
||||
|
||||
def test_name(self):
|
||||
name = None
|
||||
async def coro():
|
||||
nonlocal name
|
||||
name = asyncio.current_task().get_name()
|
||||
|
||||
async def main():
|
||||
task = self.loop.create_task(coro(), name="test name")
|
||||
self.assertEqual(name, "test name")
|
||||
await task
|
||||
|
||||
self.run_coro(coro())
|
||||
|
||||
class AsyncTaskCounter:
|
||||
def __init__(self, loop, *, task_class, eager):
|
||||
self.suspense_count = 0
|
||||
self.task_count = 0
|
||||
|
||||
def CountingTask(*args, eager_start=False, **kwargs):
|
||||
if not eager_start:
|
||||
self.task_count += 1
|
||||
kwargs["eager_start"] = eager_start
|
||||
return task_class(*args, **kwargs)
|
||||
|
||||
if eager:
|
||||
factory = asyncio.create_eager_task_factory(CountingTask)
|
||||
else:
|
||||
def factory(loop, coro, **kwargs):
|
||||
return CountingTask(coro, loop=loop, **kwargs)
|
||||
loop.set_task_factory(factory)
|
||||
|
||||
def get(self):
|
||||
return self.task_count
|
||||
|
||||
|
||||
async def awaitable_chain(depth):
|
||||
if depth == 0:
|
||||
return 0
|
||||
return 1 + await awaitable_chain(depth - 1)
|
||||
|
||||
|
||||
async def recursive_taskgroups(width, depth):
|
||||
if depth == 0:
|
||||
return
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
futures = [
|
||||
tg.create_task(recursive_taskgroups(width, depth - 1))
|
||||
for _ in range(width)
|
||||
]
|
||||
|
||||
|
||||
async def recursive_gather(width, depth):
|
||||
if depth == 0:
|
||||
return
|
||||
|
||||
await asyncio.gather(
|
||||
*[recursive_gather(width, depth - 1) for _ in range(width)]
|
||||
)
|
||||
|
||||
|
||||
class BaseTaskCountingTests:
|
||||
|
||||
Task = None
|
||||
eager = None
|
||||
expected_task_count = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager)
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
def test_awaitables_chain(self):
|
||||
observed_depth = self.loop.run_until_complete(awaitable_chain(100))
|
||||
self.assertEqual(observed_depth, 100)
|
||||
self.assertEqual(self.counter.get(), 0 if self.eager else 1)
|
||||
|
||||
def test_recursive_taskgroups(self):
|
||||
num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4))
|
||||
self.assertEqual(self.counter.get(), self.expected_task_count)
|
||||
|
||||
def test_recursive_gather(self):
|
||||
self.loop.run_until_complete(recursive_gather(5, 4))
|
||||
self.assertEqual(self.counter.get(), self.expected_task_count)
|
||||
|
||||
|
||||
class BaseNonEagerTaskFactoryTests(BaseTaskCountingTests):
|
||||
eager = False
|
||||
expected_task_count = 781 # 1 + 5 + 5^2 + 5^3 + 5^4
|
||||
|
||||
|
||||
class BaseEagerTaskFactoryTests(BaseTaskCountingTests):
|
||||
eager = True
|
||||
expected_task_count = 0
|
||||
|
||||
|
||||
class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
|
||||
Task = asyncio.Task
|
||||
|
||||
|
||||
class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
|
||||
Task = asyncio.Task
|
||||
|
||||
|
||||
class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
|
||||
Task = tasks._PyTask
|
||||
|
||||
|
||||
class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
|
||||
Task = tasks._PyTask
|
||||
|
||||
|
||||
@unittest.skipUnless(hasattr(tasks, '_CTask'),
|
||||
'requires the C _asyncio module')
|
||||
class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
|
||||
Task = getattr(tasks, '_CTask', None)
|
||||
|
||||
|
||||
@unittest.skipUnless(hasattr(tasks, '_CTask'),
|
||||
'requires the C _asyncio module')
|
||||
class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
|
||||
Task = getattr(tasks, '_CTask', None)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
3053
Utils/PythonNew32/Lib/test/test_asyncio/test_events.py
Normal file
3053
Utils/PythonNew32/Lib/test/test_asyncio/test_events.py
Normal file
File diff suppressed because it is too large
Load Diff
1148
Utils/PythonNew32/Lib/test/test_asyncio/test_futures.py
Normal file
1148
Utils/PythonNew32/Lib/test/test_asyncio/test_futures.py
Normal file
File diff suppressed because it is too large
Load Diff
95
Utils/PythonNew32/Lib/test/test_asyncio/test_futures2.py
Normal file
95
Utils/PythonNew32/Lib/test/test_asyncio/test_futures2.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# IsolatedAsyncioTestCase based tests
|
||||
import asyncio
|
||||
import contextvars
|
||||
import traceback
|
||||
import unittest
|
||||
from asyncio import tasks
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class FutureTests:
|
||||
|
||||
async def test_future_traceback(self):
|
||||
|
||||
async def raise_exc():
|
||||
raise TypeError(42)
|
||||
|
||||
future = self.cls(raise_exc())
|
||||
|
||||
for _ in range(5):
|
||||
try:
|
||||
await future
|
||||
except TypeError as e:
|
||||
tb = ''.join(traceback.format_tb(e.__traceback__))
|
||||
self.assertEqual(tb.count("await future"), 1)
|
||||
else:
|
||||
self.fail('TypeError was not raised')
|
||||
|
||||
async def test_task_exc_handler_correct_context(self):
|
||||
# see https://github.com/python/cpython/issues/96704
|
||||
name = contextvars.ContextVar('name', default='foo')
|
||||
exc_handler_called = False
|
||||
|
||||
def exc_handler(*args):
|
||||
self.assertEqual(name.get(), 'bar')
|
||||
nonlocal exc_handler_called
|
||||
exc_handler_called = True
|
||||
|
||||
async def task():
|
||||
name.set('bar')
|
||||
1/0
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.set_exception_handler(exc_handler)
|
||||
self.cls(task())
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(exc_handler_called)
|
||||
|
||||
async def test_handle_exc_handler_correct_context(self):
|
||||
# see https://github.com/python/cpython/issues/96704
|
||||
name = contextvars.ContextVar('name', default='foo')
|
||||
exc_handler_called = False
|
||||
|
||||
def exc_handler(*args):
|
||||
self.assertEqual(name.get(), 'bar')
|
||||
nonlocal exc_handler_called
|
||||
exc_handler_called = True
|
||||
|
||||
def callback():
|
||||
name.set('bar')
|
||||
1/0
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.set_exception_handler(exc_handler)
|
||||
loop.call_soon(callback)
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(exc_handler_called)
|
||||
|
||||
@unittest.skipUnless(hasattr(tasks, '_CTask'),
|
||||
'requires the C _asyncio module')
|
||||
class CFutureTests(FutureTests, unittest.IsolatedAsyncioTestCase):
|
||||
cls = tasks._CTask
|
||||
|
||||
class PyFutureTests(FutureTests, unittest.IsolatedAsyncioTestCase):
|
||||
cls = tasks._PyTask
|
||||
|
||||
class FutureReprTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_recursive_repr_for_pending_tasks(self):
|
||||
# The call crashes if the guard for recursive call
|
||||
# in base_futures:_future_repr_info is absent
|
||||
# See Also: https://bugs.python.org/issue42183
|
||||
|
||||
async def func():
|
||||
return asyncio.all_tasks()
|
||||
|
||||
# The repr() call should not raise RecursionError at first.
|
||||
waiter = await asyncio.wait_for(asyncio.Task(func()),timeout=10)
|
||||
self.assertIn('...', repr(waiter))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
1825
Utils/PythonNew32/Lib/test/test_asyncio/test_locks.py
Normal file
1825
Utils/PythonNew32/Lib/test/test_asyncio/test_locks.py
Normal file
File diff suppressed because it is too large
Load Diff
212
Utils/PythonNew32/Lib/test/test_asyncio/test_pep492.py
Normal file
212
Utils/PythonNew32/Lib/test/test_asyncio/test_pep492.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""Tests support for new syntax introduced by PEP 492."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import asyncio
|
||||
from test.test_asyncio import utils as test_utils
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
# Test that asyncio.iscoroutine() uses collections.abc.Coroutine
|
||||
class FakeCoro:
|
||||
def send(self, value):
|
||||
pass
|
||||
|
||||
def throw(self, typ, val=None, tb=None):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def __await__(self):
|
||||
yield
|
||||
|
||||
|
||||
class BaseTest(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.BaseEventLoop()
|
||||
self.loop._process_events = mock.Mock()
|
||||
self.loop._selector = mock.Mock()
|
||||
self.loop._selector.select.return_value = ()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
||||
class LockTests(BaseTest):
|
||||
|
||||
def test_context_manager_async_with(self):
|
||||
primitives = [
|
||||
asyncio.Lock(),
|
||||
asyncio.Condition(),
|
||||
asyncio.Semaphore(),
|
||||
asyncio.BoundedSemaphore(),
|
||||
]
|
||||
|
||||
async def test(lock):
|
||||
await asyncio.sleep(0.01)
|
||||
self.assertFalse(lock.locked())
|
||||
async with lock as _lock:
|
||||
self.assertIs(_lock, None)
|
||||
self.assertTrue(lock.locked())
|
||||
await asyncio.sleep(0.01)
|
||||
self.assertTrue(lock.locked())
|
||||
self.assertFalse(lock.locked())
|
||||
|
||||
for primitive in primitives:
|
||||
self.loop.run_until_complete(test(primitive))
|
||||
self.assertFalse(primitive.locked())
|
||||
|
||||
def test_context_manager_with_await(self):
|
||||
primitives = [
|
||||
asyncio.Lock(),
|
||||
asyncio.Condition(),
|
||||
asyncio.Semaphore(),
|
||||
asyncio.BoundedSemaphore(),
|
||||
]
|
||||
|
||||
async def test(lock):
|
||||
await asyncio.sleep(0.01)
|
||||
self.assertFalse(lock.locked())
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
"can't be used in 'await' expression"
|
||||
):
|
||||
with await lock:
|
||||
pass
|
||||
|
||||
for primitive in primitives:
|
||||
self.loop.run_until_complete(test(primitive))
|
||||
self.assertFalse(primitive.locked())
|
||||
|
||||
|
||||
class StreamReaderTests(BaseTest):
|
||||
|
||||
def test_readline(self):
|
||||
DATA = b'line1\nline2\nline3'
|
||||
|
||||
stream = asyncio.StreamReader(loop=self.loop)
|
||||
stream.feed_data(DATA)
|
||||
stream.feed_eof()
|
||||
|
||||
async def reader():
|
||||
data = []
|
||||
async for line in stream:
|
||||
data.append(line)
|
||||
return data
|
||||
|
||||
data = self.loop.run_until_complete(reader())
|
||||
self.assertEqual(data, [b'line1\n', b'line2\n', b'line3'])
|
||||
|
||||
|
||||
class CoroutineTests(BaseTest):
|
||||
|
||||
def test_iscoroutine(self):
|
||||
async def foo(): pass
|
||||
|
||||
f = foo()
|
||||
try:
|
||||
self.assertTrue(asyncio.iscoroutine(f))
|
||||
finally:
|
||||
f.close() # silence warning
|
||||
|
||||
self.assertTrue(asyncio.iscoroutine(FakeCoro()))
|
||||
|
||||
def test_iscoroutine_generator(self):
|
||||
def foo(): yield
|
||||
|
||||
self.assertFalse(asyncio.iscoroutine(foo()))
|
||||
|
||||
|
||||
def test_iscoroutinefunction(self):
|
||||
async def foo(): pass
|
||||
self.assertTrue(asyncio.iscoroutinefunction(foo))
|
||||
|
||||
def test_async_def_coroutines(self):
|
||||
async def bar():
|
||||
return 'spam'
|
||||
async def foo():
|
||||
return await bar()
|
||||
|
||||
# production mode
|
||||
data = self.loop.run_until_complete(foo())
|
||||
self.assertEqual(data, 'spam')
|
||||
|
||||
# debug mode
|
||||
self.loop.set_debug(True)
|
||||
data = self.loop.run_until_complete(foo())
|
||||
self.assertEqual(data, 'spam')
|
||||
|
||||
def test_debug_mode_manages_coroutine_origin_tracking(self):
|
||||
async def start():
|
||||
self.assertTrue(sys.get_coroutine_origin_tracking_depth() > 0)
|
||||
|
||||
self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
|
||||
self.loop.set_debug(True)
|
||||
self.loop.run_until_complete(start())
|
||||
self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
|
||||
|
||||
def test_types_coroutine(self):
|
||||
def gen():
|
||||
yield from ()
|
||||
return 'spam'
|
||||
|
||||
@types.coroutine
|
||||
def func():
|
||||
return gen()
|
||||
|
||||
async def coro():
|
||||
wrapper = func()
|
||||
self.assertIsInstance(wrapper, types._GeneratorWrapper)
|
||||
return await wrapper
|
||||
|
||||
data = self.loop.run_until_complete(coro())
|
||||
self.assertEqual(data, 'spam')
|
||||
|
||||
def test_task_print_stack(self):
|
||||
T = None
|
||||
|
||||
async def foo():
|
||||
f = T.get_stack(limit=1)
|
||||
try:
|
||||
self.assertEqual(f[0].f_code.co_name, 'foo')
|
||||
finally:
|
||||
f = None
|
||||
|
||||
async def runner():
|
||||
nonlocal T
|
||||
T = asyncio.ensure_future(foo(), loop=self.loop)
|
||||
await T
|
||||
|
||||
self.loop.run_until_complete(runner())
|
||||
|
||||
def test_double_await(self):
|
||||
async def afunc():
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def runner():
|
||||
coro = afunc()
|
||||
t = self.loop.create_task(coro)
|
||||
try:
|
||||
await asyncio.sleep(0)
|
||||
await coro
|
||||
finally:
|
||||
t.cancel()
|
||||
|
||||
self.loop.set_debug(True)
|
||||
with self.assertRaises(
|
||||
RuntimeError,
|
||||
msg='coroutine is being awaited already'):
|
||||
|
||||
self.loop.run_until_complete(runner())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
1092
Utils/PythonNew32/Lib/test/test_asyncio/test_proactor_events.py
Normal file
1092
Utils/PythonNew32/Lib/test/test_asyncio/test_proactor_events.py
Normal file
File diff suppressed because it is too large
Load Diff
67
Utils/PythonNew32/Lib/test/test_asyncio/test_protocols.py
Normal file
67
Utils/PythonNew32/Lib/test/test_asyncio/test_protocols.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
# not needed for the test file but added for uniformness with all other
|
||||
# asyncio test files for the sake of unified cleanup
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class ProtocolsAbsTests(unittest.TestCase):
|
||||
|
||||
def test_base_protocol(self):
|
||||
f = mock.Mock()
|
||||
p = asyncio.BaseProtocol()
|
||||
self.assertIsNone(p.connection_made(f))
|
||||
self.assertIsNone(p.connection_lost(f))
|
||||
self.assertIsNone(p.pause_writing())
|
||||
self.assertIsNone(p.resume_writing())
|
||||
self.assertFalse(hasattr(p, '__dict__'))
|
||||
|
||||
def test_protocol(self):
|
||||
f = mock.Mock()
|
||||
p = asyncio.Protocol()
|
||||
self.assertIsNone(p.connection_made(f))
|
||||
self.assertIsNone(p.connection_lost(f))
|
||||
self.assertIsNone(p.data_received(f))
|
||||
self.assertIsNone(p.eof_received())
|
||||
self.assertIsNone(p.pause_writing())
|
||||
self.assertIsNone(p.resume_writing())
|
||||
self.assertFalse(hasattr(p, '__dict__'))
|
||||
|
||||
def test_buffered_protocol(self):
|
||||
f = mock.Mock()
|
||||
p = asyncio.BufferedProtocol()
|
||||
self.assertIsNone(p.connection_made(f))
|
||||
self.assertIsNone(p.connection_lost(f))
|
||||
self.assertIsNone(p.get_buffer(100))
|
||||
self.assertIsNone(p.buffer_updated(150))
|
||||
self.assertIsNone(p.pause_writing())
|
||||
self.assertIsNone(p.resume_writing())
|
||||
self.assertFalse(hasattr(p, '__dict__'))
|
||||
|
||||
def test_datagram_protocol(self):
|
||||
f = mock.Mock()
|
||||
dp = asyncio.DatagramProtocol()
|
||||
self.assertIsNone(dp.connection_made(f))
|
||||
self.assertIsNone(dp.connection_lost(f))
|
||||
self.assertIsNone(dp.error_received(f))
|
||||
self.assertIsNone(dp.datagram_received(f, f))
|
||||
self.assertFalse(hasattr(dp, '__dict__'))
|
||||
|
||||
def test_subprocess_protocol(self):
|
||||
f = mock.Mock()
|
||||
sp = asyncio.SubprocessProtocol()
|
||||
self.assertIsNone(sp.connection_made(f))
|
||||
self.assertIsNone(sp.connection_lost(f))
|
||||
self.assertIsNone(sp.pipe_data_received(1, f))
|
||||
self.assertIsNone(sp.pipe_connection_lost(1, f))
|
||||
self.assertIsNone(sp.process_exited())
|
||||
self.assertFalse(hasattr(sp, '__dict__'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
725
Utils/PythonNew32/Lib/test/test_asyncio/test_queues.py
Normal file
725
Utils/PythonNew32/Lib/test/test_asyncio/test_queues.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""Tests for queues.py"""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from types import GenericAlias
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class QueueBasicTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def _test_repr_or_str(self, fn, expect_id):
|
||||
"""Test Queue's repr or str.
|
||||
|
||||
fn is repr or str. expect_id is True if we expect the Queue's id to
|
||||
appear in fn(Queue()).
|
||||
"""
|
||||
q = asyncio.Queue()
|
||||
self.assertTrue(fn(q).startswith('<Queue'), fn(q))
|
||||
id_is_present = hex(id(q)) in fn(q)
|
||||
self.assertEqual(expect_id, id_is_present)
|
||||
|
||||
# getters
|
||||
q = asyncio.Queue()
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
# Start a task that waits to get.
|
||||
getter = tg.create_task(q.get())
|
||||
# Let it start waiting.
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue('_getters[1]' in fn(q))
|
||||
# resume q.get coroutine to finish generator
|
||||
q.put_nowait(0)
|
||||
|
||||
self.assertEqual(0, await getter)
|
||||
|
||||
# putters
|
||||
q = asyncio.Queue(maxsize=1)
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
q.put_nowait(1)
|
||||
# Start a task that waits to put.
|
||||
putter = tg.create_task(q.put(2))
|
||||
# Let it start waiting.
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue('_putters[1]' in fn(q))
|
||||
# resume q.put coroutine to finish generator
|
||||
q.get_nowait()
|
||||
|
||||
self.assertTrue(putter.done())
|
||||
|
||||
q = asyncio.Queue()
|
||||
q.put_nowait(1)
|
||||
self.assertTrue('_queue=[1]' in fn(q))
|
||||
|
||||
async def test_repr(self):
|
||||
await self._test_repr_or_str(repr, True)
|
||||
|
||||
async def test_str(self):
|
||||
await self._test_repr_or_str(str, False)
|
||||
|
||||
def test_generic_alias(self):
|
||||
q = asyncio.Queue[int]
|
||||
self.assertEqual(q.__args__, (int,))
|
||||
self.assertIsInstance(q, GenericAlias)
|
||||
|
||||
async def test_empty(self):
|
||||
q = asyncio.Queue()
|
||||
self.assertTrue(q.empty())
|
||||
await q.put(1)
|
||||
self.assertFalse(q.empty())
|
||||
self.assertEqual(1, await q.get())
|
||||
self.assertTrue(q.empty())
|
||||
|
||||
async def test_full(self):
|
||||
q = asyncio.Queue()
|
||||
self.assertFalse(q.full())
|
||||
|
||||
q = asyncio.Queue(maxsize=1)
|
||||
await q.put(1)
|
||||
self.assertTrue(q.full())
|
||||
|
||||
async def test_order(self):
|
||||
q = asyncio.Queue()
|
||||
for i in [1, 3, 2]:
|
||||
await q.put(i)
|
||||
|
||||
items = [await q.get() for _ in range(3)]
|
||||
self.assertEqual([1, 3, 2], items)
|
||||
|
||||
async def test_maxsize(self):
|
||||
q = asyncio.Queue(maxsize=2)
|
||||
self.assertEqual(2, q.maxsize)
|
||||
have_been_put = []
|
||||
|
||||
async def putter():
|
||||
for i in range(3):
|
||||
await q.put(i)
|
||||
have_been_put.append(i)
|
||||
return True
|
||||
|
||||
t = asyncio.create_task(putter())
|
||||
for i in range(2):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# The putter is blocked after putting two items.
|
||||
self.assertEqual([0, 1], have_been_put)
|
||||
self.assertEqual(0, await q.get())
|
||||
|
||||
# Let the putter resume and put last item.
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual([0, 1, 2], have_been_put)
|
||||
self.assertEqual(1, await q.get())
|
||||
self.assertEqual(2, await q.get())
|
||||
|
||||
self.assertTrue(t.done())
|
||||
self.assertTrue(t.result())
|
||||
|
||||
|
||||
class QueueGetTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_blocking_get(self):
|
||||
q = asyncio.Queue()
|
||||
q.put_nowait(1)
|
||||
|
||||
self.assertEqual(1, await q.get())
|
||||
|
||||
async def test_get_with_putters(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
q = asyncio.Queue(1)
|
||||
await q.put(1)
|
||||
|
||||
waiter = loop.create_future()
|
||||
q._putters.append(waiter)
|
||||
|
||||
self.assertEqual(1, await q.get())
|
||||
self.assertTrue(waiter.done())
|
||||
self.assertIsNone(waiter.result())
|
||||
|
||||
async def test_blocking_get_wait(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
q = asyncio.Queue()
|
||||
started = asyncio.Event()
|
||||
finished = False
|
||||
|
||||
async def queue_get():
|
||||
nonlocal finished
|
||||
started.set()
|
||||
res = await q.get()
|
||||
finished = True
|
||||
return res
|
||||
|
||||
queue_get_task = asyncio.create_task(queue_get())
|
||||
await started.wait()
|
||||
self.assertFalse(finished)
|
||||
loop.call_later(0.01, q.put_nowait, 1)
|
||||
res = await queue_get_task
|
||||
self.assertTrue(finished)
|
||||
self.assertEqual(1, res)
|
||||
|
||||
def test_nonblocking_get(self):
|
||||
q = asyncio.Queue()
|
||||
q.put_nowait(1)
|
||||
self.assertEqual(1, q.get_nowait())
|
||||
|
||||
def test_nonblocking_get_exception(self):
|
||||
q = asyncio.Queue()
|
||||
self.assertRaises(asyncio.QueueEmpty, q.get_nowait)
|
||||
|
||||
async def test_get_cancelled_race(self):
|
||||
q = asyncio.Queue()
|
||||
|
||||
t1 = asyncio.create_task(q.get())
|
||||
t2 = asyncio.create_task(q.get())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
t1.cancel()
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(t1.done())
|
||||
await q.put('a')
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual('a', await t2)
|
||||
|
||||
async def test_get_with_waiting_putters(self):
|
||||
q = asyncio.Queue(maxsize=1)
|
||||
asyncio.create_task(q.put('a'))
|
||||
asyncio.create_task(q.put('b'))
|
||||
self.assertEqual(await q.get(), 'a')
|
||||
self.assertEqual(await q.get(), 'b')
|
||||
|
||||
async def test_why_are_getters_waiting(self):
|
||||
async def consumer(queue, num_expected):
|
||||
for _ in range(num_expected):
|
||||
await queue.get()
|
||||
|
||||
async def producer(queue, num_items):
|
||||
for i in range(num_items):
|
||||
await queue.put(i)
|
||||
|
||||
producer_num_items = 5
|
||||
|
||||
q = asyncio.Queue(1)
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tg.create_task(producer(q, producer_num_items))
|
||||
tg.create_task(consumer(q, producer_num_items))
|
||||
|
||||
async def test_cancelled_getters_not_being_held_in_self_getters(self):
|
||||
queue = asyncio.Queue(maxsize=5)
|
||||
|
||||
with self.assertRaises(TimeoutError):
|
||||
await asyncio.wait_for(queue.get(), 0.1)
|
||||
|
||||
self.assertEqual(len(queue._getters), 0)
|
||||
|
||||
|
||||
class QueuePutTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_blocking_put(self):
|
||||
q = asyncio.Queue()
|
||||
|
||||
# No maxsize, won't block.
|
||||
await q.put(1)
|
||||
self.assertEqual(1, await q.get())
|
||||
|
||||
async def test_blocking_put_wait(self):
|
||||
q = asyncio.Queue(maxsize=1)
|
||||
started = asyncio.Event()
|
||||
finished = False
|
||||
|
||||
async def queue_put():
|
||||
nonlocal finished
|
||||
started.set()
|
||||
await q.put(1)
|
||||
await q.put(2)
|
||||
finished = True
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.call_later(0.01, q.get_nowait)
|
||||
queue_put_task = asyncio.create_task(queue_put())
|
||||
await started.wait()
|
||||
self.assertFalse(finished)
|
||||
await queue_put_task
|
||||
self.assertTrue(finished)
|
||||
|
||||
def test_nonblocking_put(self):
|
||||
q = asyncio.Queue()
|
||||
q.put_nowait(1)
|
||||
self.assertEqual(1, q.get_nowait())
|
||||
|
||||
async def test_get_cancel_drop_one_pending_reader(self):
|
||||
q = asyncio.Queue()
|
||||
|
||||
reader = asyncio.create_task(q.get())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
q.put_nowait(1)
|
||||
q.put_nowait(2)
|
||||
reader.cancel()
|
||||
|
||||
try:
|
||||
await reader
|
||||
except asyncio.CancelledError:
|
||||
# try again
|
||||
reader = asyncio.create_task(q.get())
|
||||
await reader
|
||||
|
||||
result = reader.result()
|
||||
# if we get 2, it means 1 got dropped!
|
||||
self.assertEqual(1, result)
|
||||
|
||||
async def test_get_cancel_drop_many_pending_readers(self):
|
||||
q = asyncio.Queue()
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
reader1 = tg.create_task(q.get())
|
||||
reader2 = tg.create_task(q.get())
|
||||
reader3 = tg.create_task(q.get())
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
q.put_nowait(1)
|
||||
q.put_nowait(2)
|
||||
reader1.cancel()
|
||||
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await reader1
|
||||
|
||||
await reader3
|
||||
|
||||
# It is undefined in which order concurrent readers receive results.
|
||||
self.assertEqual({reader2.result(), reader3.result()}, {1, 2})
|
||||
|
||||
async def test_put_cancel_drop(self):
|
||||
q = asyncio.Queue(1)
|
||||
|
||||
q.put_nowait(1)
|
||||
|
||||
# putting a second item in the queue has to block (qsize=1)
|
||||
writer = asyncio.create_task(q.put(2))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
value1 = q.get_nowait()
|
||||
self.assertEqual(value1, 1)
|
||||
|
||||
writer.cancel()
|
||||
try:
|
||||
await writer
|
||||
except asyncio.CancelledError:
|
||||
# try again
|
||||
writer = asyncio.create_task(q.put(2))
|
||||
await writer
|
||||
|
||||
value2 = q.get_nowait()
|
||||
self.assertEqual(value2, 2)
|
||||
self.assertEqual(q.qsize(), 0)
|
||||
|
||||
def test_nonblocking_put_exception(self):
|
||||
q = asyncio.Queue(maxsize=1, )
|
||||
q.put_nowait(1)
|
||||
self.assertRaises(asyncio.QueueFull, q.put_nowait, 2)
|
||||
|
||||
async def test_float_maxsize(self):
|
||||
q = asyncio.Queue(maxsize=1.3, )
|
||||
q.put_nowait(1)
|
||||
q.put_nowait(2)
|
||||
self.assertTrue(q.full())
|
||||
self.assertRaises(asyncio.QueueFull, q.put_nowait, 3)
|
||||
|
||||
q = asyncio.Queue(maxsize=1.3, )
|
||||
|
||||
await q.put(1)
|
||||
await q.put(2)
|
||||
self.assertTrue(q.full())
|
||||
|
||||
async def test_put_cancelled(self):
|
||||
q = asyncio.Queue()
|
||||
|
||||
async def queue_put():
|
||||
await q.put(1)
|
||||
return True
|
||||
|
||||
t = asyncio.create_task(queue_put())
|
||||
|
||||
self.assertEqual(1, await q.get())
|
||||
self.assertTrue(t.done())
|
||||
self.assertTrue(t.result())
|
||||
|
||||
async def test_put_cancelled_race(self):
|
||||
q = asyncio.Queue(maxsize=1)
|
||||
|
||||
put_a = asyncio.create_task(q.put('a'))
|
||||
put_b = asyncio.create_task(q.put('b'))
|
||||
put_c = asyncio.create_task(q.put('X'))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(put_a.done())
|
||||
self.assertFalse(put_b.done())
|
||||
|
||||
put_c.cancel()
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(put_c.done())
|
||||
self.assertEqual(q.get_nowait(), 'a')
|
||||
await asyncio.sleep(0)
|
||||
self.assertEqual(q.get_nowait(), 'b')
|
||||
|
||||
await put_b
|
||||
|
||||
async def test_put_with_waiting_getters(self):
|
||||
q = asyncio.Queue()
|
||||
t = asyncio.create_task(q.get())
|
||||
await asyncio.sleep(0)
|
||||
await q.put('a')
|
||||
self.assertEqual(await t, 'a')
|
||||
|
||||
async def test_why_are_putters_waiting(self):
|
||||
queue = asyncio.Queue(2)
|
||||
|
||||
async def putter(item):
|
||||
await queue.put(item)
|
||||
|
||||
async def getter():
|
||||
await asyncio.sleep(0)
|
||||
num = queue.qsize()
|
||||
for _ in range(num):
|
||||
queue.get_nowait()
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tg.create_task(getter())
|
||||
tg.create_task(putter(0))
|
||||
tg.create_task(putter(1))
|
||||
tg.create_task(putter(2))
|
||||
tg.create_task(putter(3))
|
||||
|
||||
async def test_cancelled_puts_not_being_held_in_self_putters(self):
|
||||
# Full queue.
|
||||
queue = asyncio.Queue(maxsize=1)
|
||||
queue.put_nowait(1)
|
||||
|
||||
# Task waiting for space to put an item in the queue.
|
||||
put_task = asyncio.create_task(queue.put(1))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Check that the putter is correctly removed from queue._putters when
|
||||
# the task is canceled.
|
||||
self.assertEqual(len(queue._putters), 1)
|
||||
put_task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await put_task
|
||||
self.assertEqual(len(queue._putters), 0)
|
||||
|
||||
async def test_cancelled_put_silence_value_error_exception(self):
|
||||
# Full Queue.
|
||||
queue = asyncio.Queue(1)
|
||||
queue.put_nowait(1)
|
||||
|
||||
# Task waiting for space to put a item in the queue.
|
||||
put_task = asyncio.create_task(queue.put(1))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# get_nowait() remove the future of put_task from queue._putters.
|
||||
queue.get_nowait()
|
||||
# When canceled, queue.put is going to remove its future from
|
||||
# self._putters but it was removed previously by queue.get_nowait().
|
||||
put_task.cancel()
|
||||
|
||||
# The ValueError exception triggered by queue._putters.remove(putter)
|
||||
# inside queue.put should be silenced.
|
||||
# If the ValueError is silenced we should catch a CancelledError.
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await put_task
|
||||
|
||||
|
||||
class LifoQueueTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_order(self):
|
||||
q = asyncio.LifoQueue()
|
||||
for i in [1, 3, 2]:
|
||||
await q.put(i)
|
||||
|
||||
items = [await q.get() for _ in range(3)]
|
||||
self.assertEqual([2, 3, 1], items)
|
||||
|
||||
|
||||
class PriorityQueueTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_order(self):
|
||||
q = asyncio.PriorityQueue()
|
||||
for i in [1, 3, 2]:
|
||||
await q.put(i)
|
||||
|
||||
items = [await q.get() for _ in range(3)]
|
||||
self.assertEqual([1, 2, 3], items)
|
||||
|
||||
|
||||
class _QueueJoinTestMixin:
|
||||
|
||||
q_class = None
|
||||
|
||||
def test_task_done_underflow(self):
|
||||
q = self.q_class()
|
||||
self.assertRaises(ValueError, q.task_done)
|
||||
|
||||
async def test_task_done(self):
|
||||
q = self.q_class()
|
||||
for i in range(100):
|
||||
q.put_nowait(i)
|
||||
|
||||
accumulator = 0
|
||||
|
||||
# Two workers get items from the queue and call task_done after each.
|
||||
# Join the queue and assert all items have been processed.
|
||||
running = True
|
||||
|
||||
async def worker():
|
||||
nonlocal accumulator
|
||||
|
||||
while running:
|
||||
item = await q.get()
|
||||
accumulator += item
|
||||
q.task_done()
|
||||
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tasks = [tg.create_task(worker())
|
||||
for index in range(2)]
|
||||
|
||||
await q.join()
|
||||
self.assertEqual(sum(range(100)), accumulator)
|
||||
|
||||
# close running generators
|
||||
running = False
|
||||
for i in range(len(tasks)):
|
||||
q.put_nowait(0)
|
||||
|
||||
async def test_join_empty_queue(self):
|
||||
q = self.q_class()
|
||||
|
||||
# Test that a queue join()s successfully, and before anything else
|
||||
# (done twice for insurance).
|
||||
|
||||
await q.join()
|
||||
await q.join()
|
||||
|
||||
async def test_format(self):
|
||||
q = self.q_class()
|
||||
self.assertEqual(q._format(), 'maxsize=0')
|
||||
|
||||
q._unfinished_tasks = 2
|
||||
self.assertEqual(q._format(), 'maxsize=0 tasks=2')
|
||||
|
||||
|
||||
class QueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCase):
|
||||
q_class = asyncio.Queue
|
||||
|
||||
|
||||
class LifoQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCase):
|
||||
q_class = asyncio.LifoQueue
|
||||
|
||||
|
||||
class PriorityQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCase):
|
||||
q_class = asyncio.PriorityQueue
|
||||
|
||||
|
||||
class _QueueShutdownTestMixin:
|
||||
q_class = None
|
||||
|
||||
def assertRaisesShutdown(self, msg="Didn't appear to shut-down queue"):
|
||||
return self.assertRaises(asyncio.QueueShutDown, msg=msg)
|
||||
|
||||
async def test_format(self):
|
||||
q = self.q_class()
|
||||
q.shutdown()
|
||||
self.assertEqual(q._format(), 'maxsize=0 shutdown')
|
||||
|
||||
async def test_shutdown_empty(self):
|
||||
# Test shutting down an empty queue
|
||||
|
||||
# Setup empty queue, and join() and get() tasks
|
||||
q = self.q_class()
|
||||
loop = asyncio.get_running_loop()
|
||||
get_task = loop.create_task(q.get())
|
||||
await asyncio.sleep(0) # want get task pending before shutdown
|
||||
|
||||
# Perform shut-down
|
||||
q.shutdown(immediate=False) # unfinished tasks: 0 -> 0
|
||||
|
||||
self.assertEqual(q.qsize(), 0)
|
||||
|
||||
# Ensure join() task successfully finishes
|
||||
await q.join()
|
||||
|
||||
# Ensure get() task is finished, and raised ShutDown
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(get_task.done())
|
||||
with self.assertRaisesShutdown():
|
||||
await get_task
|
||||
|
||||
# Ensure put() and get() raise ShutDown
|
||||
with self.assertRaisesShutdown():
|
||||
await q.put("data")
|
||||
with self.assertRaisesShutdown():
|
||||
q.put_nowait("data")
|
||||
|
||||
with self.assertRaisesShutdown():
|
||||
await q.get()
|
||||
with self.assertRaisesShutdown():
|
||||
q.get_nowait()
|
||||
|
||||
async def test_shutdown_nonempty(self):
|
||||
# Test shutting down a non-empty queue
|
||||
|
||||
# Setup full queue with 1 item, and join() and put() tasks
|
||||
q = self.q_class(maxsize=1)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
q.put_nowait("data")
|
||||
join_task = loop.create_task(q.join())
|
||||
put_task = loop.create_task(q.put("data2"))
|
||||
|
||||
# Ensure put() task is not finished
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(put_task.done())
|
||||
|
||||
# Perform shut-down
|
||||
q.shutdown(immediate=False) # unfinished tasks: 1 -> 1
|
||||
|
||||
self.assertEqual(q.qsize(), 1)
|
||||
|
||||
# Ensure put() task is finished, and raised ShutDown
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(put_task.done())
|
||||
with self.assertRaisesShutdown():
|
||||
await put_task
|
||||
|
||||
# Ensure get() succeeds on enqueued item
|
||||
self.assertEqual(await q.get(), "data")
|
||||
|
||||
# Ensure join() task is not finished
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(join_task.done())
|
||||
|
||||
# Ensure put() and get() raise ShutDown
|
||||
with self.assertRaisesShutdown():
|
||||
await q.put("data")
|
||||
with self.assertRaisesShutdown():
|
||||
q.put_nowait("data")
|
||||
|
||||
with self.assertRaisesShutdown():
|
||||
await q.get()
|
||||
with self.assertRaisesShutdown():
|
||||
q.get_nowait()
|
||||
|
||||
# Ensure there is 1 unfinished task, and join() task succeeds
|
||||
q.task_done()
|
||||
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(join_task.done())
|
||||
await join_task
|
||||
|
||||
with self.assertRaises(
|
||||
ValueError, msg="Didn't appear to mark all tasks done"
|
||||
):
|
||||
q.task_done()
|
||||
|
||||
async def test_shutdown_immediate(self):
|
||||
# Test immediately shutting down a queue
|
||||
|
||||
# Setup queue with 1 item, and a join() task
|
||||
q = self.q_class()
|
||||
loop = asyncio.get_running_loop()
|
||||
q.put_nowait("data")
|
||||
join_task = loop.create_task(q.join())
|
||||
|
||||
# Perform shut-down
|
||||
q.shutdown(immediate=True) # unfinished tasks: 1 -> 0
|
||||
|
||||
self.assertEqual(q.qsize(), 0)
|
||||
|
||||
# Ensure join() task has successfully finished
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(join_task.done())
|
||||
await join_task
|
||||
|
||||
# Ensure put() and get() raise ShutDown
|
||||
with self.assertRaisesShutdown():
|
||||
await q.put("data")
|
||||
with self.assertRaisesShutdown():
|
||||
q.put_nowait("data")
|
||||
|
||||
with self.assertRaisesShutdown():
|
||||
await q.get()
|
||||
with self.assertRaisesShutdown():
|
||||
q.get_nowait()
|
||||
|
||||
# Ensure there are no unfinished tasks
|
||||
with self.assertRaises(
|
||||
ValueError, msg="Didn't appear to mark all tasks done"
|
||||
):
|
||||
q.task_done()
|
||||
|
||||
async def test_shutdown_immediate_with_unfinished(self):
|
||||
# Test immediately shutting down a queue with unfinished tasks
|
||||
|
||||
# Setup queue with 2 items (1 retrieved), and a join() task
|
||||
q = self.q_class()
|
||||
loop = asyncio.get_running_loop()
|
||||
q.put_nowait("data")
|
||||
q.put_nowait("data")
|
||||
join_task = loop.create_task(q.join())
|
||||
self.assertEqual(await q.get(), "data")
|
||||
|
||||
# Perform shut-down
|
||||
q.shutdown(immediate=True) # unfinished tasks: 2 -> 1
|
||||
|
||||
self.assertEqual(q.qsize(), 0)
|
||||
|
||||
# Ensure join() task is not finished
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(join_task.done())
|
||||
|
||||
# Ensure put() and get() raise ShutDown
|
||||
with self.assertRaisesShutdown():
|
||||
await q.put("data")
|
||||
with self.assertRaisesShutdown():
|
||||
q.put_nowait("data")
|
||||
|
||||
with self.assertRaisesShutdown():
|
||||
await q.get()
|
||||
with self.assertRaisesShutdown():
|
||||
q.get_nowait()
|
||||
|
||||
# Ensure there is 1 unfinished task
|
||||
q.task_done()
|
||||
with self.assertRaises(
|
||||
ValueError, msg="Didn't appear to mark all tasks done"
|
||||
):
|
||||
q.task_done()
|
||||
|
||||
# Ensure join() task has successfully finished
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(join_task.done())
|
||||
await join_task
|
||||
|
||||
|
||||
class QueueShutdownTests(
|
||||
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
q_class = asyncio.Queue
|
||||
|
||||
|
||||
class LifoQueueShutdownTests(
|
||||
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
q_class = asyncio.LifoQueue
|
||||
|
||||
|
||||
class PriorityQueueShutdownTests(
|
||||
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
|
||||
):
|
||||
q_class = asyncio.PriorityQueue
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
518
Utils/PythonNew32/Lib/test/test_asyncio/test_runners.py
Normal file
518
Utils/PythonNew32/Lib/test/test_asyncio/test_runners.py
Normal file
@@ -0,0 +1,518 @@
|
||||
import _thread
|
||||
import asyncio
|
||||
import contextvars
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
from test.test_asyncio import utils as test_utils
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
def interrupt_self():
|
||||
_thread.interrupt_main()
|
||||
|
||||
|
||||
class TestPolicy(asyncio.AbstractEventLoopPolicy):
|
||||
|
||||
def __init__(self, loop_factory):
|
||||
self.loop_factory = loop_factory
|
||||
self.loop = None
|
||||
|
||||
def get_event_loop(self):
|
||||
# shouldn't ever be called by asyncio.run()
|
||||
raise RuntimeError
|
||||
|
||||
def new_event_loop(self):
|
||||
return self.loop_factory()
|
||||
|
||||
def set_event_loop(self, loop):
|
||||
if loop is not None:
|
||||
# we want to check if the loop is closed
|
||||
# in BaseTest.tearDown
|
||||
self.loop = loop
|
||||
|
||||
|
||||
class BaseTest(unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
loop = asyncio.BaseEventLoop()
|
||||
loop._process_events = mock.Mock()
|
||||
# Mock waking event loop from select
|
||||
loop._write_to_self = mock.Mock()
|
||||
loop._write_to_self.return_value = None
|
||||
loop._selector = mock.Mock()
|
||||
loop._selector.select.return_value = ()
|
||||
loop.shutdown_ag_run = False
|
||||
|
||||
async def shutdown_asyncgens():
|
||||
loop.shutdown_ag_run = True
|
||||
loop.shutdown_asyncgens = shutdown_asyncgens
|
||||
|
||||
return loop
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
policy = TestPolicy(self.new_loop)
|
||||
asyncio.set_event_loop_policy(policy)
|
||||
|
||||
def tearDown(self):
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
if policy.loop is not None:
|
||||
self.assertTrue(policy.loop.is_closed())
|
||||
self.assertTrue(policy.loop.shutdown_ag_run)
|
||||
|
||||
asyncio.set_event_loop_policy(None)
|
||||
super().tearDown()
|
||||
|
||||
|
||||
class RunTests(BaseTest):
|
||||
|
||||
def test_asyncio_run_return(self):
|
||||
async def main():
|
||||
await asyncio.sleep(0)
|
||||
return 42
|
||||
|
||||
self.assertEqual(asyncio.run(main()), 42)
|
||||
|
||||
def test_asyncio_run_raises(self):
|
||||
async def main():
|
||||
await asyncio.sleep(0)
|
||||
raise ValueError('spam')
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'spam'):
|
||||
asyncio.run(main())
|
||||
|
||||
def test_asyncio_run_only_coro(self):
|
||||
for o in {1, lambda: None}:
|
||||
with self.subTest(obj=o), \
|
||||
self.assertRaisesRegex(ValueError,
|
||||
'a coroutine was expected'):
|
||||
asyncio.run(o)
|
||||
|
||||
def test_asyncio_run_debug(self):
|
||||
async def main(expected):
|
||||
loop = asyncio.get_event_loop()
|
||||
self.assertIs(loop.get_debug(), expected)
|
||||
|
||||
asyncio.run(main(False), debug=False)
|
||||
asyncio.run(main(True), debug=True)
|
||||
with mock.patch('asyncio.coroutines._is_debug_mode', lambda: True):
|
||||
asyncio.run(main(True))
|
||||
asyncio.run(main(False), debug=False)
|
||||
with mock.patch('asyncio.coroutines._is_debug_mode', lambda: False):
|
||||
asyncio.run(main(True), debug=True)
|
||||
asyncio.run(main(False))
|
||||
|
||||
def test_asyncio_run_from_running_loop(self):
|
||||
async def main():
|
||||
coro = main()
|
||||
try:
|
||||
asyncio.run(coro)
|
||||
finally:
|
||||
coro.close() # Suppress ResourceWarning
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'cannot be called from a running'):
|
||||
asyncio.run(main())
|
||||
|
||||
def test_asyncio_run_cancels_hanging_tasks(self):
|
||||
lo_task = None
|
||||
|
||||
async def leftover():
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def main():
|
||||
nonlocal lo_task
|
||||
lo_task = asyncio.create_task(leftover())
|
||||
return 123
|
||||
|
||||
self.assertEqual(asyncio.run(main()), 123)
|
||||
self.assertTrue(lo_task.done())
|
||||
|
||||
def test_asyncio_run_reports_hanging_tasks_errors(self):
|
||||
lo_task = None
|
||||
call_exc_handler_mock = mock.Mock()
|
||||
|
||||
async def leftover():
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
except asyncio.CancelledError:
|
||||
1 / 0
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.call_exception_handler = call_exc_handler_mock
|
||||
|
||||
nonlocal lo_task
|
||||
lo_task = asyncio.create_task(leftover())
|
||||
return 123
|
||||
|
||||
self.assertEqual(asyncio.run(main()), 123)
|
||||
self.assertTrue(lo_task.done())
|
||||
|
||||
call_exc_handler_mock.assert_called_with({
|
||||
'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
|
||||
'task': lo_task,
|
||||
'exception': test_utils.MockInstanceOf(ZeroDivisionError)
|
||||
})
|
||||
|
||||
def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
|
||||
spinner = None
|
||||
lazyboy = None
|
||||
|
||||
class FancyExit(Exception):
|
||||
pass
|
||||
|
||||
async def fidget():
|
||||
while True:
|
||||
yield 1
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def spin():
|
||||
nonlocal spinner
|
||||
spinner = fidget()
|
||||
try:
|
||||
async for the_meaning_of_life in spinner: # NoQA
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
1 / 0
|
||||
|
||||
async def main():
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.call_exception_handler = mock.Mock()
|
||||
|
||||
nonlocal lazyboy
|
||||
lazyboy = asyncio.create_task(spin())
|
||||
raise FancyExit
|
||||
|
||||
with self.assertRaises(FancyExit):
|
||||
asyncio.run(main())
|
||||
|
||||
self.assertTrue(lazyboy.done())
|
||||
|
||||
self.assertIsNone(spinner.ag_frame)
|
||||
self.assertFalse(spinner.ag_running)
|
||||
|
||||
def test_asyncio_run_set_event_loop(self):
|
||||
#See https://github.com/python/cpython/issues/93896
|
||||
|
||||
async def main():
|
||||
await asyncio.sleep(0)
|
||||
return 42
|
||||
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
policy.set_event_loop = mock.Mock()
|
||||
asyncio.run(main())
|
||||
self.assertTrue(policy.set_event_loop.called)
|
||||
|
||||
def test_asyncio_run_without_uncancel(self):
|
||||
# See https://github.com/python/cpython/issues/95097
|
||||
class Task:
|
||||
def __init__(self, loop, coro, **kwargs):
|
||||
self._task = asyncio.Task(coro, loop=loop, **kwargs)
|
||||
|
||||
def cancel(self, *args, **kwargs):
|
||||
return self._task.cancel(*args, **kwargs)
|
||||
|
||||
def add_done_callback(self, *args, **kwargs):
|
||||
return self._task.add_done_callback(*args, **kwargs)
|
||||
|
||||
def remove_done_callback(self, *args, **kwargs):
|
||||
return self._task.remove_done_callback(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def _asyncio_future_blocking(self):
|
||||
return self._task._asyncio_future_blocking
|
||||
|
||||
def result(self, *args, **kwargs):
|
||||
return self._task.result(*args, **kwargs)
|
||||
|
||||
def done(self, *args, **kwargs):
|
||||
return self._task.done(*args, **kwargs)
|
||||
|
||||
def cancelled(self, *args, **kwargs):
|
||||
return self._task.cancelled(*args, **kwargs)
|
||||
|
||||
def exception(self, *args, **kwargs):
|
||||
return self._task.exception(*args, **kwargs)
|
||||
|
||||
def get_loop(self, *args, **kwargs):
|
||||
return self._task.get_loop(*args, **kwargs)
|
||||
|
||||
def set_name(self, *args, **kwargs):
|
||||
return self._task.set_name(*args, **kwargs)
|
||||
|
||||
async def main():
|
||||
interrupt_self()
|
||||
await asyncio.Event().wait()
|
||||
|
||||
def new_event_loop():
|
||||
loop = self.new_loop()
|
||||
loop.set_task_factory(Task)
|
||||
return loop
|
||||
|
||||
asyncio.set_event_loop_policy(TestPolicy(new_event_loop))
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
asyncio.run(main())
|
||||
|
||||
def test_asyncio_run_loop_factory(self):
|
||||
factory = mock.Mock()
|
||||
loop = factory.return_value = self.new_loop()
|
||||
|
||||
async def main():
|
||||
self.assertEqual(asyncio.get_running_loop(), loop)
|
||||
|
||||
asyncio.run(main(), loop_factory=factory)
|
||||
factory.assert_called_once_with()
|
||||
|
||||
def test_loop_factory_default_event_loop(self):
|
||||
async def main():
|
||||
if sys.platform == "win32":
|
||||
self.assertIsInstance(asyncio.get_running_loop(), asyncio.ProactorEventLoop)
|
||||
else:
|
||||
self.assertIsInstance(asyncio.get_running_loop(), asyncio.SelectorEventLoop)
|
||||
|
||||
|
||||
asyncio.run(main(), loop_factory=asyncio.EventLoop)
|
||||
|
||||
|
||||
class RunnerTests(BaseTest):
|
||||
|
||||
def test_non_debug(self):
|
||||
with asyncio.Runner(debug=False) as runner:
|
||||
self.assertFalse(runner.get_loop().get_debug())
|
||||
|
||||
def test_debug(self):
|
||||
with asyncio.Runner(debug=True) as runner:
|
||||
self.assertTrue(runner.get_loop().get_debug())
|
||||
|
||||
def test_custom_factory(self):
|
||||
loop = mock.Mock()
|
||||
with asyncio.Runner(loop_factory=lambda: loop) as runner:
|
||||
self.assertIs(runner.get_loop(), loop)
|
||||
|
||||
def test_run(self):
|
||||
async def f():
|
||||
await asyncio.sleep(0)
|
||||
return 'done'
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
self.assertEqual('done', runner.run(f()))
|
||||
loop = runner.get_loop()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Runner is closed"
|
||||
):
|
||||
runner.get_loop()
|
||||
|
||||
self.assertTrue(loop.is_closed())
|
||||
|
||||
def test_run_non_coro(self):
|
||||
with asyncio.Runner() as runner:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"a coroutine was expected"
|
||||
):
|
||||
runner.run(123)
|
||||
|
||||
def test_run_future(self):
|
||||
with asyncio.Runner() as runner:
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"a coroutine was expected"
|
||||
):
|
||||
fut = runner.get_loop().create_future()
|
||||
runner.run(fut)
|
||||
|
||||
def test_explicit_close(self):
|
||||
runner = asyncio.Runner()
|
||||
loop = runner.get_loop()
|
||||
runner.close()
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Runner is closed"
|
||||
):
|
||||
runner.get_loop()
|
||||
|
||||
self.assertTrue(loop.is_closed())
|
||||
|
||||
def test_double_close(self):
|
||||
runner = asyncio.Runner()
|
||||
loop = runner.get_loop()
|
||||
|
||||
runner.close()
|
||||
self.assertTrue(loop.is_closed())
|
||||
|
||||
# the second call is no-op
|
||||
runner.close()
|
||||
self.assertTrue(loop.is_closed())
|
||||
|
||||
def test_second_with_block_raises(self):
|
||||
ret = []
|
||||
|
||||
async def f(arg):
|
||||
ret.append(arg)
|
||||
|
||||
runner = asyncio.Runner()
|
||||
with runner:
|
||||
runner.run(f(1))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Runner is closed"
|
||||
):
|
||||
with runner:
|
||||
runner.run(f(2))
|
||||
|
||||
self.assertEqual([1], ret)
|
||||
|
||||
def test_run_keeps_context(self):
|
||||
cvar = contextvars.ContextVar("cvar", default=-1)
|
||||
|
||||
async def f(val):
|
||||
old = cvar.get()
|
||||
await asyncio.sleep(0)
|
||||
cvar.set(val)
|
||||
return old
|
||||
|
||||
async def get_context():
|
||||
return contextvars.copy_context()
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
self.assertEqual(-1, runner.run(f(1)))
|
||||
self.assertEqual(1, runner.run(f(2)))
|
||||
|
||||
self.assertEqual(2, runner.run(get_context()).get(cvar))
|
||||
|
||||
def test_recursive_run(self):
|
||||
async def g():
|
||||
pass
|
||||
|
||||
async def f():
|
||||
runner.run(g())
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
with self.assertWarnsRegex(
|
||||
RuntimeWarning,
|
||||
"coroutine .+ was never awaited",
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
re.escape(
|
||||
"Runner.run() cannot be called from a running event loop"
|
||||
),
|
||||
):
|
||||
runner.run(f())
|
||||
|
||||
def test_interrupt_call_soon(self):
|
||||
# The only case when task is not suspended by waiting a future
|
||||
# or another task
|
||||
assert threading.current_thread() is threading.main_thread()
|
||||
|
||||
async def coro():
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
while True:
|
||||
await asyncio.sleep(0)
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
runner.get_loop().call_later(0.1, interrupt_self)
|
||||
with self.assertRaises(KeyboardInterrupt):
|
||||
runner.run(coro())
|
||||
|
||||
def test_interrupt_wait(self):
|
||||
# interrupting when waiting a future cancels both future and main task
|
||||
assert threading.current_thread() is threading.main_thread()
|
||||
|
||||
async def coro(fut):
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await fut
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
fut = runner.get_loop().create_future()
|
||||
runner.get_loop().call_later(0.1, interrupt_self)
|
||||
|
||||
with self.assertRaises(KeyboardInterrupt):
|
||||
runner.run(coro(fut))
|
||||
|
||||
self.assertTrue(fut.cancelled())
|
||||
|
||||
def test_interrupt_cancelled_task(self):
|
||||
# interrupting cancelled main task doesn't raise KeyboardInterrupt
|
||||
assert threading.current_thread() is threading.main_thread()
|
||||
|
||||
async def subtask(task):
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
interrupt_self()
|
||||
|
||||
async def coro():
|
||||
asyncio.create_task(subtask(asyncio.current_task()))
|
||||
await asyncio.sleep(10)
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
runner.run(coro())
|
||||
|
||||
def test_signal_install_not_supported_ok(self):
|
||||
# signal.signal() can throw if the "main thread" doesn't have signals enabled
|
||||
assert threading.current_thread() is threading.main_thread()
|
||||
|
||||
async def coro():
|
||||
pass
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
with patch.object(
|
||||
signal,
|
||||
"signal",
|
||||
side_effect=ValueError(
|
||||
"signal only works in main thread of the main interpreter"
|
||||
)
|
||||
):
|
||||
runner.run(coro())
|
||||
|
||||
def test_set_event_loop_called_once(self):
|
||||
# See https://github.com/python/cpython/issues/95736
|
||||
async def coro():
|
||||
pass
|
||||
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
policy.set_event_loop = mock.Mock()
|
||||
runner = asyncio.Runner()
|
||||
runner.run(coro())
|
||||
runner.run(coro())
|
||||
|
||||
self.assertEqual(1, policy.set_event_loop.call_count)
|
||||
runner.close()
|
||||
|
||||
def test_no_repr_is_call_on_the_task_result(self):
|
||||
# See https://github.com/python/cpython/issues/112559.
|
||||
class MyResult:
|
||||
def __init__(self):
|
||||
self.repr_count = 0
|
||||
def __repr__(self):
|
||||
self.repr_count += 1
|
||||
return super().__repr__()
|
||||
|
||||
async def coro():
|
||||
return MyResult()
|
||||
|
||||
|
||||
with asyncio.Runner() as runner:
|
||||
result = runner.run(coro())
|
||||
|
||||
self.assertEqual(0, result.repr_count)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
1566
Utils/PythonNew32/Lib/test/test_asyncio/test_selector_events.py
Normal file
1566
Utils/PythonNew32/Lib/test/test_asyncio/test_selector_events.py
Normal file
File diff suppressed because it is too large
Load Diff
585
Utils/PythonNew32/Lib/test/test_asyncio/test_sendfile.py
Normal file
585
Utils/PythonNew32/Lib/test/test_asyncio/test_sendfile.py
Normal file
@@ -0,0 +1,585 @@
|
||||
"""Tests for sendfile functionality."""
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from asyncio import base_events
|
||||
from asyncio import constants
|
||||
from unittest import mock
|
||||
from test import support
|
||||
from test.support import os_helper
|
||||
from test.support import socket_helper
|
||||
from test.test_asyncio import utils as test_utils
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class MySendfileProto(asyncio.Protocol):
|
||||
|
||||
def __init__(self, loop=None, close_after=0):
|
||||
self.transport = None
|
||||
self.state = 'INITIAL'
|
||||
self.nbytes = 0
|
||||
if loop is not None:
|
||||
self.connected = loop.create_future()
|
||||
self.done = loop.create_future()
|
||||
self.data = bytearray()
|
||||
self.close_after = close_after
|
||||
|
||||
def _assert_state(self, *expected):
|
||||
if self.state not in expected:
|
||||
raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self._assert_state('INITIAL')
|
||||
self.state = 'CONNECTED'
|
||||
if self.connected:
|
||||
self.connected.set_result(None)
|
||||
|
||||
def eof_received(self):
|
||||
self._assert_state('CONNECTED')
|
||||
self.state = 'EOF'
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self._assert_state('CONNECTED', 'EOF')
|
||||
self.state = 'CLOSED'
|
||||
if self.done:
|
||||
self.done.set_result(None)
|
||||
|
||||
def data_received(self, data):
|
||||
self._assert_state('CONNECTED')
|
||||
self.nbytes += len(data)
|
||||
self.data.extend(data)
|
||||
super().data_received(data)
|
||||
if self.close_after and self.nbytes >= self.close_after:
|
||||
self.transport.close()
|
||||
|
||||
|
||||
class MyProto(asyncio.Protocol):
|
||||
|
||||
def __init__(self, loop):
|
||||
self.started = False
|
||||
self.closed = False
|
||||
self.data = bytearray()
|
||||
self.fut = loop.create_future()
|
||||
self.transport = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.started = True
|
||||
self.transport = transport
|
||||
|
||||
def data_received(self, data):
|
||||
self.data.extend(data)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.closed = True
|
||||
self.fut.set_result(None)
|
||||
|
||||
async def wait_closed(self):
|
||||
await self.fut
|
||||
|
||||
|
||||
class SendfileBase:
|
||||
|
||||
# Linux >= 6.10 seems buffering up to 17 pages of data.
|
||||
# So DATA should be large enough to make this test reliable even with a
|
||||
# 64 KiB page configuration.
|
||||
DATA = b"x" * (1024 * 17 * 64 + 1)
|
||||
# Reduce socket buffer size to test on relative small data sets.
|
||||
BUF_SIZE = 4 * 1024 # 4 KiB
|
||||
|
||||
def create_event_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
with open(os_helper.TESTFN, 'wb') as fp:
|
||||
fp.write(cls.DATA)
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
os_helper.unlink(os_helper.TESTFN)
|
||||
super().tearDownClass()
|
||||
|
||||
def setUp(self):
|
||||
self.file = open(os_helper.TESTFN, 'rb')
|
||||
self.addCleanup(self.file.close)
|
||||
self.loop = self.create_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
# just in case if we have transport close callbacks
|
||||
if not self.loop.is_closed():
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
self.doCleanups()
|
||||
support.gc_collect()
|
||||
super().tearDown()
|
||||
|
||||
def run_loop(self, coro):
|
||||
return self.loop.run_until_complete(coro)
|
||||
|
||||
|
||||
class SockSendfileMixin(SendfileBase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
|
||||
constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
|
||||
super().tearDownClass()
|
||||
|
||||
def make_socket(self, cleanup=True):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.setblocking(False)
|
||||
if cleanup:
|
||||
self.addCleanup(sock.close)
|
||||
return sock
|
||||
|
||||
def reduce_receive_buffer_size(self, sock):
|
||||
# Reduce receive socket buffer size to test on relative
|
||||
# small data sets.
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
|
||||
|
||||
def reduce_send_buffer_size(self, sock, transport=None):
|
||||
# Reduce send socket buffer size to test on relative small data sets.
|
||||
|
||||
# On macOS, SO_SNDBUF is reset by connect(). So this method
|
||||
# should be called after the socket is connected.
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
|
||||
|
||||
if transport is not None:
|
||||
transport.set_write_buffer_limits(high=self.BUF_SIZE)
|
||||
|
||||
def prepare_socksendfile(self):
|
||||
proto = MyProto(self.loop)
|
||||
port = socket_helper.find_unused_port()
|
||||
srv_sock = self.make_socket(cleanup=False)
|
||||
srv_sock.bind((socket_helper.HOST, port))
|
||||
server = self.run_loop(self.loop.create_server(
|
||||
lambda: proto, sock=srv_sock))
|
||||
self.reduce_receive_buffer_size(srv_sock)
|
||||
|
||||
sock = self.make_socket()
|
||||
self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
|
||||
self.reduce_send_buffer_size(sock)
|
||||
|
||||
def cleanup():
|
||||
if proto.transport is not None:
|
||||
# can be None if the task was cancelled before
|
||||
# connection_made callback
|
||||
proto.transport.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
server.close()
|
||||
self.run_loop(server.wait_closed())
|
||||
|
||||
self.addCleanup(cleanup)
|
||||
|
||||
return sock, proto
|
||||
|
||||
def test_sock_sendfile_success(self):
|
||||
sock, proto = self.prepare_socksendfile()
|
||||
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
|
||||
sock.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(proto.data, self.DATA)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sock_sendfile_with_offset_and_count(self):
|
||||
sock, proto = self.prepare_socksendfile()
|
||||
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
|
||||
1000, 2000))
|
||||
sock.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
self.assertEqual(proto.data, self.DATA[1000:3000])
|
||||
self.assertEqual(self.file.tell(), 3000)
|
||||
self.assertEqual(ret, 2000)
|
||||
|
||||
def test_sock_sendfile_zero_size(self):
|
||||
sock, proto = self.prepare_socksendfile()
|
||||
with tempfile.TemporaryFile() as f:
|
||||
ret = self.run_loop(self.loop.sock_sendfile(sock, f,
|
||||
0, None))
|
||||
sock.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
self.assertEqual(ret, 0)
|
||||
self.assertEqual(self.file.tell(), 0)
|
||||
|
||||
def test_sock_sendfile_mix_with_regular_send(self):
|
||||
buf = b"mix_regular_send" * (4 * 1024) # 64 KiB
|
||||
sock, proto = self.prepare_socksendfile()
|
||||
self.run_loop(self.loop.sock_sendall(sock, buf))
|
||||
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
|
||||
self.run_loop(self.loop.sock_sendall(sock, buf))
|
||||
sock.close()
|
||||
self.run_loop(proto.wait_closed())
|
||||
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
expected = buf + self.DATA + buf
|
||||
self.assertEqual(proto.data, expected)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
|
||||
class SendfileMixin(SendfileBase):
|
||||
|
||||
# Note: sendfile via SSL transport is equal to sendfile fallback
|
||||
|
||||
def prepare_sendfile(self, *, is_ssl=False, close_after=0):
|
||||
port = socket_helper.find_unused_port()
|
||||
srv_proto = MySendfileProto(loop=self.loop,
|
||||
close_after=close_after)
|
||||
if is_ssl:
|
||||
if not ssl:
|
||||
self.skipTest("No ssl module")
|
||||
srv_ctx = test_utils.simple_server_sslcontext()
|
||||
cli_ctx = test_utils.simple_client_sslcontext()
|
||||
else:
|
||||
srv_ctx = None
|
||||
cli_ctx = None
|
||||
srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
srv_sock.bind((socket_helper.HOST, port))
|
||||
server = self.run_loop(self.loop.create_server(
|
||||
lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
|
||||
self.reduce_receive_buffer_size(srv_sock)
|
||||
|
||||
if is_ssl:
|
||||
server_hostname = socket_helper.HOST
|
||||
else:
|
||||
server_hostname = None
|
||||
cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
cli_sock.connect((socket_helper.HOST, port))
|
||||
|
||||
cli_proto = MySendfileProto(loop=self.loop)
|
||||
tr, pr = self.run_loop(self.loop.create_connection(
|
||||
lambda: cli_proto, sock=cli_sock,
|
||||
ssl=cli_ctx, server_hostname=server_hostname))
|
||||
self.reduce_send_buffer_size(cli_sock, transport=tr)
|
||||
|
||||
def cleanup():
|
||||
srv_proto.transport.close()
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.run_loop(cli_proto.done)
|
||||
|
||||
server.close()
|
||||
self.run_loop(server.wait_closed())
|
||||
|
||||
self.addCleanup(cleanup)
|
||||
return srv_proto, cli_proto
|
||||
|
||||
@unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
|
||||
def test_sendfile_not_supported(self):
|
||||
tr, pr = self.run_loop(
|
||||
self.loop.create_datagram_endpoint(
|
||||
asyncio.DatagramProtocol,
|
||||
family=socket.AF_INET))
|
||||
try:
|
||||
with self.assertRaisesRegex(RuntimeError, "not supported"):
|
||||
self.run_loop(
|
||||
self.loop.sendfile(tr, self.file))
|
||||
self.assertEqual(0, self.file.tell())
|
||||
finally:
|
||||
# don't use self.addCleanup because it produces resource warning
|
||||
tr.close()
|
||||
|
||||
def test_sendfile(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(srv_proto.nbytes, len(self.DATA))
|
||||
self.assertEqual(srv_proto.data, self.DATA)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_force_fallback(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
|
||||
def sendfile_native(transp, file, offset, count):
|
||||
# to raise SendfileNotAvailableError
|
||||
return base_events.BaseEventLoop._sendfile_native(
|
||||
self.loop, transp, file, offset, count)
|
||||
|
||||
self.loop._sendfile_native = sendfile_native
|
||||
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(srv_proto.nbytes, len(self.DATA))
|
||||
self.assertEqual(srv_proto.data, self.DATA)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_force_unsupported_native(self):
|
||||
if sys.platform == 'win32':
|
||||
if isinstance(self.loop, asyncio.ProactorEventLoop):
|
||||
self.skipTest("Fails on proactor event loop")
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
|
||||
def sendfile_native(transp, file, offset, count):
|
||||
# to raise SendfileNotAvailableError
|
||||
return base_events.BaseEventLoop._sendfile_native(
|
||||
self.loop, transp, file, offset, count)
|
||||
|
||||
self.loop._sendfile_native = sendfile_native
|
||||
|
||||
with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
|
||||
"not supported"):
|
||||
self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file,
|
||||
fallback=False))
|
||||
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(srv_proto.nbytes, 0)
|
||||
self.assertEqual(self.file.tell(), 0)
|
||||
|
||||
def test_sendfile_ssl(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(srv_proto.nbytes, len(self.DATA))
|
||||
self.assertEqual(srv_proto.data, self.DATA)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_for_closing_transp(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
cli_proto.transport.close()
|
||||
with self.assertRaisesRegex(RuntimeError, "is closing"):
|
||||
self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(srv_proto.nbytes, 0)
|
||||
self.assertEqual(self.file.tell(), 0)
|
||||
|
||||
def test_sendfile_pre_and_post_data(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
PREFIX = b'PREFIX__' * 1024 # 8 KiB
|
||||
SUFFIX = b'--SUFFIX' * 1024 # 8 KiB
|
||||
cli_proto.transport.write(PREFIX)
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.write(SUFFIX)
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_ssl_pre_and_post_data(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
|
||||
PREFIX = b'zxcvbnm' * 1024
|
||||
SUFFIX = b'0987654321' * 1024
|
||||
cli_proto.transport.write(PREFIX)
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.write(SUFFIX)
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_partial(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, 100)
|
||||
self.assertEqual(srv_proto.nbytes, 100)
|
||||
self.assertEqual(srv_proto.data, self.DATA[1000:1100])
|
||||
self.assertEqual(self.file.tell(), 1100)
|
||||
|
||||
def test_sendfile_ssl_partial(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, 100)
|
||||
self.assertEqual(srv_proto.nbytes, 100)
|
||||
self.assertEqual(srv_proto.data, self.DATA[1000:1100])
|
||||
self.assertEqual(self.file.tell(), 1100)
|
||||
|
||||
def test_sendfile_close_peer_after_receiving(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile(
|
||||
close_after=len(self.DATA))
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
cli_proto.transport.close()
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(srv_proto.nbytes, len(self.DATA))
|
||||
self.assertEqual(srv_proto.data, self.DATA)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
def test_sendfile_ssl_close_peer_after_receiving(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile(
|
||||
is_ssl=True, close_after=len(self.DATA))
|
||||
ret = self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
self.run_loop(srv_proto.done)
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
self.assertEqual(srv_proto.nbytes, len(self.DATA))
|
||||
self.assertEqual(srv_proto.data, self.DATA)
|
||||
self.assertEqual(self.file.tell(), len(self.DATA))
|
||||
|
||||
# On Solaris, lowering SO_RCVBUF on a TCP connection after it has been
|
||||
# established has no effect. Due to its age, this bug affects both Oracle
|
||||
# Solaris as well as all other OpenSolaris forks (unless they fixed it
|
||||
# themselves).
|
||||
@unittest.skipIf(sys.platform.startswith('sunos'),
|
||||
"Doesn't work on Solaris")
|
||||
def test_sendfile_close_peer_in_the_middle_of_receiving(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
|
||||
with self.assertRaises(ConnectionError):
|
||||
self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
self.run_loop(srv_proto.done)
|
||||
|
||||
self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
|
||||
srv_proto.nbytes)
|
||||
if not (sys.platform == 'win32'
|
||||
and isinstance(self.loop, asyncio.ProactorEventLoop)):
|
||||
# On Windows, Proactor uses transmitFile, which does not update tell()
|
||||
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
|
||||
self.file.tell())
|
||||
self.assertTrue(cli_proto.transport.is_closing())
|
||||
|
||||
def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
|
||||
|
||||
def sendfile_native(transp, file, offset, count):
|
||||
# to raise SendfileNotAvailableError
|
||||
return base_events.BaseEventLoop._sendfile_native(
|
||||
self.loop, transp, file, offset, count)
|
||||
|
||||
self.loop._sendfile_native = sendfile_native
|
||||
|
||||
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
|
||||
with self.assertRaises(ConnectionError):
|
||||
try:
|
||||
self.run_loop(
|
||||
self.loop.sendfile(cli_proto.transport, self.file))
|
||||
except OSError as e:
|
||||
# macOS may raise OSError of EPROTOTYPE when writing to a
|
||||
# socket that is in the process of closing down.
|
||||
if e.errno == errno.EPROTOTYPE and sys.platform == "darwin":
|
||||
raise ConnectionError
|
||||
else:
|
||||
raise
|
||||
|
||||
self.run_loop(srv_proto.done)
|
||||
|
||||
self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
|
||||
srv_proto.nbytes)
|
||||
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
|
||||
self.file.tell())
|
||||
|
||||
@unittest.skipIf(not hasattr(os, 'sendfile'),
|
||||
"Don't have native sendfile support")
|
||||
def test_sendfile_prevents_bare_write(self):
|
||||
srv_proto, cli_proto = self.prepare_sendfile()
|
||||
fut = self.loop.create_future()
|
||||
|
||||
async def coro():
|
||||
fut.set_result(None)
|
||||
return await self.loop.sendfile(cli_proto.transport, self.file)
|
||||
|
||||
t = self.loop.create_task(coro())
|
||||
self.run_loop(fut)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"sendfile is in progress"):
|
||||
cli_proto.transport.write(b'data')
|
||||
ret = self.run_loop(t)
|
||||
self.assertEqual(ret, len(self.DATA))
|
||||
|
||||
def test_sendfile_no_fallback_for_fallback_transport(self):
|
||||
transport = mock.Mock()
|
||||
transport.is_closing.side_effect = lambda: False
|
||||
transport._sendfile_compatible = constants._SendfileMode.FALLBACK
|
||||
with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sendfile(transport, None, fallback=False))
|
||||
|
||||
|
||||
class SendfileTestsBase(SendfileMixin, SockSendfileMixin):
|
||||
pass
|
||||
|
||||
|
||||
if sys.platform == 'win32':
|
||||
|
||||
class SelectEventLoopTests(SendfileTestsBase,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop()
|
||||
|
||||
class ProactorEventLoopTests(SendfileTestsBase,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.ProactorEventLoop()
|
||||
|
||||
else:
|
||||
import selectors
|
||||
|
||||
if hasattr(selectors, 'KqueueSelector'):
|
||||
class KqueueEventLoopTests(SendfileTestsBase,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop(
|
||||
selectors.KqueueSelector())
|
||||
|
||||
if hasattr(selectors, 'EpollSelector'):
|
||||
class EPollEventLoopTests(SendfileTestsBase,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop(selectors.EpollSelector())
|
||||
|
||||
if hasattr(selectors, 'PollSelector'):
|
||||
class PollEventLoopTests(SendfileTestsBase,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop(selectors.PollSelector())
|
||||
|
||||
# Should always exist.
|
||||
class SelectEventLoopTests(SendfileTestsBase,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop(selectors.SelectSelector())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
352
Utils/PythonNew32/Lib/test/test_asyncio/test_server.py
Normal file
352
Utils/PythonNew32/Lib/test/test_asyncio/test_server.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import asyncio
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
import threading
|
||||
import unittest
|
||||
|
||||
from test.support import socket_helper
|
||||
from test.test_asyncio import utils as test_utils
|
||||
from test.test_asyncio import functional as func_tests
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class BaseStartServer(func_tests.FunctionalTestCaseMixin):
|
||||
|
||||
def new_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_start_server_1(self):
|
||||
HELLO_MSG = b'1' * 1024 * 5 + b'\n'
|
||||
|
||||
def client(sock, addr):
|
||||
for i in range(10):
|
||||
time.sleep(0.2)
|
||||
if srv.is_serving():
|
||||
break
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
sock.settimeout(2)
|
||||
sock.connect(addr)
|
||||
sock.send(HELLO_MSG)
|
||||
sock.recv_all(1)
|
||||
sock.close()
|
||||
|
||||
async def serve(reader, writer):
|
||||
await reader.readline()
|
||||
main_task.cancel()
|
||||
writer.write(b'1')
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async def main(srv):
|
||||
async with srv:
|
||||
await srv.serve_forever()
|
||||
|
||||
srv = self.loop.run_until_complete(asyncio.start_server(
|
||||
serve, socket_helper.HOSTv4, 0, start_serving=False))
|
||||
|
||||
self.assertFalse(srv.is_serving())
|
||||
|
||||
main_task = self.loop.create_task(main(srv))
|
||||
|
||||
addr = srv.sockets[0].getsockname()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
with self.tcp_client(lambda sock: client(sock, addr)):
|
||||
self.loop.run_until_complete(main_task)
|
||||
|
||||
self.assertEqual(srv.sockets, ())
|
||||
|
||||
self.assertIsNone(srv._sockets)
|
||||
self.assertIsNone(srv._waiters)
|
||||
self.assertFalse(srv.is_serving())
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r'is closed'):
|
||||
self.loop.run_until_complete(srv.serve_forever())
|
||||
|
||||
|
||||
class SelectorStartServerTests(BaseStartServer, unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.SelectorEventLoop()
|
||||
|
||||
@socket_helper.skip_unless_bind_unix_socket
|
||||
def test_start_unix_server_1(self):
|
||||
HELLO_MSG = b'1' * 1024 * 5 + b'\n'
|
||||
started = threading.Event()
|
||||
|
||||
def client(sock, addr):
|
||||
sock.settimeout(2)
|
||||
started.wait(5)
|
||||
sock.connect(addr)
|
||||
sock.send(HELLO_MSG)
|
||||
sock.recv_all(1)
|
||||
sock.close()
|
||||
|
||||
async def serve(reader, writer):
|
||||
await reader.readline()
|
||||
main_task.cancel()
|
||||
writer.write(b'1')
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async def main(srv):
|
||||
async with srv:
|
||||
self.assertFalse(srv.is_serving())
|
||||
await srv.start_serving()
|
||||
self.assertTrue(srv.is_serving())
|
||||
started.set()
|
||||
await srv.serve_forever()
|
||||
|
||||
with test_utils.unix_socket_path() as addr:
|
||||
srv = self.loop.run_until_complete(asyncio.start_unix_server(
|
||||
serve, addr, start_serving=False))
|
||||
|
||||
main_task = self.loop.create_task(main(srv))
|
||||
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
with self.unix_client(lambda sock: client(sock, addr)):
|
||||
self.loop.run_until_complete(main_task)
|
||||
|
||||
self.assertEqual(srv.sockets, ())
|
||||
|
||||
self.assertIsNone(srv._sockets)
|
||||
self.assertIsNone(srv._waiters)
|
||||
self.assertFalse(srv.is_serving())
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r'is closed'):
|
||||
self.loop.run_until_complete(srv.serve_forever())
|
||||
|
||||
|
||||
class TestServer2(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_wait_closed_basic(self):
|
||||
async def serve(rd, wr):
|
||||
try:
|
||||
await rd.read()
|
||||
finally:
|
||||
wr.close()
|
||||
await wr.wait_closed()
|
||||
|
||||
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
|
||||
self.addCleanup(srv.close)
|
||||
|
||||
# active count = 0, not closed: should block
|
||||
task1 = asyncio.create_task(srv.wait_closed())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(task1.done())
|
||||
|
||||
# active count != 0, not closed: should block
|
||||
addr = srv.sockets[0].getsockname()
|
||||
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
|
||||
task2 = asyncio.create_task(srv.wait_closed())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(task1.done())
|
||||
self.assertFalse(task2.done())
|
||||
|
||||
srv.close()
|
||||
await asyncio.sleep(0)
|
||||
# active count != 0, closed: should block
|
||||
task3 = asyncio.create_task(srv.wait_closed())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(task1.done())
|
||||
self.assertFalse(task2.done())
|
||||
self.assertFalse(task3.done())
|
||||
|
||||
wr.close()
|
||||
await wr.wait_closed()
|
||||
# active count == 0, closed: should unblock
|
||||
await task1
|
||||
await task2
|
||||
await task3
|
||||
await srv.wait_closed() # Return immediately
|
||||
|
||||
async def test_wait_closed_race(self):
|
||||
# Test a regression in 3.12.0, should be fixed in 3.12.1
|
||||
async def serve(rd, wr):
|
||||
try:
|
||||
await rd.read()
|
||||
finally:
|
||||
wr.close()
|
||||
await wr.wait_closed()
|
||||
|
||||
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
|
||||
self.addCleanup(srv.close)
|
||||
|
||||
task = asyncio.create_task(srv.wait_closed())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(task.done())
|
||||
addr = srv.sockets[0].getsockname()
|
||||
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.call_soon(srv.close)
|
||||
loop.call_soon(wr.close)
|
||||
await srv.wait_closed()
|
||||
|
||||
async def test_close_clients(self):
|
||||
async def serve(rd, wr):
|
||||
try:
|
||||
await rd.read()
|
||||
finally:
|
||||
wr.close()
|
||||
await wr.wait_closed()
|
||||
|
||||
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
|
||||
self.addCleanup(srv.close)
|
||||
|
||||
addr = srv.sockets[0].getsockname()
|
||||
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
|
||||
self.addCleanup(wr.close)
|
||||
|
||||
task = asyncio.create_task(srv.wait_closed())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(task.done())
|
||||
|
||||
srv.close()
|
||||
srv.close_clients()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(task.done())
|
||||
|
||||
async def test_abort_clients(self):
|
||||
async def serve(rd, wr):
|
||||
fut.set_result((rd, wr))
|
||||
await wr.wait_closed()
|
||||
|
||||
fut = asyncio.Future()
|
||||
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
|
||||
self.addCleanup(srv.close)
|
||||
|
||||
addr = srv.sockets[0].getsockname()
|
||||
(c_rd, c_wr) = await asyncio.open_connection(addr[0], addr[1], limit=4096)
|
||||
self.addCleanup(c_wr.close)
|
||||
|
||||
(s_rd, s_wr) = await fut
|
||||
|
||||
# Limit the socket buffers so we can more reliably overfill them
|
||||
s_sock = s_wr.get_extra_info('socket')
|
||||
s_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
|
||||
c_sock = c_wr.get_extra_info('socket')
|
||||
c_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)
|
||||
|
||||
# Get the reader in to a paused state by sending more than twice
|
||||
# the configured limit
|
||||
s_wr.write(b'a' * 4096)
|
||||
s_wr.write(b'a' * 4096)
|
||||
s_wr.write(b'a' * 4096)
|
||||
while c_wr.transport.is_reading():
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Get the writer in a waiting state by sending data until the
|
||||
# kernel stops accepting more data in the send buffer.
|
||||
# gh-122136: getsockopt() does not reliably report the buffer size
|
||||
# available for message content.
|
||||
# We loop until we start filling up the asyncio buffer.
|
||||
# To avoid an infinite loop we cap at 10 times the expected value
|
||||
c_bufsize = c_sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)
|
||||
s_bufsize = s_sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
|
||||
for i in range(10):
|
||||
s_wr.write(b'a' * c_bufsize)
|
||||
s_wr.write(b'a' * s_bufsize)
|
||||
if s_wr.transport.get_write_buffer_size() > 0:
|
||||
break
|
||||
self.assertNotEqual(s_wr.transport.get_write_buffer_size(), 0)
|
||||
|
||||
task = asyncio.create_task(srv.wait_closed())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(task.done())
|
||||
|
||||
srv.close()
|
||||
srv.abort_clients()
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
self.assertTrue(task.done())
|
||||
|
||||
|
||||
# Test the various corner cases of Unix server socket removal
|
||||
class UnixServerCleanupTests(unittest.IsolatedAsyncioTestCase):
|
||||
@socket_helper.skip_unless_bind_unix_socket
|
||||
async def test_unix_server_addr_cleanup(self):
|
||||
# Default scenario
|
||||
with test_utils.unix_socket_path() as addr:
|
||||
async def serve(*args):
|
||||
pass
|
||||
|
||||
srv = await asyncio.start_unix_server(serve, addr)
|
||||
|
||||
srv.close()
|
||||
self.assertFalse(os.path.exists(addr))
|
||||
|
||||
@socket_helper.skip_unless_bind_unix_socket
|
||||
async def test_unix_server_sock_cleanup(self):
|
||||
# Using already bound socket
|
||||
with test_utils.unix_socket_path() as addr:
|
||||
async def serve(*args):
|
||||
pass
|
||||
|
||||
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(addr)
|
||||
|
||||
srv = await asyncio.start_unix_server(serve, sock=sock)
|
||||
|
||||
srv.close()
|
||||
self.assertFalse(os.path.exists(addr))
|
||||
|
||||
@socket_helper.skip_unless_bind_unix_socket
|
||||
async def test_unix_server_cleanup_gone(self):
|
||||
# Someone else has already cleaned up the socket
|
||||
with test_utils.unix_socket_path() as addr:
|
||||
async def serve(*args):
|
||||
pass
|
||||
|
||||
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(addr)
|
||||
|
||||
srv = await asyncio.start_unix_server(serve, sock=sock)
|
||||
|
||||
os.unlink(addr)
|
||||
|
||||
srv.close()
|
||||
|
||||
@socket_helper.skip_unless_bind_unix_socket
|
||||
async def test_unix_server_cleanup_replaced(self):
|
||||
# Someone else has replaced the socket with their own
|
||||
with test_utils.unix_socket_path() as addr:
|
||||
async def serve(*args):
|
||||
pass
|
||||
|
||||
srv = await asyncio.start_unix_server(serve, addr)
|
||||
|
||||
os.unlink(addr)
|
||||
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
|
||||
sock.bind(addr)
|
||||
|
||||
srv.close()
|
||||
self.assertTrue(os.path.exists(addr))
|
||||
|
||||
@socket_helper.skip_unless_bind_unix_socket
|
||||
async def test_unix_server_cleanup_prevented(self):
|
||||
# Automatic cleanup explicitly disabled
|
||||
with test_utils.unix_socket_path() as addr:
|
||||
async def serve(*args):
|
||||
pass
|
||||
|
||||
srv = await asyncio.start_unix_server(serve, addr, cleanup_socket=False)
|
||||
|
||||
srv.close()
|
||||
self.assertTrue(os.path.exists(addr))
|
||||
|
||||
|
||||
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
|
||||
class ProactorStartServerTests(BaseStartServer, unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.ProactorEventLoop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
679
Utils/PythonNew32/Lib/test/test_asyncio/test_sock_lowlevel.py
Normal file
679
Utils/PythonNew32/Lib/test/test_asyncio/test_sock_lowlevel.py
Normal file
@@ -0,0 +1,679 @@
|
||||
import socket
|
||||
import asyncio
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from asyncio import proactor_events
|
||||
from itertools import cycle, islice
|
||||
from unittest.mock import Mock
|
||||
from test.test_asyncio import utils as test_utils
|
||||
from test import support
|
||||
from test.support import socket_helper
|
||||
|
||||
if socket_helper.tcp_blackhole():
|
||||
raise unittest.SkipTest('Not relevant to ProactorEventLoop')
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class MyProto(asyncio.Protocol):
|
||||
connected = None
|
||||
done = None
|
||||
|
||||
def __init__(self, loop=None):
|
||||
self.transport = None
|
||||
self.state = 'INITIAL'
|
||||
self.nbytes = 0
|
||||
if loop is not None:
|
||||
self.connected = loop.create_future()
|
||||
self.done = loop.create_future()
|
||||
|
||||
def _assert_state(self, *expected):
|
||||
if self.state not in expected:
|
||||
raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
self._assert_state('INITIAL')
|
||||
self.state = 'CONNECTED'
|
||||
if self.connected:
|
||||
self.connected.set_result(None)
|
||||
transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
|
||||
|
||||
def data_received(self, data):
|
||||
self._assert_state('CONNECTED')
|
||||
self.nbytes += len(data)
|
||||
|
||||
def eof_received(self):
|
||||
self._assert_state('CONNECTED')
|
||||
self.state = 'EOF'
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self._assert_state('CONNECTED', 'EOF')
|
||||
self.state = 'CLOSED'
|
||||
if self.done:
|
||||
self.done.set_result(None)
|
||||
|
||||
|
||||
class BaseSockTestsMixin:
|
||||
|
||||
def create_event_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def setUp(self):
|
||||
self.loop = self.create_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
super().setUp()
|
||||
|
||||
def tearDown(self):
|
||||
# just in case if we have transport close callbacks
|
||||
if not self.loop.is_closed():
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
self.doCleanups()
|
||||
support.gc_collect()
|
||||
super().tearDown()
|
||||
|
||||
def _basetest_sock_client_ops(self, httpd, sock):
|
||||
if not isinstance(self.loop, proactor_events.BaseProactorEventLoop):
|
||||
# in debug mode, socket operations must fail
|
||||
# if the socket is not in blocking mode
|
||||
self.loop.set_debug(True)
|
||||
sock.setblocking(True)
|
||||
with self.assertRaises(ValueError):
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_connect(sock, httpd.address))
|
||||
with self.assertRaises(ValueError):
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||
with self.assertRaises(ValueError):
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_recv(sock, 1024))
|
||||
with self.assertRaises(ValueError):
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_recv_into(sock, bytearray()))
|
||||
with self.assertRaises(ValueError):
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_accept(sock))
|
||||
|
||||
# test in non-blocking mode
|
||||
sock.setblocking(False)
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_connect(sock, httpd.address))
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||
data = self.loop.run_until_complete(
|
||||
self.loop.sock_recv(sock, 1024))
|
||||
# consume data
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_recv(sock, 1024))
|
||||
sock.close()
|
||||
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||
|
||||
def _basetest_sock_recv_into(self, httpd, sock):
|
||||
# same as _basetest_sock_client_ops, but using sock_recv_into
|
||||
sock.setblocking(False)
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_connect(sock, httpd.address))
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||
data = bytearray(1024)
|
||||
with memoryview(data) as buf:
|
||||
nbytes = self.loop.run_until_complete(
|
||||
self.loop.sock_recv_into(sock, buf[:1024]))
|
||||
# consume data
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_recv_into(sock, buf[nbytes:]))
|
||||
sock.close()
|
||||
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||
|
||||
def test_sock_client_ops(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
sock = socket.socket()
|
||||
self._basetest_sock_client_ops(httpd, sock)
|
||||
sock = socket.socket()
|
||||
self._basetest_sock_recv_into(httpd, sock)
|
||||
|
||||
async def _basetest_sock_recv_racing(self, httpd, sock):
|
||||
sock.setblocking(False)
|
||||
await self.loop.sock_connect(sock, httpd.address)
|
||||
|
||||
task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
|
||||
asyncio.create_task(
|
||||
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||
data = await self.loop.sock_recv(sock, 1024)
|
||||
# consume data
|
||||
await self.loop.sock_recv(sock, 1024)
|
||||
|
||||
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||
|
||||
async def _basetest_sock_recv_into_racing(self, httpd, sock):
|
||||
sock.setblocking(False)
|
||||
await self.loop.sock_connect(sock, httpd.address)
|
||||
|
||||
data = bytearray(1024)
|
||||
with memoryview(data) as buf:
|
||||
task = asyncio.create_task(
|
||||
self.loop.sock_recv_into(sock, buf[:1024]))
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
|
||||
task = asyncio.create_task(
|
||||
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
|
||||
nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
|
||||
# consume data
|
||||
await self.loop.sock_recv_into(sock, buf[nbytes:])
|
||||
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||
|
||||
await task
|
||||
|
||||
async def _basetest_sock_send_racing(self, listener, sock):
|
||||
listener.bind(('127.0.0.1', 0))
|
||||
listener.listen(1)
|
||||
|
||||
# make connection
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
|
||||
sock.setblocking(False)
|
||||
task = asyncio.create_task(
|
||||
self.loop.sock_connect(sock, listener.getsockname()))
|
||||
await asyncio.sleep(0)
|
||||
server = listener.accept()[0]
|
||||
server.setblocking(False)
|
||||
|
||||
with server:
|
||||
await task
|
||||
|
||||
# fill the buffer until sending 5 chars would block
|
||||
size = 8192
|
||||
while size >= 4:
|
||||
with self.assertRaises(BlockingIOError):
|
||||
while True:
|
||||
sock.send(b' ' * size)
|
||||
size = int(size / 2)
|
||||
|
||||
# cancel a blocked sock_sendall
|
||||
task = asyncio.create_task(
|
||||
self.loop.sock_sendall(sock, b'hello'))
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
|
||||
# receive everything that is not a space
|
||||
async def recv_all():
|
||||
rv = b''
|
||||
while True:
|
||||
buf = await self.loop.sock_recv(server, 8192)
|
||||
if not buf:
|
||||
return rv
|
||||
rv += buf.strip()
|
||||
task = asyncio.create_task(recv_all())
|
||||
|
||||
# immediately make another sock_sendall call
|
||||
await self.loop.sock_sendall(sock, b'world')
|
||||
sock.shutdown(socket.SHUT_WR)
|
||||
data = await task
|
||||
# ProactorEventLoop could deliver hello, so endswith is necessary
|
||||
self.assertTrue(data.endswith(b'world'))
|
||||
|
||||
# After the first connect attempt before the listener is ready,
|
||||
# the socket needs time to "recover" to make the next connect call.
|
||||
# On Linux, a second retry will do. On Windows, the waiting time is
|
||||
# unpredictable; and on FreeBSD the socket may never come back
|
||||
# because it's a loopback address. Here we'll just retry for a few
|
||||
# times, and have to skip the test if it's not working. See also:
|
||||
# https://stackoverflow.com/a/54437602/3316267
|
||||
# https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html
|
||||
async def _basetest_sock_connect_racing(self, listener, sock):
|
||||
listener.bind(('127.0.0.1', 0))
|
||||
addr = listener.getsockname()
|
||||
sock.setblocking(False)
|
||||
|
||||
task = asyncio.create_task(self.loop.sock_connect(sock, addr))
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
|
||||
listener.listen(1)
|
||||
|
||||
skip_reason = "Max retries reached"
|
||||
for i in range(128):
|
||||
try:
|
||||
await self.loop.sock_connect(sock, addr)
|
||||
except ConnectionRefusedError as e:
|
||||
skip_reason = e
|
||||
except OSError as e:
|
||||
skip_reason = e
|
||||
|
||||
# Retry only for this error:
|
||||
# [WinError 10022] An invalid argument was supplied
|
||||
if getattr(e, 'winerror', 0) != 10022:
|
||||
break
|
||||
else:
|
||||
# success
|
||||
return
|
||||
|
||||
self.skipTest(skip_reason)
|
||||
|
||||
def test_sock_client_racing(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
sock = socket.socket()
|
||||
with sock:
|
||||
self.loop.run_until_complete(asyncio.wait_for(
|
||||
self._basetest_sock_recv_racing(httpd, sock), 10))
|
||||
sock = socket.socket()
|
||||
with sock:
|
||||
self.loop.run_until_complete(asyncio.wait_for(
|
||||
self._basetest_sock_recv_into_racing(httpd, sock), 10))
|
||||
listener = socket.socket()
|
||||
sock = socket.socket()
|
||||
with listener, sock:
|
||||
self.loop.run_until_complete(asyncio.wait_for(
|
||||
self._basetest_sock_send_racing(listener, sock), 10))
|
||||
|
||||
def test_sock_client_connect_racing(self):
|
||||
listener = socket.socket()
|
||||
sock = socket.socket()
|
||||
with listener, sock:
|
||||
self.loop.run_until_complete(asyncio.wait_for(
|
||||
self._basetest_sock_connect_racing(listener, sock), 10))
|
||||
|
||||
async def _basetest_huge_content(self, address):
|
||||
sock = socket.socket()
|
||||
sock.setblocking(False)
|
||||
DATA_SIZE = 10_000_00
|
||||
|
||||
chunk = b'0123456789' * (DATA_SIZE // 10)
|
||||
|
||||
await self.loop.sock_connect(sock, address)
|
||||
await self.loop.sock_sendall(sock,
|
||||
(b'POST /loop HTTP/1.0\r\n' +
|
||||
b'Content-Length: %d\r\n' % DATA_SIZE +
|
||||
b'\r\n'))
|
||||
|
||||
task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
|
||||
|
||||
data = await self.loop.sock_recv(sock, DATA_SIZE)
|
||||
# HTTP headers size is less than MTU,
|
||||
# they are sent by the first packet always
|
||||
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||
while data.find(b'\r\n\r\n') == -1:
|
||||
data += await self.loop.sock_recv(sock, DATA_SIZE)
|
||||
# Strip headers
|
||||
headers = data[:data.index(b'\r\n\r\n') + 4]
|
||||
data = data[len(headers):]
|
||||
|
||||
size = DATA_SIZE
|
||||
checker = cycle(b'0123456789')
|
||||
|
||||
expected = bytes(islice(checker, len(data)))
|
||||
self.assertEqual(data, expected)
|
||||
size -= len(data)
|
||||
|
||||
while True:
|
||||
data = await self.loop.sock_recv(sock, DATA_SIZE)
|
||||
if not data:
|
||||
break
|
||||
expected = bytes(islice(checker, len(data)))
|
||||
self.assertEqual(data, expected)
|
||||
size -= len(data)
|
||||
self.assertEqual(size, 0)
|
||||
|
||||
await task
|
||||
sock.close()
|
||||
|
||||
def test_huge_content(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_huge_content(httpd.address))
|
||||
|
||||
async def _basetest_huge_content_recvinto(self, address):
|
||||
sock = socket.socket()
|
||||
sock.setblocking(False)
|
||||
DATA_SIZE = 10_000_00
|
||||
|
||||
chunk = b'0123456789' * (DATA_SIZE // 10)
|
||||
|
||||
await self.loop.sock_connect(sock, address)
|
||||
await self.loop.sock_sendall(sock,
|
||||
(b'POST /loop HTTP/1.0\r\n' +
|
||||
b'Content-Length: %d\r\n' % DATA_SIZE +
|
||||
b'\r\n'))
|
||||
|
||||
task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
|
||||
|
||||
array = bytearray(DATA_SIZE)
|
||||
buf = memoryview(array)
|
||||
|
||||
nbytes = await self.loop.sock_recv_into(sock, buf)
|
||||
data = bytes(buf[:nbytes])
|
||||
# HTTP headers size is less than MTU,
|
||||
# they are sent by the first packet always
|
||||
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
|
||||
while data.find(b'\r\n\r\n') == -1:
|
||||
nbytes = await self.loop.sock_recv_into(sock, buf)
|
||||
data = bytes(buf[:nbytes])
|
||||
# Strip headers
|
||||
headers = data[:data.index(b'\r\n\r\n') + 4]
|
||||
data = data[len(headers):]
|
||||
|
||||
size = DATA_SIZE
|
||||
checker = cycle(b'0123456789')
|
||||
|
||||
expected = bytes(islice(checker, len(data)))
|
||||
self.assertEqual(data, expected)
|
||||
size -= len(data)
|
||||
|
||||
while True:
|
||||
nbytes = await self.loop.sock_recv_into(sock, buf)
|
||||
data = buf[:nbytes]
|
||||
if not data:
|
||||
break
|
||||
expected = bytes(islice(checker, len(data)))
|
||||
self.assertEqual(data, expected)
|
||||
size -= len(data)
|
||||
self.assertEqual(size, 0)
|
||||
|
||||
await task
|
||||
sock.close()
|
||||
|
||||
def test_huge_content_recvinto(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_huge_content_recvinto(httpd.address))
|
||||
|
||||
async def _basetest_datagram_recvfrom(self, server_address):
|
||||
# Happy path, sock.sendto() returns immediately
|
||||
data = b'\x01' * 4096
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.setblocking(False)
|
||||
await self.loop.sock_sendto(sock, data, server_address)
|
||||
received_data, from_addr = await self.loop.sock_recvfrom(
|
||||
sock, 4096)
|
||||
self.assertEqual(received_data, data)
|
||||
self.assertEqual(from_addr, server_address)
|
||||
|
||||
def test_recvfrom(self):
|
||||
with test_utils.run_udp_echo_server() as server_address:
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_datagram_recvfrom(server_address))
|
||||
|
||||
async def _basetest_datagram_recvfrom_into(self, server_address):
|
||||
# Happy path, sock.sendto() returns immediately
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.setblocking(False)
|
||||
|
||||
buf = bytearray(4096)
|
||||
data = b'\x01' * 4096
|
||||
await self.loop.sock_sendto(sock, data, server_address)
|
||||
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
|
||||
sock, buf)
|
||||
self.assertEqual(num_bytes, 4096)
|
||||
self.assertEqual(buf, data)
|
||||
self.assertEqual(from_addr, server_address)
|
||||
|
||||
buf = bytearray(8192)
|
||||
await self.loop.sock_sendto(sock, data, server_address)
|
||||
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
|
||||
sock, buf, 4096)
|
||||
self.assertEqual(num_bytes, 4096)
|
||||
self.assertEqual(buf[:4096], data[:4096])
|
||||
self.assertEqual(from_addr, server_address)
|
||||
|
||||
def test_recvfrom_into(self):
|
||||
with test_utils.run_udp_echo_server() as server_address:
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_datagram_recvfrom_into(server_address))
|
||||
|
||||
async def _basetest_datagram_sendto_blocking(self, server_address):
|
||||
# Sad path, sock.sendto() raises BlockingIOError
|
||||
# This involves patching sock.sendto() to raise BlockingIOError but
|
||||
# sendto() is not used by the proactor event loop
|
||||
data = b'\x01' * 4096
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
|
||||
sock.setblocking(False)
|
||||
mock_sock = Mock(sock)
|
||||
mock_sock.gettimeout = sock.gettimeout
|
||||
mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
|
||||
mock_sock.fileno = sock.fileno
|
||||
self.loop.call_soon(
|
||||
lambda: setattr(mock_sock, 'sendto', sock.sendto)
|
||||
)
|
||||
await self.loop.sock_sendto(mock_sock, data, server_address)
|
||||
|
||||
received_data, from_addr = await self.loop.sock_recvfrom(
|
||||
sock, 4096)
|
||||
self.assertEqual(received_data, data)
|
||||
self.assertEqual(from_addr, server_address)
|
||||
|
||||
def test_sendto_blocking(self):
|
||||
if sys.platform == 'win32':
|
||||
if isinstance(self.loop, asyncio.ProactorEventLoop):
|
||||
raise unittest.SkipTest('Not relevant to ProactorEventLoop')
|
||||
|
||||
with test_utils.run_udp_echo_server() as server_address:
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_datagram_sendto_blocking(server_address))
|
||||
|
||||
@socket_helper.skip_unless_bind_unix_socket
|
||||
def test_unix_sock_client_ops(self):
|
||||
with test_utils.run_test_unix_server() as httpd:
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
self._basetest_sock_client_ops(httpd, sock)
|
||||
sock = socket.socket(socket.AF_UNIX)
|
||||
self._basetest_sock_recv_into(httpd, sock)
|
||||
|
||||
def test_sock_client_fail(self):
|
||||
# Make sure that we will get an unused port
|
||||
address = None
|
||||
try:
|
||||
s = socket.socket()
|
||||
s.bind(('127.0.0.1', 0))
|
||||
address = s.getsockname()
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
sock = socket.socket()
|
||||
sock.setblocking(False)
|
||||
with self.assertRaises(ConnectionRefusedError):
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_connect(sock, address))
|
||||
sock.close()
|
||||
|
||||
def test_sock_accept(self):
|
||||
listener = socket.socket()
|
||||
listener.setblocking(False)
|
||||
listener.bind(('127.0.0.1', 0))
|
||||
listener.listen(1)
|
||||
client = socket.socket()
|
||||
client.connect(listener.getsockname())
|
||||
|
||||
f = self.loop.sock_accept(listener)
|
||||
conn, addr = self.loop.run_until_complete(f)
|
||||
self.assertEqual(conn.gettimeout(), 0)
|
||||
self.assertEqual(addr, client.getsockname())
|
||||
self.assertEqual(client.getpeername(), listener.getsockname())
|
||||
client.close()
|
||||
conn.close()
|
||||
listener.close()
|
||||
|
||||
def test_cancel_sock_accept(self):
|
||||
listener = socket.socket()
|
||||
listener.setblocking(False)
|
||||
listener.bind(('127.0.0.1', 0))
|
||||
listener.listen(1)
|
||||
sockaddr = listener.getsockname()
|
||||
f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1)
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
self.loop.run_until_complete(f)
|
||||
|
||||
listener.close()
|
||||
client = socket.socket()
|
||||
client.setblocking(False)
|
||||
f = self.loop.sock_connect(client, sockaddr)
|
||||
with self.assertRaises(ConnectionRefusedError):
|
||||
self.loop.run_until_complete(f)
|
||||
|
||||
client.close()
|
||||
|
||||
def test_create_connection_sock(self):
|
||||
with test_utils.run_test_server() as httpd:
|
||||
sock = None
|
||||
infos = self.loop.run_until_complete(
|
||||
self.loop.getaddrinfo(
|
||||
*httpd.address, type=socket.SOCK_STREAM))
|
||||
for family, type, proto, cname, address in infos:
|
||||
try:
|
||||
sock = socket.socket(family=family, type=type, proto=proto)
|
||||
sock.setblocking(False)
|
||||
self.loop.run_until_complete(
|
||||
self.loop.sock_connect(sock, address))
|
||||
except BaseException:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
else:
|
||||
self.fail('Can not create socket.')
|
||||
|
||||
f = self.loop.create_connection(
|
||||
lambda: MyProto(loop=self.loop), sock=sock)
|
||||
tr, pr = self.loop.run_until_complete(f)
|
||||
self.assertIsInstance(tr, asyncio.Transport)
|
||||
self.assertIsInstance(pr, asyncio.Protocol)
|
||||
self.loop.run_until_complete(pr.done)
|
||||
self.assertGreater(pr.nbytes, 0)
|
||||
tr.close()
|
||||
|
||||
|
||||
if sys.platform == 'win32':
|
||||
|
||||
class SelectEventLoopTests(BaseSockTestsMixin,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop()
|
||||
|
||||
|
||||
class ProactorEventLoopTests(BaseSockTestsMixin,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.ProactorEventLoop()
|
||||
|
||||
|
||||
async def _basetest_datagram_send_to_non_listening_address(self,
|
||||
recvfrom):
|
||||
# see:
|
||||
# https://github.com/python/cpython/issues/91227
|
||||
# https://github.com/python/cpython/issues/88906
|
||||
# https://bugs.python.org/issue47071
|
||||
# https://bugs.python.org/issue44743
|
||||
# The Proactor event loop would fail to receive datagram messages
|
||||
# after sending a message to an address that wasn't listening.
|
||||
|
||||
def create_socket():
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
sock.setblocking(False)
|
||||
sock.bind(('127.0.0.1', 0))
|
||||
return sock
|
||||
|
||||
socket_1 = create_socket()
|
||||
addr_1 = socket_1.getsockname()
|
||||
|
||||
socket_2 = create_socket()
|
||||
addr_2 = socket_2.getsockname()
|
||||
|
||||
# creating and immediately closing this to try to get an address
|
||||
# that is not listening
|
||||
socket_3 = create_socket()
|
||||
addr_3 = socket_3.getsockname()
|
||||
socket_3.shutdown(socket.SHUT_RDWR)
|
||||
socket_3.close()
|
||||
|
||||
socket_1_recv_task = self.loop.create_task(recvfrom(socket_1))
|
||||
socket_2_recv_task = self.loop.create_task(recvfrom(socket_2))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await self.loop.sock_sendto(socket_1, b'a', addr_2)
|
||||
self.assertEqual(await socket_2_recv_task, b'a')
|
||||
|
||||
await self.loop.sock_sendto(socket_2, b'b', addr_1)
|
||||
self.assertEqual(await socket_1_recv_task, b'b')
|
||||
socket_1_recv_task = self.loop.create_task(recvfrom(socket_1))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# this should send to an address that isn't listening
|
||||
await self.loop.sock_sendto(socket_1, b'c', addr_3)
|
||||
self.assertEqual(await socket_1_recv_task, b'')
|
||||
socket_1_recv_task = self.loop.create_task(recvfrom(socket_1))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# socket 1 should still be able to receive messages after sending
|
||||
# to an address that wasn't listening
|
||||
socket_2.sendto(b'd', addr_1)
|
||||
self.assertEqual(await socket_1_recv_task, b'd')
|
||||
|
||||
socket_1.shutdown(socket.SHUT_RDWR)
|
||||
socket_1.close()
|
||||
socket_2.shutdown(socket.SHUT_RDWR)
|
||||
socket_2.close()
|
||||
|
||||
|
||||
def test_datagram_send_to_non_listening_address_recvfrom(self):
|
||||
async def recvfrom(socket):
|
||||
data, _ = await self.loop.sock_recvfrom(socket, 4096)
|
||||
return data
|
||||
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_datagram_send_to_non_listening_address(
|
||||
recvfrom))
|
||||
|
||||
|
||||
def test_datagram_send_to_non_listening_address_recvfrom_into(self):
|
||||
async def recvfrom_into(socket):
|
||||
buf = bytearray(4096)
|
||||
length, _ = await self.loop.sock_recvfrom_into(socket, buf,
|
||||
4096)
|
||||
return buf[:length]
|
||||
|
||||
self.loop.run_until_complete(
|
||||
self._basetest_datagram_send_to_non_listening_address(
|
||||
recvfrom_into))
|
||||
|
||||
else:
|
||||
import selectors
|
||||
|
||||
if hasattr(selectors, 'KqueueSelector'):
|
||||
class KqueueEventLoopTests(BaseSockTestsMixin,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop(
|
||||
selectors.KqueueSelector())
|
||||
|
||||
if hasattr(selectors, 'EpollSelector'):
|
||||
class EPollEventLoopTests(BaseSockTestsMixin,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop(selectors.EpollSelector())
|
||||
|
||||
if hasattr(selectors, 'PollSelector'):
|
||||
class PollEventLoopTests(BaseSockTestsMixin,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop(selectors.PollSelector())
|
||||
|
||||
# Should always exist.
|
||||
class SelectEventLoopTests(BaseSockTestsMixin,
|
||||
test_utils.TestCase):
|
||||
|
||||
def create_event_loop(self):
|
||||
return asyncio.SelectorEventLoop(selectors.SelectSelector())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
1905
Utils/PythonNew32/Lib/test/test_asyncio/test_ssl.py
Normal file
1905
Utils/PythonNew32/Lib/test/test_asyncio/test_ssl.py
Normal file
File diff suppressed because it is too large
Load Diff
843
Utils/PythonNew32/Lib/test/test_asyncio/test_sslproto.py
Normal file
843
Utils/PythonNew32/Lib/test/test_asyncio/test_sslproto.py
Normal file
@@ -0,0 +1,843 @@
|
||||
"""Tests for asyncio/sslproto.py."""
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import unittest
|
||||
import weakref
|
||||
from test import support
|
||||
from test.support import socket_helper
|
||||
from unittest import mock
|
||||
try:
|
||||
import ssl
|
||||
except ImportError:
|
||||
ssl = None
|
||||
|
||||
import asyncio
|
||||
from asyncio import log
|
||||
from asyncio import protocols
|
||||
from asyncio import sslproto
|
||||
from test.test_asyncio import utils as test_utils
|
||||
from test.test_asyncio import functional as func_tests
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
class SslProtoHandshakeTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
def ssl_protocol(self, *, waiter=None, proto=None):
|
||||
sslcontext = test_utils.dummy_ssl_context()
|
||||
if proto is None: # app protocol
|
||||
proto = asyncio.Protocol()
|
||||
ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
|
||||
ssl_handshake_timeout=0.1)
|
||||
self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
|
||||
self.addCleanup(ssl_proto._app_transport.close)
|
||||
return ssl_proto
|
||||
|
||||
def connection_made(self, ssl_proto, *, do_handshake=None):
|
||||
transport = mock.Mock()
|
||||
sslobj = mock.Mock()
|
||||
# emulate reading decompressed data
|
||||
sslobj.read.side_effect = ssl.SSLWantReadError
|
||||
sslobj.write.side_effect = ssl.SSLWantReadError
|
||||
if do_handshake is not None:
|
||||
sslobj.do_handshake = do_handshake
|
||||
ssl_proto._sslobj = sslobj
|
||||
ssl_proto.connection_made(transport)
|
||||
return transport
|
||||
|
||||
def test_handshake_timeout_zero(self):
|
||||
sslcontext = test_utils.dummy_ssl_context()
|
||||
app_proto = mock.Mock()
|
||||
waiter = mock.Mock()
|
||||
with self.assertRaisesRegex(ValueError, 'a positive number'):
|
||||
sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
|
||||
ssl_handshake_timeout=0)
|
||||
|
||||
def test_handshake_timeout_negative(self):
|
||||
sslcontext = test_utils.dummy_ssl_context()
|
||||
app_proto = mock.Mock()
|
||||
waiter = mock.Mock()
|
||||
with self.assertRaisesRegex(ValueError, 'a positive number'):
|
||||
sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
|
||||
ssl_handshake_timeout=-10)
|
||||
|
||||
def test_eof_received_waiter(self):
|
||||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
self.connection_made(
|
||||
ssl_proto,
|
||||
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
|
||||
)
|
||||
ssl_proto.eof_received()
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertIsInstance(waiter.exception(), ConnectionResetError)
|
||||
|
||||
def test_fatal_error_no_name_error(self):
|
||||
# From issue #363.
|
||||
# _fatal_error() generates a NameError if sslproto.py
|
||||
# does not import base_events.
|
||||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
# Temporarily turn off error logging so as not to spoil test output.
|
||||
log_level = log.logger.getEffectiveLevel()
|
||||
log.logger.setLevel(logging.FATAL)
|
||||
try:
|
||||
ssl_proto._fatal_error(None)
|
||||
finally:
|
||||
# Restore error logging.
|
||||
log.logger.setLevel(log_level)
|
||||
|
||||
def test_connection_lost(self):
|
||||
# From issue #472.
|
||||
# yield from waiter hang if lost_connection was called.
|
||||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
self.connection_made(
|
||||
ssl_proto,
|
||||
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
|
||||
)
|
||||
ssl_proto.connection_lost(ConnectionAbortedError)
|
||||
test_utils.run_briefly(self.loop)
|
||||
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
|
||||
|
||||
def test_connection_lost_when_busy(self):
|
||||
# gh-118950: SSLProtocol.connection_lost not being called when OSError
|
||||
# is thrown on asyncio.write.
|
||||
sock = mock.Mock()
|
||||
sock.fileno = mock.Mock(return_value=12345)
|
||||
sock.send = mock.Mock(side_effect=BrokenPipeError)
|
||||
|
||||
# construct StreamWriter chain that contains loop dependant logic this emulates
|
||||
# what _make_ssl_transport() does in BaseSelectorEventLoop
|
||||
reader = asyncio.StreamReader(limit=2 ** 16, loop=self.loop)
|
||||
protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
|
||||
ssl_proto = self.ssl_protocol(proto=protocol)
|
||||
|
||||
# emulate reading decompressed data
|
||||
sslobj = mock.Mock()
|
||||
sslobj.read.side_effect = ssl.SSLWantReadError
|
||||
sslobj.write.side_effect = ssl.SSLWantReadError
|
||||
ssl_proto._sslobj = sslobj
|
||||
|
||||
# emulate outgoing data
|
||||
data = b'An interesting message'
|
||||
|
||||
outgoing = mock.Mock()
|
||||
outgoing.read = mock.Mock(return_value=data)
|
||||
outgoing.pending = len(data)
|
||||
ssl_proto._outgoing = outgoing
|
||||
|
||||
# use correct socket transport to initialize the SSLProtocol
|
||||
self.loop._make_socket_transport(sock, ssl_proto)
|
||||
|
||||
transport = ssl_proto._app_transport
|
||||
writer = asyncio.StreamWriter(transport, protocol, reader, self.loop)
|
||||
|
||||
async def main():
|
||||
# writes data to transport
|
||||
async def write():
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
|
||||
# try to write for the first time
|
||||
await write()
|
||||
# try to write for the second time, this raises as the connection_lost
|
||||
# callback should be done with error
|
||||
with self.assertRaises(ConnectionResetError):
|
||||
await write()
|
||||
|
||||
self.loop.run_until_complete(main())
|
||||
|
||||
def test_close_during_handshake(self):
|
||||
# bpo-29743 Closing transport during handshake process leaks socket
|
||||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
|
||||
transport = self.connection_made(
|
||||
ssl_proto,
|
||||
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
|
||||
)
|
||||
test_utils.run_briefly(self.loop)
|
||||
|
||||
ssl_proto._app_transport.close()
|
||||
self.assertTrue(transport._force_close.called)
|
||||
|
||||
def test_close_during_ssl_over_ssl(self):
|
||||
# gh-113214: passing exceptions from the inner wrapped SSL protocol to the
|
||||
# shim transport provided by the outer SSL protocol should not raise
|
||||
# attribute errors
|
||||
outer = self.ssl_protocol(proto=self.ssl_protocol())
|
||||
self.connection_made(outer)
|
||||
# Closing the outer app transport should not raise an exception
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
outer._app_transport.close()
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
def test_get_extra_info_on_closed_connection(self):
|
||||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
self.assertIsNone(ssl_proto._get_extra_info('socket'))
|
||||
default = object()
|
||||
self.assertIs(ssl_proto._get_extra_info('socket', default), default)
|
||||
self.connection_made(ssl_proto)
|
||||
self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
|
||||
ssl_proto.connection_lost(None)
|
||||
self.assertIsNone(ssl_proto._get_extra_info('socket'))
|
||||
|
||||
def test_set_new_app_protocol(self):
|
||||
waiter = self.loop.create_future()
|
||||
ssl_proto = self.ssl_protocol(waiter=waiter)
|
||||
new_app_proto = asyncio.Protocol()
|
||||
ssl_proto._app_transport.set_protocol(new_app_proto)
|
||||
self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
|
||||
self.assertIs(ssl_proto._app_protocol, new_app_proto)
|
||||
|
||||
def test_data_received_after_closing(self):
|
||||
ssl_proto = self.ssl_protocol()
|
||||
self.connection_made(ssl_proto)
|
||||
transp = ssl_proto._app_transport
|
||||
|
||||
transp.close()
|
||||
|
||||
# should not raise
|
||||
self.assertIsNone(ssl_proto.buffer_updated(5))
|
||||
|
||||
def test_write_after_closing(self):
|
||||
ssl_proto = self.ssl_protocol()
|
||||
self.connection_made(ssl_proto)
|
||||
transp = ssl_proto._app_transport
|
||||
transp.close()
|
||||
|
||||
# should not raise
|
||||
self.assertIsNone(transp.write(b'data'))
|
||||
|
||||
|
||||
##############################################################################
|
||||
# Start TLS Tests
|
||||
##############################################################################
|
||||
|
||||
|
||||
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
|
||||
|
||||
PAYLOAD_SIZE = 1024 * 100
|
||||
TIMEOUT = support.LONG_TIMEOUT
|
||||
|
||||
def new_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def test_buf_feed_data(self):
|
||||
|
||||
class Proto(asyncio.BufferedProtocol):
|
||||
|
||||
def __init__(self, bufsize, usemv):
|
||||
self.buf = bytearray(bufsize)
|
||||
self.mv = memoryview(self.buf)
|
||||
self.data = b''
|
||||
self.usemv = usemv
|
||||
|
||||
def get_buffer(self, sizehint):
|
||||
if self.usemv:
|
||||
return self.mv
|
||||
else:
|
||||
return self.buf
|
||||
|
||||
def buffer_updated(self, nsize):
|
||||
if self.usemv:
|
||||
self.data += self.mv[:nsize]
|
||||
else:
|
||||
self.data += self.buf[:nsize]
|
||||
|
||||
for usemv in [False, True]:
|
||||
proto = Proto(1, usemv)
|
||||
protocols._feed_data_to_buffered_proto(proto, b'12345')
|
||||
self.assertEqual(proto.data, b'12345')
|
||||
|
||||
proto = Proto(2, usemv)
|
||||
protocols._feed_data_to_buffered_proto(proto, b'12345')
|
||||
self.assertEqual(proto.data, b'12345')
|
||||
|
||||
proto = Proto(2, usemv)
|
||||
protocols._feed_data_to_buffered_proto(proto, b'1234')
|
||||
self.assertEqual(proto.data, b'1234')
|
||||
|
||||
proto = Proto(4, usemv)
|
||||
protocols._feed_data_to_buffered_proto(proto, b'1234')
|
||||
self.assertEqual(proto.data, b'1234')
|
||||
|
||||
proto = Proto(100, usemv)
|
||||
protocols._feed_data_to_buffered_proto(proto, b'12345')
|
||||
self.assertEqual(proto.data, b'12345')
|
||||
|
||||
proto = Proto(0, usemv)
|
||||
with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
|
||||
protocols._feed_data_to_buffered_proto(proto, b'12345')
|
||||
|
||||
def test_start_tls_client_reg_proto_1(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
def serve(sock):
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.start_tls(server_context, server_side=True)
|
||||
|
||||
sock.sendall(b'O')
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
class ClientProto(asyncio.Protocol):
|
||||
def __init__(self, on_data, on_eof):
|
||||
self.on_data = on_data
|
||||
self.on_eof = on_eof
|
||||
self.con_made_cnt = 0
|
||||
|
||||
def connection_made(proto, tr):
|
||||
proto.con_made_cnt += 1
|
||||
# Ensure connection_made gets called only once.
|
||||
self.assertEqual(proto.con_made_cnt, 1)
|
||||
|
||||
def data_received(self, data):
|
||||
self.on_data.set_result(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.on_eof.set_result(True)
|
||||
|
||||
async def client(addr):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
on_data = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
|
||||
tr, proto = await self.loop.create_connection(
|
||||
lambda: ClientProto(on_data, on_eof), *addr)
|
||||
|
||||
tr.write(HELLO_MSG)
|
||||
new_tr = await self.loop.start_tls(tr, proto, client_context)
|
||||
|
||||
self.assertEqual(await on_data, b'O')
|
||||
new_tr.write(HELLO_MSG)
|
||||
await on_eof
|
||||
|
||||
new_tr.close()
|
||||
|
||||
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv.addr),
|
||||
timeout=support.SHORT_TIMEOUT))
|
||||
|
||||
# No garbage is left if SSL is closed uncleanly
|
||||
client_context = weakref.ref(client_context)
|
||||
support.gc_collect()
|
||||
self.assertIsNone(client_context())
|
||||
|
||||
def test_create_connection_memory_leak(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
|
||||
def serve(sock):
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
sock.start_tls(server_context, server_side=True)
|
||||
|
||||
sock.sendall(b'O')
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
class ClientProto(asyncio.Protocol):
|
||||
def __init__(self, on_data, on_eof):
|
||||
self.on_data = on_data
|
||||
self.on_eof = on_eof
|
||||
self.con_made_cnt = 0
|
||||
|
||||
def connection_made(proto, tr):
|
||||
# XXX: We assume user stores the transport in protocol
|
||||
proto.tr = tr
|
||||
proto.con_made_cnt += 1
|
||||
# Ensure connection_made gets called only once.
|
||||
self.assertEqual(proto.con_made_cnt, 1)
|
||||
|
||||
def data_received(self, data):
|
||||
self.on_data.set_result(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.on_eof.set_result(True)
|
||||
|
||||
async def client(addr):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
on_data = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
|
||||
tr, proto = await self.loop.create_connection(
|
||||
lambda: ClientProto(on_data, on_eof), *addr,
|
||||
ssl=client_context)
|
||||
|
||||
self.assertEqual(await on_data, b'O')
|
||||
tr.write(HELLO_MSG)
|
||||
await on_eof
|
||||
|
||||
tr.close()
|
||||
|
||||
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv.addr),
|
||||
timeout=support.SHORT_TIMEOUT))
|
||||
|
||||
# No garbage is left for SSL client from loop.create_connection, even
|
||||
# if user stores the SSLTransport in corresponding protocol instance
|
||||
client_context = weakref.ref(client_context)
|
||||
support.gc_collect()
|
||||
self.assertIsNone(client_context())
|
||||
|
||||
@socket_helper.skip_if_tcp_blackhole
|
||||
def test_start_tls_client_buf_proto_1(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
client_con_made_calls = 0
|
||||
|
||||
def serve(sock):
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.start_tls(server_context, server_side=True)
|
||||
|
||||
sock.sendall(b'O')
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.sendall(b'2')
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
class ClientProtoFirst(asyncio.BufferedProtocol):
|
||||
def __init__(self, on_data):
|
||||
self.on_data = on_data
|
||||
self.buf = bytearray(1)
|
||||
|
||||
def connection_made(self, tr):
|
||||
nonlocal client_con_made_calls
|
||||
client_con_made_calls += 1
|
||||
|
||||
def get_buffer(self, sizehint):
|
||||
return self.buf
|
||||
|
||||
def buffer_updated(slf, nsize):
|
||||
self.assertEqual(nsize, 1)
|
||||
slf.on_data.set_result(bytes(slf.buf[:nsize]))
|
||||
|
||||
class ClientProtoSecond(asyncio.Protocol):
|
||||
def __init__(self, on_data, on_eof):
|
||||
self.on_data = on_data
|
||||
self.on_eof = on_eof
|
||||
self.con_made_cnt = 0
|
||||
|
||||
def connection_made(self, tr):
|
||||
nonlocal client_con_made_calls
|
||||
client_con_made_calls += 1
|
||||
|
||||
def data_received(self, data):
|
||||
self.on_data.set_result(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.on_eof.set_result(True)
|
||||
|
||||
async def client(addr):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
on_data1 = self.loop.create_future()
|
||||
on_data2 = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
|
||||
tr, proto = await self.loop.create_connection(
|
||||
lambda: ClientProtoFirst(on_data1), *addr)
|
||||
|
||||
tr.write(HELLO_MSG)
|
||||
new_tr = await self.loop.start_tls(tr, proto, client_context)
|
||||
|
||||
self.assertEqual(await on_data1, b'O')
|
||||
new_tr.write(HELLO_MSG)
|
||||
|
||||
new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
|
||||
self.assertEqual(await on_data2, b'2')
|
||||
new_tr.write(HELLO_MSG)
|
||||
await on_eof
|
||||
|
||||
new_tr.close()
|
||||
|
||||
# connection_made() should be called only once -- when
|
||||
# we establish connection for the first time. Start TLS
|
||||
# doesn't call connection_made() on application protocols.
|
||||
self.assertEqual(client_con_made_calls, 1)
|
||||
|
||||
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv.addr),
|
||||
timeout=self.TIMEOUT))
|
||||
|
||||
def test_start_tls_slow_client_cancel(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
server_waits_on_handshake = self.loop.create_future()
|
||||
|
||||
def serve(sock):
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
try:
|
||||
self.loop.call_soon_threadsafe(
|
||||
server_waits_on_handshake.set_result, None)
|
||||
data = sock.recv_all(1024 * 1024)
|
||||
except ConnectionAbortedError:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
class ClientProto(asyncio.Protocol):
|
||||
def __init__(self, on_data, on_eof):
|
||||
self.on_data = on_data
|
||||
self.on_eof = on_eof
|
||||
self.con_made_cnt = 0
|
||||
|
||||
def connection_made(proto, tr):
|
||||
proto.con_made_cnt += 1
|
||||
# Ensure connection_made gets called only once.
|
||||
self.assertEqual(proto.con_made_cnt, 1)
|
||||
|
||||
def data_received(self, data):
|
||||
self.on_data.set_result(data)
|
||||
|
||||
def eof_received(self):
|
||||
self.on_eof.set_result(True)
|
||||
|
||||
async def client(addr):
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
on_data = self.loop.create_future()
|
||||
on_eof = self.loop.create_future()
|
||||
|
||||
tr, proto = await self.loop.create_connection(
|
||||
lambda: ClientProto(on_data, on_eof), *addr)
|
||||
|
||||
tr.write(HELLO_MSG)
|
||||
|
||||
await server_waits_on_handshake
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(
|
||||
self.loop.start_tls(tr, proto, client_context),
|
||||
0.5)
|
||||
|
||||
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(client(srv.addr),
|
||||
timeout=support.SHORT_TIMEOUT))
|
||||
|
||||
@socket_helper.skip_if_tcp_blackhole
|
||||
def test_start_tls_server_1(self):
|
||||
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
|
||||
ANSWER = b'answer'
|
||||
|
||||
server_context = test_utils.simple_server_sslcontext()
|
||||
client_context = test_utils.simple_client_sslcontext()
|
||||
answer = None
|
||||
|
||||
def client(sock, addr):
|
||||
nonlocal answer
|
||||
sock.settimeout(self.TIMEOUT)
|
||||
|
||||
sock.connect(addr)
|
||||
data = sock.recv_all(len(HELLO_MSG))
|
||||
self.assertEqual(len(data), len(HELLO_MSG))
|
||||
|
||||
sock.start_tls(client_context)
|
||||
sock.sendall(HELLO_MSG)
|
||||
answer = sock.recv_all(len(ANSWER))
|
||||
sock.close()
|
||||
|
||||
class ServerProto(asyncio.Protocol):
|
||||
def __init__(self, on_con, on_con_lost, on_got_hello):
|
||||
self.on_con = on_con
|
||||
self.on_con_lost = on_con_lost
|
||||
self.on_got_hello = on_got_hello
|
||||
self.data = b''
|
||||
self.transport = None
|
||||
|
||||
def connection_made(self, tr):
|
||||
self.transport = tr
|
||||
self.on_con.set_result(tr)
|
||||
|
||||
def replace_transport(self, tr):
|
||||
self.transport = tr
|
||||
|
||||
def data_received(self, data):
|
||||
self.data += data
|
||||
if len(self.data) >= len(HELLO_MSG):
|
||||
self.on_got_hello.set_result(None)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
self.transport = None
|
||||
if exc is None:
|
||||
self.on_con_lost.set_result(None)
|
||||
else:
|
||||
self.on_con_lost.set_exception(exc)
|
||||
|
||||
async def main(proto, on_con, on_con_lost, on_got_hello):
|
||||
tr = await on_con
|
||||
tr.write(HELLO_MSG)
|
||||
|
||||
self.assertEqual(proto.data, b'')
|
||||
|
||||
new_tr = await self.loop.start_tls(
|
||||
tr, proto, server_context,
|
||||
server_side=True,
|
||||
ssl_handshake_timeout=self.TIMEOUT)
|
||||
proto.replace_transport(new_tr)
|
||||
|
||||
await on_got_hello
|
||||
new_tr.write(ANSWER)
|
||||
|
||||
await on_con_lost
|
||||
self.assertEqual(proto.data, HELLO_MSG)
|
||||
new_tr.close()
|
||||
|
||||
async def run_main():
|
||||
on_con = self.loop.create_future()
|
||||
on_con_lost = self.loop.create_future()
|
||||
on_got_hello = self.loop.create_future()
|
||||
proto = ServerProto(on_con, on_con_lost, on_got_hello)
|
||||
|
||||
server = await self.loop.create_server(
|
||||
lambda: proto, '127.0.0.1', 0)
|
||||
addr = server.sockets[0].getsockname()
|
||||
|
||||
with self.tcp_client(lambda sock: client(sock, addr),
|
||||
timeout=self.TIMEOUT):
|
||||
await asyncio.wait_for(
|
||||
main(proto, on_con, on_con_lost, on_got_hello),
|
||||
timeout=self.TIMEOUT)
|
||||
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
self.assertEqual(answer, ANSWER)
|
||||
|
||||
self.loop.run_until_complete(run_main())
|
||||
|
||||
def test_start_tls_wrong_args(self):
|
||||
async def main():
|
||||
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
|
||||
await self.loop.start_tls(None, None, None)
|
||||
|
||||
sslctx = test_utils.simple_server_sslcontext()
|
||||
with self.assertRaisesRegex(TypeError, 'is not supported'):
|
||||
await self.loop.start_tls(None, None, sslctx)
|
||||
|
||||
self.loop.run_until_complete(main())
|
||||
|
||||
def test_handshake_timeout(self):
|
||||
# bpo-29970: Check that a connection is aborted if handshake is not
|
||||
# completed in timeout period, instead of remaining open indefinitely
|
||||
client_sslctx = test_utils.simple_client_sslcontext()
|
||||
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
|
||||
server_side_aborted = False
|
||||
|
||||
def server(sock):
|
||||
nonlocal server_side_aborted
|
||||
try:
|
||||
sock.recv_all(1024 * 1024)
|
||||
except ConnectionAbortedError:
|
||||
server_side_aborted = True
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async def client(addr):
|
||||
await asyncio.wait_for(
|
||||
self.loop.create_connection(
|
||||
asyncio.Protocol,
|
||||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=support.SHORT_TIMEOUT),
|
||||
0.5)
|
||||
|
||||
with self.tcp_server(server,
|
||||
max_clients=1,
|
||||
backlog=1) as srv:
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
self.assertTrue(server_side_aborted)
|
||||
|
||||
# Python issue #23197: cancelling a handshake must not raise an
|
||||
# exception or log an error, even if the handshake failed
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
# The 10s handshake timeout should be cancelled to free related
|
||||
# objects without really waiting for 10s
|
||||
client_sslctx = weakref.ref(client_sslctx)
|
||||
support.gc_collect()
|
||||
self.assertIsNone(client_sslctx())
|
||||
|
||||
def test_create_connection_ssl_slow_handshake(self):
|
||||
client_sslctx = test_utils.simple_client_sslcontext()
|
||||
|
||||
messages = []
|
||||
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
|
||||
|
||||
def server(sock):
|
||||
try:
|
||||
sock.recv_all(1024 * 1024)
|
||||
except ConnectionAbortedError:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=1.0)
|
||||
|
||||
with self.tcp_server(server,
|
||||
max_clients=1,
|
||||
backlog=1) as srv:
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ConnectionAbortedError,
|
||||
r'SSL handshake.*is taking longer'):
|
||||
|
||||
self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
self.assertEqual(messages, [])
|
||||
|
||||
def test_create_connection_ssl_failed_certificate(self):
|
||||
self.loop.set_exception_handler(lambda loop, ctx: None)
|
||||
|
||||
sslctx = test_utils.simple_server_sslcontext()
|
||||
client_sslctx = test_utils.simple_client_sslcontext(
|
||||
disable_verify=False)
|
||||
|
||||
def server(sock):
|
||||
try:
|
||||
sock.start_tls(
|
||||
sslctx,
|
||||
server_side=True)
|
||||
except ssl.SSLError:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='',
|
||||
ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
|
||||
|
||||
with self.tcp_server(server,
|
||||
max_clients=1,
|
||||
backlog=1) as srv:
|
||||
|
||||
with self.assertRaises(ssl.SSLCertVerificationError):
|
||||
self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
def test_start_tls_client_corrupted_ssl(self):
|
||||
self.loop.set_exception_handler(lambda loop, ctx: None)
|
||||
|
||||
sslctx = test_utils.simple_server_sslcontext()
|
||||
client_sslctx = test_utils.simple_client_sslcontext()
|
||||
|
||||
def server(sock):
|
||||
orig_sock = sock.dup()
|
||||
try:
|
||||
sock.start_tls(
|
||||
sslctx,
|
||||
server_side=True)
|
||||
sock.sendall(b'A\n')
|
||||
sock.recv_all(1)
|
||||
orig_sock.send(b'please corrupt the SSL connection')
|
||||
except ssl.SSLError:
|
||||
pass
|
||||
finally:
|
||||
orig_sock.close()
|
||||
sock.close()
|
||||
|
||||
async def client(addr):
|
||||
reader, writer = await asyncio.open_connection(
|
||||
*addr,
|
||||
ssl=client_sslctx,
|
||||
server_hostname='')
|
||||
|
||||
self.assertEqual(await reader.readline(), b'A\n')
|
||||
writer.write(b'B')
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
await reader.readline()
|
||||
|
||||
writer.close()
|
||||
return 'OK'
|
||||
|
||||
with self.tcp_server(server,
|
||||
max_clients=1,
|
||||
backlog=1) as srv:
|
||||
|
||||
res = self.loop.run_until_complete(client(srv.addr))
|
||||
|
||||
self.assertEqual(res, 'OK')
|
||||
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.SelectorEventLoop()
|
||||
|
||||
|
||||
@unittest.skipIf(ssl is None, 'No ssl module')
|
||||
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
|
||||
class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
|
||||
|
||||
def new_loop(self):
|
||||
return asyncio.ProactorEventLoop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
151
Utils/PythonNew32/Lib/test/test_asyncio/test_staggered.py
Normal file
151
Utils/PythonNew32/Lib/test/test_asyncio/test_staggered.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from asyncio.staggered import staggered_race
|
||||
|
||||
from test import support
|
||||
|
||||
support.requires_working_socket(module=True)
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class StaggeredTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_empty(self):
|
||||
winner, index, excs = await staggered_race(
|
||||
[],
|
||||
delay=None,
|
||||
)
|
||||
|
||||
self.assertIs(winner, None)
|
||||
self.assertIs(index, None)
|
||||
self.assertEqual(excs, [])
|
||||
|
||||
async def test_one_successful(self):
|
||||
async def coro(index):
|
||||
return f'Res: {index}'
|
||||
|
||||
winner, index, excs = await staggered_race(
|
||||
[
|
||||
lambda: coro(0),
|
||||
lambda: coro(1),
|
||||
],
|
||||
delay=None,
|
||||
)
|
||||
|
||||
self.assertEqual(winner, 'Res: 0')
|
||||
self.assertEqual(index, 0)
|
||||
self.assertEqual(excs, [None])
|
||||
|
||||
async def test_first_error_second_successful(self):
|
||||
async def coro(index):
|
||||
if index == 0:
|
||||
raise ValueError(index)
|
||||
return f'Res: {index}'
|
||||
|
||||
winner, index, excs = await staggered_race(
|
||||
[
|
||||
lambda: coro(0),
|
||||
lambda: coro(1),
|
||||
],
|
||||
delay=None,
|
||||
)
|
||||
|
||||
self.assertEqual(winner, 'Res: 1')
|
||||
self.assertEqual(index, 1)
|
||||
self.assertEqual(len(excs), 2)
|
||||
self.assertIsInstance(excs[0], ValueError)
|
||||
self.assertIs(excs[1], None)
|
||||
|
||||
async def test_first_timeout_second_successful(self):
|
||||
async def coro(index):
|
||||
if index == 0:
|
||||
await asyncio.sleep(10) # much bigger than delay
|
||||
return f'Res: {index}'
|
||||
|
||||
winner, index, excs = await staggered_race(
|
||||
[
|
||||
lambda: coro(0),
|
||||
lambda: coro(1),
|
||||
],
|
||||
delay=0.1,
|
||||
)
|
||||
|
||||
self.assertEqual(winner, 'Res: 1')
|
||||
self.assertEqual(index, 1)
|
||||
self.assertEqual(len(excs), 2)
|
||||
self.assertIsInstance(excs[0], asyncio.CancelledError)
|
||||
self.assertIs(excs[1], None)
|
||||
|
||||
async def test_none_successful(self):
|
||||
async def coro(index):
|
||||
raise ValueError(index)
|
||||
|
||||
winner, index, excs = await staggered_race(
|
||||
[
|
||||
lambda: coro(0),
|
||||
lambda: coro(1),
|
||||
],
|
||||
delay=None,
|
||||
)
|
||||
|
||||
self.assertIs(winner, None)
|
||||
self.assertIs(index, None)
|
||||
self.assertEqual(len(excs), 2)
|
||||
self.assertIsInstance(excs[0], ValueError)
|
||||
self.assertIsInstance(excs[1], ValueError)
|
||||
|
||||
|
||||
async def test_multiple_winners(self):
|
||||
event = asyncio.Event()
|
||||
|
||||
async def coro(index):
|
||||
await event.wait()
|
||||
return index
|
||||
|
||||
async def do_set():
|
||||
event.set()
|
||||
await asyncio.Event().wait()
|
||||
|
||||
winner, index, excs = await staggered_race(
|
||||
[
|
||||
lambda: coro(0),
|
||||
lambda: coro(1),
|
||||
do_set,
|
||||
],
|
||||
delay=0.1,
|
||||
)
|
||||
self.assertIs(winner, 0)
|
||||
self.assertIs(index, 0)
|
||||
self.assertEqual(len(excs), 3)
|
||||
self.assertIsNone(excs[0], None)
|
||||
self.assertIsInstance(excs[1], asyncio.CancelledError)
|
||||
self.assertIsInstance(excs[2], asyncio.CancelledError)
|
||||
|
||||
|
||||
async def test_cancelled(self):
|
||||
log = []
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(None) as cs_outer, asyncio.timeout(None) as cs_inner:
|
||||
async def coro_fn():
|
||||
cs_inner.reschedule(-1)
|
||||
await asyncio.sleep(0)
|
||||
try:
|
||||
await asyncio.sleep(0)
|
||||
except asyncio.CancelledError:
|
||||
log.append("cancelled 1")
|
||||
|
||||
cs_outer.reschedule(-1)
|
||||
await asyncio.sleep(0)
|
||||
try:
|
||||
await asyncio.sleep(0)
|
||||
except asyncio.CancelledError:
|
||||
log.append("cancelled 2")
|
||||
try:
|
||||
await staggered_race([coro_fn], delay=None)
|
||||
except asyncio.CancelledError:
|
||||
log.append("cancelled 3")
|
||||
raise
|
||||
|
||||
self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"])
|
||||
1269
Utils/PythonNew32/Lib/test/test_asyncio/test_streams.py
Normal file
1269
Utils/PythonNew32/Lib/test/test_asyncio/test_streams.py
Normal file
File diff suppressed because it is too large
Load Diff
1028
Utils/PythonNew32/Lib/test/test_asyncio/test_subprocess.py
Normal file
1028
Utils/PythonNew32/Lib/test/test_asyncio/test_subprocess.py
Normal file
File diff suppressed because it is too large
Load Diff
1097
Utils/PythonNew32/Lib/test/test_asyncio/test_taskgroups.py
Normal file
1097
Utils/PythonNew32/Lib/test/test_asyncio/test_taskgroups.py
Normal file
File diff suppressed because it is too large
Load Diff
3601
Utils/PythonNew32/Lib/test/test_asyncio/test_tasks.py
Normal file
3601
Utils/PythonNew32/Lib/test/test_asyncio/test_tasks.py
Normal file
File diff suppressed because it is too large
Load Diff
66
Utils/PythonNew32/Lib/test/test_asyncio/test_threads.py
Normal file
66
Utils/PythonNew32/Lib/test/test_asyncio/test_threads.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for asyncio/threads.py"""
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from contextvars import ContextVar
|
||||
from unittest import mock
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class ToThreadTests(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_to_thread(self):
|
||||
result = await asyncio.to_thread(sum, [40, 2])
|
||||
self.assertEqual(result, 42)
|
||||
|
||||
async def test_to_thread_exception(self):
|
||||
def raise_runtime():
|
||||
raise RuntimeError("test")
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "test"):
|
||||
await asyncio.to_thread(raise_runtime)
|
||||
|
||||
async def test_to_thread_once(self):
|
||||
func = mock.Mock()
|
||||
|
||||
await asyncio.to_thread(func)
|
||||
func.assert_called_once()
|
||||
|
||||
async def test_to_thread_concurrent(self):
|
||||
calls = []
|
||||
def func():
|
||||
calls.append(1)
|
||||
|
||||
futs = []
|
||||
for _ in range(10):
|
||||
fut = asyncio.to_thread(func)
|
||||
futs.append(fut)
|
||||
await asyncio.gather(*futs)
|
||||
|
||||
self.assertEqual(sum(calls), 10)
|
||||
|
||||
async def test_to_thread_args_kwargs(self):
|
||||
# Unlike run_in_executor(), to_thread() should directly accept kwargs.
|
||||
func = mock.Mock()
|
||||
|
||||
await asyncio.to_thread(func, 'test', something=True)
|
||||
|
||||
func.assert_called_once_with('test', something=True)
|
||||
|
||||
async def test_to_thread_contextvars(self):
|
||||
test_ctx = ContextVar('test_ctx')
|
||||
|
||||
def get_ctx():
|
||||
return test_ctx.get()
|
||||
|
||||
test_ctx.set('parrot')
|
||||
result = await asyncio.to_thread(get_ctx)
|
||||
|
||||
self.assertEqual(result, 'parrot')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
411
Utils/PythonNew32/Lib/test/test_asyncio/test_timeouts.py
Normal file
411
Utils/PythonNew32/Lib/test/test_asyncio/test_timeouts.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""Tests for asyncio/timeouts.py"""
|
||||
|
||||
import unittest
|
||||
import time
|
||||
|
||||
import asyncio
|
||||
|
||||
from test.test_asyncio.utils import await_without_task
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_timeout_basic(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
await asyncio.sleep(10)
|
||||
self.assertTrue(cm.expired())
|
||||
|
||||
async def test_timeout_at_basic(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
with self.assertRaises(TimeoutError):
|
||||
deadline = loop.time() + 0.01
|
||||
async with asyncio.timeout_at(deadline) as cm:
|
||||
await asyncio.sleep(10)
|
||||
self.assertTrue(cm.expired())
|
||||
self.assertEqual(deadline, cm.when())
|
||||
|
||||
async def test_nested_timeouts(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
cancelled = False
|
||||
with self.assertRaises(TimeoutError):
|
||||
deadline = loop.time() + 0.01
|
||||
async with asyncio.timeout_at(deadline) as cm1:
|
||||
# Only the topmost context manager should raise TimeoutError
|
||||
try:
|
||||
async with asyncio.timeout_at(deadline) as cm2:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
cancelled = True
|
||||
raise
|
||||
self.assertTrue(cancelled)
|
||||
self.assertTrue(cm1.expired())
|
||||
self.assertTrue(cm2.expired())
|
||||
|
||||
async def test_waiter_cancelled(self):
|
||||
cancelled = False
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.01):
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
cancelled = True
|
||||
raise
|
||||
self.assertTrue(cancelled)
|
||||
|
||||
async def test_timeout_not_called(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
async with asyncio.timeout(10) as cm:
|
||||
await asyncio.sleep(0.01)
|
||||
t1 = loop.time()
|
||||
|
||||
self.assertFalse(cm.expired())
|
||||
self.assertGreater(cm.when(), t1)
|
||||
|
||||
async def test_timeout_disabled(self):
|
||||
async with asyncio.timeout(None) as cm:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
self.assertFalse(cm.expired())
|
||||
self.assertIsNone(cm.when())
|
||||
|
||||
async def test_timeout_at_disabled(self):
|
||||
async with asyncio.timeout_at(None) as cm:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
self.assertFalse(cm.expired())
|
||||
self.assertIsNone(cm.when())
|
||||
|
||||
async def test_timeout_zero(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
t0 = loop.time()
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0) as cm:
|
||||
await asyncio.sleep(10)
|
||||
t1 = loop.time()
|
||||
self.assertTrue(cm.expired())
|
||||
self.assertTrue(t0 <= cm.when() <= t1)
|
||||
|
||||
async def test_timeout_zero_sleep_zero(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
t0 = loop.time()
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0) as cm:
|
||||
await asyncio.sleep(0)
|
||||
t1 = loop.time()
|
||||
self.assertTrue(cm.expired())
|
||||
self.assertTrue(t0 <= cm.when() <= t1)
|
||||
|
||||
async def test_timeout_in_the_past_sleep_zero(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
t0 = loop.time()
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(-11) as cm:
|
||||
await asyncio.sleep(0)
|
||||
t1 = loop.time()
|
||||
self.assertTrue(cm.expired())
|
||||
self.assertTrue(t0 >= cm.when() <= t1)
|
||||
|
||||
async def test_foreign_exception_passed(self):
|
||||
with self.assertRaises(KeyError):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
raise KeyError
|
||||
self.assertFalse(cm.expired())
|
||||
|
||||
async def test_timeout_exception_context(self):
|
||||
with self.assertRaises(TimeoutError) as cm:
|
||||
async with asyncio.timeout(0.01):
|
||||
try:
|
||||
1/0
|
||||
finally:
|
||||
await asyncio.sleep(1)
|
||||
e = cm.exception
|
||||
# Expect TimeoutError caused by CancelledError raised during handling
|
||||
# of ZeroDivisionError.
|
||||
e2 = e.__cause__
|
||||
self.assertIsInstance(e2, asyncio.CancelledError)
|
||||
self.assertIs(e.__context__, e2)
|
||||
self.assertIsNone(e2.__cause__)
|
||||
self.assertIsInstance(e2.__context__, ZeroDivisionError)
|
||||
|
||||
async def test_foreign_exception_on_timeout(self):
|
||||
async def crash():
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
1/0
|
||||
with self.assertRaises(ZeroDivisionError) as cm:
|
||||
async with asyncio.timeout(0.01):
|
||||
await crash()
|
||||
e = cm.exception
|
||||
# Expect ZeroDivisionError raised during handling of TimeoutError
|
||||
# caused by CancelledError.
|
||||
self.assertIsNone(e.__cause__)
|
||||
e2 = e.__context__
|
||||
self.assertIsInstance(e2, TimeoutError)
|
||||
e3 = e2.__cause__
|
||||
self.assertIsInstance(e3, asyncio.CancelledError)
|
||||
self.assertIs(e2.__context__, e3)
|
||||
|
||||
async def test_foreign_exception_on_timeout_2(self):
|
||||
with self.assertRaises(ZeroDivisionError) as cm:
|
||||
async with asyncio.timeout(0.01):
|
||||
try:
|
||||
try:
|
||||
raise ValueError
|
||||
finally:
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
try:
|
||||
raise KeyError
|
||||
finally:
|
||||
1/0
|
||||
e = cm.exception
|
||||
# Expect ZeroDivisionError raised during handling of KeyError
|
||||
# raised during handling of TimeoutError caused by CancelledError.
|
||||
self.assertIsNone(e.__cause__)
|
||||
e2 = e.__context__
|
||||
self.assertIsInstance(e2, KeyError)
|
||||
self.assertIsNone(e2.__cause__)
|
||||
e3 = e2.__context__
|
||||
self.assertIsInstance(e3, TimeoutError)
|
||||
e4 = e3.__cause__
|
||||
self.assertIsInstance(e4, asyncio.CancelledError)
|
||||
self.assertIsNone(e4.__cause__)
|
||||
self.assertIsInstance(e4.__context__, ValueError)
|
||||
self.assertIs(e3.__context__, e4)
|
||||
|
||||
async def test_foreign_cancel_doesnt_timeout_if_not_expired(self):
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
async with asyncio.timeout(10) as cm:
|
||||
asyncio.current_task().cancel()
|
||||
await asyncio.sleep(10)
|
||||
self.assertFalse(cm.expired())
|
||||
|
||||
async def test_outer_task_is_not_cancelled(self):
|
||||
async def outer() -> None:
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.001):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
task = asyncio.create_task(outer())
|
||||
await task
|
||||
self.assertFalse(task.cancelled())
|
||||
self.assertTrue(task.done())
|
||||
|
||||
async def test_nested_timeouts_concurrent(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.002):
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.1):
|
||||
# Pretend we crunch some numbers.
|
||||
time.sleep(0.01)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def test_nested_timeouts_loop_busy(self):
|
||||
# After the inner timeout is an expensive operation which should
|
||||
# be stopped by the outer timeout.
|
||||
loop = asyncio.get_running_loop()
|
||||
# Disable a message about long running task
|
||||
loop.slow_callback_duration = 10
|
||||
t0 = loop.time()
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.1): # (1)
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.01): # (2)
|
||||
# Pretend the loop is busy for a while.
|
||||
time.sleep(0.1)
|
||||
await asyncio.sleep(1)
|
||||
# TimeoutError was cought by (2)
|
||||
await asyncio.sleep(10) # This sleep should be interrupted by (1)
|
||||
t1 = loop.time()
|
||||
self.assertTrue(t0 <= t1 <= t0 + 1)
|
||||
|
||||
async def test_reschedule(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
fut = loop.create_future()
|
||||
deadline1 = loop.time() + 10
|
||||
deadline2 = deadline1 + 20
|
||||
|
||||
async def f():
|
||||
async with asyncio.timeout_at(deadline1) as cm:
|
||||
fut.set_result(cm)
|
||||
await asyncio.sleep(50)
|
||||
|
||||
task = asyncio.create_task(f())
|
||||
cm = await fut
|
||||
|
||||
self.assertEqual(cm.when(), deadline1)
|
||||
cm.reschedule(deadline2)
|
||||
self.assertEqual(cm.when(), deadline2)
|
||||
cm.reschedule(None)
|
||||
self.assertIsNone(cm.when())
|
||||
|
||||
task.cancel()
|
||||
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await task
|
||||
self.assertFalse(cm.expired())
|
||||
|
||||
async def test_repr_active(self):
|
||||
async with asyncio.timeout(10) as cm:
|
||||
self.assertRegex(repr(cm), r"<Timeout \[active\] when=\d+\.\d*>")
|
||||
|
||||
async def test_repr_expired(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
await asyncio.sleep(10)
|
||||
self.assertEqual(repr(cm), "<Timeout [expired]>")
|
||||
|
||||
async def test_repr_finished(self):
|
||||
async with asyncio.timeout(10) as cm:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
self.assertEqual(repr(cm), "<Timeout [finished]>")
|
||||
|
||||
async def test_repr_disabled(self):
|
||||
async with asyncio.timeout(None) as cm:
|
||||
self.assertEqual(repr(cm), r"<Timeout [active] when=None>")
|
||||
|
||||
async def test_nested_timeout_in_finally(self):
|
||||
with self.assertRaises(TimeoutError) as cm1:
|
||||
async with asyncio.timeout(0.01):
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
with self.assertRaises(TimeoutError) as cm2:
|
||||
async with asyncio.timeout(0.01):
|
||||
await asyncio.sleep(10)
|
||||
e1 = cm1.exception
|
||||
# Expect TimeoutError caused by CancelledError.
|
||||
e12 = e1.__cause__
|
||||
self.assertIsInstance(e12, asyncio.CancelledError)
|
||||
self.assertIsNone(e12.__cause__)
|
||||
self.assertIsNone(e12.__context__)
|
||||
self.assertIs(e1.__context__, e12)
|
||||
e2 = cm2.exception
|
||||
# Expect TimeoutError caused by CancelledError raised during
|
||||
# handling of other CancelledError (which is the same as in
|
||||
# the above chain).
|
||||
e22 = e2.__cause__
|
||||
self.assertIsInstance(e22, asyncio.CancelledError)
|
||||
self.assertIsNone(e22.__cause__)
|
||||
self.assertIs(e22.__context__, e12)
|
||||
self.assertIs(e2.__context__, e22)
|
||||
|
||||
async def test_timeout_after_cancellation(self):
|
||||
try:
|
||||
asyncio.current_task().cancel()
|
||||
await asyncio.sleep(1) # work which will be cancelled
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
with self.assertRaises(TimeoutError) as cm:
|
||||
async with asyncio.timeout(0.0):
|
||||
await asyncio.sleep(1) # some cleanup
|
||||
|
||||
async def test_cancel_in_timeout_after_cancellation(self):
|
||||
try:
|
||||
asyncio.current_task().cancel()
|
||||
await asyncio.sleep(1) # work which will be cancelled
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
async with asyncio.timeout(1.0):
|
||||
asyncio.current_task().cancel()
|
||||
await asyncio.sleep(2) # some cleanup
|
||||
|
||||
async def test_timeout_already_entered(self):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||
async with cm:
|
||||
pass
|
||||
|
||||
async def test_timeout_double_enter(self):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
pass
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
|
||||
async with cm:
|
||||
pass
|
||||
|
||||
async def test_timeout_finished(self):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
pass
|
||||
with self.assertRaisesRegex(RuntimeError, "finished"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_expired(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
await asyncio.sleep(1)
|
||||
with self.assertRaisesRegex(RuntimeError, "expired"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_expiring(self):
|
||||
async with asyncio.timeout(0.01) as cm:
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await asyncio.sleep(1)
|
||||
with self.assertRaisesRegex(RuntimeError, "expiring"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_not_entered(self):
|
||||
cm = asyncio.timeout(0.01)
|
||||
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_without_task(self):
|
||||
cm = asyncio.timeout(0.01)
|
||||
with self.assertRaisesRegex(RuntimeError, "task"):
|
||||
await await_without_task(cm.__aenter__())
|
||||
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
|
||||
cm.reschedule(0.02)
|
||||
|
||||
async def test_timeout_taskgroup(self):
|
||||
async def task():
|
||||
try:
|
||||
await asyncio.sleep(2) # Will be interrupted after 0.01 second
|
||||
finally:
|
||||
1/0 # Crash in cleanup
|
||||
|
||||
with self.assertRaises(ExceptionGroup) as cm:
|
||||
async with asyncio.timeout(0.01):
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
tg.create_task(task())
|
||||
try:
|
||||
raise ValueError
|
||||
finally:
|
||||
await asyncio.sleep(1)
|
||||
eg = cm.exception
|
||||
# Expect ExceptionGroup raised during handling of TimeoutError caused
|
||||
# by CancelledError raised during handling of ValueError.
|
||||
self.assertIsNone(eg.__cause__)
|
||||
e_1 = eg.__context__
|
||||
self.assertIsInstance(e_1, TimeoutError)
|
||||
e_2 = e_1.__cause__
|
||||
self.assertIsInstance(e_2, asyncio.CancelledError)
|
||||
self.assertIsNone(e_2.__cause__)
|
||||
self.assertIsInstance(e_2.__context__, ValueError)
|
||||
self.assertIs(e_1.__context__, e_2)
|
||||
|
||||
self.assertEqual(len(eg.exceptions), 1, eg)
|
||||
e1 = eg.exceptions[0]
|
||||
# Expect ZeroDivisionError raised during handling of TimeoutError
|
||||
# caused by CancelledError (it is a different CancelledError).
|
||||
self.assertIsInstance(e1, ZeroDivisionError)
|
||||
self.assertIsNone(e1.__cause__)
|
||||
e2 = e1.__context__
|
||||
self.assertIsInstance(e2, TimeoutError)
|
||||
e3 = e2.__cause__
|
||||
self.assertIsInstance(e3, asyncio.CancelledError)
|
||||
self.assertIsNone(e3.__context__)
|
||||
self.assertIsNone(e3.__cause__)
|
||||
self.assertIs(e2.__context__, e3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
103
Utils/PythonNew32/Lib/test/test_asyncio/test_transports.py
Normal file
103
Utils/PythonNew32/Lib/test/test_asyncio/test_transports.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Tests for transports.py."""
|
||||
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import asyncio
|
||||
from asyncio import transports
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
# not needed for the test file but added for uniformness with all other
|
||||
# asyncio test files for the sake of unified cleanup
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class TransportTests(unittest.TestCase):
|
||||
|
||||
def test_ctor_extra_is_none(self):
|
||||
transport = asyncio.Transport()
|
||||
self.assertEqual(transport._extra, {})
|
||||
|
||||
def test_get_extra_info(self):
|
||||
transport = asyncio.Transport({'extra': 'info'})
|
||||
self.assertEqual('info', transport.get_extra_info('extra'))
|
||||
self.assertIsNone(transport.get_extra_info('unknown'))
|
||||
|
||||
default = object()
|
||||
self.assertIs(default, transport.get_extra_info('unknown', default))
|
||||
|
||||
def test_writelines(self):
|
||||
writer = mock.Mock()
|
||||
|
||||
class MyTransport(asyncio.Transport):
|
||||
def write(self, data):
|
||||
writer(data)
|
||||
|
||||
transport = MyTransport()
|
||||
|
||||
transport.writelines([b'line1',
|
||||
bytearray(b'line2'),
|
||||
memoryview(b'line3')])
|
||||
self.assertEqual(1, writer.call_count)
|
||||
writer.assert_called_with(b'line1line2line3')
|
||||
|
||||
def test_not_implemented(self):
|
||||
transport = asyncio.Transport()
|
||||
|
||||
self.assertRaises(NotImplementedError,
|
||||
transport.set_write_buffer_limits)
|
||||
self.assertRaises(NotImplementedError, transport.get_write_buffer_size)
|
||||
self.assertRaises(NotImplementedError, transport.write, 'data')
|
||||
self.assertRaises(NotImplementedError, transport.write_eof)
|
||||
self.assertRaises(NotImplementedError, transport.can_write_eof)
|
||||
self.assertRaises(NotImplementedError, transport.pause_reading)
|
||||
self.assertRaises(NotImplementedError, transport.resume_reading)
|
||||
self.assertRaises(NotImplementedError, transport.is_reading)
|
||||
self.assertRaises(NotImplementedError, transport.close)
|
||||
self.assertRaises(NotImplementedError, transport.abort)
|
||||
|
||||
def test_dgram_not_implemented(self):
|
||||
transport = asyncio.DatagramTransport()
|
||||
|
||||
self.assertRaises(NotImplementedError, transport.sendto, 'data')
|
||||
self.assertRaises(NotImplementedError, transport.abort)
|
||||
|
||||
def test_subprocess_transport_not_implemented(self):
|
||||
transport = asyncio.SubprocessTransport()
|
||||
|
||||
self.assertRaises(NotImplementedError, transport.get_pid)
|
||||
self.assertRaises(NotImplementedError, transport.get_returncode)
|
||||
self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1)
|
||||
self.assertRaises(NotImplementedError, transport.send_signal, 1)
|
||||
self.assertRaises(NotImplementedError, transport.terminate)
|
||||
self.assertRaises(NotImplementedError, transport.kill)
|
||||
|
||||
def test_flowcontrol_mixin_set_write_limits(self):
|
||||
|
||||
class MyTransport(transports._FlowControlMixin,
|
||||
transports.Transport):
|
||||
|
||||
def get_write_buffer_size(self):
|
||||
return 512
|
||||
|
||||
loop = mock.Mock()
|
||||
transport = MyTransport(loop=loop)
|
||||
transport._protocol = mock.Mock()
|
||||
|
||||
self.assertFalse(transport._protocol_paused)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
|
||||
transport.set_write_buffer_limits(high=0, low=1)
|
||||
|
||||
transport.set_write_buffer_limits(high=1024, low=128)
|
||||
self.assertFalse(transport._protocol_paused)
|
||||
self.assertEqual(transport.get_write_buffer_limits(), (128, 1024))
|
||||
|
||||
transport.set_write_buffer_limits(high=256, low=128)
|
||||
self.assertTrue(transport._protocol_paused)
|
||||
self.assertEqual(transport.get_write_buffer_limits(), (128, 256))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
1993
Utils/PythonNew32/Lib/test/test_asyncio/test_unix_events.py
Normal file
1993
Utils/PythonNew32/Lib/test/test_asyncio/test_unix_events.py
Normal file
File diff suppressed because it is too large
Load Diff
353
Utils/PythonNew32/Lib/test/test_asyncio/test_waitfor.py
Normal file
353
Utils/PythonNew32/Lib/test/test_asyncio/test_waitfor.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
import time
|
||||
from test import support
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
# The following value can be used as a very small timeout:
|
||||
# it passes check "timeout > 0", but has almost
|
||||
# no effect on the test performance
|
||||
_EPSILON = 0.0001
|
||||
|
||||
|
||||
class SlowTask:
|
||||
""" Task will run for this defined time, ignoring cancel requests """
|
||||
TASK_TIMEOUT = 0.2
|
||||
|
||||
def __init__(self):
|
||||
self.exited = False
|
||||
|
||||
async def run(self):
|
||||
exitat = time.monotonic() + self.TASK_TIMEOUT
|
||||
|
||||
while True:
|
||||
tosleep = exitat - time.monotonic()
|
||||
if tosleep <= 0:
|
||||
break
|
||||
|
||||
try:
|
||||
await asyncio.sleep(tosleep)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.exited = True
|
||||
|
||||
|
||||
class AsyncioWaitForTest(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_asyncio_wait_for_cancelled(self):
|
||||
t = SlowTask()
|
||||
|
||||
waitfortask = asyncio.create_task(
|
||||
asyncio.wait_for(t.run(), t.TASK_TIMEOUT * 2))
|
||||
await asyncio.sleep(0)
|
||||
waitfortask.cancel()
|
||||
await asyncio.wait({waitfortask})
|
||||
|
||||
self.assertTrue(t.exited)
|
||||
|
||||
async def test_asyncio_wait_for_timeout(self):
|
||||
t = SlowTask()
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(t.run(), t.TASK_TIMEOUT / 2)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
self.assertTrue(t.exited)
|
||||
|
||||
async def test_wait_for_timeout_less_then_0_or_0_future_done(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
fut = loop.create_future()
|
||||
fut.set_result('done')
|
||||
|
||||
ret = await asyncio.wait_for(fut, 0)
|
||||
|
||||
self.assertEqual(ret, 'done')
|
||||
self.assertTrue(fut.done())
|
||||
|
||||
async def test_wait_for_timeout_less_then_0_or_0_coroutine_do_not_started(self):
|
||||
foo_started = False
|
||||
|
||||
async def foo():
|
||||
nonlocal foo_started
|
||||
foo_started = True
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(foo(), 0)
|
||||
|
||||
self.assertEqual(foo_started, False)
|
||||
|
||||
async def test_wait_for_timeout_less_then_0_or_0(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
for timeout in [0, -1]:
|
||||
with self.subTest(timeout=timeout):
|
||||
foo_running = None
|
||||
started = loop.create_future()
|
||||
|
||||
async def foo():
|
||||
nonlocal foo_running
|
||||
foo_running = True
|
||||
started.set_result(None)
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
finally:
|
||||
foo_running = False
|
||||
return 'done'
|
||||
|
||||
fut = asyncio.create_task(foo())
|
||||
await started
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(fut, timeout)
|
||||
|
||||
self.assertTrue(fut.done())
|
||||
# it should have been cancelled due to the timeout
|
||||
self.assertTrue(fut.cancelled())
|
||||
self.assertEqual(foo_running, False)
|
||||
|
||||
async def test_wait_for(self):
|
||||
foo_running = None
|
||||
|
||||
async def foo():
|
||||
nonlocal foo_running
|
||||
foo_running = True
|
||||
try:
|
||||
await asyncio.sleep(support.LONG_TIMEOUT)
|
||||
finally:
|
||||
foo_running = False
|
||||
return 'done'
|
||||
|
||||
fut = asyncio.create_task(foo())
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(fut, 0.1)
|
||||
self.assertTrue(fut.done())
|
||||
# it should have been cancelled due to the timeout
|
||||
self.assertTrue(fut.cancelled())
|
||||
self.assertEqual(foo_running, False)
|
||||
|
||||
async def test_wait_for_blocking(self):
|
||||
async def coro():
|
||||
return 'done'
|
||||
|
||||
res = await asyncio.wait_for(coro(), timeout=None)
|
||||
self.assertEqual(res, 'done')
|
||||
|
||||
async def test_wait_for_race_condition(self):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
fut = loop.create_future()
|
||||
task = asyncio.wait_for(fut, timeout=0.2)
|
||||
loop.call_soon(fut.set_result, "ok")
|
||||
res = await task
|
||||
self.assertEqual(res, "ok")
|
||||
|
||||
async def test_wait_for_cancellation_race_condition(self):
|
||||
async def inner():
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await asyncio.sleep(1)
|
||||
return 1
|
||||
|
||||
result = await asyncio.wait_for(inner(), timeout=.01)
|
||||
self.assertEqual(result, 1)
|
||||
|
||||
async def test_wait_for_waits_for_task_cancellation(self):
|
||||
task_done = False
|
||||
|
||||
async def inner():
|
||||
nonlocal task_done
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
await asyncio.sleep(_EPSILON)
|
||||
raise
|
||||
finally:
|
||||
task_done = True
|
||||
|
||||
inner_task = asyncio.create_task(inner())
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError) as cm:
|
||||
await asyncio.wait_for(inner_task, timeout=_EPSILON)
|
||||
|
||||
self.assertTrue(task_done)
|
||||
chained = cm.exception.__context__
|
||||
self.assertEqual(type(chained), asyncio.CancelledError)
|
||||
|
||||
async def test_wait_for_waits_for_task_cancellation_w_timeout_0(self):
|
||||
task_done = False
|
||||
|
||||
async def foo():
|
||||
async def inner():
|
||||
nonlocal task_done
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
await asyncio.sleep(_EPSILON)
|
||||
raise
|
||||
finally:
|
||||
task_done = True
|
||||
|
||||
inner_task = asyncio.create_task(inner())
|
||||
await asyncio.sleep(_EPSILON)
|
||||
await asyncio.wait_for(inner_task, timeout=0)
|
||||
|
||||
with self.assertRaises(asyncio.TimeoutError) as cm:
|
||||
await foo()
|
||||
|
||||
self.assertTrue(task_done)
|
||||
chained = cm.exception.__context__
|
||||
self.assertEqual(type(chained), asyncio.CancelledError)
|
||||
|
||||
async def test_wait_for_reraises_exception_during_cancellation(self):
|
||||
class FooException(Exception):
|
||||
pass
|
||||
|
||||
async def foo():
|
||||
async def inner():
|
||||
try:
|
||||
await asyncio.sleep(0.2)
|
||||
finally:
|
||||
raise FooException
|
||||
|
||||
inner_task = asyncio.create_task(inner())
|
||||
|
||||
await asyncio.wait_for(inner_task, timeout=_EPSILON)
|
||||
|
||||
with self.assertRaises(FooException):
|
||||
await foo()
|
||||
|
||||
async def _test_cancel_wait_for(self, timeout):
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
async def blocking_coroutine():
|
||||
fut = loop.create_future()
|
||||
# Block: fut result is never set
|
||||
await fut
|
||||
|
||||
task = asyncio.create_task(blocking_coroutine())
|
||||
|
||||
wait = asyncio.create_task(asyncio.wait_for(task, timeout))
|
||||
loop.call_soon(wait.cancel)
|
||||
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await wait
|
||||
|
||||
# Python issue #23219: cancelling the wait must also cancel the task
|
||||
self.assertTrue(task.cancelled())
|
||||
|
||||
async def test_cancel_blocking_wait_for(self):
|
||||
await self._test_cancel_wait_for(None)
|
||||
|
||||
async def test_cancel_wait_for(self):
|
||||
await self._test_cancel_wait_for(60.0)
|
||||
|
||||
async def test_wait_for_cancel_suppressed(self):
|
||||
# GH-86296: Suppressing CancelledError is discouraged
|
||||
# but if a task suppresses CancelledError and returns a value,
|
||||
# `wait_for` should return the value instead of raising CancelledError.
|
||||
# This is the same behavior as `asyncio.timeout`.
|
||||
|
||||
async def return_42():
|
||||
try:
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
return 42
|
||||
|
||||
res = await asyncio.wait_for(return_42(), timeout=0.1)
|
||||
self.assertEqual(res, 42)
|
||||
|
||||
|
||||
async def test_wait_for_issue86296(self):
|
||||
# GH-86296: The task should get cancelled and not run to completion.
|
||||
# inner completes in one cycle of the event loop so it
|
||||
# completes before the task is cancelled.
|
||||
|
||||
async def inner():
|
||||
return 'done'
|
||||
|
||||
inner_task = asyncio.create_task(inner())
|
||||
reached_end = False
|
||||
|
||||
async def wait_for_coro():
|
||||
await asyncio.wait_for(inner_task, timeout=100)
|
||||
await asyncio.sleep(1)
|
||||
nonlocal reached_end
|
||||
reached_end = True
|
||||
|
||||
task = asyncio.create_task(wait_for_coro())
|
||||
self.assertFalse(task.done())
|
||||
# Run the task
|
||||
await asyncio.sleep(0)
|
||||
task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
await task
|
||||
self.assertTrue(inner_task.done())
|
||||
self.assertEqual(await inner_task, 'done')
|
||||
self.assertFalse(reached_end)
|
||||
|
||||
|
||||
class WaitForShieldTests(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
async def test_zero_timeout(self):
|
||||
# `asyncio.shield` creates a new task which wraps the passed in
|
||||
# awaitable and shields it from cancellation so with timeout=0
|
||||
# the task returned by `asyncio.shield` aka shielded_task gets
|
||||
# cancelled immediately and the task wrapped by it is scheduled
|
||||
# to run.
|
||||
|
||||
async def coro():
|
||||
await asyncio.sleep(0.01)
|
||||
return 'done'
|
||||
|
||||
task = asyncio.create_task(coro())
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
shielded_task = asyncio.shield(task)
|
||||
await asyncio.wait_for(shielded_task, timeout=0)
|
||||
|
||||
# Task is running in background
|
||||
self.assertFalse(task.done())
|
||||
self.assertFalse(task.cancelled())
|
||||
self.assertTrue(shielded_task.cancelled())
|
||||
|
||||
# Wait for the task to complete
|
||||
await asyncio.sleep(0.1)
|
||||
self.assertTrue(task.done())
|
||||
|
||||
|
||||
async def test_none_timeout(self):
|
||||
# With timeout=None the timeout is disabled so it
|
||||
# runs till completion.
|
||||
async def coro():
|
||||
await asyncio.sleep(0.1)
|
||||
return 'done'
|
||||
|
||||
task = asyncio.create_task(coro())
|
||||
await asyncio.wait_for(asyncio.shield(task), timeout=None)
|
||||
|
||||
self.assertTrue(task.done())
|
||||
self.assertEqual(await task, "done")
|
||||
|
||||
async def test_shielded_timeout(self):
|
||||
# shield prevents the task from being cancelled.
|
||||
async def coro():
|
||||
await asyncio.sleep(0.1)
|
||||
return 'done'
|
||||
|
||||
task = asyncio.create_task(coro())
|
||||
with self.assertRaises(asyncio.TimeoutError):
|
||||
await asyncio.wait_for(asyncio.shield(task), timeout=0.01)
|
||||
|
||||
self.assertFalse(task.done())
|
||||
self.assertFalse(task.cancelled())
|
||||
self.assertEqual(await task, "done")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
359
Utils/PythonNew32/Lib/test/test_asyncio/test_windows_events.py
Normal file
359
Utils/PythonNew32/Lib/test/test_asyncio/test_windows_events.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
if sys.platform != 'win32':
|
||||
raise unittest.SkipTest('Windows only')
|
||||
|
||||
import _overlapped
|
||||
import _winapi
|
||||
|
||||
import asyncio
|
||||
from asyncio import windows_events
|
||||
from test.test_asyncio import utils as test_utils
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class UpperProto(asyncio.Protocol):
|
||||
def __init__(self):
|
||||
self.buf = []
|
||||
|
||||
def connection_made(self, trans):
|
||||
self.trans = trans
|
||||
|
||||
def data_received(self, data):
|
||||
self.buf.append(data)
|
||||
if b'\n' in data:
|
||||
self.trans.write(b''.join(self.buf).upper())
|
||||
self.trans.close()
|
||||
|
||||
|
||||
class WindowsEventsTestCase(test_utils.TestCase):
|
||||
def _unraisablehook(self, unraisable):
|
||||
# Storing unraisable.object can resurrect an object which is being
|
||||
# finalized. Storing unraisable.exc_value creates a reference cycle.
|
||||
self._unraisable = unraisable
|
||||
print(unraisable)
|
||||
|
||||
def setUp(self):
|
||||
self._prev_unraisablehook = sys.unraisablehook
|
||||
self._unraisable = None
|
||||
sys.unraisablehook = self._unraisablehook
|
||||
|
||||
def tearDown(self):
|
||||
sys.unraisablehook = self._prev_unraisablehook
|
||||
self.assertIsNone(self._unraisable)
|
||||
|
||||
class ProactorLoopCtrlC(WindowsEventsTestCase):
|
||||
|
||||
def test_ctrl_c(self):
|
||||
|
||||
def SIGINT_after_delay():
|
||||
time.sleep(0.1)
|
||||
signal.raise_signal(signal.SIGINT)
|
||||
|
||||
thread = threading.Thread(target=SIGINT_after_delay)
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# only start the loop once the event loop is running
|
||||
loop.call_soon(thread.start)
|
||||
loop.run_forever()
|
||||
self.fail("should not fall through 'run_forever'")
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
self.close_loop(loop)
|
||||
thread.join()
|
||||
|
||||
|
||||
class ProactorMultithreading(WindowsEventsTestCase):
|
||||
def test_run_from_nonmain_thread(self):
|
||||
finished = False
|
||||
|
||||
async def coro():
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def func():
|
||||
nonlocal finished
|
||||
loop = asyncio.new_event_loop()
|
||||
loop.run_until_complete(coro())
|
||||
# close() must not call signal.set_wakeup_fd()
|
||||
loop.close()
|
||||
finished = True
|
||||
|
||||
thread = threading.Thread(target=func)
|
||||
thread.start()
|
||||
thread.join()
|
||||
self.assertTrue(finished)
|
||||
|
||||
|
||||
class ProactorTests(WindowsEventsTestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.ProactorEventLoop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
def test_close(self):
|
||||
a, b = socket.socketpair()
|
||||
trans = self.loop._make_socket_transport(a, asyncio.Protocol())
|
||||
f = asyncio.ensure_future(self.loop.sock_recv(b, 100), loop=self.loop)
|
||||
trans.close()
|
||||
self.loop.run_until_complete(f)
|
||||
self.assertEqual(f.result(), b'')
|
||||
b.close()
|
||||
|
||||
def test_double_bind(self):
|
||||
ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid()
|
||||
server1 = windows_events.PipeServer(ADDRESS)
|
||||
with self.assertRaises(PermissionError):
|
||||
windows_events.PipeServer(ADDRESS)
|
||||
server1.close()
|
||||
|
||||
def test_pipe(self):
|
||||
res = self.loop.run_until_complete(self._test_pipe())
|
||||
self.assertEqual(res, 'done')
|
||||
|
||||
async def _test_pipe(self):
|
||||
ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid()
|
||||
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
await self.loop.create_pipe_connection(
|
||||
asyncio.Protocol, ADDRESS)
|
||||
|
||||
[server] = await self.loop.start_serving_pipe(
|
||||
UpperProto, ADDRESS)
|
||||
self.assertIsInstance(server, windows_events.PipeServer)
|
||||
|
||||
clients = []
|
||||
for i in range(5):
|
||||
stream_reader = asyncio.StreamReader(loop=self.loop)
|
||||
protocol = asyncio.StreamReaderProtocol(stream_reader,
|
||||
loop=self.loop)
|
||||
trans, proto = await self.loop.create_pipe_connection(
|
||||
lambda: protocol, ADDRESS)
|
||||
self.assertIsInstance(trans, asyncio.Transport)
|
||||
self.assertEqual(protocol, proto)
|
||||
clients.append((stream_reader, trans))
|
||||
|
||||
for i, (r, w) in enumerate(clients):
|
||||
w.write('lower-{}\n'.format(i).encode())
|
||||
|
||||
for i, (r, w) in enumerate(clients):
|
||||
response = await r.readline()
|
||||
self.assertEqual(response, 'LOWER-{}\n'.format(i).encode())
|
||||
w.close()
|
||||
|
||||
server.close()
|
||||
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
await self.loop.create_pipe_connection(
|
||||
asyncio.Protocol, ADDRESS)
|
||||
|
||||
return 'done'
|
||||
|
||||
def test_connect_pipe_cancel(self):
|
||||
exc = OSError()
|
||||
exc.winerror = _overlapped.ERROR_PIPE_BUSY
|
||||
with mock.patch.object(_overlapped, 'ConnectPipe',
|
||||
side_effect=exc) as connect:
|
||||
coro = self.loop._proactor.connect_pipe('pipe_address')
|
||||
task = self.loop.create_task(coro)
|
||||
|
||||
# check that it's possible to cancel connect_pipe()
|
||||
task.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
self.loop.run_until_complete(task)
|
||||
|
||||
def test_wait_for_handle(self):
|
||||
event = _overlapped.CreateEvent(None, True, False, None)
|
||||
self.addCleanup(_winapi.CloseHandle, event)
|
||||
|
||||
# Wait for unset event with 0.5s timeout;
|
||||
# result should be False at timeout
|
||||
timeout = 0.5
|
||||
fut = self.loop._proactor.wait_for_handle(event, timeout)
|
||||
start = self.loop.time()
|
||||
done = self.loop.run_until_complete(fut)
|
||||
elapsed = self.loop.time() - start
|
||||
|
||||
self.assertEqual(done, False)
|
||||
self.assertFalse(fut.result())
|
||||
self.assertGreaterEqual(elapsed, timeout - test_utils.CLOCK_RES)
|
||||
|
||||
_overlapped.SetEvent(event)
|
||||
|
||||
# Wait for set event;
|
||||
# result should be True immediately
|
||||
fut = self.loop._proactor.wait_for_handle(event, 10)
|
||||
done = self.loop.run_until_complete(fut)
|
||||
|
||||
self.assertEqual(done, True)
|
||||
self.assertTrue(fut.result())
|
||||
|
||||
# asyncio issue #195: cancelling a done _WaitHandleFuture
|
||||
# must not crash
|
||||
fut.cancel()
|
||||
|
||||
def test_wait_for_handle_cancel(self):
|
||||
event = _overlapped.CreateEvent(None, True, False, None)
|
||||
self.addCleanup(_winapi.CloseHandle, event)
|
||||
|
||||
# Wait for unset event with a cancelled future;
|
||||
# CancelledError should be raised immediately
|
||||
fut = self.loop._proactor.wait_for_handle(event, 10)
|
||||
fut.cancel()
|
||||
with self.assertRaises(asyncio.CancelledError):
|
||||
self.loop.run_until_complete(fut)
|
||||
|
||||
# asyncio issue #195: cancelling a _WaitHandleFuture twice
|
||||
# must not crash
|
||||
fut = self.loop._proactor.wait_for_handle(event)
|
||||
fut.cancel()
|
||||
fut.cancel()
|
||||
|
||||
def test_read_self_pipe_restart(self):
|
||||
# Regression test for https://bugs.python.org/issue39010
|
||||
# Previously, restarting a proactor event loop in certain states
|
||||
# would lead to spurious ConnectionResetErrors being logged.
|
||||
self.loop.call_exception_handler = mock.Mock()
|
||||
# Start an operation in another thread so that the self-pipe is used.
|
||||
# This is theoretically timing-dependent (the task in the executor
|
||||
# must complete before our start/stop cycles), but in practice it
|
||||
# seems to work every time.
|
||||
f = self.loop.run_in_executor(None, lambda: None)
|
||||
self.loop.stop()
|
||||
self.loop.run_forever()
|
||||
self.loop.stop()
|
||||
self.loop.run_forever()
|
||||
|
||||
# Shut everything down cleanly. This is an important part of the
|
||||
# test - in issue 39010, the error occurred during loop.close(),
|
||||
# so we want to close the loop during the test instead of leaving
|
||||
# it for tearDown.
|
||||
#
|
||||
# First wait for f to complete to avoid a "future's result was never
|
||||
# retrieved" error.
|
||||
self.loop.run_until_complete(f)
|
||||
# Now shut down the loop itself (self.close_loop also shuts down the
|
||||
# loop's default executor).
|
||||
self.close_loop(self.loop)
|
||||
self.assertFalse(self.loop.call_exception_handler.called)
|
||||
|
||||
def test_address_argument_type_error(self):
|
||||
# Regression test for https://github.com/python/cpython/issues/98793
|
||||
proactor = self.loop._proactor
|
||||
sock = socket.socket(type=socket.SOCK_DGRAM)
|
||||
bad_address = None
|
||||
with self.assertRaises(TypeError):
|
||||
proactor.connect(sock, bad_address)
|
||||
with self.assertRaises(TypeError):
|
||||
proactor.sendto(sock, b'abc', addr=bad_address)
|
||||
sock.close()
|
||||
|
||||
def test_client_pipe_stat(self):
|
||||
res = self.loop.run_until_complete(self._test_client_pipe_stat())
|
||||
self.assertEqual(res, 'done')
|
||||
|
||||
async def _test_client_pipe_stat(self):
|
||||
# Regression test for https://github.com/python/cpython/issues/100573
|
||||
ADDRESS = r'\\.\pipe\test_client_pipe_stat-%s' % os.getpid()
|
||||
|
||||
async def probe():
|
||||
# See https://github.com/python/cpython/pull/100959#discussion_r1068533658
|
||||
h = _overlapped.ConnectPipe(ADDRESS)
|
||||
try:
|
||||
_winapi.CloseHandle(_overlapped.ConnectPipe(ADDRESS))
|
||||
except OSError as e:
|
||||
if e.winerror != _overlapped.ERROR_PIPE_BUSY:
|
||||
raise
|
||||
finally:
|
||||
_winapi.CloseHandle(h)
|
||||
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
await probe()
|
||||
|
||||
[server] = await self.loop.start_serving_pipe(asyncio.Protocol, ADDRESS)
|
||||
self.assertIsInstance(server, windows_events.PipeServer)
|
||||
|
||||
errors = []
|
||||
self.loop.set_exception_handler(lambda _, data: errors.append(data))
|
||||
|
||||
for i in range(5):
|
||||
await self.loop.create_task(probe())
|
||||
|
||||
self.assertEqual(len(errors), 0, errors)
|
||||
|
||||
server.close()
|
||||
|
||||
with self.assertRaises(FileNotFoundError):
|
||||
await probe()
|
||||
|
||||
return "done"
|
||||
|
||||
def test_loop_restart(self):
|
||||
# We're fishing for the "RuntimeError: <_overlapped.Overlapped object at XXX>
|
||||
# still has pending operation at deallocation, the process may crash" error
|
||||
stop = threading.Event()
|
||||
def threadMain():
|
||||
while not stop.is_set():
|
||||
self.loop.call_soon_threadsafe(lambda: None)
|
||||
time.sleep(0.01)
|
||||
thr = threading.Thread(target=threadMain)
|
||||
|
||||
# In 10 60-second runs of this test prior to the fix:
|
||||
# time in seconds until failure: (none), 15.0, 6.4, (none), 7.6, 8.3, 1.7, 22.2, 23.5, 8.3
|
||||
# 10 seconds had a 50% failure rate but longer would be more costly
|
||||
end_time = time.time() + 10 # Run for 10 seconds
|
||||
self.loop.call_soon(thr.start)
|
||||
while not self._unraisable: # Stop if we got an unraisable exc
|
||||
self.loop.stop()
|
||||
self.loop.run_forever()
|
||||
if time.time() >= end_time:
|
||||
break
|
||||
|
||||
stop.set()
|
||||
thr.join()
|
||||
|
||||
|
||||
class WinPolicyTests(WindowsEventsTestCase):
|
||||
|
||||
def test_selector_win_policy(self):
|
||||
async def main():
|
||||
self.assertIsInstance(
|
||||
asyncio.get_running_loop(),
|
||||
asyncio.SelectorEventLoop)
|
||||
|
||||
old_policy = asyncio.get_event_loop_policy()
|
||||
try:
|
||||
asyncio.set_event_loop_policy(
|
||||
asyncio.WindowsSelectorEventLoopPolicy())
|
||||
asyncio.run(main())
|
||||
finally:
|
||||
asyncio.set_event_loop_policy(old_policy)
|
||||
|
||||
def test_proactor_win_policy(self):
|
||||
async def main():
|
||||
self.assertIsInstance(
|
||||
asyncio.get_running_loop(),
|
||||
asyncio.ProactorEventLoop)
|
||||
|
||||
old_policy = asyncio.get_event_loop_policy()
|
||||
try:
|
||||
asyncio.set_event_loop_policy(
|
||||
asyncio.WindowsProactorEventLoopPolicy())
|
||||
asyncio.run(main())
|
||||
finally:
|
||||
asyncio.set_event_loop_policy(old_policy)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
133
Utils/PythonNew32/Lib/test/test_asyncio/test_windows_utils.py
Normal file
133
Utils/PythonNew32/Lib/test/test_asyncio/test_windows_utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Tests for window_utils"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
if sys.platform != 'win32':
|
||||
raise unittest.SkipTest('Windows only')
|
||||
|
||||
import _overlapped
|
||||
import _winapi
|
||||
|
||||
import asyncio
|
||||
from asyncio import windows_utils
|
||||
from test import support
|
||||
|
||||
|
||||
def tearDownModule():
|
||||
asyncio.set_event_loop_policy(None)
|
||||
|
||||
|
||||
class PipeTests(unittest.TestCase):
|
||||
|
||||
def test_pipe_overlapped(self):
|
||||
h1, h2 = windows_utils.pipe(overlapped=(True, True))
|
||||
try:
|
||||
ov1 = _overlapped.Overlapped()
|
||||
self.assertFalse(ov1.pending)
|
||||
self.assertEqual(ov1.error, 0)
|
||||
|
||||
ov1.ReadFile(h1, 100)
|
||||
self.assertTrue(ov1.pending)
|
||||
self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING)
|
||||
ERROR_IO_INCOMPLETE = 996
|
||||
try:
|
||||
ov1.getresult()
|
||||
except OSError as e:
|
||||
self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE)
|
||||
else:
|
||||
raise RuntimeError('expected ERROR_IO_INCOMPLETE')
|
||||
|
||||
ov2 = _overlapped.Overlapped()
|
||||
self.assertFalse(ov2.pending)
|
||||
self.assertEqual(ov2.error, 0)
|
||||
|
||||
ov2.WriteFile(h2, b"hello")
|
||||
self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
|
||||
|
||||
res = _winapi.WaitForMultipleObjects([ov2.event], False, 100)
|
||||
self.assertEqual(res, _winapi.WAIT_OBJECT_0)
|
||||
|
||||
self.assertFalse(ov1.pending)
|
||||
self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE)
|
||||
self.assertFalse(ov2.pending)
|
||||
self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
|
||||
self.assertEqual(ov1.getresult(), b"hello")
|
||||
finally:
|
||||
_winapi.CloseHandle(h1)
|
||||
_winapi.CloseHandle(h2)
|
||||
|
||||
def test_pipe_handle(self):
|
||||
h, _ = windows_utils.pipe(overlapped=(True, True))
|
||||
_winapi.CloseHandle(_)
|
||||
p = windows_utils.PipeHandle(h)
|
||||
self.assertEqual(p.fileno(), h)
|
||||
self.assertEqual(p.handle, h)
|
||||
|
||||
# check garbage collection of p closes handle
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", "", ResourceWarning)
|
||||
del p
|
||||
support.gc_collect()
|
||||
try:
|
||||
_winapi.CloseHandle(h)
|
||||
except OSError as e:
|
||||
self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE
|
||||
else:
|
||||
raise RuntimeError('expected ERROR_INVALID_HANDLE')
|
||||
|
||||
|
||||
class PopenTests(unittest.TestCase):
|
||||
|
||||
def test_popen(self):
|
||||
command = r"""if 1:
|
||||
import sys
|
||||
s = sys.stdin.readline()
|
||||
sys.stdout.write(s.upper())
|
||||
sys.stderr.write('stderr')
|
||||
"""
|
||||
msg = b"blah\n"
|
||||
|
||||
p = windows_utils.Popen([sys.executable, '-c', command],
|
||||
stdin=windows_utils.PIPE,
|
||||
stdout=windows_utils.PIPE,
|
||||
stderr=windows_utils.PIPE)
|
||||
|
||||
for f in [p.stdin, p.stdout, p.stderr]:
|
||||
self.assertIsInstance(f, windows_utils.PipeHandle)
|
||||
|
||||
ovin = _overlapped.Overlapped()
|
||||
ovout = _overlapped.Overlapped()
|
||||
overr = _overlapped.Overlapped()
|
||||
|
||||
ovin.WriteFile(p.stdin.handle, msg)
|
||||
ovout.ReadFile(p.stdout.handle, 100)
|
||||
overr.ReadFile(p.stderr.handle, 100)
|
||||
|
||||
events = [ovin.event, ovout.event, overr.event]
|
||||
# Super-long timeout for slow buildbots.
|
||||
res = _winapi.WaitForMultipleObjects(events, True,
|
||||
int(support.SHORT_TIMEOUT * 1000))
|
||||
self.assertEqual(res, _winapi.WAIT_OBJECT_0)
|
||||
self.assertFalse(ovout.pending)
|
||||
self.assertFalse(overr.pending)
|
||||
self.assertFalse(ovin.pending)
|
||||
|
||||
self.assertEqual(ovin.getresult(), len(msg))
|
||||
out = ovout.getresult().rstrip()
|
||||
err = overr.getresult().rstrip()
|
||||
|
||||
self.assertGreater(len(out), 0)
|
||||
self.assertGreater(len(err), 0)
|
||||
# allow for partial reads...
|
||||
self.assertTrue(msg.upper().rstrip().startswith(out))
|
||||
self.assertTrue(b"stderr".startswith(err))
|
||||
|
||||
# The context manager calls wait() and closes resources
|
||||
with p:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
638
Utils/PythonNew32/Lib/test/test_asyncio/utils.py
Normal file
638
Utils/PythonNew32/Lib/test/test_asyncio/utils.py
Normal file
@@ -0,0 +1,638 @@
|
||||
"""Utilities shared by tests."""
|
||||
|
||||
import asyncio
|
||||
import collections
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import selectors
|
||||
import socket
|
||||
import socketserver
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
import weakref
|
||||
import warnings
|
||||
from unittest import mock
|
||||
|
||||
from http.server import HTTPServer
|
||||
from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
|
||||
|
||||
try:
|
||||
import ssl
|
||||
except ImportError: # pragma: no cover
|
||||
ssl = None
|
||||
|
||||
from asyncio import base_events
|
||||
from asyncio import events
|
||||
from asyncio import format_helpers
|
||||
from asyncio import tasks
|
||||
from asyncio.log import logger
|
||||
from test import support
|
||||
from test.support import socket_helper
|
||||
from test.support import threading_helper
|
||||
|
||||
|
||||
# Use the maximum known clock resolution (gh-75191, gh-110088): Windows
|
||||
# GetTickCount64() has a resolution of 15.6 ms. Use 50 ms to tolerate rounding
|
||||
# issues.
|
||||
CLOCK_RES = 0.050
|
||||
|
||||
|
||||
def data_file(*filename):
|
||||
fullname = os.path.join(support.TEST_HOME_DIR, *filename)
|
||||
if os.path.isfile(fullname):
|
||||
return fullname
|
||||
fullname = os.path.join(os.path.dirname(__file__), '..', *filename)
|
||||
if os.path.isfile(fullname):
|
||||
return fullname
|
||||
raise FileNotFoundError(os.path.join(filename))
|
||||
|
||||
|
||||
ONLYCERT = data_file('certdata', 'ssl_cert.pem')
|
||||
ONLYKEY = data_file('certdata', 'ssl_key.pem')
|
||||
SIGNED_CERTFILE = data_file('certdata', 'keycert3.pem')
|
||||
SIGNING_CA = data_file('certdata', 'pycacert.pem')
|
||||
PEERCERT = {
|
||||
'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
|
||||
'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
|
||||
'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
|
||||
'issuer': ((('countryName', 'XY'),),
|
||||
(('organizationName', 'Python Software Foundation CA'),),
|
||||
(('commonName', 'our-ca-server'),)),
|
||||
'notAfter': 'Oct 28 14:23:16 2037 GMT',
|
||||
'notBefore': 'Aug 29 14:23:16 2018 GMT',
|
||||
'serialNumber': 'CB2D80995A69525C',
|
||||
'subject': ((('countryName', 'XY'),),
|
||||
(('localityName', 'Castle Anthrax'),),
|
||||
(('organizationName', 'Python Software Foundation'),),
|
||||
(('commonName', 'localhost'),)),
|
||||
'subjectAltName': (('DNS', 'localhost'),),
|
||||
'version': 3
|
||||
}
|
||||
|
||||
|
||||
def simple_server_sslcontext():
|
||||
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
server_context.load_cert_chain(ONLYCERT, ONLYKEY)
|
||||
server_context.check_hostname = False
|
||||
server_context.verify_mode = ssl.CERT_NONE
|
||||
return server_context
|
||||
|
||||
|
||||
def simple_client_sslcontext(*, disable_verify=True):
|
||||
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
client_context.check_hostname = False
|
||||
if disable_verify:
|
||||
client_context.verify_mode = ssl.CERT_NONE
|
||||
return client_context
|
||||
|
||||
|
||||
def dummy_ssl_context():
|
||||
if ssl is None:
|
||||
return None
|
||||
else:
|
||||
return simple_client_sslcontext(disable_verify=True)
|
||||
|
||||
|
||||
def run_briefly(loop):
|
||||
async def once():
|
||||
pass
|
||||
gen = once()
|
||||
t = loop.create_task(gen)
|
||||
# Don't log a warning if the task is not done after run_until_complete().
|
||||
# It occurs if the loop is stopped or if a task raises a BaseException.
|
||||
t._log_destroy_pending = False
|
||||
try:
|
||||
loop.run_until_complete(t)
|
||||
finally:
|
||||
gen.close()
|
||||
|
||||
|
||||
def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
|
||||
delay = 0.001
|
||||
for _ in support.busy_retry(timeout, error=False):
|
||||
if pred():
|
||||
break
|
||||
loop.run_until_complete(tasks.sleep(delay))
|
||||
delay = max(delay * 2, 1.0)
|
||||
else:
|
||||
raise TimeoutError()
|
||||
|
||||
|
||||
def run_once(loop):
|
||||
"""Legacy API to run once through the event loop.
|
||||
|
||||
This is the recommended pattern for test code. It will poll the
|
||||
selector once and run all callbacks scheduled in response to I/O
|
||||
events.
|
||||
"""
|
||||
loop.call_soon(loop.stop)
|
||||
loop.run_forever()
|
||||
|
||||
|
||||
class SilentWSGIRequestHandler(WSGIRequestHandler):
|
||||
|
||||
def get_stderr(self):
|
||||
return io.StringIO()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass
|
||||
|
||||
|
||||
class SilentWSGIServer(WSGIServer):
|
||||
|
||||
request_timeout = support.LOOPBACK_TIMEOUT
|
||||
|
||||
def get_request(self):
|
||||
request, client_addr = super().get_request()
|
||||
request.settimeout(self.request_timeout)
|
||||
return request, client_addr
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
pass
|
||||
|
||||
|
||||
class SSLWSGIServerMixin:
|
||||
|
||||
def finish_request(self, request, client_address):
|
||||
# The relative location of our test directory (which
|
||||
# contains the ssl key and certificate files) differs
|
||||
# between the stdlib and stand-alone asyncio.
|
||||
# Prefer our own if we can find it.
|
||||
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
context.load_cert_chain(ONLYCERT, ONLYKEY)
|
||||
|
||||
ssock = context.wrap_socket(request, server_side=True)
|
||||
try:
|
||||
self.RequestHandlerClass(ssock, client_address, self)
|
||||
ssock.close()
|
||||
except OSError:
|
||||
# maybe socket has been closed by peer
|
||||
pass
|
||||
|
||||
|
||||
class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
|
||||
pass
|
||||
|
||||
|
||||
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
|
||||
|
||||
def loop(environ):
|
||||
size = int(environ['CONTENT_LENGTH'])
|
||||
while size:
|
||||
data = environ['wsgi.input'].read(min(size, 0x10000))
|
||||
yield data
|
||||
size -= len(data)
|
||||
|
||||
def app(environ, start_response):
|
||||
status = '200 OK'
|
||||
headers = [('Content-type', 'text/plain')]
|
||||
start_response(status, headers)
|
||||
if environ['PATH_INFO'] == '/loop':
|
||||
return loop(environ)
|
||||
else:
|
||||
return [b'Test message']
|
||||
|
||||
# Run the test WSGI server in a separate thread in order not to
|
||||
# interfere with event handling in the main thread
|
||||
server_class = server_ssl_cls if use_ssl else server_cls
|
||||
httpd = server_class(address, SilentWSGIRequestHandler)
|
||||
httpd.set_app(app)
|
||||
httpd.address = httpd.server_address
|
||||
server_thread = threading.Thread(
|
||||
target=lambda: httpd.serve_forever(poll_interval=0.05))
|
||||
server_thread.start()
|
||||
try:
|
||||
yield httpd
|
||||
finally:
|
||||
httpd.shutdown()
|
||||
httpd.server_close()
|
||||
server_thread.join()
|
||||
|
||||
|
||||
if hasattr(socket, 'AF_UNIX'):
|
||||
|
||||
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
|
||||
|
||||
def server_bind(self):
|
||||
socketserver.UnixStreamServer.server_bind(self)
|
||||
self.server_name = '127.0.0.1'
|
||||
self.server_port = 80
|
||||
|
||||
|
||||
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
|
||||
|
||||
request_timeout = support.LOOPBACK_TIMEOUT
|
||||
|
||||
def server_bind(self):
|
||||
UnixHTTPServer.server_bind(self)
|
||||
self.setup_environ()
|
||||
|
||||
def get_request(self):
|
||||
request, client_addr = super().get_request()
|
||||
request.settimeout(self.request_timeout)
|
||||
# Code in the stdlib expects that get_request
|
||||
# will return a socket and a tuple (host, port).
|
||||
# However, this isn't true for UNIX sockets,
|
||||
# as the second return value will be a path;
|
||||
# hence we return some fake data sufficient
|
||||
# to get the tests going
|
||||
return request, ('127.0.0.1', '')
|
||||
|
||||
|
||||
class SilentUnixWSGIServer(UnixWSGIServer):
|
||||
|
||||
def handle_error(self, request, client_address):
|
||||
pass
|
||||
|
||||
|
||||
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
|
||||
pass
|
||||
|
||||
|
||||
def gen_unix_socket_path():
|
||||
return socket_helper.create_unix_domain_name()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unix_socket_path():
|
||||
path = gen_unix_socket_path()
|
||||
try:
|
||||
yield path
|
||||
finally:
|
||||
try:
|
||||
os.unlink(path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_test_unix_server(*, use_ssl=False):
|
||||
with unix_socket_path() as path:
|
||||
yield from _run_test_server(address=path, use_ssl=use_ssl,
|
||||
server_cls=SilentUnixWSGIServer,
|
||||
server_ssl_cls=UnixSSLWSGIServer)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
|
||||
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
|
||||
server_cls=SilentWSGIServer,
|
||||
server_ssl_cls=SSLWSGIServer)
|
||||
|
||||
|
||||
def echo_datagrams(sock):
|
||||
while True:
|
||||
data, addr = sock.recvfrom(4096)
|
||||
if data == b'STOP':
|
||||
sock.close()
|
||||
break
|
||||
else:
|
||||
sock.sendto(data, addr)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def run_udp_echo_server(*, host='127.0.0.1', port=0):
|
||||
addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
|
||||
family, type, proto, _, sockaddr = addr_info[0]
|
||||
sock = socket.socket(family, type, proto)
|
||||
sock.bind((host, port))
|
||||
sockname = sock.getsockname()
|
||||
thread = threading.Thread(target=lambda: echo_datagrams(sock))
|
||||
thread.start()
|
||||
try:
|
||||
yield sockname
|
||||
finally:
|
||||
# gh-122187: use a separate socket to send the stop message to avoid
|
||||
# TSan reported race on the same socket.
|
||||
sock2 = socket.socket(family, type, proto)
|
||||
sock2.sendto(b'STOP', sockname)
|
||||
sock2.close()
|
||||
thread.join()
|
||||
|
||||
|
||||
def make_test_protocol(base):
|
||||
dct = {}
|
||||
for name in dir(base):
|
||||
if name.startswith('__') and name.endswith('__'):
|
||||
# skip magic names
|
||||
continue
|
||||
dct[name] = MockCallback(return_value=None)
|
||||
return type('TestProtocol', (base,) + base.__bases__, dct)()
|
||||
|
||||
|
||||
class TestSelector(selectors.BaseSelector):
|
||||
|
||||
def __init__(self):
|
||||
self.keys = {}
|
||||
|
||||
def register(self, fileobj, events, data=None):
|
||||
key = selectors.SelectorKey(fileobj, 0, events, data)
|
||||
self.keys[fileobj] = key
|
||||
return key
|
||||
|
||||
def unregister(self, fileobj):
|
||||
return self.keys.pop(fileobj)
|
||||
|
||||
def select(self, timeout):
|
||||
return []
|
||||
|
||||
def get_map(self):
|
||||
return self.keys
|
||||
|
||||
|
||||
class TestLoop(base_events.BaseEventLoop):
|
||||
"""Loop for unittests.
|
||||
|
||||
It manages self time directly.
|
||||
If something scheduled to be executed later then
|
||||
on next loop iteration after all ready handlers done
|
||||
generator passed to __init__ is calling.
|
||||
|
||||
Generator should be like this:
|
||||
|
||||
def gen():
|
||||
...
|
||||
when = yield ...
|
||||
... = yield time_advance
|
||||
|
||||
Value returned by yield is absolute time of next scheduled handler.
|
||||
Value passed to yield is time advance to move loop's time forward.
|
||||
"""
|
||||
|
||||
def __init__(self, gen=None):
|
||||
super().__init__()
|
||||
|
||||
if gen is None:
|
||||
def gen():
|
||||
yield
|
||||
self._check_on_close = False
|
||||
else:
|
||||
self._check_on_close = True
|
||||
|
||||
self._gen = gen()
|
||||
next(self._gen)
|
||||
self._time = 0
|
||||
self._clock_resolution = 1e-9
|
||||
self._timers = []
|
||||
self._selector = TestSelector()
|
||||
|
||||
self.readers = {}
|
||||
self.writers = {}
|
||||
self.reset_counters()
|
||||
|
||||
self._transports = weakref.WeakValueDictionary()
|
||||
|
||||
def time(self):
|
||||
return self._time
|
||||
|
||||
def advance_time(self, advance):
|
||||
"""Move test time forward."""
|
||||
if advance:
|
||||
self._time += advance
|
||||
|
||||
def close(self):
|
||||
super().close()
|
||||
if self._check_on_close:
|
||||
try:
|
||||
self._gen.send(0)
|
||||
except StopIteration:
|
||||
pass
|
||||
else: # pragma: no cover
|
||||
raise AssertionError("Time generator is not finished")
|
||||
|
||||
def _add_reader(self, fd, callback, *args):
|
||||
self.readers[fd] = events.Handle(callback, args, self, None)
|
||||
|
||||
def _remove_reader(self, fd):
|
||||
self.remove_reader_count[fd] += 1
|
||||
if fd in self.readers:
|
||||
del self.readers[fd]
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def assert_reader(self, fd, callback, *args):
|
||||
if fd not in self.readers:
|
||||
raise AssertionError(f'fd {fd} is not registered')
|
||||
handle = self.readers[fd]
|
||||
if handle._callback != callback:
|
||||
raise AssertionError(
|
||||
f'unexpected callback: {handle._callback} != {callback}')
|
||||
if handle._args != args:
|
||||
raise AssertionError(
|
||||
f'unexpected callback args: {handle._args} != {args}')
|
||||
|
||||
def assert_no_reader(self, fd):
|
||||
if fd in self.readers:
|
||||
raise AssertionError(f'fd {fd} is registered')
|
||||
|
||||
def _add_writer(self, fd, callback, *args):
|
||||
self.writers[fd] = events.Handle(callback, args, self, None)
|
||||
|
||||
def _remove_writer(self, fd):
|
||||
self.remove_writer_count[fd] += 1
|
||||
if fd in self.writers:
|
||||
del self.writers[fd]
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def assert_writer(self, fd, callback, *args):
|
||||
if fd not in self.writers:
|
||||
raise AssertionError(f'fd {fd} is not registered')
|
||||
handle = self.writers[fd]
|
||||
if handle._callback != callback:
|
||||
raise AssertionError(f'{handle._callback!r} != {callback!r}')
|
||||
if handle._args != args:
|
||||
raise AssertionError(f'{handle._args!r} != {args!r}')
|
||||
|
||||
def _ensure_fd_no_transport(self, fd):
|
||||
if not isinstance(fd, int):
|
||||
try:
|
||||
fd = int(fd.fileno())
|
||||
except (AttributeError, TypeError, ValueError):
|
||||
# This code matches selectors._fileobj_to_fd function.
|
||||
raise ValueError("Invalid file object: "
|
||||
"{!r}".format(fd)) from None
|
||||
try:
|
||||
transport = self._transports[fd]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'File descriptor {!r} is used by transport {!r}'.format(
|
||||
fd, transport))
|
||||
|
||||
def add_reader(self, fd, callback, *args):
|
||||
"""Add a reader callback."""
|
||||
self._ensure_fd_no_transport(fd)
|
||||
return self._add_reader(fd, callback, *args)
|
||||
|
||||
def remove_reader(self, fd):
|
||||
"""Remove a reader callback."""
|
||||
self._ensure_fd_no_transport(fd)
|
||||
return self._remove_reader(fd)
|
||||
|
||||
def add_writer(self, fd, callback, *args):
|
||||
"""Add a writer callback.."""
|
||||
self._ensure_fd_no_transport(fd)
|
||||
return self._add_writer(fd, callback, *args)
|
||||
|
||||
def remove_writer(self, fd):
|
||||
"""Remove a writer callback."""
|
||||
self._ensure_fd_no_transport(fd)
|
||||
return self._remove_writer(fd)
|
||||
|
||||
def reset_counters(self):
|
||||
self.remove_reader_count = collections.defaultdict(int)
|
||||
self.remove_writer_count = collections.defaultdict(int)
|
||||
|
||||
def _run_once(self):
|
||||
super()._run_once()
|
||||
for when in self._timers:
|
||||
advance = self._gen.send(when)
|
||||
self.advance_time(advance)
|
||||
self._timers = []
|
||||
|
||||
def call_at(self, when, callback, *args, context=None):
|
||||
self._timers.append(when)
|
||||
return super().call_at(when, callback, *args, context=context)
|
||||
|
||||
def _process_events(self, event_list):
|
||||
return
|
||||
|
||||
def _write_to_self(self):
|
||||
pass
|
||||
|
||||
|
||||
def MockCallback(**kwargs):
|
||||
return mock.Mock(spec=['__call__'], **kwargs)
|
||||
|
||||
|
||||
class MockPattern(str):
|
||||
"""A regex based str with a fuzzy __eq__.
|
||||
|
||||
Use this helper with 'mock.assert_called_with', or anywhere
|
||||
where a regex comparison between strings is needed.
|
||||
|
||||
For instance:
|
||||
mock_call.assert_called_with(MockPattern('spam.*ham'))
|
||||
"""
|
||||
def __eq__(self, other):
|
||||
return bool(re.search(str(self), other, re.S))
|
||||
|
||||
|
||||
class MockInstanceOf:
|
||||
def __init__(self, type):
|
||||
self._type = type
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self._type)
|
||||
|
||||
|
||||
def get_function_source(func):
|
||||
source = format_helpers._get_function_source(func)
|
||||
if source is None:
|
||||
raise ValueError("unable to get the source of %r" % (func,))
|
||||
return source
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
@staticmethod
|
||||
def close_loop(loop):
|
||||
if loop._default_executor is not None:
|
||||
if not loop.is_closed():
|
||||
loop.run_until_complete(loop.shutdown_default_executor())
|
||||
else:
|
||||
loop._default_executor.shutdown(wait=True)
|
||||
loop.close()
|
||||
|
||||
policy = support.maybe_get_event_loop_policy()
|
||||
if policy is not None:
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', DeprecationWarning)
|
||||
watcher = policy.get_child_watcher()
|
||||
except NotImplementedError:
|
||||
# watcher is not implemented by EventLoopPolicy, e.g. Windows
|
||||
pass
|
||||
else:
|
||||
if isinstance(watcher, asyncio.ThreadedChildWatcher):
|
||||
# Wait for subprocess to finish, but not forever
|
||||
for thread in list(watcher._threads.values()):
|
||||
thread.join(timeout=support.SHORT_TIMEOUT)
|
||||
if thread.is_alive():
|
||||
raise RuntimeError(f"thread {thread} still alive: "
|
||||
"subprocess still running")
|
||||
|
||||
|
||||
def set_event_loop(self, loop, *, cleanup=True):
|
||||
if loop is None:
|
||||
raise AssertionError('loop is None')
|
||||
# ensure that the event loop is passed explicitly in asyncio
|
||||
events.set_event_loop(None)
|
||||
if cleanup:
|
||||
self.addCleanup(self.close_loop, loop)
|
||||
|
||||
def new_test_loop(self, gen=None):
|
||||
loop = TestLoop(gen)
|
||||
self.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
def setUp(self):
|
||||
self._thread_cleanup = threading_helper.threading_setup()
|
||||
|
||||
def tearDown(self):
|
||||
events.set_event_loop(None)
|
||||
|
||||
# Detect CPython bug #23353: ensure that yield/yield-from is not used
|
||||
# in an except block of a generator
|
||||
self.assertIsNone(sys.exception())
|
||||
|
||||
self.doCleanups()
|
||||
threading_helper.threading_cleanup(*self._thread_cleanup)
|
||||
support.reap_children()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable_logger():
|
||||
"""Context manager to disable asyncio logger.
|
||||
|
||||
For example, it can be used to ignore warnings in debug mode.
|
||||
"""
|
||||
old_level = logger.level
|
||||
try:
|
||||
logger.setLevel(logging.CRITICAL+1)
|
||||
yield
|
||||
finally:
|
||||
logger.setLevel(old_level)
|
||||
|
||||
|
||||
def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
|
||||
family=socket.AF_INET):
|
||||
"""Create a mock of a non-blocking socket."""
|
||||
sock = mock.MagicMock(socket.socket)
|
||||
sock.proto = proto
|
||||
sock.type = type
|
||||
sock.family = family
|
||||
sock.gettimeout.return_value = 0.0
|
||||
return sock
|
||||
|
||||
|
||||
async def await_without_task(coro):
|
||||
exc = None
|
||||
def func():
|
||||
try:
|
||||
for _ in coro.__await__():
|
||||
pass
|
||||
except BaseException as err:
|
||||
nonlocal exc
|
||||
exc = err
|
||||
asyncio.get_running_loop().call_soon(func)
|
||||
await asyncio.sleep(0)
|
||||
if exc is not None:
|
||||
raise exc
|
||||
Reference in New Issue
Block a user