Todo: 集成多平台 解决因SaiNiu线程抢占资源问题 本地提交测试环境打包 和 正式打包脚本与正式环境打包bat 提交Python32环境包 改进多日志文件生成情况修改打包日志细节

This commit is contained in:
2025-09-18 15:52:03 +08:00
parent 8b9fc925fa
commit 7cfc0c22b7
7608 changed files with 2424791 additions and 25 deletions

View File

@@ -0,0 +1,12 @@
import os
from test import support
from test.support import load_package_tests
from test.support import import_helper
support.requires_working_socket(module=True)
# Skip tests if we don't have concurrent.futures.
import_helper.import_module('concurrent.futures')
def load_tests(*args):
return load_package_tests(os.path.dirname(__file__), *args)

View File

@@ -0,0 +1,4 @@
from . import load_tests
import unittest
unittest.main()

View File

@@ -0,0 +1,8 @@
import os
if __name__ == '__main__':
while True:
buf = os.read(0, 1024)
if not buf:
break
os.write(1, buf)

View File

@@ -0,0 +1,6 @@
import os
if __name__ == '__main__':
buf = os.read(0, 1024)
os.write(1, b'OUT:'+buf)
os.write(2, b'ERR:'+buf)

View File

@@ -0,0 +1,11 @@
import os
if __name__ == '__main__':
while True:
buf = os.read(0, 1024)
if not buf:
break
try:
os.write(1, b'OUT:'+buf)
except OSError as ex:
os.write(2, b'ERR:' + ex.__class__.__name__.encode('ascii'))

View File

@@ -0,0 +1,269 @@
import asyncio
import asyncio.events
import contextlib
import os
import pprint
import select
import socket
import tempfile
import threading
from test import support
class FunctionalTestCaseMixin:
def new_loop(self):
return asyncio.new_event_loop()
def run_loop_briefly(self, *, delay=0.01):
self.loop.run_until_complete(asyncio.sleep(delay))
def loop_exception_handler(self, loop, context):
self.__unhandled_exceptions.append(context)
self.loop.default_exception_handler(context)
def setUp(self):
self.loop = self.new_loop()
asyncio.set_event_loop(None)
self.loop.set_exception_handler(self.loop_exception_handler)
self.__unhandled_exceptions = []
def tearDown(self):
try:
self.loop.close()
if self.__unhandled_exceptions:
print('Unexpected calls to loop.call_exception_handler():')
pprint.pprint(self.__unhandled_exceptions)
self.fail('unexpected calls to loop.call_exception_handler()')
finally:
asyncio.set_event_loop(None)
self.loop = None
def tcp_server(self, server_prog, *,
family=socket.AF_INET,
addr=None,
timeout=support.LOOPBACK_TIMEOUT,
backlog=1,
max_clients=10):
if addr is None:
if hasattr(socket, 'AF_UNIX') and family == socket.AF_UNIX:
with tempfile.NamedTemporaryFile() as tmp:
addr = tmp.name
else:
addr = ('127.0.0.1', 0)
sock = socket.create_server(addr, family=family, backlog=backlog)
if timeout is None:
raise RuntimeError('timeout is required')
if timeout <= 0:
raise RuntimeError('only blocking sockets are supported')
sock.settimeout(timeout)
return TestThreadedServer(
self, sock, server_prog, timeout, max_clients)
def tcp_client(self, client_prog,
family=socket.AF_INET,
timeout=support.LOOPBACK_TIMEOUT):
sock = socket.socket(family, socket.SOCK_STREAM)
if timeout is None:
raise RuntimeError('timeout is required')
if timeout <= 0:
raise RuntimeError('only blocking sockets are supported')
sock.settimeout(timeout)
return TestThreadedClient(
self, sock, client_prog, timeout)
def unix_server(self, *args, **kwargs):
if not hasattr(socket, 'AF_UNIX'):
raise NotImplementedError
return self.tcp_server(*args, family=socket.AF_UNIX, **kwargs)
def unix_client(self, *args, **kwargs):
if not hasattr(socket, 'AF_UNIX'):
raise NotImplementedError
return self.tcp_client(*args, family=socket.AF_UNIX, **kwargs)
@contextlib.contextmanager
def unix_sock_name(self):
with tempfile.TemporaryDirectory() as td:
fn = os.path.join(td, 'sock')
try:
yield fn
finally:
try:
os.unlink(fn)
except OSError:
pass
def _abort_socket_test(self, ex):
try:
self.loop.stop()
finally:
self.fail(ex)
##############################################################################
# Socket Testing Utilities
##############################################################################
class TestSocketWrapper:
def __init__(self, sock):
self.__sock = sock
def recv_all(self, n):
buf = b''
while len(buf) < n:
data = self.recv(n - len(buf))
if data == b'':
raise ConnectionAbortedError
buf += data
return buf
def start_tls(self, ssl_context, *,
server_side=False,
server_hostname=None):
ssl_sock = ssl_context.wrap_socket(
self.__sock, server_side=server_side,
server_hostname=server_hostname,
do_handshake_on_connect=False)
try:
ssl_sock.do_handshake()
except:
ssl_sock.close()
raise
finally:
self.__sock.close()
self.__sock = ssl_sock
def __getattr__(self, name):
return getattr(self.__sock, name)
def __repr__(self):
return '<{} {!r}>'.format(type(self).__name__, self.__sock)
class SocketThread(threading.Thread):
def stop(self):
self._active = False
self.join()
def __enter__(self):
self.start()
return self
def __exit__(self, *exc):
self.stop()
class TestThreadedClient(SocketThread):
def __init__(self, test, sock, prog, timeout):
threading.Thread.__init__(self, None, None, 'test-client')
self.daemon = True
self._timeout = timeout
self._sock = sock
self._active = True
self._prog = prog
self._test = test
def run(self):
try:
self._prog(TestSocketWrapper(self._sock))
except Exception as ex:
self._test._abort_socket_test(ex)
class TestThreadedServer(SocketThread):
def __init__(self, test, sock, prog, timeout, max_clients):
threading.Thread.__init__(self, None, None, 'test-server')
self.daemon = True
self._clients = 0
self._finished_clients = 0
self._max_clients = max_clients
self._timeout = timeout
self._sock = sock
self._active = True
self._prog = prog
self._s1, self._s2 = socket.socketpair()
self._s1.setblocking(False)
self._test = test
def stop(self):
try:
if self._s2 and self._s2.fileno() != -1:
try:
self._s2.send(b'stop')
except OSError:
pass
finally:
super().stop()
def run(self):
try:
with self._sock:
self._sock.setblocking(False)
self._run()
finally:
self._s1.close()
self._s2.close()
def _run(self):
while self._active:
if self._clients >= self._max_clients:
return
r, w, x = select.select(
[self._sock, self._s1], [], [], self._timeout)
if self._s1 in r:
return
if self._sock in r:
try:
conn, addr = self._sock.accept()
except BlockingIOError:
continue
except TimeoutError:
if not self._active:
return
else:
raise
else:
self._clients += 1
conn.settimeout(self._timeout)
try:
with conn:
self._handle_client(conn)
except Exception as ex:
self._active = False
try:
raise
finally:
self._test._abort_socket_test(ex)
def _handle_client(self, sock):
self._prog(TestSocketWrapper(sock))
@property
def addr(self):
return self._sock.getsockname()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,89 @@
import asyncio
import unittest
from test.test_asyncio import functional as func_tests
def tearDownModule():
asyncio.set_event_loop_policy(None)
class ReceiveStuffProto(asyncio.BufferedProtocol):
def __init__(self, cb, con_lost_fut):
self.cb = cb
self.con_lost_fut = con_lost_fut
def get_buffer(self, sizehint):
self.buffer = bytearray(100)
return self.buffer
def buffer_updated(self, nbytes):
self.cb(self.buffer[:nbytes])
def connection_lost(self, exc):
if exc is None:
self.con_lost_fut.set_result(None)
else:
self.con_lost_fut.set_exception(exc)
class BaseTestBufferedProtocol(func_tests.FunctionalTestCaseMixin):
def new_loop(self):
raise NotImplementedError
def test_buffered_proto_create_connection(self):
NOISE = b'12345678+' * 1024
async def client(addr):
data = b''
def on_buf(buf):
nonlocal data
data += buf
if data == NOISE:
tr.write(b'1')
conn_lost_fut = self.loop.create_future()
tr, pr = await self.loop.create_connection(
lambda: ReceiveStuffProto(on_buf, conn_lost_fut), *addr)
await conn_lost_fut
async def on_server_client(reader, writer):
writer.write(NOISE)
await reader.readexactly(1)
writer.close()
await writer.wait_closed()
srv = self.loop.run_until_complete(
asyncio.start_server(
on_server_client, '127.0.0.1', 0))
addr = srv.sockets[0].getsockname()
self.loop.run_until_complete(
asyncio.wait_for(client(addr), 5))
srv.close()
self.loop.run_until_complete(srv.wait_closed())
class BufferedProtocolSelectorTests(BaseTestBufferedProtocol,
unittest.TestCase):
def new_loop(self):
return asyncio.SelectorEventLoop()
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
class BufferedProtocolProactorTests(BaseTestBufferedProtocol,
unittest.TestCase):
def new_loop(self):
return asyncio.ProactorEventLoop()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,38 @@
import asyncio
import decimal
import unittest
def tearDownModule():
asyncio.set_event_loop_policy(None)
@unittest.skipUnless(decimal.HAVE_CONTEXTVAR, "decimal is built with a thread-local context")
class DecimalContextTest(unittest.TestCase):
def test_asyncio_task_decimal_context(self):
async def fractions(t, precision, x, y):
with decimal.localcontext() as ctx:
ctx.prec = precision
a = decimal.Decimal(x) / decimal.Decimal(y)
await asyncio.sleep(t)
b = decimal.Decimal(x) / decimal.Decimal(y ** 2)
return a, b
async def main():
r1, r2 = await asyncio.gather(
fractions(0.1, 3, 1, 3), fractions(0.2, 6, 1, 3))
return r1, r2
r1, r2 = asyncio.run(main())
self.assertEqual(str(r1[0]), '0.333')
self.assertEqual(str(r1[1]), '0.111')
self.assertEqual(str(r2[0]), '0.333333')
self.assertEqual(str(r2[1]), '0.111111')
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,430 @@
"""Tests for base_events.py"""
import asyncio
import contextvars
import unittest
from unittest import mock
from asyncio import tasks
from test.test_asyncio import utils as test_utils
from test.support.script_helper import assert_python_ok
MOCK_ANY = mock.ANY
def tearDownModule():
asyncio.set_event_loop_policy(None)
class EagerTaskFactoryLoopTests:
Task = None
def run_coro(self, coro):
"""
Helper method to run the `coro` coroutine in the test event loop.
It helps with making sure the event loop is running before starting
to execute `coro`. This is important for testing the eager step
functionality, since an eager step is taken only if the event loop
is already running.
"""
async def coro_runner():
self.assertTrue(asyncio.get_event_loop().is_running())
return await coro
return self.loop.run_until_complete(coro)
def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.eager_task_factory = asyncio.create_eager_task_factory(self.Task)
self.loop.set_task_factory(self.eager_task_factory)
self.set_event_loop(self.loop)
def test_eager_task_factory_set(self):
self.assertIsNotNone(self.eager_task_factory)
self.assertIs(self.loop.get_task_factory(), self.eager_task_factory)
async def noop(): pass
async def run():
t = self.loop.create_task(noop())
self.assertIsInstance(t, self.Task)
await t
self.run_coro(run())
def test_await_future_during_eager_step(self):
async def set_result(fut, val):
fut.set_result(val)
async def run():
fut = self.loop.create_future()
t = self.loop.create_task(set_result(fut, 'my message'))
# assert the eager step completed the task
self.assertTrue(t.done())
return await fut
self.assertEqual(self.run_coro(run()), 'my message')
def test_eager_completion(self):
async def coro():
return 'hello'
async def run():
t = self.loop.create_task(coro())
# assert the eager step completed the task
self.assertTrue(t.done())
return await t
self.assertEqual(self.run_coro(run()), 'hello')
def test_block_after_eager_step(self):
async def coro():
await asyncio.sleep(0.1)
return 'finished after blocking'
async def run():
t = self.loop.create_task(coro())
self.assertFalse(t.done())
result = await t
self.assertTrue(t.done())
return result
self.assertEqual(self.run_coro(run()), 'finished after blocking')
def test_cancellation_after_eager_completion(self):
async def coro():
return 'finished without blocking'
async def run():
t = self.loop.create_task(coro())
t.cancel()
result = await t
# finished task can't be cancelled
self.assertFalse(t.cancelled())
return result
self.assertEqual(self.run_coro(run()), 'finished without blocking')
def test_cancellation_after_eager_step_blocks(self):
async def coro():
await asyncio.sleep(0.1)
return 'finished after blocking'
async def run():
t = self.loop.create_task(coro())
t.cancel('cancellation message')
self.assertGreater(t.cancelling(), 0)
result = await t
with self.assertRaises(asyncio.CancelledError) as cm:
self.run_coro(run())
self.assertEqual('cancellation message', cm.exception.args[0])
def test_current_task(self):
captured_current_task = None
async def coro():
nonlocal captured_current_task
captured_current_task = asyncio.current_task()
# verify the task before and after blocking is identical
await asyncio.sleep(0.1)
self.assertIs(asyncio.current_task(), captured_current_task)
async def run():
t = self.loop.create_task(coro())
self.assertIs(captured_current_task, t)
await t
self.run_coro(run())
captured_current_task = None
def test_all_tasks_with_eager_completion(self):
captured_all_tasks = None
async def coro():
nonlocal captured_all_tasks
captured_all_tasks = asyncio.all_tasks()
async def run():
t = self.loop.create_task(coro())
self.assertIn(t, captured_all_tasks)
self.assertNotIn(t, asyncio.all_tasks())
self.run_coro(run())
def test_all_tasks_with_blocking(self):
captured_eager_all_tasks = None
async def coro(fut1, fut2):
nonlocal captured_eager_all_tasks
captured_eager_all_tasks = asyncio.all_tasks()
await fut1
fut2.set_result(None)
async def run():
fut1 = self.loop.create_future()
fut2 = self.loop.create_future()
t = self.loop.create_task(coro(fut1, fut2))
self.assertIn(t, captured_eager_all_tasks)
self.assertIn(t, asyncio.all_tasks())
fut1.set_result(None)
await fut2
self.assertNotIn(t, asyncio.all_tasks())
self.run_coro(run())
def test_context_vars(self):
cv = contextvars.ContextVar('cv', default=0)
coro_first_step_ran = False
coro_second_step_ran = False
async def coro():
nonlocal coro_first_step_ran
nonlocal coro_second_step_ran
self.assertEqual(cv.get(), 1)
cv.set(2)
self.assertEqual(cv.get(), 2)
coro_first_step_ran = True
await asyncio.sleep(0.1)
self.assertEqual(cv.get(), 2)
cv.set(3)
self.assertEqual(cv.get(), 3)
coro_second_step_ran = True
async def run():
cv.set(1)
t = self.loop.create_task(coro())
self.assertTrue(coro_first_step_ran)
self.assertFalse(coro_second_step_ran)
self.assertEqual(cv.get(), 1)
await t
self.assertTrue(coro_second_step_ran)
self.assertEqual(cv.get(), 1)
self.run_coro(run())
def test_staggered_race_with_eager_tasks(self):
# See https://github.com/python/cpython/issues/124309
async def fail():
await asyncio.sleep(0)
raise ValueError("no good")
async def blocked():
fut = asyncio.Future()
await fut
async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: blocked(),
lambda: asyncio.sleep(1, result="sleep1"),
lambda: fail()
],
delay=0.25
)
self.assertEqual(winner, 'sleep1')
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], asyncio.CancelledError)
self.assertIsInstance(excs[2], ValueError)
self.run_coro(run())
def test_staggered_race_with_eager_tasks_no_delay(self):
# See https://github.com/python/cpython/issues/124309
async def fail():
raise ValueError("no good")
async def run():
winner, index, excs = await asyncio.staggered.staggered_race(
[
lambda: fail(),
lambda: asyncio.sleep(1, result="sleep1"),
lambda: asyncio.sleep(0, result="sleep0"),
],
delay=None
)
self.assertEqual(winner, 'sleep1')
self.assertEqual(index, 1)
self.assertIsNone(excs[index])
self.assertIsInstance(excs[0], ValueError)
self.assertEqual(len(excs), 2)
self.run_coro(run())
class PyEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = tasks._PyTask
@unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module')
class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase):
Task = getattr(tasks, '_CTask', None)
def test_issue105987(self):
code = """if 1:
from _asyncio import _swap_current_task
class DummyTask:
pass
class DummyLoop:
pass
l = DummyLoop()
_swap_current_task(l, DummyTask())
t = _swap_current_task(l, None)
"""
_, out, err = assert_python_ok("-c", code)
self.assertFalse(err)
def test_issue122332(self):
async def coro():
pass
async def run():
task = self.loop.create_task(coro())
await task
self.assertIsNone(task.get_coro())
self.run_coro(run())
def test_name(self):
name = None
async def coro():
nonlocal name
name = asyncio.current_task().get_name()
async def main():
task = self.loop.create_task(coro(), name="test name")
self.assertEqual(name, "test name")
await task
self.run_coro(coro())
class AsyncTaskCounter:
def __init__(self, loop, *, task_class, eager):
self.suspense_count = 0
self.task_count = 0
def CountingTask(*args, eager_start=False, **kwargs):
if not eager_start:
self.task_count += 1
kwargs["eager_start"] = eager_start
return task_class(*args, **kwargs)
if eager:
factory = asyncio.create_eager_task_factory(CountingTask)
else:
def factory(loop, coro, **kwargs):
return CountingTask(coro, loop=loop, **kwargs)
loop.set_task_factory(factory)
def get(self):
return self.task_count
async def awaitable_chain(depth):
if depth == 0:
return 0
return 1 + await awaitable_chain(depth - 1)
async def recursive_taskgroups(width, depth):
if depth == 0:
return
async with asyncio.TaskGroup() as tg:
futures = [
tg.create_task(recursive_taskgroups(width, depth - 1))
for _ in range(width)
]
async def recursive_gather(width, depth):
if depth == 0:
return
await asyncio.gather(
*[recursive_gather(width, depth - 1) for _ in range(width)]
)
class BaseTaskCountingTests:
Task = None
eager = None
expected_task_count = None
def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.counter = AsyncTaskCounter(self.loop, task_class=self.Task, eager=self.eager)
self.set_event_loop(self.loop)
def test_awaitables_chain(self):
observed_depth = self.loop.run_until_complete(awaitable_chain(100))
self.assertEqual(observed_depth, 100)
self.assertEqual(self.counter.get(), 0 if self.eager else 1)
def test_recursive_taskgroups(self):
num_tasks = self.loop.run_until_complete(recursive_taskgroups(5, 4))
self.assertEqual(self.counter.get(), self.expected_task_count)
def test_recursive_gather(self):
self.loop.run_until_complete(recursive_gather(5, 4))
self.assertEqual(self.counter.get(), self.expected_task_count)
class BaseNonEagerTaskFactoryTests(BaseTaskCountingTests):
eager = False
expected_task_count = 781 # 1 + 5 + 5^2 + 5^3 + 5^4
class BaseEagerTaskFactoryTests(BaseTaskCountingTests):
eager = True
expected_task_count = 0
class NonEagerTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
Task = asyncio.Task
class EagerTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
Task = asyncio.Task
class NonEagerPyTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
Task = tasks._PyTask
class EagerPyTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
Task = tasks._PyTask
@unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module')
class NonEagerCTaskTests(BaseNonEagerTaskFactoryTests, test_utils.TestCase):
Task = getattr(tasks, '_CTask', None)
@unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module')
class EagerCTaskTests(BaseEagerTaskFactoryTests, test_utils.TestCase):
Task = getattr(tasks, '_CTask', None)
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,95 @@
# IsolatedAsyncioTestCase based tests
import asyncio
import contextvars
import traceback
import unittest
from asyncio import tasks
def tearDownModule():
asyncio.set_event_loop_policy(None)
class FutureTests:
async def test_future_traceback(self):
async def raise_exc():
raise TypeError(42)
future = self.cls(raise_exc())
for _ in range(5):
try:
await future
except TypeError as e:
tb = ''.join(traceback.format_tb(e.__traceback__))
self.assertEqual(tb.count("await future"), 1)
else:
self.fail('TypeError was not raised')
async def test_task_exc_handler_correct_context(self):
# see https://github.com/python/cpython/issues/96704
name = contextvars.ContextVar('name', default='foo')
exc_handler_called = False
def exc_handler(*args):
self.assertEqual(name.get(), 'bar')
nonlocal exc_handler_called
exc_handler_called = True
async def task():
name.set('bar')
1/0
loop = asyncio.get_running_loop()
loop.set_exception_handler(exc_handler)
self.cls(task())
await asyncio.sleep(0)
self.assertTrue(exc_handler_called)
async def test_handle_exc_handler_correct_context(self):
# see https://github.com/python/cpython/issues/96704
name = contextvars.ContextVar('name', default='foo')
exc_handler_called = False
def exc_handler(*args):
self.assertEqual(name.get(), 'bar')
nonlocal exc_handler_called
exc_handler_called = True
def callback():
name.set('bar')
1/0
loop = asyncio.get_running_loop()
loop.set_exception_handler(exc_handler)
loop.call_soon(callback)
await asyncio.sleep(0)
self.assertTrue(exc_handler_called)
@unittest.skipUnless(hasattr(tasks, '_CTask'),
'requires the C _asyncio module')
class CFutureTests(FutureTests, unittest.IsolatedAsyncioTestCase):
cls = tasks._CTask
class PyFutureTests(FutureTests, unittest.IsolatedAsyncioTestCase):
cls = tasks._PyTask
class FutureReprTests(unittest.IsolatedAsyncioTestCase):
async def test_recursive_repr_for_pending_tasks(self):
# The call crashes if the guard for recursive call
# in base_futures:_future_repr_info is absent
# See Also: https://bugs.python.org/issue42183
async def func():
return asyncio.all_tasks()
# The repr() call should not raise RecursionError at first.
waiter = await asyncio.wait_for(asyncio.Task(func()),timeout=10)
self.assertIn('...', repr(waiter))
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,212 @@
"""Tests support for new syntax introduced by PEP 492."""
import sys
import types
import unittest
from unittest import mock
import asyncio
from test.test_asyncio import utils as test_utils
def tearDownModule():
asyncio.set_event_loop_policy(None)
# Test that asyncio.iscoroutine() uses collections.abc.Coroutine
class FakeCoro:
def send(self, value):
pass
def throw(self, typ, val=None, tb=None):
pass
def close(self):
pass
def __await__(self):
yield
class BaseTest(test_utils.TestCase):
def setUp(self):
super().setUp()
self.loop = asyncio.BaseEventLoop()
self.loop._process_events = mock.Mock()
self.loop._selector = mock.Mock()
self.loop._selector.select.return_value = ()
self.set_event_loop(self.loop)
class LockTests(BaseTest):
def test_context_manager_async_with(self):
primitives = [
asyncio.Lock(),
asyncio.Condition(),
asyncio.Semaphore(),
asyncio.BoundedSemaphore(),
]
async def test(lock):
await asyncio.sleep(0.01)
self.assertFalse(lock.locked())
async with lock as _lock:
self.assertIs(_lock, None)
self.assertTrue(lock.locked())
await asyncio.sleep(0.01)
self.assertTrue(lock.locked())
self.assertFalse(lock.locked())
for primitive in primitives:
self.loop.run_until_complete(test(primitive))
self.assertFalse(primitive.locked())
def test_context_manager_with_await(self):
primitives = [
asyncio.Lock(),
asyncio.Condition(),
asyncio.Semaphore(),
asyncio.BoundedSemaphore(),
]
async def test(lock):
await asyncio.sleep(0.01)
self.assertFalse(lock.locked())
with self.assertRaisesRegex(
TypeError,
"can't be used in 'await' expression"
):
with await lock:
pass
for primitive in primitives:
self.loop.run_until_complete(test(primitive))
self.assertFalse(primitive.locked())
class StreamReaderTests(BaseTest):
def test_readline(self):
DATA = b'line1\nline2\nline3'
stream = asyncio.StreamReader(loop=self.loop)
stream.feed_data(DATA)
stream.feed_eof()
async def reader():
data = []
async for line in stream:
data.append(line)
return data
data = self.loop.run_until_complete(reader())
self.assertEqual(data, [b'line1\n', b'line2\n', b'line3'])
class CoroutineTests(BaseTest):
def test_iscoroutine(self):
async def foo(): pass
f = foo()
try:
self.assertTrue(asyncio.iscoroutine(f))
finally:
f.close() # silence warning
self.assertTrue(asyncio.iscoroutine(FakeCoro()))
def test_iscoroutine_generator(self):
def foo(): yield
self.assertFalse(asyncio.iscoroutine(foo()))
def test_iscoroutinefunction(self):
async def foo(): pass
self.assertTrue(asyncio.iscoroutinefunction(foo))
def test_async_def_coroutines(self):
async def bar():
return 'spam'
async def foo():
return await bar()
# production mode
data = self.loop.run_until_complete(foo())
self.assertEqual(data, 'spam')
# debug mode
self.loop.set_debug(True)
data = self.loop.run_until_complete(foo())
self.assertEqual(data, 'spam')
def test_debug_mode_manages_coroutine_origin_tracking(self):
async def start():
self.assertTrue(sys.get_coroutine_origin_tracking_depth() > 0)
self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
self.loop.set_debug(True)
self.loop.run_until_complete(start())
self.assertEqual(sys.get_coroutine_origin_tracking_depth(), 0)
def test_types_coroutine(self):
def gen():
yield from ()
return 'spam'
@types.coroutine
def func():
return gen()
async def coro():
wrapper = func()
self.assertIsInstance(wrapper, types._GeneratorWrapper)
return await wrapper
data = self.loop.run_until_complete(coro())
self.assertEqual(data, 'spam')
def test_task_print_stack(self):
T = None
async def foo():
f = T.get_stack(limit=1)
try:
self.assertEqual(f[0].f_code.co_name, 'foo')
finally:
f = None
async def runner():
nonlocal T
T = asyncio.ensure_future(foo(), loop=self.loop)
await T
self.loop.run_until_complete(runner())
def test_double_await(self):
async def afunc():
await asyncio.sleep(0.1)
async def runner():
coro = afunc()
t = self.loop.create_task(coro)
try:
await asyncio.sleep(0)
await coro
finally:
t.cancel()
self.loop.set_debug(True)
with self.assertRaises(
RuntimeError,
msg='coroutine is being awaited already'):
self.loop.run_until_complete(runner())
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,67 @@
import unittest
from unittest import mock
import asyncio
def tearDownModule():
# not needed for the test file but added for uniformness with all other
# asyncio test files for the sake of unified cleanup
asyncio.set_event_loop_policy(None)
class ProtocolsAbsTests(unittest.TestCase):
def test_base_protocol(self):
f = mock.Mock()
p = asyncio.BaseProtocol()
self.assertIsNone(p.connection_made(f))
self.assertIsNone(p.connection_lost(f))
self.assertIsNone(p.pause_writing())
self.assertIsNone(p.resume_writing())
self.assertFalse(hasattr(p, '__dict__'))
def test_protocol(self):
f = mock.Mock()
p = asyncio.Protocol()
self.assertIsNone(p.connection_made(f))
self.assertIsNone(p.connection_lost(f))
self.assertIsNone(p.data_received(f))
self.assertIsNone(p.eof_received())
self.assertIsNone(p.pause_writing())
self.assertIsNone(p.resume_writing())
self.assertFalse(hasattr(p, '__dict__'))
def test_buffered_protocol(self):
f = mock.Mock()
p = asyncio.BufferedProtocol()
self.assertIsNone(p.connection_made(f))
self.assertIsNone(p.connection_lost(f))
self.assertIsNone(p.get_buffer(100))
self.assertIsNone(p.buffer_updated(150))
self.assertIsNone(p.pause_writing())
self.assertIsNone(p.resume_writing())
self.assertFalse(hasattr(p, '__dict__'))
def test_datagram_protocol(self):
f = mock.Mock()
dp = asyncio.DatagramProtocol()
self.assertIsNone(dp.connection_made(f))
self.assertIsNone(dp.connection_lost(f))
self.assertIsNone(dp.error_received(f))
self.assertIsNone(dp.datagram_received(f, f))
self.assertFalse(hasattr(dp, '__dict__'))
def test_subprocess_protocol(self):
f = mock.Mock()
sp = asyncio.SubprocessProtocol()
self.assertIsNone(sp.connection_made(f))
self.assertIsNone(sp.connection_lost(f))
self.assertIsNone(sp.pipe_data_received(1, f))
self.assertIsNone(sp.pipe_connection_lost(1, f))
self.assertIsNone(sp.process_exited())
self.assertFalse(hasattr(sp, '__dict__'))
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,725 @@
"""Tests for queues.py"""
import asyncio
import unittest
from types import GenericAlias
def tearDownModule():
asyncio.set_event_loop_policy(None)
class QueueBasicTests(unittest.IsolatedAsyncioTestCase):
async def _test_repr_or_str(self, fn, expect_id):
"""Test Queue's repr or str.
fn is repr or str. expect_id is True if we expect the Queue's id to
appear in fn(Queue()).
"""
q = asyncio.Queue()
self.assertTrue(fn(q).startswith('<Queue'), fn(q))
id_is_present = hex(id(q)) in fn(q)
self.assertEqual(expect_id, id_is_present)
# getters
q = asyncio.Queue()
async with asyncio.TaskGroup() as tg:
# Start a task that waits to get.
getter = tg.create_task(q.get())
# Let it start waiting.
await asyncio.sleep(0)
self.assertTrue('_getters[1]' in fn(q))
# resume q.get coroutine to finish generator
q.put_nowait(0)
self.assertEqual(0, await getter)
# putters
q = asyncio.Queue(maxsize=1)
async with asyncio.TaskGroup() as tg:
q.put_nowait(1)
# Start a task that waits to put.
putter = tg.create_task(q.put(2))
# Let it start waiting.
await asyncio.sleep(0)
self.assertTrue('_putters[1]' in fn(q))
# resume q.put coroutine to finish generator
q.get_nowait()
self.assertTrue(putter.done())
q = asyncio.Queue()
q.put_nowait(1)
self.assertTrue('_queue=[1]' in fn(q))
async def test_repr(self):
await self._test_repr_or_str(repr, True)
async def test_str(self):
await self._test_repr_or_str(str, False)
def test_generic_alias(self):
q = asyncio.Queue[int]
self.assertEqual(q.__args__, (int,))
self.assertIsInstance(q, GenericAlias)
async def test_empty(self):
q = asyncio.Queue()
self.assertTrue(q.empty())
await q.put(1)
self.assertFalse(q.empty())
self.assertEqual(1, await q.get())
self.assertTrue(q.empty())
async def test_full(self):
q = asyncio.Queue()
self.assertFalse(q.full())
q = asyncio.Queue(maxsize=1)
await q.put(1)
self.assertTrue(q.full())
async def test_order(self):
q = asyncio.Queue()
for i in [1, 3, 2]:
await q.put(i)
items = [await q.get() for _ in range(3)]
self.assertEqual([1, 3, 2], items)
async def test_maxsize(self):
q = asyncio.Queue(maxsize=2)
self.assertEqual(2, q.maxsize)
have_been_put = []
async def putter():
for i in range(3):
await q.put(i)
have_been_put.append(i)
return True
t = asyncio.create_task(putter())
for i in range(2):
await asyncio.sleep(0)
# The putter is blocked after putting two items.
self.assertEqual([0, 1], have_been_put)
self.assertEqual(0, await q.get())
# Let the putter resume and put last item.
await asyncio.sleep(0)
self.assertEqual([0, 1, 2], have_been_put)
self.assertEqual(1, await q.get())
self.assertEqual(2, await q.get())
self.assertTrue(t.done())
self.assertTrue(t.result())
class QueueGetTests(unittest.IsolatedAsyncioTestCase):
async def test_blocking_get(self):
q = asyncio.Queue()
q.put_nowait(1)
self.assertEqual(1, await q.get())
async def test_get_with_putters(self):
loop = asyncio.get_running_loop()
q = asyncio.Queue(1)
await q.put(1)
waiter = loop.create_future()
q._putters.append(waiter)
self.assertEqual(1, await q.get())
self.assertTrue(waiter.done())
self.assertIsNone(waiter.result())
async def test_blocking_get_wait(self):
loop = asyncio.get_running_loop()
q = asyncio.Queue()
started = asyncio.Event()
finished = False
async def queue_get():
nonlocal finished
started.set()
res = await q.get()
finished = True
return res
queue_get_task = asyncio.create_task(queue_get())
await started.wait()
self.assertFalse(finished)
loop.call_later(0.01, q.put_nowait, 1)
res = await queue_get_task
self.assertTrue(finished)
self.assertEqual(1, res)
def test_nonblocking_get(self):
q = asyncio.Queue()
q.put_nowait(1)
self.assertEqual(1, q.get_nowait())
def test_nonblocking_get_exception(self):
q = asyncio.Queue()
self.assertRaises(asyncio.QueueEmpty, q.get_nowait)
async def test_get_cancelled_race(self):
q = asyncio.Queue()
t1 = asyncio.create_task(q.get())
t2 = asyncio.create_task(q.get())
await asyncio.sleep(0)
t1.cancel()
await asyncio.sleep(0)
self.assertTrue(t1.done())
await q.put('a')
await asyncio.sleep(0)
self.assertEqual('a', await t2)
async def test_get_with_waiting_putters(self):
q = asyncio.Queue(maxsize=1)
asyncio.create_task(q.put('a'))
asyncio.create_task(q.put('b'))
self.assertEqual(await q.get(), 'a')
self.assertEqual(await q.get(), 'b')
async def test_why_are_getters_waiting(self):
async def consumer(queue, num_expected):
for _ in range(num_expected):
await queue.get()
async def producer(queue, num_items):
for i in range(num_items):
await queue.put(i)
producer_num_items = 5
q = asyncio.Queue(1)
async with asyncio.TaskGroup() as tg:
tg.create_task(producer(q, producer_num_items))
tg.create_task(consumer(q, producer_num_items))
async def test_cancelled_getters_not_being_held_in_self_getters(self):
queue = asyncio.Queue(maxsize=5)
with self.assertRaises(TimeoutError):
await asyncio.wait_for(queue.get(), 0.1)
self.assertEqual(len(queue._getters), 0)
class QueuePutTests(unittest.IsolatedAsyncioTestCase):
async def test_blocking_put(self):
q = asyncio.Queue()
# No maxsize, won't block.
await q.put(1)
self.assertEqual(1, await q.get())
async def test_blocking_put_wait(self):
q = asyncio.Queue(maxsize=1)
started = asyncio.Event()
finished = False
async def queue_put():
nonlocal finished
started.set()
await q.put(1)
await q.put(2)
finished = True
loop = asyncio.get_running_loop()
loop.call_later(0.01, q.get_nowait)
queue_put_task = asyncio.create_task(queue_put())
await started.wait()
self.assertFalse(finished)
await queue_put_task
self.assertTrue(finished)
def test_nonblocking_put(self):
q = asyncio.Queue()
q.put_nowait(1)
self.assertEqual(1, q.get_nowait())
async def test_get_cancel_drop_one_pending_reader(self):
q = asyncio.Queue()
reader = asyncio.create_task(q.get())
await asyncio.sleep(0)
q.put_nowait(1)
q.put_nowait(2)
reader.cancel()
try:
await reader
except asyncio.CancelledError:
# try again
reader = asyncio.create_task(q.get())
await reader
result = reader.result()
# if we get 2, it means 1 got dropped!
self.assertEqual(1, result)
async def test_get_cancel_drop_many_pending_readers(self):
q = asyncio.Queue()
async with asyncio.TaskGroup() as tg:
reader1 = tg.create_task(q.get())
reader2 = tg.create_task(q.get())
reader3 = tg.create_task(q.get())
await asyncio.sleep(0)
q.put_nowait(1)
q.put_nowait(2)
reader1.cancel()
with self.assertRaises(asyncio.CancelledError):
await reader1
await reader3
# It is undefined in which order concurrent readers receive results.
self.assertEqual({reader2.result(), reader3.result()}, {1, 2})
async def test_put_cancel_drop(self):
q = asyncio.Queue(1)
q.put_nowait(1)
# putting a second item in the queue has to block (qsize=1)
writer = asyncio.create_task(q.put(2))
await asyncio.sleep(0)
value1 = q.get_nowait()
self.assertEqual(value1, 1)
writer.cancel()
try:
await writer
except asyncio.CancelledError:
# try again
writer = asyncio.create_task(q.put(2))
await writer
value2 = q.get_nowait()
self.assertEqual(value2, 2)
self.assertEqual(q.qsize(), 0)
def test_nonblocking_put_exception(self):
q = asyncio.Queue(maxsize=1, )
q.put_nowait(1)
self.assertRaises(asyncio.QueueFull, q.put_nowait, 2)
async def test_float_maxsize(self):
q = asyncio.Queue(maxsize=1.3, )
q.put_nowait(1)
q.put_nowait(2)
self.assertTrue(q.full())
self.assertRaises(asyncio.QueueFull, q.put_nowait, 3)
q = asyncio.Queue(maxsize=1.3, )
await q.put(1)
await q.put(2)
self.assertTrue(q.full())
async def test_put_cancelled(self):
q = asyncio.Queue()
async def queue_put():
await q.put(1)
return True
t = asyncio.create_task(queue_put())
self.assertEqual(1, await q.get())
self.assertTrue(t.done())
self.assertTrue(t.result())
async def test_put_cancelled_race(self):
q = asyncio.Queue(maxsize=1)
put_a = asyncio.create_task(q.put('a'))
put_b = asyncio.create_task(q.put('b'))
put_c = asyncio.create_task(q.put('X'))
await asyncio.sleep(0)
self.assertTrue(put_a.done())
self.assertFalse(put_b.done())
put_c.cancel()
await asyncio.sleep(0)
self.assertTrue(put_c.done())
self.assertEqual(q.get_nowait(), 'a')
await asyncio.sleep(0)
self.assertEqual(q.get_nowait(), 'b')
await put_b
async def test_put_with_waiting_getters(self):
q = asyncio.Queue()
t = asyncio.create_task(q.get())
await asyncio.sleep(0)
await q.put('a')
self.assertEqual(await t, 'a')
async def test_why_are_putters_waiting(self):
queue = asyncio.Queue(2)
async def putter(item):
await queue.put(item)
async def getter():
await asyncio.sleep(0)
num = queue.qsize()
for _ in range(num):
queue.get_nowait()
async with asyncio.TaskGroup() as tg:
tg.create_task(getter())
tg.create_task(putter(0))
tg.create_task(putter(1))
tg.create_task(putter(2))
tg.create_task(putter(3))
async def test_cancelled_puts_not_being_held_in_self_putters(self):
# Full queue.
queue = asyncio.Queue(maxsize=1)
queue.put_nowait(1)
# Task waiting for space to put an item in the queue.
put_task = asyncio.create_task(queue.put(1))
await asyncio.sleep(0)
# Check that the putter is correctly removed from queue._putters when
# the task is canceled.
self.assertEqual(len(queue._putters), 1)
put_task.cancel()
with self.assertRaises(asyncio.CancelledError):
await put_task
self.assertEqual(len(queue._putters), 0)
async def test_cancelled_put_silence_value_error_exception(self):
# Full Queue.
queue = asyncio.Queue(1)
queue.put_nowait(1)
# Task waiting for space to put a item in the queue.
put_task = asyncio.create_task(queue.put(1))
await asyncio.sleep(0)
# get_nowait() remove the future of put_task from queue._putters.
queue.get_nowait()
# When canceled, queue.put is going to remove its future from
# self._putters but it was removed previously by queue.get_nowait().
put_task.cancel()
# The ValueError exception triggered by queue._putters.remove(putter)
# inside queue.put should be silenced.
# If the ValueError is silenced we should catch a CancelledError.
with self.assertRaises(asyncio.CancelledError):
await put_task
class LifoQueueTests(unittest.IsolatedAsyncioTestCase):
async def test_order(self):
q = asyncio.LifoQueue()
for i in [1, 3, 2]:
await q.put(i)
items = [await q.get() for _ in range(3)]
self.assertEqual([2, 3, 1], items)
class PriorityQueueTests(unittest.IsolatedAsyncioTestCase):
async def test_order(self):
q = asyncio.PriorityQueue()
for i in [1, 3, 2]:
await q.put(i)
items = [await q.get() for _ in range(3)]
self.assertEqual([1, 2, 3], items)
class _QueueJoinTestMixin:
q_class = None
def test_task_done_underflow(self):
q = self.q_class()
self.assertRaises(ValueError, q.task_done)
async def test_task_done(self):
q = self.q_class()
for i in range(100):
q.put_nowait(i)
accumulator = 0
# Two workers get items from the queue and call task_done after each.
# Join the queue and assert all items have been processed.
running = True
async def worker():
nonlocal accumulator
while running:
item = await q.get()
accumulator += item
q.task_done()
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(worker())
for index in range(2)]
await q.join()
self.assertEqual(sum(range(100)), accumulator)
# close running generators
running = False
for i in range(len(tasks)):
q.put_nowait(0)
async def test_join_empty_queue(self):
q = self.q_class()
# Test that a queue join()s successfully, and before anything else
# (done twice for insurance).
await q.join()
await q.join()
async def test_format(self):
q = self.q_class()
self.assertEqual(q._format(), 'maxsize=0')
q._unfinished_tasks = 2
self.assertEqual(q._format(), 'maxsize=0 tasks=2')
class QueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCase):
q_class = asyncio.Queue
class LifoQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCase):
q_class = asyncio.LifoQueue
class PriorityQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCase):
q_class = asyncio.PriorityQueue
class _QueueShutdownTestMixin:
q_class = None
def assertRaisesShutdown(self, msg="Didn't appear to shut-down queue"):
return self.assertRaises(asyncio.QueueShutDown, msg=msg)
async def test_format(self):
q = self.q_class()
q.shutdown()
self.assertEqual(q._format(), 'maxsize=0 shutdown')
async def test_shutdown_empty(self):
# Test shutting down an empty queue
# Setup empty queue, and join() and get() tasks
q = self.q_class()
loop = asyncio.get_running_loop()
get_task = loop.create_task(q.get())
await asyncio.sleep(0) # want get task pending before shutdown
# Perform shut-down
q.shutdown(immediate=False) # unfinished tasks: 0 -> 0
self.assertEqual(q.qsize(), 0)
# Ensure join() task successfully finishes
await q.join()
# Ensure get() task is finished, and raised ShutDown
await asyncio.sleep(0)
self.assertTrue(get_task.done())
with self.assertRaisesShutdown():
await get_task
# Ensure put() and get() raise ShutDown
with self.assertRaisesShutdown():
await q.put("data")
with self.assertRaisesShutdown():
q.put_nowait("data")
with self.assertRaisesShutdown():
await q.get()
with self.assertRaisesShutdown():
q.get_nowait()
async def test_shutdown_nonempty(self):
# Test shutting down a non-empty queue
# Setup full queue with 1 item, and join() and put() tasks
q = self.q_class(maxsize=1)
loop = asyncio.get_running_loop()
q.put_nowait("data")
join_task = loop.create_task(q.join())
put_task = loop.create_task(q.put("data2"))
# Ensure put() task is not finished
await asyncio.sleep(0)
self.assertFalse(put_task.done())
# Perform shut-down
q.shutdown(immediate=False) # unfinished tasks: 1 -> 1
self.assertEqual(q.qsize(), 1)
# Ensure put() task is finished, and raised ShutDown
await asyncio.sleep(0)
self.assertTrue(put_task.done())
with self.assertRaisesShutdown():
await put_task
# Ensure get() succeeds on enqueued item
self.assertEqual(await q.get(), "data")
# Ensure join() task is not finished
await asyncio.sleep(0)
self.assertFalse(join_task.done())
# Ensure put() and get() raise ShutDown
with self.assertRaisesShutdown():
await q.put("data")
with self.assertRaisesShutdown():
q.put_nowait("data")
with self.assertRaisesShutdown():
await q.get()
with self.assertRaisesShutdown():
q.get_nowait()
# Ensure there is 1 unfinished task, and join() task succeeds
q.task_done()
await asyncio.sleep(0)
self.assertTrue(join_task.done())
await join_task
with self.assertRaises(
ValueError, msg="Didn't appear to mark all tasks done"
):
q.task_done()
async def test_shutdown_immediate(self):
# Test immediately shutting down a queue
# Setup queue with 1 item, and a join() task
q = self.q_class()
loop = asyncio.get_running_loop()
q.put_nowait("data")
join_task = loop.create_task(q.join())
# Perform shut-down
q.shutdown(immediate=True) # unfinished tasks: 1 -> 0
self.assertEqual(q.qsize(), 0)
# Ensure join() task has successfully finished
await asyncio.sleep(0)
self.assertTrue(join_task.done())
await join_task
# Ensure put() and get() raise ShutDown
with self.assertRaisesShutdown():
await q.put("data")
with self.assertRaisesShutdown():
q.put_nowait("data")
with self.assertRaisesShutdown():
await q.get()
with self.assertRaisesShutdown():
q.get_nowait()
# Ensure there are no unfinished tasks
with self.assertRaises(
ValueError, msg="Didn't appear to mark all tasks done"
):
q.task_done()
async def test_shutdown_immediate_with_unfinished(self):
# Test immediately shutting down a queue with unfinished tasks
# Setup queue with 2 items (1 retrieved), and a join() task
q = self.q_class()
loop = asyncio.get_running_loop()
q.put_nowait("data")
q.put_nowait("data")
join_task = loop.create_task(q.join())
self.assertEqual(await q.get(), "data")
# Perform shut-down
q.shutdown(immediate=True) # unfinished tasks: 2 -> 1
self.assertEqual(q.qsize(), 0)
# Ensure join() task is not finished
await asyncio.sleep(0)
self.assertFalse(join_task.done())
# Ensure put() and get() raise ShutDown
with self.assertRaisesShutdown():
await q.put("data")
with self.assertRaisesShutdown():
q.put_nowait("data")
with self.assertRaisesShutdown():
await q.get()
with self.assertRaisesShutdown():
q.get_nowait()
# Ensure there is 1 unfinished task
q.task_done()
with self.assertRaises(
ValueError, msg="Didn't appear to mark all tasks done"
):
q.task_done()
# Ensure join() task has successfully finished
await asyncio.sleep(0)
self.assertTrue(join_task.done())
await join_task
class QueueShutdownTests(
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
):
q_class = asyncio.Queue
class LifoQueueShutdownTests(
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
):
q_class = asyncio.LifoQueue
class PriorityQueueShutdownTests(
_QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
):
q_class = asyncio.PriorityQueue
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,518 @@
import _thread
import asyncio
import contextvars
import re
import signal
import sys
import threading
import unittest
from test.test_asyncio import utils as test_utils
from unittest import mock
from unittest.mock import patch
def tearDownModule():
asyncio.set_event_loop_policy(None)
def interrupt_self():
_thread.interrupt_main()
class TestPolicy(asyncio.AbstractEventLoopPolicy):
def __init__(self, loop_factory):
self.loop_factory = loop_factory
self.loop = None
def get_event_loop(self):
# shouldn't ever be called by asyncio.run()
raise RuntimeError
def new_event_loop(self):
return self.loop_factory()
def set_event_loop(self, loop):
if loop is not None:
# we want to check if the loop is closed
# in BaseTest.tearDown
self.loop = loop
class BaseTest(unittest.TestCase):
def new_loop(self):
loop = asyncio.BaseEventLoop()
loop._process_events = mock.Mock()
# Mock waking event loop from select
loop._write_to_self = mock.Mock()
loop._write_to_self.return_value = None
loop._selector = mock.Mock()
loop._selector.select.return_value = ()
loop.shutdown_ag_run = False
async def shutdown_asyncgens():
loop.shutdown_ag_run = True
loop.shutdown_asyncgens = shutdown_asyncgens
return loop
def setUp(self):
super().setUp()
policy = TestPolicy(self.new_loop)
asyncio.set_event_loop_policy(policy)
def tearDown(self):
policy = asyncio.get_event_loop_policy()
if policy.loop is not None:
self.assertTrue(policy.loop.is_closed())
self.assertTrue(policy.loop.shutdown_ag_run)
asyncio.set_event_loop_policy(None)
super().tearDown()
class RunTests(BaseTest):
def test_asyncio_run_return(self):
async def main():
await asyncio.sleep(0)
return 42
self.assertEqual(asyncio.run(main()), 42)
def test_asyncio_run_raises(self):
async def main():
await asyncio.sleep(0)
raise ValueError('spam')
with self.assertRaisesRegex(ValueError, 'spam'):
asyncio.run(main())
def test_asyncio_run_only_coro(self):
for o in {1, lambda: None}:
with self.subTest(obj=o), \
self.assertRaisesRegex(ValueError,
'a coroutine was expected'):
asyncio.run(o)
def test_asyncio_run_debug(self):
async def main(expected):
loop = asyncio.get_event_loop()
self.assertIs(loop.get_debug(), expected)
asyncio.run(main(False), debug=False)
asyncio.run(main(True), debug=True)
with mock.patch('asyncio.coroutines._is_debug_mode', lambda: True):
asyncio.run(main(True))
asyncio.run(main(False), debug=False)
with mock.patch('asyncio.coroutines._is_debug_mode', lambda: False):
asyncio.run(main(True), debug=True)
asyncio.run(main(False))
def test_asyncio_run_from_running_loop(self):
async def main():
coro = main()
try:
asyncio.run(coro)
finally:
coro.close() # Suppress ResourceWarning
with self.assertRaisesRegex(RuntimeError,
'cannot be called from a running'):
asyncio.run(main())
def test_asyncio_run_cancels_hanging_tasks(self):
lo_task = None
async def leftover():
await asyncio.sleep(0.1)
async def main():
nonlocal lo_task
lo_task = asyncio.create_task(leftover())
return 123
self.assertEqual(asyncio.run(main()), 123)
self.assertTrue(lo_task.done())
def test_asyncio_run_reports_hanging_tasks_errors(self):
lo_task = None
call_exc_handler_mock = mock.Mock()
async def leftover():
try:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
1 / 0
async def main():
loop = asyncio.get_running_loop()
loop.call_exception_handler = call_exc_handler_mock
nonlocal lo_task
lo_task = asyncio.create_task(leftover())
return 123
self.assertEqual(asyncio.run(main()), 123)
self.assertTrue(lo_task.done())
call_exc_handler_mock.assert_called_with({
'message': test_utils.MockPattern(r'asyncio.run.*shutdown'),
'task': lo_task,
'exception': test_utils.MockInstanceOf(ZeroDivisionError)
})
def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self):
spinner = None
lazyboy = None
class FancyExit(Exception):
pass
async def fidget():
while True:
yield 1
await asyncio.sleep(1)
async def spin():
nonlocal spinner
spinner = fidget()
try:
async for the_meaning_of_life in spinner: # NoQA
pass
except asyncio.CancelledError:
1 / 0
async def main():
loop = asyncio.get_running_loop()
loop.call_exception_handler = mock.Mock()
nonlocal lazyboy
lazyboy = asyncio.create_task(spin())
raise FancyExit
with self.assertRaises(FancyExit):
asyncio.run(main())
self.assertTrue(lazyboy.done())
self.assertIsNone(spinner.ag_frame)
self.assertFalse(spinner.ag_running)
def test_asyncio_run_set_event_loop(self):
#See https://github.com/python/cpython/issues/93896
async def main():
await asyncio.sleep(0)
return 42
policy = asyncio.get_event_loop_policy()
policy.set_event_loop = mock.Mock()
asyncio.run(main())
self.assertTrue(policy.set_event_loop.called)
def test_asyncio_run_without_uncancel(self):
# See https://github.com/python/cpython/issues/95097
class Task:
def __init__(self, loop, coro, **kwargs):
self._task = asyncio.Task(coro, loop=loop, **kwargs)
def cancel(self, *args, **kwargs):
return self._task.cancel(*args, **kwargs)
def add_done_callback(self, *args, **kwargs):
return self._task.add_done_callback(*args, **kwargs)
def remove_done_callback(self, *args, **kwargs):
return self._task.remove_done_callback(*args, **kwargs)
@property
def _asyncio_future_blocking(self):
return self._task._asyncio_future_blocking
def result(self, *args, **kwargs):
return self._task.result(*args, **kwargs)
def done(self, *args, **kwargs):
return self._task.done(*args, **kwargs)
def cancelled(self, *args, **kwargs):
return self._task.cancelled(*args, **kwargs)
def exception(self, *args, **kwargs):
return self._task.exception(*args, **kwargs)
def get_loop(self, *args, **kwargs):
return self._task.get_loop(*args, **kwargs)
def set_name(self, *args, **kwargs):
return self._task.set_name(*args, **kwargs)
async def main():
interrupt_self()
await asyncio.Event().wait()
def new_event_loop():
loop = self.new_loop()
loop.set_task_factory(Task)
return loop
asyncio.set_event_loop_policy(TestPolicy(new_event_loop))
with self.assertRaises(asyncio.CancelledError):
asyncio.run(main())
def test_asyncio_run_loop_factory(self):
factory = mock.Mock()
loop = factory.return_value = self.new_loop()
async def main():
self.assertEqual(asyncio.get_running_loop(), loop)
asyncio.run(main(), loop_factory=factory)
factory.assert_called_once_with()
def test_loop_factory_default_event_loop(self):
async def main():
if sys.platform == "win32":
self.assertIsInstance(asyncio.get_running_loop(), asyncio.ProactorEventLoop)
else:
self.assertIsInstance(asyncio.get_running_loop(), asyncio.SelectorEventLoop)
asyncio.run(main(), loop_factory=asyncio.EventLoop)
class RunnerTests(BaseTest):
def test_non_debug(self):
with asyncio.Runner(debug=False) as runner:
self.assertFalse(runner.get_loop().get_debug())
def test_debug(self):
with asyncio.Runner(debug=True) as runner:
self.assertTrue(runner.get_loop().get_debug())
def test_custom_factory(self):
loop = mock.Mock()
with asyncio.Runner(loop_factory=lambda: loop) as runner:
self.assertIs(runner.get_loop(), loop)
def test_run(self):
async def f():
await asyncio.sleep(0)
return 'done'
with asyncio.Runner() as runner:
self.assertEqual('done', runner.run(f()))
loop = runner.get_loop()
with self.assertRaisesRegex(
RuntimeError,
"Runner is closed"
):
runner.get_loop()
self.assertTrue(loop.is_closed())
def test_run_non_coro(self):
with asyncio.Runner() as runner:
with self.assertRaisesRegex(
ValueError,
"a coroutine was expected"
):
runner.run(123)
def test_run_future(self):
with asyncio.Runner() as runner:
with self.assertRaisesRegex(
ValueError,
"a coroutine was expected"
):
fut = runner.get_loop().create_future()
runner.run(fut)
def test_explicit_close(self):
runner = asyncio.Runner()
loop = runner.get_loop()
runner.close()
with self.assertRaisesRegex(
RuntimeError,
"Runner is closed"
):
runner.get_loop()
self.assertTrue(loop.is_closed())
def test_double_close(self):
runner = asyncio.Runner()
loop = runner.get_loop()
runner.close()
self.assertTrue(loop.is_closed())
# the second call is no-op
runner.close()
self.assertTrue(loop.is_closed())
def test_second_with_block_raises(self):
ret = []
async def f(arg):
ret.append(arg)
runner = asyncio.Runner()
with runner:
runner.run(f(1))
with self.assertRaisesRegex(
RuntimeError,
"Runner is closed"
):
with runner:
runner.run(f(2))
self.assertEqual([1], ret)
def test_run_keeps_context(self):
cvar = contextvars.ContextVar("cvar", default=-1)
async def f(val):
old = cvar.get()
await asyncio.sleep(0)
cvar.set(val)
return old
async def get_context():
return contextvars.copy_context()
with asyncio.Runner() as runner:
self.assertEqual(-1, runner.run(f(1)))
self.assertEqual(1, runner.run(f(2)))
self.assertEqual(2, runner.run(get_context()).get(cvar))
def test_recursive_run(self):
async def g():
pass
async def f():
runner.run(g())
with asyncio.Runner() as runner:
with self.assertWarnsRegex(
RuntimeWarning,
"coroutine .+ was never awaited",
):
with self.assertRaisesRegex(
RuntimeError,
re.escape(
"Runner.run() cannot be called from a running event loop"
),
):
runner.run(f())
def test_interrupt_call_soon(self):
# The only case when task is not suspended by waiting a future
# or another task
assert threading.current_thread() is threading.main_thread()
async def coro():
with self.assertRaises(asyncio.CancelledError):
while True:
await asyncio.sleep(0)
raise asyncio.CancelledError()
with asyncio.Runner() as runner:
runner.get_loop().call_later(0.1, interrupt_self)
with self.assertRaises(KeyboardInterrupt):
runner.run(coro())
def test_interrupt_wait(self):
# interrupting when waiting a future cancels both future and main task
assert threading.current_thread() is threading.main_thread()
async def coro(fut):
with self.assertRaises(asyncio.CancelledError):
await fut
raise asyncio.CancelledError()
with asyncio.Runner() as runner:
fut = runner.get_loop().create_future()
runner.get_loop().call_later(0.1, interrupt_self)
with self.assertRaises(KeyboardInterrupt):
runner.run(coro(fut))
self.assertTrue(fut.cancelled())
def test_interrupt_cancelled_task(self):
# interrupting cancelled main task doesn't raise KeyboardInterrupt
assert threading.current_thread() is threading.main_thread()
async def subtask(task):
await asyncio.sleep(0)
task.cancel()
interrupt_self()
async def coro():
asyncio.create_task(subtask(asyncio.current_task()))
await asyncio.sleep(10)
with asyncio.Runner() as runner:
with self.assertRaises(asyncio.CancelledError):
runner.run(coro())
def test_signal_install_not_supported_ok(self):
# signal.signal() can throw if the "main thread" doesn't have signals enabled
assert threading.current_thread() is threading.main_thread()
async def coro():
pass
with asyncio.Runner() as runner:
with patch.object(
signal,
"signal",
side_effect=ValueError(
"signal only works in main thread of the main interpreter"
)
):
runner.run(coro())
def test_set_event_loop_called_once(self):
# See https://github.com/python/cpython/issues/95736
async def coro():
pass
policy = asyncio.get_event_loop_policy()
policy.set_event_loop = mock.Mock()
runner = asyncio.Runner()
runner.run(coro())
runner.run(coro())
self.assertEqual(1, policy.set_event_loop.call_count)
runner.close()
def test_no_repr_is_call_on_the_task_result(self):
# See https://github.com/python/cpython/issues/112559.
class MyResult:
def __init__(self):
self.repr_count = 0
def __repr__(self):
self.repr_count += 1
return super().__repr__()
async def coro():
return MyResult()
with asyncio.Runner() as runner:
result = runner.run(coro())
self.assertEqual(0, result.repr_count)
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,585 @@
"""Tests for sendfile functionality."""
import asyncio
import errno
import os
import socket
import sys
import tempfile
import unittest
from asyncio import base_events
from asyncio import constants
from unittest import mock
from test import support
from test.support import os_helper
from test.support import socket_helper
from test.test_asyncio import utils as test_utils
try:
import ssl
except ImportError:
ssl = None
def tearDownModule():
asyncio.set_event_loop_policy(None)
class MySendfileProto(asyncio.Protocol):
def __init__(self, loop=None, close_after=0):
self.transport = None
self.state = 'INITIAL'
self.nbytes = 0
if loop is not None:
self.connected = loop.create_future()
self.done = loop.create_future()
self.data = bytearray()
self.close_after = close_after
def _assert_state(self, *expected):
if self.state not in expected:
raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
def connection_made(self, transport):
self.transport = transport
self._assert_state('INITIAL')
self.state = 'CONNECTED'
if self.connected:
self.connected.set_result(None)
def eof_received(self):
self._assert_state('CONNECTED')
self.state = 'EOF'
def connection_lost(self, exc):
self._assert_state('CONNECTED', 'EOF')
self.state = 'CLOSED'
if self.done:
self.done.set_result(None)
def data_received(self, data):
self._assert_state('CONNECTED')
self.nbytes += len(data)
self.data.extend(data)
super().data_received(data)
if self.close_after and self.nbytes >= self.close_after:
self.transport.close()
class MyProto(asyncio.Protocol):
def __init__(self, loop):
self.started = False
self.closed = False
self.data = bytearray()
self.fut = loop.create_future()
self.transport = None
def connection_made(self, transport):
self.started = True
self.transport = transport
def data_received(self, data):
self.data.extend(data)
def connection_lost(self, exc):
self.closed = True
self.fut.set_result(None)
async def wait_closed(self):
await self.fut
class SendfileBase:
# Linux >= 6.10 seems buffering up to 17 pages of data.
# So DATA should be large enough to make this test reliable even with a
# 64 KiB page configuration.
DATA = b"x" * (1024 * 17 * 64 + 1)
# Reduce socket buffer size to test on relative small data sets.
BUF_SIZE = 4 * 1024 # 4 KiB
def create_event_loop(self):
raise NotImplementedError
@classmethod
def setUpClass(cls):
with open(os_helper.TESTFN, 'wb') as fp:
fp.write(cls.DATA)
super().setUpClass()
@classmethod
def tearDownClass(cls):
os_helper.unlink(os_helper.TESTFN)
super().tearDownClass()
def setUp(self):
self.file = open(os_helper.TESTFN, 'rb')
self.addCleanup(self.file.close)
self.loop = self.create_event_loop()
self.set_event_loop(self.loop)
super().setUp()
def tearDown(self):
# just in case if we have transport close callbacks
if not self.loop.is_closed():
test_utils.run_briefly(self.loop)
self.doCleanups()
support.gc_collect()
super().tearDown()
def run_loop(self, coro):
return self.loop.run_until_complete(coro)
class SockSendfileMixin(SendfileBase):
@classmethod
def setUpClass(cls):
cls.__old_bufsize = constants.SENDFILE_FALLBACK_READBUFFER_SIZE
constants.SENDFILE_FALLBACK_READBUFFER_SIZE = 1024 * 16
super().setUpClass()
@classmethod
def tearDownClass(cls):
constants.SENDFILE_FALLBACK_READBUFFER_SIZE = cls.__old_bufsize
super().tearDownClass()
def make_socket(self, cleanup=True):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(False)
if cleanup:
self.addCleanup(sock.close)
return sock
def reduce_receive_buffer_size(self, sock):
# Reduce receive socket buffer size to test on relative
# small data sets.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, self.BUF_SIZE)
def reduce_send_buffer_size(self, sock, transport=None):
# Reduce send socket buffer size to test on relative small data sets.
# On macOS, SO_SNDBUF is reset by connect(). So this method
# should be called after the socket is connected.
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, self.BUF_SIZE)
if transport is not None:
transport.set_write_buffer_limits(high=self.BUF_SIZE)
def prepare_socksendfile(self):
proto = MyProto(self.loop)
port = socket_helper.find_unused_port()
srv_sock = self.make_socket(cleanup=False)
srv_sock.bind((socket_helper.HOST, port))
server = self.run_loop(self.loop.create_server(
lambda: proto, sock=srv_sock))
self.reduce_receive_buffer_size(srv_sock)
sock = self.make_socket()
self.run_loop(self.loop.sock_connect(sock, ('127.0.0.1', port)))
self.reduce_send_buffer_size(sock)
def cleanup():
if proto.transport is not None:
# can be None if the task was cancelled before
# connection_made callback
proto.transport.close()
self.run_loop(proto.wait_closed())
server.close()
self.run_loop(server.wait_closed())
self.addCleanup(cleanup)
return sock, proto
def test_sock_sendfile_success(self):
sock, proto = self.prepare_socksendfile()
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, len(self.DATA))
self.assertEqual(proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sock_sendfile_with_offset_and_count(self):
sock, proto = self.prepare_socksendfile()
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file,
1000, 2000))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(proto.data, self.DATA[1000:3000])
self.assertEqual(self.file.tell(), 3000)
self.assertEqual(ret, 2000)
def test_sock_sendfile_zero_size(self):
sock, proto = self.prepare_socksendfile()
with tempfile.TemporaryFile() as f:
ret = self.run_loop(self.loop.sock_sendfile(sock, f,
0, None))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, 0)
self.assertEqual(self.file.tell(), 0)
def test_sock_sendfile_mix_with_regular_send(self):
buf = b"mix_regular_send" * (4 * 1024) # 64 KiB
sock, proto = self.prepare_socksendfile()
self.run_loop(self.loop.sock_sendall(sock, buf))
ret = self.run_loop(self.loop.sock_sendfile(sock, self.file))
self.run_loop(self.loop.sock_sendall(sock, buf))
sock.close()
self.run_loop(proto.wait_closed())
self.assertEqual(ret, len(self.DATA))
expected = buf + self.DATA + buf
self.assertEqual(proto.data, expected)
self.assertEqual(self.file.tell(), len(self.DATA))
class SendfileMixin(SendfileBase):
# Note: sendfile via SSL transport is equal to sendfile fallback
def prepare_sendfile(self, *, is_ssl=False, close_after=0):
port = socket_helper.find_unused_port()
srv_proto = MySendfileProto(loop=self.loop,
close_after=close_after)
if is_ssl:
if not ssl:
self.skipTest("No ssl module")
srv_ctx = test_utils.simple_server_sslcontext()
cli_ctx = test_utils.simple_client_sslcontext()
else:
srv_ctx = None
cli_ctx = None
srv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
srv_sock.bind((socket_helper.HOST, port))
server = self.run_loop(self.loop.create_server(
lambda: srv_proto, sock=srv_sock, ssl=srv_ctx))
self.reduce_receive_buffer_size(srv_sock)
if is_ssl:
server_hostname = socket_helper.HOST
else:
server_hostname = None
cli_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
cli_sock.connect((socket_helper.HOST, port))
cli_proto = MySendfileProto(loop=self.loop)
tr, pr = self.run_loop(self.loop.create_connection(
lambda: cli_proto, sock=cli_sock,
ssl=cli_ctx, server_hostname=server_hostname))
self.reduce_send_buffer_size(cli_sock, transport=tr)
def cleanup():
srv_proto.transport.close()
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.run_loop(cli_proto.done)
server.close()
self.run_loop(server.wait_closed())
self.addCleanup(cleanup)
return srv_proto, cli_proto
@unittest.skipIf(sys.platform == 'win32', "UDP sockets are not supported")
def test_sendfile_not_supported(self):
tr, pr = self.run_loop(
self.loop.create_datagram_endpoint(
asyncio.DatagramProtocol,
family=socket.AF_INET))
try:
with self.assertRaisesRegex(RuntimeError, "not supported"):
self.run_loop(
self.loop.sendfile(tr, self.file))
self.assertEqual(0, self.file.tell())
finally:
# don't use self.addCleanup because it produces resource warning
tr.close()
def test_sendfile(self):
srv_proto, cli_proto = self.prepare_sendfile()
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_force_fallback(self):
srv_proto, cli_proto = self.prepare_sendfile()
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
return base_events.BaseEventLoop._sendfile_native(
self.loop, transp, file, offset, count)
self.loop._sendfile_native = sendfile_native
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_force_unsupported_native(self):
if sys.platform == 'win32':
if isinstance(self.loop, asyncio.ProactorEventLoop):
self.skipTest("Fails on proactor event loop")
srv_proto, cli_proto = self.prepare_sendfile()
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
return base_events.BaseEventLoop._sendfile_native(
self.loop, transp, file, offset, count)
self.loop._sendfile_native = sendfile_native
with self.assertRaisesRegex(asyncio.SendfileNotAvailableError,
"not supported"):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file,
fallback=False))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(srv_proto.nbytes, 0)
self.assertEqual(self.file.tell(), 0)
def test_sendfile_ssl(self):
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_for_closing_transp(self):
srv_proto, cli_proto = self.prepare_sendfile()
cli_proto.transport.close()
with self.assertRaisesRegex(RuntimeError, "is closing"):
self.run_loop(self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
self.assertEqual(srv_proto.nbytes, 0)
self.assertEqual(self.file.tell(), 0)
def test_sendfile_pre_and_post_data(self):
srv_proto, cli_proto = self.prepare_sendfile()
PREFIX = b'PREFIX__' * 1024 # 8 KiB
SUFFIX = b'--SUFFIX' * 1024 # 8 KiB
cli_proto.transport.write(PREFIX)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.write(SUFFIX)
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_ssl_pre_and_post_data(self):
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
PREFIX = b'zxcvbnm' * 1024
SUFFIX = b'0987654321' * 1024
cli_proto.transport.write(PREFIX)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.write(SUFFIX)
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.data, PREFIX + self.DATA + SUFFIX)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_partial(self):
srv_proto, cli_proto = self.prepare_sendfile()
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, 100)
self.assertEqual(srv_proto.nbytes, 100)
self.assertEqual(srv_proto.data, self.DATA[1000:1100])
self.assertEqual(self.file.tell(), 1100)
def test_sendfile_ssl_partial(self):
srv_proto, cli_proto = self.prepare_sendfile(is_ssl=True)
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file, 1000, 100))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, 100)
self.assertEqual(srv_proto.nbytes, 100)
self.assertEqual(srv_proto.data, self.DATA[1000:1100])
self.assertEqual(self.file.tell(), 1100)
def test_sendfile_close_peer_after_receiving(self):
srv_proto, cli_proto = self.prepare_sendfile(
close_after=len(self.DATA))
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
cli_proto.transport.close()
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
def test_sendfile_ssl_close_peer_after_receiving(self):
srv_proto, cli_proto = self.prepare_sendfile(
is_ssl=True, close_after=len(self.DATA))
ret = self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
self.assertEqual(ret, len(self.DATA))
self.assertEqual(srv_proto.nbytes, len(self.DATA))
self.assertEqual(srv_proto.data, self.DATA)
self.assertEqual(self.file.tell(), len(self.DATA))
# On Solaris, lowering SO_RCVBUF on a TCP connection after it has been
# established has no effect. Due to its age, this bug affects both Oracle
# Solaris as well as all other OpenSolaris forks (unless they fixed it
# themselves).
@unittest.skipIf(sys.platform.startswith('sunos'),
"Doesn't work on Solaris")
def test_sendfile_close_peer_in_the_middle_of_receiving(self):
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
with self.assertRaises(ConnectionError):
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
self.run_loop(srv_proto.done)
self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
srv_proto.nbytes)
if not (sys.platform == 'win32'
and isinstance(self.loop, asyncio.ProactorEventLoop)):
# On Windows, Proactor uses transmitFile, which does not update tell()
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
self.file.tell())
self.assertTrue(cli_proto.transport.is_closing())
def test_sendfile_fallback_close_peer_in_the_middle_of_receiving(self):
def sendfile_native(transp, file, offset, count):
# to raise SendfileNotAvailableError
return base_events.BaseEventLoop._sendfile_native(
self.loop, transp, file, offset, count)
self.loop._sendfile_native = sendfile_native
srv_proto, cli_proto = self.prepare_sendfile(close_after=1024)
with self.assertRaises(ConnectionError):
try:
self.run_loop(
self.loop.sendfile(cli_proto.transport, self.file))
except OSError as e:
# macOS may raise OSError of EPROTOTYPE when writing to a
# socket that is in the process of closing down.
if e.errno == errno.EPROTOTYPE and sys.platform == "darwin":
raise ConnectionError
else:
raise
self.run_loop(srv_proto.done)
self.assertTrue(1024 <= srv_proto.nbytes < len(self.DATA),
srv_proto.nbytes)
self.assertTrue(1024 <= self.file.tell() < len(self.DATA),
self.file.tell())
@unittest.skipIf(not hasattr(os, 'sendfile'),
"Don't have native sendfile support")
def test_sendfile_prevents_bare_write(self):
srv_proto, cli_proto = self.prepare_sendfile()
fut = self.loop.create_future()
async def coro():
fut.set_result(None)
return await self.loop.sendfile(cli_proto.transport, self.file)
t = self.loop.create_task(coro())
self.run_loop(fut)
with self.assertRaisesRegex(RuntimeError,
"sendfile is in progress"):
cli_proto.transport.write(b'data')
ret = self.run_loop(t)
self.assertEqual(ret, len(self.DATA))
def test_sendfile_no_fallback_for_fallback_transport(self):
transport = mock.Mock()
transport.is_closing.side_effect = lambda: False
transport._sendfile_compatible = constants._SendfileMode.FALLBACK
with self.assertRaisesRegex(RuntimeError, 'fallback is disabled'):
self.loop.run_until_complete(
self.loop.sendfile(transport, None, fallback=False))
class SendfileTestsBase(SendfileMixin, SockSendfileMixin):
pass
if sys.platform == 'win32':
class SelectEventLoopTests(SendfileTestsBase,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(SendfileTestsBase,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.ProactorEventLoop()
else:
import selectors
if hasattr(selectors, 'KqueueSelector'):
class KqueueEventLoopTests(SendfileTestsBase,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(
selectors.KqueueSelector())
if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(SendfileTestsBase,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.EpollSelector())
if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(SendfileTestsBase,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.PollSelector())
# Should always exist.
class SelectEventLoopTests(SendfileTestsBase,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.SelectSelector())
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,352 @@
import asyncio
import os
import socket
import time
import threading
import unittest
from test.support import socket_helper
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests
def tearDownModule():
asyncio.set_event_loop_policy(None)
class BaseStartServer(func_tests.FunctionalTestCaseMixin):
def new_loop(self):
raise NotImplementedError
def test_start_server_1(self):
HELLO_MSG = b'1' * 1024 * 5 + b'\n'
def client(sock, addr):
for i in range(10):
time.sleep(0.2)
if srv.is_serving():
break
else:
raise RuntimeError
sock.settimeout(2)
sock.connect(addr)
sock.send(HELLO_MSG)
sock.recv_all(1)
sock.close()
async def serve(reader, writer):
await reader.readline()
main_task.cancel()
writer.write(b'1')
writer.close()
await writer.wait_closed()
async def main(srv):
async with srv:
await srv.serve_forever()
srv = self.loop.run_until_complete(asyncio.start_server(
serve, socket_helper.HOSTv4, 0, start_serving=False))
self.assertFalse(srv.is_serving())
main_task = self.loop.create_task(main(srv))
addr = srv.sockets[0].getsockname()
with self.assertRaises(asyncio.CancelledError):
with self.tcp_client(lambda sock: client(sock, addr)):
self.loop.run_until_complete(main_task)
self.assertEqual(srv.sockets, ())
self.assertIsNone(srv._sockets)
self.assertIsNone(srv._waiters)
self.assertFalse(srv.is_serving())
with self.assertRaisesRegex(RuntimeError, r'is closed'):
self.loop.run_until_complete(srv.serve_forever())
class SelectorStartServerTests(BaseStartServer, unittest.TestCase):
def new_loop(self):
return asyncio.SelectorEventLoop()
@socket_helper.skip_unless_bind_unix_socket
def test_start_unix_server_1(self):
HELLO_MSG = b'1' * 1024 * 5 + b'\n'
started = threading.Event()
def client(sock, addr):
sock.settimeout(2)
started.wait(5)
sock.connect(addr)
sock.send(HELLO_MSG)
sock.recv_all(1)
sock.close()
async def serve(reader, writer):
await reader.readline()
main_task.cancel()
writer.write(b'1')
writer.close()
await writer.wait_closed()
async def main(srv):
async with srv:
self.assertFalse(srv.is_serving())
await srv.start_serving()
self.assertTrue(srv.is_serving())
started.set()
await srv.serve_forever()
with test_utils.unix_socket_path() as addr:
srv = self.loop.run_until_complete(asyncio.start_unix_server(
serve, addr, start_serving=False))
main_task = self.loop.create_task(main(srv))
with self.assertRaises(asyncio.CancelledError):
with self.unix_client(lambda sock: client(sock, addr)):
self.loop.run_until_complete(main_task)
self.assertEqual(srv.sockets, ())
self.assertIsNone(srv._sockets)
self.assertIsNone(srv._waiters)
self.assertFalse(srv.is_serving())
with self.assertRaisesRegex(RuntimeError, r'is closed'):
self.loop.run_until_complete(srv.serve_forever())
class TestServer2(unittest.IsolatedAsyncioTestCase):
async def test_wait_closed_basic(self):
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
# active count = 0, not closed: should block
task1 = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task1.done())
# active count != 0, not closed: should block
addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
task2 = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task1.done())
self.assertFalse(task2.done())
srv.close()
await asyncio.sleep(0)
# active count != 0, closed: should block
task3 = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task1.done())
self.assertFalse(task2.done())
self.assertFalse(task3.done())
wr.close()
await wr.wait_closed()
# active count == 0, closed: should unblock
await task1
await task2
await task3
await srv.wait_closed() # Return immediately
async def test_wait_closed_race(self):
# Test a regression in 3.12.0, should be fixed in 3.12.1
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())
addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
loop = asyncio.get_running_loop()
loop.call_soon(srv.close)
loop.call_soon(wr.close)
await srv.wait_closed()
async def test_close_clients(self):
async def serve(rd, wr):
try:
await rd.read()
finally:
wr.close()
await wr.wait_closed()
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
addr = srv.sockets[0].getsockname()
(rd, wr) = await asyncio.open_connection(addr[0], addr[1])
self.addCleanup(wr.close)
task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())
srv.close()
srv.close_clients()
await asyncio.sleep(0)
await asyncio.sleep(0)
self.assertTrue(task.done())
async def test_abort_clients(self):
async def serve(rd, wr):
fut.set_result((rd, wr))
await wr.wait_closed()
fut = asyncio.Future()
srv = await asyncio.start_server(serve, socket_helper.HOSTv4, 0)
self.addCleanup(srv.close)
addr = srv.sockets[0].getsockname()
(c_rd, c_wr) = await asyncio.open_connection(addr[0], addr[1], limit=4096)
self.addCleanup(c_wr.close)
(s_rd, s_wr) = await fut
# Limit the socket buffers so we can more reliably overfill them
s_sock = s_wr.get_extra_info('socket')
s_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 65536)
c_sock = c_wr.get_extra_info('socket')
c_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 65536)
# Get the reader in to a paused state by sending more than twice
# the configured limit
s_wr.write(b'a' * 4096)
s_wr.write(b'a' * 4096)
s_wr.write(b'a' * 4096)
while c_wr.transport.is_reading():
await asyncio.sleep(0)
# Get the writer in a waiting state by sending data until the
# kernel stops accepting more data in the send buffer.
# gh-122136: getsockopt() does not reliably report the buffer size
# available for message content.
# We loop until we start filling up the asyncio buffer.
# To avoid an infinite loop we cap at 10 times the expected value
c_bufsize = c_sock.getsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF)
s_bufsize = s_sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF)
for i in range(10):
s_wr.write(b'a' * c_bufsize)
s_wr.write(b'a' * s_bufsize)
if s_wr.transport.get_write_buffer_size() > 0:
break
self.assertNotEqual(s_wr.transport.get_write_buffer_size(), 0)
task = asyncio.create_task(srv.wait_closed())
await asyncio.sleep(0)
self.assertFalse(task.done())
srv.close()
srv.abort_clients()
await asyncio.sleep(0)
await asyncio.sleep(0)
self.assertTrue(task.done())
# Test the various corner cases of Unix server socket removal
class UnixServerCleanupTests(unittest.IsolatedAsyncioTestCase):
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_addr_cleanup(self):
# Default scenario
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
srv = await asyncio.start_unix_server(serve, addr)
srv.close()
self.assertFalse(os.path.exists(addr))
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_sock_cleanup(self):
# Using already bound socket
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.bind(addr)
srv = await asyncio.start_unix_server(serve, sock=sock)
srv.close()
self.assertFalse(os.path.exists(addr))
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_cleanup_gone(self):
# Someone else has already cleaned up the socket
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.bind(addr)
srv = await asyncio.start_unix_server(serve, sock=sock)
os.unlink(addr)
srv.close()
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_cleanup_replaced(self):
# Someone else has replaced the socket with their own
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
srv = await asyncio.start_unix_server(serve, addr)
os.unlink(addr)
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock:
sock.bind(addr)
srv.close()
self.assertTrue(os.path.exists(addr))
@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_cleanup_prevented(self):
# Automatic cleanup explicitly disabled
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass
srv = await asyncio.start_unix_server(serve, addr, cleanup_socket=False)
srv.close()
self.assertTrue(os.path.exists(addr))
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
class ProactorStartServerTests(BaseStartServer, unittest.TestCase):
def new_loop(self):
return asyncio.ProactorEventLoop()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,679 @@
import socket
import asyncio
import sys
import unittest
from asyncio import proactor_events
from itertools import cycle, islice
from unittest.mock import Mock
from test.test_asyncio import utils as test_utils
from test import support
from test.support import socket_helper
if socket_helper.tcp_blackhole():
raise unittest.SkipTest('Not relevant to ProactorEventLoop')
def tearDownModule():
asyncio.set_event_loop_policy(None)
class MyProto(asyncio.Protocol):
connected = None
done = None
def __init__(self, loop=None):
self.transport = None
self.state = 'INITIAL'
self.nbytes = 0
if loop is not None:
self.connected = loop.create_future()
self.done = loop.create_future()
def _assert_state(self, *expected):
if self.state not in expected:
raise AssertionError(f'state: {self.state!r}, expected: {expected!r}')
def connection_made(self, transport):
self.transport = transport
self._assert_state('INITIAL')
self.state = 'CONNECTED'
if self.connected:
self.connected.set_result(None)
transport.write(b'GET / HTTP/1.0\r\nHost: example.com\r\n\r\n')
def data_received(self, data):
self._assert_state('CONNECTED')
self.nbytes += len(data)
def eof_received(self):
self._assert_state('CONNECTED')
self.state = 'EOF'
def connection_lost(self, exc):
self._assert_state('CONNECTED', 'EOF')
self.state = 'CLOSED'
if self.done:
self.done.set_result(None)
class BaseSockTestsMixin:
def create_event_loop(self):
raise NotImplementedError
def setUp(self):
self.loop = self.create_event_loop()
self.set_event_loop(self.loop)
super().setUp()
def tearDown(self):
# just in case if we have transport close callbacks
if not self.loop.is_closed():
test_utils.run_briefly(self.loop)
self.doCleanups()
support.gc_collect()
super().tearDown()
def _basetest_sock_client_ops(self, httpd, sock):
if not isinstance(self.loop, proactor_events.BaseProactorEventLoop):
# in debug mode, socket operations must fail
# if the socket is not in blocking mode
self.loop.set_debug(True)
sock.setblocking(True)
with self.assertRaises(ValueError):
self.loop.run_until_complete(
self.loop.sock_connect(sock, httpd.address))
with self.assertRaises(ValueError):
self.loop.run_until_complete(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
with self.assertRaises(ValueError):
self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
with self.assertRaises(ValueError):
self.loop.run_until_complete(
self.loop.sock_recv_into(sock, bytearray()))
with self.assertRaises(ValueError):
self.loop.run_until_complete(
self.loop.sock_accept(sock))
# test in non-blocking mode
sock.setblocking(False)
self.loop.run_until_complete(
self.loop.sock_connect(sock, httpd.address))
self.loop.run_until_complete(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
data = self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
# consume data
self.loop.run_until_complete(
self.loop.sock_recv(sock, 1024))
sock.close()
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
def _basetest_sock_recv_into(self, httpd, sock):
# same as _basetest_sock_client_ops, but using sock_recv_into
sock.setblocking(False)
self.loop.run_until_complete(
self.loop.sock_connect(sock, httpd.address))
self.loop.run_until_complete(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
data = bytearray(1024)
with memoryview(data) as buf:
nbytes = self.loop.run_until_complete(
self.loop.sock_recv_into(sock, buf[:1024]))
# consume data
self.loop.run_until_complete(
self.loop.sock_recv_into(sock, buf[nbytes:]))
sock.close()
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
def test_sock_client_ops(self):
with test_utils.run_test_server() as httpd:
sock = socket.socket()
self._basetest_sock_client_ops(httpd, sock)
sock = socket.socket()
self._basetest_sock_recv_into(httpd, sock)
async def _basetest_sock_recv_racing(self, httpd, sock):
sock.setblocking(False)
await self.loop.sock_connect(sock, httpd.address)
task = asyncio.create_task(self.loop.sock_recv(sock, 1024))
await asyncio.sleep(0)
task.cancel()
asyncio.create_task(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
data = await self.loop.sock_recv(sock, 1024)
# consume data
await self.loop.sock_recv(sock, 1024)
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
async def _basetest_sock_recv_into_racing(self, httpd, sock):
sock.setblocking(False)
await self.loop.sock_connect(sock, httpd.address)
data = bytearray(1024)
with memoryview(data) as buf:
task = asyncio.create_task(
self.loop.sock_recv_into(sock, buf[:1024]))
await asyncio.sleep(0)
task.cancel()
task = asyncio.create_task(
self.loop.sock_sendall(sock, b'GET / HTTP/1.0\r\n\r\n'))
nbytes = await self.loop.sock_recv_into(sock, buf[:1024])
# consume data
await self.loop.sock_recv_into(sock, buf[nbytes:])
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
await task
async def _basetest_sock_send_racing(self, listener, sock):
listener.bind(('127.0.0.1', 0))
listener.listen(1)
# make connection
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
sock.setblocking(False)
task = asyncio.create_task(
self.loop.sock_connect(sock, listener.getsockname()))
await asyncio.sleep(0)
server = listener.accept()[0]
server.setblocking(False)
with server:
await task
# fill the buffer until sending 5 chars would block
size = 8192
while size >= 4:
with self.assertRaises(BlockingIOError):
while True:
sock.send(b' ' * size)
size = int(size / 2)
# cancel a blocked sock_sendall
task = asyncio.create_task(
self.loop.sock_sendall(sock, b'hello'))
await asyncio.sleep(0)
task.cancel()
# receive everything that is not a space
async def recv_all():
rv = b''
while True:
buf = await self.loop.sock_recv(server, 8192)
if not buf:
return rv
rv += buf.strip()
task = asyncio.create_task(recv_all())
# immediately make another sock_sendall call
await self.loop.sock_sendall(sock, b'world')
sock.shutdown(socket.SHUT_WR)
data = await task
# ProactorEventLoop could deliver hello, so endswith is necessary
self.assertTrue(data.endswith(b'world'))
# After the first connect attempt before the listener is ready,
# the socket needs time to "recover" to make the next connect call.
# On Linux, a second retry will do. On Windows, the waiting time is
# unpredictable; and on FreeBSD the socket may never come back
# because it's a loopback address. Here we'll just retry for a few
# times, and have to skip the test if it's not working. See also:
# https://stackoverflow.com/a/54437602/3316267
# https://lists.freebsd.org/pipermail/freebsd-current/2005-May/049876.html
async def _basetest_sock_connect_racing(self, listener, sock):
listener.bind(('127.0.0.1', 0))
addr = listener.getsockname()
sock.setblocking(False)
task = asyncio.create_task(self.loop.sock_connect(sock, addr))
await asyncio.sleep(0)
task.cancel()
listener.listen(1)
skip_reason = "Max retries reached"
for i in range(128):
try:
await self.loop.sock_connect(sock, addr)
except ConnectionRefusedError as e:
skip_reason = e
except OSError as e:
skip_reason = e
# Retry only for this error:
# [WinError 10022] An invalid argument was supplied
if getattr(e, 'winerror', 0) != 10022:
break
else:
# success
return
self.skipTest(skip_reason)
def test_sock_client_racing(self):
with test_utils.run_test_server() as httpd:
sock = socket.socket()
with sock:
self.loop.run_until_complete(asyncio.wait_for(
self._basetest_sock_recv_racing(httpd, sock), 10))
sock = socket.socket()
with sock:
self.loop.run_until_complete(asyncio.wait_for(
self._basetest_sock_recv_into_racing(httpd, sock), 10))
listener = socket.socket()
sock = socket.socket()
with listener, sock:
self.loop.run_until_complete(asyncio.wait_for(
self._basetest_sock_send_racing(listener, sock), 10))
def test_sock_client_connect_racing(self):
listener = socket.socket()
sock = socket.socket()
with listener, sock:
self.loop.run_until_complete(asyncio.wait_for(
self._basetest_sock_connect_racing(listener, sock), 10))
async def _basetest_huge_content(self, address):
sock = socket.socket()
sock.setblocking(False)
DATA_SIZE = 10_000_00
chunk = b'0123456789' * (DATA_SIZE // 10)
await self.loop.sock_connect(sock, address)
await self.loop.sock_sendall(sock,
(b'POST /loop HTTP/1.0\r\n' +
b'Content-Length: %d\r\n' % DATA_SIZE +
b'\r\n'))
task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
data = await self.loop.sock_recv(sock, DATA_SIZE)
# HTTP headers size is less than MTU,
# they are sent by the first packet always
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
while data.find(b'\r\n\r\n') == -1:
data += await self.loop.sock_recv(sock, DATA_SIZE)
# Strip headers
headers = data[:data.index(b'\r\n\r\n') + 4]
data = data[len(headers):]
size = DATA_SIZE
checker = cycle(b'0123456789')
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
while True:
data = await self.loop.sock_recv(sock, DATA_SIZE)
if not data:
break
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
self.assertEqual(size, 0)
await task
sock.close()
def test_huge_content(self):
with test_utils.run_test_server() as httpd:
self.loop.run_until_complete(
self._basetest_huge_content(httpd.address))
async def _basetest_huge_content_recvinto(self, address):
sock = socket.socket()
sock.setblocking(False)
DATA_SIZE = 10_000_00
chunk = b'0123456789' * (DATA_SIZE // 10)
await self.loop.sock_connect(sock, address)
await self.loop.sock_sendall(sock,
(b'POST /loop HTTP/1.0\r\n' +
b'Content-Length: %d\r\n' % DATA_SIZE +
b'\r\n'))
task = asyncio.create_task(self.loop.sock_sendall(sock, chunk))
array = bytearray(DATA_SIZE)
buf = memoryview(array)
nbytes = await self.loop.sock_recv_into(sock, buf)
data = bytes(buf[:nbytes])
# HTTP headers size is less than MTU,
# they are sent by the first packet always
self.assertTrue(data.startswith(b'HTTP/1.0 200 OK'))
while data.find(b'\r\n\r\n') == -1:
nbytes = await self.loop.sock_recv_into(sock, buf)
data = bytes(buf[:nbytes])
# Strip headers
headers = data[:data.index(b'\r\n\r\n') + 4]
data = data[len(headers):]
size = DATA_SIZE
checker = cycle(b'0123456789')
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
while True:
nbytes = await self.loop.sock_recv_into(sock, buf)
data = buf[:nbytes]
if not data:
break
expected = bytes(islice(checker, len(data)))
self.assertEqual(data, expected)
size -= len(data)
self.assertEqual(size, 0)
await task
sock.close()
def test_huge_content_recvinto(self):
with test_utils.run_test_server() as httpd:
self.loop.run_until_complete(
self._basetest_huge_content_recvinto(httpd.address))
async def _basetest_datagram_recvfrom(self, server_address):
# Happy path, sock.sendto() returns immediately
data = b'\x01' * 4096
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
await self.loop.sock_sendto(sock, data, server_address)
received_data, from_addr = await self.loop.sock_recvfrom(
sock, 4096)
self.assertEqual(received_data, data)
self.assertEqual(from_addr, server_address)
def test_recvfrom(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_recvfrom(server_address))
async def _basetest_datagram_recvfrom_into(self, server_address):
# Happy path, sock.sendto() returns immediately
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
buf = bytearray(4096)
data = b'\x01' * 4096
await self.loop.sock_sendto(sock, data, server_address)
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
sock, buf)
self.assertEqual(num_bytes, 4096)
self.assertEqual(buf, data)
self.assertEqual(from_addr, server_address)
buf = bytearray(8192)
await self.loop.sock_sendto(sock, data, server_address)
num_bytes, from_addr = await self.loop.sock_recvfrom_into(
sock, buf, 4096)
self.assertEqual(num_bytes, 4096)
self.assertEqual(buf[:4096], data[:4096])
self.assertEqual(from_addr, server_address)
def test_recvfrom_into(self):
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_recvfrom_into(server_address))
async def _basetest_datagram_sendto_blocking(self, server_address):
# Sad path, sock.sendto() raises BlockingIOError
# This involves patching sock.sendto() to raise BlockingIOError but
# sendto() is not used by the proactor event loop
data = b'\x01' * 4096
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
sock.setblocking(False)
mock_sock = Mock(sock)
mock_sock.gettimeout = sock.gettimeout
mock_sock.sendto.configure_mock(side_effect=BlockingIOError)
mock_sock.fileno = sock.fileno
self.loop.call_soon(
lambda: setattr(mock_sock, 'sendto', sock.sendto)
)
await self.loop.sock_sendto(mock_sock, data, server_address)
received_data, from_addr = await self.loop.sock_recvfrom(
sock, 4096)
self.assertEqual(received_data, data)
self.assertEqual(from_addr, server_address)
def test_sendto_blocking(self):
if sys.platform == 'win32':
if isinstance(self.loop, asyncio.ProactorEventLoop):
raise unittest.SkipTest('Not relevant to ProactorEventLoop')
with test_utils.run_udp_echo_server() as server_address:
self.loop.run_until_complete(
self._basetest_datagram_sendto_blocking(server_address))
@socket_helper.skip_unless_bind_unix_socket
def test_unix_sock_client_ops(self):
with test_utils.run_test_unix_server() as httpd:
sock = socket.socket(socket.AF_UNIX)
self._basetest_sock_client_ops(httpd, sock)
sock = socket.socket(socket.AF_UNIX)
self._basetest_sock_recv_into(httpd, sock)
def test_sock_client_fail(self):
# Make sure that we will get an unused port
address = None
try:
s = socket.socket()
s.bind(('127.0.0.1', 0))
address = s.getsockname()
finally:
s.close()
sock = socket.socket()
sock.setblocking(False)
with self.assertRaises(ConnectionRefusedError):
self.loop.run_until_complete(
self.loop.sock_connect(sock, address))
sock.close()
def test_sock_accept(self):
listener = socket.socket()
listener.setblocking(False)
listener.bind(('127.0.0.1', 0))
listener.listen(1)
client = socket.socket()
client.connect(listener.getsockname())
f = self.loop.sock_accept(listener)
conn, addr = self.loop.run_until_complete(f)
self.assertEqual(conn.gettimeout(), 0)
self.assertEqual(addr, client.getsockname())
self.assertEqual(client.getpeername(), listener.getsockname())
client.close()
conn.close()
listener.close()
def test_cancel_sock_accept(self):
listener = socket.socket()
listener.setblocking(False)
listener.bind(('127.0.0.1', 0))
listener.listen(1)
sockaddr = listener.getsockname()
f = asyncio.wait_for(self.loop.sock_accept(listener), 0.1)
with self.assertRaises(asyncio.TimeoutError):
self.loop.run_until_complete(f)
listener.close()
client = socket.socket()
client.setblocking(False)
f = self.loop.sock_connect(client, sockaddr)
with self.assertRaises(ConnectionRefusedError):
self.loop.run_until_complete(f)
client.close()
def test_create_connection_sock(self):
with test_utils.run_test_server() as httpd:
sock = None
infos = self.loop.run_until_complete(
self.loop.getaddrinfo(
*httpd.address, type=socket.SOCK_STREAM))
for family, type, proto, cname, address in infos:
try:
sock = socket.socket(family=family, type=type, proto=proto)
sock.setblocking(False)
self.loop.run_until_complete(
self.loop.sock_connect(sock, address))
except BaseException:
pass
else:
break
else:
self.fail('Can not create socket.')
f = self.loop.create_connection(
lambda: MyProto(loop=self.loop), sock=sock)
tr, pr = self.loop.run_until_complete(f)
self.assertIsInstance(tr, asyncio.Transport)
self.assertIsInstance(pr, asyncio.Protocol)
self.loop.run_until_complete(pr.done)
self.assertGreater(pr.nbytes, 0)
tr.close()
if sys.platform == 'win32':
class SelectEventLoopTests(BaseSockTestsMixin,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop()
class ProactorEventLoopTests(BaseSockTestsMixin,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.ProactorEventLoop()
async def _basetest_datagram_send_to_non_listening_address(self,
recvfrom):
# see:
# https://github.com/python/cpython/issues/91227
# https://github.com/python/cpython/issues/88906
# https://bugs.python.org/issue47071
# https://bugs.python.org/issue44743
# The Proactor event loop would fail to receive datagram messages
# after sending a message to an address that wasn't listening.
def create_socket():
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(('127.0.0.1', 0))
return sock
socket_1 = create_socket()
addr_1 = socket_1.getsockname()
socket_2 = create_socket()
addr_2 = socket_2.getsockname()
# creating and immediately closing this to try to get an address
# that is not listening
socket_3 = create_socket()
addr_3 = socket_3.getsockname()
socket_3.shutdown(socket.SHUT_RDWR)
socket_3.close()
socket_1_recv_task = self.loop.create_task(recvfrom(socket_1))
socket_2_recv_task = self.loop.create_task(recvfrom(socket_2))
await asyncio.sleep(0)
await self.loop.sock_sendto(socket_1, b'a', addr_2)
self.assertEqual(await socket_2_recv_task, b'a')
await self.loop.sock_sendto(socket_2, b'b', addr_1)
self.assertEqual(await socket_1_recv_task, b'b')
socket_1_recv_task = self.loop.create_task(recvfrom(socket_1))
await asyncio.sleep(0)
# this should send to an address that isn't listening
await self.loop.sock_sendto(socket_1, b'c', addr_3)
self.assertEqual(await socket_1_recv_task, b'')
socket_1_recv_task = self.loop.create_task(recvfrom(socket_1))
await asyncio.sleep(0)
# socket 1 should still be able to receive messages after sending
# to an address that wasn't listening
socket_2.sendto(b'd', addr_1)
self.assertEqual(await socket_1_recv_task, b'd')
socket_1.shutdown(socket.SHUT_RDWR)
socket_1.close()
socket_2.shutdown(socket.SHUT_RDWR)
socket_2.close()
def test_datagram_send_to_non_listening_address_recvfrom(self):
async def recvfrom(socket):
data, _ = await self.loop.sock_recvfrom(socket, 4096)
return data
self.loop.run_until_complete(
self._basetest_datagram_send_to_non_listening_address(
recvfrom))
def test_datagram_send_to_non_listening_address_recvfrom_into(self):
async def recvfrom_into(socket):
buf = bytearray(4096)
length, _ = await self.loop.sock_recvfrom_into(socket, buf,
4096)
return buf[:length]
self.loop.run_until_complete(
self._basetest_datagram_send_to_non_listening_address(
recvfrom_into))
else:
import selectors
if hasattr(selectors, 'KqueueSelector'):
class KqueueEventLoopTests(BaseSockTestsMixin,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(
selectors.KqueueSelector())
if hasattr(selectors, 'EpollSelector'):
class EPollEventLoopTests(BaseSockTestsMixin,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.EpollSelector())
if hasattr(selectors, 'PollSelector'):
class PollEventLoopTests(BaseSockTestsMixin,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.PollSelector())
# Should always exist.
class SelectEventLoopTests(BaseSockTestsMixin,
test_utils.TestCase):
def create_event_loop(self):
return asyncio.SelectorEventLoop(selectors.SelectSelector())
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,843 @@
"""Tests for asyncio/sslproto.py."""
import logging
import socket
import unittest
import weakref
from test import support
from test.support import socket_helper
from unittest import mock
try:
import ssl
except ImportError:
ssl = None
import asyncio
from asyncio import log
from asyncio import protocols
from asyncio import sslproto
from test.test_asyncio import utils as test_utils
from test.test_asyncio import functional as func_tests
def tearDownModule():
asyncio.set_event_loop_policy(None)
@unittest.skipIf(ssl is None, 'No ssl module')
class SslProtoHandshakeTests(test_utils.TestCase):
def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
self.set_event_loop(self.loop)
def ssl_protocol(self, *, waiter=None, proto=None):
sslcontext = test_utils.dummy_ssl_context()
if proto is None: # app protocol
proto = asyncio.Protocol()
ssl_proto = sslproto.SSLProtocol(self.loop, proto, sslcontext, waiter,
ssl_handshake_timeout=0.1)
self.assertIs(ssl_proto._app_transport.get_protocol(), proto)
self.addCleanup(ssl_proto._app_transport.close)
return ssl_proto
def connection_made(self, ssl_proto, *, do_handshake=None):
transport = mock.Mock()
sslobj = mock.Mock()
# emulate reading decompressed data
sslobj.read.side_effect = ssl.SSLWantReadError
sslobj.write.side_effect = ssl.SSLWantReadError
if do_handshake is not None:
sslobj.do_handshake = do_handshake
ssl_proto._sslobj = sslobj
ssl_proto.connection_made(transport)
return transport
def test_handshake_timeout_zero(self):
sslcontext = test_utils.dummy_ssl_context()
app_proto = mock.Mock()
waiter = mock.Mock()
with self.assertRaisesRegex(ValueError, 'a positive number'):
sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
ssl_handshake_timeout=0)
def test_handshake_timeout_negative(self):
sslcontext = test_utils.dummy_ssl_context()
app_proto = mock.Mock()
waiter = mock.Mock()
with self.assertRaisesRegex(ValueError, 'a positive number'):
sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
ssl_handshake_timeout=-10)
def test_eof_received_waiter(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
ssl_proto.eof_received()
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionResetError)
def test_fatal_error_no_name_error(self):
# From issue #363.
# _fatal_error() generates a NameError if sslproto.py
# does not import base_events.
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
# Temporarily turn off error logging so as not to spoil test output.
log_level = log.logger.getEffectiveLevel()
log.logger.setLevel(logging.FATAL)
try:
ssl_proto._fatal_error(None)
finally:
# Restore error logging.
log.logger.setLevel(log_level)
def test_connection_lost(self):
# From issue #472.
# yield from waiter hang if lost_connection was called.
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
ssl_proto.connection_lost(ConnectionAbortedError)
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)
def test_connection_lost_when_busy(self):
# gh-118950: SSLProtocol.connection_lost not being called when OSError
# is thrown on asyncio.write.
sock = mock.Mock()
sock.fileno = mock.Mock(return_value=12345)
sock.send = mock.Mock(side_effect=BrokenPipeError)
# construct StreamWriter chain that contains loop dependant logic this emulates
# what _make_ssl_transport() does in BaseSelectorEventLoop
reader = asyncio.StreamReader(limit=2 ** 16, loop=self.loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
ssl_proto = self.ssl_protocol(proto=protocol)
# emulate reading decompressed data
sslobj = mock.Mock()
sslobj.read.side_effect = ssl.SSLWantReadError
sslobj.write.side_effect = ssl.SSLWantReadError
ssl_proto._sslobj = sslobj
# emulate outgoing data
data = b'An interesting message'
outgoing = mock.Mock()
outgoing.read = mock.Mock(return_value=data)
outgoing.pending = len(data)
ssl_proto._outgoing = outgoing
# use correct socket transport to initialize the SSLProtocol
self.loop._make_socket_transport(sock, ssl_proto)
transport = ssl_proto._app_transport
writer = asyncio.StreamWriter(transport, protocol, reader, self.loop)
async def main():
# writes data to transport
async def write():
writer.write(data)
await writer.drain()
# try to write for the first time
await write()
# try to write for the second time, this raises as the connection_lost
# callback should be done with error
with self.assertRaises(ConnectionResetError):
await write()
self.loop.run_until_complete(main())
def test_close_during_handshake(self):
# bpo-29743 Closing transport during handshake process leaks socket
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
transport = self.connection_made(
ssl_proto,
do_handshake=mock.Mock(side_effect=ssl.SSLWantReadError)
)
test_utils.run_briefly(self.loop)
ssl_proto._app_transport.close()
self.assertTrue(transport._force_close.called)
def test_close_during_ssl_over_ssl(self):
# gh-113214: passing exceptions from the inner wrapped SSL protocol to the
# shim transport provided by the outer SSL protocol should not raise
# attribute errors
outer = self.ssl_protocol(proto=self.ssl_protocol())
self.connection_made(outer)
# Closing the outer app transport should not raise an exception
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
outer._app_transport.close()
self.assertEqual(messages, [])
def test_get_extra_info_on_closed_connection(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
self.assertIsNone(ssl_proto._get_extra_info('socket'))
default = object()
self.assertIs(ssl_proto._get_extra_info('socket', default), default)
self.connection_made(ssl_proto)
self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
ssl_proto.connection_lost(None)
self.assertIsNone(ssl_proto._get_extra_info('socket'))
def test_set_new_app_protocol(self):
waiter = self.loop.create_future()
ssl_proto = self.ssl_protocol(waiter=waiter)
new_app_proto = asyncio.Protocol()
ssl_proto._app_transport.set_protocol(new_app_proto)
self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
self.assertIs(ssl_proto._app_protocol, new_app_proto)
def test_data_received_after_closing(self):
ssl_proto = self.ssl_protocol()
self.connection_made(ssl_proto)
transp = ssl_proto._app_transport
transp.close()
# should not raise
self.assertIsNone(ssl_proto.buffer_updated(5))
def test_write_after_closing(self):
ssl_proto = self.ssl_protocol()
self.connection_made(ssl_proto)
transp = ssl_proto._app_transport
transp.close()
# should not raise
self.assertIsNone(transp.write(b'data'))
##############################################################################
# Start TLS Tests
##############################################################################
class BaseStartTLS(func_tests.FunctionalTestCaseMixin):
PAYLOAD_SIZE = 1024 * 100
TIMEOUT = support.LONG_TIMEOUT
def new_loop(self):
raise NotImplementedError
def test_buf_feed_data(self):
class Proto(asyncio.BufferedProtocol):
def __init__(self, bufsize, usemv):
self.buf = bytearray(bufsize)
self.mv = memoryview(self.buf)
self.data = b''
self.usemv = usemv
def get_buffer(self, sizehint):
if self.usemv:
return self.mv
else:
return self.buf
def buffer_updated(self, nsize):
if self.usemv:
self.data += self.mv[:nsize]
else:
self.data += self.buf[:nsize]
for usemv in [False, True]:
proto = Proto(1, usemv)
protocols._feed_data_to_buffered_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(2, usemv)
protocols._feed_data_to_buffered_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(2, usemv)
protocols._feed_data_to_buffered_proto(proto, b'1234')
self.assertEqual(proto.data, b'1234')
proto = Proto(4, usemv)
protocols._feed_data_to_buffered_proto(proto, b'1234')
self.assertEqual(proto.data, b'1234')
proto = Proto(100, usemv)
protocols._feed_data_to_buffered_proto(proto, b'12345')
self.assertEqual(proto.data, b'12345')
proto = Proto(0, usemv)
with self.assertRaisesRegex(RuntimeError, 'empty buffer'):
protocols._feed_data_to_buffered_proto(proto, b'12345')
def test_start_tls_client_reg_proto_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
def serve(sock):
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.start_tls(server_context, server_side=True)
sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ClientProto(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def data_received(self, data):
self.on_data.set_result(data)
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5)
on_data = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)
tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)
self.assertEqual(await on_data, b'O')
new_tr.write(HELLO_MSG)
await on_eof
new_tr.close()
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
timeout=support.SHORT_TIMEOUT))
# No garbage is left if SSL is closed uncleanly
client_context = weakref.ref(client_context)
support.gc_collect()
self.assertIsNone(client_context())
def test_create_connection_memory_leak(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
def serve(sock):
sock.settimeout(self.TIMEOUT)
sock.start_tls(server_context, server_side=True)
sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ClientProto(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
def connection_made(proto, tr):
# XXX: We assume user stores the transport in protocol
proto.tr = tr
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def data_received(self, data):
self.on_data.set_result(data)
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5)
on_data = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr,
ssl=client_context)
self.assertEqual(await on_data, b'O')
tr.write(HELLO_MSG)
await on_eof
tr.close()
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
timeout=support.SHORT_TIMEOUT))
# No garbage is left for SSL client from loop.create_connection, even
# if user stores the SSLTransport in corresponding protocol instance
client_context = weakref.ref(client_context)
support.gc_collect()
self.assertIsNone(client_context())
@socket_helper.skip_if_tcp_blackhole
def test_start_tls_client_buf_proto_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
client_con_made_calls = 0
def serve(sock):
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.start_tls(server_context, server_side=True)
sock.sendall(b'O')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.sendall(b'2')
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.shutdown(socket.SHUT_RDWR)
sock.close()
class ClientProtoFirst(asyncio.BufferedProtocol):
def __init__(self, on_data):
self.on_data = on_data
self.buf = bytearray(1)
def connection_made(self, tr):
nonlocal client_con_made_calls
client_con_made_calls += 1
def get_buffer(self, sizehint):
return self.buf
def buffer_updated(slf, nsize):
self.assertEqual(nsize, 1)
slf.on_data.set_result(bytes(slf.buf[:nsize]))
class ClientProtoSecond(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
def connection_made(self, tr):
nonlocal client_con_made_calls
client_con_made_calls += 1
def data_received(self, data):
self.on_data.set_result(data)
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5)
on_data1 = self.loop.create_future()
on_data2 = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProtoFirst(on_data1), *addr)
tr.write(HELLO_MSG)
new_tr = await self.loop.start_tls(tr, proto, client_context)
self.assertEqual(await on_data1, b'O')
new_tr.write(HELLO_MSG)
new_tr.set_protocol(ClientProtoSecond(on_data2, on_eof))
self.assertEqual(await on_data2, b'2')
new_tr.write(HELLO_MSG)
await on_eof
new_tr.close()
# connection_made() should be called only once -- when
# we establish connection for the first time. Start TLS
# doesn't call connection_made() on application protocols.
self.assertEqual(client_con_made_calls, 1)
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
timeout=self.TIMEOUT))
def test_start_tls_slow_client_cancel(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
client_context = test_utils.simple_client_sslcontext()
server_waits_on_handshake = self.loop.create_future()
def serve(sock):
sock.settimeout(self.TIMEOUT)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
try:
self.loop.call_soon_threadsafe(
server_waits_on_handshake.set_result, None)
data = sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
pass
finally:
sock.close()
class ClientProto(asyncio.Protocol):
def __init__(self, on_data, on_eof):
self.on_data = on_data
self.on_eof = on_eof
self.con_made_cnt = 0
def connection_made(proto, tr):
proto.con_made_cnt += 1
# Ensure connection_made gets called only once.
self.assertEqual(proto.con_made_cnt, 1)
def data_received(self, data):
self.on_data.set_result(data)
def eof_received(self):
self.on_eof.set_result(True)
async def client(addr):
await asyncio.sleep(0.5)
on_data = self.loop.create_future()
on_eof = self.loop.create_future()
tr, proto = await self.loop.create_connection(
lambda: ClientProto(on_data, on_eof), *addr)
tr.write(HELLO_MSG)
await server_waits_on_handshake
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(
self.loop.start_tls(tr, proto, client_context),
0.5)
with self.tcp_server(serve, timeout=self.TIMEOUT) as srv:
self.loop.run_until_complete(
asyncio.wait_for(client(srv.addr),
timeout=support.SHORT_TIMEOUT))
@socket_helper.skip_if_tcp_blackhole
def test_start_tls_server_1(self):
HELLO_MSG = b'1' * self.PAYLOAD_SIZE
ANSWER = b'answer'
server_context = test_utils.simple_server_sslcontext()
client_context = test_utils.simple_client_sslcontext()
answer = None
def client(sock, addr):
nonlocal answer
sock.settimeout(self.TIMEOUT)
sock.connect(addr)
data = sock.recv_all(len(HELLO_MSG))
self.assertEqual(len(data), len(HELLO_MSG))
sock.start_tls(client_context)
sock.sendall(HELLO_MSG)
answer = sock.recv_all(len(ANSWER))
sock.close()
class ServerProto(asyncio.Protocol):
def __init__(self, on_con, on_con_lost, on_got_hello):
self.on_con = on_con
self.on_con_lost = on_con_lost
self.on_got_hello = on_got_hello
self.data = b''
self.transport = None
def connection_made(self, tr):
self.transport = tr
self.on_con.set_result(tr)
def replace_transport(self, tr):
self.transport = tr
def data_received(self, data):
self.data += data
if len(self.data) >= len(HELLO_MSG):
self.on_got_hello.set_result(None)
def connection_lost(self, exc):
self.transport = None
if exc is None:
self.on_con_lost.set_result(None)
else:
self.on_con_lost.set_exception(exc)
async def main(proto, on_con, on_con_lost, on_got_hello):
tr = await on_con
tr.write(HELLO_MSG)
self.assertEqual(proto.data, b'')
new_tr = await self.loop.start_tls(
tr, proto, server_context,
server_side=True,
ssl_handshake_timeout=self.TIMEOUT)
proto.replace_transport(new_tr)
await on_got_hello
new_tr.write(ANSWER)
await on_con_lost
self.assertEqual(proto.data, HELLO_MSG)
new_tr.close()
async def run_main():
on_con = self.loop.create_future()
on_con_lost = self.loop.create_future()
on_got_hello = self.loop.create_future()
proto = ServerProto(on_con, on_con_lost, on_got_hello)
server = await self.loop.create_server(
lambda: proto, '127.0.0.1', 0)
addr = server.sockets[0].getsockname()
with self.tcp_client(lambda sock: client(sock, addr),
timeout=self.TIMEOUT):
await asyncio.wait_for(
main(proto, on_con, on_con_lost, on_got_hello),
timeout=self.TIMEOUT)
server.close()
await server.wait_closed()
self.assertEqual(answer, ANSWER)
self.loop.run_until_complete(run_main())
def test_start_tls_wrong_args(self):
async def main():
with self.assertRaisesRegex(TypeError, 'SSLContext, got'):
await self.loop.start_tls(None, None, None)
sslctx = test_utils.simple_server_sslcontext()
with self.assertRaisesRegex(TypeError, 'is not supported'):
await self.loop.start_tls(None, None, sslctx)
self.loop.run_until_complete(main())
def test_handshake_timeout(self):
# bpo-29970: Check that a connection is aborted if handshake is not
# completed in timeout period, instead of remaining open indefinitely
client_sslctx = test_utils.simple_client_sslcontext()
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
server_side_aborted = False
def server(sock):
nonlocal server_side_aborted
try:
sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
server_side_aborted = True
finally:
sock.close()
async def client(addr):
await asyncio.wait_for(
self.loop.create_connection(
asyncio.Protocol,
*addr,
ssl=client_sslctx,
server_hostname='',
ssl_handshake_timeout=support.SHORT_TIMEOUT),
0.5)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaises(asyncio.TimeoutError):
self.loop.run_until_complete(client(srv.addr))
self.assertTrue(server_side_aborted)
# Python issue #23197: cancelling a handshake must not raise an
# exception or log an error, even if the handshake failed
self.assertEqual(messages, [])
# The 10s handshake timeout should be cancelled to free related
# objects without really waiting for 10s
client_sslctx = weakref.ref(client_sslctx)
support.gc_collect()
self.assertIsNone(client_sslctx())
def test_create_connection_ssl_slow_handshake(self):
client_sslctx = test_utils.simple_client_sslcontext()
messages = []
self.loop.set_exception_handler(lambda loop, ctx: messages.append(ctx))
def server(sock):
try:
sock.recv_all(1024 * 1024)
except ConnectionAbortedError:
pass
finally:
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
ssl_handshake_timeout=1.0)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaisesRegex(
ConnectionAbortedError,
r'SSL handshake.*is taking longer'):
self.loop.run_until_complete(client(srv.addr))
self.assertEqual(messages, [])
def test_create_connection_ssl_failed_certificate(self):
self.loop.set_exception_handler(lambda loop, ctx: None)
sslctx = test_utils.simple_server_sslcontext()
client_sslctx = test_utils.simple_client_sslcontext(
disable_verify=False)
def server(sock):
try:
sock.start_tls(
sslctx,
server_side=True)
except ssl.SSLError:
pass
except OSError:
pass
finally:
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='',
ssl_handshake_timeout=support.LOOPBACK_TIMEOUT)
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
with self.assertRaises(ssl.SSLCertVerificationError):
self.loop.run_until_complete(client(srv.addr))
def test_start_tls_client_corrupted_ssl(self):
self.loop.set_exception_handler(lambda loop, ctx: None)
sslctx = test_utils.simple_server_sslcontext()
client_sslctx = test_utils.simple_client_sslcontext()
def server(sock):
orig_sock = sock.dup()
try:
sock.start_tls(
sslctx,
server_side=True)
sock.sendall(b'A\n')
sock.recv_all(1)
orig_sock.send(b'please corrupt the SSL connection')
except ssl.SSLError:
pass
finally:
orig_sock.close()
sock.close()
async def client(addr):
reader, writer = await asyncio.open_connection(
*addr,
ssl=client_sslctx,
server_hostname='')
self.assertEqual(await reader.readline(), b'A\n')
writer.write(b'B')
with self.assertRaises(ssl.SSLError):
await reader.readline()
writer.close()
return 'OK'
with self.tcp_server(server,
max_clients=1,
backlog=1) as srv:
res = self.loop.run_until_complete(client(srv.addr))
self.assertEqual(res, 'OK')
@unittest.skipIf(ssl is None, 'No ssl module')
class SelectorStartTLSTests(BaseStartTLS, unittest.TestCase):
def new_loop(self):
return asyncio.SelectorEventLoop()
@unittest.skipIf(ssl is None, 'No ssl module')
@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
class ProactorStartTLSTests(BaseStartTLS, unittest.TestCase):
def new_loop(self):
return asyncio.ProactorEventLoop()
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,151 @@
import asyncio
import unittest
from asyncio.staggered import staggered_race
from test import support
support.requires_working_socket(module=True)
def tearDownModule():
asyncio.set_event_loop_policy(None)
class StaggeredTests(unittest.IsolatedAsyncioTestCase):
async def test_empty(self):
winner, index, excs = await staggered_race(
[],
delay=None,
)
self.assertIs(winner, None)
self.assertIs(index, None)
self.assertEqual(excs, [])
async def test_one_successful(self):
async def coro(index):
return f'Res: {index}'
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=None,
)
self.assertEqual(winner, 'Res: 0')
self.assertEqual(index, 0)
self.assertEqual(excs, [None])
async def test_first_error_second_successful(self):
async def coro(index):
if index == 0:
raise ValueError(index)
return f'Res: {index}'
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=None,
)
self.assertEqual(winner, 'Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIs(excs[1], None)
async def test_first_timeout_second_successful(self):
async def coro(index):
if index == 0:
await asyncio.sleep(10) # much bigger than delay
return f'Res: {index}'
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=0.1,
)
self.assertEqual(winner, 'Res: 1')
self.assertEqual(index, 1)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], asyncio.CancelledError)
self.assertIs(excs[1], None)
async def test_none_successful(self):
async def coro(index):
raise ValueError(index)
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
],
delay=None,
)
self.assertIs(winner, None)
self.assertIs(index, None)
self.assertEqual(len(excs), 2)
self.assertIsInstance(excs[0], ValueError)
self.assertIsInstance(excs[1], ValueError)
async def test_multiple_winners(self):
event = asyncio.Event()
async def coro(index):
await event.wait()
return index
async def do_set():
event.set()
await asyncio.Event().wait()
winner, index, excs = await staggered_race(
[
lambda: coro(0),
lambda: coro(1),
do_set,
],
delay=0.1,
)
self.assertIs(winner, 0)
self.assertIs(index, 0)
self.assertEqual(len(excs), 3)
self.assertIsNone(excs[0], None)
self.assertIsInstance(excs[1], asyncio.CancelledError)
self.assertIsInstance(excs[2], asyncio.CancelledError)
async def test_cancelled(self):
log = []
with self.assertRaises(TimeoutError):
async with asyncio.timeout(None) as cs_outer, asyncio.timeout(None) as cs_inner:
async def coro_fn():
cs_inner.reschedule(-1)
await asyncio.sleep(0)
try:
await asyncio.sleep(0)
except asyncio.CancelledError:
log.append("cancelled 1")
cs_outer.reschedule(-1)
await asyncio.sleep(0)
try:
await asyncio.sleep(0)
except asyncio.CancelledError:
log.append("cancelled 2")
try:
await staggered_race([coro_fn], delay=None)
except asyncio.CancelledError:
log.append("cancelled 3")
raise
self.assertListEqual(log, ["cancelled 1", "cancelled 2", "cancelled 3"])

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,66 @@
"""Tests for asyncio/threads.py"""
import asyncio
import unittest
from contextvars import ContextVar
from unittest import mock
def tearDownModule():
asyncio.set_event_loop_policy(None)
class ToThreadTests(unittest.IsolatedAsyncioTestCase):
async def test_to_thread(self):
result = await asyncio.to_thread(sum, [40, 2])
self.assertEqual(result, 42)
async def test_to_thread_exception(self):
def raise_runtime():
raise RuntimeError("test")
with self.assertRaisesRegex(RuntimeError, "test"):
await asyncio.to_thread(raise_runtime)
async def test_to_thread_once(self):
func = mock.Mock()
await asyncio.to_thread(func)
func.assert_called_once()
async def test_to_thread_concurrent(self):
calls = []
def func():
calls.append(1)
futs = []
for _ in range(10):
fut = asyncio.to_thread(func)
futs.append(fut)
await asyncio.gather(*futs)
self.assertEqual(sum(calls), 10)
async def test_to_thread_args_kwargs(self):
# Unlike run_in_executor(), to_thread() should directly accept kwargs.
func = mock.Mock()
await asyncio.to_thread(func, 'test', something=True)
func.assert_called_once_with('test', something=True)
async def test_to_thread_contextvars(self):
test_ctx = ContextVar('test_ctx')
def get_ctx():
return test_ctx.get()
test_ctx.set('parrot')
result = await asyncio.to_thread(get_ctx)
self.assertEqual(result, 'parrot')
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,411 @@
"""Tests for asyncio/timeouts.py"""
import unittest
import time
import asyncio
from test.test_asyncio.utils import await_without_task
def tearDownModule():
asyncio.set_event_loop_policy(None)
class TimeoutTests(unittest.IsolatedAsyncioTestCase):
async def test_timeout_basic(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(10)
self.assertTrue(cm.expired())
async def test_timeout_at_basic(self):
loop = asyncio.get_running_loop()
with self.assertRaises(TimeoutError):
deadline = loop.time() + 0.01
async with asyncio.timeout_at(deadline) as cm:
await asyncio.sleep(10)
self.assertTrue(cm.expired())
self.assertEqual(deadline, cm.when())
async def test_nested_timeouts(self):
loop = asyncio.get_running_loop()
cancelled = False
with self.assertRaises(TimeoutError):
deadline = loop.time() + 0.01
async with asyncio.timeout_at(deadline) as cm1:
# Only the topmost context manager should raise TimeoutError
try:
async with asyncio.timeout_at(deadline) as cm2:
await asyncio.sleep(10)
except asyncio.CancelledError:
cancelled = True
raise
self.assertTrue(cancelled)
self.assertTrue(cm1.expired())
self.assertTrue(cm2.expired())
async def test_waiter_cancelled(self):
cancelled = False
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01):
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
cancelled = True
raise
self.assertTrue(cancelled)
async def test_timeout_not_called(self):
loop = asyncio.get_running_loop()
async with asyncio.timeout(10) as cm:
await asyncio.sleep(0.01)
t1 = loop.time()
self.assertFalse(cm.expired())
self.assertGreater(cm.when(), t1)
async def test_timeout_disabled(self):
async with asyncio.timeout(None) as cm:
await asyncio.sleep(0.01)
self.assertFalse(cm.expired())
self.assertIsNone(cm.when())
async def test_timeout_at_disabled(self):
async with asyncio.timeout_at(None) as cm:
await asyncio.sleep(0.01)
self.assertFalse(cm.expired())
self.assertIsNone(cm.when())
async def test_timeout_zero(self):
loop = asyncio.get_running_loop()
t0 = loop.time()
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0) as cm:
await asyncio.sleep(10)
t1 = loop.time()
self.assertTrue(cm.expired())
self.assertTrue(t0 <= cm.when() <= t1)
async def test_timeout_zero_sleep_zero(self):
loop = asyncio.get_running_loop()
t0 = loop.time()
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0) as cm:
await asyncio.sleep(0)
t1 = loop.time()
self.assertTrue(cm.expired())
self.assertTrue(t0 <= cm.when() <= t1)
async def test_timeout_in_the_past_sleep_zero(self):
loop = asyncio.get_running_loop()
t0 = loop.time()
with self.assertRaises(TimeoutError):
async with asyncio.timeout(-11) as cm:
await asyncio.sleep(0)
t1 = loop.time()
self.assertTrue(cm.expired())
self.assertTrue(t0 >= cm.when() <= t1)
async def test_foreign_exception_passed(self):
with self.assertRaises(KeyError):
async with asyncio.timeout(0.01) as cm:
raise KeyError
self.assertFalse(cm.expired())
async def test_timeout_exception_context(self):
with self.assertRaises(TimeoutError) as cm:
async with asyncio.timeout(0.01):
try:
1/0
finally:
await asyncio.sleep(1)
e = cm.exception
# Expect TimeoutError caused by CancelledError raised during handling
# of ZeroDivisionError.
e2 = e.__cause__
self.assertIsInstance(e2, asyncio.CancelledError)
self.assertIs(e.__context__, e2)
self.assertIsNone(e2.__cause__)
self.assertIsInstance(e2.__context__, ZeroDivisionError)
async def test_foreign_exception_on_timeout(self):
async def crash():
try:
await asyncio.sleep(1)
finally:
1/0
with self.assertRaises(ZeroDivisionError) as cm:
async with asyncio.timeout(0.01):
await crash()
e = cm.exception
# Expect ZeroDivisionError raised during handling of TimeoutError
# caused by CancelledError.
self.assertIsNone(e.__cause__)
e2 = e.__context__
self.assertIsInstance(e2, TimeoutError)
e3 = e2.__cause__
self.assertIsInstance(e3, asyncio.CancelledError)
self.assertIs(e2.__context__, e3)
async def test_foreign_exception_on_timeout_2(self):
with self.assertRaises(ZeroDivisionError) as cm:
async with asyncio.timeout(0.01):
try:
try:
raise ValueError
finally:
await asyncio.sleep(1)
finally:
try:
raise KeyError
finally:
1/0
e = cm.exception
# Expect ZeroDivisionError raised during handling of KeyError
# raised during handling of TimeoutError caused by CancelledError.
self.assertIsNone(e.__cause__)
e2 = e.__context__
self.assertIsInstance(e2, KeyError)
self.assertIsNone(e2.__cause__)
e3 = e2.__context__
self.assertIsInstance(e3, TimeoutError)
e4 = e3.__cause__
self.assertIsInstance(e4, asyncio.CancelledError)
self.assertIsNone(e4.__cause__)
self.assertIsInstance(e4.__context__, ValueError)
self.assertIs(e3.__context__, e4)
async def test_foreign_cancel_doesnt_timeout_if_not_expired(self):
with self.assertRaises(asyncio.CancelledError):
async with asyncio.timeout(10) as cm:
asyncio.current_task().cancel()
await asyncio.sleep(10)
self.assertFalse(cm.expired())
async def test_outer_task_is_not_cancelled(self):
async def outer() -> None:
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.001):
await asyncio.sleep(10)
task = asyncio.create_task(outer())
await task
self.assertFalse(task.cancelled())
self.assertTrue(task.done())
async def test_nested_timeouts_concurrent(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.002):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.1):
# Pretend we crunch some numbers.
time.sleep(0.01)
await asyncio.sleep(1)
async def test_nested_timeouts_loop_busy(self):
# After the inner timeout is an expensive operation which should
# be stopped by the outer timeout.
loop = asyncio.get_running_loop()
# Disable a message about long running task
loop.slow_callback_duration = 10
t0 = loop.time()
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.1): # (1)
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01): # (2)
# Pretend the loop is busy for a while.
time.sleep(0.1)
await asyncio.sleep(1)
# TimeoutError was cought by (2)
await asyncio.sleep(10) # This sleep should be interrupted by (1)
t1 = loop.time()
self.assertTrue(t0 <= t1 <= t0 + 1)
async def test_reschedule(self):
loop = asyncio.get_running_loop()
fut = loop.create_future()
deadline1 = loop.time() + 10
deadline2 = deadline1 + 20
async def f():
async with asyncio.timeout_at(deadline1) as cm:
fut.set_result(cm)
await asyncio.sleep(50)
task = asyncio.create_task(f())
cm = await fut
self.assertEqual(cm.when(), deadline1)
cm.reschedule(deadline2)
self.assertEqual(cm.when(), deadline2)
cm.reschedule(None)
self.assertIsNone(cm.when())
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertFalse(cm.expired())
async def test_repr_active(self):
async with asyncio.timeout(10) as cm:
self.assertRegex(repr(cm), r"<Timeout \[active\] when=\d+\.\d*>")
async def test_repr_expired(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(10)
self.assertEqual(repr(cm), "<Timeout [expired]>")
async def test_repr_finished(self):
async with asyncio.timeout(10) as cm:
await asyncio.sleep(0)
self.assertEqual(repr(cm), "<Timeout [finished]>")
async def test_repr_disabled(self):
async with asyncio.timeout(None) as cm:
self.assertEqual(repr(cm), r"<Timeout [active] when=None>")
async def test_nested_timeout_in_finally(self):
with self.assertRaises(TimeoutError) as cm1:
async with asyncio.timeout(0.01):
try:
await asyncio.sleep(1)
finally:
with self.assertRaises(TimeoutError) as cm2:
async with asyncio.timeout(0.01):
await asyncio.sleep(10)
e1 = cm1.exception
# Expect TimeoutError caused by CancelledError.
e12 = e1.__cause__
self.assertIsInstance(e12, asyncio.CancelledError)
self.assertIsNone(e12.__cause__)
self.assertIsNone(e12.__context__)
self.assertIs(e1.__context__, e12)
e2 = cm2.exception
# Expect TimeoutError caused by CancelledError raised during
# handling of other CancelledError (which is the same as in
# the above chain).
e22 = e2.__cause__
self.assertIsInstance(e22, asyncio.CancelledError)
self.assertIsNone(e22.__cause__)
self.assertIs(e22.__context__, e12)
self.assertIs(e2.__context__, e22)
async def test_timeout_after_cancellation(self):
try:
asyncio.current_task().cancel()
await asyncio.sleep(1) # work which will be cancelled
except asyncio.CancelledError:
pass
finally:
with self.assertRaises(TimeoutError) as cm:
async with asyncio.timeout(0.0):
await asyncio.sleep(1) # some cleanup
async def test_cancel_in_timeout_after_cancellation(self):
try:
asyncio.current_task().cancel()
await asyncio.sleep(1) # work which will be cancelled
except asyncio.CancelledError:
pass
finally:
with self.assertRaises(asyncio.CancelledError):
async with asyncio.timeout(1.0):
asyncio.current_task().cancel()
await asyncio.sleep(2) # some cleanup
async def test_timeout_already_entered(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass
async def test_timeout_double_enter(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "has already been entered"):
async with cm:
pass
async def test_timeout_finished(self):
async with asyncio.timeout(0.01) as cm:
pass
with self.assertRaisesRegex(RuntimeError, "finished"):
cm.reschedule(0.02)
async def test_timeout_expired(self):
with self.assertRaises(TimeoutError):
async with asyncio.timeout(0.01) as cm:
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expired"):
cm.reschedule(0.02)
async def test_timeout_expiring(self):
async with asyncio.timeout(0.01) as cm:
with self.assertRaises(asyncio.CancelledError):
await asyncio.sleep(1)
with self.assertRaisesRegex(RuntimeError, "expiring"):
cm.reschedule(0.02)
async def test_timeout_not_entered(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)
async def test_timeout_without_task(self):
cm = asyncio.timeout(0.01)
with self.assertRaisesRegex(RuntimeError, "task"):
await await_without_task(cm.__aenter__())
with self.assertRaisesRegex(RuntimeError, "has not been entered"):
cm.reschedule(0.02)
async def test_timeout_taskgroup(self):
async def task():
try:
await asyncio.sleep(2) # Will be interrupted after 0.01 second
finally:
1/0 # Crash in cleanup
with self.assertRaises(ExceptionGroup) as cm:
async with asyncio.timeout(0.01):
async with asyncio.TaskGroup() as tg:
tg.create_task(task())
try:
raise ValueError
finally:
await asyncio.sleep(1)
eg = cm.exception
# Expect ExceptionGroup raised during handling of TimeoutError caused
# by CancelledError raised during handling of ValueError.
self.assertIsNone(eg.__cause__)
e_1 = eg.__context__
self.assertIsInstance(e_1, TimeoutError)
e_2 = e_1.__cause__
self.assertIsInstance(e_2, asyncio.CancelledError)
self.assertIsNone(e_2.__cause__)
self.assertIsInstance(e_2.__context__, ValueError)
self.assertIs(e_1.__context__, e_2)
self.assertEqual(len(eg.exceptions), 1, eg)
e1 = eg.exceptions[0]
# Expect ZeroDivisionError raised during handling of TimeoutError
# caused by CancelledError (it is a different CancelledError).
self.assertIsInstance(e1, ZeroDivisionError)
self.assertIsNone(e1.__cause__)
e2 = e1.__context__
self.assertIsInstance(e2, TimeoutError)
e3 = e2.__cause__
self.assertIsInstance(e3, asyncio.CancelledError)
self.assertIsNone(e3.__context__)
self.assertIsNone(e3.__cause__)
self.assertIs(e2.__context__, e3)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,103 @@
"""Tests for transports.py."""
import unittest
from unittest import mock
import asyncio
from asyncio import transports
def tearDownModule():
# not needed for the test file but added for uniformness with all other
# asyncio test files for the sake of unified cleanup
asyncio.set_event_loop_policy(None)
class TransportTests(unittest.TestCase):
def test_ctor_extra_is_none(self):
transport = asyncio.Transport()
self.assertEqual(transport._extra, {})
def test_get_extra_info(self):
transport = asyncio.Transport({'extra': 'info'})
self.assertEqual('info', transport.get_extra_info('extra'))
self.assertIsNone(transport.get_extra_info('unknown'))
default = object()
self.assertIs(default, transport.get_extra_info('unknown', default))
def test_writelines(self):
writer = mock.Mock()
class MyTransport(asyncio.Transport):
def write(self, data):
writer(data)
transport = MyTransport()
transport.writelines([b'line1',
bytearray(b'line2'),
memoryview(b'line3')])
self.assertEqual(1, writer.call_count)
writer.assert_called_with(b'line1line2line3')
def test_not_implemented(self):
transport = asyncio.Transport()
self.assertRaises(NotImplementedError,
transport.set_write_buffer_limits)
self.assertRaises(NotImplementedError, transport.get_write_buffer_size)
self.assertRaises(NotImplementedError, transport.write, 'data')
self.assertRaises(NotImplementedError, transport.write_eof)
self.assertRaises(NotImplementedError, transport.can_write_eof)
self.assertRaises(NotImplementedError, transport.pause_reading)
self.assertRaises(NotImplementedError, transport.resume_reading)
self.assertRaises(NotImplementedError, transport.is_reading)
self.assertRaises(NotImplementedError, transport.close)
self.assertRaises(NotImplementedError, transport.abort)
def test_dgram_not_implemented(self):
transport = asyncio.DatagramTransport()
self.assertRaises(NotImplementedError, transport.sendto, 'data')
self.assertRaises(NotImplementedError, transport.abort)
def test_subprocess_transport_not_implemented(self):
transport = asyncio.SubprocessTransport()
self.assertRaises(NotImplementedError, transport.get_pid)
self.assertRaises(NotImplementedError, transport.get_returncode)
self.assertRaises(NotImplementedError, transport.get_pipe_transport, 1)
self.assertRaises(NotImplementedError, transport.send_signal, 1)
self.assertRaises(NotImplementedError, transport.terminate)
self.assertRaises(NotImplementedError, transport.kill)
def test_flowcontrol_mixin_set_write_limits(self):
class MyTransport(transports._FlowControlMixin,
transports.Transport):
def get_write_buffer_size(self):
return 512
loop = mock.Mock()
transport = MyTransport(loop=loop)
transport._protocol = mock.Mock()
self.assertFalse(transport._protocol_paused)
with self.assertRaisesRegex(ValueError, 'high.*must be >= low'):
transport.set_write_buffer_limits(high=0, low=1)
transport.set_write_buffer_limits(high=1024, low=128)
self.assertFalse(transport._protocol_paused)
self.assertEqual(transport.get_write_buffer_limits(), (128, 1024))
transport.set_write_buffer_limits(high=256, low=128)
self.assertTrue(transport._protocol_paused)
self.assertEqual(transport.get_write_buffer_limits(), (128, 256))
if __name__ == '__main__':
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,353 @@
import asyncio
import unittest
import time
from test import support
def tearDownModule():
asyncio.set_event_loop_policy(None)
# The following value can be used as a very small timeout:
# it passes check "timeout > 0", but has almost
# no effect on the test performance
_EPSILON = 0.0001
class SlowTask:
""" Task will run for this defined time, ignoring cancel requests """
TASK_TIMEOUT = 0.2
def __init__(self):
self.exited = False
async def run(self):
exitat = time.monotonic() + self.TASK_TIMEOUT
while True:
tosleep = exitat - time.monotonic()
if tosleep <= 0:
break
try:
await asyncio.sleep(tosleep)
except asyncio.CancelledError:
pass
self.exited = True
class AsyncioWaitForTest(unittest.IsolatedAsyncioTestCase):
async def test_asyncio_wait_for_cancelled(self):
t = SlowTask()
waitfortask = asyncio.create_task(
asyncio.wait_for(t.run(), t.TASK_TIMEOUT * 2))
await asyncio.sleep(0)
waitfortask.cancel()
await asyncio.wait({waitfortask})
self.assertTrue(t.exited)
async def test_asyncio_wait_for_timeout(self):
t = SlowTask()
try:
await asyncio.wait_for(t.run(), t.TASK_TIMEOUT / 2)
except asyncio.TimeoutError:
pass
self.assertTrue(t.exited)
async def test_wait_for_timeout_less_then_0_or_0_future_done(self):
loop = asyncio.get_running_loop()
fut = loop.create_future()
fut.set_result('done')
ret = await asyncio.wait_for(fut, 0)
self.assertEqual(ret, 'done')
self.assertTrue(fut.done())
async def test_wait_for_timeout_less_then_0_or_0_coroutine_do_not_started(self):
foo_started = False
async def foo():
nonlocal foo_started
foo_started = True
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(foo(), 0)
self.assertEqual(foo_started, False)
async def test_wait_for_timeout_less_then_0_or_0(self):
loop = asyncio.get_running_loop()
for timeout in [0, -1]:
with self.subTest(timeout=timeout):
foo_running = None
started = loop.create_future()
async def foo():
nonlocal foo_running
foo_running = True
started.set_result(None)
try:
await asyncio.sleep(10)
finally:
foo_running = False
return 'done'
fut = asyncio.create_task(foo())
await started
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(fut, timeout)
self.assertTrue(fut.done())
# it should have been cancelled due to the timeout
self.assertTrue(fut.cancelled())
self.assertEqual(foo_running, False)
async def test_wait_for(self):
foo_running = None
async def foo():
nonlocal foo_running
foo_running = True
try:
await asyncio.sleep(support.LONG_TIMEOUT)
finally:
foo_running = False
return 'done'
fut = asyncio.create_task(foo())
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(fut, 0.1)
self.assertTrue(fut.done())
# it should have been cancelled due to the timeout
self.assertTrue(fut.cancelled())
self.assertEqual(foo_running, False)
async def test_wait_for_blocking(self):
async def coro():
return 'done'
res = await asyncio.wait_for(coro(), timeout=None)
self.assertEqual(res, 'done')
async def test_wait_for_race_condition(self):
loop = asyncio.get_running_loop()
fut = loop.create_future()
task = asyncio.wait_for(fut, timeout=0.2)
loop.call_soon(fut.set_result, "ok")
res = await task
self.assertEqual(res, "ok")
async def test_wait_for_cancellation_race_condition(self):
async def inner():
with self.assertRaises(asyncio.CancelledError):
await asyncio.sleep(1)
return 1
result = await asyncio.wait_for(inner(), timeout=.01)
self.assertEqual(result, 1)
async def test_wait_for_waits_for_task_cancellation(self):
task_done = False
async def inner():
nonlocal task_done
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
await asyncio.sleep(_EPSILON)
raise
finally:
task_done = True
inner_task = asyncio.create_task(inner())
with self.assertRaises(asyncio.TimeoutError) as cm:
await asyncio.wait_for(inner_task, timeout=_EPSILON)
self.assertTrue(task_done)
chained = cm.exception.__context__
self.assertEqual(type(chained), asyncio.CancelledError)
async def test_wait_for_waits_for_task_cancellation_w_timeout_0(self):
task_done = False
async def foo():
async def inner():
nonlocal task_done
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
await asyncio.sleep(_EPSILON)
raise
finally:
task_done = True
inner_task = asyncio.create_task(inner())
await asyncio.sleep(_EPSILON)
await asyncio.wait_for(inner_task, timeout=0)
with self.assertRaises(asyncio.TimeoutError) as cm:
await foo()
self.assertTrue(task_done)
chained = cm.exception.__context__
self.assertEqual(type(chained), asyncio.CancelledError)
async def test_wait_for_reraises_exception_during_cancellation(self):
class FooException(Exception):
pass
async def foo():
async def inner():
try:
await asyncio.sleep(0.2)
finally:
raise FooException
inner_task = asyncio.create_task(inner())
await asyncio.wait_for(inner_task, timeout=_EPSILON)
with self.assertRaises(FooException):
await foo()
async def _test_cancel_wait_for(self, timeout):
loop = asyncio.get_running_loop()
async def blocking_coroutine():
fut = loop.create_future()
# Block: fut result is never set
await fut
task = asyncio.create_task(blocking_coroutine())
wait = asyncio.create_task(asyncio.wait_for(task, timeout))
loop.call_soon(wait.cancel)
with self.assertRaises(asyncio.CancelledError):
await wait
# Python issue #23219: cancelling the wait must also cancel the task
self.assertTrue(task.cancelled())
async def test_cancel_blocking_wait_for(self):
await self._test_cancel_wait_for(None)
async def test_cancel_wait_for(self):
await self._test_cancel_wait_for(60.0)
async def test_wait_for_cancel_suppressed(self):
# GH-86296: Suppressing CancelledError is discouraged
# but if a task suppresses CancelledError and returns a value,
# `wait_for` should return the value instead of raising CancelledError.
# This is the same behavior as `asyncio.timeout`.
async def return_42():
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
return 42
res = await asyncio.wait_for(return_42(), timeout=0.1)
self.assertEqual(res, 42)
async def test_wait_for_issue86296(self):
# GH-86296: The task should get cancelled and not run to completion.
# inner completes in one cycle of the event loop so it
# completes before the task is cancelled.
async def inner():
return 'done'
inner_task = asyncio.create_task(inner())
reached_end = False
async def wait_for_coro():
await asyncio.wait_for(inner_task, timeout=100)
await asyncio.sleep(1)
nonlocal reached_end
reached_end = True
task = asyncio.create_task(wait_for_coro())
self.assertFalse(task.done())
# Run the task
await asyncio.sleep(0)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(inner_task.done())
self.assertEqual(await inner_task, 'done')
self.assertFalse(reached_end)
class WaitForShieldTests(unittest.IsolatedAsyncioTestCase):
async def test_zero_timeout(self):
# `asyncio.shield` creates a new task which wraps the passed in
# awaitable and shields it from cancellation so with timeout=0
# the task returned by `asyncio.shield` aka shielded_task gets
# cancelled immediately and the task wrapped by it is scheduled
# to run.
async def coro():
await asyncio.sleep(0.01)
return 'done'
task = asyncio.create_task(coro())
with self.assertRaises(asyncio.TimeoutError):
shielded_task = asyncio.shield(task)
await asyncio.wait_for(shielded_task, timeout=0)
# Task is running in background
self.assertFalse(task.done())
self.assertFalse(task.cancelled())
self.assertTrue(shielded_task.cancelled())
# Wait for the task to complete
await asyncio.sleep(0.1)
self.assertTrue(task.done())
async def test_none_timeout(self):
# With timeout=None the timeout is disabled so it
# runs till completion.
async def coro():
await asyncio.sleep(0.1)
return 'done'
task = asyncio.create_task(coro())
await asyncio.wait_for(asyncio.shield(task), timeout=None)
self.assertTrue(task.done())
self.assertEqual(await task, "done")
async def test_shielded_timeout(self):
# shield prevents the task from being cancelled.
async def coro():
await asyncio.sleep(0.1)
return 'done'
task = asyncio.create_task(coro())
with self.assertRaises(asyncio.TimeoutError):
await asyncio.wait_for(asyncio.shield(task), timeout=0.01)
self.assertFalse(task.done())
self.assertFalse(task.cancelled())
self.assertEqual(await task, "done")
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,359 @@
import os
import signal
import socket
import sys
import time
import threading
import unittest
from unittest import mock
if sys.platform != 'win32':
raise unittest.SkipTest('Windows only')
import _overlapped
import _winapi
import asyncio
from asyncio import windows_events
from test.test_asyncio import utils as test_utils
def tearDownModule():
asyncio.set_event_loop_policy(None)
class UpperProto(asyncio.Protocol):
def __init__(self):
self.buf = []
def connection_made(self, trans):
self.trans = trans
def data_received(self, data):
self.buf.append(data)
if b'\n' in data:
self.trans.write(b''.join(self.buf).upper())
self.trans.close()
class WindowsEventsTestCase(test_utils.TestCase):
def _unraisablehook(self, unraisable):
# Storing unraisable.object can resurrect an object which is being
# finalized. Storing unraisable.exc_value creates a reference cycle.
self._unraisable = unraisable
print(unraisable)
def setUp(self):
self._prev_unraisablehook = sys.unraisablehook
self._unraisable = None
sys.unraisablehook = self._unraisablehook
def tearDown(self):
sys.unraisablehook = self._prev_unraisablehook
self.assertIsNone(self._unraisable)
class ProactorLoopCtrlC(WindowsEventsTestCase):
def test_ctrl_c(self):
def SIGINT_after_delay():
time.sleep(0.1)
signal.raise_signal(signal.SIGINT)
thread = threading.Thread(target=SIGINT_after_delay)
loop = asyncio.new_event_loop()
try:
# only start the loop once the event loop is running
loop.call_soon(thread.start)
loop.run_forever()
self.fail("should not fall through 'run_forever'")
except KeyboardInterrupt:
pass
finally:
self.close_loop(loop)
thread.join()
class ProactorMultithreading(WindowsEventsTestCase):
def test_run_from_nonmain_thread(self):
finished = False
async def coro():
await asyncio.sleep(0)
def func():
nonlocal finished
loop = asyncio.new_event_loop()
loop.run_until_complete(coro())
# close() must not call signal.set_wakeup_fd()
loop.close()
finished = True
thread = threading.Thread(target=func)
thread.start()
thread.join()
self.assertTrue(finished)
class ProactorTests(WindowsEventsTestCase):
def setUp(self):
super().setUp()
self.loop = asyncio.ProactorEventLoop()
self.set_event_loop(self.loop)
def test_close(self):
a, b = socket.socketpair()
trans = self.loop._make_socket_transport(a, asyncio.Protocol())
f = asyncio.ensure_future(self.loop.sock_recv(b, 100), loop=self.loop)
trans.close()
self.loop.run_until_complete(f)
self.assertEqual(f.result(), b'')
b.close()
def test_double_bind(self):
ADDRESS = r'\\.\pipe\test_double_bind-%s' % os.getpid()
server1 = windows_events.PipeServer(ADDRESS)
with self.assertRaises(PermissionError):
windows_events.PipeServer(ADDRESS)
server1.close()
def test_pipe(self):
res = self.loop.run_until_complete(self._test_pipe())
self.assertEqual(res, 'done')
async def _test_pipe(self):
ADDRESS = r'\\.\pipe\_test_pipe-%s' % os.getpid()
with self.assertRaises(FileNotFoundError):
await self.loop.create_pipe_connection(
asyncio.Protocol, ADDRESS)
[server] = await self.loop.start_serving_pipe(
UpperProto, ADDRESS)
self.assertIsInstance(server, windows_events.PipeServer)
clients = []
for i in range(5):
stream_reader = asyncio.StreamReader(loop=self.loop)
protocol = asyncio.StreamReaderProtocol(stream_reader,
loop=self.loop)
trans, proto = await self.loop.create_pipe_connection(
lambda: protocol, ADDRESS)
self.assertIsInstance(trans, asyncio.Transport)
self.assertEqual(protocol, proto)
clients.append((stream_reader, trans))
for i, (r, w) in enumerate(clients):
w.write('lower-{}\n'.format(i).encode())
for i, (r, w) in enumerate(clients):
response = await r.readline()
self.assertEqual(response, 'LOWER-{}\n'.format(i).encode())
w.close()
server.close()
with self.assertRaises(FileNotFoundError):
await self.loop.create_pipe_connection(
asyncio.Protocol, ADDRESS)
return 'done'
def test_connect_pipe_cancel(self):
exc = OSError()
exc.winerror = _overlapped.ERROR_PIPE_BUSY
with mock.patch.object(_overlapped, 'ConnectPipe',
side_effect=exc) as connect:
coro = self.loop._proactor.connect_pipe('pipe_address')
task = self.loop.create_task(coro)
# check that it's possible to cancel connect_pipe()
task.cancel()
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(task)
def test_wait_for_handle(self):
event = _overlapped.CreateEvent(None, True, False, None)
self.addCleanup(_winapi.CloseHandle, event)
# Wait for unset event with 0.5s timeout;
# result should be False at timeout
timeout = 0.5
fut = self.loop._proactor.wait_for_handle(event, timeout)
start = self.loop.time()
done = self.loop.run_until_complete(fut)
elapsed = self.loop.time() - start
self.assertEqual(done, False)
self.assertFalse(fut.result())
self.assertGreaterEqual(elapsed, timeout - test_utils.CLOCK_RES)
_overlapped.SetEvent(event)
# Wait for set event;
# result should be True immediately
fut = self.loop._proactor.wait_for_handle(event, 10)
done = self.loop.run_until_complete(fut)
self.assertEqual(done, True)
self.assertTrue(fut.result())
# asyncio issue #195: cancelling a done _WaitHandleFuture
# must not crash
fut.cancel()
def test_wait_for_handle_cancel(self):
event = _overlapped.CreateEvent(None, True, False, None)
self.addCleanup(_winapi.CloseHandle, event)
# Wait for unset event with a cancelled future;
# CancelledError should be raised immediately
fut = self.loop._proactor.wait_for_handle(event, 10)
fut.cancel()
with self.assertRaises(asyncio.CancelledError):
self.loop.run_until_complete(fut)
# asyncio issue #195: cancelling a _WaitHandleFuture twice
# must not crash
fut = self.loop._proactor.wait_for_handle(event)
fut.cancel()
fut.cancel()
def test_read_self_pipe_restart(self):
# Regression test for https://bugs.python.org/issue39010
# Previously, restarting a proactor event loop in certain states
# would lead to spurious ConnectionResetErrors being logged.
self.loop.call_exception_handler = mock.Mock()
# Start an operation in another thread so that the self-pipe is used.
# This is theoretically timing-dependent (the task in the executor
# must complete before our start/stop cycles), but in practice it
# seems to work every time.
f = self.loop.run_in_executor(None, lambda: None)
self.loop.stop()
self.loop.run_forever()
self.loop.stop()
self.loop.run_forever()
# Shut everything down cleanly. This is an important part of the
# test - in issue 39010, the error occurred during loop.close(),
# so we want to close the loop during the test instead of leaving
# it for tearDown.
#
# First wait for f to complete to avoid a "future's result was never
# retrieved" error.
self.loop.run_until_complete(f)
# Now shut down the loop itself (self.close_loop also shuts down the
# loop's default executor).
self.close_loop(self.loop)
self.assertFalse(self.loop.call_exception_handler.called)
def test_address_argument_type_error(self):
# Regression test for https://github.com/python/cpython/issues/98793
proactor = self.loop._proactor
sock = socket.socket(type=socket.SOCK_DGRAM)
bad_address = None
with self.assertRaises(TypeError):
proactor.connect(sock, bad_address)
with self.assertRaises(TypeError):
proactor.sendto(sock, b'abc', addr=bad_address)
sock.close()
def test_client_pipe_stat(self):
res = self.loop.run_until_complete(self._test_client_pipe_stat())
self.assertEqual(res, 'done')
async def _test_client_pipe_stat(self):
# Regression test for https://github.com/python/cpython/issues/100573
ADDRESS = r'\\.\pipe\test_client_pipe_stat-%s' % os.getpid()
async def probe():
# See https://github.com/python/cpython/pull/100959#discussion_r1068533658
h = _overlapped.ConnectPipe(ADDRESS)
try:
_winapi.CloseHandle(_overlapped.ConnectPipe(ADDRESS))
except OSError as e:
if e.winerror != _overlapped.ERROR_PIPE_BUSY:
raise
finally:
_winapi.CloseHandle(h)
with self.assertRaises(FileNotFoundError):
await probe()
[server] = await self.loop.start_serving_pipe(asyncio.Protocol, ADDRESS)
self.assertIsInstance(server, windows_events.PipeServer)
errors = []
self.loop.set_exception_handler(lambda _, data: errors.append(data))
for i in range(5):
await self.loop.create_task(probe())
self.assertEqual(len(errors), 0, errors)
server.close()
with self.assertRaises(FileNotFoundError):
await probe()
return "done"
def test_loop_restart(self):
# We're fishing for the "RuntimeError: <_overlapped.Overlapped object at XXX>
# still has pending operation at deallocation, the process may crash" error
stop = threading.Event()
def threadMain():
while not stop.is_set():
self.loop.call_soon_threadsafe(lambda: None)
time.sleep(0.01)
thr = threading.Thread(target=threadMain)
# In 10 60-second runs of this test prior to the fix:
# time in seconds until failure: (none), 15.0, 6.4, (none), 7.6, 8.3, 1.7, 22.2, 23.5, 8.3
# 10 seconds had a 50% failure rate but longer would be more costly
end_time = time.time() + 10 # Run for 10 seconds
self.loop.call_soon(thr.start)
while not self._unraisable: # Stop if we got an unraisable exc
self.loop.stop()
self.loop.run_forever()
if time.time() >= end_time:
break
stop.set()
thr.join()
class WinPolicyTests(WindowsEventsTestCase):
def test_selector_win_policy(self):
async def main():
self.assertIsInstance(
asyncio.get_running_loop(),
asyncio.SelectorEventLoop)
old_policy = asyncio.get_event_loop_policy()
try:
asyncio.set_event_loop_policy(
asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(main())
finally:
asyncio.set_event_loop_policy(old_policy)
def test_proactor_win_policy(self):
async def main():
self.assertIsInstance(
asyncio.get_running_loop(),
asyncio.ProactorEventLoop)
old_policy = asyncio.get_event_loop_policy()
try:
asyncio.set_event_loop_policy(
asyncio.WindowsProactorEventLoopPolicy())
asyncio.run(main())
finally:
asyncio.set_event_loop_policy(old_policy)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,133 @@
"""Tests for window_utils"""
import sys
import unittest
import warnings
if sys.platform != 'win32':
raise unittest.SkipTest('Windows only')
import _overlapped
import _winapi
import asyncio
from asyncio import windows_utils
from test import support
def tearDownModule():
asyncio.set_event_loop_policy(None)
class PipeTests(unittest.TestCase):
def test_pipe_overlapped(self):
h1, h2 = windows_utils.pipe(overlapped=(True, True))
try:
ov1 = _overlapped.Overlapped()
self.assertFalse(ov1.pending)
self.assertEqual(ov1.error, 0)
ov1.ReadFile(h1, 100)
self.assertTrue(ov1.pending)
self.assertEqual(ov1.error, _winapi.ERROR_IO_PENDING)
ERROR_IO_INCOMPLETE = 996
try:
ov1.getresult()
except OSError as e:
self.assertEqual(e.winerror, ERROR_IO_INCOMPLETE)
else:
raise RuntimeError('expected ERROR_IO_INCOMPLETE')
ov2 = _overlapped.Overlapped()
self.assertFalse(ov2.pending)
self.assertEqual(ov2.error, 0)
ov2.WriteFile(h2, b"hello")
self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
res = _winapi.WaitForMultipleObjects([ov2.event], False, 100)
self.assertEqual(res, _winapi.WAIT_OBJECT_0)
self.assertFalse(ov1.pending)
self.assertEqual(ov1.error, ERROR_IO_INCOMPLETE)
self.assertFalse(ov2.pending)
self.assertIn(ov2.error, {0, _winapi.ERROR_IO_PENDING})
self.assertEqual(ov1.getresult(), b"hello")
finally:
_winapi.CloseHandle(h1)
_winapi.CloseHandle(h2)
def test_pipe_handle(self):
h, _ = windows_utils.pipe(overlapped=(True, True))
_winapi.CloseHandle(_)
p = windows_utils.PipeHandle(h)
self.assertEqual(p.fileno(), h)
self.assertEqual(p.handle, h)
# check garbage collection of p closes handle
with warnings.catch_warnings():
warnings.filterwarnings("ignore", "", ResourceWarning)
del p
support.gc_collect()
try:
_winapi.CloseHandle(h)
except OSError as e:
self.assertEqual(e.winerror, 6) # ERROR_INVALID_HANDLE
else:
raise RuntimeError('expected ERROR_INVALID_HANDLE')
class PopenTests(unittest.TestCase):
def test_popen(self):
command = r"""if 1:
import sys
s = sys.stdin.readline()
sys.stdout.write(s.upper())
sys.stderr.write('stderr')
"""
msg = b"blah\n"
p = windows_utils.Popen([sys.executable, '-c', command],
stdin=windows_utils.PIPE,
stdout=windows_utils.PIPE,
stderr=windows_utils.PIPE)
for f in [p.stdin, p.stdout, p.stderr]:
self.assertIsInstance(f, windows_utils.PipeHandle)
ovin = _overlapped.Overlapped()
ovout = _overlapped.Overlapped()
overr = _overlapped.Overlapped()
ovin.WriteFile(p.stdin.handle, msg)
ovout.ReadFile(p.stdout.handle, 100)
overr.ReadFile(p.stderr.handle, 100)
events = [ovin.event, ovout.event, overr.event]
# Super-long timeout for slow buildbots.
res = _winapi.WaitForMultipleObjects(events, True,
int(support.SHORT_TIMEOUT * 1000))
self.assertEqual(res, _winapi.WAIT_OBJECT_0)
self.assertFalse(ovout.pending)
self.assertFalse(overr.pending)
self.assertFalse(ovin.pending)
self.assertEqual(ovin.getresult(), len(msg))
out = ovout.getresult().rstrip()
err = overr.getresult().rstrip()
self.assertGreater(len(out), 0)
self.assertGreater(len(err), 0)
# allow for partial reads...
self.assertTrue(msg.upper().rstrip().startswith(out))
self.assertTrue(b"stderr".startswith(err))
# The context manager calls wait() and closes resources
with p:
pass
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,638 @@
"""Utilities shared by tests."""
import asyncio
import collections
import contextlib
import io
import logging
import os
import re
import selectors
import socket
import socketserver
import sys
import threading
import unittest
import weakref
import warnings
from unittest import mock
from http.server import HTTPServer
from wsgiref.simple_server import WSGIRequestHandler, WSGIServer
try:
import ssl
except ImportError: # pragma: no cover
ssl = None
from asyncio import base_events
from asyncio import events
from asyncio import format_helpers
from asyncio import tasks
from asyncio.log import logger
from test import support
from test.support import socket_helper
from test.support import threading_helper
# Use the maximum known clock resolution (gh-75191, gh-110088): Windows
# GetTickCount64() has a resolution of 15.6 ms. Use 50 ms to tolerate rounding
# issues.
CLOCK_RES = 0.050
def data_file(*filename):
fullname = os.path.join(support.TEST_HOME_DIR, *filename)
if os.path.isfile(fullname):
return fullname
fullname = os.path.join(os.path.dirname(__file__), '..', *filename)
if os.path.isfile(fullname):
return fullname
raise FileNotFoundError(os.path.join(filename))
ONLYCERT = data_file('certdata', 'ssl_cert.pem')
ONLYKEY = data_file('certdata', 'ssl_key.pem')
SIGNED_CERTFILE = data_file('certdata', 'keycert3.pem')
SIGNING_CA = data_file('certdata', 'pycacert.pem')
PEERCERT = {
'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
'issuer': ((('countryName', 'XY'),),
(('organizationName', 'Python Software Foundation CA'),),
(('commonName', 'our-ca-server'),)),
'notAfter': 'Oct 28 14:23:16 2037 GMT',
'notBefore': 'Aug 29 14:23:16 2018 GMT',
'serialNumber': 'CB2D80995A69525C',
'subject': ((('countryName', 'XY'),),
(('localityName', 'Castle Anthrax'),),
(('organizationName', 'Python Software Foundation'),),
(('commonName', 'localhost'),)),
'subjectAltName': (('DNS', 'localhost'),),
'version': 3
}
def simple_server_sslcontext():
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.load_cert_chain(ONLYCERT, ONLYKEY)
server_context.check_hostname = False
server_context.verify_mode = ssl.CERT_NONE
return server_context
def simple_client_sslcontext(*, disable_verify=True):
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.check_hostname = False
if disable_verify:
client_context.verify_mode = ssl.CERT_NONE
return client_context
def dummy_ssl_context():
if ssl is None:
return None
else:
return simple_client_sslcontext(disable_verify=True)
def run_briefly(loop):
async def once():
pass
gen = once()
t = loop.create_task(gen)
# Don't log a warning if the task is not done after run_until_complete().
# It occurs if the loop is stopped or if a task raises a BaseException.
t._log_destroy_pending = False
try:
loop.run_until_complete(t)
finally:
gen.close()
def run_until(loop, pred, timeout=support.SHORT_TIMEOUT):
delay = 0.001
for _ in support.busy_retry(timeout, error=False):
if pred():
break
loop.run_until_complete(tasks.sleep(delay))
delay = max(delay * 2, 1.0)
else:
raise TimeoutError()
def run_once(loop):
"""Legacy API to run once through the event loop.
This is the recommended pattern for test code. It will poll the
selector once and run all callbacks scheduled in response to I/O
events.
"""
loop.call_soon(loop.stop)
loop.run_forever()
class SilentWSGIRequestHandler(WSGIRequestHandler):
def get_stderr(self):
return io.StringIO()
def log_message(self, format, *args):
pass
class SilentWSGIServer(WSGIServer):
request_timeout = support.LOOPBACK_TIMEOUT
def get_request(self):
request, client_addr = super().get_request()
request.settimeout(self.request_timeout)
return request, client_addr
def handle_error(self, request, client_address):
pass
class SSLWSGIServerMixin:
def finish_request(self, request, client_address):
# The relative location of our test directory (which
# contains the ssl key and certificate files) differs
# between the stdlib and stand-alone asyncio.
# Prefer our own if we can find it.
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.load_cert_chain(ONLYCERT, ONLYKEY)
ssock = context.wrap_socket(request, server_side=True)
try:
self.RequestHandlerClass(ssock, client_address, self)
ssock.close()
except OSError:
# maybe socket has been closed by peer
pass
class SSLWSGIServer(SSLWSGIServerMixin, SilentWSGIServer):
pass
def _run_test_server(*, address, use_ssl=False, server_cls, server_ssl_cls):
def loop(environ):
size = int(environ['CONTENT_LENGTH'])
while size:
data = environ['wsgi.input'].read(min(size, 0x10000))
yield data
size -= len(data)
def app(environ, start_response):
status = '200 OK'
headers = [('Content-type', 'text/plain')]
start_response(status, headers)
if environ['PATH_INFO'] == '/loop':
return loop(environ)
else:
return [b'Test message']
# Run the test WSGI server in a separate thread in order not to
# interfere with event handling in the main thread
server_class = server_ssl_cls if use_ssl else server_cls
httpd = server_class(address, SilentWSGIRequestHandler)
httpd.set_app(app)
httpd.address = httpd.server_address
server_thread = threading.Thread(
target=lambda: httpd.serve_forever(poll_interval=0.05))
server_thread.start()
try:
yield httpd
finally:
httpd.shutdown()
httpd.server_close()
server_thread.join()
if hasattr(socket, 'AF_UNIX'):
class UnixHTTPServer(socketserver.UnixStreamServer, HTTPServer):
def server_bind(self):
socketserver.UnixStreamServer.server_bind(self)
self.server_name = '127.0.0.1'
self.server_port = 80
class UnixWSGIServer(UnixHTTPServer, WSGIServer):
request_timeout = support.LOOPBACK_TIMEOUT
def server_bind(self):
UnixHTTPServer.server_bind(self)
self.setup_environ()
def get_request(self):
request, client_addr = super().get_request()
request.settimeout(self.request_timeout)
# Code in the stdlib expects that get_request
# will return a socket and a tuple (host, port).
# However, this isn't true for UNIX sockets,
# as the second return value will be a path;
# hence we return some fake data sufficient
# to get the tests going
return request, ('127.0.0.1', '')
class SilentUnixWSGIServer(UnixWSGIServer):
def handle_error(self, request, client_address):
pass
class UnixSSLWSGIServer(SSLWSGIServerMixin, SilentUnixWSGIServer):
pass
def gen_unix_socket_path():
return socket_helper.create_unix_domain_name()
@contextlib.contextmanager
def unix_socket_path():
path = gen_unix_socket_path()
try:
yield path
finally:
try:
os.unlink(path)
except OSError:
pass
@contextlib.contextmanager
def run_test_unix_server(*, use_ssl=False):
with unix_socket_path() as path:
yield from _run_test_server(address=path, use_ssl=use_ssl,
server_cls=SilentUnixWSGIServer,
server_ssl_cls=UnixSSLWSGIServer)
@contextlib.contextmanager
def run_test_server(*, host='127.0.0.1', port=0, use_ssl=False):
yield from _run_test_server(address=(host, port), use_ssl=use_ssl,
server_cls=SilentWSGIServer,
server_ssl_cls=SSLWSGIServer)
def echo_datagrams(sock):
while True:
data, addr = sock.recvfrom(4096)
if data == b'STOP':
sock.close()
break
else:
sock.sendto(data, addr)
@contextlib.contextmanager
def run_udp_echo_server(*, host='127.0.0.1', port=0):
addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_DGRAM)
family, type, proto, _, sockaddr = addr_info[0]
sock = socket.socket(family, type, proto)
sock.bind((host, port))
sockname = sock.getsockname()
thread = threading.Thread(target=lambda: echo_datagrams(sock))
thread.start()
try:
yield sockname
finally:
# gh-122187: use a separate socket to send the stop message to avoid
# TSan reported race on the same socket.
sock2 = socket.socket(family, type, proto)
sock2.sendto(b'STOP', sockname)
sock2.close()
thread.join()
def make_test_protocol(base):
dct = {}
for name in dir(base):
if name.startswith('__') and name.endswith('__'):
# skip magic names
continue
dct[name] = MockCallback(return_value=None)
return type('TestProtocol', (base,) + base.__bases__, dct)()
class TestSelector(selectors.BaseSelector):
def __init__(self):
self.keys = {}
def register(self, fileobj, events, data=None):
key = selectors.SelectorKey(fileobj, 0, events, data)
self.keys[fileobj] = key
return key
def unregister(self, fileobj):
return self.keys.pop(fileobj)
def select(self, timeout):
return []
def get_map(self):
return self.keys
class TestLoop(base_events.BaseEventLoop):
"""Loop for unittests.
It manages self time directly.
If something scheduled to be executed later then
on next loop iteration after all ready handlers done
generator passed to __init__ is calling.
Generator should be like this:
def gen():
...
when = yield ...
... = yield time_advance
Value returned by yield is absolute time of next scheduled handler.
Value passed to yield is time advance to move loop's time forward.
"""
def __init__(self, gen=None):
super().__init__()
if gen is None:
def gen():
yield
self._check_on_close = False
else:
self._check_on_close = True
self._gen = gen()
next(self._gen)
self._time = 0
self._clock_resolution = 1e-9
self._timers = []
self._selector = TestSelector()
self.readers = {}
self.writers = {}
self.reset_counters()
self._transports = weakref.WeakValueDictionary()
def time(self):
return self._time
def advance_time(self, advance):
"""Move test time forward."""
if advance:
self._time += advance
def close(self):
super().close()
if self._check_on_close:
try:
self._gen.send(0)
except StopIteration:
pass
else: # pragma: no cover
raise AssertionError("Time generator is not finished")
def _add_reader(self, fd, callback, *args):
self.readers[fd] = events.Handle(callback, args, self, None)
def _remove_reader(self, fd):
self.remove_reader_count[fd] += 1
if fd in self.readers:
del self.readers[fd]
return True
else:
return False
def assert_reader(self, fd, callback, *args):
if fd not in self.readers:
raise AssertionError(f'fd {fd} is not registered')
handle = self.readers[fd]
if handle._callback != callback:
raise AssertionError(
f'unexpected callback: {handle._callback} != {callback}')
if handle._args != args:
raise AssertionError(
f'unexpected callback args: {handle._args} != {args}')
def assert_no_reader(self, fd):
if fd in self.readers:
raise AssertionError(f'fd {fd} is registered')
def _add_writer(self, fd, callback, *args):
self.writers[fd] = events.Handle(callback, args, self, None)
def _remove_writer(self, fd):
self.remove_writer_count[fd] += 1
if fd in self.writers:
del self.writers[fd]
return True
else:
return False
def assert_writer(self, fd, callback, *args):
if fd not in self.writers:
raise AssertionError(f'fd {fd} is not registered')
handle = self.writers[fd]
if handle._callback != callback:
raise AssertionError(f'{handle._callback!r} != {callback!r}')
if handle._args != args:
raise AssertionError(f'{handle._args!r} != {args!r}')
def _ensure_fd_no_transport(self, fd):
if not isinstance(fd, int):
try:
fd = int(fd.fileno())
except (AttributeError, TypeError, ValueError):
# This code matches selectors._fileobj_to_fd function.
raise ValueError("Invalid file object: "
"{!r}".format(fd)) from None
try:
transport = self._transports[fd]
except KeyError:
pass
else:
raise RuntimeError(
'File descriptor {!r} is used by transport {!r}'.format(
fd, transport))
def add_reader(self, fd, callback, *args):
"""Add a reader callback."""
self._ensure_fd_no_transport(fd)
return self._add_reader(fd, callback, *args)
def remove_reader(self, fd):
"""Remove a reader callback."""
self._ensure_fd_no_transport(fd)
return self._remove_reader(fd)
def add_writer(self, fd, callback, *args):
"""Add a writer callback.."""
self._ensure_fd_no_transport(fd)
return self._add_writer(fd, callback, *args)
def remove_writer(self, fd):
"""Remove a writer callback."""
self._ensure_fd_no_transport(fd)
return self._remove_writer(fd)
def reset_counters(self):
self.remove_reader_count = collections.defaultdict(int)
self.remove_writer_count = collections.defaultdict(int)
def _run_once(self):
super()._run_once()
for when in self._timers:
advance = self._gen.send(when)
self.advance_time(advance)
self._timers = []
def call_at(self, when, callback, *args, context=None):
self._timers.append(when)
return super().call_at(when, callback, *args, context=context)
def _process_events(self, event_list):
return
def _write_to_self(self):
pass
def MockCallback(**kwargs):
return mock.Mock(spec=['__call__'], **kwargs)
class MockPattern(str):
"""A regex based str with a fuzzy __eq__.
Use this helper with 'mock.assert_called_with', or anywhere
where a regex comparison between strings is needed.
For instance:
mock_call.assert_called_with(MockPattern('spam.*ham'))
"""
def __eq__(self, other):
return bool(re.search(str(self), other, re.S))
class MockInstanceOf:
def __init__(self, type):
self._type = type
def __eq__(self, other):
return isinstance(other, self._type)
def get_function_source(func):
source = format_helpers._get_function_source(func)
if source is None:
raise ValueError("unable to get the source of %r" % (func,))
return source
class TestCase(unittest.TestCase):
@staticmethod
def close_loop(loop):
if loop._default_executor is not None:
if not loop.is_closed():
loop.run_until_complete(loop.shutdown_default_executor())
else:
loop._default_executor.shutdown(wait=True)
loop.close()
policy = support.maybe_get_event_loop_policy()
if policy is not None:
try:
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
watcher = policy.get_child_watcher()
except NotImplementedError:
# watcher is not implemented by EventLoopPolicy, e.g. Windows
pass
else:
if isinstance(watcher, asyncio.ThreadedChildWatcher):
# Wait for subprocess to finish, but not forever
for thread in list(watcher._threads.values()):
thread.join(timeout=support.SHORT_TIMEOUT)
if thread.is_alive():
raise RuntimeError(f"thread {thread} still alive: "
"subprocess still running")
def set_event_loop(self, loop, *, cleanup=True):
if loop is None:
raise AssertionError('loop is None')
# ensure that the event loop is passed explicitly in asyncio
events.set_event_loop(None)
if cleanup:
self.addCleanup(self.close_loop, loop)
def new_test_loop(self, gen=None):
loop = TestLoop(gen)
self.set_event_loop(loop)
return loop
def setUp(self):
self._thread_cleanup = threading_helper.threading_setup()
def tearDown(self):
events.set_event_loop(None)
# Detect CPython bug #23353: ensure that yield/yield-from is not used
# in an except block of a generator
self.assertIsNone(sys.exception())
self.doCleanups()
threading_helper.threading_cleanup(*self._thread_cleanup)
support.reap_children()
@contextlib.contextmanager
def disable_logger():
"""Context manager to disable asyncio logger.
For example, it can be used to ignore warnings in debug mode.
"""
old_level = logger.level
try:
logger.setLevel(logging.CRITICAL+1)
yield
finally:
logger.setLevel(old_level)
def mock_nonblocking_socket(proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM,
family=socket.AF_INET):
"""Create a mock of a non-blocking socket."""
sock = mock.MagicMock(socket.socket)
sock.proto = proto
sock.type = type
sock.family = family
sock.gettimeout.return_value = 0.0
return sock
async def await_without_task(coro):
exc = None
def func():
try:
for _ in coro.__await__():
pass
except BaseException as err:
nonlocal exc
exc = err
asyncio.get_running_loop().call_soon(func)
await asyncio.sleep(0)
if exc is not None:
raise exc