Initial commit: Robot ökoszisztéma v2.0 - Stabilizált jármű és szerviz robotok

This commit is contained in:
Kincses
2026-03-04 02:03:03 +01:00
commit 250f4f4b8f
7942 changed files with 449625 additions and 0 deletions

View File

@@ -0,0 +1,8 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from dotenv.__main__ import cli
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(cli())

View File

@@ -0,0 +1,8 @@
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import re
import sys
from httpx import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())

View File

@@ -0,0 +1,164 @@
/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */
/* Greenlet object interface */
#ifndef Py_GREENLETOBJECT_H
#define Py_GREENLETOBJECT_H
#include <Python.h>
#ifdef __cplusplus
extern "C" {
#endif
/* This is deprecated and undocumented. It does not change. */
#define GREENLET_VERSION "1.0.0"
#ifndef GREENLET_MODULE
#define implementation_ptr_t void*
#endif
typedef struct _greenlet {
PyObject_HEAD
PyObject* weakreflist;
PyObject* dict;
implementation_ptr_t pimpl;
} PyGreenlet;
#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type))
/* C API functions */
/* Total number of symbols that are exported */
#define PyGreenlet_API_pointers 12
#define PyGreenlet_Type_NUM 0
#define PyExc_GreenletError_NUM 1
#define PyExc_GreenletExit_NUM 2
#define PyGreenlet_New_NUM 3
#define PyGreenlet_GetCurrent_NUM 4
#define PyGreenlet_Throw_NUM 5
#define PyGreenlet_Switch_NUM 6
#define PyGreenlet_SetParent_NUM 7
#define PyGreenlet_MAIN_NUM 8
#define PyGreenlet_STARTED_NUM 9
#define PyGreenlet_ACTIVE_NUM 10
#define PyGreenlet_GET_PARENT_NUM 11
#ifndef GREENLET_MODULE
/* This section is used by modules that uses the greenlet C API */
static void** _PyGreenlet_API = NULL;
# define PyGreenlet_Type \
(*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM])
# define PyExc_GreenletError \
((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM])
# define PyExc_GreenletExit \
((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM])
/*
* PyGreenlet_New(PyObject *args)
*
* greenlet.greenlet(run, parent=None)
*/
# define PyGreenlet_New \
(*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \
_PyGreenlet_API[PyGreenlet_New_NUM])
/*
* PyGreenlet_GetCurrent(void)
*
* greenlet.getcurrent()
*/
# define PyGreenlet_GetCurrent \
(*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM])
/*
* PyGreenlet_Throw(
* PyGreenlet *greenlet,
* PyObject *typ,
* PyObject *val,
* PyObject *tb)
*
* g.throw(...)
*/
# define PyGreenlet_Throw \
(*(PyObject * (*)(PyGreenlet * self, \
PyObject * typ, \
PyObject * val, \
PyObject * tb)) \
_PyGreenlet_API[PyGreenlet_Throw_NUM])
/*
* PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args)
*
* g.switch(*args, **kwargs)
*/
# define PyGreenlet_Switch \
(*(PyObject * \
(*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \
_PyGreenlet_API[PyGreenlet_Switch_NUM])
/*
* PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent)
*
* g.parent = new_parent
*/
# define PyGreenlet_SetParent \
(*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \
_PyGreenlet_API[PyGreenlet_SetParent_NUM])
/*
* PyGreenlet_GetParent(PyObject* greenlet)
*
* return greenlet.parent;
*
* This could return NULL even if there is no exception active.
* If it does not return NULL, you are responsible for decrementing the
* reference count.
*/
# define PyGreenlet_GetParent \
(*(PyGreenlet* (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_GET_PARENT_NUM])
/*
* deprecated, undocumented alias.
*/
# define PyGreenlet_GET_PARENT PyGreenlet_GetParent
# define PyGreenlet_MAIN \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_MAIN_NUM])
# define PyGreenlet_STARTED \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_STARTED_NUM])
# define PyGreenlet_ACTIVE \
(*(int (*)(PyGreenlet*)) \
_PyGreenlet_API[PyGreenlet_ACTIVE_NUM])
/* Macro that imports greenlet and initializes C API */
/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we
keep the older definition to be sure older code that might have a copy of
the header still works. */
# define PyGreenlet_Import() \
{ \
_PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \
}
#endif /* GREENLET_MODULE */
#ifdef __cplusplus
}
#endif
#endif /* !Py_GREENLETOBJECT_H */

View File

@@ -0,0 +1,96 @@
Metadata-Version: 2.4
Name: anyio
Version: 4.12.1
Summary: High-level concurrency and networking framework on top of asyncio or Trio
Author-email: Alex Grönholm <alex.gronholm@nextday.fi>
License-Expression: MIT
Project-URL: Documentation, https://anyio.readthedocs.io/en/latest/
Project-URL: Changelog, https://anyio.readthedocs.io/en/stable/versionhistory.html
Project-URL: Source code, https://github.com/agronholm/anyio
Project-URL: Issue tracker, https://github.com/agronholm/anyio/issues
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Framework :: AnyIO
Classifier: Typing :: Typed
Classifier: Programming Language :: Python
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Requires-Python: >=3.9
Description-Content-Type: text/x-rst
License-File: LICENSE
Requires-Dist: exceptiongroup>=1.0.2; python_version < "3.11"
Requires-Dist: idna>=2.8
Requires-Dist: typing_extensions>=4.5; python_version < "3.13"
Provides-Extra: trio
Requires-Dist: trio>=0.32.0; python_version >= "3.10" and extra == "trio"
Requires-Dist: trio>=0.31.0; python_version < "3.10" and extra == "trio"
Dynamic: license-file
.. image:: https://github.com/agronholm/anyio/actions/workflows/test.yml/badge.svg
:target: https://github.com/agronholm/anyio/actions/workflows/test.yml
:alt: Build Status
.. image:: https://coveralls.io/repos/github/agronholm/anyio/badge.svg?branch=master
:target: https://coveralls.io/github/agronholm/anyio?branch=master
:alt: Code Coverage
.. image:: https://readthedocs.org/projects/anyio/badge/?version=latest
:target: https://anyio.readthedocs.io/en/latest/?badge=latest
:alt: Documentation
.. image:: https://badges.gitter.im/gitterHQ/gitter.svg
:target: https://gitter.im/python-trio/AnyIO
:alt: Gitter chat
AnyIO is an asynchronous networking and concurrency library that works on top of either asyncio_ or
Trio_. It implements Trio-like `structured concurrency`_ (SC) on top of asyncio and works in harmony
with the native SC of Trio itself.
Applications and libraries written against AnyIO's API will run unmodified on either asyncio_ or
Trio_. AnyIO can also be adopted into a library or application incrementally bit by bit, no full
refactoring necessary. It will blend in with the native libraries of your chosen backend.
To find out why you might want to use AnyIO's APIs instead of asyncio's, you can read about it
`here <https://anyio.readthedocs.io/en/stable/why.html>`_.
Documentation
-------------
View full documentation at: https://anyio.readthedocs.io/
Features
--------
AnyIO offers the following functionality:
* Task groups (nurseries_ in trio terminology)
* High-level networking (TCP, UDP and UNIX sockets)
* `Happy eyeballs`_ algorithm for TCP connections (more robust than that of asyncio on Python
3.8)
* async/await style UDP sockets (unlike asyncio where you still have to use Transports and
Protocols)
* A versatile API for byte streams and object streams
* Inter-task synchronization and communication (locks, conditions, events, semaphores, object
streams)
* Worker threads
* Subprocesses
* Subinterpreter support for code parallelization (on Python 3.13 and later)
* Asynchronous file I/O (using worker threads)
* Signal handling
* Asynchronous version of the functools_ module
AnyIO also comes with its own pytest_ plugin which also supports asynchronous fixtures.
It even works with the popular Hypothesis_ library.
.. _asyncio: https://docs.python.org/3/library/asyncio.html
.. _Trio: https://github.com/python-trio/trio
.. _structured concurrency: https://en.wikipedia.org/wiki/Structured_concurrency
.. _nurseries: https://trio.readthedocs.io/en/stable/reference-core.html#nurseries-and-spawning
.. _Happy eyeballs: https://en.wikipedia.org/wiki/Happy_Eyeballs
.. _pytest: https://docs.pytest.org/en/latest/
.. _functools: https://docs.python.org/3/library/functools.html
.. _Hypothesis: https://hypothesis.works/

View File

@@ -0,0 +1,92 @@
anyio-4.12.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
anyio-4.12.1.dist-info/METADATA,sha256=DfiDab9Tmmcfy802lOLTMEHJQShkOSbopCwqCYbLuJk,4277
anyio-4.12.1.dist-info/RECORD,,
anyio-4.12.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
anyio-4.12.1.dist-info/entry_points.txt,sha256=_d6Yu6uiaZmNe0CydowirE9Cmg7zUL2g08tQpoS3Qvc,39
anyio-4.12.1.dist-info/licenses/LICENSE,sha256=U2GsncWPLvX9LpsJxoKXwX8ElQkJu8gCO9uC6s8iwrA,1081
anyio-4.12.1.dist-info/top_level.txt,sha256=QglSMiWX8_5dpoVAEIHdEYzvqFMdSYWmCj6tYw2ITkQ,6
anyio/__init__.py,sha256=7iDVqMUprUuKNY91FuoKqayAhR-OY136YDPI6P78HHk,6170
anyio/__pycache__/__init__.cpython-312.pyc,,
anyio/__pycache__/from_thread.cpython-312.pyc,,
anyio/__pycache__/functools.cpython-312.pyc,,
anyio/__pycache__/lowlevel.cpython-312.pyc,,
anyio/__pycache__/pytest_plugin.cpython-312.pyc,,
anyio/__pycache__/to_interpreter.cpython-312.pyc,,
anyio/__pycache__/to_process.cpython-312.pyc,,
anyio/__pycache__/to_thread.cpython-312.pyc,,
anyio/_backends/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/_backends/__pycache__/__init__.cpython-312.pyc,,
anyio/_backends/__pycache__/_asyncio.cpython-312.pyc,,
anyio/_backends/__pycache__/_trio.cpython-312.pyc,,
anyio/_backends/_asyncio.py,sha256=xG6qv60mgGnL0mK82dxjH2b8hlkMlJ-x2BqIq3qv70Y,98863
anyio/_backends/_trio.py,sha256=30Rctb7lm8g63ZHljVPVnj5aH-uK6oQvphjwUBoAzuI,41456
anyio/_core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/_core/__pycache__/__init__.cpython-312.pyc,,
anyio/_core/__pycache__/_asyncio_selector_thread.cpython-312.pyc,,
anyio/_core/__pycache__/_contextmanagers.cpython-312.pyc,,
anyio/_core/__pycache__/_eventloop.cpython-312.pyc,,
anyio/_core/__pycache__/_exceptions.cpython-312.pyc,,
anyio/_core/__pycache__/_fileio.cpython-312.pyc,,
anyio/_core/__pycache__/_resources.cpython-312.pyc,,
anyio/_core/__pycache__/_signals.cpython-312.pyc,,
anyio/_core/__pycache__/_sockets.cpython-312.pyc,,
anyio/_core/__pycache__/_streams.cpython-312.pyc,,
anyio/_core/__pycache__/_subprocesses.cpython-312.pyc,,
anyio/_core/__pycache__/_synchronization.cpython-312.pyc,,
anyio/_core/__pycache__/_tasks.cpython-312.pyc,,
anyio/_core/__pycache__/_tempfile.cpython-312.pyc,,
anyio/_core/__pycache__/_testing.cpython-312.pyc,,
anyio/_core/__pycache__/_typedattr.cpython-312.pyc,,
anyio/_core/_asyncio_selector_thread.py,sha256=2PdxFM3cs02Kp6BSppbvmRT7q7asreTW5FgBxEsflBo,5626
anyio/_core/_contextmanagers.py,sha256=YInBCabiEeS-UaP_Jdxa1CaFC71ETPW8HZTHIM8Rsc8,7215
anyio/_core/_eventloop.py,sha256=c2EdcBX-xnKwxPcC4Pjn3_qG9I-x4IWFO2R9RqCGjM4,6448
anyio/_core/_exceptions.py,sha256=Y3aq-Wxd7Q2HqwSg7nZPvRsHEuGazv_qeet6gqEBdPk,4407
anyio/_core/_fileio.py,sha256=uc7t10Vb-If7GbdWM_zFf-ajUe6uek63fSt7IBLlZW0,25731
anyio/_core/_resources.py,sha256=NbmU5O5UX3xEyACnkmYX28Fmwdl-f-ny0tHym26e0w0,435
anyio/_core/_signals.py,sha256=mjTBB2hTKNPRlU0IhnijeQedpWOGERDiMjSlJQsFrug,1016
anyio/_core/_sockets.py,sha256=RBXHcUqZt5gg_-OOfgHVv8uq2FSKk1uVUzTdpjBoI1o,34977
anyio/_core/_streams.py,sha256=FczFwIgDpnkK0bODWJXMpsUJYdvAD04kaUaGzJU8DK0,1806
anyio/_core/_subprocesses.py,sha256=EXm5igL7dj55iYkPlbYVAqtbqxJxjU-6OndSTIx9SRg,8047
anyio/_core/_synchronization.py,sha256=MgVVqFzvt580tHC31LiOcq1G6aryut--xRG4Ff8KwxQ,20869
anyio/_core/_tasks.py,sha256=pVB7K6AAulzUM8YgXAeqNZG44nSyZ1bYJjH8GznC00I,5435
anyio/_core/_tempfile.py,sha256=lHb7CW4FyIlpkf5ADAf4VmLHCKwEHF9nxqNyBCFFUiA,19697
anyio/_core/_testing.py,sha256=u7MPqGXwpTxqI7hclSdNA30z2GH1Nw258uwKvy_RfBg,2340
anyio/_core/_typedattr.py,sha256=P4ozZikn3-DbpoYcvyghS_FOYAgbmUxeoU8-L_07pZM,2508
anyio/abc/__init__.py,sha256=6mWhcl_pGXhrgZVHP_TCfMvIXIOp9mroEFM90fYCU_U,2869
anyio/abc/__pycache__/__init__.cpython-312.pyc,,
anyio/abc/__pycache__/_eventloop.cpython-312.pyc,,
anyio/abc/__pycache__/_resources.cpython-312.pyc,,
anyio/abc/__pycache__/_sockets.cpython-312.pyc,,
anyio/abc/__pycache__/_streams.cpython-312.pyc,,
anyio/abc/__pycache__/_subprocesses.cpython-312.pyc,,
anyio/abc/__pycache__/_tasks.cpython-312.pyc,,
anyio/abc/__pycache__/_testing.cpython-312.pyc,,
anyio/abc/_eventloop.py,sha256=GlzgB3UJGgG6Kr7olpjOZ-o00PghecXuofVDQ_5611Q,10749
anyio/abc/_resources.py,sha256=DrYvkNN1hH6Uvv5_5uKySvDsnknGVDe8FCKfko0VtN8,783
anyio/abc/_sockets.py,sha256=ECTY0jLEF18gryANHR3vFzXzGdZ-xPwELq1QdgOb0Jo,13258
anyio/abc/_streams.py,sha256=005GKSCXGprxnhucILboSqc2JFovECZk9m3p-qqxXVc,7640
anyio/abc/_subprocesses.py,sha256=cumAPJTktOQtw63IqG0lDpyZqu_l1EElvQHMiwJgL08,2067
anyio/abc/_tasks.py,sha256=KC7wrciE48AINOI-AhPutnFhe1ewfP7QnamFlDzqesQ,3721
anyio/abc/_testing.py,sha256=tBJUzkSfOXJw23fe8qSJ03kJlShOYjjaEyFB6k6MYT8,1821
anyio/from_thread.py,sha256=L-0w1HxJ6BSb-KuVi57k5Tkc3yzQrx3QK5tAxMPcY-0,19141
anyio/functools.py,sha256=HWj7GBEmc0Z-mZg3uok7Z7ZJn0rEC_0Pzbt0nYUDaTQ,10973
anyio/lowlevel.py,sha256=AyKLVK3LaWSoK39LkCKxE4_GDMLKZBNqTrLUgk63y80,5158
anyio/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/pytest_plugin.py,sha256=3jAFQn0jv_pyoWE2GBBlHaj9sqXj4e8vob0_hgrsXE8,10244
anyio/streams/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
anyio/streams/__pycache__/__init__.cpython-312.pyc,,
anyio/streams/__pycache__/buffered.cpython-312.pyc,,
anyio/streams/__pycache__/file.cpython-312.pyc,,
anyio/streams/__pycache__/memory.cpython-312.pyc,,
anyio/streams/__pycache__/stapled.cpython-312.pyc,,
anyio/streams/__pycache__/text.cpython-312.pyc,,
anyio/streams/__pycache__/tls.cpython-312.pyc,,
anyio/streams/buffered.py,sha256=2R3PeJhe4EXrdYqz44Y6-Eg9R6DrmlsYrP36Ir43-po,6263
anyio/streams/file.py,sha256=4WZ7XGz5WNu39FQHvqbe__TQ0HDP9OOhgO1mk9iVpVU,4470
anyio/streams/memory.py,sha256=F0zwzvFJKAhX_LRZGoKzzqDC2oMM-f-yyTBrEYEGOaU,10740
anyio/streams/stapled.py,sha256=T8Xqwf8K6EgURPxbt1N4i7A8BAk-gScv-GRhjLXIf_o,4390
anyio/streams/text.py,sha256=BcVAGJw1VRvtIqnv-o0Rb0pwH7p8vwlvl21xHq522ag,5765
anyio/streams/tls.py,sha256=Jpxy0Mfbcp1BxHCwE-YjSSFaLnIBbnnwur-excYThs4,15368
anyio/to_interpreter.py,sha256=_mLngrMy97TMR6VbW4Y6YzDUk9ZuPcQMPlkuyRh3C9k,7100
anyio/to_process.py,sha256=J7gAA_YOuoHqnpDAf5fm1Qu6kOmTzdFbiDNvnV755vk,9798
anyio/to_thread.py,sha256=menEgXYmUV7Fjg_9WqCV95P9MAtQS8BzPGGcWB_QnfQ,2687

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (80.9.0)
Root-Is-Purelib: true
Tag: py3-none-any

View File

@@ -0,0 +1,2 @@
[pytest11]
anyio = anyio.pytest_plugin

View File

@@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2018 Alex Grönholm
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -0,0 +1,111 @@
from __future__ import annotations
from ._core._contextmanagers import AsyncContextManagerMixin as AsyncContextManagerMixin
from ._core._contextmanagers import ContextManagerMixin as ContextManagerMixin
from ._core._eventloop import current_time as current_time
from ._core._eventloop import get_all_backends as get_all_backends
from ._core._eventloop import get_available_backends as get_available_backends
from ._core._eventloop import get_cancelled_exc_class as get_cancelled_exc_class
from ._core._eventloop import run as run
from ._core._eventloop import sleep as sleep
from ._core._eventloop import sleep_forever as sleep_forever
from ._core._eventloop import sleep_until as sleep_until
from ._core._exceptions import BrokenResourceError as BrokenResourceError
from ._core._exceptions import BrokenWorkerInterpreter as BrokenWorkerInterpreter
from ._core._exceptions import BrokenWorkerProcess as BrokenWorkerProcess
from ._core._exceptions import BusyResourceError as BusyResourceError
from ._core._exceptions import ClosedResourceError as ClosedResourceError
from ._core._exceptions import ConnectionFailed as ConnectionFailed
from ._core._exceptions import DelimiterNotFound as DelimiterNotFound
from ._core._exceptions import EndOfStream as EndOfStream
from ._core._exceptions import IncompleteRead as IncompleteRead
from ._core._exceptions import NoEventLoopError as NoEventLoopError
from ._core._exceptions import RunFinishedError as RunFinishedError
from ._core._exceptions import TypedAttributeLookupError as TypedAttributeLookupError
from ._core._exceptions import WouldBlock as WouldBlock
from ._core._fileio import AsyncFile as AsyncFile
from ._core._fileio import Path as Path
from ._core._fileio import open_file as open_file
from ._core._fileio import wrap_file as wrap_file
from ._core._resources import aclose_forcefully as aclose_forcefully
from ._core._signals import open_signal_receiver as open_signal_receiver
from ._core._sockets import TCPConnectable as TCPConnectable
from ._core._sockets import UNIXConnectable as UNIXConnectable
from ._core._sockets import as_connectable as as_connectable
from ._core._sockets import connect_tcp as connect_tcp
from ._core._sockets import connect_unix as connect_unix
from ._core._sockets import create_connected_udp_socket as create_connected_udp_socket
from ._core._sockets import (
create_connected_unix_datagram_socket as create_connected_unix_datagram_socket,
)
from ._core._sockets import create_tcp_listener as create_tcp_listener
from ._core._sockets import create_udp_socket as create_udp_socket
from ._core._sockets import create_unix_datagram_socket as create_unix_datagram_socket
from ._core._sockets import create_unix_listener as create_unix_listener
from ._core._sockets import getaddrinfo as getaddrinfo
from ._core._sockets import getnameinfo as getnameinfo
from ._core._sockets import notify_closing as notify_closing
from ._core._sockets import wait_readable as wait_readable
from ._core._sockets import wait_socket_readable as wait_socket_readable
from ._core._sockets import wait_socket_writable as wait_socket_writable
from ._core._sockets import wait_writable as wait_writable
from ._core._streams import create_memory_object_stream as create_memory_object_stream
from ._core._subprocesses import open_process as open_process
from ._core._subprocesses import run_process as run_process
from ._core._synchronization import CapacityLimiter as CapacityLimiter
from ._core._synchronization import (
CapacityLimiterStatistics as CapacityLimiterStatistics,
)
from ._core._synchronization import Condition as Condition
from ._core._synchronization import ConditionStatistics as ConditionStatistics
from ._core._synchronization import Event as Event
from ._core._synchronization import EventStatistics as EventStatistics
from ._core._synchronization import Lock as Lock
from ._core._synchronization import LockStatistics as LockStatistics
from ._core._synchronization import ResourceGuard as ResourceGuard
from ._core._synchronization import Semaphore as Semaphore
from ._core._synchronization import SemaphoreStatistics as SemaphoreStatistics
from ._core._tasks import TASK_STATUS_IGNORED as TASK_STATUS_IGNORED
from ._core._tasks import CancelScope as CancelScope
from ._core._tasks import create_task_group as create_task_group
from ._core._tasks import current_effective_deadline as current_effective_deadline
from ._core._tasks import fail_after as fail_after
from ._core._tasks import move_on_after as move_on_after
from ._core._tempfile import NamedTemporaryFile as NamedTemporaryFile
from ._core._tempfile import SpooledTemporaryFile as SpooledTemporaryFile
from ._core._tempfile import TemporaryDirectory as TemporaryDirectory
from ._core._tempfile import TemporaryFile as TemporaryFile
from ._core._tempfile import gettempdir as gettempdir
from ._core._tempfile import gettempdirb as gettempdirb
from ._core._tempfile import mkdtemp as mkdtemp
from ._core._tempfile import mkstemp as mkstemp
from ._core._testing import TaskInfo as TaskInfo
from ._core._testing import get_current_task as get_current_task
from ._core._testing import get_running_tasks as get_running_tasks
from ._core._testing import wait_all_tasks_blocked as wait_all_tasks_blocked
from ._core._typedattr import TypedAttributeProvider as TypedAttributeProvider
from ._core._typedattr import TypedAttributeSet as TypedAttributeSet
from ._core._typedattr import typed_attribute as typed_attribute
# Re-export imports so they look like they live directly in this package
for __value in list(locals().values()):
if getattr(__value, "__module__", "").startswith("anyio."):
__value.__module__ = __name__
del __value
def __getattr__(attr: str) -> type[BrokenWorkerInterpreter]:
"""Support deprecated aliases."""
if attr == "BrokenWorkerIntepreter":
import warnings
warnings.warn(
"The 'BrokenWorkerIntepreter' alias is deprecated, use 'BrokenWorkerInterpreter' instead.",
DeprecationWarning,
stacklevel=2,
)
return BrokenWorkerInterpreter
raise AttributeError(f"module {__name__!r} has no attribute {attr!r}")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,167 @@
from __future__ import annotations
import asyncio
import socket
import threading
from collections.abc import Callable
from selectors import EVENT_READ, EVENT_WRITE, DefaultSelector
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from _typeshed import FileDescriptorLike
_selector_lock = threading.Lock()
_selector: Selector | None = None
class Selector:
def __init__(self) -> None:
self._thread = threading.Thread(target=self.run, name="AnyIO socket selector")
self._selector = DefaultSelector()
self._send, self._receive = socket.socketpair()
self._send.setblocking(False)
self._receive.setblocking(False)
# This somewhat reduces the amount of memory wasted queueing up data
# for wakeups. With these settings, maximum number of 1-byte sends
# before getting BlockingIOError:
# Linux 4.8: 6
# macOS (darwin 15.5): 1
# Windows 10: 525347
# Windows you're weird. (And on Windows setting SNDBUF to 0 makes send
# blocking, even on non-blocking sockets, so don't do that.)
self._receive.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
self._send.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
# On Windows this is a TCP socket so this might matter. On other
# platforms this fails b/c AF_UNIX sockets aren't actually TCP.
try:
self._send.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except OSError:
pass
self._selector.register(self._receive, EVENT_READ)
self._closed = False
def start(self) -> None:
self._thread.start()
threading._register_atexit(self._stop) # type: ignore[attr-defined]
def _stop(self) -> None:
global _selector
self._closed = True
self._notify_self()
self._send.close()
self._thread.join()
self._selector.unregister(self._receive)
self._receive.close()
self._selector.close()
_selector = None
assert not self._selector.get_map(), (
"selector still has registered file descriptors after shutdown"
)
def _notify_self(self) -> None:
try:
self._send.send(b"\x00")
except BlockingIOError:
pass
def add_reader(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
loop = asyncio.get_running_loop()
try:
key = self._selector.get_key(fd)
except KeyError:
self._selector.register(fd, EVENT_READ, {EVENT_READ: (loop, callback)})
else:
if EVENT_READ in key.data:
raise ValueError(
"this file descriptor is already registered for reading"
)
key.data[EVENT_READ] = loop, callback
self._selector.modify(fd, key.events | EVENT_READ, key.data)
self._notify_self()
def add_writer(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
loop = asyncio.get_running_loop()
try:
key = self._selector.get_key(fd)
except KeyError:
self._selector.register(fd, EVENT_WRITE, {EVENT_WRITE: (loop, callback)})
else:
if EVENT_WRITE in key.data:
raise ValueError(
"this file descriptor is already registered for writing"
)
key.data[EVENT_WRITE] = loop, callback
self._selector.modify(fd, key.events | EVENT_WRITE, key.data)
self._notify_self()
def remove_reader(self, fd: FileDescriptorLike) -> bool:
try:
key = self._selector.get_key(fd)
except KeyError:
return False
if new_events := key.events ^ EVENT_READ:
del key.data[EVENT_READ]
self._selector.modify(fd, new_events, key.data)
else:
self._selector.unregister(fd)
return True
def remove_writer(self, fd: FileDescriptorLike) -> bool:
try:
key = self._selector.get_key(fd)
except KeyError:
return False
if new_events := key.events ^ EVENT_WRITE:
del key.data[EVENT_WRITE]
self._selector.modify(fd, new_events, key.data)
else:
self._selector.unregister(fd)
return True
def run(self) -> None:
while not self._closed:
for key, events in self._selector.select():
if key.fileobj is self._receive:
try:
while self._receive.recv(4096):
pass
except BlockingIOError:
pass
continue
if events & EVENT_READ:
loop, callback = key.data[EVENT_READ]
self.remove_reader(key.fd)
try:
loop.call_soon_threadsafe(callback)
except RuntimeError:
pass # the loop was already closed
if events & EVENT_WRITE:
loop, callback = key.data[EVENT_WRITE]
self.remove_writer(key.fd)
try:
loop.call_soon_threadsafe(callback)
except RuntimeError:
pass # the loop was already closed
def get_selector() -> Selector:
global _selector
with _selector_lock:
if _selector is None:
_selector = Selector()
_selector.start()
return _selector

View File

@@ -0,0 +1,200 @@
from __future__ import annotations
from abc import abstractmethod
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from inspect import isasyncgen, iscoroutine, isgenerator
from types import TracebackType
from typing import Protocol, TypeVar, cast, final
_T_co = TypeVar("_T_co", covariant=True)
_ExitT_co = TypeVar("_ExitT_co", covariant=True, bound="bool | None")
class _SupportsCtxMgr(Protocol[_T_co, _ExitT_co]):
def __contextmanager__(self) -> AbstractContextManager[_T_co, _ExitT_co]: ...
class _SupportsAsyncCtxMgr(Protocol[_T_co, _ExitT_co]):
def __asynccontextmanager__(
self,
) -> AbstractAsyncContextManager[_T_co, _ExitT_co]: ...
class ContextManagerMixin:
"""
Mixin class providing context manager functionality via a generator-based
implementation.
This class allows you to implement a context manager via :meth:`__contextmanager__`
which should return a generator. The mechanics are meant to mirror those of
:func:`@contextmanager <contextlib.contextmanager>`.
.. note:: Classes using this mix-in are not reentrant as context managers, meaning
that once you enter it, you can't re-enter before first exiting it.
.. seealso:: :doc:`contextmanagers`
"""
__cm: AbstractContextManager[object, bool | None] | None = None
@final
def __enter__(self: _SupportsCtxMgr[_T_co, bool | None]) -> _T_co:
# Needed for mypy to assume self still has the __cm member
assert isinstance(self, ContextManagerMixin)
if self.__cm is not None:
raise RuntimeError(
f"this {self.__class__.__qualname__} has already been entered"
)
cm = self.__contextmanager__()
if not isinstance(cm, AbstractContextManager):
if isgenerator(cm):
raise TypeError(
"__contextmanager__() returned a generator object instead of "
"a context manager. Did you forget to add the @contextmanager "
"decorator?"
)
raise TypeError(
f"__contextmanager__() did not return a context manager object, "
f"but {cm.__class__!r}"
)
if cm is self:
raise TypeError(
f"{self.__class__.__qualname__}.__contextmanager__() returned "
f"self. Did you forget to add the @contextmanager decorator and a "
f"'yield' statement?"
)
value = cm.__enter__()
self.__cm = cm
return value
@final
def __exit__(
self: _SupportsCtxMgr[object, _ExitT_co],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> _ExitT_co:
# Needed for mypy to assume self still has the __cm member
assert isinstance(self, ContextManagerMixin)
if self.__cm is None:
raise RuntimeError(
f"this {self.__class__.__qualname__} has not been entered yet"
)
# Prevent circular references
cm = self.__cm
del self.__cm
return cast(_ExitT_co, cm.__exit__(exc_type, exc_val, exc_tb))
@abstractmethod
def __contextmanager__(self) -> AbstractContextManager[object, bool | None]:
"""
Implement your context manager logic here.
This method **must** be decorated with
:func:`@contextmanager <contextlib.contextmanager>`.
.. note:: Remember that the ``yield`` will raise any exception raised in the
enclosed context block, so use a ``finally:`` block to clean up resources!
:return: a context manager object
"""
class AsyncContextManagerMixin:
"""
Mixin class providing async context manager functionality via a generator-based
implementation.
This class allows you to implement a context manager via
:meth:`__asynccontextmanager__`. The mechanics are meant to mirror those of
:func:`@asynccontextmanager <contextlib.asynccontextmanager>`.
.. note:: Classes using this mix-in are not reentrant as context managers, meaning
that once you enter it, you can't re-enter before first exiting it.
.. seealso:: :doc:`contextmanagers`
"""
__cm: AbstractAsyncContextManager[object, bool | None] | None = None
@final
async def __aenter__(self: _SupportsAsyncCtxMgr[_T_co, bool | None]) -> _T_co:
# Needed for mypy to assume self still has the __cm member
assert isinstance(self, AsyncContextManagerMixin)
if self.__cm is not None:
raise RuntimeError(
f"this {self.__class__.__qualname__} has already been entered"
)
cm = self.__asynccontextmanager__()
if not isinstance(cm, AbstractAsyncContextManager):
if isasyncgen(cm):
raise TypeError(
"__asynccontextmanager__() returned an async generator instead of "
"an async context manager. Did you forget to add the "
"@asynccontextmanager decorator?"
)
elif iscoroutine(cm):
cm.close()
raise TypeError(
"__asynccontextmanager__() returned a coroutine object instead of "
"an async context manager. Did you forget to add the "
"@asynccontextmanager decorator and a 'yield' statement?"
)
raise TypeError(
f"__asynccontextmanager__() did not return an async context manager, "
f"but {cm.__class__!r}"
)
if cm is self:
raise TypeError(
f"{self.__class__.__qualname__}.__asynccontextmanager__() returned "
f"self. Did you forget to add the @asynccontextmanager decorator and a "
f"'yield' statement?"
)
value = await cm.__aenter__()
self.__cm = cm
return value
@final
async def __aexit__(
self: _SupportsAsyncCtxMgr[object, _ExitT_co],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> _ExitT_co:
assert isinstance(self, AsyncContextManagerMixin)
if self.__cm is None:
raise RuntimeError(
f"this {self.__class__.__qualname__} has not been entered yet"
)
# Prevent circular references
cm = self.__cm
del self.__cm
return cast(_ExitT_co, await cm.__aexit__(exc_type, exc_val, exc_tb))
@abstractmethod
def __asynccontextmanager__(
self,
) -> AbstractAsyncContextManager[object, bool | None]:
"""
Implement your async context manager logic here.
This method **must** be decorated with
:func:`@asynccontextmanager <contextlib.asynccontextmanager>`.
.. note:: Remember that the ``yield`` will raise any exception raised in the
enclosed context block, so use a ``finally:`` block to clean up resources!
:return: an async context manager object
"""

View File

@@ -0,0 +1,234 @@
from __future__ import annotations
import math
import sys
import threading
from collections.abc import Awaitable, Callable, Generator
from contextlib import contextmanager
from contextvars import Token
from importlib import import_module
from typing import TYPE_CHECKING, Any, TypeVar
from ._exceptions import NoEventLoopError
if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack
sniffio: Any
try:
import sniffio
except ModuleNotFoundError:
sniffio = None
if TYPE_CHECKING:
from ..abc import AsyncBackend
# This must be updated when new backends are introduced
BACKENDS = "asyncio", "trio"
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
threadlocals = threading.local()
loaded_backends: dict[str, type[AsyncBackend]] = {}
def run(
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
*args: Unpack[PosArgsT],
backend: str = "asyncio",
backend_options: dict[str, Any] | None = None,
) -> T_Retval:
"""
Run the given coroutine function in an asynchronous event loop.
The current thread must not be already running an event loop.
:param func: a coroutine function
:param args: positional arguments to ``func``
:param backend: name of the asynchronous event loop implementation currently
either ``asyncio`` or ``trio``
:param backend_options: keyword arguments to call the backend ``run()``
implementation with (documented :ref:`here <backend options>`)
:return: the return value of the coroutine function
:raises RuntimeError: if an asynchronous event loop is already running in this
thread
:raises LookupError: if the named backend is not found
"""
if asynclib_name := current_async_library():
raise RuntimeError(f"Already running {asynclib_name} in this thread")
try:
async_backend = get_async_backend(backend)
except ImportError as exc:
raise LookupError(f"No such backend: {backend}") from exc
token = None
if asynclib_name is None:
# Since we're in control of the event loop, we can cache the name of the async
# library
token = set_current_async_library(backend)
try:
backend_options = backend_options or {}
return async_backend.run(func, args, {}, backend_options)
finally:
reset_current_async_library(token)
async def sleep(delay: float) -> None:
"""
Pause the current task for the specified duration.
:param delay: the duration, in seconds
"""
return await get_async_backend().sleep(delay)
async def sleep_forever() -> None:
"""
Pause the current task until it's cancelled.
This is a shortcut for ``sleep(math.inf)``.
.. versionadded:: 3.1
"""
await sleep(math.inf)
async def sleep_until(deadline: float) -> None:
"""
Pause the current task until the given time.
:param deadline: the absolute time to wake up at (according to the internal
monotonic clock of the event loop)
.. versionadded:: 3.1
"""
now = current_time()
await sleep(max(deadline - now, 0))
def current_time() -> float:
"""
Return the current value of the event loop's internal clock.
:return: the clock value (seconds)
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
return get_async_backend().current_time()
def get_all_backends() -> tuple[str, ...]:
"""Return a tuple of the names of all built-in backends."""
return BACKENDS
def get_available_backends() -> tuple[str, ...]:
"""
Test for the availability of built-in backends.
:return a tuple of the built-in backend names that were successfully imported
.. versionadded:: 4.12
"""
available_backends: list[str] = []
for backend_name in get_all_backends():
try:
get_async_backend(backend_name)
except ImportError:
continue
available_backends.append(backend_name)
return tuple(available_backends)
def get_cancelled_exc_class() -> type[BaseException]:
"""
Return the current async library's cancellation exception class.
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
return get_async_backend().cancelled_exception_class()
#
# Private API
#
@contextmanager
def claim_worker_thread(
backend_class: type[AsyncBackend], token: object
) -> Generator[Any, None, None]:
from ..lowlevel import EventLoopToken
threadlocals.current_token = EventLoopToken(backend_class, token)
try:
yield
finally:
del threadlocals.current_token
def get_async_backend(asynclib_name: str | None = None) -> type[AsyncBackend]:
if asynclib_name is None:
asynclib_name = current_async_library()
if not asynclib_name:
raise NoEventLoopError(
f"Not currently running on any asynchronous event loop. "
f"Available async backends: {', '.join(get_all_backends())}"
)
# We use our own dict instead of sys.modules to get the already imported back-end
# class because the appropriate modules in sys.modules could potentially be only
# partially initialized
try:
return loaded_backends[asynclib_name]
except KeyError:
module = import_module(f"anyio._backends._{asynclib_name}")
loaded_backends[asynclib_name] = module.backend_class
return module.backend_class
def current_async_library() -> str | None:
if sniffio is None:
# If sniffio is not installed, we assume we're either running asyncio or nothing
import asyncio
try:
asyncio.get_running_loop()
return "asyncio"
except RuntimeError:
pass
else:
try:
return sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
pass
return None
def set_current_async_library(asynclib_name: str | None) -> Token | None:
# no-op if sniffio is not installed
if sniffio is None:
return None
return sniffio.current_async_library_cvar.set(asynclib_name)
def reset_current_async_library(token: Token | None) -> None:
if token is not None:
sniffio.current_async_library_cvar.reset(token)

View File

@@ -0,0 +1,156 @@
from __future__ import annotations
import sys
from collections.abc import Generator
from textwrap import dedent
from typing import Any
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
class BrokenResourceError(Exception):
"""
Raised when trying to use a resource that has been rendered unusable due to external
causes (e.g. a send stream whose peer has disconnected).
"""
class BrokenWorkerProcess(Exception):
"""
Raised by :meth:`~anyio.to_process.run_sync` if the worker process terminates abruptly or
otherwise misbehaves.
"""
class BrokenWorkerInterpreter(Exception):
"""
Raised by :meth:`~anyio.to_interpreter.run_sync` if an unexpected exception is
raised in the subinterpreter.
"""
def __init__(self, excinfo: Any):
# This was adapted from concurrent.futures.interpreter.ExecutionFailed
msg = excinfo.formatted
if not msg:
if excinfo.type and excinfo.msg:
msg = f"{excinfo.type.__name__}: {excinfo.msg}"
else:
msg = excinfo.type.__name__ or excinfo.msg
super().__init__(msg)
self.excinfo = excinfo
def __str__(self) -> str:
try:
formatted = self.excinfo.errdisplay
except Exception:
return super().__str__()
else:
return dedent(
f"""
{super().__str__()}
Uncaught in the interpreter:
{formatted}
""".strip()
)
class BusyResourceError(Exception):
"""
Raised when two tasks are trying to read from or write to the same resource
concurrently.
"""
def __init__(self, action: str):
super().__init__(f"Another task is already {action} this resource")
class ClosedResourceError(Exception):
"""Raised when trying to use a resource that has been closed."""
class ConnectionFailed(OSError):
"""
Raised when a connection attempt fails.
.. note:: This class inherits from :exc:`OSError` for backwards compatibility.
"""
def iterate_exceptions(
exception: BaseException,
) -> Generator[BaseException, None, None]:
if isinstance(exception, BaseExceptionGroup):
for exc in exception.exceptions:
yield from iterate_exceptions(exc)
else:
yield exception
class DelimiterNotFound(Exception):
"""
Raised during
:meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the
maximum number of bytes has been read without the delimiter being found.
"""
def __init__(self, max_bytes: int) -> None:
super().__init__(
f"The delimiter was not found among the first {max_bytes} bytes"
)
class EndOfStream(Exception):
"""
Raised when trying to read from a stream that has been closed from the other end.
"""
class IncompleteRead(Exception):
"""
Raised during
:meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_exactly` or
:meth:`~anyio.streams.buffered.BufferedByteReceiveStream.receive_until` if the
connection is closed before the requested amount of bytes has been read.
"""
def __init__(self) -> None:
super().__init__(
"The stream was closed before the read operation could be completed"
)
class TypedAttributeLookupError(LookupError):
"""
Raised by :meth:`~anyio.TypedAttributeProvider.extra` when the given typed attribute
is not found and no default value has been given.
"""
class WouldBlock(Exception):
"""Raised by ``X_nowait`` functions if ``X()`` would block."""
class NoEventLoopError(RuntimeError):
"""
Raised by several functions that require an event loop to be running in the current
thread when there is no running event loop.
This is also raised by :func:`.from_thread.run` and :func:`.from_thread.run_sync`
if not calling from an AnyIO worker thread, and no ``token`` was passed.
"""
class RunFinishedError(RuntimeError):
"""
Raised by :func:`.from_thread.run` and :func:`.from_thread.run_sync` if the event
loop associated with the explicitly passed token has already finished.
"""
def __init__(self) -> None:
super().__init__(
"The event loop associated with the given token has already finished"
)

View File

@@ -0,0 +1,797 @@
from __future__ import annotations
import os
import pathlib
import sys
from collections.abc import (
AsyncIterator,
Callable,
Iterable,
Iterator,
Sequence,
)
from dataclasses import dataclass
from functools import partial
from os import PathLike
from typing import (
IO,
TYPE_CHECKING,
Any,
AnyStr,
ClassVar,
Final,
Generic,
overload,
)
from .. import to_thread
from ..abc import AsyncResource
if TYPE_CHECKING:
from types import ModuleType
from _typeshed import OpenBinaryMode, OpenTextMode, ReadableBuffer, WriteableBuffer
else:
ReadableBuffer = OpenBinaryMode = OpenTextMode = WriteableBuffer = object
class AsyncFile(AsyncResource, Generic[AnyStr]):
"""
An asynchronous file object.
This class wraps a standard file object and provides async friendly versions of the
following blocking methods (where available on the original file object):
* read
* read1
* readline
* readlines
* readinto
* readinto1
* write
* writelines
* truncate
* seek
* tell
* flush
All other methods are directly passed through.
This class supports the asynchronous context manager protocol which closes the
underlying file at the end of the context block.
This class also supports asynchronous iteration::
async with await open_file(...) as f:
async for line in f:
print(line)
"""
def __init__(self, fp: IO[AnyStr]) -> None:
self._fp: Any = fp
def __getattr__(self, name: str) -> object:
return getattr(self._fp, name)
@property
def wrapped(self) -> IO[AnyStr]:
"""The wrapped file object."""
return self._fp
async def __aiter__(self) -> AsyncIterator[AnyStr]:
while True:
line = await self.readline()
if line:
yield line
else:
break
async def aclose(self) -> None:
return await to_thread.run_sync(self._fp.close)
async def read(self, size: int = -1) -> AnyStr:
return await to_thread.run_sync(self._fp.read, size)
async def read1(self: AsyncFile[bytes], size: int = -1) -> bytes:
return await to_thread.run_sync(self._fp.read1, size)
async def readline(self) -> AnyStr:
return await to_thread.run_sync(self._fp.readline)
async def readlines(self) -> list[AnyStr]:
return await to_thread.run_sync(self._fp.readlines)
async def readinto(self: AsyncFile[bytes], b: WriteableBuffer) -> int:
return await to_thread.run_sync(self._fp.readinto, b)
async def readinto1(self: AsyncFile[bytes], b: WriteableBuffer) -> int:
return await to_thread.run_sync(self._fp.readinto1, b)
@overload
async def write(self: AsyncFile[bytes], b: ReadableBuffer) -> int: ...
@overload
async def write(self: AsyncFile[str], b: str) -> int: ...
async def write(self, b: ReadableBuffer | str) -> int:
return await to_thread.run_sync(self._fp.write, b)
@overload
async def writelines(
self: AsyncFile[bytes], lines: Iterable[ReadableBuffer]
) -> None: ...
@overload
async def writelines(self: AsyncFile[str], lines: Iterable[str]) -> None: ...
async def writelines(self, lines: Iterable[ReadableBuffer] | Iterable[str]) -> None:
return await to_thread.run_sync(self._fp.writelines, lines)
async def truncate(self, size: int | None = None) -> int:
return await to_thread.run_sync(self._fp.truncate, size)
async def seek(self, offset: int, whence: int | None = os.SEEK_SET) -> int:
return await to_thread.run_sync(self._fp.seek, offset, whence)
async def tell(self) -> int:
return await to_thread.run_sync(self._fp.tell)
async def flush(self) -> None:
return await to_thread.run_sync(self._fp.flush)
@overload
async def open_file(
file: str | PathLike[str] | int,
mode: OpenBinaryMode,
buffering: int = ...,
encoding: str | None = ...,
errors: str | None = ...,
newline: str | None = ...,
closefd: bool = ...,
opener: Callable[[str, int], int] | None = ...,
) -> AsyncFile[bytes]: ...
@overload
async def open_file(
file: str | PathLike[str] | int,
mode: OpenTextMode = ...,
buffering: int = ...,
encoding: str | None = ...,
errors: str | None = ...,
newline: str | None = ...,
closefd: bool = ...,
opener: Callable[[str, int], int] | None = ...,
) -> AsyncFile[str]: ...
async def open_file(
file: str | PathLike[str] | int,
mode: str = "r",
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
closefd: bool = True,
opener: Callable[[str, int], int] | None = None,
) -> AsyncFile[Any]:
"""
Open a file asynchronously.
The arguments are exactly the same as for the builtin :func:`open`.
:return: an asynchronous file object
"""
fp = await to_thread.run_sync(
open, file, mode, buffering, encoding, errors, newline, closefd, opener
)
return AsyncFile(fp)
def wrap_file(file: IO[AnyStr]) -> AsyncFile[AnyStr]:
"""
Wrap an existing file as an asynchronous file.
:param file: an existing file-like object
:return: an asynchronous file object
"""
return AsyncFile(file)
@dataclass(eq=False)
class _PathIterator(AsyncIterator["Path"]):
iterator: Iterator[PathLike[str]]
async def __anext__(self) -> Path:
nextval = await to_thread.run_sync(
next, self.iterator, None, abandon_on_cancel=True
)
if nextval is None:
raise StopAsyncIteration from None
return Path(nextval)
class Path:
"""
An asynchronous version of :class:`pathlib.Path`.
This class cannot be substituted for :class:`pathlib.Path` or
:class:`pathlib.PurePath`, but it is compatible with the :class:`os.PathLike`
interface.
It implements the Python 3.10 version of :class:`pathlib.Path` interface, except for
the deprecated :meth:`~pathlib.Path.link_to` method.
Some methods may be unavailable or have limited functionality, based on the Python
version:
* :meth:`~pathlib.Path.copy` (available on Python 3.14 or later)
* :meth:`~pathlib.Path.copy_into` (available on Python 3.14 or later)
* :meth:`~pathlib.Path.from_uri` (available on Python 3.13 or later)
* :meth:`~pathlib.PurePath.full_match` (available on Python 3.13 or later)
* :attr:`~pathlib.Path.info` (available on Python 3.14 or later)
* :meth:`~pathlib.Path.is_junction` (available on Python 3.12 or later)
* :meth:`~pathlib.PurePath.match` (the ``case_sensitive`` parameter is only
available on Python 3.13 or later)
* :meth:`~pathlib.Path.move` (available on Python 3.14 or later)
* :meth:`~pathlib.Path.move_into` (available on Python 3.14 or later)
* :meth:`~pathlib.PurePath.relative_to` (the ``walk_up`` parameter is only available
on Python 3.12 or later)
* :meth:`~pathlib.Path.walk` (available on Python 3.12 or later)
Any methods that do disk I/O need to be awaited on. These methods are:
* :meth:`~pathlib.Path.absolute`
* :meth:`~pathlib.Path.chmod`
* :meth:`~pathlib.Path.cwd`
* :meth:`~pathlib.Path.exists`
* :meth:`~pathlib.Path.expanduser`
* :meth:`~pathlib.Path.group`
* :meth:`~pathlib.Path.hardlink_to`
* :meth:`~pathlib.Path.home`
* :meth:`~pathlib.Path.is_block_device`
* :meth:`~pathlib.Path.is_char_device`
* :meth:`~pathlib.Path.is_dir`
* :meth:`~pathlib.Path.is_fifo`
* :meth:`~pathlib.Path.is_file`
* :meth:`~pathlib.Path.is_junction`
* :meth:`~pathlib.Path.is_mount`
* :meth:`~pathlib.Path.is_socket`
* :meth:`~pathlib.Path.is_symlink`
* :meth:`~pathlib.Path.lchmod`
* :meth:`~pathlib.Path.lstat`
* :meth:`~pathlib.Path.mkdir`
* :meth:`~pathlib.Path.open`
* :meth:`~pathlib.Path.owner`
* :meth:`~pathlib.Path.read_bytes`
* :meth:`~pathlib.Path.read_text`
* :meth:`~pathlib.Path.readlink`
* :meth:`~pathlib.Path.rename`
* :meth:`~pathlib.Path.replace`
* :meth:`~pathlib.Path.resolve`
* :meth:`~pathlib.Path.rmdir`
* :meth:`~pathlib.Path.samefile`
* :meth:`~pathlib.Path.stat`
* :meth:`~pathlib.Path.symlink_to`
* :meth:`~pathlib.Path.touch`
* :meth:`~pathlib.Path.unlink`
* :meth:`~pathlib.Path.walk`
* :meth:`~pathlib.Path.write_bytes`
* :meth:`~pathlib.Path.write_text`
Additionally, the following methods return an async iterator yielding
:class:`~.Path` objects:
* :meth:`~pathlib.Path.glob`
* :meth:`~pathlib.Path.iterdir`
* :meth:`~pathlib.Path.rglob`
"""
__slots__ = "_path", "__weakref__"
__weakref__: Any
def __init__(self, *args: str | PathLike[str]) -> None:
self._path: Final[pathlib.Path] = pathlib.Path(*args)
def __fspath__(self) -> str:
return self._path.__fspath__()
def __str__(self) -> str:
return self._path.__str__()
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.as_posix()!r})"
def __bytes__(self) -> bytes:
return self._path.__bytes__()
def __hash__(self) -> int:
return self._path.__hash__()
def __eq__(self, other: object) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__eq__(target)
def __lt__(self, other: pathlib.PurePath | Path) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__lt__(target)
def __le__(self, other: pathlib.PurePath | Path) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__le__(target)
def __gt__(self, other: pathlib.PurePath | Path) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__gt__(target)
def __ge__(self, other: pathlib.PurePath | Path) -> bool:
target = other._path if isinstance(other, Path) else other
return self._path.__ge__(target)
def __truediv__(self, other: str | PathLike[str]) -> Path:
return Path(self._path / other)
def __rtruediv__(self, other: str | PathLike[str]) -> Path:
return Path(other) / self
@property
def parts(self) -> tuple[str, ...]:
return self._path.parts
@property
def drive(self) -> str:
return self._path.drive
@property
def root(self) -> str:
return self._path.root
@property
def anchor(self) -> str:
return self._path.anchor
@property
def parents(self) -> Sequence[Path]:
return tuple(Path(p) for p in self._path.parents)
@property
def parent(self) -> Path:
return Path(self._path.parent)
@property
def name(self) -> str:
return self._path.name
@property
def suffix(self) -> str:
return self._path.suffix
@property
def suffixes(self) -> list[str]:
return self._path.suffixes
@property
def stem(self) -> str:
return self._path.stem
async def absolute(self) -> Path:
path = await to_thread.run_sync(self._path.absolute)
return Path(path)
def as_posix(self) -> str:
return self._path.as_posix()
def as_uri(self) -> str:
return self._path.as_uri()
if sys.version_info >= (3, 13):
parser: ClassVar[ModuleType] = pathlib.Path.parser
@classmethod
def from_uri(cls, uri: str) -> Path:
return Path(pathlib.Path.from_uri(uri))
def full_match(
self, path_pattern: str, *, case_sensitive: bool | None = None
) -> bool:
return self._path.full_match(path_pattern, case_sensitive=case_sensitive)
def match(
self, path_pattern: str, *, case_sensitive: bool | None = None
) -> bool:
return self._path.match(path_pattern, case_sensitive=case_sensitive)
else:
def match(self, path_pattern: str) -> bool:
return self._path.match(path_pattern)
if sys.version_info >= (3, 14):
@property
def info(self) -> Any: # TODO: add return type annotation when Typeshed gets it
return self._path.info
async def copy(
self,
target: str | os.PathLike[str],
*,
follow_symlinks: bool = True,
preserve_metadata: bool = False,
) -> Path:
func = partial(
self._path.copy,
follow_symlinks=follow_symlinks,
preserve_metadata=preserve_metadata,
)
return Path(await to_thread.run_sync(func, pathlib.Path(target)))
async def copy_into(
self,
target_dir: str | os.PathLike[str],
*,
follow_symlinks: bool = True,
preserve_metadata: bool = False,
) -> Path:
func = partial(
self._path.copy_into,
follow_symlinks=follow_symlinks,
preserve_metadata=preserve_metadata,
)
return Path(await to_thread.run_sync(func, pathlib.Path(target_dir)))
async def move(self, target: str | os.PathLike[str]) -> Path:
# Upstream does not handle anyio.Path properly as a PathLike
target = pathlib.Path(target)
return Path(await to_thread.run_sync(self._path.move, target))
async def move_into(
self,
target_dir: str | os.PathLike[str],
) -> Path:
return Path(await to_thread.run_sync(self._path.move_into, target_dir))
def is_relative_to(self, other: str | PathLike[str]) -> bool:
try:
self.relative_to(other)
return True
except ValueError:
return False
async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None:
func = partial(os.chmod, follow_symlinks=follow_symlinks)
return await to_thread.run_sync(func, self._path, mode)
@classmethod
async def cwd(cls) -> Path:
path = await to_thread.run_sync(pathlib.Path.cwd)
return cls(path)
async def exists(self) -> bool:
return await to_thread.run_sync(self._path.exists, abandon_on_cancel=True)
async def expanduser(self) -> Path:
return Path(
await to_thread.run_sync(self._path.expanduser, abandon_on_cancel=True)
)
if sys.version_info < (3, 12):
# Python 3.11 and earlier
def glob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.glob(pattern)
return _PathIterator(gen)
elif (3, 12) <= sys.version_info < (3, 13):
# changed in Python 3.12:
# - The case_sensitive parameter was added.
def glob(
self,
pattern: str,
*,
case_sensitive: bool | None = None,
) -> AsyncIterator[Path]:
gen = self._path.glob(pattern, case_sensitive=case_sensitive)
return _PathIterator(gen)
elif sys.version_info >= (3, 13):
# Changed in Python 3.13:
# - The recurse_symlinks parameter was added.
# - The pattern parameter accepts a path-like object.
def glob( # type: ignore[misc] # mypy doesn't allow for differing signatures in a conditional block
self,
pattern: str | PathLike[str],
*,
case_sensitive: bool | None = None,
recurse_symlinks: bool = False,
) -> AsyncIterator[Path]:
gen = self._path.glob(
pattern, # type: ignore[arg-type]
case_sensitive=case_sensitive,
recurse_symlinks=recurse_symlinks,
)
return _PathIterator(gen)
async def group(self) -> str:
return await to_thread.run_sync(self._path.group, abandon_on_cancel=True)
async def hardlink_to(
self, target: str | bytes | PathLike[str] | PathLike[bytes]
) -> None:
if isinstance(target, Path):
target = target._path
await to_thread.run_sync(os.link, target, self)
@classmethod
async def home(cls) -> Path:
home_path = await to_thread.run_sync(pathlib.Path.home)
return cls(home_path)
def is_absolute(self) -> bool:
return self._path.is_absolute()
async def is_block_device(self) -> bool:
return await to_thread.run_sync(
self._path.is_block_device, abandon_on_cancel=True
)
async def is_char_device(self) -> bool:
return await to_thread.run_sync(
self._path.is_char_device, abandon_on_cancel=True
)
async def is_dir(self) -> bool:
return await to_thread.run_sync(self._path.is_dir, abandon_on_cancel=True)
async def is_fifo(self) -> bool:
return await to_thread.run_sync(self._path.is_fifo, abandon_on_cancel=True)
async def is_file(self) -> bool:
return await to_thread.run_sync(self._path.is_file, abandon_on_cancel=True)
if sys.version_info >= (3, 12):
async def is_junction(self) -> bool:
return await to_thread.run_sync(self._path.is_junction)
async def is_mount(self) -> bool:
return await to_thread.run_sync(
os.path.ismount, self._path, abandon_on_cancel=True
)
def is_reserved(self) -> bool:
return self._path.is_reserved()
async def is_socket(self) -> bool:
return await to_thread.run_sync(self._path.is_socket, abandon_on_cancel=True)
async def is_symlink(self) -> bool:
return await to_thread.run_sync(self._path.is_symlink, abandon_on_cancel=True)
async def iterdir(self) -> AsyncIterator[Path]:
gen = (
self._path.iterdir()
if sys.version_info < (3, 13)
else await to_thread.run_sync(self._path.iterdir, abandon_on_cancel=True)
)
async for path in _PathIterator(gen):
yield path
def joinpath(self, *args: str | PathLike[str]) -> Path:
return Path(self._path.joinpath(*args))
async def lchmod(self, mode: int) -> None:
await to_thread.run_sync(self._path.lchmod, mode)
async def lstat(self) -> os.stat_result:
return await to_thread.run_sync(self._path.lstat, abandon_on_cancel=True)
async def mkdir(
self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False
) -> None:
await to_thread.run_sync(self._path.mkdir, mode, parents, exist_ok)
@overload
async def open(
self,
mode: OpenBinaryMode,
buffering: int = ...,
encoding: str | None = ...,
errors: str | None = ...,
newline: str | None = ...,
) -> AsyncFile[bytes]: ...
@overload
async def open(
self,
mode: OpenTextMode = ...,
buffering: int = ...,
encoding: str | None = ...,
errors: str | None = ...,
newline: str | None = ...,
) -> AsyncFile[str]: ...
async def open(
self,
mode: str = "r",
buffering: int = -1,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
) -> AsyncFile[Any]:
fp = await to_thread.run_sync(
self._path.open, mode, buffering, encoding, errors, newline
)
return AsyncFile(fp)
async def owner(self) -> str:
return await to_thread.run_sync(self._path.owner, abandon_on_cancel=True)
async def read_bytes(self) -> bytes:
return await to_thread.run_sync(self._path.read_bytes)
async def read_text(
self, encoding: str | None = None, errors: str | None = None
) -> str:
return await to_thread.run_sync(self._path.read_text, encoding, errors)
if sys.version_info >= (3, 12):
def relative_to(
self, *other: str | PathLike[str], walk_up: bool = False
) -> Path:
# relative_to() should work with any PathLike but it doesn't
others = [pathlib.Path(other) for other in other]
return Path(self._path.relative_to(*others, walk_up=walk_up))
else:
def relative_to(self, *other: str | PathLike[str]) -> Path:
return Path(self._path.relative_to(*other))
async def readlink(self) -> Path:
target = await to_thread.run_sync(os.readlink, self._path)
return Path(target)
async def rename(self, target: str | pathlib.PurePath | Path) -> Path:
if isinstance(target, Path):
target = target._path
await to_thread.run_sync(self._path.rename, target)
return Path(target)
async def replace(self, target: str | pathlib.PurePath | Path) -> Path:
if isinstance(target, Path):
target = target._path
await to_thread.run_sync(self._path.replace, target)
return Path(target)
async def resolve(self, strict: bool = False) -> Path:
func = partial(self._path.resolve, strict=strict)
return Path(await to_thread.run_sync(func, abandon_on_cancel=True))
if sys.version_info < (3, 12):
# Pre Python 3.12
def rglob(self, pattern: str) -> AsyncIterator[Path]:
gen = self._path.rglob(pattern)
return _PathIterator(gen)
elif (3, 12) <= sys.version_info < (3, 13):
# Changed in Python 3.12:
# - The case_sensitive parameter was added.
def rglob(
self, pattern: str, *, case_sensitive: bool | None = None
) -> AsyncIterator[Path]:
gen = self._path.rglob(pattern, case_sensitive=case_sensitive)
return _PathIterator(gen)
elif sys.version_info >= (3, 13):
# Changed in Python 3.13:
# - The recurse_symlinks parameter was added.
# - The pattern parameter accepts a path-like object.
def rglob( # type: ignore[misc] # mypy doesn't allow for differing signatures in a conditional block
self,
pattern: str | PathLike[str],
*,
case_sensitive: bool | None = None,
recurse_symlinks: bool = False,
) -> AsyncIterator[Path]:
gen = self._path.rglob(
pattern, # type: ignore[arg-type]
case_sensitive=case_sensitive,
recurse_symlinks=recurse_symlinks,
)
return _PathIterator(gen)
async def rmdir(self) -> None:
await to_thread.run_sync(self._path.rmdir)
async def samefile(self, other_path: str | PathLike[str]) -> bool:
if isinstance(other_path, Path):
other_path = other_path._path
return await to_thread.run_sync(
self._path.samefile, other_path, abandon_on_cancel=True
)
async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result:
func = partial(os.stat, follow_symlinks=follow_symlinks)
return await to_thread.run_sync(func, self._path, abandon_on_cancel=True)
async def symlink_to(
self,
target: str | bytes | PathLike[str] | PathLike[bytes],
target_is_directory: bool = False,
) -> None:
if isinstance(target, Path):
target = target._path
await to_thread.run_sync(self._path.symlink_to, target, target_is_directory)
async def touch(self, mode: int = 0o666, exist_ok: bool = True) -> None:
await to_thread.run_sync(self._path.touch, mode, exist_ok)
async def unlink(self, missing_ok: bool = False) -> None:
try:
await to_thread.run_sync(self._path.unlink)
except FileNotFoundError:
if not missing_ok:
raise
if sys.version_info >= (3, 12):
async def walk(
self,
top_down: bool = True,
on_error: Callable[[OSError], object] | None = None,
follow_symlinks: bool = False,
) -> AsyncIterator[tuple[Path, list[str], list[str]]]:
def get_next_value() -> tuple[pathlib.Path, list[str], list[str]] | None:
try:
return next(gen)
except StopIteration:
return None
gen = self._path.walk(top_down, on_error, follow_symlinks)
while True:
value = await to_thread.run_sync(get_next_value)
if value is None:
return
root, dirs, paths = value
yield Path(root), dirs, paths
def with_name(self, name: str) -> Path:
return Path(self._path.with_name(name))
def with_stem(self, stem: str) -> Path:
return Path(self._path.with_name(stem + self._path.suffix))
def with_suffix(self, suffix: str) -> Path:
return Path(self._path.with_suffix(suffix))
def with_segments(self, *pathsegments: str | PathLike[str]) -> Path:
return Path(*pathsegments)
async def write_bytes(self, data: bytes) -> int:
return await to_thread.run_sync(self._path.write_bytes, data)
async def write_text(
self,
data: str,
encoding: str | None = None,
errors: str | None = None,
newline: str | None = None,
) -> int:
# Path.write_text() does not support the "newline" parameter before Python 3.10
def sync_write_text() -> int:
with self._path.open(
"w", encoding=encoding, errors=errors, newline=newline
) as fp:
return fp.write(data)
return await to_thread.run_sync(sync_write_text)
PathLike.register(Path)

View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from ..abc import AsyncResource
from ._tasks import CancelScope
async def aclose_forcefully(resource: AsyncResource) -> None:
"""
Close an asynchronous resource in a cancelled scope.
Doing this closes the resource without waiting on anything.
:param resource: the resource to close
"""
with CancelScope() as scope:
scope.cancel()
await resource.aclose()

View File

@@ -0,0 +1,29 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from contextlib import AbstractContextManager
from signal import Signals
from ._eventloop import get_async_backend
def open_signal_receiver(
*signals: Signals,
) -> AbstractContextManager[AsyncIterator[Signals]]:
"""
Start receiving operating system signals.
:param signals: signals to receive (e.g. ``signal.SIGINT``)
:return: an asynchronous context manager for an asynchronous iterator which yields
signal numbers
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
.. warning:: Windows does not support signals natively so it is best to avoid
relying on this in cross-platform applications.
.. warning:: On asyncio, this permanently replaces any previous signal handler for
the given signals, as set via :meth:`~asyncio.loop.add_signal_handler`.
"""
return get_async_backend().open_signal_receiver(*signals)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import math
from typing import TypeVar
from warnings import warn
from ..streams.memory import (
MemoryObjectReceiveStream,
MemoryObjectSendStream,
_MemoryObjectStreamState,
)
T_Item = TypeVar("T_Item")
class create_memory_object_stream(
tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]],
):
"""
Create a memory object stream.
The stream's item type can be annotated like
:func:`create_memory_object_stream[T_Item]`.
:param max_buffer_size: number of items held in the buffer until ``send()`` starts
blocking
:param item_type: old way of marking the streams with the right generic type for
static typing (does nothing on AnyIO 4)
.. deprecated:: 4.0
Use ``create_memory_object_stream[YourItemType](...)`` instead.
:return: a tuple of (send stream, receive stream)
"""
def __new__( # type: ignore[misc]
cls, max_buffer_size: float = 0, item_type: object = None
) -> tuple[MemoryObjectSendStream[T_Item], MemoryObjectReceiveStream[T_Item]]:
if max_buffer_size != math.inf and not isinstance(max_buffer_size, int):
raise ValueError("max_buffer_size must be either an integer or math.inf")
if max_buffer_size < 0:
raise ValueError("max_buffer_size cannot be negative")
if item_type is not None:
warn(
"The item_type argument has been deprecated in AnyIO 4.0. "
"Use create_memory_object_stream[YourItemType](...) instead.",
DeprecationWarning,
stacklevel=2,
)
state = _MemoryObjectStreamState[T_Item](max_buffer_size)
return (MemoryObjectSendStream(state), MemoryObjectReceiveStream(state))

View File

@@ -0,0 +1,202 @@
from __future__ import annotations
import sys
from collections.abc import AsyncIterable, Iterable, Mapping, Sequence
from io import BytesIO
from os import PathLike
from subprocess import PIPE, CalledProcessError, CompletedProcess
from typing import IO, Any, Union, cast
from ..abc import Process
from ._eventloop import get_async_backend
from ._tasks import create_task_group
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
StrOrBytesPath: TypeAlias = Union[str, bytes, "PathLike[str]", "PathLike[bytes]"]
async def run_process(
command: StrOrBytesPath | Sequence[StrOrBytesPath],
*,
input: bytes | None = None,
stdin: int | IO[Any] | None = None,
stdout: int | IO[Any] | None = PIPE,
stderr: int | IO[Any] | None = PIPE,
check: bool = True,
cwd: StrOrBytesPath | None = None,
env: Mapping[str, str] | None = None,
startupinfo: Any = None,
creationflags: int = 0,
start_new_session: bool = False,
pass_fds: Sequence[int] = (),
user: str | int | None = None,
group: str | int | None = None,
extra_groups: Iterable[str | int] | None = None,
umask: int = -1,
) -> CompletedProcess[bytes]:
"""
Run an external command in a subprocess and wait until it completes.
.. seealso:: :func:`subprocess.run`
:param command: either a string to pass to the shell, or an iterable of strings
containing the executable name or path and its arguments
:param input: bytes passed to the standard input of the subprocess
:param stdin: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
a file-like object, or `None`; ``input`` overrides this
:param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
a file-like object, or `None`
:param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
:data:`subprocess.STDOUT`, a file-like object, or `None`
:param check: if ``True``, raise :exc:`~subprocess.CalledProcessError` if the
process terminates with a return code other than 0
:param cwd: If not ``None``, change the working directory to this before running the
command
:param env: if not ``None``, this mapping replaces the inherited environment
variables from the parent process
:param startupinfo: an instance of :class:`subprocess.STARTUPINFO` that can be used
to specify process startup parameters (Windows only)
:param creationflags: flags that can be used to control the creation of the
subprocess (see :class:`subprocess.Popen` for the specifics)
:param start_new_session: if ``true`` the setsid() system call will be made in the
child process prior to the execution of the subprocess. (POSIX only)
:param pass_fds: sequence of file descriptors to keep open between the parent and
child processes. (POSIX only)
:param user: effective user to run the process as (Python >= 3.9, POSIX only)
:param group: effective group to run the process as (Python >= 3.9, POSIX only)
:param extra_groups: supplementary groups to set in the subprocess (Python >= 3.9,
POSIX only)
:param umask: if not negative, this umask is applied in the child process before
running the given command (Python >= 3.9, POSIX only)
:return: an object representing the completed process
:raises ~subprocess.CalledProcessError: if ``check`` is ``True`` and the process
exits with a nonzero return code
"""
async def drain_stream(stream: AsyncIterable[bytes], index: int) -> None:
buffer = BytesIO()
async for chunk in stream:
buffer.write(chunk)
stream_contents[index] = buffer.getvalue()
if stdin is not None and input is not None:
raise ValueError("only one of stdin and input is allowed")
async with await open_process(
command,
stdin=PIPE if input else stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
startupinfo=startupinfo,
creationflags=creationflags,
start_new_session=start_new_session,
pass_fds=pass_fds,
user=user,
group=group,
extra_groups=extra_groups,
umask=umask,
) as process:
stream_contents: list[bytes | None] = [None, None]
async with create_task_group() as tg:
if process.stdout:
tg.start_soon(drain_stream, process.stdout, 0)
if process.stderr:
tg.start_soon(drain_stream, process.stderr, 1)
if process.stdin and input:
await process.stdin.send(input)
await process.stdin.aclose()
await process.wait()
output, errors = stream_contents
if check and process.returncode != 0:
raise CalledProcessError(cast(int, process.returncode), command, output, errors)
return CompletedProcess(command, cast(int, process.returncode), output, errors)
async def open_process(
command: StrOrBytesPath | Sequence[StrOrBytesPath],
*,
stdin: int | IO[Any] | None = PIPE,
stdout: int | IO[Any] | None = PIPE,
stderr: int | IO[Any] | None = PIPE,
cwd: StrOrBytesPath | None = None,
env: Mapping[str, str] | None = None,
startupinfo: Any = None,
creationflags: int = 0,
start_new_session: bool = False,
pass_fds: Sequence[int] = (),
user: str | int | None = None,
group: str | int | None = None,
extra_groups: Iterable[str | int] | None = None,
umask: int = -1,
) -> Process:
"""
Start an external command in a subprocess.
.. seealso:: :class:`subprocess.Popen`
:param command: either a string to pass to the shell, or an iterable of strings
containing the executable name or path and its arguments
:param stdin: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`, a
file-like object, or ``None``
:param stdout: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
a file-like object, or ``None``
:param stderr: one of :data:`subprocess.PIPE`, :data:`subprocess.DEVNULL`,
:data:`subprocess.STDOUT`, a file-like object, or ``None``
:param cwd: If not ``None``, the working directory is changed before executing
:param env: If env is not ``None``, it must be a mapping that defines the
environment variables for the new process
:param creationflags: flags that can be used to control the creation of the
subprocess (see :class:`subprocess.Popen` for the specifics)
:param startupinfo: an instance of :class:`subprocess.STARTUPINFO` that can be used
to specify process startup parameters (Windows only)
:param start_new_session: if ``true`` the setsid() system call will be made in the
child process prior to the execution of the subprocess. (POSIX only)
:param pass_fds: sequence of file descriptors to keep open between the parent and
child processes. (POSIX only)
:param user: effective user to run the process as (POSIX only)
:param group: effective group to run the process as (POSIX only)
:param extra_groups: supplementary groups to set in the subprocess (POSIX only)
:param umask: if not negative, this umask is applied in the child process before
running the given command (POSIX only)
:return: an asynchronous process object
"""
kwargs: dict[str, Any] = {}
if user is not None:
kwargs["user"] = user
if group is not None:
kwargs["group"] = group
if extra_groups is not None:
kwargs["extra_groups"] = group
if umask >= 0:
kwargs["umask"] = umask
return await get_async_backend().open_process(
command,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
startupinfo=startupinfo,
creationflags=creationflags,
start_new_session=start_new_session,
pass_fds=pass_fds,
**kwargs,
)

View File

@@ -0,0 +1,753 @@
from __future__ import annotations
import math
from collections import deque
from collections.abc import Callable
from dataclasses import dataclass
from types import TracebackType
from typing import TypeVar
from ..lowlevel import checkpoint_if_cancelled
from ._eventloop import get_async_backend
from ._exceptions import BusyResourceError, NoEventLoopError
from ._tasks import CancelScope
from ._testing import TaskInfo, get_current_task
T = TypeVar("T")
@dataclass(frozen=True)
class EventStatistics:
"""
:ivar int tasks_waiting: number of tasks waiting on :meth:`~.Event.wait`
"""
tasks_waiting: int
@dataclass(frozen=True)
class CapacityLimiterStatistics:
"""
:ivar int borrowed_tokens: number of tokens currently borrowed by tasks
:ivar float total_tokens: total number of available tokens
:ivar tuple borrowers: tasks or other objects currently holding tokens borrowed from
this limiter
:ivar int tasks_waiting: number of tasks waiting on
:meth:`~.CapacityLimiter.acquire` or
:meth:`~.CapacityLimiter.acquire_on_behalf_of`
"""
borrowed_tokens: int
total_tokens: float
borrowers: tuple[object, ...]
tasks_waiting: int
@dataclass(frozen=True)
class LockStatistics:
"""
:ivar bool locked: flag indicating if this lock is locked or not
:ivar ~anyio.TaskInfo owner: task currently holding the lock (or ``None`` if the
lock is not held by any task)
:ivar int tasks_waiting: number of tasks waiting on :meth:`~.Lock.acquire`
"""
locked: bool
owner: TaskInfo | None
tasks_waiting: int
@dataclass(frozen=True)
class ConditionStatistics:
"""
:ivar int tasks_waiting: number of tasks blocked on :meth:`~.Condition.wait`
:ivar ~anyio.LockStatistics lock_statistics: statistics of the underlying
:class:`~.Lock`
"""
tasks_waiting: int
lock_statistics: LockStatistics
@dataclass(frozen=True)
class SemaphoreStatistics:
"""
:ivar int tasks_waiting: number of tasks waiting on :meth:`~.Semaphore.acquire`
"""
tasks_waiting: int
class Event:
def __new__(cls) -> Event:
try:
return get_async_backend().create_event()
except NoEventLoopError:
return EventAdapter()
def set(self) -> None:
"""Set the flag, notifying all listeners."""
raise NotImplementedError
def is_set(self) -> bool:
"""Return ``True`` if the flag is set, ``False`` if not."""
raise NotImplementedError
async def wait(self) -> None:
"""
Wait until the flag has been set.
If the flag has already been set when this method is called, it returns
immediately.
"""
raise NotImplementedError
def statistics(self) -> EventStatistics:
"""Return statistics about the current state of this event."""
raise NotImplementedError
class EventAdapter(Event):
_internal_event: Event | None = None
_is_set: bool = False
def __new__(cls) -> EventAdapter:
return object.__new__(cls)
@property
def _event(self) -> Event:
if self._internal_event is None:
self._internal_event = get_async_backend().create_event()
if self._is_set:
self._internal_event.set()
return self._internal_event
def set(self) -> None:
if self._internal_event is None:
self._is_set = True
else:
self._event.set()
def is_set(self) -> bool:
if self._internal_event is None:
return self._is_set
return self._internal_event.is_set()
async def wait(self) -> None:
await self._event.wait()
def statistics(self) -> EventStatistics:
if self._internal_event is None:
return EventStatistics(tasks_waiting=0)
return self._internal_event.statistics()
class Lock:
def __new__(cls, *, fast_acquire: bool = False) -> Lock:
try:
return get_async_backend().create_lock(fast_acquire=fast_acquire)
except NoEventLoopError:
return LockAdapter(fast_acquire=fast_acquire)
async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.release()
async def acquire(self) -> None:
"""Acquire the lock."""
raise NotImplementedError
def acquire_nowait(self) -> None:
"""
Acquire the lock, without blocking.
:raises ~anyio.WouldBlock: if the operation would block
"""
raise NotImplementedError
def release(self) -> None:
"""Release the lock."""
raise NotImplementedError
def locked(self) -> bool:
"""Return True if the lock is currently held."""
raise NotImplementedError
def statistics(self) -> LockStatistics:
"""
Return statistics about the current state of this lock.
.. versionadded:: 3.0
"""
raise NotImplementedError
class LockAdapter(Lock):
_internal_lock: Lock | None = None
def __new__(cls, *, fast_acquire: bool = False) -> LockAdapter:
return object.__new__(cls)
def __init__(self, *, fast_acquire: bool = False):
self._fast_acquire = fast_acquire
@property
def _lock(self) -> Lock:
if self._internal_lock is None:
self._internal_lock = get_async_backend().create_lock(
fast_acquire=self._fast_acquire
)
return self._internal_lock
async def __aenter__(self) -> None:
await self._lock.acquire()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._internal_lock is not None:
self._internal_lock.release()
async def acquire(self) -> None:
"""Acquire the lock."""
await self._lock.acquire()
def acquire_nowait(self) -> None:
"""
Acquire the lock, without blocking.
:raises ~anyio.WouldBlock: if the operation would block
"""
self._lock.acquire_nowait()
def release(self) -> None:
"""Release the lock."""
self._lock.release()
def locked(self) -> bool:
"""Return True if the lock is currently held."""
return self._lock.locked()
def statistics(self) -> LockStatistics:
"""
Return statistics about the current state of this lock.
.. versionadded:: 3.0
"""
if self._internal_lock is None:
return LockStatistics(False, None, 0)
return self._internal_lock.statistics()
class Condition:
_owner_task: TaskInfo | None = None
def __init__(self, lock: Lock | None = None):
self._lock = lock or Lock()
self._waiters: deque[Event] = deque()
async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.release()
def _check_acquired(self) -> None:
if self._owner_task != get_current_task():
raise RuntimeError("The current task is not holding the underlying lock")
async def acquire(self) -> None:
"""Acquire the underlying lock."""
await self._lock.acquire()
self._owner_task = get_current_task()
def acquire_nowait(self) -> None:
"""
Acquire the underlying lock, without blocking.
:raises ~anyio.WouldBlock: if the operation would block
"""
self._lock.acquire_nowait()
self._owner_task = get_current_task()
def release(self) -> None:
"""Release the underlying lock."""
self._lock.release()
def locked(self) -> bool:
"""Return True if the lock is set."""
return self._lock.locked()
def notify(self, n: int = 1) -> None:
"""Notify exactly n listeners."""
self._check_acquired()
for _ in range(n):
try:
event = self._waiters.popleft()
except IndexError:
break
event.set()
def notify_all(self) -> None:
"""Notify all the listeners."""
self._check_acquired()
for event in self._waiters:
event.set()
self._waiters.clear()
async def wait(self) -> None:
"""Wait for a notification."""
await checkpoint_if_cancelled()
self._check_acquired()
event = Event()
self._waiters.append(event)
self.release()
try:
await event.wait()
except BaseException:
if not event.is_set():
self._waiters.remove(event)
raise
finally:
with CancelScope(shield=True):
await self.acquire()
async def wait_for(self, predicate: Callable[[], T]) -> T:
"""
Wait until a predicate becomes true.
:param predicate: a callable that returns a truthy value when the condition is
met
:return: the result of the predicate
.. versionadded:: 4.11.0
"""
while not (result := predicate()):
await self.wait()
return result
def statistics(self) -> ConditionStatistics:
"""
Return statistics about the current state of this condition.
.. versionadded:: 3.0
"""
return ConditionStatistics(len(self._waiters), self._lock.statistics())
class Semaphore:
def __new__(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> Semaphore:
try:
return get_async_backend().create_semaphore(
initial_value, max_value=max_value, fast_acquire=fast_acquire
)
except NoEventLoopError:
return SemaphoreAdapter(initial_value, max_value=max_value)
def __init__(
self,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
):
if not isinstance(initial_value, int):
raise TypeError("initial_value must be an integer")
if initial_value < 0:
raise ValueError("initial_value must be >= 0")
if max_value is not None:
if not isinstance(max_value, int):
raise TypeError("max_value must be an integer or None")
if max_value < initial_value:
raise ValueError(
"max_value must be equal to or higher than initial_value"
)
self._fast_acquire = fast_acquire
async def __aenter__(self) -> Semaphore:
await self.acquire()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.release()
async def acquire(self) -> None:
"""Decrement the semaphore value, blocking if necessary."""
raise NotImplementedError
def acquire_nowait(self) -> None:
"""
Acquire the underlying lock, without blocking.
:raises ~anyio.WouldBlock: if the operation would block
"""
raise NotImplementedError
def release(self) -> None:
"""Increment the semaphore value."""
raise NotImplementedError
@property
def value(self) -> int:
"""The current value of the semaphore."""
raise NotImplementedError
@property
def max_value(self) -> int | None:
"""The maximum value of the semaphore."""
raise NotImplementedError
def statistics(self) -> SemaphoreStatistics:
"""
Return statistics about the current state of this semaphore.
.. versionadded:: 3.0
"""
raise NotImplementedError
class SemaphoreAdapter(Semaphore):
_internal_semaphore: Semaphore | None = None
def __new__(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> SemaphoreAdapter:
return object.__new__(cls)
def __init__(
self,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> None:
super().__init__(initial_value, max_value=max_value, fast_acquire=fast_acquire)
self._initial_value = initial_value
self._max_value = max_value
@property
def _semaphore(self) -> Semaphore:
if self._internal_semaphore is None:
self._internal_semaphore = get_async_backend().create_semaphore(
self._initial_value, max_value=self._max_value
)
return self._internal_semaphore
async def acquire(self) -> None:
await self._semaphore.acquire()
def acquire_nowait(self) -> None:
self._semaphore.acquire_nowait()
def release(self) -> None:
self._semaphore.release()
@property
def value(self) -> int:
if self._internal_semaphore is None:
return self._initial_value
return self._semaphore.value
@property
def max_value(self) -> int | None:
return self._max_value
def statistics(self) -> SemaphoreStatistics:
if self._internal_semaphore is None:
return SemaphoreStatistics(tasks_waiting=0)
return self._semaphore.statistics()
class CapacityLimiter:
def __new__(cls, total_tokens: float) -> CapacityLimiter:
try:
return get_async_backend().create_capacity_limiter(total_tokens)
except NoEventLoopError:
return CapacityLimiterAdapter(total_tokens)
async def __aenter__(self) -> None:
raise NotImplementedError
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
raise NotImplementedError
@property
def total_tokens(self) -> float:
"""
The total number of tokens available for borrowing.
This is a read-write property. If the total number of tokens is increased, the
proportionate number of tasks waiting on this limiter will be granted their
tokens.
.. versionchanged:: 3.0
The property is now writable.
.. versionchanged:: 4.12
The value can now be set to 0.
"""
raise NotImplementedError
@total_tokens.setter
def total_tokens(self, value: float) -> None:
raise NotImplementedError
@property
def borrowed_tokens(self) -> int:
"""The number of tokens that have currently been borrowed."""
raise NotImplementedError
@property
def available_tokens(self) -> float:
"""The number of tokens currently available to be borrowed"""
raise NotImplementedError
def acquire_nowait(self) -> None:
"""
Acquire a token for the current task without waiting for one to become
available.
:raises ~anyio.WouldBlock: if there are no tokens available for borrowing
"""
raise NotImplementedError
def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
"""
Acquire a token without waiting for one to become available.
:param borrower: the entity borrowing a token
:raises ~anyio.WouldBlock: if there are no tokens available for borrowing
"""
raise NotImplementedError
async def acquire(self) -> None:
"""
Acquire a token for the current task, waiting if necessary for one to become
available.
"""
raise NotImplementedError
async def acquire_on_behalf_of(self, borrower: object) -> None:
"""
Acquire a token, waiting if necessary for one to become available.
:param borrower: the entity borrowing a token
"""
raise NotImplementedError
def release(self) -> None:
"""
Release the token held by the current task.
:raises RuntimeError: if the current task has not borrowed a token from this
limiter.
"""
raise NotImplementedError
def release_on_behalf_of(self, borrower: object) -> None:
"""
Release the token held by the given borrower.
:raises RuntimeError: if the borrower has not borrowed a token from this
limiter.
"""
raise NotImplementedError
def statistics(self) -> CapacityLimiterStatistics:
"""
Return statistics about the current state of this limiter.
.. versionadded:: 3.0
"""
raise NotImplementedError
class CapacityLimiterAdapter(CapacityLimiter):
_internal_limiter: CapacityLimiter | None = None
def __new__(cls, total_tokens: float) -> CapacityLimiterAdapter:
return object.__new__(cls)
def __init__(self, total_tokens: float) -> None:
self.total_tokens = total_tokens
@property
def _limiter(self) -> CapacityLimiter:
if self._internal_limiter is None:
self._internal_limiter = get_async_backend().create_capacity_limiter(
self._total_tokens
)
return self._internal_limiter
async def __aenter__(self) -> None:
await self._limiter.__aenter__()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
return await self._limiter.__aexit__(exc_type, exc_val, exc_tb)
@property
def total_tokens(self) -> float:
if self._internal_limiter is None:
return self._total_tokens
return self._internal_limiter.total_tokens
@total_tokens.setter
def total_tokens(self, value: float) -> None:
if not isinstance(value, int) and value is not math.inf:
raise TypeError("total_tokens must be an int or math.inf")
elif value < 1:
raise ValueError("total_tokens must be >= 1")
if self._internal_limiter is None:
self._total_tokens = value
return
self._limiter.total_tokens = value
@property
def borrowed_tokens(self) -> int:
if self._internal_limiter is None:
return 0
return self._internal_limiter.borrowed_tokens
@property
def available_tokens(self) -> float:
if self._internal_limiter is None:
return self._total_tokens
return self._internal_limiter.available_tokens
def acquire_nowait(self) -> None:
self._limiter.acquire_nowait()
def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
self._limiter.acquire_on_behalf_of_nowait(borrower)
async def acquire(self) -> None:
await self._limiter.acquire()
async def acquire_on_behalf_of(self, borrower: object) -> None:
await self._limiter.acquire_on_behalf_of(borrower)
def release(self) -> None:
self._limiter.release()
def release_on_behalf_of(self, borrower: object) -> None:
self._limiter.release_on_behalf_of(borrower)
def statistics(self) -> CapacityLimiterStatistics:
if self._internal_limiter is None:
return CapacityLimiterStatistics(
borrowed_tokens=0,
total_tokens=self.total_tokens,
borrowers=(),
tasks_waiting=0,
)
return self._internal_limiter.statistics()
class ResourceGuard:
"""
A context manager for ensuring that a resource is only used by a single task at a
time.
Entering this context manager while the previous has not exited it yet will trigger
:exc:`BusyResourceError`.
:param action: the action to guard against (visible in the :exc:`BusyResourceError`
when triggered, e.g. "Another task is already {action} this resource")
.. versionadded:: 4.1
"""
__slots__ = "action", "_guarded"
def __init__(self, action: str = "using"):
self.action: str = action
self._guarded = False
def __enter__(self) -> None:
if self._guarded:
raise BusyResourceError(self.action)
self._guarded = True
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._guarded = False

View File

@@ -0,0 +1,173 @@
from __future__ import annotations
import math
from collections.abc import Generator
from contextlib import contextmanager
from types import TracebackType
from ..abc._tasks import TaskGroup, TaskStatus
from ._eventloop import get_async_backend
class _IgnoredTaskStatus(TaskStatus[object]):
def started(self, value: object = None) -> None:
pass
TASK_STATUS_IGNORED = _IgnoredTaskStatus()
class CancelScope:
"""
Wraps a unit of work that can be made separately cancellable.
:param deadline: The time (clock value) when this scope is cancelled automatically
:param shield: ``True`` to shield the cancel scope from external cancellation
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
def __new__(
cls, *, deadline: float = math.inf, shield: bool = False
) -> CancelScope:
return get_async_backend().create_cancel_scope(shield=shield, deadline=deadline)
def cancel(self, reason: str | None = None) -> None:
"""
Cancel this scope immediately.
:param reason: a message describing the reason for the cancellation
"""
raise NotImplementedError
@property
def deadline(self) -> float:
"""
The time (clock value) when this scope is cancelled automatically.
Will be ``float('inf')`` if no timeout has been set.
"""
raise NotImplementedError
@deadline.setter
def deadline(self, value: float) -> None:
raise NotImplementedError
@property
def cancel_called(self) -> bool:
"""``True`` if :meth:`cancel` has been called."""
raise NotImplementedError
@property
def cancelled_caught(self) -> bool:
"""
``True`` if this scope suppressed a cancellation exception it itself raised.
This is typically used to check if any work was interrupted, or to see if the
scope was cancelled due to its deadline being reached. The value will, however,
only be ``True`` if the cancellation was triggered by the scope itself (and not
an outer scope).
"""
raise NotImplementedError
@property
def shield(self) -> bool:
"""
``True`` if this scope is shielded from external cancellation.
While a scope is shielded, it will not receive cancellations from outside.
"""
raise NotImplementedError
@shield.setter
def shield(self, value: bool) -> None:
raise NotImplementedError
def __enter__(self) -> CancelScope:
raise NotImplementedError
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
raise NotImplementedError
@contextmanager
def fail_after(
delay: float | None, shield: bool = False
) -> Generator[CancelScope, None, None]:
"""
Create a context manager which raises a :class:`TimeoutError` if does not finish in
time.
:param delay: maximum allowed time (in seconds) before raising the exception, or
``None`` to disable the timeout
:param shield: ``True`` to shield the cancel scope from external cancellation
:return: a context manager that yields a cancel scope
:rtype: :class:`~typing.ContextManager`\\[:class:`~anyio.CancelScope`\\]
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
current_time = get_async_backend().current_time
deadline = (current_time() + delay) if delay is not None else math.inf
with get_async_backend().create_cancel_scope(
deadline=deadline, shield=shield
) as cancel_scope:
yield cancel_scope
if cancel_scope.cancelled_caught and current_time() >= cancel_scope.deadline:
raise TimeoutError
def move_on_after(delay: float | None, shield: bool = False) -> CancelScope:
"""
Create a cancel scope with a deadline that expires after the given delay.
:param delay: maximum allowed time (in seconds) before exiting the context block, or
``None`` to disable the timeout
:param shield: ``True`` to shield the cancel scope from external cancellation
:return: a cancel scope
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
deadline = (
(get_async_backend().current_time() + delay) if delay is not None else math.inf
)
return get_async_backend().create_cancel_scope(deadline=deadline, shield=shield)
def current_effective_deadline() -> float:
"""
Return the nearest deadline among all the cancel scopes effective for the current
task.
:return: a clock value from the event loop's internal clock (or ``float('inf')`` if
there is no deadline in effect, or ``float('-inf')`` if the current scope has
been cancelled)
:rtype: float
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
return get_async_backend().current_effective_deadline()
def create_task_group() -> TaskGroup:
"""
Create a task group.
:return: a task group
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
return get_async_backend().create_task_group()

View File

@@ -0,0 +1,616 @@
from __future__ import annotations
import os
import sys
import tempfile
from collections.abc import Iterable
from io import BytesIO, TextIOWrapper
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Generic,
overload,
)
from .. import to_thread
from .._core._fileio import AsyncFile
from ..lowlevel import checkpoint_if_cancelled
if TYPE_CHECKING:
from _typeshed import OpenBinaryMode, OpenTextMode, ReadableBuffer, WriteableBuffer
class TemporaryFile(Generic[AnyStr]):
"""
An asynchronous temporary file that is automatically created and cleaned up.
This class provides an asynchronous context manager interface to a temporary file.
The file is created using Python's standard `tempfile.TemporaryFile` function in a
background thread, and is wrapped as an asynchronous file using `AsyncFile`.
:param mode: The mode in which the file is opened. Defaults to "w+b".
:param buffering: The buffering policy (-1 means the default buffering).
:param encoding: The encoding used to decode or encode the file. Only applicable in
text mode.
:param newline: Controls how universal newlines mode works (only applicable in text
mode).
:param suffix: The suffix for the temporary file name.
:param prefix: The prefix for the temporary file name.
:param dir: The directory in which the temporary file is created.
:param errors: The error handling scheme used for encoding/decoding errors.
"""
_async_file: AsyncFile[AnyStr]
@overload
def __init__(
self: TemporaryFile[bytes],
mode: OpenBinaryMode = ...,
buffering: int = ...,
encoding: str | None = ...,
newline: str | None = ...,
suffix: str | None = ...,
prefix: str | None = ...,
dir: str | None = ...,
*,
errors: str | None = ...,
): ...
@overload
def __init__(
self: TemporaryFile[str],
mode: OpenTextMode,
buffering: int = ...,
encoding: str | None = ...,
newline: str | None = ...,
suffix: str | None = ...,
prefix: str | None = ...,
dir: str | None = ...,
*,
errors: str | None = ...,
): ...
def __init__(
self,
mode: OpenTextMode | OpenBinaryMode = "w+b",
buffering: int = -1,
encoding: str | None = None,
newline: str | None = None,
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
*,
errors: str | None = None,
) -> None:
self.mode = mode
self.buffering = buffering
self.encoding = encoding
self.newline = newline
self.suffix: str | None = suffix
self.prefix: str | None = prefix
self.dir: str | None = dir
self.errors = errors
async def __aenter__(self) -> AsyncFile[AnyStr]:
fp = await to_thread.run_sync(
lambda: tempfile.TemporaryFile(
self.mode,
self.buffering,
self.encoding,
self.newline,
self.suffix,
self.prefix,
self.dir,
errors=self.errors,
)
)
self._async_file = AsyncFile(fp)
return self._async_file
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
await self._async_file.aclose()
class NamedTemporaryFile(Generic[AnyStr]):
"""
An asynchronous named temporary file that is automatically created and cleaned up.
This class provides an asynchronous context manager for a temporary file with a
visible name in the file system. It uses Python's standard
:func:`~tempfile.NamedTemporaryFile` function and wraps the file object with
:class:`AsyncFile` for asynchronous operations.
:param mode: The mode in which the file is opened. Defaults to "w+b".
:param buffering: The buffering policy (-1 means the default buffering).
:param encoding: The encoding used to decode or encode the file. Only applicable in
text mode.
:param newline: Controls how universal newlines mode works (only applicable in text
mode).
:param suffix: The suffix for the temporary file name.
:param prefix: The prefix for the temporary file name.
:param dir: The directory in which the temporary file is created.
:param delete: Whether to delete the file when it is closed.
:param errors: The error handling scheme used for encoding/decoding errors.
:param delete_on_close: (Python 3.12+) Whether to delete the file on close.
"""
_async_file: AsyncFile[AnyStr]
@overload
def __init__(
self: NamedTemporaryFile[bytes],
mode: OpenBinaryMode = ...,
buffering: int = ...,
encoding: str | None = ...,
newline: str | None = ...,
suffix: str | None = ...,
prefix: str | None = ...,
dir: str | None = ...,
delete: bool = ...,
*,
errors: str | None = ...,
delete_on_close: bool = ...,
): ...
@overload
def __init__(
self: NamedTemporaryFile[str],
mode: OpenTextMode,
buffering: int = ...,
encoding: str | None = ...,
newline: str | None = ...,
suffix: str | None = ...,
prefix: str | None = ...,
dir: str | None = ...,
delete: bool = ...,
*,
errors: str | None = ...,
delete_on_close: bool = ...,
): ...
def __init__(
self,
mode: OpenBinaryMode | OpenTextMode = "w+b",
buffering: int = -1,
encoding: str | None = None,
newline: str | None = None,
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
delete: bool = True,
*,
errors: str | None = None,
delete_on_close: bool = True,
) -> None:
self._params: dict[str, Any] = {
"mode": mode,
"buffering": buffering,
"encoding": encoding,
"newline": newline,
"suffix": suffix,
"prefix": prefix,
"dir": dir,
"delete": delete,
"errors": errors,
}
if sys.version_info >= (3, 12):
self._params["delete_on_close"] = delete_on_close
async def __aenter__(self) -> AsyncFile[AnyStr]:
fp = await to_thread.run_sync(
lambda: tempfile.NamedTemporaryFile(**self._params)
)
self._async_file = AsyncFile(fp)
return self._async_file
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
await self._async_file.aclose()
class SpooledTemporaryFile(AsyncFile[AnyStr]):
"""
An asynchronous spooled temporary file that starts in memory and is spooled to disk.
This class provides an asynchronous interface to a spooled temporary file, much like
Python's standard :class:`~tempfile.SpooledTemporaryFile`. It supports asynchronous
write operations and provides a method to force a rollover to disk.
:param max_size: Maximum size in bytes before the file is rolled over to disk.
:param mode: The mode in which the file is opened. Defaults to "w+b".
:param buffering: The buffering policy (-1 means the default buffering).
:param encoding: The encoding used to decode or encode the file (text mode only).
:param newline: Controls how universal newlines mode works (text mode only).
:param suffix: The suffix for the temporary file name.
:param prefix: The prefix for the temporary file name.
:param dir: The directory in which the temporary file is created.
:param errors: The error handling scheme used for encoding/decoding errors.
"""
_rolled: bool = False
@overload
def __init__(
self: SpooledTemporaryFile[bytes],
max_size: int = ...,
mode: OpenBinaryMode = ...,
buffering: int = ...,
encoding: str | None = ...,
newline: str | None = ...,
suffix: str | None = ...,
prefix: str | None = ...,
dir: str | None = ...,
*,
errors: str | None = ...,
): ...
@overload
def __init__(
self: SpooledTemporaryFile[str],
max_size: int = ...,
mode: OpenTextMode = ...,
buffering: int = ...,
encoding: str | None = ...,
newline: str | None = ...,
suffix: str | None = ...,
prefix: str | None = ...,
dir: str | None = ...,
*,
errors: str | None = ...,
): ...
def __init__(
self,
max_size: int = 0,
mode: OpenBinaryMode | OpenTextMode = "w+b",
buffering: int = -1,
encoding: str | None = None,
newline: str | None = None,
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
*,
errors: str | None = None,
) -> None:
self._tempfile_params: dict[str, Any] = {
"mode": mode,
"buffering": buffering,
"encoding": encoding,
"newline": newline,
"suffix": suffix,
"prefix": prefix,
"dir": dir,
"errors": errors,
}
self._max_size = max_size
if "b" in mode:
super().__init__(BytesIO()) # type: ignore[arg-type]
else:
super().__init__(
TextIOWrapper( # type: ignore[arg-type]
BytesIO(),
encoding=encoding,
errors=errors,
newline=newline,
write_through=True,
)
)
async def aclose(self) -> None:
if not self._rolled:
self._fp.close()
return
await super().aclose()
async def _check(self) -> None:
if self._rolled or self._fp.tell() <= self._max_size:
return
await self.rollover()
async def rollover(self) -> None:
if self._rolled:
return
self._rolled = True
buffer = self._fp
buffer.seek(0)
self._fp = await to_thread.run_sync(
lambda: tempfile.TemporaryFile(**self._tempfile_params)
)
await self.write(buffer.read())
buffer.close()
@property
def closed(self) -> bool:
return self._fp.closed
async def read(self, size: int = -1) -> AnyStr:
if not self._rolled:
await checkpoint_if_cancelled()
return self._fp.read(size)
return await super().read(size) # type: ignore[return-value]
async def read1(self: SpooledTemporaryFile[bytes], size: int = -1) -> bytes:
if not self._rolled:
await checkpoint_if_cancelled()
return self._fp.read1(size)
return await super().read1(size)
async def readline(self) -> AnyStr:
if not self._rolled:
await checkpoint_if_cancelled()
return self._fp.readline()
return await super().readline() # type: ignore[return-value]
async def readlines(self) -> list[AnyStr]:
if not self._rolled:
await checkpoint_if_cancelled()
return self._fp.readlines()
return await super().readlines() # type: ignore[return-value]
async def readinto(self: SpooledTemporaryFile[bytes], b: WriteableBuffer) -> int:
if not self._rolled:
await checkpoint_if_cancelled()
self._fp.readinto(b)
return await super().readinto(b)
async def readinto1(self: SpooledTemporaryFile[bytes], b: WriteableBuffer) -> int:
if not self._rolled:
await checkpoint_if_cancelled()
self._fp.readinto(b)
return await super().readinto1(b)
async def seek(self, offset: int, whence: int | None = os.SEEK_SET) -> int:
if not self._rolled:
await checkpoint_if_cancelled()
return self._fp.seek(offset, whence)
return await super().seek(offset, whence)
async def tell(self) -> int:
if not self._rolled:
await checkpoint_if_cancelled()
return self._fp.tell()
return await super().tell()
async def truncate(self, size: int | None = None) -> int:
if not self._rolled:
await checkpoint_if_cancelled()
return self._fp.truncate(size)
return await super().truncate(size)
@overload
async def write(self: SpooledTemporaryFile[bytes], b: ReadableBuffer) -> int: ...
@overload
async def write(self: SpooledTemporaryFile[str], b: str) -> int: ...
async def write(self, b: ReadableBuffer | str) -> int:
"""
Asynchronously write data to the spooled temporary file.
If the file has not yet been rolled over, the data is written synchronously,
and a rollover is triggered if the size exceeds the maximum size.
:param s: The data to write.
:return: The number of bytes written.
:raises RuntimeError: If the underlying file is not initialized.
"""
if not self._rolled:
await checkpoint_if_cancelled()
result = self._fp.write(b)
await self._check()
return result
return await super().write(b) # type: ignore[misc]
@overload
async def writelines(
self: SpooledTemporaryFile[bytes], lines: Iterable[ReadableBuffer]
) -> None: ...
@overload
async def writelines(
self: SpooledTemporaryFile[str], lines: Iterable[str]
) -> None: ...
async def writelines(self, lines: Iterable[str] | Iterable[ReadableBuffer]) -> None:
"""
Asynchronously write a list of lines to the spooled temporary file.
If the file has not yet been rolled over, the lines are written synchronously,
and a rollover is triggered if the size exceeds the maximum size.
:param lines: An iterable of lines to write.
:raises RuntimeError: If the underlying file is not initialized.
"""
if not self._rolled:
await checkpoint_if_cancelled()
result = self._fp.writelines(lines)
await self._check()
return result
return await super().writelines(lines) # type: ignore[misc]
class TemporaryDirectory(Generic[AnyStr]):
"""
An asynchronous temporary directory that is created and cleaned up automatically.
This class provides an asynchronous context manager for creating a temporary
directory. It wraps Python's standard :class:`~tempfile.TemporaryDirectory` to
perform directory creation and cleanup operations in a background thread.
:param suffix: Suffix to be added to the temporary directory name.
:param prefix: Prefix to be added to the temporary directory name.
:param dir: The parent directory where the temporary directory is created.
:param ignore_cleanup_errors: Whether to ignore errors during cleanup
(Python 3.10+).
:param delete: Whether to delete the directory upon closing (Python 3.12+).
"""
def __init__(
self,
suffix: AnyStr | None = None,
prefix: AnyStr | None = None,
dir: AnyStr | None = None,
*,
ignore_cleanup_errors: bool = False,
delete: bool = True,
) -> None:
self.suffix: AnyStr | None = suffix
self.prefix: AnyStr | None = prefix
self.dir: AnyStr | None = dir
self.ignore_cleanup_errors = ignore_cleanup_errors
self.delete = delete
self._tempdir: tempfile.TemporaryDirectory | None = None
async def __aenter__(self) -> str:
params: dict[str, Any] = {
"suffix": self.suffix,
"prefix": self.prefix,
"dir": self.dir,
}
if sys.version_info >= (3, 10):
params["ignore_cleanup_errors"] = self.ignore_cleanup_errors
if sys.version_info >= (3, 12):
params["delete"] = self.delete
self._tempdir = await to_thread.run_sync(
lambda: tempfile.TemporaryDirectory(**params)
)
return await to_thread.run_sync(self._tempdir.__enter__)
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if self._tempdir is not None:
await to_thread.run_sync(
self._tempdir.__exit__, exc_type, exc_value, traceback
)
async def cleanup(self) -> None:
if self._tempdir is not None:
await to_thread.run_sync(self._tempdir.cleanup)
@overload
async def mkstemp(
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
text: bool = False,
) -> tuple[int, str]: ...
@overload
async def mkstemp(
suffix: bytes | None = None,
prefix: bytes | None = None,
dir: bytes | None = None,
text: bool = False,
) -> tuple[int, bytes]: ...
async def mkstemp(
suffix: AnyStr | None = None,
prefix: AnyStr | None = None,
dir: AnyStr | None = None,
text: bool = False,
) -> tuple[int, str | bytes]:
"""
Asynchronously create a temporary file and return an OS-level handle and the file
name.
This function wraps `tempfile.mkstemp` and executes it in a background thread.
:param suffix: Suffix to be added to the file name.
:param prefix: Prefix to be added to the file name.
:param dir: Directory in which the temporary file is created.
:param text: Whether the file is opened in text mode.
:return: A tuple containing the file descriptor and the file name.
"""
return await to_thread.run_sync(tempfile.mkstemp, suffix, prefix, dir, text)
@overload
async def mkdtemp(
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
) -> str: ...
@overload
async def mkdtemp(
suffix: bytes | None = None,
prefix: bytes | None = None,
dir: bytes | None = None,
) -> bytes: ...
async def mkdtemp(
suffix: AnyStr | None = None,
prefix: AnyStr | None = None,
dir: AnyStr | None = None,
) -> str | bytes:
"""
Asynchronously create a temporary directory and return its path.
This function wraps `tempfile.mkdtemp` and executes it in a background thread.
:param suffix: Suffix to be added to the directory name.
:param prefix: Prefix to be added to the directory name.
:param dir: Parent directory where the temporary directory is created.
:return: The path of the created temporary directory.
"""
return await to_thread.run_sync(tempfile.mkdtemp, suffix, prefix, dir)
async def gettempdir() -> str:
"""
Asynchronously return the name of the directory used for temporary files.
This function wraps `tempfile.gettempdir` and executes it in a background thread.
:return: The path of the temporary directory as a string.
"""
return await to_thread.run_sync(tempfile.gettempdir)
async def gettempdirb() -> bytes:
"""
Asynchronously return the name of the directory used for temporary files in bytes.
This function wraps `tempfile.gettempdirb` and executes it in a background thread.
:return: The path of the temporary directory as bytes.
"""
return await to_thread.run_sync(tempfile.gettempdirb)

View File

@@ -0,0 +1,82 @@
from __future__ import annotations
from collections.abc import Awaitable, Generator
from typing import Any, cast
from ._eventloop import get_async_backend
class TaskInfo:
"""
Represents an asynchronous task.
:ivar int id: the unique identifier of the task
:ivar parent_id: the identifier of the parent task, if any
:vartype parent_id: Optional[int]
:ivar str name: the description of the task (if any)
:ivar ~collections.abc.Coroutine coro: the coroutine object of the task
"""
__slots__ = "_name", "id", "parent_id", "name", "coro"
def __init__(
self,
id: int,
parent_id: int | None,
name: str | None,
coro: Generator[Any, Any, Any] | Awaitable[Any],
):
func = get_current_task
self._name = f"{func.__module__}.{func.__qualname__}"
self.id: int = id
self.parent_id: int | None = parent_id
self.name: str | None = name
self.coro: Generator[Any, Any, Any] | Awaitable[Any] = coro
def __eq__(self, other: object) -> bool:
if isinstance(other, TaskInfo):
return self.id == other.id
return NotImplemented
def __hash__(self) -> int:
return hash(self.id)
def __repr__(self) -> str:
return f"{self.__class__.__name__}(id={self.id!r}, name={self.name!r})"
def has_pending_cancellation(self) -> bool:
"""
Return ``True`` if the task has a cancellation pending, ``False`` otherwise.
"""
return False
def get_current_task() -> TaskInfo:
"""
Return the current task.
:return: a representation of the current task
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
return get_async_backend().get_current_task()
def get_running_tasks() -> list[TaskInfo]:
"""
Return a list of running tasks in the current event loop.
:return: a list of task info objects
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
return cast("list[TaskInfo]", get_async_backend().get_running_tasks())
async def wait_all_tasks_blocked() -> None:
"""Wait until all other tasks are waiting for something."""
await get_async_backend().wait_all_tasks_blocked()

View File

@@ -0,0 +1,81 @@
from __future__ import annotations
from collections.abc import Callable, Mapping
from typing import Any, TypeVar, final, overload
from ._exceptions import TypedAttributeLookupError
T_Attr = TypeVar("T_Attr")
T_Default = TypeVar("T_Default")
undefined = object()
def typed_attribute() -> Any:
"""Return a unique object, used to mark typed attributes."""
return object()
class TypedAttributeSet:
"""
Superclass for typed attribute collections.
Checks that every public attribute of every subclass has a type annotation.
"""
def __init_subclass__(cls) -> None:
annotations: dict[str, Any] = getattr(cls, "__annotations__", {})
for attrname in dir(cls):
if not attrname.startswith("_") and attrname not in annotations:
raise TypeError(
f"Attribute {attrname!r} is missing its type annotation"
)
super().__init_subclass__()
class TypedAttributeProvider:
"""Base class for classes that wish to provide typed extra attributes."""
@property
def extra_attributes(self) -> Mapping[T_Attr, Callable[[], T_Attr]]:
"""
A mapping of the extra attributes to callables that return the corresponding
values.
If the provider wraps another provider, the attributes from that wrapper should
also be included in the returned mapping (but the wrapper may override the
callables from the wrapped instance).
"""
return {}
@overload
def extra(self, attribute: T_Attr) -> T_Attr: ...
@overload
def extra(self, attribute: T_Attr, default: T_Default) -> T_Attr | T_Default: ...
@final
def extra(self, attribute: Any, default: object = undefined) -> object:
"""
extra(attribute, default=undefined)
Return the value of the given typed extra attribute.
:param attribute: the attribute (member of a :class:`~TypedAttributeSet`) to
look for
:param default: the value that should be returned if no value is found for the
attribute
:raises ~anyio.TypedAttributeLookupError: if the search failed and no default
value was given
"""
try:
getter = self.extra_attributes[attribute]
except KeyError:
if default is undefined:
raise TypedAttributeLookupError("Attribute not found") from None
else:
return default
return getter()

View File

@@ -0,0 +1,58 @@
from __future__ import annotations
from ._eventloop import AsyncBackend as AsyncBackend
from ._resources import AsyncResource as AsyncResource
from ._sockets import ConnectedUDPSocket as ConnectedUDPSocket
from ._sockets import ConnectedUNIXDatagramSocket as ConnectedUNIXDatagramSocket
from ._sockets import IPAddressType as IPAddressType
from ._sockets import IPSockAddrType as IPSockAddrType
from ._sockets import SocketAttribute as SocketAttribute
from ._sockets import SocketListener as SocketListener
from ._sockets import SocketStream as SocketStream
from ._sockets import UDPPacketType as UDPPacketType
from ._sockets import UDPSocket as UDPSocket
from ._sockets import UNIXDatagramPacketType as UNIXDatagramPacketType
from ._sockets import UNIXDatagramSocket as UNIXDatagramSocket
from ._sockets import UNIXSocketStream as UNIXSocketStream
from ._streams import AnyByteReceiveStream as AnyByteReceiveStream
from ._streams import AnyByteSendStream as AnyByteSendStream
from ._streams import AnyByteStream as AnyByteStream
from ._streams import AnyByteStreamConnectable as AnyByteStreamConnectable
from ._streams import AnyUnreliableByteReceiveStream as AnyUnreliableByteReceiveStream
from ._streams import AnyUnreliableByteSendStream as AnyUnreliableByteSendStream
from ._streams import AnyUnreliableByteStream as AnyUnreliableByteStream
from ._streams import ByteReceiveStream as ByteReceiveStream
from ._streams import ByteSendStream as ByteSendStream
from ._streams import ByteStream as ByteStream
from ._streams import ByteStreamConnectable as ByteStreamConnectable
from ._streams import Listener as Listener
from ._streams import ObjectReceiveStream as ObjectReceiveStream
from ._streams import ObjectSendStream as ObjectSendStream
from ._streams import ObjectStream as ObjectStream
from ._streams import ObjectStreamConnectable as ObjectStreamConnectable
from ._streams import UnreliableObjectReceiveStream as UnreliableObjectReceiveStream
from ._streams import UnreliableObjectSendStream as UnreliableObjectSendStream
from ._streams import UnreliableObjectStream as UnreliableObjectStream
from ._subprocesses import Process as Process
from ._tasks import TaskGroup as TaskGroup
from ._tasks import TaskStatus as TaskStatus
from ._testing import TestRunner as TestRunner
# Re-exported here, for backwards compatibility
# isort: off
from .._core._synchronization import (
CapacityLimiter as CapacityLimiter,
Condition as Condition,
Event as Event,
Lock as Lock,
Semaphore as Semaphore,
)
from .._core._tasks import CancelScope as CancelScope
from ..from_thread import BlockingPortal as BlockingPortal
# Re-export imports so they look like they live directly in this package
for __value in list(locals().values()):
if getattr(__value, "__module__", "").startswith("anyio.abc."):
__value.__module__ = __name__
del __value

View File

@@ -0,0 +1,414 @@
from __future__ import annotations
import math
import sys
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncIterator, Awaitable, Callable, Sequence
from contextlib import AbstractContextManager
from os import PathLike
from signal import Signals
from socket import AddressFamily, SocketKind, socket
from typing import (
IO,
TYPE_CHECKING,
Any,
TypeVar,
Union,
overload,
)
if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
if TYPE_CHECKING:
from _typeshed import FileDescriptorLike
from .._core._synchronization import CapacityLimiter, Event, Lock, Semaphore
from .._core._tasks import CancelScope
from .._core._testing import TaskInfo
from ._sockets import (
ConnectedUDPSocket,
ConnectedUNIXDatagramSocket,
IPSockAddrType,
SocketListener,
SocketStream,
UDPSocket,
UNIXDatagramSocket,
UNIXSocketStream,
)
from ._subprocesses import Process
from ._tasks import TaskGroup
from ._testing import TestRunner
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
StrOrBytesPath: TypeAlias = Union[str, bytes, "PathLike[str]", "PathLike[bytes]"]
class AsyncBackend(metaclass=ABCMeta):
@classmethod
@abstractmethod
def run(
cls,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
options: dict[str, Any],
) -> T_Retval:
"""
Run the given coroutine function in an asynchronous event loop.
The current thread must not be already running an event loop.
:param func: a coroutine function
:param args: positional arguments to ``func``
:param kwargs: positional arguments to ``func``
:param options: keyword arguments to call the backend ``run()`` implementation
with
:return: the return value of the coroutine function
"""
@classmethod
@abstractmethod
def current_token(cls) -> object:
"""
Return an object that allows other threads to run code inside the event loop.
:return: a token object, specific to the event loop running in the current
thread
"""
@classmethod
@abstractmethod
def current_time(cls) -> float:
"""
Return the current value of the event loop's internal clock.
:return: the clock value (seconds)
"""
@classmethod
@abstractmethod
def cancelled_exception_class(cls) -> type[BaseException]:
"""Return the exception class that is raised in a task if it's cancelled."""
@classmethod
@abstractmethod
async def checkpoint(cls) -> None:
"""
Check if the task has been cancelled, and allow rescheduling of other tasks.
This is effectively the same as running :meth:`checkpoint_if_cancelled` and then
:meth:`cancel_shielded_checkpoint`.
"""
@classmethod
async def checkpoint_if_cancelled(cls) -> None:
"""
Check if the current task group has been cancelled.
This will check if the task has been cancelled, but will not allow other tasks
to be scheduled if not.
"""
if cls.current_effective_deadline() == -math.inf:
await cls.checkpoint()
@classmethod
async def cancel_shielded_checkpoint(cls) -> None:
"""
Allow the rescheduling of other tasks.
This will give other tasks the opportunity to run, but without checking if the
current task group has been cancelled, unlike with :meth:`checkpoint`.
"""
with cls.create_cancel_scope(shield=True):
await cls.sleep(0)
@classmethod
@abstractmethod
async def sleep(cls, delay: float) -> None:
"""
Pause the current task for the specified duration.
:param delay: the duration, in seconds
"""
@classmethod
@abstractmethod
def create_cancel_scope(
cls, *, deadline: float = math.inf, shield: bool = False
) -> CancelScope:
pass
@classmethod
@abstractmethod
def current_effective_deadline(cls) -> float:
"""
Return the nearest deadline among all the cancel scopes effective for the
current task.
:return:
- a clock value from the event loop's internal clock
- ``inf`` if there is no deadline in effect
- ``-inf`` if the current scope has been cancelled
:rtype: float
"""
@classmethod
@abstractmethod
def create_task_group(cls) -> TaskGroup:
pass
@classmethod
@abstractmethod
def create_event(cls) -> Event:
pass
@classmethod
@abstractmethod
def create_lock(cls, *, fast_acquire: bool) -> Lock:
pass
@classmethod
@abstractmethod
def create_semaphore(
cls,
initial_value: int,
*,
max_value: int | None = None,
fast_acquire: bool = False,
) -> Semaphore:
pass
@classmethod
@abstractmethod
def create_capacity_limiter(cls, total_tokens: float) -> CapacityLimiter:
pass
@classmethod
@abstractmethod
async def run_sync_in_worker_thread(
cls,
func: Callable[[Unpack[PosArgsT]], T_Retval],
args: tuple[Unpack[PosArgsT]],
abandon_on_cancel: bool = False,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
pass
@classmethod
@abstractmethod
def check_cancelled(cls) -> None:
pass
@classmethod
@abstractmethod
def run_async_from_thread(
cls,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
args: tuple[Unpack[PosArgsT]],
token: object,
) -> T_Retval:
pass
@classmethod
@abstractmethod
def run_sync_from_thread(
cls,
func: Callable[[Unpack[PosArgsT]], T_Retval],
args: tuple[Unpack[PosArgsT]],
token: object,
) -> T_Retval:
pass
@classmethod
@abstractmethod
async def open_process(
cls,
command: StrOrBytesPath | Sequence[StrOrBytesPath],
*,
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
**kwargs: Any,
) -> Process:
pass
@classmethod
@abstractmethod
def setup_process_pool_exit_at_shutdown(cls, workers: set[Process]) -> None:
pass
@classmethod
@abstractmethod
async def connect_tcp(
cls, host: str, port: int, local_address: IPSockAddrType | None = None
) -> SocketStream:
pass
@classmethod
@abstractmethod
async def connect_unix(cls, path: str | bytes) -> UNIXSocketStream:
pass
@classmethod
@abstractmethod
def create_tcp_listener(cls, sock: socket) -> SocketListener:
pass
@classmethod
@abstractmethod
def create_unix_listener(cls, sock: socket) -> SocketListener:
pass
@classmethod
@abstractmethod
async def create_udp_socket(
cls,
family: AddressFamily,
local_address: IPSockAddrType | None,
remote_address: IPSockAddrType | None,
reuse_port: bool,
) -> UDPSocket | ConnectedUDPSocket:
pass
@classmethod
@overload
async def create_unix_datagram_socket(
cls, raw_socket: socket, remote_path: None
) -> UNIXDatagramSocket: ...
@classmethod
@overload
async def create_unix_datagram_socket(
cls, raw_socket: socket, remote_path: str | bytes
) -> ConnectedUNIXDatagramSocket: ...
@classmethod
@abstractmethod
async def create_unix_datagram_socket(
cls, raw_socket: socket, remote_path: str | bytes | None
) -> UNIXDatagramSocket | ConnectedUNIXDatagramSocket:
pass
@classmethod
@abstractmethod
async def getaddrinfo(
cls,
host: bytes | str | None,
port: str | int | None,
*,
family: int | AddressFamily = 0,
type: int | SocketKind = 0,
proto: int = 0,
flags: int = 0,
) -> Sequence[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int] | tuple[int, bytes],
]
]:
pass
@classmethod
@abstractmethod
async def getnameinfo(
cls, sockaddr: IPSockAddrType, flags: int = 0
) -> tuple[str, str]:
pass
@classmethod
@abstractmethod
async def wait_readable(cls, obj: FileDescriptorLike) -> None:
pass
@classmethod
@abstractmethod
async def wait_writable(cls, obj: FileDescriptorLike) -> None:
pass
@classmethod
@abstractmethod
def notify_closing(cls, obj: FileDescriptorLike) -> None:
pass
@classmethod
@abstractmethod
async def wrap_listener_socket(cls, sock: socket) -> SocketListener:
pass
@classmethod
@abstractmethod
async def wrap_stream_socket(cls, sock: socket) -> SocketStream:
pass
@classmethod
@abstractmethod
async def wrap_unix_stream_socket(cls, sock: socket) -> UNIXSocketStream:
pass
@classmethod
@abstractmethod
async def wrap_udp_socket(cls, sock: socket) -> UDPSocket:
pass
@classmethod
@abstractmethod
async def wrap_connected_udp_socket(cls, sock: socket) -> ConnectedUDPSocket:
pass
@classmethod
@abstractmethod
async def wrap_unix_datagram_socket(cls, sock: socket) -> UNIXDatagramSocket:
pass
@classmethod
@abstractmethod
async def wrap_connected_unix_datagram_socket(
cls, sock: socket
) -> ConnectedUNIXDatagramSocket:
pass
@classmethod
@abstractmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
pass
@classmethod
@abstractmethod
def open_signal_receiver(
cls, *signals: Signals
) -> AbstractContextManager[AsyncIterator[Signals]]:
pass
@classmethod
@abstractmethod
def get_current_task(cls) -> TaskInfo:
pass
@classmethod
@abstractmethod
def get_running_tasks(cls) -> Sequence[TaskInfo]:
pass
@classmethod
@abstractmethod
async def wait_all_tasks_blocked(cls) -> None:
pass
@classmethod
@abstractmethod
def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
pass

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from types import TracebackType
from typing import TypeVar
T = TypeVar("T")
class AsyncResource(metaclass=ABCMeta):
"""
Abstract base class for all closeable asynchronous resources.
Works as an asynchronous context manager which returns the instance itself on enter,
and calls :meth:`aclose` on exit.
"""
__slots__ = ()
async def __aenter__(self: T) -> T:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
@abstractmethod
async def aclose(self) -> None:
"""Close the resource."""

View File

@@ -0,0 +1,405 @@
from __future__ import annotations
import errno
import socket
import sys
from abc import abstractmethod
from collections.abc import Callable, Collection, Mapping
from contextlib import AsyncExitStack
from io import IOBase
from ipaddress import IPv4Address, IPv6Address
from socket import AddressFamily
from typing import Any, TypeVar, Union
from .._core._eventloop import get_async_backend
from .._core._typedattr import (
TypedAttributeProvider,
TypedAttributeSet,
typed_attribute,
)
from ._streams import ByteStream, Listener, UnreliableObjectStream
from ._tasks import TaskGroup
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
IPAddressType: TypeAlias = Union[str, IPv4Address, IPv6Address]
IPSockAddrType: TypeAlias = tuple[str, int]
SockAddrType: TypeAlias = Union[IPSockAddrType, str]
UDPPacketType: TypeAlias = tuple[bytes, IPSockAddrType]
UNIXDatagramPacketType: TypeAlias = tuple[bytes, str]
T_Retval = TypeVar("T_Retval")
def _validate_socket(
sock_or_fd: socket.socket | int,
sock_type: socket.SocketKind,
addr_family: socket.AddressFamily = socket.AF_UNSPEC,
*,
require_connected: bool = False,
require_bound: bool = False,
) -> socket.socket:
if isinstance(sock_or_fd, int):
try:
sock = socket.socket(fileno=sock_or_fd)
except OSError as exc:
if exc.errno == errno.ENOTSOCK:
raise ValueError(
"the file descriptor does not refer to a socket"
) from exc
elif require_connected:
raise ValueError("the socket must be connected") from exc
elif require_bound:
raise ValueError("the socket must be bound to a local address") from exc
else:
raise
elif isinstance(sock_or_fd, socket.socket):
sock = sock_or_fd
else:
raise TypeError(
f"expected an int or socket, got {type(sock_or_fd).__qualname__} instead"
)
try:
if require_connected:
try:
sock.getpeername()
except OSError as exc:
raise ValueError("the socket must be connected") from exc
if require_bound:
try:
if sock.family in (socket.AF_INET, socket.AF_INET6):
bound_addr = sock.getsockname()[1]
else:
bound_addr = sock.getsockname()
except OSError:
bound_addr = None
if not bound_addr:
raise ValueError("the socket must be bound to a local address")
if addr_family != socket.AF_UNSPEC and sock.family != addr_family:
raise ValueError(
f"address family mismatch: expected {addr_family.name}, got "
f"{sock.family.name}"
)
if sock.type != sock_type:
raise ValueError(
f"socket type mismatch: expected {sock_type.name}, got {sock.type.name}"
)
except BaseException:
# Avoid ResourceWarning from the locally constructed socket object
if isinstance(sock_or_fd, int):
sock.detach()
raise
sock.setblocking(False)
return sock
class SocketAttribute(TypedAttributeSet):
"""
.. attribute:: family
:type: socket.AddressFamily
the address family of the underlying socket
.. attribute:: local_address
:type: tuple[str, int] | str
the local address the underlying socket is connected to
.. attribute:: local_port
:type: int
for IP based sockets, the local port the underlying socket is bound to
.. attribute:: raw_socket
:type: socket.socket
the underlying stdlib socket object
.. attribute:: remote_address
:type: tuple[str, int] | str
the remote address the underlying socket is connected to
.. attribute:: remote_port
:type: int
for IP based sockets, the remote port the underlying socket is connected to
"""
family: AddressFamily = typed_attribute()
local_address: SockAddrType = typed_attribute()
local_port: int = typed_attribute()
raw_socket: socket.socket = typed_attribute()
remote_address: SockAddrType = typed_attribute()
remote_port: int = typed_attribute()
class _SocketProvider(TypedAttributeProvider):
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
from .._core._sockets import convert_ipv6_sockaddr as convert
attributes: dict[Any, Callable[[], Any]] = {
SocketAttribute.family: lambda: self._raw_socket.family,
SocketAttribute.local_address: lambda: convert(
self._raw_socket.getsockname()
),
SocketAttribute.raw_socket: lambda: self._raw_socket,
}
try:
peername: tuple[str, int] | None = convert(self._raw_socket.getpeername())
except OSError:
peername = None
# Provide the remote address for connected sockets
if peername is not None:
attributes[SocketAttribute.remote_address] = lambda: peername
# Provide local and remote ports for IP based sockets
if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
attributes[SocketAttribute.local_port] = (
lambda: self._raw_socket.getsockname()[1]
)
if peername is not None:
remote_port = peername[1]
attributes[SocketAttribute.remote_port] = lambda: remote_port
return attributes
@property
@abstractmethod
def _raw_socket(self) -> socket.socket:
pass
class SocketStream(ByteStream, _SocketProvider):
"""
Transports bytes over a socket.
Supports all relevant extra attributes from :class:`~SocketAttribute`.
"""
@classmethod
async def from_socket(cls, sock_or_fd: socket.socket | int) -> SocketStream:
"""
Wrap an existing socket object or file descriptor as a socket stream.
The newly created socket wrapper takes ownership of the socket being passed in.
The existing socket must already be connected.
:param sock_or_fd: a socket object or file descriptor
:return: a socket stream
"""
sock = _validate_socket(sock_or_fd, socket.SOCK_STREAM, require_connected=True)
return await get_async_backend().wrap_stream_socket(sock)
class UNIXSocketStream(SocketStream):
@classmethod
async def from_socket(cls, sock_or_fd: socket.socket | int) -> UNIXSocketStream:
"""
Wrap an existing socket object or file descriptor as a UNIX socket stream.
The newly created socket wrapper takes ownership of the socket being passed in.
The existing socket must already be connected.
:param sock_or_fd: a socket object or file descriptor
:return: a UNIX socket stream
"""
sock = _validate_socket(
sock_or_fd, socket.SOCK_STREAM, socket.AF_UNIX, require_connected=True
)
return await get_async_backend().wrap_unix_stream_socket(sock)
@abstractmethod
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
"""
Send file descriptors along with a message to the peer.
:param message: a non-empty bytestring
:param fds: a collection of files (either numeric file descriptors or open file
or socket objects)
"""
@abstractmethod
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
"""
Receive file descriptors along with a message from the peer.
:param msglen: length of the message to expect from the peer
:param maxfds: maximum number of file descriptors to expect from the peer
:return: a tuple of (message, file descriptors)
"""
class SocketListener(Listener[SocketStream], _SocketProvider):
"""
Listens to incoming socket connections.
Supports all relevant extra attributes from :class:`~SocketAttribute`.
"""
@classmethod
async def from_socket(
cls,
sock_or_fd: socket.socket | int,
) -> SocketListener:
"""
Wrap an existing socket object or file descriptor as a socket listener.
The newly created listener takes ownership of the socket being passed in.
:param sock_or_fd: a socket object or file descriptor
:return: a socket listener
"""
sock = _validate_socket(sock_or_fd, socket.SOCK_STREAM, require_bound=True)
return await get_async_backend().wrap_listener_socket(sock)
@abstractmethod
async def accept(self) -> SocketStream:
"""Accept an incoming connection."""
async def serve(
self,
handler: Callable[[SocketStream], Any],
task_group: TaskGroup | None = None,
) -> None:
from .. import create_task_group
async with AsyncExitStack() as stack:
if task_group is None:
task_group = await stack.enter_async_context(create_task_group())
while True:
stream = await self.accept()
task_group.start_soon(handler, stream)
class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
"""
Represents an unconnected UDP socket.
Supports all relevant extra attributes from :class:`~SocketAttribute`.
"""
@classmethod
async def from_socket(cls, sock_or_fd: socket.socket | int) -> UDPSocket:
"""
Wrap an existing socket object or file descriptor as a UDP socket.
The newly created socket wrapper takes ownership of the socket being passed in.
The existing socket must be bound to a local address.
:param sock_or_fd: a socket object or file descriptor
:return: a UDP socket
"""
sock = _validate_socket(sock_or_fd, socket.SOCK_DGRAM, require_bound=True)
return await get_async_backend().wrap_udp_socket(sock)
async def sendto(self, data: bytes, host: str, port: int) -> None:
"""
Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).
"""
return await self.send((data, (host, port)))
class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
"""
Represents an connected UDP socket.
Supports all relevant extra attributes from :class:`~SocketAttribute`.
"""
@classmethod
async def from_socket(cls, sock_or_fd: socket.socket | int) -> ConnectedUDPSocket:
"""
Wrap an existing socket object or file descriptor as a connected UDP socket.
The newly created socket wrapper takes ownership of the socket being passed in.
The existing socket must already be connected.
:param sock_or_fd: a socket object or file descriptor
:return: a connected UDP socket
"""
sock = _validate_socket(
sock_or_fd,
socket.SOCK_DGRAM,
require_connected=True,
)
return await get_async_backend().wrap_connected_udp_socket(sock)
class UNIXDatagramSocket(
UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider
):
"""
Represents an unconnected Unix datagram socket.
Supports all relevant extra attributes from :class:`~SocketAttribute`.
"""
@classmethod
async def from_socket(
cls,
sock_or_fd: socket.socket | int,
) -> UNIXDatagramSocket:
"""
Wrap an existing socket object or file descriptor as a UNIX datagram
socket.
The newly created socket wrapper takes ownership of the socket being passed in.
:param sock_or_fd: a socket object or file descriptor
:return: a UNIX datagram socket
"""
sock = _validate_socket(sock_or_fd, socket.SOCK_DGRAM, socket.AF_UNIX)
return await get_async_backend().wrap_unix_datagram_socket(sock)
async def sendto(self, data: bytes, path: str) -> None:
"""Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path))."""
return await self.send((data, path))
class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider):
"""
Represents a connected Unix datagram socket.
Supports all relevant extra attributes from :class:`~SocketAttribute`.
"""
@classmethod
async def from_socket(
cls,
sock_or_fd: socket.socket | int,
) -> ConnectedUNIXDatagramSocket:
"""
Wrap an existing socket object or file descriptor as a connected UNIX datagram
socket.
The newly created socket wrapper takes ownership of the socket being passed in.
The existing socket must already be connected.
:param sock_or_fd: a socket object or file descriptor
:return: a connected UNIX datagram socket
"""
sock = _validate_socket(
sock_or_fd, socket.SOCK_DGRAM, socket.AF_UNIX, require_connected=True
)
return await get_async_backend().wrap_connected_unix_datagram_socket(sock)

View File

@@ -0,0 +1,239 @@
from __future__ import annotations
import sys
from abc import ABCMeta, abstractmethod
from collections.abc import Callable
from typing import Any, Generic, TypeVar, Union
from .._core._exceptions import EndOfStream
from .._core._typedattr import TypedAttributeProvider
from ._resources import AsyncResource
from ._tasks import TaskGroup
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
T_Item = TypeVar("T_Item")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)
class UnreliableObjectReceiveStream(
Generic[T_co], AsyncResource, TypedAttributeProvider
):
"""
An interface for receiving objects.
This interface makes no guarantees that the received messages arrive in the order in
which they were sent, or that no messages are missed.
Asynchronously iterating over objects of this type will yield objects matching the
given type parameter.
"""
def __aiter__(self) -> UnreliableObjectReceiveStream[T_co]:
return self
async def __anext__(self) -> T_co:
try:
return await self.receive()
except EndOfStream:
raise StopAsyncIteration from None
@abstractmethod
async def receive(self) -> T_co:
"""
Receive the next item.
:raises ~anyio.ClosedResourceError: if the receive stream has been explicitly
closed
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
due to external causes
"""
class UnreliableObjectSendStream(
Generic[T_contra], AsyncResource, TypedAttributeProvider
):
"""
An interface for sending objects.
This interface makes no guarantees that the messages sent will reach the
recipient(s) in the same order in which they were sent, or at all.
"""
@abstractmethod
async def send(self, item: T_contra) -> None:
"""
Send an item to the peer(s).
:param item: the item to send
:raises ~anyio.ClosedResourceError: if the send stream has been explicitly
closed
:raises ~anyio.BrokenResourceError: if this stream has been rendered unusable
due to external causes
"""
class UnreliableObjectStream(
UnreliableObjectReceiveStream[T_Item], UnreliableObjectSendStream[T_Item]
):
"""
A bidirectional message stream which does not guarantee the order or reliability of
message delivery.
"""
class ObjectReceiveStream(UnreliableObjectReceiveStream[T_co]):
"""
A receive message stream which guarantees that messages are received in the same
order in which they were sent, and that no messages are missed.
"""
class ObjectSendStream(UnreliableObjectSendStream[T_contra]):
"""
A send message stream which guarantees that messages are delivered in the same order
in which they were sent, without missing any messages in the middle.
"""
class ObjectStream(
ObjectReceiveStream[T_Item],
ObjectSendStream[T_Item],
UnreliableObjectStream[T_Item],
):
"""
A bidirectional message stream which guarantees the order and reliability of message
delivery.
"""
@abstractmethod
async def send_eof(self) -> None:
"""
Send an end-of-file indication to the peer.
You should not try to send any further data to this stream after calling this
method. This method is idempotent (does nothing on successive calls).
"""
class ByteReceiveStream(AsyncResource, TypedAttributeProvider):
"""
An interface for receiving bytes from a single peer.
Iterating this byte stream will yield a byte string of arbitrary length, but no more
than 65536 bytes.
"""
def __aiter__(self) -> ByteReceiveStream:
return self
async def __anext__(self) -> bytes:
try:
return await self.receive()
except EndOfStream:
raise StopAsyncIteration from None
@abstractmethod
async def receive(self, max_bytes: int = 65536) -> bytes:
"""
Receive at most ``max_bytes`` bytes from the peer.
.. note:: Implementers of this interface should not return an empty
:class:`bytes` object, and users should ignore them.
:param max_bytes: maximum number of bytes to receive
:return: the received bytes
:raises ~anyio.EndOfStream: if this stream has been closed from the other end
"""
class ByteSendStream(AsyncResource, TypedAttributeProvider):
"""An interface for sending bytes to a single peer."""
@abstractmethod
async def send(self, item: bytes) -> None:
"""
Send the given bytes to the peer.
:param item: the bytes to send
"""
class ByteStream(ByteReceiveStream, ByteSendStream):
"""A bidirectional byte stream."""
@abstractmethod
async def send_eof(self) -> None:
"""
Send an end-of-file indication to the peer.
You should not try to send any further data to this stream after calling this
method. This method is idempotent (does nothing on successive calls).
"""
#: Type alias for all unreliable bytes-oriented receive streams.
AnyUnreliableByteReceiveStream: TypeAlias = Union[
UnreliableObjectReceiveStream[bytes], ByteReceiveStream
]
#: Type alias for all unreliable bytes-oriented send streams.
AnyUnreliableByteSendStream: TypeAlias = Union[
UnreliableObjectSendStream[bytes], ByteSendStream
]
#: Type alias for all unreliable bytes-oriented streams.
AnyUnreliableByteStream: TypeAlias = Union[UnreliableObjectStream[bytes], ByteStream]
#: Type alias for all bytes-oriented receive streams.
AnyByteReceiveStream: TypeAlias = Union[ObjectReceiveStream[bytes], ByteReceiveStream]
#: Type alias for all bytes-oriented send streams.
AnyByteSendStream: TypeAlias = Union[ObjectSendStream[bytes], ByteSendStream]
#: Type alias for all bytes-oriented streams.
AnyByteStream: TypeAlias = Union[ObjectStream[bytes], ByteStream]
class Listener(Generic[T_co], AsyncResource, TypedAttributeProvider):
"""An interface for objects that let you accept incoming connections."""
@abstractmethod
async def serve(
self, handler: Callable[[T_co], Any], task_group: TaskGroup | None = None
) -> None:
"""
Accept incoming connections as they come in and start tasks to handle them.
:param handler: a callable that will be used to handle each accepted connection
:param task_group: the task group that will be used to start tasks for handling
each accepted connection (if omitted, an ad-hoc task group will be created)
"""
class ObjectStreamConnectable(Generic[T_co], metaclass=ABCMeta):
@abstractmethod
async def connect(self) -> ObjectStream[T_co]:
"""
Connect to the remote endpoint.
:return: an object stream connected to the remote end
:raises ConnectionFailed: if the connection fails
"""
class ByteStreamConnectable(metaclass=ABCMeta):
@abstractmethod
async def connect(self) -> ByteStream:
"""
Connect to the remote endpoint.
:return: a bytestream connected to the remote end
:raises ConnectionFailed: if the connection fails
"""
#: Type alias for all connectables returning bytestreams or bytes-oriented object streams
AnyByteStreamConnectable: TypeAlias = Union[
ObjectStreamConnectable[bytes], ByteStreamConnectable
]

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from abc import abstractmethod
from signal import Signals
from ._resources import AsyncResource
from ._streams import ByteReceiveStream, ByteSendStream
class Process(AsyncResource):
"""An asynchronous version of :class:`subprocess.Popen`."""
@abstractmethod
async def wait(self) -> int:
"""
Wait until the process exits.
:return: the exit code of the process
"""
@abstractmethod
def terminate(self) -> None:
"""
Terminates the process, gracefully if possible.
On Windows, this calls ``TerminateProcess()``.
On POSIX systems, this sends ``SIGTERM`` to the process.
.. seealso:: :meth:`subprocess.Popen.terminate`
"""
@abstractmethod
def kill(self) -> None:
"""
Kills the process.
On Windows, this calls ``TerminateProcess()``.
On POSIX systems, this sends ``SIGKILL`` to the process.
.. seealso:: :meth:`subprocess.Popen.kill`
"""
@abstractmethod
def send_signal(self, signal: Signals) -> None:
"""
Send a signal to the subprocess.
.. seealso:: :meth:`subprocess.Popen.send_signal`
:param signal: the signal number (e.g. :data:`signal.SIGHUP`)
"""
@property
@abstractmethod
def pid(self) -> int:
"""The process ID of the process."""
@property
@abstractmethod
def returncode(self) -> int | None:
"""
The return code of the process. If the process has not yet terminated, this will
be ``None``.
"""
@property
@abstractmethod
def stdin(self) -> ByteSendStream | None:
"""The stream for the standard input of the process."""
@property
@abstractmethod
def stdout(self) -> ByteReceiveStream | None:
"""The stream for the standard output of the process."""
@property
@abstractmethod
def stderr(self) -> ByteReceiveStream | None:
"""The stream for the standard error output of the process."""

View File

@@ -0,0 +1,117 @@
from __future__ import annotations
import sys
from abc import ABCMeta, abstractmethod
from collections.abc import Awaitable, Callable
from types import TracebackType
from typing import TYPE_CHECKING, Any, Protocol, overload
if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar
if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack
if TYPE_CHECKING:
from .._core._tasks import CancelScope
T_Retval = TypeVar("T_Retval")
T_contra = TypeVar("T_contra", contravariant=True, default=None)
PosArgsT = TypeVarTuple("PosArgsT")
class TaskStatus(Protocol[T_contra]):
@overload
def started(self: TaskStatus[None]) -> None: ...
@overload
def started(self, value: T_contra) -> None: ...
def started(self, value: T_contra | None = None) -> None:
"""
Signal that the task has started.
:param value: object passed back to the starter of the task
"""
class TaskGroup(metaclass=ABCMeta):
"""
Groups several asynchronous tasks together.
:ivar cancel_scope: the cancel scope inherited by all child tasks
:vartype cancel_scope: CancelScope
.. note:: On asyncio, support for eager task factories is considered to be
**experimental**. In particular, they don't follow the usual semantics of new
tasks being scheduled on the next iteration of the event loop, and may thus
cause unexpected behavior in code that wasn't written with such semantics in
mind.
"""
cancel_scope: CancelScope
@abstractmethod
def start_soon(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
*args: Unpack[PosArgsT],
name: object = None,
) -> None:
"""
Start a new task in this task group.
:param func: a coroutine function
:param args: positional arguments to call the function with
:param name: name of the task, for the purposes of introspection and debugging
.. versionadded:: 3.0
"""
@abstractmethod
async def start(
self,
func: Callable[..., Awaitable[Any]],
*args: object,
name: object = None,
) -> Any:
"""
Start a new task and wait until it signals for readiness.
The target callable must accept a keyword argument ``task_status`` (of type
:class:`TaskStatus`). Awaiting on this method will return whatever was passed to
``task_status.started()`` (``None`` by default).
.. note:: The :class:`TaskStatus` class is generic, and the type argument should
indicate the type of the value that will be passed to
``task_status.started()``.
:param func: a coroutine function that accepts the ``task_status`` keyword
argument
:param args: positional arguments to call the function with
:param name: an optional name for the task, for introspection and debugging
:return: the value passed to ``task_status.started()``
:raises RuntimeError: if the task finishes without calling
``task_status.started()``
.. seealso:: :ref:`start_initialize`
.. versionadded:: 3.0
"""
@abstractmethod
async def __aenter__(self) -> TaskGroup:
"""Enter the task group context and allow starting new tasks."""
@abstractmethod
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
"""Exit the task group context waiting for all tasks to finish."""

View File

@@ -0,0 +1,65 @@
from __future__ import annotations
import types
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncGenerator, Callable, Coroutine, Iterable
from typing import Any, TypeVar
_T = TypeVar("_T")
class TestRunner(metaclass=ABCMeta):
"""
Encapsulates a running event loop. Every call made through this object will use the
same event loop.
"""
def __enter__(self) -> TestRunner:
return self
@abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> bool | None: ...
@abstractmethod
def run_asyncgen_fixture(
self,
fixture_func: Callable[..., AsyncGenerator[_T, Any]],
kwargs: dict[str, Any],
) -> Iterable[_T]:
"""
Run an async generator fixture.
:param fixture_func: the fixture function
:param kwargs: keyword arguments to call the fixture function with
:return: an iterator yielding the value yielded from the async generator
"""
@abstractmethod
def run_fixture(
self,
fixture_func: Callable[..., Coroutine[Any, Any, _T]],
kwargs: dict[str, Any],
) -> _T:
"""
Run an async fixture.
:param fixture_func: the fixture function
:param kwargs: keyword arguments to call the fixture function with
:return: the return value of the fixture function
"""
@abstractmethod
def run_test(
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
) -> None:
"""
Run an async test function.
:param test_func: the test function
:param kwargs: keyword arguments to call the test function with
"""

View File

@@ -0,0 +1,578 @@
from __future__ import annotations
__all__ = (
"BlockingPortal",
"BlockingPortalProvider",
"check_cancelled",
"run",
"run_sync",
"start_blocking_portal",
)
import sys
from collections.abc import Awaitable, Callable, Generator
from concurrent.futures import Future
from contextlib import (
AbstractAsyncContextManager,
AbstractContextManager,
contextmanager,
)
from dataclasses import dataclass, field
from functools import partial
from inspect import isawaitable
from threading import Lock, Thread, current_thread, get_ident
from types import TracebackType
from typing import (
Any,
Generic,
TypeVar,
cast,
overload,
)
from ._core._eventloop import (
get_cancelled_exc_class,
threadlocals,
)
from ._core._eventloop import run as run_eventloop
from ._core._exceptions import NoEventLoopError
from ._core._synchronization import Event
from ._core._tasks import CancelScope, create_task_group
from .abc._tasks import TaskStatus
from .lowlevel import EventLoopToken, current_token
if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack
T_Retval = TypeVar("T_Retval")
T_co = TypeVar("T_co", covariant=True)
PosArgsT = TypeVarTuple("PosArgsT")
def _token_or_error(token: EventLoopToken | None) -> EventLoopToken:
if token is not None:
return token
try:
return threadlocals.current_token
except AttributeError:
raise NoEventLoopError(
"Not running inside an AnyIO worker thread, and no event loop token was "
"provided"
) from None
def run(
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
*args: Unpack[PosArgsT],
token: EventLoopToken | None = None,
) -> T_Retval:
"""
Call a coroutine function from a worker thread.
:param func: a coroutine function
:param args: positional arguments for the callable
:param token: an event loop token to use to get back to the event loop thread
(required if calling this function from outside an AnyIO worker thread)
:return: the return value of the coroutine function
:raises MissingTokenError: if no token was provided and called from outside an
AnyIO worker thread
:raises RunFinishedError: if the event loop tied to ``token`` is no longer running
.. versionchanged:: 4.11.0
Added the ``token`` parameter.
"""
explicit_token = token is not None
token = _token_or_error(token)
return token.backend_class.run_async_from_thread(
func, args, token=token.native_token if explicit_token else None
)
def run_sync(
func: Callable[[Unpack[PosArgsT]], T_Retval],
*args: Unpack[PosArgsT],
token: EventLoopToken | None = None,
) -> T_Retval:
"""
Call a function in the event loop thread from a worker thread.
:param func: a callable
:param args: positional arguments for the callable
:param token: an event loop token to use to get back to the event loop thread
(required if calling this function from outside an AnyIO worker thread)
:return: the return value of the callable
:raises MissingTokenError: if no token was provided and called from outside an
AnyIO worker thread
:raises RunFinishedError: if the event loop tied to ``token`` is no longer running
.. versionchanged:: 4.11.0
Added the ``token`` parameter.
"""
explicit_token = token is not None
token = _token_or_error(token)
return token.backend_class.run_sync_from_thread(
func, args, token=token.native_token if explicit_token else None
)
class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager):
_enter_future: Future[T_co]
_exit_future: Future[bool | None]
_exit_event: Event
_exit_exc_info: tuple[
type[BaseException] | None, BaseException | None, TracebackType | None
] = (None, None, None)
def __init__(
self, async_cm: AbstractAsyncContextManager[T_co], portal: BlockingPortal
):
self._async_cm = async_cm
self._portal = portal
async def run_async_cm(self) -> bool | None:
try:
self._exit_event = Event()
value = await self._async_cm.__aenter__()
except BaseException as exc:
self._enter_future.set_exception(exc)
raise
else:
self._enter_future.set_result(value)
try:
# Wait for the sync context manager to exit.
# This next statement can raise `get_cancelled_exc_class()` if
# something went wrong in a task group in this async context
# manager.
await self._exit_event.wait()
finally:
# In case of cancellation, it could be that we end up here before
# `_BlockingAsyncContextManager.__exit__` is called, and an
# `_exit_exc_info` has been set.
result = await self._async_cm.__aexit__(*self._exit_exc_info)
return result
def __enter__(self) -> T_co:
self._enter_future = Future()
self._exit_future = self._portal.start_task_soon(self.run_async_cm)
return self._enter_future.result()
def __exit__(
self,
__exc_type: type[BaseException] | None,
__exc_value: BaseException | None,
__traceback: TracebackType | None,
) -> bool | None:
self._exit_exc_info = __exc_type, __exc_value, __traceback
self._portal.call(self._exit_event.set)
return self._exit_future.result()
class _BlockingPortalTaskStatus(TaskStatus):
def __init__(self, future: Future):
self._future = future
def started(self, value: object = None) -> None:
self._future.set_result(value)
class BlockingPortal:
"""
An object that lets external threads run code in an asynchronous event loop.
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
def __init__(self) -> None:
self._token = current_token()
self._event_loop_thread_id: int | None = get_ident()
self._stop_event = Event()
self._task_group = create_task_group()
async def __aenter__(self) -> BlockingPortal:
await self._task_group.__aenter__()
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
await self.stop()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
def _check_running(self) -> None:
if self._event_loop_thread_id is None:
raise RuntimeError("This portal is not running")
if self._event_loop_thread_id == get_ident():
raise RuntimeError(
"This method cannot be called from the event loop thread"
)
async def sleep_until_stopped(self) -> None:
"""Sleep until :meth:`stop` is called."""
await self._stop_event.wait()
async def stop(self, cancel_remaining: bool = False) -> None:
"""
Signal the portal to shut down.
This marks the portal as no longer accepting new calls and exits from
:meth:`sleep_until_stopped`.
:param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False``
to let them finish before returning
"""
self._event_loop_thread_id = None
self._stop_event.set()
if cancel_remaining:
self._task_group.cancel_scope.cancel("the blocking portal is shutting down")
async def _call_func(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
future: Future[T_Retval],
) -> None:
def callback(f: Future[T_Retval]) -> None:
if f.cancelled():
if self._event_loop_thread_id == get_ident():
scope.cancel("the future was cancelled")
elif self._event_loop_thread_id is not None:
self.call(scope.cancel, "the future was cancelled")
try:
retval_or_awaitable = func(*args, **kwargs)
if isawaitable(retval_or_awaitable):
with CancelScope() as scope:
future.add_done_callback(callback)
retval = await retval_or_awaitable
else:
retval = retval_or_awaitable
except get_cancelled_exc_class():
future.cancel()
future.set_running_or_notify_cancel()
except BaseException as exc:
if not future.cancelled():
future.set_exception(exc)
# Let base exceptions fall through
if not isinstance(exc, Exception):
raise
else:
if not future.cancelled():
future.set_result(retval)
finally:
scope = None # type: ignore[assignment]
def _spawn_task_from_thread(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
future: Future[T_Retval],
) -> None:
"""
Spawn a new task using the given callable.
:param func: a callable
:param args: positional arguments to be passed to the callable
:param kwargs: keyword arguments to be passed to the callable
:param name: name of the task (will be coerced to a string if not ``None``)
:param future: a future that will resolve to the return value of the callable,
or the exception raised during its execution
"""
run_sync(
partial(self._task_group.start_soon, name=name),
self._call_func,
func,
args,
kwargs,
future,
token=self._token,
)
@overload
def call(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
*args: Unpack[PosArgsT],
) -> T_Retval: ...
@overload
def call(
self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
) -> T_Retval: ...
def call(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
*args: Unpack[PosArgsT],
) -> T_Retval:
"""
Call the given function in the event loop thread.
If the callable returns a coroutine object, it is awaited on.
:param func: any callable
:raises RuntimeError: if the portal is not running or if this method is called
from within the event loop thread
"""
return cast(T_Retval, self.start_task_soon(func, *args).result())
@overload
def start_task_soon(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
*args: Unpack[PosArgsT],
name: object = None,
) -> Future[T_Retval]: ...
@overload
def start_task_soon(
self,
func: Callable[[Unpack[PosArgsT]], T_Retval],
*args: Unpack[PosArgsT],
name: object = None,
) -> Future[T_Retval]: ...
def start_task_soon(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
*args: Unpack[PosArgsT],
name: object = None,
) -> Future[T_Retval]:
"""
Start a task in the portal's task group.
The task will be run inside a cancel scope which can be cancelled by cancelling
the returned future.
:param func: the target function
:param args: positional arguments passed to ``func``
:param name: name of the task (will be coerced to a string if not ``None``)
:return: a future that resolves with the return value of the callable if the
task completes successfully, or with the exception raised in the task
:raises RuntimeError: if the portal is not running or if this method is called
from within the event loop thread
:rtype: concurrent.futures.Future[T_Retval]
.. versionadded:: 3.0
"""
self._check_running()
f: Future[T_Retval] = Future()
self._spawn_task_from_thread(func, args, {}, name, f)
return f
def start_task(
self,
func: Callable[..., Awaitable[T_Retval]],
*args: object,
name: object = None,
) -> tuple[Future[T_Retval], Any]:
"""
Start a task in the portal's task group and wait until it signals for readiness.
This method works the same way as :meth:`.abc.TaskGroup.start`.
:param func: the target function
:param args: positional arguments passed to ``func``
:param name: name of the task (will be coerced to a string if not ``None``)
:return: a tuple of (future, task_status_value) where the ``task_status_value``
is the value passed to ``task_status.started()`` from within the target
function
:rtype: tuple[concurrent.futures.Future[T_Retval], Any]
.. versionadded:: 3.0
"""
def task_done(future: Future[T_Retval]) -> None:
if not task_status_future.done():
if future.cancelled():
task_status_future.cancel()
elif future.exception():
task_status_future.set_exception(future.exception())
else:
exc = RuntimeError(
"Task exited without calling task_status.started()"
)
task_status_future.set_exception(exc)
self._check_running()
task_status_future: Future = Future()
task_status = _BlockingPortalTaskStatus(task_status_future)
f: Future = Future()
f.add_done_callback(task_done)
self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f)
return f, task_status_future.result()
def wrap_async_context_manager(
self, cm: AbstractAsyncContextManager[T_co]
) -> AbstractContextManager[T_co]:
"""
Wrap an async context manager as a synchronous context manager via this portal.
Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping
in the middle until the synchronous context manager exits.
:param cm: an asynchronous context manager
:return: a synchronous context manager
.. versionadded:: 2.1
"""
return _BlockingAsyncContextManager(cm, self)
@dataclass
class BlockingPortalProvider:
"""
A manager for a blocking portal. Used as a context manager. The first thread to
enter this context manager causes a blocking portal to be started with the specific
parameters, and the last thread to exit causes the portal to be shut down. Thus,
there will be exactly one blocking portal running in this context as long as at
least one thread has entered this context manager.
The parameters are the same as for :func:`~anyio.run`.
:param backend: name of the backend
:param backend_options: backend options
.. versionadded:: 4.4
"""
backend: str = "asyncio"
backend_options: dict[str, Any] | None = None
_lock: Lock = field(init=False, default_factory=Lock)
_leases: int = field(init=False, default=0)
_portal: BlockingPortal = field(init=False)
_portal_cm: AbstractContextManager[BlockingPortal] | None = field(
init=False, default=None
)
def __enter__(self) -> BlockingPortal:
with self._lock:
if self._portal_cm is None:
self._portal_cm = start_blocking_portal(
self.backend, self.backend_options
)
self._portal = self._portal_cm.__enter__()
self._leases += 1
return self._portal
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
portal_cm: AbstractContextManager[BlockingPortal] | None = None
with self._lock:
assert self._portal_cm
assert self._leases > 0
self._leases -= 1
if not self._leases:
portal_cm = self._portal_cm
self._portal_cm = None
del self._portal
if portal_cm:
portal_cm.__exit__(None, None, None)
@contextmanager
def start_blocking_portal(
backend: str = "asyncio",
backend_options: dict[str, Any] | None = None,
*,
name: str | None = None,
) -> Generator[BlockingPortal, Any, None]:
"""
Start a new event loop in a new thread and run a blocking portal in its main task.
The parameters are the same as for :func:`~anyio.run`.
:param backend: name of the backend
:param backend_options: backend options
:param name: name of the thread
:return: a context manager that yields a blocking portal
.. versionchanged:: 3.0
Usage as a context manager is now required.
"""
async def run_portal() -> None:
async with BlockingPortal() as portal_:
if name is None:
current_thread().name = f"{backend}-portal-{id(portal_):x}"
future.set_result(portal_)
await portal_.sleep_until_stopped()
def run_blocking_portal() -> None:
if future.set_running_or_notify_cancel():
try:
run_eventloop(
run_portal, backend=backend, backend_options=backend_options
)
except BaseException as exc:
if not future.done():
future.set_exception(exc)
future: Future[BlockingPortal] = Future()
thread = Thread(target=run_blocking_portal, daemon=True, name=name)
thread.start()
try:
cancel_remaining_tasks = False
portal = future.result()
try:
yield portal
except BaseException:
cancel_remaining_tasks = True
raise
finally:
try:
portal.call(portal.stop, cancel_remaining_tasks)
except RuntimeError:
pass
finally:
thread.join()
def check_cancelled() -> None:
"""
Check if the cancel scope of the host task's running the current worker thread has
been cancelled.
If the host task's current cancel scope has indeed been cancelled, the
backend-specific cancellation exception will be raised.
:raises RuntimeError: if the current thread was not spawned by
:func:`.to_thread.run_sync`
"""
try:
token: EventLoopToken = threadlocals.current_token
except AttributeError:
raise NoEventLoopError(
"This function can only be called inside an AnyIO worker thread"
) from None
token.backend_class.check_cancelled()

View File

@@ -0,0 +1,375 @@
from __future__ import annotations
__all__ = (
"AsyncCacheInfo",
"AsyncCacheParameters",
"AsyncLRUCacheWrapper",
"cache",
"lru_cache",
"reduce",
)
import functools
import sys
from collections import OrderedDict
from collections.abc import (
AsyncIterable,
Awaitable,
Callable,
Coroutine,
Hashable,
Iterable,
)
from functools import update_wrapper
from inspect import iscoroutinefunction
from typing import (
Any,
Generic,
NamedTuple,
TypedDict,
TypeVar,
cast,
final,
overload,
)
from weakref import WeakKeyDictionary
from ._core._synchronization import Lock
from .lowlevel import RunVar, checkpoint
if sys.version_info >= (3, 11):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
T = TypeVar("T")
S = TypeVar("S")
P = ParamSpec("P")
lru_cache_items: RunVar[
WeakKeyDictionary[
AsyncLRUCacheWrapper[Any, Any],
OrderedDict[Hashable, tuple[_InitialMissingType, Lock] | tuple[Any, None]],
]
] = RunVar("lru_cache_items")
class _InitialMissingType:
pass
initial_missing: _InitialMissingType = _InitialMissingType()
class AsyncCacheInfo(NamedTuple):
hits: int
misses: int
maxsize: int | None
currsize: int
class AsyncCacheParameters(TypedDict):
maxsize: int | None
typed: bool
always_checkpoint: bool
class _LRUMethodWrapper(Generic[T]):
def __init__(self, wrapper: AsyncLRUCacheWrapper[..., T], instance: object):
self.__wrapper = wrapper
self.__instance = instance
def cache_info(self) -> AsyncCacheInfo:
return self.__wrapper.cache_info()
def cache_parameters(self) -> AsyncCacheParameters:
return self.__wrapper.cache_parameters()
def cache_clear(self) -> None:
self.__wrapper.cache_clear()
async def __call__(self, *args: Any, **kwargs: Any) -> T:
if self.__instance is None:
return await self.__wrapper(*args, **kwargs)
return await self.__wrapper(self.__instance, *args, **kwargs)
@final
class AsyncLRUCacheWrapper(Generic[P, T]):
def __init__(
self,
func: Callable[P, Awaitable[T]],
maxsize: int | None,
typed: bool,
always_checkpoint: bool,
):
self.__wrapped__ = func
self._hits: int = 0
self._misses: int = 0
self._maxsize = max(maxsize, 0) if maxsize is not None else None
self._currsize: int = 0
self._typed = typed
self._always_checkpoint = always_checkpoint
update_wrapper(self, func)
def cache_info(self) -> AsyncCacheInfo:
return AsyncCacheInfo(self._hits, self._misses, self._maxsize, self._currsize)
def cache_parameters(self) -> AsyncCacheParameters:
return {
"maxsize": self._maxsize,
"typed": self._typed,
"always_checkpoint": self._always_checkpoint,
}
def cache_clear(self) -> None:
if cache := lru_cache_items.get(None):
cache.pop(self, None)
self._hits = self._misses = self._currsize = 0
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
# Easy case first: if maxsize == 0, no caching is done
if self._maxsize == 0:
value = await self.__wrapped__(*args, **kwargs)
self._misses += 1
return value
# The key is constructed as a flat tuple to avoid memory overhead
key: tuple[Any, ...] = args
if kwargs:
# initial_missing is used as a separator
key += (initial_missing,) + sum(kwargs.items(), ())
if self._typed:
key += tuple(type(arg) for arg in args)
if kwargs:
key += (initial_missing,) + tuple(type(val) for val in kwargs.values())
try:
cache = lru_cache_items.get()
except LookupError:
cache = WeakKeyDictionary()
lru_cache_items.set(cache)
try:
cache_entry = cache[self]
except KeyError:
cache_entry = cache[self] = OrderedDict()
cached_value: T | _InitialMissingType
try:
cached_value, lock = cache_entry[key]
except KeyError:
# We're the first task to call this function
cached_value, lock = (
initial_missing,
Lock(fast_acquire=not self._always_checkpoint),
)
cache_entry[key] = cached_value, lock
if lock is None:
# The value was already cached
self._hits += 1
cache_entry.move_to_end(key)
if self._always_checkpoint:
await checkpoint()
return cast(T, cached_value)
async with lock:
# Check if another task filled the cache while we acquired the lock
if (cached_value := cache_entry[key][0]) is initial_missing:
self._misses += 1
if self._maxsize is not None and self._currsize >= self._maxsize:
cache_entry.popitem(last=False)
else:
self._currsize += 1
value = await self.__wrapped__(*args, **kwargs)
cache_entry[key] = value, None
else:
# Another task filled the cache while we were waiting for the lock
self._hits += 1
cache_entry.move_to_end(key)
value = cast(T, cached_value)
return value
def __get__(
self, instance: object, owner: type | None = None
) -> _LRUMethodWrapper[T]:
wrapper = _LRUMethodWrapper(self, instance)
update_wrapper(wrapper, self.__wrapped__)
return wrapper
class _LRUCacheWrapper(Generic[T]):
def __init__(self, maxsize: int | None, typed: bool, always_checkpoint: bool):
self._maxsize = maxsize
self._typed = typed
self._always_checkpoint = always_checkpoint
@overload
def __call__( # type: ignore[overload-overlap]
self, func: Callable[P, Coroutine[Any, Any, T]], /
) -> AsyncLRUCacheWrapper[P, T]: ...
@overload
def __call__(
self, func: Callable[..., T], /
) -> functools._lru_cache_wrapper[T]: ...
def __call__(
self, f: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T], /
) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
if iscoroutinefunction(f):
return AsyncLRUCacheWrapper(
f, self._maxsize, self._typed, self._always_checkpoint
)
return functools.lru_cache(maxsize=self._maxsize, typed=self._typed)(f) # type: ignore[arg-type]
@overload
def cache( # type: ignore[overload-overlap]
func: Callable[P, Coroutine[Any, Any, T]], /
) -> AsyncLRUCacheWrapper[P, T]: ...
@overload
def cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
def cache(
func: Callable[..., T] | Callable[P, Coroutine[Any, Any, T]], /
) -> AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T]:
"""
A convenient shortcut for :func:`lru_cache` with ``maxsize=None``.
This is the asynchronous equivalent to :func:`functools.cache`.
"""
return lru_cache(maxsize=None)(func)
@overload
def lru_cache(
*, maxsize: int | None = ..., typed: bool = ..., always_checkpoint: bool = ...
) -> _LRUCacheWrapper[Any]: ...
@overload
def lru_cache( # type: ignore[overload-overlap]
func: Callable[P, Coroutine[Any, Any, T]], /
) -> AsyncLRUCacheWrapper[P, T]: ...
@overload
def lru_cache(func: Callable[..., T], /) -> functools._lru_cache_wrapper[T]: ...
def lru_cache(
func: Callable[P, Coroutine[Any, Any, T]] | Callable[..., T] | None = None,
/,
*,
maxsize: int | None = 128,
typed: bool = False,
always_checkpoint: bool = False,
) -> (
AsyncLRUCacheWrapper[P, T] | functools._lru_cache_wrapper[T] | _LRUCacheWrapper[Any]
):
"""
An asynchronous version of :func:`functools.lru_cache`.
If a synchronous function is passed, the standard library
:func:`functools.lru_cache` is applied instead.
:param always_checkpoint: if ``True``, every call to the cached function will be
guaranteed to yield control to the event loop at least once
.. note:: Caches and locks are managed on a per-event loop basis.
"""
if func is None:
return _LRUCacheWrapper[Any](maxsize, typed, always_checkpoint)
if not callable(func):
raise TypeError("the first argument must be callable")
return _LRUCacheWrapper[T](maxsize, typed, always_checkpoint)(func)
@overload
async def reduce(
function: Callable[[T, S], Awaitable[T]],
iterable: Iterable[S] | AsyncIterable[S],
/,
initial: T,
) -> T: ...
@overload
async def reduce(
function: Callable[[T, T], Awaitable[T]],
iterable: Iterable[T] | AsyncIterable[T],
/,
) -> T: ...
async def reduce( # type: ignore[misc]
function: Callable[[T, T], Awaitable[T]] | Callable[[T, S], Awaitable[T]],
iterable: Iterable[T] | Iterable[S] | AsyncIterable[T] | AsyncIterable[S],
/,
initial: T | _InitialMissingType = initial_missing,
) -> T:
"""
Asynchronous version of :func:`functools.reduce`.
:param function: a coroutine function that takes two arguments: the accumulated
value and the next element from the iterable
:param iterable: an iterable or async iterable
:param initial: the initial value (if missing, the first element of the iterable is
used as the initial value)
"""
element: Any
function_called = False
if isinstance(iterable, AsyncIterable):
async_it = iterable.__aiter__()
if initial is initial_missing:
try:
value = cast(T, await async_it.__anext__())
except StopAsyncIteration:
raise TypeError(
"reduce() of empty sequence with no initial value"
) from None
else:
value = cast(T, initial)
async for element in async_it:
value = await function(value, element)
function_called = True
elif isinstance(iterable, Iterable):
it = iter(iterable)
if initial is initial_missing:
try:
value = cast(T, next(it))
except StopIteration:
raise TypeError(
"reduce() of empty sequence with no initial value"
) from None
else:
value = cast(T, initial)
for element in it:
value = await function(value, element)
function_called = True
else:
raise TypeError("reduce() argument 2 must be an iterable or async iterable")
# Make sure there is at least one checkpoint, even if an empty iterable and an
# initial value were given
if not function_called:
await checkpoint()
return value

View File

@@ -0,0 +1,196 @@
from __future__ import annotations
__all__ = (
"EventLoopToken",
"RunvarToken",
"RunVar",
"checkpoint",
"checkpoint_if_cancelled",
"cancel_shielded_checkpoint",
"current_token",
)
import enum
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Generic, Literal, TypeVar, final, overload
from weakref import WeakKeyDictionary
from ._core._eventloop import get_async_backend
from .abc import AsyncBackend
T = TypeVar("T")
D = TypeVar("D")
async def checkpoint() -> None:
"""
Check for cancellation and allow the scheduler to switch to another task.
Equivalent to (but more efficient than)::
await checkpoint_if_cancelled()
await cancel_shielded_checkpoint()
.. versionadded:: 3.0
"""
await get_async_backend().checkpoint()
async def checkpoint_if_cancelled() -> None:
"""
Enter a checkpoint if the enclosing cancel scope has been cancelled.
This does not allow the scheduler to switch to a different task.
.. versionadded:: 3.0
"""
await get_async_backend().checkpoint_if_cancelled()
async def cancel_shielded_checkpoint() -> None:
"""
Allow the scheduler to switch to another task but without checking for cancellation.
Equivalent to (but potentially more efficient than)::
with CancelScope(shield=True):
await checkpoint()
.. versionadded:: 3.0
"""
await get_async_backend().cancel_shielded_checkpoint()
@final
@dataclass(frozen=True, repr=False)
class EventLoopToken:
"""
An opaque object that holds a reference to an event loop.
.. versionadded:: 4.11.0
"""
backend_class: type[AsyncBackend]
native_token: object
def current_token() -> EventLoopToken:
"""
Return a token object that can be used to call code in the current event loop from
another thread.
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
.. versionadded:: 4.11.0
"""
backend_class = get_async_backend()
raw_token = backend_class.current_token()
return EventLoopToken(backend_class, raw_token)
_run_vars: WeakKeyDictionary[object, dict[RunVar[Any], Any]] = WeakKeyDictionary()
class _NoValueSet(enum.Enum):
NO_VALUE_SET = enum.auto()
class RunvarToken(Generic[T]):
__slots__ = "_var", "_value", "_redeemed"
def __init__(self, var: RunVar[T], value: T | Literal[_NoValueSet.NO_VALUE_SET]):
self._var = var
self._value: T | Literal[_NoValueSet.NO_VALUE_SET] = value
self._redeemed = False
def __enter__(self) -> RunvarToken[T]:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._var.reset(self)
class RunVar(Generic[T]):
"""
Like a :class:`~contextvars.ContextVar`, except scoped to the running event loop.
Can be used as a context manager, Just like :class:`~contextvars.ContextVar`, that
will reset the variable to its previous value when the context block is exited.
"""
__slots__ = "_name", "_default"
NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
def __init__(
self, name: str, default: T | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
):
self._name = name
self._default = default
@property
def _current_vars(self) -> dict[RunVar[T], T]:
native_token = current_token().native_token
try:
return _run_vars[native_token]
except KeyError:
run_vars = _run_vars[native_token] = {}
return run_vars
@overload
def get(self, default: D) -> T | D: ...
@overload
def get(self) -> T: ...
def get(
self, default: D | Literal[_NoValueSet.NO_VALUE_SET] = NO_VALUE_SET
) -> T | D:
try:
return self._current_vars[self]
except KeyError:
if default is not RunVar.NO_VALUE_SET:
return default
elif self._default is not RunVar.NO_VALUE_SET:
return self._default
raise LookupError(
f'Run variable "{self._name}" has no value and no default set'
)
def set(self, value: T) -> RunvarToken[T]:
current_vars = self._current_vars
token = RunvarToken(self, current_vars.get(self, RunVar.NO_VALUE_SET))
current_vars[self] = value
return token
def reset(self, token: RunvarToken[T]) -> None:
if token._var is not self:
raise ValueError("This token does not belong to this RunVar")
if token._redeemed:
raise ValueError("This token has already been used")
if token._value is _NoValueSet.NO_VALUE_SET:
try:
del self._current_vars[self]
except KeyError:
pass
else:
self._current_vars[self] = token._value
token._redeemed = True
def __repr__(self) -> str:
return f"<RunVar name={self._name!r}>"

View File

@@ -0,0 +1,302 @@
from __future__ import annotations
import socket
import sys
from collections.abc import Callable, Generator, Iterator
from contextlib import ExitStack, contextmanager
from inspect import isasyncgenfunction, iscoroutinefunction, ismethod
from typing import Any, cast
import pytest
from _pytest.fixtures import SubRequest
from _pytest.outcomes import Exit
from . import get_available_backends
from ._core._eventloop import (
current_async_library,
get_async_backend,
reset_current_async_library,
set_current_async_library,
)
from ._core._exceptions import iterate_exceptions
from .abc import TestRunner
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
_current_runner: TestRunner | None = None
_runner_stack: ExitStack | None = None
_runner_leases = 0
def extract_backend_and_options(backend: object) -> tuple[str, dict[str, Any]]:
if isinstance(backend, str):
return backend, {}
elif isinstance(backend, tuple) and len(backend) == 2:
if isinstance(backend[0], str) and isinstance(backend[1], dict):
return cast(tuple[str, dict[str, Any]], backend)
raise TypeError("anyio_backend must be either a string or tuple of (string, dict)")
@contextmanager
def get_runner(
backend_name: str, backend_options: dict[str, Any]
) -> Iterator[TestRunner]:
global _current_runner, _runner_leases, _runner_stack
if _current_runner is None:
asynclib = get_async_backend(backend_name)
_runner_stack = ExitStack()
if current_async_library() is None:
# Since we're in control of the event loop, we can cache the name of the
# async library
token = set_current_async_library(backend_name)
_runner_stack.callback(reset_current_async_library, token)
backend_options = backend_options or {}
_current_runner = _runner_stack.enter_context(
asynclib.create_test_runner(backend_options)
)
_runner_leases += 1
try:
yield _current_runner
finally:
_runner_leases -= 1
if not _runner_leases:
assert _runner_stack is not None
_runner_stack.close()
_runner_stack = _current_runner = None
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addini(
"anyio_mode",
default="strict",
help='AnyIO plugin mode (either "strict" or "auto")',
)
def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line(
"markers",
"anyio: mark the (coroutine function) test to be run asynchronously via anyio.",
)
if (
config.getini("anyio_mode") == "auto"
and config.pluginmanager.has_plugin("asyncio")
and config.getini("asyncio_mode") == "auto"
):
config.issue_config_time_warning(
pytest.PytestConfigWarning(
"AnyIO auto mode has been enabled together with pytest-asyncio auto "
"mode. This may cause unexpected behavior."
),
1,
)
@pytest.hookimpl(hookwrapper=True)
def pytest_fixture_setup(fixturedef: Any, request: Any) -> Generator[Any]:
def wrapper(anyio_backend: Any, request: SubRequest, **kwargs: Any) -> Any:
# Rebind any fixture methods to the request instance
if (
request.instance
and ismethod(func)
and type(func.__self__) is type(request.instance)
):
local_func = func.__func__.__get__(request.instance)
else:
local_func = func
backend_name, backend_options = extract_backend_and_options(anyio_backend)
if has_backend_arg:
kwargs["anyio_backend"] = anyio_backend
if has_request_arg:
kwargs["request"] = request
with get_runner(backend_name, backend_options) as runner:
if isasyncgenfunction(local_func):
yield from runner.run_asyncgen_fixture(local_func, kwargs)
else:
yield runner.run_fixture(local_func, kwargs)
# Only apply this to coroutine functions and async generator functions in requests
# that involve the anyio_backend fixture
func = fixturedef.func
if isasyncgenfunction(func) or iscoroutinefunction(func):
if "anyio_backend" in request.fixturenames:
fixturedef.func = wrapper
original_argname = fixturedef.argnames
if not (has_backend_arg := "anyio_backend" in fixturedef.argnames):
fixturedef.argnames += ("anyio_backend",)
if not (has_request_arg := "request" in fixturedef.argnames):
fixturedef.argnames += ("request",)
try:
return (yield)
finally:
fixturedef.func = func
fixturedef.argnames = original_argname
return (yield)
@pytest.hookimpl(tryfirst=True)
def pytest_pycollect_makeitem(
collector: pytest.Module | pytest.Class, name: str, obj: object
) -> None:
if collector.istestfunction(obj, name):
inner_func = obj.hypothesis.inner_test if hasattr(obj, "hypothesis") else obj
if iscoroutinefunction(inner_func):
anyio_auto_mode = collector.config.getini("anyio_mode") == "auto"
marker = collector.get_closest_marker("anyio")
own_markers = getattr(obj, "pytestmark", ())
if (
anyio_auto_mode
or marker
or any(marker.name == "anyio" for marker in own_markers)
):
pytest.mark.usefixtures("anyio_backend")(obj)
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem: Any) -> bool | None:
def run_with_hypothesis(**kwargs: Any) -> None:
with get_runner(backend_name, backend_options) as runner:
runner.run_test(original_func, kwargs)
backend = pyfuncitem.funcargs.get("anyio_backend")
if backend:
backend_name, backend_options = extract_backend_and_options(backend)
if hasattr(pyfuncitem.obj, "hypothesis"):
# Wrap the inner test function unless it's already wrapped
original_func = pyfuncitem.obj.hypothesis.inner_test
if original_func.__qualname__ != run_with_hypothesis.__qualname__:
if iscoroutinefunction(original_func):
pyfuncitem.obj.hypothesis.inner_test = run_with_hypothesis
return None
if iscoroutinefunction(pyfuncitem.obj):
funcargs = pyfuncitem.funcargs
testargs = {arg: funcargs[arg] for arg in pyfuncitem._fixtureinfo.argnames}
with get_runner(backend_name, backend_options) as runner:
try:
runner.run_test(pyfuncitem.obj, testargs)
except ExceptionGroup as excgrp:
for exc in iterate_exceptions(excgrp):
if isinstance(exc, (Exit, KeyboardInterrupt, SystemExit)):
raise exc from excgrp
raise
return True
return None
@pytest.fixture(scope="module", params=get_available_backends())
def anyio_backend(request: Any) -> Any:
return request.param
@pytest.fixture
def anyio_backend_name(anyio_backend: Any) -> str:
if isinstance(anyio_backend, str):
return anyio_backend
else:
return anyio_backend[0]
@pytest.fixture
def anyio_backend_options(anyio_backend: Any) -> dict[str, Any]:
if isinstance(anyio_backend, str):
return {}
else:
return anyio_backend[1]
class FreePortFactory:
"""
Manages port generation based on specified socket kind, ensuring no duplicate
ports are generated.
This class provides functionality for generating available free ports on the
system. It is initialized with a specific socket kind and can generate ports
for given address families while avoiding reuse of previously generated ports.
Users should not instantiate this class directly, but use the
``free_tcp_port_factory`` and ``free_udp_port_factory`` fixtures instead. For simple
uses cases, ``free_tcp_port`` and ``free_udp_port`` can be used instead.
"""
def __init__(self, kind: socket.SocketKind) -> None:
self._kind = kind
self._generated = set[int]()
@property
def kind(self) -> socket.SocketKind:
"""
The type of socket connection (e.g., :data:`~socket.SOCK_STREAM` or
:data:`~socket.SOCK_DGRAM`) used to bind for checking port availability
"""
return self._kind
def __call__(self, family: socket.AddressFamily | None = None) -> int:
"""
Return an unbound port for the given address family.
:param family: if omitted, both IPv4 and IPv6 addresses will be tried
:return: a port number
"""
if family is not None:
families = [family]
else:
families = [socket.AF_INET]
if socket.has_ipv6:
families.append(socket.AF_INET6)
while True:
port = 0
with ExitStack() as stack:
for family in families:
sock = stack.enter_context(socket.socket(family, self._kind))
addr = "::1" if family == socket.AF_INET6 else "127.0.0.1"
try:
sock.bind((addr, port))
except OSError:
break
if not port:
port = sock.getsockname()[1]
else:
if port not in self._generated:
self._generated.add(port)
return port
@pytest.fixture(scope="session")
def free_tcp_port_factory() -> FreePortFactory:
return FreePortFactory(socket.SOCK_STREAM)
@pytest.fixture(scope="session")
def free_udp_port_factory() -> FreePortFactory:
return FreePortFactory(socket.SOCK_DGRAM)
@pytest.fixture
def free_tcp_port(free_tcp_port_factory: Callable[[], int]) -> int:
return free_tcp_port_factory()
@pytest.fixture
def free_udp_port(free_udp_port_factory: Callable[[], int]) -> int:
return free_udp_port_factory()

View File

@@ -0,0 +1,188 @@
from __future__ import annotations
__all__ = (
"BufferedByteReceiveStream",
"BufferedByteStream",
"BufferedConnectable",
)
import sys
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from typing import Any, SupportsIndex
from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
from ..abc import (
AnyByteReceiveStream,
AnyByteStream,
AnyByteStreamConnectable,
ByteReceiveStream,
ByteStream,
ByteStreamConnectable,
)
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@dataclass(eq=False)
class BufferedByteReceiveStream(ByteReceiveStream):
"""
Wraps any bytes-based receive stream and uses a buffer to provide sophisticated
receiving capabilities in the form of a byte stream.
"""
receive_stream: AnyByteReceiveStream
_buffer: bytearray = field(init=False, default_factory=bytearray)
_closed: bool = field(init=False, default=False)
async def aclose(self) -> None:
await self.receive_stream.aclose()
self._closed = True
@property
def buffer(self) -> bytes:
"""The bytes currently in the buffer."""
return bytes(self._buffer)
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return self.receive_stream.extra_attributes
def feed_data(self, data: Iterable[SupportsIndex], /) -> None:
"""
Append data directly into the buffer.
Any data in the buffer will be consumed by receive operations before receiving
anything from the wrapped stream.
:param data: the data to append to the buffer (can be bytes or anything else
that supports ``__index__()``)
"""
self._buffer.extend(data)
async def receive(self, max_bytes: int = 65536) -> bytes:
if self._closed:
raise ClosedResourceError
if self._buffer:
chunk = bytes(self._buffer[:max_bytes])
del self._buffer[:max_bytes]
return chunk
elif isinstance(self.receive_stream, ByteReceiveStream):
return await self.receive_stream.receive(max_bytes)
else:
# With a bytes-oriented object stream, we need to handle any surplus bytes
# we get from the receive() call
chunk = await self.receive_stream.receive()
if len(chunk) > max_bytes:
# Save the surplus bytes in the buffer
self._buffer.extend(chunk[max_bytes:])
return chunk[:max_bytes]
else:
return chunk
async def receive_exactly(self, nbytes: int) -> bytes:
"""
Read exactly the given amount of bytes from the stream.
:param nbytes: the number of bytes to read
:return: the bytes read
:raises ~anyio.IncompleteRead: if the stream was closed before the requested
amount of bytes could be read from the stream
"""
while True:
remaining = nbytes - len(self._buffer)
if remaining <= 0:
retval = self._buffer[:nbytes]
del self._buffer[:nbytes]
return bytes(retval)
try:
if isinstance(self.receive_stream, ByteReceiveStream):
chunk = await self.receive_stream.receive(remaining)
else:
chunk = await self.receive_stream.receive()
except EndOfStream as exc:
raise IncompleteRead from exc
self._buffer.extend(chunk)
async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes:
"""
Read from the stream until the delimiter is found or max_bytes have been read.
:param delimiter: the marker to look for in the stream
:param max_bytes: maximum number of bytes that will be read before raising
:exc:`~anyio.DelimiterNotFound`
:return: the bytes read (not including the delimiter)
:raises ~anyio.IncompleteRead: if the stream was closed before the delimiter
was found
:raises ~anyio.DelimiterNotFound: if the delimiter is not found within the
bytes read up to the maximum allowed
"""
delimiter_size = len(delimiter)
offset = 0
while True:
# Check if the delimiter can be found in the current buffer
index = self._buffer.find(delimiter, offset)
if index >= 0:
found = self._buffer[:index]
del self._buffer[: index + len(delimiter) :]
return bytes(found)
# Check if the buffer is already at or over the limit
if len(self._buffer) >= max_bytes:
raise DelimiterNotFound(max_bytes)
# Read more data into the buffer from the socket
try:
data = await self.receive_stream.receive()
except EndOfStream as exc:
raise IncompleteRead from exc
# Move the offset forward and add the new data to the buffer
offset = max(len(self._buffer) - delimiter_size + 1, 0)
self._buffer.extend(data)
class BufferedByteStream(BufferedByteReceiveStream, ByteStream):
"""
A full-duplex variant of :class:`BufferedByteReceiveStream`. All writes are passed
through to the wrapped stream as-is.
"""
def __init__(self, stream: AnyByteStream):
"""
:param stream: the stream to be wrapped
"""
super().__init__(stream)
self._stream = stream
@override
async def send_eof(self) -> None:
await self._stream.send_eof()
@override
async def send(self, item: bytes) -> None:
await self._stream.send(item)
class BufferedConnectable(ByteStreamConnectable):
def __init__(self, connectable: AnyByteStreamConnectable):
"""
:param connectable: the connectable to wrap
"""
self.connectable = connectable
@override
async def connect(self) -> BufferedByteStream:
stream = await self.connectable.connect()
return BufferedByteStream(stream)

View File

@@ -0,0 +1,154 @@
from __future__ import annotations
__all__ = (
"FileReadStream",
"FileStreamAttribute",
"FileWriteStream",
)
from collections.abc import Callable, Mapping
from io import SEEK_SET, UnsupportedOperation
from os import PathLike
from pathlib import Path
from typing import Any, BinaryIO, cast
from .. import (
BrokenResourceError,
ClosedResourceError,
EndOfStream,
TypedAttributeSet,
to_thread,
typed_attribute,
)
from ..abc import ByteReceiveStream, ByteSendStream
class FileStreamAttribute(TypedAttributeSet):
#: the open file descriptor
file: BinaryIO = typed_attribute()
#: the path of the file on the file system, if available (file must be a real file)
path: Path = typed_attribute()
#: the file number, if available (file must be a real file or a TTY)
fileno: int = typed_attribute()
class _BaseFileStream:
def __init__(self, file: BinaryIO):
self._file = file
async def aclose(self) -> None:
await to_thread.run_sync(self._file.close)
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
attributes: dict[Any, Callable[[], Any]] = {
FileStreamAttribute.file: lambda: self._file,
}
if hasattr(self._file, "name"):
attributes[FileStreamAttribute.path] = lambda: Path(self._file.name)
try:
self._file.fileno()
except UnsupportedOperation:
pass
else:
attributes[FileStreamAttribute.fileno] = lambda: self._file.fileno()
return attributes
class FileReadStream(_BaseFileStream, ByteReceiveStream):
"""
A byte stream that reads from a file in the file system.
:param file: a file that has been opened for reading in binary mode
.. versionadded:: 3.0
"""
@classmethod
async def from_path(cls, path: str | PathLike[str]) -> FileReadStream:
"""
Create a file read stream by opening the given file.
:param path: path of the file to read from
"""
file = await to_thread.run_sync(Path(path).open, "rb")
return cls(cast(BinaryIO, file))
async def receive(self, max_bytes: int = 65536) -> bytes:
try:
data = await to_thread.run_sync(self._file.read, max_bytes)
except ValueError:
raise ClosedResourceError from None
except OSError as exc:
raise BrokenResourceError from exc
if data:
return data
else:
raise EndOfStream
async def seek(self, position: int, whence: int = SEEK_SET) -> int:
"""
Seek the file to the given position.
.. seealso:: :meth:`io.IOBase.seek`
.. note:: Not all file descriptors are seekable.
:param position: position to seek the file to
:param whence: controls how ``position`` is interpreted
:return: the new absolute position
:raises OSError: if the file is not seekable
"""
return await to_thread.run_sync(self._file.seek, position, whence)
async def tell(self) -> int:
"""
Return the current stream position.
.. note:: Not all file descriptors are seekable.
:return: the current absolute position
:raises OSError: if the file is not seekable
"""
return await to_thread.run_sync(self._file.tell)
class FileWriteStream(_BaseFileStream, ByteSendStream):
"""
A byte stream that writes to a file in the file system.
:param file: a file that has been opened for writing in binary mode
.. versionadded:: 3.0
"""
@classmethod
async def from_path(
cls, path: str | PathLike[str], append: bool = False
) -> FileWriteStream:
"""
Create a file write stream by opening the given file for writing.
:param path: path of the file to write to
:param append: if ``True``, open the file for appending; if ``False``, any
existing file at the given path will be truncated
"""
mode = "ab" if append else "wb"
file = await to_thread.run_sync(Path(path).open, mode)
return cls(cast(BinaryIO, file))
async def send(self, item: bytes) -> None:
try:
await to_thread.run_sync(self._file.write, item)
except ValueError:
raise ClosedResourceError from None
except OSError as exc:
raise BrokenResourceError from exc

View File

@@ -0,0 +1,325 @@
from __future__ import annotations
__all__ = (
"MemoryObjectReceiveStream",
"MemoryObjectSendStream",
"MemoryObjectStreamStatistics",
)
import warnings
from collections import OrderedDict, deque
from dataclasses import dataclass, field
from types import TracebackType
from typing import Generic, NamedTuple, TypeVar
from .. import (
BrokenResourceError,
ClosedResourceError,
EndOfStream,
WouldBlock,
)
from .._core._testing import TaskInfo, get_current_task
from ..abc import Event, ObjectReceiveStream, ObjectSendStream
from ..lowlevel import checkpoint
T_Item = TypeVar("T_Item")
T_co = TypeVar("T_co", covariant=True)
T_contra = TypeVar("T_contra", contravariant=True)
class MemoryObjectStreamStatistics(NamedTuple):
current_buffer_used: int #: number of items stored in the buffer
#: maximum number of items that can be stored on this stream (or :data:`math.inf`)
max_buffer_size: float
open_send_streams: int #: number of unclosed clones of the send stream
open_receive_streams: int #: number of unclosed clones of the receive stream
#: number of tasks blocked on :meth:`MemoryObjectSendStream.send`
tasks_waiting_send: int
#: number of tasks blocked on :meth:`MemoryObjectReceiveStream.receive`
tasks_waiting_receive: int
@dataclass(eq=False)
class _MemoryObjectItemReceiver(Generic[T_Item]):
task_info: TaskInfo = field(init=False, default_factory=get_current_task)
item: T_Item = field(init=False)
def __repr__(self) -> str:
# When item is not defined, we get following error with default __repr__:
# AttributeError: 'MemoryObjectItemReceiver' object has no attribute 'item'
item = getattr(self, "item", None)
return f"{self.__class__.__name__}(task_info={self.task_info}, item={item!r})"
@dataclass(eq=False)
class _MemoryObjectStreamState(Generic[T_Item]):
max_buffer_size: float = field()
buffer: deque[T_Item] = field(init=False, default_factory=deque)
open_send_channels: int = field(init=False, default=0)
open_receive_channels: int = field(init=False, default=0)
waiting_receivers: OrderedDict[Event, _MemoryObjectItemReceiver[T_Item]] = field(
init=False, default_factory=OrderedDict
)
waiting_senders: OrderedDict[Event, T_Item] = field(
init=False, default_factory=OrderedDict
)
def statistics(self) -> MemoryObjectStreamStatistics:
return MemoryObjectStreamStatistics(
len(self.buffer),
self.max_buffer_size,
self.open_send_channels,
self.open_receive_channels,
len(self.waiting_senders),
len(self.waiting_receivers),
)
@dataclass(eq=False)
class MemoryObjectReceiveStream(Generic[T_co], ObjectReceiveStream[T_co]):
_state: _MemoryObjectStreamState[T_co]
_closed: bool = field(init=False, default=False)
def __post_init__(self) -> None:
self._state.open_receive_channels += 1
def receive_nowait(self) -> T_co:
"""
Receive the next item if it can be done without waiting.
:return: the received item
:raises ~anyio.ClosedResourceError: if this send stream has been closed
:raises ~anyio.EndOfStream: if the buffer is empty and this stream has been
closed from the sending end
:raises ~anyio.WouldBlock: if there are no items in the buffer and no tasks
waiting to send
"""
if self._closed:
raise ClosedResourceError
if self._state.waiting_senders:
# Get the item from the next sender
send_event, item = self._state.waiting_senders.popitem(last=False)
self._state.buffer.append(item)
send_event.set()
if self._state.buffer:
return self._state.buffer.popleft()
elif not self._state.open_send_channels:
raise EndOfStream
raise WouldBlock
async def receive(self) -> T_co:
await checkpoint()
try:
return self.receive_nowait()
except WouldBlock:
# Add ourselves in the queue
receive_event = Event()
receiver = _MemoryObjectItemReceiver[T_co]()
self._state.waiting_receivers[receive_event] = receiver
try:
await receive_event.wait()
finally:
self._state.waiting_receivers.pop(receive_event, None)
try:
return receiver.item
except AttributeError:
raise EndOfStream from None
def clone(self) -> MemoryObjectReceiveStream[T_co]:
"""
Create a clone of this receive stream.
Each clone can be closed separately. Only when all clones have been closed will
the receiving end of the memory stream be considered closed by the sending ends.
:return: the cloned stream
"""
if self._closed:
raise ClosedResourceError
return MemoryObjectReceiveStream(_state=self._state)
def close(self) -> None:
"""
Close the stream.
This works the exact same way as :meth:`aclose`, but is provided as a special
case for the benefit of synchronous callbacks.
"""
if not self._closed:
self._closed = True
self._state.open_receive_channels -= 1
if self._state.open_receive_channels == 0:
send_events = list(self._state.waiting_senders.keys())
for event in send_events:
event.set()
async def aclose(self) -> None:
self.close()
def statistics(self) -> MemoryObjectStreamStatistics:
"""
Return statistics about the current state of this stream.
.. versionadded:: 3.0
"""
return self._state.statistics()
def __enter__(self) -> MemoryObjectReceiveStream[T_co]:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def __del__(self) -> None:
if not self._closed:
warnings.warn(
f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
ResourceWarning,
stacklevel=1,
source=self,
)
@dataclass(eq=False)
class MemoryObjectSendStream(Generic[T_contra], ObjectSendStream[T_contra]):
_state: _MemoryObjectStreamState[T_contra]
_closed: bool = field(init=False, default=False)
def __post_init__(self) -> None:
self._state.open_send_channels += 1
def send_nowait(self, item: T_contra) -> None:
"""
Send an item immediately if it can be done without waiting.
:param item: the item to send
:raises ~anyio.ClosedResourceError: if this send stream has been closed
:raises ~anyio.BrokenResourceError: if the stream has been closed from the
receiving end
:raises ~anyio.WouldBlock: if the buffer is full and there are no tasks waiting
to receive
"""
if self._closed:
raise ClosedResourceError
if not self._state.open_receive_channels:
raise BrokenResourceError
while self._state.waiting_receivers:
receive_event, receiver = self._state.waiting_receivers.popitem(last=False)
if not receiver.task_info.has_pending_cancellation():
receiver.item = item
receive_event.set()
return
if len(self._state.buffer) < self._state.max_buffer_size:
self._state.buffer.append(item)
else:
raise WouldBlock
async def send(self, item: T_contra) -> None:
"""
Send an item to the stream.
If the buffer is full, this method blocks until there is again room in the
buffer or the item can be sent directly to a receiver.
:param item: the item to send
:raises ~anyio.ClosedResourceError: if this send stream has been closed
:raises ~anyio.BrokenResourceError: if the stream has been closed from the
receiving end
"""
await checkpoint()
try:
self.send_nowait(item)
except WouldBlock:
# Wait until there's someone on the receiving end
send_event = Event()
self._state.waiting_senders[send_event] = item
try:
await send_event.wait()
except BaseException:
self._state.waiting_senders.pop(send_event, None)
raise
if send_event in self._state.waiting_senders:
del self._state.waiting_senders[send_event]
raise BrokenResourceError from None
def clone(self) -> MemoryObjectSendStream[T_contra]:
"""
Create a clone of this send stream.
Each clone can be closed separately. Only when all clones have been closed will
the sending end of the memory stream be considered closed by the receiving ends.
:return: the cloned stream
"""
if self._closed:
raise ClosedResourceError
return MemoryObjectSendStream(_state=self._state)
def close(self) -> None:
"""
Close the stream.
This works the exact same way as :meth:`aclose`, but is provided as a special
case for the benefit of synchronous callbacks.
"""
if not self._closed:
self._closed = True
self._state.open_send_channels -= 1
if self._state.open_send_channels == 0:
receive_events = list(self._state.waiting_receivers.keys())
self._state.waiting_receivers.clear()
for event in receive_events:
event.set()
async def aclose(self) -> None:
self.close()
def statistics(self) -> MemoryObjectStreamStatistics:
"""
Return statistics about the current state of this stream.
.. versionadded:: 3.0
"""
return self._state.statistics()
def __enter__(self) -> MemoryObjectSendStream[T_contra]:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def __del__(self) -> None:
if not self._closed:
warnings.warn(
f"Unclosed <{self.__class__.__name__} at {id(self):x}>",
ResourceWarning,
stacklevel=1,
source=self,
)

View File

@@ -0,0 +1,147 @@
from __future__ import annotations
__all__ = (
"MultiListener",
"StapledByteStream",
"StapledObjectStream",
)
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Generic, TypeVar
from ..abc import (
ByteReceiveStream,
ByteSendStream,
ByteStream,
Listener,
ObjectReceiveStream,
ObjectSendStream,
ObjectStream,
TaskGroup,
)
T_Item = TypeVar("T_Item")
T_Stream = TypeVar("T_Stream")
@dataclass(eq=False)
class StapledByteStream(ByteStream):
"""
Combines two byte streams into a single, bidirectional byte stream.
Extra attributes will be provided from both streams, with the receive stream
providing the values in case of a conflict.
:param ByteSendStream send_stream: the sending byte stream
:param ByteReceiveStream receive_stream: the receiving byte stream
"""
send_stream: ByteSendStream
receive_stream: ByteReceiveStream
async def receive(self, max_bytes: int = 65536) -> bytes:
return await self.receive_stream.receive(max_bytes)
async def send(self, item: bytes) -> None:
await self.send_stream.send(item)
async def send_eof(self) -> None:
await self.send_stream.aclose()
async def aclose(self) -> None:
await self.send_stream.aclose()
await self.receive_stream.aclose()
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
**self.send_stream.extra_attributes,
**self.receive_stream.extra_attributes,
}
@dataclass(eq=False)
class StapledObjectStream(Generic[T_Item], ObjectStream[T_Item]):
"""
Combines two object streams into a single, bidirectional object stream.
Extra attributes will be provided from both streams, with the receive stream
providing the values in case of a conflict.
:param ObjectSendStream send_stream: the sending object stream
:param ObjectReceiveStream receive_stream: the receiving object stream
"""
send_stream: ObjectSendStream[T_Item]
receive_stream: ObjectReceiveStream[T_Item]
async def receive(self) -> T_Item:
return await self.receive_stream.receive()
async def send(self, item: T_Item) -> None:
await self.send_stream.send(item)
async def send_eof(self) -> None:
await self.send_stream.aclose()
async def aclose(self) -> None:
await self.send_stream.aclose()
await self.receive_stream.aclose()
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
**self.send_stream.extra_attributes,
**self.receive_stream.extra_attributes,
}
@dataclass(eq=False)
class MultiListener(Generic[T_Stream], Listener[T_Stream]):
"""
Combines multiple listeners into one, serving connections from all of them at once.
Any MultiListeners in the given collection of listeners will have their listeners
moved into this one.
Extra attributes are provided from each listener, with each successive listener
overriding any conflicting attributes from the previous one.
:param listeners: listeners to serve
:type listeners: Sequence[Listener[T_Stream]]
"""
listeners: Sequence[Listener[T_Stream]]
def __post_init__(self) -> None:
listeners: list[Listener[T_Stream]] = []
for listener in self.listeners:
if isinstance(listener, MultiListener):
listeners.extend(listener.listeners)
del listener.listeners[:] # type: ignore[attr-defined]
else:
listeners.append(listener)
self.listeners = listeners
async def serve(
self, handler: Callable[[T_Stream], Any], task_group: TaskGroup | None = None
) -> None:
from .. import create_task_group
async with create_task_group() as tg:
for listener in self.listeners:
tg.start_soon(listener.serve, handler, task_group)
async def aclose(self) -> None:
for listener in self.listeners:
await listener.aclose()
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
attributes: dict = {}
for listener in self.listeners:
attributes.update(listener.extra_attributes)
return attributes

View File

@@ -0,0 +1,176 @@
from __future__ import annotations
__all__ = (
"TextConnectable",
"TextReceiveStream",
"TextSendStream",
"TextStream",
)
import codecs
import sys
from collections.abc import Callable, Mapping
from dataclasses import InitVar, dataclass, field
from typing import Any
from ..abc import (
AnyByteReceiveStream,
AnyByteSendStream,
AnyByteStream,
AnyByteStreamConnectable,
ObjectReceiveStream,
ObjectSendStream,
ObjectStream,
ObjectStreamConnectable,
)
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@dataclass(eq=False)
class TextReceiveStream(ObjectReceiveStream[str]):
"""
Stream wrapper that decodes bytes to strings using the given encoding.
Decoding is done using :class:`~codecs.IncrementalDecoder` which returns any
completely received unicode characters as soon as they come in.
:param transport_stream: any bytes-based receive stream
:param encoding: character encoding to use for decoding bytes to strings (defaults
to ``utf-8``)
:param errors: handling scheme for decoding errors (defaults to ``strict``; see the
`codecs module documentation`_ for a comprehensive list of options)
.. _codecs module documentation:
https://docs.python.org/3/library/codecs.html#codec-objects
"""
transport_stream: AnyByteReceiveStream
encoding: InitVar[str] = "utf-8"
errors: InitVar[str] = "strict"
_decoder: codecs.IncrementalDecoder = field(init=False)
def __post_init__(self, encoding: str, errors: str) -> None:
decoder_class = codecs.getincrementaldecoder(encoding)
self._decoder = decoder_class(errors=errors)
async def receive(self) -> str:
while True:
chunk = await self.transport_stream.receive()
decoded = self._decoder.decode(chunk)
if decoded:
return decoded
async def aclose(self) -> None:
await self.transport_stream.aclose()
self._decoder.reset()
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return self.transport_stream.extra_attributes
@dataclass(eq=False)
class TextSendStream(ObjectSendStream[str]):
"""
Sends strings to the wrapped stream as bytes using the given encoding.
:param AnyByteSendStream transport_stream: any bytes-based send stream
:param str encoding: character encoding to use for encoding strings to bytes
(defaults to ``utf-8``)
:param str errors: handling scheme for encoding errors (defaults to ``strict``; see
the `codecs module documentation`_ for a comprehensive list of options)
.. _codecs module documentation:
https://docs.python.org/3/library/codecs.html#codec-objects
"""
transport_stream: AnyByteSendStream
encoding: InitVar[str] = "utf-8"
errors: str = "strict"
_encoder: Callable[..., tuple[bytes, int]] = field(init=False)
def __post_init__(self, encoding: str) -> None:
self._encoder = codecs.getencoder(encoding)
async def send(self, item: str) -> None:
encoded = self._encoder(item, self.errors)[0]
await self.transport_stream.send(encoded)
async def aclose(self) -> None:
await self.transport_stream.aclose()
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return self.transport_stream.extra_attributes
@dataclass(eq=False)
class TextStream(ObjectStream[str]):
"""
A bidirectional stream that decodes bytes to strings on receive and encodes strings
to bytes on send.
Extra attributes will be provided from both streams, with the receive stream
providing the values in case of a conflict.
:param AnyByteStream transport_stream: any bytes-based stream
:param str encoding: character encoding to use for encoding/decoding strings to/from
bytes (defaults to ``utf-8``)
:param str errors: handling scheme for encoding errors (defaults to ``strict``; see
the `codecs module documentation`_ for a comprehensive list of options)
.. _codecs module documentation:
https://docs.python.org/3/library/codecs.html#codec-objects
"""
transport_stream: AnyByteStream
encoding: InitVar[str] = "utf-8"
errors: InitVar[str] = "strict"
_receive_stream: TextReceiveStream = field(init=False)
_send_stream: TextSendStream = field(init=False)
def __post_init__(self, encoding: str, errors: str) -> None:
self._receive_stream = TextReceiveStream(
self.transport_stream, encoding=encoding, errors=errors
)
self._send_stream = TextSendStream(
self.transport_stream, encoding=encoding, errors=errors
)
async def receive(self) -> str:
return await self._receive_stream.receive()
async def send(self, item: str) -> None:
await self._send_stream.send(item)
async def send_eof(self) -> None:
await self.transport_stream.send_eof()
async def aclose(self) -> None:
await self._send_stream.aclose()
await self._receive_stream.aclose()
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
**self._send_stream.extra_attributes,
**self._receive_stream.extra_attributes,
}
class TextConnectable(ObjectStreamConnectable[str]):
def __init__(self, connectable: AnyByteStreamConnectable):
"""
:param connectable: the bytestream endpoint to wrap
"""
self.connectable = connectable
@override
async def connect(self) -> TextStream:
stream = await self.connectable.connect()
return TextStream(stream)

View File

@@ -0,0 +1,424 @@
from __future__ import annotations
__all__ = (
"TLSAttribute",
"TLSConnectable",
"TLSListener",
"TLSStream",
)
import logging
import re
import ssl
import sys
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import wraps
from ssl import SSLContext
from typing import Any, TypeVar
from .. import (
BrokenResourceError,
EndOfStream,
aclose_forcefully,
get_cancelled_exc_class,
to_thread,
)
from .._core._typedattr import TypedAttributeSet, typed_attribute
from ..abc import (
AnyByteStream,
AnyByteStreamConnectable,
ByteStream,
ByteStreamConnectable,
Listener,
TaskGroup,
)
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
_PCTRTT: TypeAlias = tuple[tuple[str, str], ...]
_PCTRTTT: TypeAlias = tuple[_PCTRTT, ...]
class TLSAttribute(TypedAttributeSet):
"""Contains Transport Layer Security related attributes."""
#: the selected ALPN protocol
alpn_protocol: str | None = typed_attribute()
#: the channel binding for type ``tls-unique``
channel_binding_tls_unique: bytes = typed_attribute()
#: the selected cipher
cipher: tuple[str, str, int] = typed_attribute()
#: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert`
# for more information)
peer_certificate: None | (dict[str, str | _PCTRTTT | _PCTRTT]) = typed_attribute()
#: the peer certificate in binary form
peer_certificate_binary: bytes | None = typed_attribute()
#: ``True`` if this is the server side of the connection
server_side: bool = typed_attribute()
#: ciphers shared by the client during the TLS handshake (``None`` if this is the
#: client side)
shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute()
#: the :class:`~ssl.SSLObject` used for encryption
ssl_object: ssl.SSLObject = typed_attribute()
#: ``True`` if this stream does (and expects) a closing TLS handshake when the
#: stream is being closed
standard_compatible: bool = typed_attribute()
#: the TLS protocol version (e.g. ``TLSv1.2``)
tls_version: str = typed_attribute()
@dataclass(eq=False)
class TLSStream(ByteStream):
"""
A stream wrapper that encrypts all sent data and decrypts received data.
This class has no public initializer; use :meth:`wrap` instead.
All extra attributes from :class:`~TLSAttribute` are supported.
:var AnyByteStream transport_stream: the wrapped stream
"""
transport_stream: AnyByteStream
standard_compatible: bool
_ssl_object: ssl.SSLObject
_read_bio: ssl.MemoryBIO
_write_bio: ssl.MemoryBIO
@classmethod
async def wrap(
cls,
transport_stream: AnyByteStream,
*,
server_side: bool | None = None,
hostname: str | None = None,
ssl_context: ssl.SSLContext | None = None,
standard_compatible: bool = True,
) -> TLSStream:
"""
Wrap an existing stream with Transport Layer Security.
This performs a TLS handshake with the peer.
:param transport_stream: a bytes-transporting stream to wrap
:param server_side: ``True`` if this is the server side of the connection,
``False`` if this is the client side (if omitted, will be set to ``False``
if ``hostname`` has been provided, ``False`` otherwise). Used only to create
a default context when an explicit context has not been provided.
:param hostname: host name of the peer (if host name checking is desired)
:param ssl_context: the SSLContext object to use (if not provided, a secure
default will be created)
:param standard_compatible: if ``False``, skip the closing handshake when
closing the connection, and don't raise an exception if the peer does the
same
:raises ~ssl.SSLError: if the TLS handshake fails
"""
if server_side is None:
server_side = not hostname
if not ssl_context:
purpose = (
ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH
)
ssl_context = ssl.create_default_context(purpose)
# Re-enable detection of unexpected EOFs if it was disabled by Python
if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"):
ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF
bio_in = ssl.MemoryBIO()
bio_out = ssl.MemoryBIO()
# External SSLContext implementations may do blocking I/O in wrap_bio(),
# but the standard library implementation won't
if type(ssl_context) is ssl.SSLContext:
ssl_object = ssl_context.wrap_bio(
bio_in, bio_out, server_side=server_side, server_hostname=hostname
)
else:
ssl_object = await to_thread.run_sync(
ssl_context.wrap_bio,
bio_in,
bio_out,
server_side,
hostname,
None,
)
wrapper = cls(
transport_stream=transport_stream,
standard_compatible=standard_compatible,
_ssl_object=ssl_object,
_read_bio=bio_in,
_write_bio=bio_out,
)
await wrapper._call_sslobject_method(ssl_object.do_handshake)
return wrapper
async def _call_sslobject_method(
self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT]
) -> T_Retval:
while True:
try:
result = func(*args)
except ssl.SSLWantReadError:
try:
# Flush any pending writes first
if self._write_bio.pending:
await self.transport_stream.send(self._write_bio.read())
data = await self.transport_stream.receive()
except EndOfStream:
self._read_bio.write_eof()
except OSError as exc:
self._read_bio.write_eof()
self._write_bio.write_eof()
raise BrokenResourceError from exc
else:
self._read_bio.write(data)
except ssl.SSLWantWriteError:
await self.transport_stream.send(self._write_bio.read())
except ssl.SSLSyscallError as exc:
self._read_bio.write_eof()
self._write_bio.write_eof()
raise BrokenResourceError from exc
except ssl.SSLError as exc:
self._read_bio.write_eof()
self._write_bio.write_eof()
if isinstance(exc, ssl.SSLEOFError) or (
exc.strerror and "UNEXPECTED_EOF_WHILE_READING" in exc.strerror
):
if self.standard_compatible:
raise BrokenResourceError from exc
else:
raise EndOfStream from None
raise
else:
# Flush any pending writes first
if self._write_bio.pending:
await self.transport_stream.send(self._write_bio.read())
return result
async def unwrap(self) -> tuple[AnyByteStream, bytes]:
"""
Does the TLS closing handshake.
:return: a tuple of (wrapped byte stream, bytes left in the read buffer)
"""
await self._call_sslobject_method(self._ssl_object.unwrap)
self._read_bio.write_eof()
self._write_bio.write_eof()
return self.transport_stream, self._read_bio.read()
async def aclose(self) -> None:
if self.standard_compatible:
try:
await self.unwrap()
except BaseException:
await aclose_forcefully(self.transport_stream)
raise
await self.transport_stream.aclose()
async def receive(self, max_bytes: int = 65536) -> bytes:
data = await self._call_sslobject_method(self._ssl_object.read, max_bytes)
if not data:
raise EndOfStream
return data
async def send(self, item: bytes) -> None:
await self._call_sslobject_method(self._ssl_object.write, item)
async def send_eof(self) -> None:
tls_version = self.extra(TLSAttribute.tls_version)
match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version)
if match:
major, minor = int(match.group(1)), int(match.group(2) or 0)
if (major, minor) < (1, 3):
raise NotImplementedError(
f"send_eof() requires at least TLSv1.3; current "
f"session uses {tls_version}"
)
raise NotImplementedError(
"send_eof() has not yet been implemented for TLS streams"
)
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
**self.transport_stream.extra_attributes,
TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
TLSAttribute.channel_binding_tls_unique: (
self._ssl_object.get_channel_binding
),
TLSAttribute.cipher: self._ssl_object.cipher,
TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False),
TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert(
True
),
TLSAttribute.server_side: lambda: self._ssl_object.server_side,
TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers()
if self._ssl_object.server_side
else None,
TLSAttribute.standard_compatible: lambda: self.standard_compatible,
TLSAttribute.ssl_object: lambda: self._ssl_object,
TLSAttribute.tls_version: self._ssl_object.version,
}
@dataclass(eq=False)
class TLSListener(Listener[TLSStream]):
"""
A convenience listener that wraps another listener and auto-negotiates a TLS session
on every accepted connection.
If the TLS handshake times out or raises an exception,
:meth:`handle_handshake_error` is called to do whatever post-mortem processing is
deemed necessary.
Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute.
:param Listener listener: the listener to wrap
:param ssl_context: the SSL context object
:param standard_compatible: a flag passed through to :meth:`TLSStream.wrap`
:param handshake_timeout: time limit for the TLS handshake
(passed to :func:`~anyio.fail_after`)
"""
listener: Listener[Any]
ssl_context: ssl.SSLContext
standard_compatible: bool = True
handshake_timeout: float = 30
@staticmethod
async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None:
"""
Handle an exception raised during the TLS handshake.
This method does 3 things:
#. Forcefully closes the original stream
#. Logs the exception (unless it was a cancellation exception) using the
``anyio.streams.tls`` logger
#. Reraises the exception if it was a base exception or a cancellation exception
:param exc: the exception
:param stream: the original stream
"""
await aclose_forcefully(stream)
# Log all except cancellation exceptions
if not isinstance(exc, get_cancelled_exc_class()):
# CPython (as of 3.11.5) returns incorrect `sys.exc_info()` here when using
# any asyncio implementation, so we explicitly pass the exception to log
# (https://github.com/python/cpython/issues/108668). Trio does not have this
# issue because it works around the CPython bug.
logging.getLogger(__name__).exception(
"Error during TLS handshake", exc_info=exc
)
# Only reraise base exceptions and cancellation exceptions
if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()):
raise
async def serve(
self,
handler: Callable[[TLSStream], Any],
task_group: TaskGroup | None = None,
) -> None:
@wraps(handler)
async def handler_wrapper(stream: AnyByteStream) -> None:
from .. import fail_after
try:
with fail_after(self.handshake_timeout):
wrapped_stream = await TLSStream.wrap(
stream,
ssl_context=self.ssl_context,
standard_compatible=self.standard_compatible,
)
except BaseException as exc:
await self.handle_handshake_error(exc, stream)
else:
await handler(wrapped_stream)
await self.listener.serve(handler_wrapper, task_group)
async def aclose(self) -> None:
await self.listener.aclose()
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
TLSAttribute.standard_compatible: lambda: self.standard_compatible,
}
class TLSConnectable(ByteStreamConnectable):
"""
Wraps another connectable and does TLS negotiation after a successful connection.
:param connectable: the connectable to wrap
:param hostname: host name of the server (if host name checking is desired)
:param ssl_context: the SSLContext object to use (if not provided, a secure default
will be created)
:param standard_compatible: if ``False``, skip the closing handshake when closing
the connection, and don't raise an exception if the server does the same
"""
def __init__(
self,
connectable: AnyByteStreamConnectable,
*,
hostname: str | None = None,
ssl_context: ssl.SSLContext | None = None,
standard_compatible: bool = True,
) -> None:
self.connectable = connectable
self.ssl_context: SSLContext = ssl_context or ssl.create_default_context(
ssl.Purpose.SERVER_AUTH
)
if not isinstance(self.ssl_context, ssl.SSLContext):
raise TypeError(
"ssl_context must be an instance of ssl.SSLContext, not "
f"{type(self.ssl_context).__name__}"
)
self.hostname = hostname
self.standard_compatible = standard_compatible
@override
async def connect(self) -> TLSStream:
stream = await self.connectable.connect()
try:
return await TLSStream.wrap(
stream,
hostname=self.hostname,
ssl_context=self.ssl_context,
standard_compatible=self.standard_compatible,
)
except BaseException:
await aclose_forcefully(stream)
raise

View File

@@ -0,0 +1,246 @@
from __future__ import annotations
__all__ = (
"run_sync",
"current_default_interpreter_limiter",
)
import atexit
import os
import sys
from collections import deque
from collections.abc import Callable
from typing import Any, Final, TypeVar
from . import current_time, to_thread
from ._core._exceptions import BrokenWorkerInterpreter
from ._core._synchronization import CapacityLimiter
from .lowlevel import RunVar
if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack
if sys.version_info >= (3, 14):
from concurrent.interpreters import ExecutionFailed, create
def _interp_call(
func: Callable[..., Any], args: tuple[Any, ...]
) -> tuple[Any, bool]:
try:
retval = func(*args)
except BaseException as exc:
return exc, True
else:
return retval, False
class _Worker:
last_used: float = 0
def __init__(self) -> None:
self._interpreter = create()
def destroy(self) -> None:
self._interpreter.close()
def call(
self,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
) -> T_Retval:
try:
res, is_exception = self._interpreter.call(_interp_call, func, args)
except ExecutionFailed as exc:
raise BrokenWorkerInterpreter(exc.excinfo) from exc
if is_exception:
raise res
return res
elif sys.version_info >= (3, 13):
import _interpqueues
import _interpreters
UNBOUND: Final = 2 # I have no clue how this works, but it was used in the stdlib
FMT_UNPICKLED: Final = 0
FMT_PICKLED: Final = 1
QUEUE_PICKLE_ARGS: Final = (FMT_PICKLED, UNBOUND)
QUEUE_UNPICKLE_ARGS: Final = (FMT_UNPICKLED, UNBOUND)
_run_func = compile(
"""
import _interpqueues
from _interpreters import NotShareableError
from pickle import loads, dumps, HIGHEST_PROTOCOL
QUEUE_PICKLE_ARGS = (1, 2)
QUEUE_UNPICKLE_ARGS = (0, 2)
item = _interpqueues.get(queue_id)[0]
try:
func, args = loads(item)
retval = func(*args)
except BaseException as exc:
is_exception = True
retval = exc
else:
is_exception = False
try:
_interpqueues.put(queue_id, (retval, is_exception), *QUEUE_UNPICKLE_ARGS)
except NotShareableError:
retval = dumps(retval, HIGHEST_PROTOCOL)
_interpqueues.put(queue_id, (retval, is_exception), *QUEUE_PICKLE_ARGS)
""",
"<string>",
"exec",
)
class _Worker:
last_used: float = 0
def __init__(self) -> None:
self._interpreter_id = _interpreters.create()
self._queue_id = _interpqueues.create(1, *QUEUE_UNPICKLE_ARGS)
_interpreters.set___main___attrs(
self._interpreter_id, {"queue_id": self._queue_id}
)
def destroy(self) -> None:
_interpqueues.destroy(self._queue_id)
_interpreters.destroy(self._interpreter_id)
def call(
self,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
) -> T_Retval:
import pickle
item = pickle.dumps((func, args), pickle.HIGHEST_PROTOCOL)
_interpqueues.put(self._queue_id, item, *QUEUE_PICKLE_ARGS)
exc_info = _interpreters.exec(self._interpreter_id, _run_func)
if exc_info:
raise BrokenWorkerInterpreter(exc_info)
res = _interpqueues.get(self._queue_id)
(res, is_exception), fmt = res[:2]
if fmt == FMT_PICKLED:
res = pickle.loads(res)
if is_exception:
raise res
return res
else:
class _Worker:
last_used: float = 0
def __init__(self) -> None:
raise RuntimeError("subinterpreters require at least Python 3.13")
def call(
self,
func: Callable[..., T_Retval],
args: tuple[Any, ...],
) -> T_Retval:
raise NotImplementedError
def destroy(self) -> None:
pass
DEFAULT_CPU_COUNT: Final = 8 # this is just an arbitrarily selected value
MAX_WORKER_IDLE_TIME = (
30 # seconds a subinterpreter can be idle before becoming eligible for pruning
)
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
_idle_workers = RunVar[deque[_Worker]]("_available_workers")
_default_interpreter_limiter = RunVar[CapacityLimiter]("_default_interpreter_limiter")
def _stop_workers(workers: deque[_Worker]) -> None:
for worker in workers:
worker.destroy()
workers.clear()
async def run_sync(
func: Callable[[Unpack[PosArgsT]], T_Retval],
*args: Unpack[PosArgsT],
limiter: CapacityLimiter | None = None,
) -> T_Retval:
"""
Call the given function with the given arguments in a subinterpreter.
.. warning:: On Python 3.13, the :mod:`concurrent.interpreters` module was not yet
available, so the code path for that Python version relies on an undocumented,
private API. As such, it is recommended to not rely on this function for anything
mission-critical on Python 3.13.
:param func: a callable
:param args: the positional arguments for the callable
:param limiter: capacity limiter to use to limit the total number of subinterpreters
running (if omitted, the default limiter is used)
:return: the result of the call
:raises BrokenWorkerInterpreter: if there's an internal error in a subinterpreter
"""
if limiter is None:
limiter = current_default_interpreter_limiter()
try:
idle_workers = _idle_workers.get()
except LookupError:
idle_workers = deque()
_idle_workers.set(idle_workers)
atexit.register(_stop_workers, idle_workers)
async with limiter:
try:
worker = idle_workers.pop()
except IndexError:
worker = _Worker()
try:
return await to_thread.run_sync(
worker.call,
func,
args,
limiter=limiter,
)
finally:
# Prune workers that have been idle for too long
now = current_time()
while idle_workers:
if now - idle_workers[0].last_used <= MAX_WORKER_IDLE_TIME:
break
await to_thread.run_sync(idle_workers.popleft().destroy, limiter=limiter)
worker.last_used = current_time()
idle_workers.append(worker)
def current_default_interpreter_limiter() -> CapacityLimiter:
"""
Return the capacity limiter used by default to limit the number of concurrently
running subinterpreters.
Defaults to the number of CPU cores.
:return: a capacity limiter object
"""
try:
return _default_interpreter_limiter.get()
except LookupError:
limiter = CapacityLimiter(os.cpu_count() or DEFAULT_CPU_COUNT)
_default_interpreter_limiter.set(limiter)
return limiter

View File

@@ -0,0 +1,266 @@
from __future__ import annotations
__all__ = (
"current_default_process_limiter",
"process_worker",
"run_sync",
)
import os
import pickle
import subprocess
import sys
from collections import deque
from collections.abc import Callable
from importlib.util import module_from_spec, spec_from_file_location
from typing import TypeVar, cast
from ._core._eventloop import current_time, get_async_backend, get_cancelled_exc_class
from ._core._exceptions import BrokenWorkerProcess
from ._core._subprocesses import open_process
from ._core._synchronization import CapacityLimiter
from ._core._tasks import CancelScope, fail_after
from .abc import ByteReceiveStream, ByteSendStream, Process
from .lowlevel import RunVar, checkpoint_if_cancelled
from .streams.buffered import BufferedByteReceiveStream
if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack
WORKER_MAX_IDLE_TIME = 300 # 5 minutes
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
_process_pool_workers: RunVar[set[Process]] = RunVar("_process_pool_workers")
_process_pool_idle_workers: RunVar[deque[tuple[Process, float]]] = RunVar(
"_process_pool_idle_workers"
)
_default_process_limiter: RunVar[CapacityLimiter] = RunVar("_default_process_limiter")
async def run_sync( # type: ignore[return]
func: Callable[[Unpack[PosArgsT]], T_Retval],
*args: Unpack[PosArgsT],
cancellable: bool = False,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
"""
Call the given function with the given arguments in a worker process.
If the ``cancellable`` option is enabled and the task waiting for its completion is
cancelled, the worker process running it will be abruptly terminated using SIGKILL
(or ``terminateProcess()`` on Windows).
:param func: a callable
:param args: positional arguments for the callable
:param cancellable: ``True`` to allow cancellation of the operation while it's
running
:param limiter: capacity limiter to use to limit the total amount of processes
running (if omitted, the default limiter is used)
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
:return: an awaitable that yields the return value of the function.
"""
async def send_raw_command(pickled_cmd: bytes) -> object:
try:
await stdin.send(pickled_cmd)
response = await buffered.receive_until(b"\n", 50)
status, length = response.split(b" ")
if status not in (b"RETURN", b"EXCEPTION"):
raise RuntimeError(
f"Worker process returned unexpected response: {response!r}"
)
pickled_response = await buffered.receive_exactly(int(length))
except BaseException as exc:
workers.discard(process)
try:
process.kill()
with CancelScope(shield=True):
await process.aclose()
except ProcessLookupError:
pass
if isinstance(exc, get_cancelled_exc_class()):
raise
else:
raise BrokenWorkerProcess from exc
retval = pickle.loads(pickled_response)
if status == b"EXCEPTION":
assert isinstance(retval, BaseException)
raise retval
else:
return retval
# First pickle the request before trying to reserve a worker process
await checkpoint_if_cancelled()
request = pickle.dumps(("run", func, args), protocol=pickle.HIGHEST_PROTOCOL)
# If this is the first run in this event loop thread, set up the necessary variables
try:
workers = _process_pool_workers.get()
idle_workers = _process_pool_idle_workers.get()
except LookupError:
workers = set()
idle_workers = deque()
_process_pool_workers.set(workers)
_process_pool_idle_workers.set(idle_workers)
get_async_backend().setup_process_pool_exit_at_shutdown(workers)
async with limiter or current_default_process_limiter():
# Pop processes from the pool (starting from the most recently used) until we
# find one that hasn't exited yet
process: Process
while idle_workers:
process, idle_since = idle_workers.pop()
if process.returncode is None:
stdin = cast(ByteSendStream, process.stdin)
buffered = BufferedByteReceiveStream(
cast(ByteReceiveStream, process.stdout)
)
# Prune any other workers that have been idle for WORKER_MAX_IDLE_TIME
# seconds or longer
now = current_time()
killed_processes: list[Process] = []
while idle_workers:
if now - idle_workers[0][1] < WORKER_MAX_IDLE_TIME:
break
process_to_kill, idle_since = idle_workers.popleft()
process_to_kill.kill()
workers.remove(process_to_kill)
killed_processes.append(process_to_kill)
with CancelScope(shield=True):
for killed_process in killed_processes:
await killed_process.aclose()
break
workers.remove(process)
else:
command = [sys.executable, "-u", "-m", __name__]
process = await open_process(
command, stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
try:
stdin = cast(ByteSendStream, process.stdin)
buffered = BufferedByteReceiveStream(
cast(ByteReceiveStream, process.stdout)
)
with fail_after(20):
message = await buffered.receive(6)
if message != b"READY\n":
raise BrokenWorkerProcess(
f"Worker process returned unexpected response: {message!r}"
)
main_module_path = getattr(sys.modules["__main__"], "__file__", None)
pickled = pickle.dumps(
("init", sys.path, main_module_path),
protocol=pickle.HIGHEST_PROTOCOL,
)
await send_raw_command(pickled)
except (BrokenWorkerProcess, get_cancelled_exc_class()):
raise
except BaseException as exc:
process.kill()
raise BrokenWorkerProcess(
"Error during worker process initialization"
) from exc
workers.add(process)
with CancelScope(shield=not cancellable):
try:
return cast(T_Retval, await send_raw_command(request))
finally:
if process in workers:
idle_workers.append((process, current_time()))
def current_default_process_limiter() -> CapacityLimiter:
"""
Return the capacity limiter that is used by default to limit the number of worker
processes.
:return: a capacity limiter object
"""
try:
return _default_process_limiter.get()
except LookupError:
limiter = CapacityLimiter(os.cpu_count() or 2)
_default_process_limiter.set(limiter)
return limiter
def process_worker() -> None:
# Redirect standard streams to os.devnull so that user code won't interfere with the
# parent-worker communication
stdin = sys.stdin
stdout = sys.stdout
sys.stdin = open(os.devnull)
sys.stdout = open(os.devnull, "w")
stdout.buffer.write(b"READY\n")
while True:
retval = exception = None
try:
command, *args = pickle.load(stdin.buffer)
except EOFError:
return
except BaseException as exc:
exception = exc
else:
if command == "run":
func, args = args
try:
retval = func(*args)
except BaseException as exc:
exception = exc
elif command == "init":
main_module_path: str | None
sys.path, main_module_path = args
del sys.modules["__main__"]
if main_module_path and os.path.isfile(main_module_path):
# Load the parent's main module but as __mp_main__ instead of
# __main__ (like multiprocessing does) to avoid infinite recursion
try:
spec = spec_from_file_location("__mp_main__", main_module_path)
if spec and spec.loader:
main = module_from_spec(spec)
spec.loader.exec_module(main)
sys.modules["__main__"] = main
except BaseException as exc:
exception = exc
try:
if exception is not None:
status = b"EXCEPTION"
pickled = pickle.dumps(exception, pickle.HIGHEST_PROTOCOL)
else:
status = b"RETURN"
pickled = pickle.dumps(retval, pickle.HIGHEST_PROTOCOL)
except BaseException as exc:
exception = exc
status = b"EXCEPTION"
pickled = pickle.dumps(exc, pickle.HIGHEST_PROTOCOL)
stdout.buffer.write(b"%s %d\n" % (status, len(pickled)))
stdout.buffer.write(pickled)
# Respect SIGTERM
if isinstance(exception, SystemExit):
raise exception
if __name__ == "__main__":
process_worker()

View File

@@ -0,0 +1,78 @@
from __future__ import annotations
__all__ = (
"run_sync",
"current_default_thread_limiter",
)
import sys
from collections.abc import Callable
from typing import TypeVar
from warnings import warn
from ._core._eventloop import get_async_backend
from .abc import CapacityLimiter
if sys.version_info >= (3, 11):
from typing import TypeVarTuple, Unpack
else:
from typing_extensions import TypeVarTuple, Unpack
T_Retval = TypeVar("T_Retval")
PosArgsT = TypeVarTuple("PosArgsT")
async def run_sync(
func: Callable[[Unpack[PosArgsT]], T_Retval],
*args: Unpack[PosArgsT],
abandon_on_cancel: bool = False,
cancellable: bool | None = None,
limiter: CapacityLimiter | None = None,
) -> T_Retval:
"""
Call the given function with the given arguments in a worker thread.
If the ``cancellable`` option is enabled and the task waiting for its completion is
cancelled, the thread will still run its course but its return value (or any raised
exception) will be ignored.
:param func: a callable
:param args: positional arguments for the callable
:param abandon_on_cancel: ``True`` to abandon the thread (leaving it to run
unchecked on own) if the host task is cancelled, ``False`` to ignore
cancellations in the host task until the operation has completed in the worker
thread
:param cancellable: deprecated alias of ``abandon_on_cancel``; will override
``abandon_on_cancel`` if both parameters are passed
:param limiter: capacity limiter to use to limit the total amount of threads running
(if omitted, the default limiter is used)
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
:return: an awaitable that yields the return value of the function.
"""
if cancellable is not None:
abandon_on_cancel = cancellable
warn(
"The `cancellable=` keyword argument to `anyio.to_thread.run_sync` is "
"deprecated since AnyIO 4.1.0; use `abandon_on_cancel=` instead",
DeprecationWarning,
stacklevel=2,
)
return await get_async_backend().run_sync_in_worker_thread(
func, args, abandon_on_cancel=abandon_on_cancel, limiter=limiter
)
def current_default_thread_limiter() -> CapacityLimiter:
"""
Return the capacity limiter that is used by default to limit the number of
concurrent threads.
:return: a capacity limiter object
:raises NoEventLoopError: if no supported asynchronous event loop is running in the
current thread
"""
return get_async_backend().current_default_thread_limiter()

View File

@@ -0,0 +1,131 @@
Metadata-Version: 2.4
Name: asyncpg
Version: 0.31.0
Summary: An asyncio PostgreSQL driver
Author-email: MagicStack Inc <hello@magic.io>
License-Expression: Apache-2.0
Project-URL: github, https://github.com/MagicStack/asyncpg
Keywords: database,postgres
Classifier: Development Status :: 5 - Production/Stable
Classifier: Framework :: AsyncIO
Classifier: Intended Audience :: Developers
Classifier: Operating System :: POSIX
Classifier: Operating System :: MacOS :: MacOS X
Classifier: Operating System :: Microsoft :: Windows
Classifier: Programming Language :: Python :: 3 :: Only
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Programming Language :: Python :: 3.14
Classifier: Programming Language :: Python :: Implementation :: CPython
Classifier: Programming Language :: Python :: Free Threading :: 2 - Beta
Classifier: Topic :: Database :: Front-Ends
Requires-Python: >=3.9.0
Description-Content-Type: text/x-rst
License-File: LICENSE
Requires-Dist: async_timeout>=4.0.3; python_version < "3.11.0"
Provides-Extra: gssauth
Requires-Dist: gssapi; platform_system != "Windows" and extra == "gssauth"
Requires-Dist: sspilib; platform_system == "Windows" and extra == "gssauth"
Dynamic: license-file
asyncpg -- A fast PostgreSQL Database Client Library for Python/asyncio
=======================================================================
.. image:: https://github.com/MagicStack/asyncpg/workflows/Tests/badge.svg
:target: https://github.com/MagicStack/asyncpg/actions?query=workflow%3ATests+branch%3Amaster
:alt: GitHub Actions status
.. image:: https://img.shields.io/pypi/v/asyncpg.svg
:target: https://pypi.python.org/pypi/asyncpg
**asyncpg** is a database interface library designed specifically for
PostgreSQL and Python/asyncio. asyncpg is an efficient, clean implementation
of PostgreSQL server binary protocol for use with Python's ``asyncio``
framework. You can read more about asyncpg in an introductory
`blog post <http://magic.io/blog/asyncpg-1m-rows-from-postgres-to-python/>`_.
asyncpg requires Python 3.9 or later and is supported for PostgreSQL
versions 9.5 to 18. Other PostgreSQL versions or other databases
implementing the PostgreSQL protocol *may* work, but are not being
actively tested.
Documentation
-------------
The project documentation can be found
`here <https://magicstack.github.io/asyncpg/current/>`_.
Performance
-----------
In our testing asyncpg is, on average, **5x** faster than psycopg3.
.. image:: https://raw.githubusercontent.com/MagicStack/asyncpg/master/performance.png?fddca40ab0
:target: https://gistpreview.github.io/?0ed296e93523831ea0918d42dd1258c2
The above results are a geometric mean of benchmarks obtained with PostgreSQL
`client driver benchmarking toolbench <https://github.com/MagicStack/pgbench>`_
in June 2023 (click on the chart to see full details).
Features
--------
asyncpg implements PostgreSQL server protocol natively and exposes its
features directly, as opposed to hiding them behind a generic facade
like DB-API.
This enables asyncpg to have easy-to-use support for:
* **prepared statements**
* **scrollable cursors**
* **partial iteration** on query results
* automatic encoding and decoding of composite types, arrays,
and any combination of those
* straightforward support for custom data types
Installation
------------
asyncpg is available on PyPI. When not using GSSAPI/SSPI authentication it
has no dependencies. Use pip to install::
$ pip install asyncpg
If you need GSSAPI/SSPI authentication, use::
$ pip install 'asyncpg[gssauth]'
For more details, please `see the documentation
<https://magicstack.github.io/asyncpg/current/installation.html>`_.
Basic Usage
-----------
.. code-block:: python
import asyncio
import asyncpg
async def run():
conn = await asyncpg.connect(user='user', password='password',
database='database', host='127.0.0.1')
values = await conn.fetch(
'SELECT * FROM mytable WHERE id = $1',
10,
)
await conn.close()
asyncio.run(run())
License
-------
asyncpg is developed and distributed under the Apache 2.0 license.

View File

@@ -0,0 +1,115 @@
asyncpg-0.31.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
asyncpg-0.31.0.dist-info/METADATA,sha256=6_wrzxCAjX9RTCqcxrKtAFVVNEZ2fTxlHcbvo62M0R8,4412
asyncpg-0.31.0.dist-info/RECORD,,
asyncpg-0.31.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
asyncpg-0.31.0.dist-info/WHEEL,sha256=VXvNKn6nFeCM45GEUrNLJOO_J_e-cNJphGt9rWFxyE0,113
asyncpg-0.31.0.dist-info/licenses/LICENSE,sha256=2SItc_2sUJkhdAdu-gT0T2-82dVhVafHCS6YdXBCpvY,11466
asyncpg-0.31.0.dist-info/top_level.txt,sha256=DdhVhpzCq49mykkHNag6i9zuJx05_tx4CMZymM1F8dU,8
asyncpg/__init__.py,sha256=bzD31aMekbKR9waMXuAxIYFbmrQ-S1Mttjmru_sSjo8,647
asyncpg/__pycache__/__init__.cpython-312.pyc,,
asyncpg/__pycache__/_asyncio_compat.cpython-312.pyc,,
asyncpg/__pycache__/_version.cpython-312.pyc,,
asyncpg/__pycache__/cluster.cpython-312.pyc,,
asyncpg/__pycache__/compat.cpython-312.pyc,,
asyncpg/__pycache__/connect_utils.cpython-312.pyc,,
asyncpg/__pycache__/connection.cpython-312.pyc,,
asyncpg/__pycache__/connresource.cpython-312.pyc,,
asyncpg/__pycache__/cursor.cpython-312.pyc,,
asyncpg/__pycache__/introspection.cpython-312.pyc,,
asyncpg/__pycache__/pool.cpython-312.pyc,,
asyncpg/__pycache__/prepared_stmt.cpython-312.pyc,,
asyncpg/__pycache__/serverversion.cpython-312.pyc,,
asyncpg/__pycache__/transaction.cpython-312.pyc,,
asyncpg/__pycache__/types.cpython-312.pyc,,
asyncpg/__pycache__/utils.cpython-312.pyc,,
asyncpg/_asyncio_compat.py,sha256=pXF_aF4o_AqxNql0sPnuGdoe5sSSwQxHpKWF6ShZTbo,2540
asyncpg/_testbase/__init__.py,sha256=IzMqfgI5gtOxajneoeWyoI4NtmE5sp7S5dXmU0gwwB8,16499
asyncpg/_testbase/__pycache__/__init__.cpython-312.pyc,,
asyncpg/_testbase/__pycache__/fuzzer.cpython-312.pyc,,
asyncpg/_testbase/fuzzer.py,sha256=3Uxdu0YXei-7JZMCuCI3bxKMdnbuossV-KC68GG-AS4,9804
asyncpg/_version.py,sha256=DIxy-OSz203zJ1z-llM0b0YFnqQtPgPR9jOjCWlyutE,641
asyncpg/cluster.py,sha256=s_HmtiEGJqJ6GQWa6_zmfe11fZ29OpOtMT6Ufcu-g0g,24476
asyncpg/compat.py,sha256=ebs2IeJw82rY9m0ZCmOYUqry_2nF3zqTi3tsWP5FT2o,2459
asyncpg/connect_utils.py,sha256=-kLsbKn6zO5ixgatHN3Av3FvkZrLIAXF-tpao_bmy5w,43438
asyncpg/connection.py,sha256=6avVwVO8cM-tthIV8RFA6QdBwv7Vkp5XS2VI74CStIw,99069
asyncpg/connresource.py,sha256=tBAidNpEhbDvrMOKQbwn3ZNgIVAtsVxARxTnwj5fk-Q,1384
asyncpg/cursor.py,sha256=rKeSIJMW5mUpvsian6a1MLrLoEwbkYTZsmZtEgwFT6s,9160
asyncpg/exceptions/__init__.py,sha256=FXUYDFQw9gxE3mVz99FmsldYxivLUMtTIhXzu5tZ7Pk,29157
asyncpg/exceptions/__pycache__/__init__.cpython-312.pyc,,
asyncpg/exceptions/__pycache__/_base.cpython-312.pyc,,
asyncpg/exceptions/_base.py,sha256=u62xv69n4AHO1xr35FjdgZhYvqdeb_mkQKyp-ip_AyQ,9260
asyncpg/introspection.py,sha256=0eRQtt0mKPGv8V2fnTSizC_vLuk8jbO1VwoiQ5SAcd4,9340
asyncpg/pgproto/__init__.pxd,sha256=uUIkKuI6IGnQ5tZXtrjOC_13qjp9MZOwewKlrxKFzPY,213
asyncpg/pgproto/__init__.py,sha256=uUIkKuI6IGnQ5tZXtrjOC_13qjp9MZOwewKlrxKFzPY,213
asyncpg/pgproto/__pycache__/__init__.cpython-312.pyc,,
asyncpg/pgproto/__pycache__/types.cpython-312.pyc,,
asyncpg/pgproto/buffer.pxd,sha256=d-hqi81ZVLD16lT-NacaQXTYKCtE-RYxWJmZisliZV8,4542
asyncpg/pgproto/buffer.pxi,sha256=zdk5rOOoenHpwwYXcEo-NyUcLgUv3QxA40JCXarKVsM,93
asyncpg/pgproto/buffer.pyx,sha256=mcaNScr1jPI8kpxT9jfg_Nicy8JOjQv6WM2j5JNPoAA,25710
asyncpg/pgproto/codecs/__init__.pxd,sha256=zHtFdDnGn8eDSPOTPw8JzD8w_OCJyqL0F6s9dR8Vz2Y,6130
asyncpg/pgproto/codecs/bits.pyx,sha256=x4MMVRLotz9R8n81E0S3lQQk23AvLlODb2pe_NGYqCI,1475
asyncpg/pgproto/codecs/bytea.pyx,sha256=ot-oFH-hzQ89EUWneHk5QDUxl2krKkpYE_nWklVHXWU,997
asyncpg/pgproto/codecs/context.pyx,sha256=oYurToHnpZz-Q8kPzRORFS_RyV4HH5kscNKsZYPt4FU,623
asyncpg/pgproto/codecs/datetime.pyx,sha256=wLcPVOoPMsI7P8VIuQRO7oyqhr8VRIrKMvDtSqjYSVo,12855
asyncpg/pgproto/codecs/float.pyx,sha256=A6XXA2NdS82EENhADA35LInxLcJsRpXvF6JVme_6HCc,1031
asyncpg/pgproto/codecs/geometry.pyx,sha256=DtRADwsifbzAZyACxakne2MVApcUNji8EyOgtKuoEaw,4665
asyncpg/pgproto/codecs/hstore.pyx,sha256=sXwFn3uzypvPkYIFH0FykiW9RU8qRme2N0lg8UoB6kg,2018
asyncpg/pgproto/codecs/int.pyx,sha256=4RuntTl_4-I7ekCSONK9y4CWFghUmaFGldXL6ruLgxM,4527
asyncpg/pgproto/codecs/json.pyx,sha256=fs7d0sroyMM9UZW-mmGgvHtVG7MiBac7Inb_wz1mMRs,1454
asyncpg/pgproto/codecs/jsonpath.pyx,sha256=bAXgTvPzQlkJdlHHB95CNl03J2WAd_iK3JsE1PXI2KU,833
asyncpg/pgproto/codecs/misc.pyx,sha256=ul5HFobQ1H3shO6ThrSlkEHO1lvxOoqTnRej3UabKiQ,484
asyncpg/pgproto/codecs/network.pyx,sha256=1oFM__xT5H3pIZrLyRqjNqrR6z1UNlqMOWGTGnsbOyw,3917
asyncpg/pgproto/codecs/numeric.pyx,sha256=TAN5stFXzmEiyP69MDG1oXryPAFCyZmxHcqPc-vy7LM,10373
asyncpg/pgproto/codecs/pg_snapshot.pyx,sha256=WGJ-dv7JXVufybAiuScth7KlXXLRdMqSKbtfT4kpVWI,1814
asyncpg/pgproto/codecs/text.pyx,sha256=yHpJCRxrf2Pgmz1abYSgvFQDRcgCJN137aniygOo_ec,1516
asyncpg/pgproto/codecs/tid.pyx,sha256=_9L8C9NSDV6Ehk48VV8xOLDNLVJz2R88EornZbHcq88,1549
asyncpg/pgproto/codecs/uuid.pyx,sha256=XIydQCaPUlfz9MvVDOu_5BTHd1kSKmJ1r3kBpsfjfYE,855
asyncpg/pgproto/consts.pxi,sha256=emui5kw362ivo2G15l8epVAA8m5eDlcbcrRn794_kew,282
asyncpg/pgproto/cpythonx.pxd,sha256=B9fAfasXgoWN-Z-STGCxbu0sW-QR8EblCIbxlzPo0Uc,736
asyncpg/pgproto/debug.pxd,sha256=SuLG2tteWe3cXnS0czRTTNnnm2QGgG02icp_6G_X9Yw,263
asyncpg/pgproto/frb.pxd,sha256=B2s2dw-SkzfKWeLEWzVLTkjjYYW53pazPcVNH3vPxAk,1212
asyncpg/pgproto/frb.pyx,sha256=7bipWSBXebweq3JBFlCvSwa03fIZGLkKPqWbJ8VFWFI,409
asyncpg/pgproto/hton.pxd,sha256=Swx5ry82iWYO9Ok4fRa_b7cLSrIPyxNYlyXm-ncYweo,953
asyncpg/pgproto/pgproto.cpython-312-x86_64-linux-gnu.so,sha256=NTdRB22nirJunzh0LxTilo0tt4TwIBn6OuKV3eDvhkc,3188448
asyncpg/pgproto/pgproto.pxd,sha256=QUUxWiHKdKfFxdDT0czSvOFsA4b59MJRR6WlUbJFgPg,430
asyncpg/pgproto/pgproto.pyi,sha256=vDsno93anu44CMa0TGTOcQSBT8mRXV-vvOH0v-kGCI0,411
asyncpg/pgproto/pgproto.pyx,sha256=bK75qfRQlofzO8dDzJ2mHUE0wLeXSsc5SLeAGvyXSeE,1249
asyncpg/pgproto/tohex.pxd,sha256=fQVaxBu6dBw2P_ROR8MSPVDlVep0McKi69fdQBLhifI,361
asyncpg/pgproto/types.py,sha256=7Onq2d7i01PLjR199TdraZFB66CdgIEXYuHoQCT3jKU,12775
asyncpg/pgproto/uuid.pyx,sha256=heUtr0KxJSSNQ92Wz9xYCHbuzS_uMHzqmJrEtKW_0cc,10243
asyncpg/pool.py,sha256=q8cQHSnbR_8x---isjqk_o_FmSS9LICuNiKkICSvhM8,42368
asyncpg/prepared_stmt.py,sha256=rLWS-YvCtxEgMWEIlbSgcuUcB39wzCoxfkzCE806Mac,9783
asyncpg/protocol/__init__.py,sha256=fDWUanigffIYzRccQeD2nhQXekl0bNGvefJDcMiOE8I,359
asyncpg/protocol/__pycache__/__init__.cpython-312.pyc,,
asyncpg/protocol/codecs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
asyncpg/protocol/codecs/__pycache__/__init__.cpython-312.pyc,,
asyncpg/protocol/codecs/array.pyx,sha256=1S_6xdgxllG8_1Lb68XdPkH1QgF63gAAmjh091Q7Dyk,29486
asyncpg/protocol/codecs/base.pxd,sha256=wSk_mmRGI-4echq_F-bj5Am33kd3cuUODFe5KgSLfME,6614
asyncpg/protocol/codecs/base.pyx,sha256=58vw_sWBBPIRHLBACBUpUJ0geAREGDaS434-V_eBMUk,34460
asyncpg/protocol/codecs/pgproto.pyx,sha256=5PDv1JT_nXbDbHtYVrGCcZN3CxzQdgwqlXT8GpyMamk,17175
asyncpg/protocol/codecs/range.pyx,sha256=-P-acyY2e5TlEtjqbkeH28PYk-DGLxqbmzKDFGL5BbI,6359
asyncpg/protocol/codecs/record.pyx,sha256=l17HPv3ZeZzvDMXmh-FTdOQ0LxqaQsge_4hlmnGaf6s,2362
asyncpg/protocol/codecs/textutils.pyx,sha256=UmTt1Zs5N2oLVDMTSlSe1zAFt5q4_4akbXZoS6HSPO8,2011
asyncpg/protocol/consts.pxi,sha256=VT7NLBpLgPUvcUbPflrX84I79JZiFg4zFzBK28nCRZo,381
asyncpg/protocol/coreproto.pxd,sha256=77yJqaBMGWHmxyihZIFfyVgfzICF9jLwKSvtuCoE8rM,6215
asyncpg/protocol/coreproto.pyx,sha256=G6HhC5sbsAFVNmNvvp8gFgaeTaVDRBRkfB8X0DXHYFA,41005
asyncpg/protocol/cpythonx.pxd,sha256=VX71g4PiwXWGTY-BzBPm7S-AiX5ySRrY40qAggH-BIA,613
asyncpg/protocol/encodings.pyx,sha256=rTjWPi-nMun7x8UYCkwsUN16pL5M5LPODYWI28nIKL4,1634
asyncpg/protocol/pgtypes.pxi,sha256=4XVeMr1Y04QOJV9f76V2882nT1vmcK-8J4AhFw0eFOw,6918
asyncpg/protocol/prepared_stmt.pxd,sha256=GhHzJgQMehpWg0i3XSmbkJH6G5nnnmdNCf2EU_gXhDY,1115
asyncpg/protocol/prepared_stmt.pyx,sha256=f-YQVhV5trxNdnlyjHCjOmi9byP5mbP__eps_JamQts,13036
asyncpg/protocol/protocol.cpython-312-x86_64-linux-gnu.so,sha256=oecLbiotWzaj1mpL0m8YlOgb5n9I_Jpq0kxctF6elN8,8949160
asyncpg/protocol/protocol.pxd,sha256=yOVFbkD7mA8VK5IGIJ4dGTyvHKWZTQOFfCFNfdeUdK8,1927
asyncpg/protocol/protocol.pyi,sha256=F3LFOSUEWZF5-XLpAkUomHflQT9ITm985J-HsN1QVO8,9108
asyncpg/protocol/protocol.pyx,sha256=li4lQhO3PeulcvSy6_Gi4PJAUp3GHE9e-nqh6xtQ-QU,34473
asyncpg/protocol/record.cpython-312-x86_64-linux-gnu.so,sha256=mjK7E05BF9dbEBdlLAIjrQvvWH3vHB-fANm2hOJIFus,135280
asyncpg/protocol/record.pyi,sha256=KJ6nF9Ad5tz4Z8wEAye_rzjBNghyCXUh9PAOqW53jW0,741
asyncpg/protocol/recordcapi.pxd,sha256=AkcjhRpPAmun1gmohKLwvalcqODvAEdUijnsK6wY6Lg,360
asyncpg/protocol/scram.pxd,sha256=t_nkicIS_4AzxyHoq-aYUNrFNv8O0W7E090HfMAIuno,1299
asyncpg/protocol/scram.pyx,sha256=nT_Rawg6h3OrRWDBwWN7lju5_hnOmXpwWFWVrb3l_dQ,14594
asyncpg/protocol/settings.pxd,sha256=8DTwZ5mi0aAUJRWE6SUIRDhWFGFis1mj8lcA8hNFTL0,1066
asyncpg/protocol/settings.pyx,sha256=yICjZF5FXwfmdxQBg-1qO0XbpLvZL11-c3aMbiwM7oo,3777
asyncpg/serverversion.py,sha256=WwlqBJkXZHvvnFluubCjPoaX_7OqjR8QgiOe90w6C9E,2133
asyncpg/transaction.py,sha256=uAJok6Shx7-Kdt5l4NX-GJtLxVJSPXTOJUryGdbIVG8,8497
asyncpg/types.py,sha256=rvWDTt-ZF56HrAjtDlOe7aodldEmtTTCDo3l134VHVM,5512
asyncpg/utils.py,sha256=Y0vATexoIHFkpWURlqnlUZUacc4F1iZJ9rWJ3654OnM,1495

View File

@@ -0,0 +1,5 @@
Wheel-Version: 1.0
Generator: setuptools (80.9.0)
Root-Is-Purelib: false
Tag: cp312-cp312-manylinux_2_28_x86_64

View File

@@ -0,0 +1,204 @@
Copyright (C) 2016-present the asyncpg authors and contributors.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright (C) 2016-present the asyncpg authors and contributors
<see AUTHORS file>
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@@ -0,0 +1,24 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
from .connection import connect, Connection # NOQA
from .exceptions import * # NOQA
from .pool import create_pool, Pool # NOQA
from .protocol import Record # NOQA
from .types import * # NOQA
from ._version import __version__ # NOQA
from . import exceptions
__all__: tuple[str, ...] = (
'connect', 'create_pool', 'Pool', 'Record', 'Connection'
)
__all__ += exceptions.__all__ # NOQA

View File

@@ -0,0 +1,94 @@
# Backports from Python/Lib/asyncio for older Pythons
#
# Copyright (c) 2001-2023 Python Software Foundation; All Rights Reserved
#
# SPDX-License-Identifier: PSF-2.0
from __future__ import annotations
import asyncio
import functools
import sys
import typing
if typing.TYPE_CHECKING:
from . import compat
if sys.version_info < (3, 11):
from async_timeout import timeout as timeout_ctx
else:
from asyncio import timeout as timeout_ctx
_T = typing.TypeVar('_T')
async def wait_for(fut: compat.Awaitable[_T], timeout: float | None) -> _T:
"""Wait for the single Future or coroutine to complete, with timeout.
Coroutine will be wrapped in Task.
Returns result of the Future or coroutine. When a timeout occurs,
it cancels the task and raises TimeoutError. To avoid the task
cancellation, wrap it in shield().
If the wait is cancelled, the task is also cancelled.
If the task supresses the cancellation and returns a value instead,
that value is returned.
This function is a coroutine.
"""
# The special case for timeout <= 0 is for the following case:
#
# async def test_waitfor():
# func_started = False
#
# async def func():
# nonlocal func_started
# func_started = True
#
# try:
# await asyncio.wait_for(func(), 0)
# except asyncio.TimeoutError:
# assert not func_started
# else:
# assert False
#
# asyncio.run(test_waitfor())
if timeout is not None and timeout <= 0:
fut = asyncio.ensure_future(fut)
if fut.done():
return fut.result()
await _cancel_and_wait(fut)
try:
return fut.result()
except asyncio.CancelledError as exc:
raise TimeoutError from exc
async with timeout_ctx(timeout):
return await fut
async def _cancel_and_wait(fut: asyncio.Future[_T]) -> None:
"""Cancel the *fut* future or task and wait until it completes."""
loop = asyncio.get_running_loop()
waiter = loop.create_future()
cb = functools.partial(_release_waiter, waiter)
fut.add_done_callback(cb)
try:
fut.cancel()
# We cannot wait on *fut* directly to make
# sure _cancel_and_wait itself is reliably cancellable.
await waiter
finally:
fut.remove_done_callback(cb)
def _release_waiter(waiter: asyncio.Future[typing.Any], *args: object) -> None:
if not waiter.done():
waiter.set_result(None)

View File

@@ -0,0 +1,543 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import atexit
import contextlib
import functools
import inspect
import logging
import os
import re
import textwrap
import time
import traceback
import unittest
import asyncpg
from asyncpg import cluster as pg_cluster
from asyncpg import connection as pg_connection
from asyncpg import pool as pg_pool
from . import fuzzer
@contextlib.contextmanager
def silence_asyncio_long_exec_warning():
def flt(log_record):
msg = log_record.getMessage()
return not msg.startswith('Executing ')
logger = logging.getLogger('asyncio')
logger.addFilter(flt)
try:
yield
finally:
logger.removeFilter(flt)
def with_timeout(timeout):
def wrap(func):
func.__timeout__ = timeout
return func
return wrap
class TestCaseMeta(type(unittest.TestCase)):
TEST_TIMEOUT = None
@staticmethod
def _iter_methods(bases, ns):
for base in bases:
for methname in dir(base):
if not methname.startswith('test_'):
continue
meth = getattr(base, methname)
if not inspect.iscoroutinefunction(meth):
continue
yield methname, meth
for methname, meth in ns.items():
if not methname.startswith('test_'):
continue
if not inspect.iscoroutinefunction(meth):
continue
yield methname, meth
def __new__(mcls, name, bases, ns):
for methname, meth in mcls._iter_methods(bases, ns):
@functools.wraps(meth)
def wrapper(self, *args, __meth__=meth, **kwargs):
coro = __meth__(self, *args, **kwargs)
timeout = getattr(__meth__, '__timeout__', mcls.TEST_TIMEOUT)
if timeout:
coro = asyncio.wait_for(coro, timeout)
try:
self.loop.run_until_complete(coro)
except asyncio.TimeoutError:
raise self.failureException(
'test timed out after {} seconds'.format(
timeout)) from None
else:
self.loop.run_until_complete(coro)
ns[methname] = wrapper
return super().__new__(mcls, name, bases, ns)
class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
@classmethod
def setUpClass(cls):
if os.environ.get('USE_UVLOOP'):
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
cls.loop = loop
@classmethod
def tearDownClass(cls):
cls.loop.close()
asyncio.set_event_loop(None)
def setUp(self):
self.loop.set_exception_handler(self.loop_exception_handler)
self.__unhandled_exceptions = []
def tearDown(self):
excs = []
for exc in self.__unhandled_exceptions:
if isinstance(exc, ConnectionResetError):
texc = traceback.TracebackException.from_exception(
exc, lookup_lines=False)
if texc.stack[-1].name == "_call_connection_lost":
# On Windows calling socket.shutdown may raise
# ConnectionResetError, which happens in the
# finally block of _call_connection_lost.
continue
excs.append(exc)
if excs:
formatted = []
for i, context in enumerate(excs):
formatted.append(self._format_loop_exception(context, i + 1))
self.fail(
'unexpected exceptions in asynchronous code:\n' +
'\n'.join(formatted))
@contextlib.contextmanager
def assertRunUnder(self, delta):
st = time.monotonic()
try:
yield
finally:
elapsed = time.monotonic() - st
if elapsed > delta:
raise AssertionError(
'running block took {:0.3f}s which is longer '
'than the expected maximum of {:0.3f}s'.format(
elapsed, delta))
@contextlib.contextmanager
def assertLoopErrorHandlerCalled(self, msg_re: str):
contexts = []
def handler(loop, ctx):
contexts.append(ctx)
old_handler = self.loop.get_exception_handler()
self.loop.set_exception_handler(handler)
try:
yield
for ctx in contexts:
msg = ctx.get('message')
if msg and re.search(msg_re, msg):
return
raise AssertionError(
'no message matching {!r} was logged with '
'loop.call_exception_handler()'.format(msg_re))
finally:
self.loop.set_exception_handler(old_handler)
def loop_exception_handler(self, loop, context):
self.__unhandled_exceptions.append(context)
loop.default_exception_handler(context)
def _format_loop_exception(self, context, n):
message = context.get('message', 'Unhandled exception in event loop')
exception = context.get('exception')
if exception is not None:
exc_info = (type(exception), exception, exception.__traceback__)
else:
exc_info = None
lines = []
for key in sorted(context):
if key in {'message', 'exception'}:
continue
value = context[key]
if key == 'source_traceback':
tb = ''.join(traceback.format_list(value))
value = 'Object created at (most recent call last):\n'
value += tb.rstrip()
else:
try:
value = repr(value)
except Exception as ex:
value = ('Exception in __repr__ {!r}; '
'value type: {!r}'.format(ex, type(value)))
lines.append('[{}]: {}\n\n'.format(key, value))
if exc_info is not None:
lines.append('[exception]:\n')
formatted_exc = textwrap.indent(
''.join(traceback.format_exception(*exc_info)), ' ')
lines.append(formatted_exc)
details = textwrap.indent(''.join(lines), ' ')
return '{:02d}. {}:\n{}\n'.format(n, message, details)
_default_cluster = None
def _init_cluster(ClusterCls, cluster_kwargs, initdb_options=None):
cluster = ClusterCls(**cluster_kwargs)
cluster.init(**(initdb_options or {}))
cluster.trust_local_connections()
atexit.register(_shutdown_cluster, cluster)
return cluster
def _get_initdb_options(initdb_options=None):
if not initdb_options:
initdb_options = {}
else:
initdb_options = dict(initdb_options)
# Make the default superuser name stable.
if 'username' not in initdb_options:
initdb_options['username'] = 'postgres'
return initdb_options
def _init_default_cluster(initdb_options=None):
global _default_cluster
if _default_cluster is None:
pg_host = os.environ.get('PGHOST')
if pg_host:
# Using existing cluster, assuming it is initialized and running
_default_cluster = pg_cluster.RunningCluster()
else:
_default_cluster = _init_cluster(
pg_cluster.TempCluster,
cluster_kwargs={
"data_dir_suffix": ".apgtest",
},
initdb_options=_get_initdb_options(initdb_options),
)
return _default_cluster
def _shutdown_cluster(cluster):
if cluster.get_status() == 'running':
cluster.stop()
if cluster.get_status() != 'not-initialized':
cluster.destroy()
def create_pool(dsn=None, *,
min_size=10,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=60.0,
connect=None,
setup=None,
init=None,
loop=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
record_class=asyncpg.Record,
**connect_kwargs):
return pool_class(
dsn,
min_size=min_size,
max_size=max_size,
max_queries=max_queries,
loop=loop,
connect=connect,
setup=setup,
init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
record_class=record_class,
**connect_kwargs,
)
class ClusterTestCase(TestCase):
@classmethod
def get_server_settings(cls):
settings = {
'log_connections': 'on'
}
if cls.cluster.get_pg_version() >= (11, 0):
# JITting messes up timing tests, and
# is not essential for testing.
settings['jit'] = 'off'
return settings
@classmethod
def new_cluster(cls, ClusterCls, *, cluster_kwargs={}, initdb_options={}):
cluster = _init_cluster(ClusterCls, cluster_kwargs,
_get_initdb_options(initdb_options))
cls._clusters.append(cluster)
return cluster
@classmethod
def start_cluster(cls, cluster, *, server_settings={}):
cluster.start(port='dynamic', server_settings=server_settings)
@classmethod
def setup_cluster(cls):
cls.cluster = _init_default_cluster()
if cls.cluster.get_status() != 'running':
cls.cluster.start(
port='dynamic', server_settings=cls.get_server_settings())
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._clusters = []
cls.setup_cluster()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
for cluster in cls._clusters:
if cluster is not _default_cluster:
cluster.stop()
cluster.destroy()
cls._clusters = []
@classmethod
def get_connection_spec(cls, kwargs={}):
conn_spec = cls.cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec
@classmethod
def connect(cls, **kwargs):
conn_spec = cls.get_connection_spec(kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)
def setUp(self):
super().setUp()
self._pools = []
def tearDown(self):
super().tearDown()
for pool in self._pools:
pool.terminate()
self._pools = []
def create_pool(self, pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection, **kwargs):
conn_spec = self.get_connection_spec(kwargs)
pool = create_pool(loop=self.loop, pool_class=pool_class,
connection_class=connection_class, **conn_spec)
self._pools.append(pool)
return pool
class ProxiedClusterTestCase(ClusterTestCase):
@classmethod
def get_server_settings(cls):
settings = dict(super().get_server_settings())
settings['listen_addresses'] = '127.0.0.1'
return settings
@classmethod
def get_proxy_settings(cls):
return {'fuzzing-mode': None}
@classmethod
def setUpClass(cls):
super().setUpClass()
conn_spec = cls.cluster.get_connection_spec()
host = conn_spec.get('host')
if not host:
host = '127.0.0.1'
elif host.startswith('/'):
host = '127.0.0.1'
cls.proxy = fuzzer.TCPFuzzingProxy(
backend_host=host,
backend_port=conn_spec['port'],
)
cls.proxy.start()
@classmethod
def tearDownClass(cls):
cls.proxy.stop()
super().tearDownClass()
@classmethod
def get_connection_spec(cls, kwargs):
conn_spec = super().get_connection_spec(kwargs)
conn_spec['host'] = cls.proxy.listening_addr
conn_spec['port'] = cls.proxy.listening_port
return conn_spec
def tearDown(self):
self.proxy.reset()
super().tearDown()
def with_connection_options(**options):
if not options:
raise ValueError('no connection options were specified')
def wrap(func):
func.__connect_options__ = options
return func
return wrap
class ConnectedTestCase(ClusterTestCase):
def setUp(self):
super().setUp()
# Extract options set up with `with_connection_options`.
test_func = getattr(self, self._testMethodName).__func__
opts = getattr(test_func, '__connect_options__', {})
self.con = self.loop.run_until_complete(self.connect(**opts))
self.server_version = self.con.get_server_version()
def tearDown(self):
try:
self.loop.run_until_complete(self.con.close())
self.con = None
finally:
super().tearDown()
class HotStandbyTestCase(ClusterTestCase):
@classmethod
def setup_cluster(cls):
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
cls.start_cluster(
cls.master_cluster,
server_settings={
'max_wal_senders': 10,
'wal_level': 'hot_standby'
}
)
con = None
try:
con = cls.loop.run_until_complete(
cls.master_cluster.connect(
database='postgres', user='postgres', loop=cls.loop))
cls.loop.run_until_complete(
con.execute('''
CREATE ROLE replication WITH LOGIN REPLICATION
'''))
cls.master_cluster.trust_local_replication_by('replication')
conn_spec = cls.master_cluster.get_connection_spec()
cls.standby_cluster = cls.new_cluster(
pg_cluster.HotStandbyCluster,
cluster_kwargs={
'master': conn_spec,
'replication_user': 'replication'
}
)
cls.start_cluster(
cls.standby_cluster,
server_settings={
'hot_standby': True
}
)
finally:
if con is not None:
cls.loop.run_until_complete(con.close())
@classmethod
def get_cluster_connection_spec(cls, cluster, kwargs={}):
conn_spec = cluster.get_connection_spec()
if kwargs.get('dsn'):
conn_spec.pop('host')
conn_spec.update(kwargs)
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
if 'database' not in conn_spec:
conn_spec['database'] = 'postgres'
if 'user' not in conn_spec:
conn_spec['user'] = 'postgres'
return conn_spec
@classmethod
def get_connection_spec(cls, kwargs={}):
primary_spec = cls.get_cluster_connection_spec(
cls.master_cluster, kwargs
)
standby_spec = cls.get_cluster_connection_spec(
cls.standby_cluster, kwargs
)
return {
'host': [primary_spec['host'], standby_spec['host']],
'port': [primary_spec['port'], standby_spec['port']],
'database': primary_spec['database'],
'user': primary_spec['user'],
**kwargs
}
@classmethod
def connect_primary(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)
@classmethod
def connect_standby(cls, **kwargs):
conn_spec = cls.get_cluster_connection_spec(
cls.standby_cluster,
kwargs
)
return pg_connection.connect(**conn_spec, loop=cls.loop)

View File

@@ -0,0 +1,306 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import socket
import threading
import typing
from asyncpg import cluster
class StopServer(Exception):
pass
class TCPFuzzingProxy:
def __init__(self, *, listening_addr: str='127.0.0.1',
listening_port: typing.Optional[int]=None,
backend_host: str, backend_port: int,
settings: typing.Optional[dict]=None) -> None:
self.listening_addr = listening_addr
self.listening_port = listening_port
self.backend_host = backend_host
self.backend_port = backend_port
self.settings = settings or {}
self.loop = None
self.connectivity = None
self.connectivity_loss = None
self.stop_event = None
self.connections = {}
self.sock = None
self.listen_task = None
async def _wait(self, work):
work_task = asyncio.ensure_future(work)
stop_event_task = asyncio.ensure_future(self.stop_event.wait())
try:
await asyncio.wait(
[work_task, stop_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.stop_event.is_set():
raise StopServer()
else:
return work_task.result()
finally:
if not work_task.done():
work_task.cancel()
if not stop_event_task.done():
stop_event_task.cancel()
def start(self):
started = threading.Event()
self.thread = threading.Thread(
target=self._start_thread, args=(started,))
self.thread.start()
if not started.wait(timeout=2):
raise RuntimeError('fuzzer proxy failed to start')
def stop(self):
self.loop.call_soon_threadsafe(self._stop)
self.thread.join()
def _stop(self):
self.stop_event.set()
def _start_thread(self, started_event):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.connectivity = asyncio.Event()
self.connectivity.set()
self.connectivity_loss = asyncio.Event()
self.stop_event = asyncio.Event()
if self.listening_port is None:
self.listening_port = cluster.find_available_port()
self.sock = socket.socket()
self.sock.bind((self.listening_addr, self.listening_port))
self.sock.listen(50)
self.sock.setblocking(False)
try:
self.loop.run_until_complete(self._main(started_event))
finally:
self.loop.close()
async def _main(self, started_event):
self.listen_task = asyncio.ensure_future(self.listen())
# Notify the main thread that we are ready to go.
started_event.set()
try:
await self.listen_task
finally:
for c in list(self.connections):
c.close()
await asyncio.sleep(0.01)
if hasattr(self.loop, 'remove_reader'):
self.loop.remove_reader(self.sock.fileno())
self.sock.close()
async def listen(self):
while True:
try:
client_sock, _ = await self._wait(
self.loop.sock_accept(self.sock))
backend_sock = socket.socket()
backend_sock.setblocking(False)
await self._wait(self.loop.sock_connect(
backend_sock, (self.backend_host, self.backend_port)))
except StopServer:
break
conn = Connection(client_sock, backend_sock, self)
conn_task = self.loop.create_task(conn.handle())
self.connections[conn] = conn_task
def trigger_connectivity_loss(self):
self.loop.call_soon_threadsafe(self._trigger_connectivity_loss)
def _trigger_connectivity_loss(self):
self.connectivity.clear()
self.connectivity_loss.set()
def restore_connectivity(self):
self.loop.call_soon_threadsafe(self._restore_connectivity)
def _restore_connectivity(self):
self.connectivity.set()
self.connectivity_loss.clear()
def reset(self):
self.restore_connectivity()
def _close_connection(self, connection):
conn_task = self.connections.pop(connection, None)
if conn_task is not None:
conn_task.cancel()
def close_all_connections(self):
for conn in list(self.connections):
self.loop.call_soon_threadsafe(self._close_connection, conn)
class Connection:
def __init__(self, client_sock, backend_sock, proxy):
self.client_sock = client_sock
self.backend_sock = backend_sock
self.proxy = proxy
self.loop = proxy.loop
self.connectivity = proxy.connectivity
self.connectivity_loss = proxy.connectivity_loss
self.proxy_to_backend_task = None
self.proxy_from_backend_task = None
self.is_closed = False
def close(self):
if self.is_closed:
return
self.is_closed = True
if self.proxy_to_backend_task is not None:
self.proxy_to_backend_task.cancel()
self.proxy_to_backend_task = None
if self.proxy_from_backend_task is not None:
self.proxy_from_backend_task.cancel()
self.proxy_from_backend_task = None
self.proxy._close_connection(self)
async def handle(self):
self.proxy_to_backend_task = asyncio.ensure_future(
self.proxy_to_backend())
self.proxy_from_backend_task = asyncio.ensure_future(
self.proxy_from_backend())
try:
await asyncio.wait(
[self.proxy_to_backend_task, self.proxy_from_backend_task],
return_when=asyncio.FIRST_COMPLETED)
finally:
if self.proxy_to_backend_task is not None:
self.proxy_to_backend_task.cancel()
if self.proxy_from_backend_task is not None:
self.proxy_from_backend_task.cancel()
# Asyncio fails to properly remove the readers and writers
# when the task doing recv() or send() is cancelled, so
# we must remove the readers and writers manually before
# closing the sockets.
self.loop.remove_reader(self.client_sock.fileno())
self.loop.remove_writer(self.client_sock.fileno())
self.loop.remove_reader(self.backend_sock.fileno())
self.loop.remove_writer(self.backend_sock.fileno())
self.client_sock.close()
self.backend_sock.close()
async def _read(self, sock, n):
read_task = asyncio.ensure_future(
self.loop.sock_recv(sock, n))
conn_event_task = asyncio.ensure_future(
self.connectivity_loss.wait())
try:
await asyncio.wait(
[read_task, conn_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.connectivity_loss.is_set():
return None
else:
return read_task.result()
finally:
if not self.loop.is_closed():
if not read_task.done():
read_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
async def _write(self, sock, data):
write_task = asyncio.ensure_future(
self.loop.sock_sendall(sock, data))
conn_event_task = asyncio.ensure_future(
self.connectivity_loss.wait())
try:
await asyncio.wait(
[write_task, conn_event_task],
return_when=asyncio.FIRST_COMPLETED)
if self.connectivity_loss.is_set():
return None
else:
return write_task.result()
finally:
if not self.loop.is_closed():
if not write_task.done():
write_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
async def proxy_to_backend(self):
buf = None
try:
while True:
await self.connectivity.wait()
if buf is not None:
data = buf
buf = None
else:
data = await self._read(self.client_sock, 4096)
if data == b'':
break
if self.connectivity_loss.is_set():
if data:
buf = data
continue
await self._write(self.backend_sock, data)
except ConnectionError:
pass
finally:
if not self.loop.is_closed():
self.loop.call_soon(self.close)
async def proxy_from_backend(self):
buf = None
try:
while True:
await self.connectivity.wait()
if buf is not None:
data = buf
buf = None
else:
data = await self._read(self.backend_sock, 4096)
if data == b'':
break
if self.connectivity_loss.is_set():
if data:
buf = data
continue
await self._write(self.client_sock, data)
except ConnectionError:
pass
finally:
if not self.loop.is_closed():
self.loop.call_soon(self.close)

View File

@@ -0,0 +1,17 @@
# This file MUST NOT contain anything but the __version__ assignment.
#
# When making a release, change the value of __version__
# to an appropriate value, and open a pull request against
# the correct branch (master if making a new feature release).
# The commit message MUST contain a properly formatted release
# log, and the commit must be signed.
#
# The release automation will: build and test the packages for the
# supported platforms, publish the packages on PyPI, merge the PR
# to the target branch, create a Git tag pointing to the commit.
from __future__ import annotations
import typing
__version__: typing.Final = '0.31.0'

View File

@@ -0,0 +1,729 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncio
import os
import os.path
import platform
import random
import re
import shutil
import socket
import string
import subprocess
import sys
import tempfile
import textwrap
import time
import asyncpg
from asyncpg import serverversion
_system = platform.uname().system
if _system == 'Windows':
def platform_exe(name):
if name.endswith('.exe'):
return name
return name + '.exe'
else:
def platform_exe(name):
return name
def find_available_port():
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
sock.bind(('127.0.0.1', 0))
return sock.getsockname()[1]
except Exception:
return None
finally:
sock.close()
def _world_readable_mkdtemp(suffix=None, prefix=None, dir=None):
name = "".join(random.choices(string.ascii_lowercase, k=8))
if dir is None:
dir = tempfile.gettempdir()
if prefix is None:
prefix = tempfile.gettempprefix()
if suffix is None:
suffix = ""
fn = os.path.join(dir, prefix + name + suffix)
os.mkdir(fn, 0o755)
return fn
def _mkdtemp(suffix=None, prefix=None, dir=None):
if _system == 'Windows' and os.environ.get("GITHUB_ACTIONS"):
# Due to mitigations introduced in python/cpython#118486
# when Python runs in a session created via an SSH connection
# tempfile.mkdtemp creates directories that are not accessible.
return _world_readable_mkdtemp(suffix, prefix, dir)
else:
return tempfile.mkdtemp(suffix, prefix, dir)
class ClusterError(Exception):
pass
class Cluster:
def __init__(self, data_dir, *, pg_config_path=None):
self._data_dir = data_dir
self._pg_config_path = pg_config_path
self._pg_bin_dir = (
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
self._pg_ctl = None
self._daemon_pid = None
self._daemon_process = None
self._connection_addr = None
self._connection_spec_override = None
def get_pg_version(self):
return self._pg_version
def is_managed(self):
return True
def get_data_dir(self):
return self._data_dir
def get_status(self):
if self._pg_ctl is None:
self._init_env()
process = subprocess.run(
[self._pg_ctl, 'status', '-D', self._data_dir],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if (process.returncode == 4 or not os.path.exists(self._data_dir) or
not os.listdir(self._data_dir)):
return 'not-initialized'
elif process.returncode == 3:
return 'stopped'
elif process.returncode == 0:
r = re.match(r'.*PID\s?:\s+(\d+).*', stdout.decode())
if not r:
raise ClusterError(
'could not parse pg_ctl status output: {}'.format(
stdout.decode()))
self._daemon_pid = int(r.group(1))
return self._test_connection(timeout=0)
else:
raise ClusterError(
'pg_ctl status exited with status {:d}: {}'.format(
process.returncode, stderr))
async def connect(self, loop=None, **kwargs):
conn_info = self.get_connection_spec()
conn_info.update(kwargs)
return await asyncpg.connect(loop=loop, **conn_info)
def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))
settings = dict(settings)
if 'encoding' not in settings:
settings['encoding'] = 'UTF-8'
if settings:
settings_args = ['--{}={}'.format(k, v)
for k, v in settings.items()]
extra_args = ['-o'] + [' '.join(settings_args)]
else:
extra_args = []
os.makedirs(self._data_dir, exist_ok=True)
process = subprocess.run(
[self._pg_ctl, 'init', '-D', self._data_dir] + extra_args,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
output = process.stdout
if process.returncode != 0:
raise ClusterError(
'pg_ctl init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
return output.decode()
def start(self, wait=60, *, server_settings={}, **opts):
"""Start the cluster."""
status = self.get_status()
if status == 'running':
return
elif status == 'not-initialized':
raise ClusterError(
'cluster in {!r} has not been initialized'.format(
self._data_dir))
port = opts.pop('port', None)
if port == 'dynamic':
port = find_available_port()
extra_args = ['--{}={}'.format(k, v) for k, v in opts.items()]
extra_args.append('--port={}'.format(port))
sockdir = server_settings.get('unix_socket_directories')
if sockdir is None:
sockdir = server_settings.get('unix_socket_directory')
if sockdir is None and _system != 'Windows':
sockdir = tempfile.gettempdir()
ssl_key = server_settings.get('ssl_key_file')
if ssl_key:
# Make sure server certificate key file has correct permissions.
keyfile = os.path.join(self._data_dir, 'srvkey.pem')
shutil.copy(ssl_key, keyfile)
os.chmod(keyfile, 0o600)
server_settings = server_settings.copy()
server_settings['ssl_key_file'] = keyfile
if sockdir is not None:
if self._pg_version < (9, 3):
sockdir_opt = 'unix_socket_directory'
else:
sockdir_opt = 'unix_socket_directories'
server_settings[sockdir_opt] = sockdir
for k, v in server_settings.items():
extra_args.extend(['-c', '{}={}'.format(k, v)])
if _system == 'Windows':
# On Windows we have to use pg_ctl as direct execution
# of postgres daemon under an Administrative account
# is not permitted and there is no easy way to drop
# privileges.
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
print(
'asyncpg.cluster: Running',
' '.join([
self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)
]),
file=sys.stderr,
)
else:
stdout = subprocess.DEVNULL
process = subprocess.run(
[self._pg_ctl, 'start', '-D', self._data_dir,
'-o', ' '.join(extra_args)],
stdout=stdout,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
if process.returncode != 0:
if process.stderr:
stderr = ':\n{}'.format(process.stderr.decode())
else:
stderr = ''
raise ClusterError(
'pg_ctl start exited with status {:d}{}'.format(
process.returncode, stderr))
else:
if os.getenv('ASYNCPG_DEBUG_SERVER'):
stdout = sys.stdout
else:
stdout = subprocess.DEVNULL
self._daemon_process = \
subprocess.Popen(
[self._postgres, '-D', self._data_dir, *extra_args],
stdout=stdout,
stderr=subprocess.STDOUT,
cwd=self._data_dir,
)
self._daemon_pid = self._daemon_process.pid
self._test_connection(timeout=wait)
def reload(self):
"""Reload server configuration."""
status = self.get_status()
if status != 'running':
raise ClusterError('cannot reload: cluster is not running')
process = subprocess.run(
[self._pg_ctl, 'reload', '-D', self._data_dir],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=self._data_dir,
)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_ctl stop exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
def stop(self, wait=60):
process = subprocess.run(
[self._pg_ctl, 'stop', '-D', self._data_dir, '-t', str(wait),
'-m', 'fast'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=self._data_dir,
)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_ctl stop exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
if (self._daemon_process is not None and
self._daemon_process.returncode is None):
self._daemon_process.kill()
def destroy(self):
status = self.get_status()
if status == 'stopped' or status == 'not-initialized':
shutil.rmtree(self._data_dir)
else:
raise ClusterError('cannot destroy {} cluster'.format(status))
def _get_connection_spec(self):
if self._connection_addr is None:
self._connection_addr = self._connection_addr_from_pidfile()
if self._connection_addr is not None:
if self._connection_spec_override:
args = self._connection_addr.copy()
args.update(self._connection_spec_override)
return args
else:
return self._connection_addr
def get_connection_spec(self):
status = self.get_status()
if status != 'running':
raise ClusterError('cluster is not running')
return self._get_connection_spec()
def override_connection_spec(self, **kwargs):
self._connection_spec_override = kwargs
def reset_wal(self, *, oid=None, xid=None):
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify WAL status: cluster is not initialized')
if status == 'running':
raise ClusterError(
'cannot modify WAL status: cluster is running')
opts = []
if oid is not None:
opts.extend(['-o', str(oid)])
if xid is not None:
opts.extend(['-x', str(xid)])
if not opts:
return
opts.append(self._data_dir)
try:
reset_wal = self._find_pg_binary('pg_resetwal')
except ClusterError:
reset_wal = self._find_pg_binary('pg_resetxlog')
process = subprocess.run(
[reset_wal] + opts,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr = process.stderr
if process.returncode != 0:
raise ClusterError(
'pg_resetwal exited with status {:d}: {}'.format(
process.returncode, stderr.decode()))
def reset_hba(self):
"""Remove all records from pg_hba.conf."""
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify HBA records: cluster is not initialized')
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
try:
with open(pg_hba, 'w'):
pass
except IOError as e:
raise ClusterError(
'cannot modify HBA records: {}'.format(e)) from e
def add_hba_entry(self, *, type='host', database, user, address=None,
auth_method, auth_options=None):
"""Add a record to pg_hba.conf."""
status = self.get_status()
if status == 'not-initialized':
raise ClusterError(
'cannot modify HBA records: cluster is not initialized')
if type not in {'local', 'host', 'hostssl', 'hostnossl'}:
raise ValueError('invalid HBA record type: {!r}'.format(type))
pg_hba = os.path.join(self._data_dir, 'pg_hba.conf')
record = '{} {} {}'.format(type, database, user)
if type != 'local':
if address is None:
raise ValueError(
'{!r} entry requires a valid address'.format(type))
else:
record += ' {}'.format(address)
record += ' {}'.format(auth_method)
if auth_options is not None:
record += ' ' + ' '.join(
'{}={}'.format(k, v) for k, v in auth_options)
try:
with open(pg_hba, 'a') as f:
print(record, file=f)
except IOError as e:
raise ClusterError(
'cannot modify HBA records: {}'.format(e)) from e
def trust_local_connections(self):
self.reset_hba()
if _system != 'Windows':
self.add_hba_entry(type='local', database='all',
user='all', auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='all', user='all',
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='all', user='all',
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()
def trust_local_replication_by(self, user):
if _system != 'Windows':
self.add_hba_entry(type='local', database='replication',
user=user, auth_method='trust')
self.add_hba_entry(type='host', address='127.0.0.1/32',
database='replication', user=user,
auth_method='trust')
self.add_hba_entry(type='host', address='::1/128',
database='replication', user=user,
auth_method='trust')
status = self.get_status()
if status == 'running':
self.reload()
def _init_env(self):
if not self._pg_bin_dir:
pg_config = self._find_pg_config(self._pg_config_path)
pg_config_data = self._run_pg_config(pg_config)
self._pg_bin_dir = pg_config_data.get('bindir')
if not self._pg_bin_dir:
raise ClusterError(
'pg_config output did not provide the BINDIR value')
self._pg_ctl = self._find_pg_binary('pg_ctl')
self._postgres = self._find_pg_binary('postgres')
self._pg_version = self._get_pg_version()
def _connection_addr_from_pidfile(self):
pidfile = os.path.join(self._data_dir, 'postmaster.pid')
try:
with open(pidfile, 'rt') as f:
piddata = f.read()
except FileNotFoundError:
return None
lines = piddata.splitlines()
if len(lines) < 6:
# A complete postgres pidfile is at least 6 lines
return None
pmpid = int(lines[0])
if self._daemon_pid and pmpid != self._daemon_pid:
# This might be an old pidfile left from previous postgres
# daemon run.
return None
portnum = lines[3]
sockdir = lines[4]
hostaddr = lines[5]
if sockdir:
if sockdir[0] != '/':
# Relative sockdir
sockdir = os.path.normpath(
os.path.join(self._data_dir, sockdir))
host_str = sockdir
else:
host_str = hostaddr
if host_str == '*':
host_str = 'localhost'
elif host_str == '0.0.0.0':
host_str = '127.0.0.1'
elif host_str == '::':
host_str = '::1'
return {
'host': host_str,
'port': portnum
}
def _test_connection(self, timeout=60):
self._connection_addr = None
loop = asyncio.new_event_loop()
try:
for i in range(timeout):
if self._connection_addr is None:
conn_spec = self._get_connection_spec()
if conn_spec is None:
time.sleep(1)
continue
try:
con = loop.run_until_complete(
asyncpg.connect(database='postgres',
user='postgres',
timeout=5, loop=loop,
**self._connection_addr))
except (OSError, asyncio.TimeoutError,
asyncpg.CannotConnectNowError,
asyncpg.PostgresConnectionError):
time.sleep(1)
continue
except asyncpg.PostgresError:
# Any other error other than ServerNotReadyError or
# ConnectionError is interpreted to indicate the server is
# up.
break
else:
loop.run_until_complete(con.close())
break
finally:
loop.close()
return 'running'
def _run_pg_config(self, pg_config_path):
process = subprocess.run(
pg_config_path, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if process.returncode != 0:
raise ClusterError('pg_config exited with status {:d}: {}'.format(
process.returncode, stderr))
else:
config = {}
for line in stdout.splitlines():
k, eq, v = line.decode('utf-8').partition('=')
if eq:
config[k.strip().lower()] = v.strip()
return config
def _find_pg_config(self, pg_config_path):
if pg_config_path is None:
pg_install = (
os.environ.get('PGINSTALLATION')
or os.environ.get('PGBIN')
)
if pg_install:
pg_config_path = platform_exe(
os.path.join(pg_install, 'pg_config'))
else:
pathenv = os.environ.get('PATH').split(os.pathsep)
for path in pathenv:
pg_config_path = platform_exe(
os.path.join(path, 'pg_config'))
if os.path.exists(pg_config_path):
break
else:
pg_config_path = None
if not pg_config_path:
raise ClusterError('could not find pg_config executable')
if not os.path.isfile(pg_config_path):
raise ClusterError('{!r} is not an executable'.format(
pg_config_path))
return pg_config_path
def _find_pg_binary(self, binary):
bpath = platform_exe(os.path.join(self._pg_bin_dir, binary))
if not os.path.isfile(bpath):
raise ClusterError(
'could not find {} executable: '.format(binary) +
'{!r} does not exist or is not a file'.format(bpath))
return bpath
def _get_pg_version(self):
process = subprocess.run(
[self._postgres, '--version'],
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.stdout, process.stderr
if process.returncode != 0:
raise ClusterError(
'postgres --version exited with status {:d}: {}'.format(
process.returncode, stderr))
version_string = stdout.decode('utf-8').strip(' \n')
prefix = 'postgres (PostgreSQL) '
if not version_string.startswith(prefix):
raise ClusterError(
'could not determine server version from {!r}'.format(
version_string))
version_string = version_string[len(prefix):]
return serverversion.split_server_version_string(version_string)
class TempCluster(Cluster):
def __init__(self, *,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._data_dir = _mkdtemp(suffix=data_dir_suffix,
prefix=data_dir_prefix,
dir=data_dir_parent)
super().__init__(self._data_dir, pg_config_path=pg_config_path)
class HotStandbyCluster(TempCluster):
def __init__(self, *,
master, replication_user,
data_dir_suffix=None, data_dir_prefix=None,
data_dir_parent=None, pg_config_path=None):
self._master = master
self._repl_user = replication_user
super().__init__(
data_dir_suffix=data_dir_suffix,
data_dir_prefix=data_dir_prefix,
data_dir_parent=data_dir_parent,
pg_config_path=pg_config_path)
def _init_env(self):
super()._init_env()
self._pg_basebackup = self._find_pg_binary('pg_basebackup')
def init(self, **settings):
"""Initialize cluster."""
if self.get_status() != 'not-initialized':
raise ClusterError(
'cluster in {!r} has already been initialized'.format(
self._data_dir))
process = subprocess.run(
[self._pg_basebackup, '-h', self._master['host'],
'-p', self._master['port'], '-D', self._data_dir,
'-U', self._repl_user],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output = process.stdout
if process.returncode != 0:
raise ClusterError(
'pg_basebackup init exited with status {:d}:\n{}'.format(
process.returncode, output.decode()))
if self._pg_version < (12, 0):
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
f.write(textwrap.dedent("""\
standby_mode = 'on'
primary_conninfo = 'host={host} port={port} user={user}'
""".format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user)))
else:
f = open(os.path.join(self._data_dir, 'standby.signal'), 'w')
f.close()
return output.decode()
def start(self, wait=60, *, server_settings={}, **opts):
if self._pg_version >= (12, 0):
server_settings = server_settings.copy()
server_settings['primary_conninfo'] = (
'"host={host} port={port} user={user}"'.format(
host=self._master['host'],
port=self._master['port'],
user=self._repl_user,
)
)
super().start(wait=wait, server_settings=server_settings, **opts)
class RunningCluster(Cluster):
def __init__(self, **kwargs):
self.conn_spec = kwargs
def is_managed(self):
return False
def get_connection_spec(self):
return dict(self.conn_spec)
def get_status(self):
return 'running'
def init(self, **settings):
pass
def start(self, wait=60, **settings):
pass
def stop(self, wait=60):
pass
def destroy(self):
pass
def reset_hba(self):
raise ClusterError('cannot modify HBA records of unmanaged cluster')
def add_hba_entry(self, *, type='host', database, user, address=None,
auth_method, auth_options=None):
raise ClusterError('cannot modify HBA records of unmanaged cluster')

View File

@@ -0,0 +1,88 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import enum
import pathlib
import platform
import typing
import sys
if typing.TYPE_CHECKING:
import asyncio
SYSTEM: typing.Final = platform.uname().system
if sys.platform == 'win32':
import ctypes.wintypes
CSIDL_APPDATA: typing.Final = 0x001a
def get_pg_home_directory() -> pathlib.Path | None:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH)
r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf)
if r:
return None
else:
return pathlib.Path(buf.value) / 'postgresql'
else:
def get_pg_home_directory() -> pathlib.Path | None:
try:
return pathlib.Path.home()
except (RuntimeError, KeyError):
return None
async def wait_closed(stream: asyncio.StreamWriter) -> None:
# Not all asyncio versions have StreamWriter.wait_closed().
if hasattr(stream, 'wait_closed'):
try:
await stream.wait_closed()
except ConnectionResetError:
# On Windows wait_closed() sometimes propagates
# ConnectionResetError which is totally unnecessary.
pass
if sys.version_info < (3, 12):
def markcoroutinefunction(c): # type: ignore
pass
else:
from inspect import markcoroutinefunction # noqa: F401
if sys.version_info < (3, 12):
from ._asyncio_compat import wait_for as wait_for # noqa: F401
else:
from asyncio import wait_for as wait_for # noqa: F401
if sys.version_info < (3, 11):
from ._asyncio_compat import timeout_ctx as timeout # noqa: F401
else:
from asyncio import timeout as timeout # noqa: F401
if sys.version_info < (3, 9):
from typing import ( # noqa: F401
Awaitable as Awaitable,
)
else:
from collections.abc import ( # noqa: F401
Awaitable as Awaitable,
)
if sys.version_info < (3, 11):
class StrEnum(str, enum.Enum):
__str__ = str.__str__
__repr__ = enum.Enum.__repr__
else:
from enum import StrEnum as StrEnum # noqa: F401

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,44 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import functools
from . import exceptions
def guarded(meth):
"""A decorator to add a sanity check to ConnectionResource methods."""
@functools.wraps(meth)
def _check(self, *args, **kwargs):
self._check_conn_validity(meth.__name__)
return meth(self, *args, **kwargs)
return _check
class ConnectionResource:
__slots__ = ('_connection', '_con_release_ctr')
def __init__(self, connection):
self._connection = connection
self._con_release_ctr = connection._pool_release_ctr
def _check_conn_validity(self, meth_name):
con_release_ctr = self._connection._pool_release_ctr
if con_release_ctr != self._con_release_ctr:
raise exceptions.InterfaceError(
'cannot call {}.{}(): '
'the underlying connection has been released back '
'to the pool'.format(self.__class__.__name__, meth_name))
if self._connection.is_closed():
raise exceptions.InterfaceError(
'cannot call {}.{}(): '
'the underlying connection is closed'.format(
self.__class__.__name__, meth_name))

View File

@@ -0,0 +1,323 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import collections
from . import connresource
from . import exceptions
class CursorFactory(connresource.ConnectionResource):
"""A cursor interface for the results of a query.
A cursor interface can be used to initiate efficient traversal of the
results of a large query.
"""
__slots__ = (
'_state',
'_args',
'_prefetch',
'_query',
'_timeout',
'_record_class',
)
def __init__(
self,
connection,
query,
state,
args,
prefetch,
timeout,
record_class
):
super().__init__(connection)
self._args = args
self._prefetch = prefetch
self._query = query
self._timeout = timeout
self._state = state
self._record_class = record_class
if state is not None:
state.attach()
@connresource.guarded
def __aiter__(self):
prefetch = 50 if self._prefetch is None else self._prefetch
return CursorIterator(
self._connection,
self._query,
self._state,
self._args,
self._record_class,
prefetch,
self._timeout,
)
@connresource.guarded
def __await__(self):
if self._prefetch is not None:
raise exceptions.InterfaceError(
'prefetch argument can only be specified for iterable cursor')
cursor = Cursor(
self._connection,
self._query,
self._state,
self._args,
self._record_class,
)
return cursor._init(self._timeout).__await__()
def __del__(self):
if self._state is not None:
self._state.detach()
self._connection._maybe_gc_stmt(self._state)
class BaseCursor(connresource.ConnectionResource):
__slots__ = (
'_state',
'_args',
'_portal_name',
'_exhausted',
'_query',
'_record_class',
)
def __init__(self, connection, query, state, args, record_class):
super().__init__(connection)
self._args = args
self._state = state
if state is not None:
state.attach()
self._portal_name = None
self._exhausted = False
self._query = query
self._record_class = record_class
def _check_ready(self):
if self._state is None:
raise exceptions.InterfaceError(
'cursor: no associated prepared statement')
if self._state.closed:
raise exceptions.InterfaceError(
'cursor: the prepared statement is closed')
if not self._connection._top_xact:
raise exceptions.NoActiveSQLTransactionError(
'cursor cannot be created outside of a transaction')
async def _bind_exec(self, n, timeout):
self._check_ready()
if self._portal_name:
raise exceptions.InterfaceError(
'cursor already has an open portal')
con = self._connection
protocol = con._protocol
self._portal_name = con._get_unique_id('portal')
buffer, _, self._exhausted = await protocol.bind_execute(
self._state, self._args, self._portal_name, n, True, timeout)
return buffer
async def _bind(self, timeout):
self._check_ready()
if self._portal_name:
raise exceptions.InterfaceError(
'cursor already has an open portal')
con = self._connection
protocol = con._protocol
self._portal_name = con._get_unique_id('portal')
buffer = await protocol.bind(self._state, self._args,
self._portal_name,
timeout)
return buffer
async def _exec(self, n, timeout):
self._check_ready()
if not self._portal_name:
raise exceptions.InterfaceError(
'cursor does not have an open portal')
protocol = self._connection._protocol
buffer, _, self._exhausted = await protocol.execute(
self._state, self._portal_name, n, True, timeout)
return buffer
async def _close_portal(self, timeout):
self._check_ready()
if not self._portal_name:
raise exceptions.InterfaceError(
'cursor does not have an open portal')
protocol = self._connection._protocol
await protocol.close_portal(self._portal_name, timeout)
self._portal_name = None
def __repr__(self):
attrs = []
if self._exhausted:
attrs.append('exhausted')
attrs.append('') # to separate from id
if self.__class__.__module__.startswith('asyncpg.'):
mod = 'asyncpg'
else:
mod = self.__class__.__module__
return '<{}.{} "{!s:.30}" {}{:#x}>'.format(
mod, self.__class__.__name__,
self._state.query,
' '.join(attrs), id(self))
def __del__(self):
if self._state is not None:
self._state.detach()
self._connection._maybe_gc_stmt(self._state)
class CursorIterator(BaseCursor):
__slots__ = ('_buffer', '_prefetch', '_timeout')
def __init__(
self,
connection,
query,
state,
args,
record_class,
prefetch,
timeout
):
super().__init__(connection, query, state, args, record_class)
if prefetch <= 0:
raise exceptions.InterfaceError(
'prefetch argument must be greater than zero')
self._buffer = collections.deque()
self._prefetch = prefetch
self._timeout = timeout
@connresource.guarded
def __aiter__(self):
return self
@connresource.guarded
async def __anext__(self):
if self._state is None:
self._state = await self._connection._get_statement(
self._query,
self._timeout,
named=True,
record_class=self._record_class,
)
self._state.attach()
if not self._portal_name and not self._exhausted:
buffer = await self._bind_exec(self._prefetch, self._timeout)
self._buffer.extend(buffer)
if not self._buffer and not self._exhausted:
buffer = await self._exec(self._prefetch, self._timeout)
self._buffer.extend(buffer)
if self._portal_name and self._exhausted:
await self._close_portal(self._timeout)
if self._buffer:
return self._buffer.popleft()
raise StopAsyncIteration
class Cursor(BaseCursor):
"""An open *portal* into the results of a query."""
__slots__ = ()
async def _init(self, timeout):
if self._state is None:
self._state = await self._connection._get_statement(
self._query,
timeout,
named=True,
record_class=self._record_class,
)
self._state.attach()
self._check_ready()
await self._bind(timeout)
return self
@connresource.guarded
async def fetch(self, n, *, timeout=None):
r"""Return the next *n* rows as a list of :class:`Record` objects.
:param float timeout: Optional timeout value in seconds.
:return: A list of :class:`Record` instances.
"""
self._check_ready()
if n <= 0:
raise exceptions.InterfaceError('n must be greater than zero')
if self._exhausted:
return []
recs = await self._exec(n, timeout)
if len(recs) < n:
self._exhausted = True
return recs
@connresource.guarded
async def fetchrow(self, *, timeout=None):
r"""Return the next row.
:param float timeout: Optional timeout value in seconds.
:return: A :class:`Record` instance.
"""
self._check_ready()
if self._exhausted:
return None
recs = await self._exec(1, timeout)
if len(recs) < 1:
self._exhausted = True
return None
return recs[0]
@connresource.guarded
async def forward(self, n, *, timeout=None) -> int:
r"""Skip over the next *n* rows.
:param float timeout: Optional timeout value in seconds.
:return: A number of rows actually skipped over (<= *n*).
"""
self._check_ready()
if n <= 0:
raise exceptions.InterfaceError('n must be greater than zero')
protocol = self._connection._protocol
status = await protocol.query('MOVE FORWARD {:d} {}'.format(
n, self._portal_name), timeout)
advanced = int(status.split()[1])
if advanced < n:
self._exhausted = True
return advanced

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,299 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import asyncpg
import sys
import textwrap
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
'ClientConfigurationError',
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched',
'UnsupportedServerFeatureError')
def _is_asyncpg_class(cls):
modname = cls.__module__
return modname == 'asyncpg' or modname.startswith('asyncpg.')
class PostgresMessageMeta(type):
_message_map = {}
_field_map = {
'S': 'severity',
'V': 'severity_en',
'C': 'sqlstate',
'M': 'message',
'D': 'detail',
'H': 'hint',
'P': 'position',
'p': 'internal_position',
'q': 'internal_query',
'W': 'context',
's': 'schema_name',
't': 'table_name',
'c': 'column_name',
'd': 'data_type_name',
'n': 'constraint_name',
'F': 'server_source_filename',
'L': 'server_source_line',
'R': 'server_source_function'
}
def __new__(mcls, name, bases, dct):
cls = super().__new__(mcls, name, bases, dct)
if cls.__module__ == mcls.__module__ and name == 'PostgresMessage':
for f in mcls._field_map.values():
setattr(cls, f, None)
if _is_asyncpg_class(cls):
mod = sys.modules[cls.__module__]
if hasattr(mod, name):
raise RuntimeError('exception class redefinition: {}'.format(
name))
code = dct.get('sqlstate')
if code is not None:
existing = mcls._message_map.get(code)
if existing is not None:
raise TypeError('{} has duplicate SQLSTATE code, which is'
'already defined by {}'.format(
name, existing.__name__))
mcls._message_map[code] = cls
return cls
@classmethod
def get_message_class_for_sqlstate(mcls, code):
return mcls._message_map.get(code, UnknownPostgresError)
class PostgresMessage(metaclass=PostgresMessageMeta):
@classmethod
def _get_error_class(cls, fields):
sqlstate = fields.get('C')
return type(cls).get_message_class_for_sqlstate(sqlstate)
@classmethod
def _get_error_dict(cls, fields, query):
dct = {
'query': query
}
field_map = type(cls)._field_map
for k, v in fields.items():
field = field_map.get(k)
if field:
dct[field] = v
return dct
@classmethod
def _make_constructor(cls, fields, query=None):
dct = cls._get_error_dict(fields, query)
exccls = cls._get_error_class(fields)
message = dct.get('message', '')
# PostgreSQL will raise an exception when it detects
# that the result type of the query has changed from
# when the statement was prepared.
#
# The original error is somewhat cryptic and unspecific,
# so we raise a custom subclass that is easier to handle
# and identify.
#
# Note that we specifically do not rely on the error
# message, as it is localizable.
is_icse = (
exccls.__name__ == 'FeatureNotSupportedError' and
_is_asyncpg_class(exccls) and
dct.get('server_source_function') == 'RevalidateCachedQuery'
)
if is_icse:
exceptions = sys.modules[exccls.__module__]
exccls = exceptions.InvalidCachedStatementError
message = ('cached statement plan is invalid due to a database '
'schema or configuration change')
is_prepared_stmt_error = (
exccls.__name__ in ('DuplicatePreparedStatementError',
'InvalidSQLStatementNameError') and
_is_asyncpg_class(exccls)
)
if is_prepared_stmt_error:
hint = dct.get('hint', '')
hint += textwrap.dedent("""\
NOTE: pgbouncer with pool_mode set to "transaction" or
"statement" does not support prepared statements properly.
You have two options:
* if you are using pgbouncer for connection pooling to a
single server, switch to the connection pool functionality
provided by asyncpg, it is a much better option for this
purpose;
* if you have no option of avoiding the use of pgbouncer,
then you can set statement_cache_size to 0 when creating
the asyncpg connection object.
""")
dct['hint'] = hint
return exccls, message, dct
def as_dict(self):
dct = {}
for f in type(self)._field_map.values():
val = getattr(self, f)
if val is not None:
dct[f] = val
return dct
class PostgresError(PostgresMessage, Exception):
"""Base class for all Postgres errors."""
def __str__(self):
msg = self.args[0]
if self.detail:
msg += '\nDETAIL: {}'.format(self.detail)
if self.hint:
msg += '\nHINT: {}'.format(self.hint)
return msg
@classmethod
def new(cls, fields, query=None):
exccls, message, dct = cls._make_constructor(fields, query)
ex = exccls(message)
ex.__dict__.update(dct)
return ex
class FatalPostgresError(PostgresError):
"""A fatal error that should result in server disconnection."""
class UnknownPostgresError(FatalPostgresError):
"""An error with an unknown SQLSTATE code."""
class InterfaceMessage:
def __init__(self, *, detail=None, hint=None):
self.detail = detail
self.hint = hint
def __str__(self):
msg = self.args[0]
if self.detail:
msg += '\nDETAIL: {}'.format(self.detail)
if self.hint:
msg += '\nHINT: {}'.format(self.hint)
return msg
class InterfaceError(InterfaceMessage, Exception):
"""An error caused by improper use of asyncpg API."""
def __init__(self, msg, *, detail=None, hint=None):
InterfaceMessage.__init__(self, detail=detail, hint=hint)
Exception.__init__(self, msg)
def with_msg(self, msg):
return type(self)(
msg,
detail=self.detail,
hint=self.hint,
).with_traceback(
self.__traceback__
)
class ClientConfigurationError(InterfaceError, ValueError):
"""An error caused by improper client configuration."""
class DataError(InterfaceError, ValueError):
"""An error caused by invalid query input."""
class UnsupportedClientFeatureError(InterfaceError):
"""Requested feature is unsupported by asyncpg."""
class UnsupportedServerFeatureError(InterfaceError):
"""Requested feature is unsupported by PostgreSQL server."""
class InterfaceWarning(InterfaceMessage, UserWarning):
"""A warning caused by an improper use of asyncpg API."""
def __init__(self, msg, *, detail=None, hint=None):
InterfaceMessage.__init__(self, detail=detail, hint=hint)
UserWarning.__init__(self, msg)
class InternalClientError(Exception):
"""All unexpected errors not classified otherwise."""
class ProtocolError(InternalClientError):
"""Unexpected condition in the handling of PostgreSQL protocol input."""
class TargetServerAttributeNotMatched(InternalClientError):
"""Could not find a host that satisfies the target attribute requirement"""
class OutdatedSchemaCacheError(InternalClientError):
"""A value decoding error caused by a schema change before row fetching."""
def __init__(self, msg, *, schema=None, data_type=None, position=None):
super().__init__(msg)
self.schema_name = schema
self.data_type_name = data_type
self.position = position
class PostgresLogMessage(PostgresMessage):
"""A base class for non-error server messages."""
def __str__(self):
return '{}: {}'.format(type(self).__name__, self.message)
def __setattr__(self, name, val):
raise TypeError('instances of {} are immutable'.format(
type(self).__name__))
@classmethod
def new(cls, fields, query=None):
exccls, message_text, dct = cls._make_constructor(fields, query)
if exccls is UnknownPostgresError:
exccls = PostgresLogMessage
if exccls is PostgresLogMessage:
severity = dct.get('severity_en') or dct.get('severity')
if severity and severity.upper() == 'WARNING':
exccls = asyncpg.PostgresWarning
if issubclass(exccls, (BaseException, Warning)):
msg = exccls(message_text)
else:
msg = exccls()
msg.__dict__.update(dct)
return msg

View File

@@ -0,0 +1,296 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from __future__ import annotations
import typing
from .protocol.protocol import _create_record # type: ignore
if typing.TYPE_CHECKING:
from . import protocol
_TYPEINFO_13: typing.Final = '''\
(
SELECT
t.oid AS oid,
ns.nspname AS ns,
t.typname AS name,
t.typtype AS kind,
(CASE WHEN t.typtype = 'd' THEN
(WITH RECURSIVE typebases(oid, depth) AS (
SELECT
t2.typbasetype AS oid,
0 AS depth
FROM
pg_type t2
WHERE
t2.oid = t.oid
UNION ALL
SELECT
t2.typbasetype AS oid,
tb.depth + 1 AS depth
FROM
pg_type t2,
typebases tb
WHERE
tb.oid = t2.oid
AND t2.typbasetype != 0
) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1)
ELSE NULL
END) AS basetype,
t.typelem AS elemtype,
elem_t.typdelim AS elemdelim,
range_t.rngsubtype AS range_subtype,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.atttypid ORDER BY ia.attnum)
FROM
pg_attribute ia
INNER JOIN pg_class c
ON (ia.attrelid = c.oid)
WHERE
ia.attnum > 0 AND NOT ia.attisdropped
AND c.reltype = t.oid)
ELSE NULL
END) AS attrtypoids,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.attname::text ORDER BY ia.attnum)
FROM
pg_attribute ia
INNER JOIN pg_class c
ON (ia.attrelid = c.oid)
WHERE
ia.attnum > 0 AND NOT ia.attisdropped
AND c.reltype = t.oid)
ELSE NULL
END) AS attrnames
FROM
pg_catalog.pg_type AS t
INNER JOIN pg_catalog.pg_namespace ns ON (
ns.oid = t.typnamespace)
LEFT JOIN pg_type elem_t ON (
t.typlen = -1 AND
t.typelem != 0 AND
t.typelem = elem_t.oid
)
LEFT JOIN pg_range range_t ON (
t.oid = range_t.rngtypid
)
)
'''
INTRO_LOOKUP_TYPES_13 = '''\
WITH RECURSIVE typeinfo_tree(
oid, ns, name, kind, basetype, elemtype, elemdelim,
range_subtype, attrtypoids, attrnames, depth)
AS (
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, 0
FROM
{typeinfo} AS ti
WHERE
ti.oid = any($1::oid[])
UNION ALL
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, tt.depth + 1
FROM
{typeinfo} ti,
typeinfo_tree tt
WHERE
(tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
)
SELECT DISTINCT
*,
basetype::regtype::text AS basetype_name,
elemtype::regtype::text AS elemtype_name,
range_subtype::regtype::text AS range_subtype_name
FROM
typeinfo_tree
ORDER BY
depth DESC
'''.format(typeinfo=_TYPEINFO_13)
_TYPEINFO: typing.Final = '''\
(
SELECT
t.oid AS oid,
ns.nspname AS ns,
t.typname AS name,
t.typtype AS kind,
(CASE WHEN t.typtype = 'd' THEN
(WITH RECURSIVE typebases(oid, depth) AS (
SELECT
t2.typbasetype AS oid,
0 AS depth
FROM
pg_type t2
WHERE
t2.oid = t.oid
UNION ALL
SELECT
t2.typbasetype AS oid,
tb.depth + 1 AS depth
FROM
pg_type t2,
typebases tb
WHERE
tb.oid = t2.oid
AND t2.typbasetype != 0
) SELECT oid FROM typebases ORDER BY depth DESC LIMIT 1)
ELSE NULL
END) AS basetype,
t.typelem AS elemtype,
elem_t.typdelim AS elemdelim,
COALESCE(
range_t.rngsubtype,
multirange_t.rngsubtype) AS range_subtype,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.atttypid ORDER BY ia.attnum)
FROM
pg_attribute ia
INNER JOIN pg_class c
ON (ia.attrelid = c.oid)
WHERE
ia.attnum > 0 AND NOT ia.attisdropped
AND c.reltype = t.oid)
ELSE NULL
END) AS attrtypoids,
(CASE WHEN t.typtype = 'c' THEN
(SELECT
array_agg(ia.attname::text ORDER BY ia.attnum)
FROM
pg_attribute ia
INNER JOIN pg_class c
ON (ia.attrelid = c.oid)
WHERE
ia.attnum > 0 AND NOT ia.attisdropped
AND c.reltype = t.oid)
ELSE NULL
END) AS attrnames
FROM
pg_catalog.pg_type AS t
INNER JOIN pg_catalog.pg_namespace ns ON (
ns.oid = t.typnamespace)
LEFT JOIN pg_type elem_t ON (
t.typlen = -1 AND
t.typelem != 0 AND
t.typelem = elem_t.oid
)
LEFT JOIN pg_range range_t ON (
t.oid = range_t.rngtypid
)
LEFT JOIN pg_range multirange_t ON (
t.oid = multirange_t.rngmultitypid
)
)
'''
INTRO_LOOKUP_TYPES = '''\
WITH RECURSIVE typeinfo_tree(
oid, ns, name, kind, basetype, elemtype, elemdelim,
range_subtype, attrtypoids, attrnames, depth)
AS (
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, 0
FROM
{typeinfo} AS ti
WHERE
ti.oid = any($1::oid[])
UNION ALL
SELECT
ti.oid, ti.ns, ti.name, ti.kind, ti.basetype,
ti.elemtype, ti.elemdelim, ti.range_subtype,
ti.attrtypoids, ti.attrnames, tt.depth + 1
FROM
{typeinfo} ti,
typeinfo_tree tt
WHERE
(tt.elemtype IS NOT NULL AND ti.oid = tt.elemtype)
OR (tt.attrtypoids IS NOT NULL AND ti.oid = any(tt.attrtypoids))
OR (tt.range_subtype IS NOT NULL AND ti.oid = tt.range_subtype)
OR (tt.basetype IS NOT NULL AND ti.oid = tt.basetype)
)
SELECT DISTINCT
*,
basetype::regtype::text AS basetype_name,
elemtype::regtype::text AS elemtype_name,
range_subtype::regtype::text AS range_subtype_name
FROM
typeinfo_tree
ORDER BY
depth DESC
'''.format(typeinfo=_TYPEINFO)
TYPE_BY_NAME: typing.Final = '''\
SELECT
t.oid,
t.typelem AS elemtype,
t.typtype AS kind
FROM
pg_catalog.pg_type AS t
INNER JOIN pg_catalog.pg_namespace ns ON (ns.oid = t.typnamespace)
WHERE
t.typname = $1 AND ns.nspname = $2
'''
def TypeRecord(
rec: typing.Tuple[int, typing.Optional[int], bytes],
) -> protocol.Record:
assert len(rec) == 3
return _create_record( # type: ignore
{"oid": 0, "elemtype": 1, "kind": 2}, rec)
# 'b' for a base type, 'd' for a domain, 'e' for enum.
SCALAR_TYPE_KINDS = (b'b', b'd', b'e')
def is_scalar_type(typeinfo: protocol.Record) -> bool:
return (
typeinfo['kind'] in SCALAR_TYPE_KINDS and
not typeinfo['elemtype']
)
def is_domain_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'd' # type: ignore[no-any-return]
def is_composite_type(typeinfo: protocol.Record) -> bool:
return typeinfo['kind'] == b'c' # type: ignore[no-any-return]

View File

@@ -0,0 +1,5 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

View File

@@ -0,0 +1,5 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0

View File

@@ -0,0 +1,143 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc.stdint cimport int8_t, uint8_t, int16_t, uint16_t, \
int32_t, uint32_t, int64_t, uint64_t
include "./buffer.pxi"
cdef class WriteBuffer:
cdef:
# Preallocated small buffer
bint _smallbuf_inuse
char _smallbuf[_BUFFER_INITIAL_SIZE]
char *_buf
# Allocated size
ssize_t _size
# Length of data in the buffer
ssize_t _length
# Number of memoryviews attached to the buffer
int _view_count
# True is start_message was used
bint _message_mode
cdef inline len(self):
return self._length
cdef inline write_len_prefixed_utf8(self, str s):
return self.write_len_prefixed_bytes(s.encode('utf-8'))
cdef inline _check_readonly(self)
cdef inline _ensure_alloced(self, ssize_t extra_length)
cdef _reallocate(self, ssize_t new_size)
cdef inline reset(self)
cdef inline start_message(self, char type)
cdef inline end_message(self)
cdef write_buffer(self, WriteBuffer buf)
cdef write_byte(self, char b)
cdef write_bytes(self, bytes data)
cdef write_len_prefixed_buffer(self, WriteBuffer buf)
cdef write_len_prefixed_bytes(self, bytes data)
cdef write_bytestring(self, bytes string)
cdef write_str(self, str string, str encoding)
cdef write_frbuf(self, FRBuffer *buf)
cdef write_cstr(self, const char *data, ssize_t len)
cdef write_int16(self, int16_t i)
cdef write_int32(self, int32_t i)
cdef write_int64(self, int64_t i)
cdef write_float(self, float f)
cdef write_double(self, double d)
@staticmethod
cdef WriteBuffer new_message(char type)
@staticmethod
cdef WriteBuffer new()
ctypedef const char * (*try_consume_message_method)(object, ssize_t*)
ctypedef int32_t (*take_message_type_method)(object, char) except -1
ctypedef int32_t (*take_message_method)(object) except -1
ctypedef char (*get_message_type_method)(object)
cdef class ReadBuffer:
cdef:
# A deque of buffers (bytes objects)
object _bufs
object _bufs_append
object _bufs_popleft
# A pointer to the first buffer in `_bufs`
bytes _buf0
# A pointer to the previous first buffer
# (used to prolong the life of _buf0 when using
# methods like _try_read_bytes)
bytes _buf0_prev
# Number of buffers in `_bufs`
int32_t _bufs_len
# A read position in the first buffer in `_bufs`
ssize_t _pos0
# Length of the first buffer in `_bufs`
ssize_t _len0
# A total number of buffered bytes in ReadBuffer
ssize_t _length
char _current_message_type
int32_t _current_message_len
ssize_t _current_message_len_unread
bint _current_message_ready
cdef inline len(self):
return self._length
cdef inline char get_message_type(self):
return self._current_message_type
cdef inline int32_t get_message_length(self):
return self._current_message_len
cdef feed_data(self, data)
cdef inline _ensure_first_buf(self)
cdef _switch_to_next_buf(self)
cdef inline char read_byte(self) except? -1
cdef inline const char* _try_read_bytes(self, ssize_t nbytes)
cdef inline _read_into(self, char *buf, ssize_t nbytes)
cdef inline _read_and_discard(self, ssize_t nbytes)
cdef bytes read_bytes(self, ssize_t nbytes)
cdef bytes read_len_prefixed_bytes(self)
cdef str read_len_prefixed_utf8(self)
cdef read_uuid(self)
cdef inline int64_t read_int64(self) except? -1
cdef inline int32_t read_int32(self) except? -1
cdef inline int16_t read_int16(self) except? -1
cdef inline read_null_str(self)
cdef int32_t take_message(self) except -1
cdef inline int32_t take_message_type(self, char mtype) except -1
cdef int32_t put_message(self) except -1
cdef inline const char* try_consume_message(self, ssize_t* len)
cdef bytes consume_message(self)
cdef discard_message(self)
cdef int32_t redirect_messages(self, WriteBuffer buf, char mtype, int stop_at=?)
cdef bytearray consume_messages(self, char mtype)
cdef finish_message(self)
cdef inline _finish_message(self)
@staticmethod
cdef ReadBuffer new_message_parser(object data)

View File

@@ -0,0 +1,3 @@
DEF _BUFFER_INITIAL_SIZE = 1024
DEF _BUFFER_MAX_GROW = 65536
DEF _BUFFER_FREELIST_SIZE = 256

View File

@@ -0,0 +1,829 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc.string cimport memcpy
import collections
class BufferError(Exception):
pass
include "./buffer.pxi"
@cython.no_gc_clear
@cython.final
@cython.freelist(_BUFFER_FREELIST_SIZE)
cdef class WriteBuffer:
def __cinit__(self):
self._smallbuf_inuse = True
self._buf = self._smallbuf
self._size = _BUFFER_INITIAL_SIZE
self._length = 0
self._message_mode = 0
def __dealloc__(self):
if self._buf is not NULL and not self._smallbuf_inuse:
cpython.PyMem_Free(self._buf)
self._buf = NULL
self._size = 0
if self._view_count:
raise BufferError(
'Deallocating buffer with attached memoryviews')
def __getbuffer__(self, Py_buffer *buffer, int flags):
self._view_count += 1
cpython.PyBuffer_FillInfo(
buffer, self, self._buf, self._length,
1, # read-only
flags)
def __releasebuffer__(self, Py_buffer *buffer):
self._view_count -= 1
cdef inline _check_readonly(self):
if self._view_count:
raise BufferError('the buffer is in read-only mode')
cdef inline _ensure_alloced(self, ssize_t extra_length):
cdef ssize_t new_size = extra_length + self._length
if new_size > self._size:
self._reallocate(new_size)
cdef _reallocate(self, ssize_t new_size):
cdef char *new_buf
if new_size < _BUFFER_MAX_GROW:
new_size = _BUFFER_MAX_GROW
else:
# Add a little extra
new_size += _BUFFER_INITIAL_SIZE
if self._smallbuf_inuse:
new_buf = <char*>cpython.PyMem_Malloc(
sizeof(char) * <size_t>new_size)
if new_buf is NULL:
self._buf = NULL
self._size = 0
self._length = 0
raise MemoryError
memcpy(new_buf, self._buf, <size_t>self._size)
self._size = new_size
self._buf = new_buf
self._smallbuf_inuse = False
else:
new_buf = <char*>cpython.PyMem_Realloc(
<void*>self._buf, <size_t>new_size)
if new_buf is NULL:
cpython.PyMem_Free(self._buf)
self._buf = NULL
self._size = 0
self._length = 0
raise MemoryError
self._buf = new_buf
self._size = new_size
cdef inline start_message(self, char type):
if self._length != 0:
raise BufferError(
'cannot start_message for a non-empty buffer')
self._ensure_alloced(5)
self._message_mode = 1
self._buf[0] = type
self._length = 5
cdef inline end_message(self):
# "length-1" to exclude the message type byte
cdef ssize_t mlen = self._length - 1
self._check_readonly()
if not self._message_mode:
raise BufferError(
'end_message can only be called with start_message')
if self._length < 5:
raise BufferError('end_message: buffer is too small')
if mlen > _MAXINT32:
raise BufferError('end_message: message is too large')
hton.pack_int32(&self._buf[1], <int32_t>mlen)
return self
cdef inline reset(self):
self._length = 0
self._message_mode = 0
cdef write_buffer(self, WriteBuffer buf):
self._check_readonly()
if not buf._length:
return
self._ensure_alloced(buf._length)
memcpy(self._buf + self._length,
<void*>buf._buf,
<size_t>buf._length)
self._length += buf._length
cdef write_byte(self, char b):
self._check_readonly()
self._ensure_alloced(1)
self._buf[self._length] = b
self._length += 1
cdef write_bytes(self, bytes data):
cdef char* buf
cdef ssize_t len
cpython.PyBytes_AsStringAndSize(data, &buf, &len)
self.write_cstr(buf, len)
cdef write_bytestring(self, bytes string):
cdef char* buf
cdef ssize_t len
cpython.PyBytes_AsStringAndSize(string, &buf, &len)
# PyBytes_AsStringAndSize returns a null-terminated buffer,
# but the null byte is not counted in len. hence the + 1
self.write_cstr(buf, len + 1)
cdef write_str(self, str string, str encoding):
self.write_bytestring(string.encode(encoding))
cdef write_len_prefixed_buffer(self, WriteBuffer buf):
# Write a length-prefixed (not NULL-terminated) bytes sequence.
self.write_int32(<int32_t>buf.len())
self.write_buffer(buf)
cdef write_len_prefixed_bytes(self, bytes data):
# Write a length-prefixed (not NULL-terminated) bytes sequence.
cdef:
char *buf
ssize_t size
cpython.PyBytes_AsStringAndSize(data, &buf, &size)
if size > _MAXINT32:
raise BufferError('string is too large')
# `size` does not account for the NULL at the end.
self.write_int32(<int32_t>size)
self.write_cstr(buf, size)
cdef write_frbuf(self, FRBuffer *buf):
cdef:
ssize_t buf_len = buf.len
if buf_len > 0:
self.write_cstr(frb_read_all(buf), buf_len)
cdef write_cstr(self, const char *data, ssize_t len):
self._check_readonly()
self._ensure_alloced(len)
memcpy(self._buf + self._length, <void*>data, <size_t>len)
self._length += len
cdef write_int16(self, int16_t i):
self._check_readonly()
self._ensure_alloced(2)
hton.pack_int16(&self._buf[self._length], i)
self._length += 2
cdef write_int32(self, int32_t i):
self._check_readonly()
self._ensure_alloced(4)
hton.pack_int32(&self._buf[self._length], i)
self._length += 4
cdef write_int64(self, int64_t i):
self._check_readonly()
self._ensure_alloced(8)
hton.pack_int64(&self._buf[self._length], i)
self._length += 8
cdef write_float(self, float f):
self._check_readonly()
self._ensure_alloced(4)
hton.pack_float(&self._buf[self._length], f)
self._length += 4
cdef write_double(self, double d):
self._check_readonly()
self._ensure_alloced(8)
hton.pack_double(&self._buf[self._length], d)
self._length += 8
@staticmethod
cdef WriteBuffer new_message(char type):
cdef WriteBuffer buf
buf = WriteBuffer.__new__(WriteBuffer)
buf.start_message(type)
return buf
@staticmethod
cdef WriteBuffer new():
cdef WriteBuffer buf
buf = WriteBuffer.__new__(WriteBuffer)
return buf
@cython.no_gc_clear
@cython.final
@cython.freelist(_BUFFER_FREELIST_SIZE)
cdef class ReadBuffer:
def __cinit__(self):
self._bufs = collections.deque()
self._bufs_append = self._bufs.append
self._bufs_popleft = self._bufs.popleft
self._bufs_len = 0
self._buf0 = None
self._buf0_prev = None
self._pos0 = 0
self._len0 = 0
self._length = 0
self._current_message_type = 0
self._current_message_len = 0
self._current_message_len_unread = 0
self._current_message_ready = 0
cdef feed_data(self, data):
cdef:
ssize_t dlen
bytes data_bytes
if not cpython.PyBytes_CheckExact(data):
if cpythonx.PyByteArray_CheckExact(data):
# ProactorEventLoop in Python 3.10+ seems to be sending
# bytearray objects instead of bytes. Handle this here
# to avoid duplicating this check in every data_received().
data = bytes(data)
else:
raise BufferError(
'feed_data: a bytes or bytearray object expected')
# Uncomment the below code to test code paths that
# read single int/str/bytes sequences are split over
# multiple received buffers.
#
# ll = 107
# if len(data) > ll:
# self.feed_data(data[:ll])
# self.feed_data(data[ll:])
# return
data_bytes = <bytes>data
dlen = cpython.Py_SIZE(data_bytes)
if dlen == 0:
# EOF?
return
self._bufs_append(data_bytes)
self._length += dlen
if self._bufs_len == 0:
# First buffer
self._len0 = dlen
self._buf0 = data_bytes
self._bufs_len += 1
cdef inline _ensure_first_buf(self):
if PG_DEBUG:
if self._len0 == 0:
raise BufferError('empty first buffer')
if self._length == 0:
raise BufferError('empty buffer')
if self._pos0 == self._len0:
self._switch_to_next_buf()
cdef _switch_to_next_buf(self):
# The first buffer is fully read, discard it
self._bufs_popleft()
self._bufs_len -= 1
# Shouldn't fail, since we've checked that `_length >= 1`
# in _ensure_first_buf()
self._buf0_prev = self._buf0
self._buf0 = <bytes>self._bufs[0]
self._pos0 = 0
self._len0 = len(self._buf0)
if PG_DEBUG:
if self._len0 < 1:
raise BufferError(
'debug: second buffer of ReadBuffer is empty')
cdef inline const char* _try_read_bytes(self, ssize_t nbytes):
# Try to read *nbytes* from the first buffer.
#
# Returns pointer to data if there is at least *nbytes*
# in the buffer, NULL otherwise.
#
# Important: caller must call _ensure_first_buf() prior
# to calling try_read_bytes, and must not overread
cdef:
const char *result
if PG_DEBUG:
if nbytes > self._length:
return NULL
if self._current_message_ready:
if self._current_message_len_unread < nbytes:
return NULL
if self._pos0 + nbytes <= self._len0:
result = cpython.PyBytes_AS_STRING(self._buf0)
result += self._pos0
self._pos0 += nbytes
self._length -= nbytes
if self._current_message_ready:
self._current_message_len_unread -= nbytes
return result
else:
return NULL
cdef inline _read_into(self, char *buf, ssize_t nbytes):
cdef:
ssize_t nread
char *buf0
while True:
buf0 = cpython.PyBytes_AS_STRING(self._buf0)
if self._pos0 + nbytes > self._len0:
nread = self._len0 - self._pos0
memcpy(buf, buf0 + self._pos0, <size_t>nread)
self._pos0 = self._len0
self._length -= nread
nbytes -= nread
buf += nread
self._ensure_first_buf()
else:
memcpy(buf, buf0 + self._pos0, <size_t>nbytes)
self._pos0 += nbytes
self._length -= nbytes
break
cdef inline _read_and_discard(self, ssize_t nbytes):
cdef:
ssize_t nread
self._ensure_first_buf()
while True:
if self._pos0 + nbytes > self._len0:
nread = self._len0 - self._pos0
self._pos0 = self._len0
self._length -= nread
nbytes -= nread
self._ensure_first_buf()
else:
self._pos0 += nbytes
self._length -= nbytes
break
cdef bytes read_bytes(self, ssize_t nbytes):
cdef:
bytes result
ssize_t nread
const char *cbuf
char *buf
self._ensure_first_buf()
cbuf = self._try_read_bytes(nbytes)
if cbuf != NULL:
return cpython.PyBytes_FromStringAndSize(cbuf, nbytes)
if nbytes > self._length:
raise BufferError(
'not enough data to read {} bytes'.format(nbytes))
if self._current_message_ready:
self._current_message_len_unread -= nbytes
if self._current_message_len_unread < 0:
raise BufferError('buffer overread')
result = cpython.PyBytes_FromStringAndSize(NULL, nbytes)
buf = cpython.PyBytes_AS_STRING(result)
self._read_into(buf, nbytes)
return result
cdef bytes read_len_prefixed_bytes(self):
cdef int32_t size = self.read_int32()
if size < 0:
raise BufferError(
'negative length for a len-prefixed bytes value')
if size == 0:
return b''
return self.read_bytes(size)
cdef str read_len_prefixed_utf8(self):
cdef:
int32_t size
const char *cbuf
size = self.read_int32()
if size < 0:
raise BufferError(
'negative length for a len-prefixed bytes value')
if size == 0:
return ''
self._ensure_first_buf()
cbuf = self._try_read_bytes(size)
if cbuf != NULL:
return cpython.PyUnicode_DecodeUTF8(cbuf, size, NULL)
else:
return self.read_bytes(size).decode('utf-8')
cdef read_uuid(self):
cdef:
bytes mem
const char *cbuf
self._ensure_first_buf()
cbuf = self._try_read_bytes(16)
if cbuf != NULL:
return pg_uuid_from_buf(cbuf)
else:
return pg_UUID(self.read_bytes(16))
cdef inline char read_byte(self) except? -1:
cdef const char *first_byte
if PG_DEBUG:
if not self._buf0:
raise BufferError(
'debug: first buffer of ReadBuffer is empty')
self._ensure_first_buf()
first_byte = self._try_read_bytes(1)
if first_byte is NULL:
raise BufferError('not enough data to read one byte')
return first_byte[0]
cdef inline int64_t read_int64(self) except? -1:
cdef:
bytes mem
const char *cbuf
self._ensure_first_buf()
cbuf = self._try_read_bytes(8)
if cbuf != NULL:
return hton.unpack_int64(cbuf)
else:
mem = self.read_bytes(8)
return hton.unpack_int64(cpython.PyBytes_AS_STRING(mem))
cdef inline int32_t read_int32(self) except? -1:
cdef:
bytes mem
const char *cbuf
self._ensure_first_buf()
cbuf = self._try_read_bytes(4)
if cbuf != NULL:
return hton.unpack_int32(cbuf)
else:
mem = self.read_bytes(4)
return hton.unpack_int32(cpython.PyBytes_AS_STRING(mem))
cdef inline int16_t read_int16(self) except? -1:
cdef:
bytes mem
const char *cbuf
self._ensure_first_buf()
cbuf = self._try_read_bytes(2)
if cbuf != NULL:
return hton.unpack_int16(cbuf)
else:
mem = self.read_bytes(2)
return hton.unpack_int16(cpython.PyBytes_AS_STRING(mem))
cdef inline read_null_str(self):
if not self._current_message_ready:
raise BufferError(
'read_null_str only works when the message guaranteed '
'to be in the buffer')
cdef:
ssize_t pos
ssize_t nread
bytes result
const char *buf
const char *buf_start
self._ensure_first_buf()
buf_start = cpython.PyBytes_AS_STRING(self._buf0)
buf = buf_start + self._pos0
while buf - buf_start < self._len0:
if buf[0] == 0:
pos = buf - buf_start
nread = pos - self._pos0
buf = self._try_read_bytes(nread + 1)
if buf != NULL:
return cpython.PyBytes_FromStringAndSize(buf, nread)
else:
break
else:
buf += 1
result = b''
while True:
pos = self._buf0.find(b'\x00', self._pos0)
if pos >= 0:
result += self._buf0[self._pos0 : pos]
nread = pos - self._pos0 + 1
self._pos0 = pos + 1
self._length -= nread
self._current_message_len_unread -= nread
if self._current_message_len_unread < 0:
raise BufferError(
'read_null_str: buffer overread')
return result
else:
result += self._buf0[self._pos0:]
nread = self._len0 - self._pos0
self._pos0 = self._len0
self._length -= nread
self._current_message_len_unread -= nread
if self._current_message_len_unread < 0:
raise BufferError(
'read_null_str: buffer overread')
self._ensure_first_buf()
cdef int32_t take_message(self) except -1:
cdef:
const char *cbuf
if self._current_message_ready:
return 1
if self._current_message_type == 0:
if self._length < 1:
return 0
self._ensure_first_buf()
cbuf = self._try_read_bytes(1)
if cbuf == NULL:
raise BufferError(
'failed to read one byte on a non-empty buffer')
self._current_message_type = cbuf[0]
if self._current_message_len == 0:
if self._length < 4:
return 0
self._ensure_first_buf()
cbuf = self._try_read_bytes(4)
if cbuf != NULL:
self._current_message_len = hton.unpack_int32(cbuf)
else:
self._current_message_len = self.read_int32()
self._current_message_len_unread = self._current_message_len - 4
if self._length < self._current_message_len_unread:
return 0
self._current_message_ready = 1
return 1
cdef inline int32_t take_message_type(self, char mtype) except -1:
cdef const char *buf0
if self._current_message_ready:
return self._current_message_type == mtype
elif self._length >= 1:
self._ensure_first_buf()
buf0 = cpython.PyBytes_AS_STRING(self._buf0)
return buf0[self._pos0] == mtype and self.take_message()
else:
return 0
cdef int32_t put_message(self) except -1:
if not self._current_message_ready:
raise BufferError(
'cannot put message: no message taken')
self._current_message_ready = False
return 0
cdef inline const char* try_consume_message(self, ssize_t* len):
cdef:
ssize_t buf_len
const char *buf
if not self._current_message_ready:
return NULL
self._ensure_first_buf()
buf_len = self._current_message_len_unread
buf = self._try_read_bytes(buf_len)
if buf != NULL:
len[0] = buf_len
self._finish_message()
return buf
cdef discard_message(self):
if not self._current_message_ready:
raise BufferError('no message to discard')
if self._current_message_len_unread > 0:
self._read_and_discard(self._current_message_len_unread)
self._current_message_len_unread = 0
self._finish_message()
cdef bytes consume_message(self):
if not self._current_message_ready:
raise BufferError('no message to consume')
if self._current_message_len_unread > 0:
mem = self.read_bytes(self._current_message_len_unread)
else:
mem = b''
self._finish_message()
return mem
cdef int32_t redirect_messages(self, WriteBuffer buf, char mtype,
int stop_at=0):
# Redirects messages from self into buf until either
# a message with a type different than mtype is encountered, or
# buf contains stop_at bytes.
# Returns the number of messages redirected.
if not self._current_message_ready:
raise BufferError(
'consume_full_messages called on a buffer without a '
'complete first message')
if mtype != self._current_message_type:
raise BufferError(
'consume_full_messages called with a wrong mtype')
if self._current_message_len_unread != self._current_message_len - 4:
raise BufferError(
'consume_full_messages called on a partially read message')
cdef:
const char* cbuf
ssize_t cbuf_len
int32_t msg_len
ssize_t new_pos0
ssize_t pos_delta
int32_t done
int32_t count
count = 0
while True:
count += 1
buf.write_byte(mtype)
buf.write_int32(self._current_message_len)
cbuf = self.try_consume_message(&cbuf_len)
if cbuf != NULL:
buf.write_cstr(cbuf, cbuf_len)
else:
buf.write_bytes(self.consume_message())
if self._length > 0:
self._ensure_first_buf()
else:
return count
if stop_at and buf._length >= stop_at:
return count
# Fast path: exhaust buf0 as efficiently as possible.
if self._pos0 + 5 <= self._len0:
cbuf = cpython.PyBytes_AS_STRING(self._buf0)
new_pos0 = self._pos0
cbuf_len = self._len0
done = 0
# Scan the first buffer and find the position of the
# end of the last "mtype" message.
while new_pos0 + 5 <= cbuf_len:
if (cbuf + new_pos0)[0] != mtype:
done = 1
break
if (stop_at and
(buf._length + new_pos0 - self._pos0) > stop_at):
done = 1
break
msg_len = hton.unpack_int32(cbuf + new_pos0 + 1) + 1
if new_pos0 + msg_len > cbuf_len:
break
new_pos0 += msg_len
count += 1
if new_pos0 != self._pos0:
assert self._pos0 < new_pos0 <= self._len0
pos_delta = new_pos0 - self._pos0
buf.write_cstr(
cbuf + self._pos0,
pos_delta
)
self._pos0 = new_pos0
self._length -= pos_delta
assert self._length >= 0
if done:
# The next message is of a different type.
return count
# Back to slow path.
if not self.take_message_type(mtype):
return count
cdef bytearray consume_messages(self, char mtype):
"""Consume consecutive messages of the same type."""
cdef:
char *buf
ssize_t nbytes
ssize_t total_bytes = 0
bytearray result
if not self.take_message_type(mtype):
return None
# consume_messages is a volume-oriented method, so
# we assume that the remainder of the buffer will contain
# messages of the requested type.
result = cpythonx.PyByteArray_FromStringAndSize(NULL, self._length)
buf = cpythonx.PyByteArray_AsString(result)
while self.take_message_type(mtype):
self._ensure_first_buf()
nbytes = self._current_message_len_unread
self._read_into(buf, nbytes)
buf += nbytes
total_bytes += nbytes
self._finish_message()
# Clamp the result to an actual size read.
cpythonx.PyByteArray_Resize(result, total_bytes)
return result
cdef finish_message(self):
if self._current_message_type == 0 or not self._current_message_ready:
# The message has already been finished (e.g by consume_message()),
# or has been put back by put_message().
return
if self._current_message_len_unread:
if PG_DEBUG:
mtype = chr(self._current_message_type)
discarded = self.consume_message()
if PG_DEBUG:
print('!!! discarding message {!r} unread data: {!r}'.format(
mtype,
discarded))
self._finish_message()
cdef inline _finish_message(self):
self._current_message_type = 0
self._current_message_len = 0
self._current_message_ready = 0
self._current_message_len_unread = 0
@staticmethod
cdef ReadBuffer new_message_parser(object data):
cdef ReadBuffer buf
buf = ReadBuffer.__new__(ReadBuffer)
buf.feed_data(data)
buf._current_message_ready = 1
buf._current_message_len_unread = buf._len0
return buf

View File

@@ -0,0 +1,159 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class CodecContext:
cpdef get_text_codec(self)
cdef is_encoding_utf8(self)
cpdef get_json_decoder(self)
cdef is_decoding_json(self)
cpdef get_json_encoder(self)
cdef is_encoding_json(self)
ctypedef object (*encode_func)(CodecContext settings,
WriteBuffer buf,
object obj)
ctypedef object (*decode_func)(CodecContext settings,
FRBuffer *buf)
# Datetime
cdef date_encode(CodecContext settings, WriteBuffer buf, obj)
cdef date_decode(CodecContext settings, FRBuffer * buf)
cdef date_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
cdef date_decode_tuple(CodecContext settings, FRBuffer * buf)
cdef timestamp_encode(CodecContext settings, WriteBuffer buf, obj)
cdef timestamp_decode(CodecContext settings, FRBuffer * buf)
cdef timestamp_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
cdef timestamp_decode_tuple(CodecContext settings, FRBuffer * buf)
cdef timestamptz_encode(CodecContext settings, WriteBuffer buf, obj)
cdef timestamptz_decode(CodecContext settings, FRBuffer * buf)
cdef time_encode(CodecContext settings, WriteBuffer buf, obj)
cdef time_decode(CodecContext settings, FRBuffer * buf)
cdef time_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
cdef time_decode_tuple(CodecContext settings, FRBuffer * buf)
cdef timetz_encode(CodecContext settings, WriteBuffer buf, obj)
cdef timetz_decode(CodecContext settings, FRBuffer * buf)
cdef timetz_encode_tuple(CodecContext settings, WriteBuffer buf, obj)
cdef timetz_decode_tuple(CodecContext settings, FRBuffer * buf)
cdef interval_encode(CodecContext settings, WriteBuffer buf, obj)
cdef interval_decode(CodecContext settings, FRBuffer * buf)
cdef interval_encode_tuple(CodecContext settings, WriteBuffer buf, tuple obj)
cdef interval_decode_tuple(CodecContext settings, FRBuffer * buf)
# Bits
cdef bits_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef bits_decode(CodecContext settings, FRBuffer * buf)
# Bools
cdef bool_encode(CodecContext settings, WriteBuffer buf, obj)
cdef bool_decode(CodecContext settings, FRBuffer * buf)
# Geometry
cdef box_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef box_decode(CodecContext settings, FRBuffer * buf)
cdef line_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef line_decode(CodecContext settings, FRBuffer * buf)
cdef lseg_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef lseg_decode(CodecContext settings, FRBuffer * buf)
cdef point_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef point_decode(CodecContext settings, FRBuffer * buf)
cdef path_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef path_decode(CodecContext settings, FRBuffer * buf)
cdef poly_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef poly_decode(CodecContext settings, FRBuffer * buf)
cdef circle_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef circle_decode(CodecContext settings, FRBuffer * buf)
# Hstore
cdef hstore_encode(CodecContext settings, WriteBuffer buf, obj)
cdef hstore_decode(CodecContext settings, FRBuffer * buf)
# Ints
cdef int2_encode(CodecContext settings, WriteBuffer buf, obj)
cdef int2_decode(CodecContext settings, FRBuffer * buf)
cdef int4_encode(CodecContext settings, WriteBuffer buf, obj)
cdef int4_decode(CodecContext settings, FRBuffer * buf)
cdef uint4_encode(CodecContext settings, WriteBuffer buf, obj)
cdef uint4_decode(CodecContext settings, FRBuffer * buf)
cdef int8_encode(CodecContext settings, WriteBuffer buf, obj)
cdef int8_decode(CodecContext settings, FRBuffer * buf)
cdef uint8_encode(CodecContext settings, WriteBuffer buf, obj)
cdef uint8_decode(CodecContext settings, FRBuffer * buf)
# Floats
cdef float4_encode(CodecContext settings, WriteBuffer buf, obj)
cdef float4_decode(CodecContext settings, FRBuffer * buf)
cdef float8_encode(CodecContext settings, WriteBuffer buf, obj)
cdef float8_decode(CodecContext settings, FRBuffer * buf)
# JSON
cdef jsonb_encode(CodecContext settings, WriteBuffer buf, obj)
cdef jsonb_decode(CodecContext settings, FRBuffer * buf)
cdef json_encode(CodecContext settings, WriteBuffer buf, obj)
cdef json_decode(CodecContext settings, FRBuffer *buf)
# JSON path
cdef jsonpath_encode(CodecContext settings, WriteBuffer buf, obj)
cdef jsonpath_decode(CodecContext settings, FRBuffer * buf)
# Text
cdef as_pg_string_and_size(
CodecContext settings, obj, char **cstr, ssize_t *size)
cdef text_encode(CodecContext settings, WriteBuffer buf, obj)
cdef text_decode(CodecContext settings, FRBuffer * buf)
# Bytea
cdef bytea_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef bytea_decode(CodecContext settings, FRBuffer * buf)
# UUID
cdef uuid_encode(CodecContext settings, WriteBuffer wbuf, obj)
cdef uuid_decode(CodecContext settings, FRBuffer * buf)
# Numeric
cdef numeric_encode_text(CodecContext settings, WriteBuffer buf, obj)
cdef numeric_decode_text(CodecContext settings, FRBuffer * buf)
cdef numeric_encode_binary(CodecContext settings, WriteBuffer buf, obj)
cdef numeric_decode_binary(CodecContext settings, FRBuffer * buf)
cdef numeric_decode_binary_ex(CodecContext settings, FRBuffer * buf,
bint trail_fract_zero)
# Void
cdef void_encode(CodecContext settings, WriteBuffer buf, obj)
cdef void_decode(CodecContext settings, FRBuffer * buf)
# tid
cdef tid_encode(CodecContext settings, WriteBuffer buf, obj)
cdef tid_decode(CodecContext settings, FRBuffer * buf)
# Network
cdef cidr_encode(CodecContext settings, WriteBuffer buf, obj)
cdef cidr_decode(CodecContext settings, FRBuffer * buf)
cdef inet_encode(CodecContext settings, WriteBuffer buf, obj)
cdef inet_decode(CodecContext settings, FRBuffer * buf)
# pg_snapshot
cdef pg_snapshot_encode(CodecContext settings, WriteBuffer buf, obj)
cdef pg_snapshot_decode(CodecContext settings, FRBuffer * buf)

View File

@@ -0,0 +1,47 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef bits_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
Py_buffer pybuf
bint pybuf_used = False
char *buf
ssize_t len
ssize_t bitlen
if cpython.PyBytes_CheckExact(obj):
buf = cpython.PyBytes_AS_STRING(obj)
len = cpython.Py_SIZE(obj)
bitlen = len * 8
elif isinstance(obj, pgproto_types.BitString):
cpython.PyBytes_AsStringAndSize(obj.bytes, &buf, &len)
bitlen = obj.__len__()
else:
cpython.PyObject_GetBuffer(obj, &pybuf, cpython.PyBUF_SIMPLE)
pybuf_used = True
buf = <char*>pybuf.buf
len = pybuf.len
bitlen = len * 8
try:
if bitlen > _MAXINT32:
raise ValueError('bit value too long')
wbuf.write_int32(4 + <int32_t>len)
wbuf.write_int32(<int32_t>bitlen)
wbuf.write_cstr(buf, len)
finally:
if pybuf_used:
cpython.PyBuffer_Release(&pybuf)
cdef bits_decode(CodecContext settings, FRBuffer *buf):
cdef:
int32_t bitlen = hton.unpack_int32(frb_read(buf, 4))
ssize_t buf_len = buf.len
bytes_ = cpython.PyBytes_FromStringAndSize(frb_read_all(buf), buf_len)
return pgproto_types.BitString.frombytes(bytes_, bitlen)

View File

@@ -0,0 +1,34 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef bytea_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
Py_buffer pybuf
bint pybuf_used = False
char *buf
ssize_t len
if cpython.PyBytes_CheckExact(obj):
buf = cpython.PyBytes_AS_STRING(obj)
len = cpython.Py_SIZE(obj)
else:
cpython.PyObject_GetBuffer(obj, &pybuf, cpython.PyBUF_SIMPLE)
pybuf_used = True
buf = <char*>pybuf.buf
len = pybuf.len
try:
wbuf.write_int32(<int32_t>len)
wbuf.write_cstr(buf, len)
finally:
if pybuf_used:
cpython.PyBuffer_Release(&pybuf)
cdef bytea_decode(CodecContext settings, FRBuffer *buf):
cdef ssize_t buf_len = buf.len
return cpython.PyBytes_FromStringAndSize(frb_read_all(buf), buf_len)

View File

@@ -0,0 +1,26 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef class CodecContext:
cpdef get_text_codec(self):
raise NotImplementedError
cdef is_encoding_utf8(self):
raise NotImplementedError
cpdef get_json_decoder(self):
raise NotImplementedError
cdef is_decoding_json(self):
return False
cpdef get_json_encoder(self):
raise NotImplementedError
cdef is_encoding_json(self):
return False

View File

@@ -0,0 +1,423 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cimport cpython.datetime
import datetime
cpython.datetime.import_datetime()
utc = datetime.timezone.utc
date_from_ordinal = datetime.date.fromordinal
timedelta = datetime.timedelta
pg_epoch_datetime = datetime.datetime(2000, 1, 1)
cdef int32_t pg_epoch_datetime_ts = \
<int32_t>cpython.PyLong_AsLong(int(pg_epoch_datetime.timestamp()))
pg_epoch_datetime_utc = datetime.datetime(2000, 1, 1, tzinfo=utc)
cdef int32_t pg_epoch_datetime_utc_ts = \
<int32_t>cpython.PyLong_AsLong(int(pg_epoch_datetime_utc.timestamp()))
pg_epoch_date = datetime.date(2000, 1, 1)
cdef int32_t pg_date_offset_ord = \
<int32_t>cpython.PyLong_AsLong(pg_epoch_date.toordinal())
# Binary representations of infinity for datetimes.
cdef const int64_t pg_time64_infinity = 0x7fffffffffffffff
cdef const int64_t pg_time64_negative_infinity = <int64_t>0x8000000000000000
cdef const int32_t pg_date_infinity = 0x7fffffff
cdef const int32_t pg_date_negative_infinity = <int32_t>0x80000000
infinity_datetime = datetime.datetime(
datetime.MAXYEAR, 12, 31, 23, 59, 59, 999999)
cdef int32_t infinity_datetime_ord = <int32_t>cpython.PyLong_AsLong(
infinity_datetime.toordinal())
cdef int64_t infinity_datetime_ts = 252455615999999999
negative_infinity_datetime = datetime.datetime(
datetime.MINYEAR, 1, 1, 0, 0, 0, 0)
cdef int32_t negative_infinity_datetime_ord = <int32_t>cpython.PyLong_AsLong(
negative_infinity_datetime.toordinal())
cdef int64_t negative_infinity_datetime_ts = -63082281600000000
infinity_date = datetime.date(datetime.MAXYEAR, 12, 31)
cdef int32_t infinity_date_ord = <int32_t>cpython.PyLong_AsLong(
infinity_date.toordinal())
negative_infinity_date = datetime.date(datetime.MINYEAR, 1, 1)
cdef int32_t negative_infinity_date_ord = <int32_t>cpython.PyLong_AsLong(
negative_infinity_date.toordinal())
cdef inline _local_timezone():
d = datetime.datetime.now(datetime.timezone.utc).astimezone()
return datetime.timezone(d.utcoffset())
cdef inline _encode_time(WriteBuffer buf, int64_t seconds,
int32_t microseconds):
# XXX: add support for double timestamps
# int64 timestamps,
cdef int64_t ts = seconds * 1000000 + microseconds
if ts == infinity_datetime_ts:
buf.write_int64(pg_time64_infinity)
elif ts == negative_infinity_datetime_ts:
buf.write_int64(pg_time64_negative_infinity)
else:
buf.write_int64(ts)
cdef inline int32_t _decode_time(FRBuffer *buf, int64_t *seconds,
int32_t *microseconds):
cdef int64_t ts = hton.unpack_int64(frb_read(buf, 8))
if ts == pg_time64_infinity:
return 1
elif ts == pg_time64_negative_infinity:
return -1
else:
seconds[0] = ts // 1000000
microseconds[0] = <int32_t>(ts % 1000000)
return 0
cdef date_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
int32_t ordinal = <int32_t>cpython.PyLong_AsLong(obj.toordinal())
int32_t pg_ordinal
if ordinal == infinity_date_ord:
pg_ordinal = pg_date_infinity
elif ordinal == negative_infinity_date_ord:
pg_ordinal = pg_date_negative_infinity
else:
pg_ordinal = ordinal - pg_date_offset_ord
buf.write_int32(4)
buf.write_int32(pg_ordinal)
cdef date_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
cdef:
int32_t pg_ordinal
if len(obj) != 1:
raise ValueError(
'date tuple encoder: expecting 1 element '
'in tuple, got {}'.format(len(obj)))
pg_ordinal = obj[0]
buf.write_int32(4)
buf.write_int32(pg_ordinal)
cdef date_decode(CodecContext settings, FRBuffer *buf):
cdef int32_t pg_ordinal = hton.unpack_int32(frb_read(buf, 4))
if pg_ordinal == pg_date_infinity:
return infinity_date
elif pg_ordinal == pg_date_negative_infinity:
return negative_infinity_date
else:
return date_from_ordinal(pg_ordinal + pg_date_offset_ord)
cdef date_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef int32_t pg_ordinal = hton.unpack_int32(frb_read(buf, 4))
return (pg_ordinal,)
cdef timestamp_encode(CodecContext settings, WriteBuffer buf, obj):
if not cpython.datetime.PyDateTime_Check(obj):
if cpython.datetime.PyDate_Check(obj):
obj = datetime.datetime(obj.year, obj.month, obj.day)
else:
raise TypeError(
'expected a datetime.date or datetime.datetime instance, '
'got {!r}'.format(type(obj).__name__)
)
delta = obj - pg_epoch_datetime
cdef:
int64_t seconds = cpython.PyLong_AsLongLong(delta.days) * 86400 + \
cpython.PyLong_AsLong(delta.seconds)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(
delta.microseconds)
buf.write_int32(8)
_encode_time(buf, seconds, microseconds)
cdef timestamp_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
cdef:
int64_t microseconds
if len(obj) != 1:
raise ValueError(
'timestamp tuple encoder: expecting 1 element '
'in tuple, got {}'.format(len(obj)))
microseconds = obj[0]
buf.write_int32(8)
buf.write_int64(microseconds)
cdef timestamp_decode(CodecContext settings, FRBuffer *buf):
cdef:
int64_t seconds = 0
int32_t microseconds = 0
int32_t inf = _decode_time(buf, &seconds, &microseconds)
if inf > 0:
# positive infinity
return infinity_datetime
elif inf < 0:
# negative infinity
return negative_infinity_datetime
else:
return pg_epoch_datetime.__add__(
timedelta(0, seconds, microseconds))
cdef timestamp_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef:
int64_t ts = hton.unpack_int64(frb_read(buf, 8))
return (ts,)
cdef timestamptz_encode(CodecContext settings, WriteBuffer buf, obj):
if not cpython.datetime.PyDateTime_Check(obj):
if cpython.datetime.PyDate_Check(obj):
obj = datetime.datetime(obj.year, obj.month, obj.day,
tzinfo=_local_timezone())
else:
raise TypeError(
'expected a datetime.date or datetime.datetime instance, '
'got {!r}'.format(type(obj).__name__)
)
buf.write_int32(8)
if obj == infinity_datetime:
buf.write_int64(pg_time64_infinity)
return
elif obj == negative_infinity_datetime:
buf.write_int64(pg_time64_negative_infinity)
return
utc_dt = obj.astimezone(utc)
delta = utc_dt - pg_epoch_datetime_utc
cdef:
int64_t seconds = cpython.PyLong_AsLongLong(delta.days) * 86400 + \
cpython.PyLong_AsLong(delta.seconds)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(
delta.microseconds)
_encode_time(buf, seconds, microseconds)
cdef timestamptz_decode(CodecContext settings, FRBuffer *buf):
cdef:
int64_t seconds = 0
int32_t microseconds = 0
int32_t inf = _decode_time(buf, &seconds, &microseconds)
if inf > 0:
# positive infinity
return infinity_datetime
elif inf < 0:
# negative infinity
return negative_infinity_datetime
else:
return pg_epoch_datetime_utc.__add__(
timedelta(0, seconds, microseconds))
cdef time_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
int64_t seconds = cpython.PyLong_AsLong(obj.hour) * 3600 + \
cpython.PyLong_AsLong(obj.minute) * 60 + \
cpython.PyLong_AsLong(obj.second)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microsecond)
buf.write_int32(8)
_encode_time(buf, seconds, microseconds)
cdef time_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
cdef:
int64_t microseconds
if len(obj) != 1:
raise ValueError(
'time tuple encoder: expecting 1 element '
'in tuple, got {}'.format(len(obj)))
microseconds = obj[0]
buf.write_int32(8)
buf.write_int64(microseconds)
cdef time_decode(CodecContext settings, FRBuffer *buf):
cdef:
int64_t seconds = 0
int32_t microseconds = 0
_decode_time(buf, &seconds, &microseconds)
cdef:
int64_t minutes = <int64_t>(seconds / 60)
int64_t sec = seconds % 60
int64_t hours = <int64_t>(minutes / 60)
int64_t min = minutes % 60
return datetime.time(hours, min, sec, microseconds)
cdef time_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef:
int64_t ts = hton.unpack_int64(frb_read(buf, 8))
return (ts,)
cdef timetz_encode(CodecContext settings, WriteBuffer buf, obj):
offset = obj.tzinfo.utcoffset(None)
cdef:
int32_t offset_sec = \
<int32_t>cpython.PyLong_AsLong(offset.days) * 24 * 60 * 60 + \
<int32_t>cpython.PyLong_AsLong(offset.seconds)
int64_t seconds = cpython.PyLong_AsLong(obj.hour) * 3600 + \
cpython.PyLong_AsLong(obj.minute) * 60 + \
cpython.PyLong_AsLong(obj.second)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microsecond)
buf.write_int32(12)
_encode_time(buf, seconds, microseconds)
# In Python utcoffset() is the difference between the local time
# and the UTC, whereas in PostgreSQL it's the opposite,
# so we need to flip the sign.
buf.write_int32(-offset_sec)
cdef timetz_encode_tuple(CodecContext settings, WriteBuffer buf, obj):
cdef:
int64_t microseconds
int32_t offset_sec
if len(obj) != 2:
raise ValueError(
'time tuple encoder: expecting 2 elements2 '
'in tuple, got {}'.format(len(obj)))
microseconds = obj[0]
offset_sec = obj[1]
buf.write_int32(12)
buf.write_int64(microseconds)
buf.write_int32(offset_sec)
cdef timetz_decode(CodecContext settings, FRBuffer *buf):
time = time_decode(settings, buf)
cdef int32_t offset = <int32_t>(hton.unpack_int32(frb_read(buf, 4)) / 60)
# See the comment in the `timetz_encode` method.
return time.replace(tzinfo=datetime.timezone(timedelta(minutes=-offset)))
cdef timetz_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef:
int64_t microseconds = hton.unpack_int64(frb_read(buf, 8))
int32_t offset_sec = hton.unpack_int32(frb_read(buf, 4))
return (microseconds, offset_sec)
cdef interval_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
int32_t days = <int32_t>cpython.PyLong_AsLong(obj.days)
int64_t seconds = cpython.PyLong_AsLongLong(obj.seconds)
int32_t microseconds = <int32_t>cpython.PyLong_AsLong(obj.microseconds)
buf.write_int32(16)
_encode_time(buf, seconds, microseconds)
buf.write_int32(days)
buf.write_int32(0) # Months
cdef interval_encode_tuple(CodecContext settings, WriteBuffer buf,
tuple obj):
cdef:
int32_t months
int32_t days
int64_t microseconds
if len(obj) != 3:
raise ValueError(
'interval tuple encoder: expecting 3 elements '
'in tuple, got {}'.format(len(obj)))
months = obj[0]
days = obj[1]
microseconds = obj[2]
buf.write_int32(16)
buf.write_int64(microseconds)
buf.write_int32(days)
buf.write_int32(months)
cdef interval_decode(CodecContext settings, FRBuffer *buf):
cdef:
int32_t days
int32_t months
int32_t years
int64_t seconds = 0
int32_t microseconds = 0
_decode_time(buf, &seconds, &microseconds)
days = hton.unpack_int32(frb_read(buf, 4))
months = hton.unpack_int32(frb_read(buf, 4))
if months < 0:
years = -<int32_t>(-months // 12)
months = -<int32_t>(-months % 12)
else:
years = <int32_t>(months // 12)
months = <int32_t>(months % 12)
return datetime.timedelta(days=days + months * 30 + years * 365,
seconds=seconds, microseconds=microseconds)
cdef interval_decode_tuple(CodecContext settings, FRBuffer *buf):
cdef:
int32_t days
int32_t months
int64_t microseconds
microseconds = hton.unpack_int64(frb_read(buf, 8))
days = hton.unpack_int32(frb_read(buf, 4))
months = hton.unpack_int32(frb_read(buf, 4))
return (months, days, microseconds)

View File

@@ -0,0 +1,34 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc cimport math
cdef float4_encode(CodecContext settings, WriteBuffer buf, obj):
cdef double dval = cpython.PyFloat_AsDouble(obj)
cdef float fval = <float>dval
if math.isinf(fval) and not math.isinf(dval):
raise ValueError('value out of float32 range')
buf.write_int32(4)
buf.write_float(fval)
cdef float4_decode(CodecContext settings, FRBuffer *buf):
cdef float f = hton.unpack_float(frb_read(buf, 4))
return cpython.PyFloat_FromDouble(f)
cdef float8_encode(CodecContext settings, WriteBuffer buf, obj):
cdef double dval = cpython.PyFloat_AsDouble(obj)
buf.write_int32(8)
buf.write_double(dval)
cdef float8_decode(CodecContext settings, FRBuffer *buf):
cdef double f = hton.unpack_double(frb_read(buf, 8))
return cpython.PyFloat_FromDouble(f)

View File

@@ -0,0 +1,164 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef inline _encode_points(WriteBuffer wbuf, object points):
cdef object point
for point in points:
wbuf.write_double(point[0])
wbuf.write_double(point[1])
cdef inline _decode_points(FRBuffer *buf):
cdef:
int32_t npts = hton.unpack_int32(frb_read(buf, 4))
pts = cpython.PyTuple_New(npts)
int32_t i
object point
double x
double y
for i in range(npts):
x = hton.unpack_double(frb_read(buf, 8))
y = hton.unpack_double(frb_read(buf, 8))
point = pgproto_types.Point(x, y)
cpython.Py_INCREF(point)
cpython.PyTuple_SET_ITEM(pts, i, point)
return pts
cdef box_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(32)
_encode_points(wbuf, (obj[0], obj[1]))
cdef box_decode(CodecContext settings, FRBuffer *buf):
cdef:
double high_x = hton.unpack_double(frb_read(buf, 8))
double high_y = hton.unpack_double(frb_read(buf, 8))
double low_x = hton.unpack_double(frb_read(buf, 8))
double low_y = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.Box(
pgproto_types.Point(high_x, high_y),
pgproto_types.Point(low_x, low_y))
cdef line_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(24)
wbuf.write_double(obj[0])
wbuf.write_double(obj[1])
wbuf.write_double(obj[2])
cdef line_decode(CodecContext settings, FRBuffer *buf):
cdef:
double A = hton.unpack_double(frb_read(buf, 8))
double B = hton.unpack_double(frb_read(buf, 8))
double C = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.Line(A, B, C)
cdef lseg_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(32)
_encode_points(wbuf, (obj[0], obj[1]))
cdef lseg_decode(CodecContext settings, FRBuffer *buf):
cdef:
double p1_x = hton.unpack_double(frb_read(buf, 8))
double p1_y = hton.unpack_double(frb_read(buf, 8))
double p2_x = hton.unpack_double(frb_read(buf, 8))
double p2_y = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.LineSegment((p1_x, p1_y), (p2_x, p2_y))
cdef point_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(16)
wbuf.write_double(obj[0])
wbuf.write_double(obj[1])
cdef point_decode(CodecContext settings, FRBuffer *buf):
cdef:
double x = hton.unpack_double(frb_read(buf, 8))
double y = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.Point(x, y)
cdef path_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
int8_t is_closed = 0
ssize_t npts
ssize_t encoded_len
int32_t i
if cpython.PyTuple_Check(obj):
is_closed = 1
elif cpython.PyList_Check(obj):
is_closed = 0
elif isinstance(obj, pgproto_types.Path):
is_closed = obj.is_closed
npts = len(obj)
encoded_len = 1 + 4 + 16 * npts
if encoded_len > _MAXINT32:
raise ValueError('path value too long')
wbuf.write_int32(<int32_t>encoded_len)
wbuf.write_byte(is_closed)
wbuf.write_int32(<int32_t>npts)
_encode_points(wbuf, obj)
cdef path_decode(CodecContext settings, FRBuffer *buf):
cdef:
int8_t is_closed = <int8_t>(frb_read(buf, 1)[0])
return pgproto_types.Path(*_decode_points(buf), is_closed=is_closed == 1)
cdef poly_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
bint is_closed
ssize_t npts
ssize_t encoded_len
int32_t i
npts = len(obj)
encoded_len = 4 + 16 * npts
if encoded_len > _MAXINT32:
raise ValueError('polygon value too long')
wbuf.write_int32(<int32_t>encoded_len)
wbuf.write_int32(<int32_t>npts)
_encode_points(wbuf, obj)
cdef poly_decode(CodecContext settings, FRBuffer *buf):
return pgproto_types.Polygon(*_decode_points(buf))
cdef circle_encode(CodecContext settings, WriteBuffer wbuf, obj):
wbuf.write_int32(24)
wbuf.write_double(obj[0][0])
wbuf.write_double(obj[0][1])
wbuf.write_double(obj[1])
cdef circle_decode(CodecContext settings, FRBuffer *buf):
cdef:
double center_x = hton.unpack_double(frb_read(buf, 8))
double center_y = hton.unpack_double(frb_read(buf, 8))
double radius = hton.unpack_double(frb_read(buf, 8))
return pgproto_types.Circle((center_x, center_y), radius)

View File

@@ -0,0 +1,73 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef hstore_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
ssize_t count
object items
WriteBuffer item_buf = WriteBuffer.new()
count = len(obj)
if count > _MAXINT32:
raise ValueError('hstore value is too large')
item_buf.write_int32(<int32_t>count)
if hasattr(obj, 'items'):
items = obj.items()
else:
items = obj
for k, v in items:
if k is None:
raise ValueError('null value not allowed in hstore key')
as_pg_string_and_size(settings, k, &str, &size)
item_buf.write_int32(<int32_t>size)
item_buf.write_cstr(str, size)
if v is None:
item_buf.write_int32(<int32_t>-1)
else:
as_pg_string_and_size(settings, v, &str, &size)
item_buf.write_int32(<int32_t>size)
item_buf.write_cstr(str, size)
buf.write_int32(item_buf.len())
buf.write_buffer(item_buf)
cdef hstore_decode(CodecContext settings, FRBuffer *buf):
cdef:
dict result
uint32_t elem_count
int32_t elem_len
uint32_t i
str k
str v
result = {}
elem_count = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
if elem_count == 0:
return result
for i in range(elem_count):
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len < 0:
raise ValueError('null value not allowed in hstore key')
k = decode_pg_string(settings, frb_read(buf, elem_len), elem_len)
elem_len = hton.unpack_int32(frb_read(buf, 4))
if elem_len < 0:
v = None
else:
v = decode_pg_string(settings, frb_read(buf, elem_len), elem_len)
result[k] = v
return result

View File

@@ -0,0 +1,144 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef bool_encode(CodecContext settings, WriteBuffer buf, obj):
if not cpython.PyBool_Check(obj):
raise TypeError('a boolean is required (got type {})'.format(
type(obj).__name__))
buf.write_int32(1)
buf.write_byte(b'\x01' if obj is True else b'\x00')
cdef bool_decode(CodecContext settings, FRBuffer *buf):
return frb_read(buf, 1)[0] is b'\x01'
cdef int2_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef long val
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsLong(obj)
except OverflowError:
overflow = 1
if overflow or val < INT16_MIN or val > INT16_MAX:
raise OverflowError('value out of int16 range')
buf.write_int32(2)
buf.write_int16(<int16_t>val)
cdef int2_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromLong(hton.unpack_int16(frb_read(buf, 2)))
cdef int4_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef long val = 0
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsLong(obj)
except OverflowError:
overflow = 1
# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(val) > 4 and (val < INT32_MIN or val > INT32_MAX)):
raise OverflowError('value out of int32 range')
buf.write_int32(4)
buf.write_int32(<int32_t>val)
cdef int4_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromLong(hton.unpack_int32(frb_read(buf, 4)))
cdef uint4_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef unsigned long val = 0
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsUnsignedLong(obj)
except OverflowError:
overflow = 1
# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(val) > 4 and val > UINT32_MAX):
raise OverflowError('value out of uint32 range')
buf.write_int32(4)
buf.write_int32(<int32_t>val)
cdef uint4_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromUnsignedLong(
<uint32_t>hton.unpack_int32(frb_read(buf, 4)))
cdef int8_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef long long val
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsLongLong(obj)
except OverflowError:
overflow = 1
# Just in case for systems with "long long" bigger than 8 bytes
if overflow or (sizeof(val) > 8 and (val < INT64_MIN or val > INT64_MAX)):
raise OverflowError('value out of int64 range')
buf.write_int32(8)
buf.write_int64(<int64_t>val)
cdef int8_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromLongLong(hton.unpack_int64(frb_read(buf, 8)))
cdef uint8_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef unsigned long long val = 0
try:
if type(obj) is not int and hasattr(type(obj), '__int__'):
# Silence a Python warning about implicit __int__
# conversion.
obj = int(obj)
val = cpython.PyLong_AsUnsignedLongLong(obj)
except OverflowError:
overflow = 1
# Just in case for systems with "long long" bigger than 8 bytes
if overflow or (sizeof(val) > 8 and val > UINT64_MAX):
raise OverflowError('value out of uint64 range')
buf.write_int32(8)
buf.write_int64(<int64_t>val)
cdef uint8_decode(CodecContext settings, FRBuffer *buf):
return cpython.PyLong_FromUnsignedLongLong(
<uint64_t>hton.unpack_int64(frb_read(buf, 8)))

View File

@@ -0,0 +1,57 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef jsonb_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
if settings.is_encoding_json():
obj = settings.get_json_encoder().encode(obj)
as_pg_string_and_size(settings, obj, &str, &size)
if size > 0x7fffffff - 1:
raise ValueError('string too long')
buf.write_int32(<int32_t>size + 1)
buf.write_byte(1) # JSONB format version
buf.write_cstr(str, size)
cdef jsonb_decode(CodecContext settings, FRBuffer *buf):
cdef uint8_t format = <uint8_t>(frb_read(buf, 1)[0])
if format != 1:
raise ValueError('unexpected JSONB format: {}'.format(format))
rv = text_decode(settings, buf)
if settings.is_decoding_json():
rv = settings.get_json_decoder().decode(rv)
return rv
cdef json_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
if settings.is_encoding_json():
obj = settings.get_json_encoder().encode(obj)
text_encode(settings, buf, obj)
cdef json_decode(CodecContext settings, FRBuffer *buf):
rv = text_decode(settings, buf)
if settings.is_decoding_json():
rv = settings.get_json_decoder().decode(rv)
return rv

View File

@@ -0,0 +1,29 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef jsonpath_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
as_pg_string_and_size(settings, obj, &str, &size)
if size > 0x7fffffff - 1:
raise ValueError('string too long')
buf.write_int32(<int32_t>size + 1)
buf.write_byte(1) # jsonpath format version
buf.write_cstr(str, size)
cdef jsonpath_decode(CodecContext settings, FRBuffer *buf):
cdef uint8_t format = <uint8_t>(frb_read(buf, 1)[0])
if format != 1:
raise ValueError('unexpected jsonpath format: {}'.format(format))
return text_decode(settings, buf)

View File

@@ -0,0 +1,16 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef void_encode(CodecContext settings, WriteBuffer buf, obj):
# Void is zero bytes
buf.write_int32(0)
cdef void_decode(CodecContext settings, FRBuffer *buf):
# Do nothing; void will be passed as NULL so this function
# will never be called.
pass

View File

@@ -0,0 +1,139 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
import ipaddress
# defined in postgresql/src/include/inet.h
#
DEF PGSQL_AF_INET = 2 # AF_INET
DEF PGSQL_AF_INET6 = 3 # AF_INET + 1
_ipaddr = ipaddress.ip_address
_ipiface = ipaddress.ip_interface
_ipnet = ipaddress.ip_network
cdef inline uint8_t _ip_max_prefix_len(int32_t family):
# Maximum number of bits in the network prefix of the specified
# IP protocol version.
if family == PGSQL_AF_INET:
return 32
else:
return 128
cdef inline int32_t _ip_addr_len(int32_t family):
# Length of address in bytes for the specified IP protocol version.
if family == PGSQL_AF_INET:
return 4
else:
return 16
cdef inline int8_t _ver_to_family(int32_t version):
if version == 4:
return PGSQL_AF_INET
else:
return PGSQL_AF_INET6
cdef inline _net_encode(WriteBuffer buf, int8_t family, uint32_t bits,
int8_t is_cidr, bytes addr):
cdef:
char *addrbytes
ssize_t addrlen
cpython.PyBytes_AsStringAndSize(addr, &addrbytes, &addrlen)
buf.write_int32(4 + <int32_t>addrlen)
buf.write_byte(family)
buf.write_byte(<int8_t>bits)
buf.write_byte(is_cidr)
buf.write_byte(<int8_t>addrlen)
buf.write_cstr(addrbytes, addrlen)
cdef net_decode(CodecContext settings, FRBuffer *buf, bint as_cidr):
cdef:
int32_t family = <int32_t>frb_read(buf, 1)[0]
uint8_t bits = <uint8_t>frb_read(buf, 1)[0]
int prefix_len
int32_t is_cidr = <int32_t>frb_read(buf, 1)[0]
int32_t addrlen = <int32_t>frb_read(buf, 1)[0]
bytes addr
uint8_t max_prefix_len = _ip_max_prefix_len(family)
if is_cidr != as_cidr:
raise ValueError('unexpected CIDR flag set in non-cidr value')
if family != PGSQL_AF_INET and family != PGSQL_AF_INET6:
raise ValueError('invalid address family in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))
max_prefix_len = _ip_max_prefix_len(family)
if bits > max_prefix_len:
raise ValueError('invalid network prefix length in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))
if addrlen != _ip_addr_len(family):
raise ValueError('invalid address length in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))
addr = cpython.PyBytes_FromStringAndSize(frb_read(buf, addrlen), addrlen)
if as_cidr or bits != max_prefix_len:
prefix_len = cpython.PyLong_FromLong(bits)
if as_cidr:
return _ipnet((addr, prefix_len))
else:
return _ipiface((addr, prefix_len))
else:
return _ipaddr(addr)
cdef cidr_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
object ipnet
int8_t family
ipnet = _ipnet(obj)
family = _ver_to_family(ipnet.version)
_net_encode(buf, family, ipnet.prefixlen, 1, ipnet.network_address.packed)
cdef cidr_decode(CodecContext settings, FRBuffer *buf):
return net_decode(settings, buf, True)
cdef inet_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
object ipaddr
int8_t family
try:
ipaddr = _ipaddr(obj)
except ValueError:
# PostgreSQL accepts *both* CIDR and host values
# for the host datatype.
ipaddr = _ipiface(obj)
family = _ver_to_family(ipaddr.version)
_net_encode(buf, family, ipaddr.network.prefixlen, 1, ipaddr.packed)
else:
family = _ver_to_family(ipaddr.version)
_net_encode(buf, family, _ip_max_prefix_len(family), 0, ipaddr.packed)
cdef inet_decode(CodecContext settings, FRBuffer *buf):
return net_decode(settings, buf, False)

View File

@@ -0,0 +1,356 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from libc.math cimport abs, log10
from libc.stdio cimport snprintf
import decimal
# defined in postgresql/src/backend/utils/adt/numeric.c
DEF DEC_DIGITS = 4
DEF MAX_DSCALE = 0x3FFF
DEF NUMERIC_POS = 0x0000
DEF NUMERIC_NEG = 0x4000
DEF NUMERIC_NAN = 0xC000
DEF NUMERIC_PINF = 0xD000
DEF NUMERIC_NINF = 0xF000
_Dec = decimal.Decimal
cdef numeric_encode_text(CodecContext settings, WriteBuffer buf, obj):
text_encode(settings, buf, str(obj))
cdef numeric_decode_text(CodecContext settings, FRBuffer *buf):
return _Dec(text_decode(settings, buf))
cdef numeric_encode_binary(CodecContext settings, WriteBuffer buf, obj):
cdef:
object dec
object dt
int64_t exponent
int64_t i
int64_t j
tuple pydigits
int64_t num_pydigits
int16_t pgdigit
int64_t num_pgdigits
int16_t dscale
int64_t dweight
int64_t weight
uint16_t sign
int64_t padding_size = 0
if isinstance(obj, _Dec):
dec = obj
else:
dec = _Dec(obj)
dt = dec.as_tuple()
if dt.exponent == 'n' or dt.exponent == 'N':
# NaN
sign = NUMERIC_NAN
num_pgdigits = 0
weight = 0
dscale = 0
elif dt.exponent == 'F':
# Infinity
if dt.sign:
sign = NUMERIC_NINF
else:
sign = NUMERIC_PINF
num_pgdigits = 0
weight = 0
dscale = 0
else:
exponent = dt.exponent
if exponent < 0 and -exponent > MAX_DSCALE:
raise ValueError(
'cannot encode Decimal value into numeric: '
'exponent is too small')
if dt.sign:
sign = NUMERIC_NEG
else:
sign = NUMERIC_POS
pydigits = dt.digits
num_pydigits = len(pydigits)
dweight = num_pydigits + exponent - 1
if dweight >= 0:
weight = (dweight + DEC_DIGITS) // DEC_DIGITS - 1
else:
weight = -((-dweight - 1) // DEC_DIGITS + 1)
if weight > 2 ** 16 - 1:
raise ValueError(
'cannot encode Decimal value into numeric: '
'exponent is too large')
padding_size = \
(weight + 1) * DEC_DIGITS - (dweight + 1)
num_pgdigits = \
(num_pydigits + padding_size + DEC_DIGITS - 1) // DEC_DIGITS
if num_pgdigits > 2 ** 16 - 1:
raise ValueError(
'cannot encode Decimal value into numeric: '
'number of digits is too large')
# Pad decimal digits to provide room for correct Postgres
# digit alignment in the digit computation loop.
pydigits = (0,) * DEC_DIGITS + pydigits + (0,) * DEC_DIGITS
if exponent < 0:
if -exponent > MAX_DSCALE:
raise ValueError(
'cannot encode Decimal value into numeric: '
'exponent is too small')
dscale = <int16_t>-exponent
else:
dscale = 0
buf.write_int32(2 + 2 + 2 + 2 + 2 * <uint16_t>num_pgdigits)
buf.write_int16(<int16_t>num_pgdigits)
buf.write_int16(<int16_t>weight)
buf.write_int16(<int16_t>sign)
buf.write_int16(dscale)
j = DEC_DIGITS - padding_size
for i in range(num_pgdigits):
pgdigit = (pydigits[j] * 1000 + pydigits[j + 1] * 100 +
pydigits[j + 2] * 10 + pydigits[j + 3])
j += DEC_DIGITS
buf.write_int16(pgdigit)
# The decoding strategy here is to form a string representation of
# the numeric var, as it is faster than passing an iterable of digits.
# For this reason the below code is pure overhead and is ~25% slower
# than the simple text decoder above. That said, we need the binary
# decoder to support binary COPY with numeric values.
cdef numeric_decode_binary_ex(
CodecContext settings,
FRBuffer *buf,
bint trail_fract_zero,
):
cdef:
uint16_t num_pgdigits = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
int16_t weight = hton.unpack_int16(frb_read(buf, 2))
uint16_t sign = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
uint16_t dscale = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
int16_t pgdigit0
ssize_t i
int16_t pgdigit
object pydigits
ssize_t num_pydigits
ssize_t actual_num_pydigits
ssize_t buf_size
int64_t exponent
int64_t abs_exponent
ssize_t exponent_chars
ssize_t front_padding = 0
ssize_t num_fract_digits
ssize_t trailing_fract_zeros_adj
char smallbuf[_NUMERIC_DECODER_SMALLBUF_SIZE]
char *charbuf
char *bufptr
bint buf_allocated = False
if sign == NUMERIC_NAN:
# Not-a-number
return _Dec('NaN')
elif sign == NUMERIC_PINF:
# +Infinity
return _Dec('Infinity')
elif sign == NUMERIC_NINF:
# -Infinity
return _Dec('-Infinity')
if num_pgdigits == 0:
# Zero
return _Dec('0e-' + str(dscale))
pgdigit0 = hton.unpack_int16(frb_read(buf, 2))
if weight >= 0:
if pgdigit0 < 10:
front_padding = 3
elif pgdigit0 < 100:
front_padding = 2
elif pgdigit0 < 1000:
front_padding = 1
# The number of fractional decimal digits actually encoded in
# base-DEC_DEIGITS digits sent by Postgres.
num_fract_digits = (num_pgdigits - weight - 1) * DEC_DIGITS
# The trailing zero adjustment necessary to obtain exactly
# dscale number of fractional digits in output. May be negative,
# which indicates that trailing zeros in the last input digit
# should be discarded.
trailing_fract_zeros_adj = dscale - num_fract_digits
# Maximum possible number of decimal digits in base 10.
# The actual number might be up to 3 digits smaller due to
# leading zeros in first input digit.
num_pydigits = num_pgdigits * DEC_DIGITS
if trailing_fract_zeros_adj > 0:
num_pydigits += trailing_fract_zeros_adj
# Exponent.
exponent = (weight + 1) * DEC_DIGITS - front_padding
abs_exponent = abs(exponent)
if abs_exponent != 0:
# Number of characters required to render absolute exponent value
# in decimal.
exponent_chars = <ssize_t>log10(<double>abs_exponent) + 1
else:
exponent_chars = 0
# Output buffer size.
buf_size = (
1 + # sign
1 + # leading zero
1 + # decimal dot
num_pydigits + # digits
1 + # possible trailing zero padding
2 + # exponent indicator (E-,E+)
exponent_chars + # exponent
1 # null terminator char
)
if buf_size > _NUMERIC_DECODER_SMALLBUF_SIZE:
charbuf = <char *>cpython.PyMem_Malloc(<size_t>buf_size)
buf_allocated = True
else:
charbuf = smallbuf
try:
bufptr = charbuf
if sign == NUMERIC_NEG:
bufptr[0] = b'-'
bufptr += 1
bufptr[0] = b'0'
bufptr[1] = b'.'
bufptr += 2
if weight >= 0:
bufptr = _unpack_digit_stripping_lzeros(bufptr, pgdigit0)
else:
bufptr = _unpack_digit(bufptr, pgdigit0)
for i in range(1, num_pgdigits):
pgdigit = hton.unpack_int16(frb_read(buf, 2))
bufptr = _unpack_digit(bufptr, pgdigit)
if dscale:
if trailing_fract_zeros_adj > 0:
for i in range(trailing_fract_zeros_adj):
bufptr[i] = <char>b'0'
# If display scale is _less_ than the number of rendered digits,
# trailing_fract_zeros_adj will be negative and this will strip
# the excess trailing zeros.
bufptr += trailing_fract_zeros_adj
if trail_fract_zero:
# Check if the number of rendered digits matches the exponent,
# and if so, add another trailing zero, so the result always
# appears with a decimal point.
actual_num_pydigits = bufptr - charbuf - 2
if sign == NUMERIC_NEG:
actual_num_pydigits -= 1
if actual_num_pydigits == abs_exponent:
bufptr[0] = <char>b'0'
bufptr += 1
if exponent != 0:
bufptr[0] = b'E'
if exponent < 0:
bufptr[1] = b'-'
else:
bufptr[1] = b'+'
bufptr += 2
snprintf(bufptr, <size_t>exponent_chars + 1, '%d',
<int>abs_exponent)
bufptr += exponent_chars
bufptr[0] = 0
pydigits = cpythonx.PyUnicode_FromString(charbuf)
return _Dec(pydigits)
finally:
if buf_allocated:
cpython.PyMem_Free(charbuf)
cdef numeric_decode_binary(CodecContext settings, FRBuffer *buf):
return numeric_decode_binary_ex(settings, buf, False)
cdef inline char *_unpack_digit_stripping_lzeros(char *buf, int64_t pgdigit):
cdef:
int64_t d
bint significant
d = pgdigit // 1000
significant = (d > 0)
if significant:
pgdigit -= d * 1000
buf[0] = <char>(d + <int32_t>b'0')
buf += 1
d = pgdigit // 100
significant |= (d > 0)
if significant:
pgdigit -= d * 100
buf[0] = <char>(d + <int32_t>b'0')
buf += 1
d = pgdigit // 10
significant |= (d > 0)
if significant:
pgdigit -= d * 10
buf[0] = <char>(d + <int32_t>b'0')
buf += 1
buf[0] = <char>(pgdigit + <int32_t>b'0')
buf += 1
return buf
cdef inline char *_unpack_digit(char *buf, int64_t pgdigit):
cdef:
int64_t d
d = pgdigit // 1000
pgdigit -= d * 1000
buf[0] = <char>(d + <int32_t>b'0')
d = pgdigit // 100
pgdigit -= d * 100
buf[1] = <char>(d + <int32_t>b'0')
d = pgdigit // 10
pgdigit -= d * 10
buf[2] = <char>(d + <int32_t>b'0')
buf[3] = <char>(pgdigit + <int32_t>b'0')
buf += 4
return buf

View File

@@ -0,0 +1,63 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef pg_snapshot_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
ssize_t nxip
uint64_t xmin
uint64_t xmax
int i
WriteBuffer xip_buf = WriteBuffer.new()
if not (cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)):
raise TypeError(
'list or tuple expected (got type {})'.format(type(obj)))
if len(obj) != 3:
raise ValueError(
'invalid number of elements in txid_snapshot tuple, expecting 4')
nxip = len(obj[2])
if nxip > _MAXINT32:
raise ValueError('txid_snapshot value is too long')
xmin = obj[0]
xmax = obj[1]
for i in range(nxip):
xip_buf.write_int64(
<int64_t>cpython.PyLong_AsUnsignedLongLong(obj[2][i]))
buf.write_int32(20 + xip_buf.len())
buf.write_int32(<int32_t>nxip)
buf.write_int64(<int64_t>xmin)
buf.write_int64(<int64_t>xmax)
buf.write_buffer(xip_buf)
cdef pg_snapshot_decode(CodecContext settings, FRBuffer *buf):
cdef:
int32_t nxip
uint64_t xmin
uint64_t xmax
tuple xip_tup
int32_t i
object xip
nxip = hton.unpack_int32(frb_read(buf, 4))
xmin = <uint64_t>hton.unpack_int64(frb_read(buf, 8))
xmax = <uint64_t>hton.unpack_int64(frb_read(buf, 8))
xip_tup = cpython.PyTuple_New(nxip)
for i in range(nxip):
xip = cpython.PyLong_FromUnsignedLongLong(
<uint64_t>hton.unpack_int64(frb_read(buf, 8)))
cpython.Py_INCREF(xip)
cpython.PyTuple_SET_ITEM(xip_tup, i, xip)
return (xmin, xmax, xip_tup)

View File

@@ -0,0 +1,48 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef inline as_pg_string_and_size(
CodecContext settings, obj, char **cstr, ssize_t *size):
if not cpython.PyUnicode_Check(obj):
raise TypeError('expected str, got {}'.format(type(obj).__name__))
if settings.is_encoding_utf8():
cstr[0] = <char*>cpythonx.PyUnicode_AsUTF8AndSize(obj, size)
else:
encoded = settings.get_text_codec().encode(obj)[0]
cpython.PyBytes_AsStringAndSize(encoded, cstr, size)
if size[0] > 0x7fffffff:
raise ValueError('string too long')
cdef text_encode(CodecContext settings, WriteBuffer buf, obj):
cdef:
char *str
ssize_t size
as_pg_string_and_size(settings, obj, &str, &size)
buf.write_int32(<int32_t>size)
buf.write_cstr(str, size)
cdef inline decode_pg_string(CodecContext settings, const char* data,
ssize_t len):
if settings.is_encoding_utf8():
# decode UTF-8 in strict mode
return cpython.PyUnicode_DecodeUTF8(data, len, NULL)
else:
bytes = cpython.PyBytes_FromStringAndSize(data, len)
return settings.get_text_codec().decode(bytes)[0]
cdef text_decode(CodecContext settings, FRBuffer *buf):
cdef ssize_t buf_len = buf.len
return decode_pg_string(settings, frb_read_all(buf), buf_len)

View File

@@ -0,0 +1,51 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef tid_encode(CodecContext settings, WriteBuffer buf, obj):
cdef int overflow = 0
cdef unsigned long block, offset
if not (cpython.PyTuple_Check(obj) or cpython.PyList_Check(obj)):
raise TypeError(
'list or tuple expected (got type {})'.format(type(obj)))
if len(obj) != 2:
raise ValueError(
'invalid number of elements in tid tuple, expecting 2')
try:
block = cpython.PyLong_AsUnsignedLong(obj[0])
except OverflowError:
overflow = 1
# "long" and "long long" have the same size for x86_64, need an extra check
if overflow or (sizeof(block) > 4 and block > UINT32_MAX):
raise OverflowError('tuple id block value out of uint32 range')
try:
offset = cpython.PyLong_AsUnsignedLong(obj[1])
overflow = 0
except OverflowError:
overflow = 1
if overflow or offset > 65535:
raise OverflowError('tuple id offset value out of uint16 range')
buf.write_int32(6)
buf.write_int32(<int32_t>block)
buf.write_int16(<int16_t>offset)
cdef tid_decode(CodecContext settings, FRBuffer *buf):
cdef:
uint32_t block
uint16_t offset
block = <uint32_t>hton.unpack_int32(frb_read(buf, 4))
offset = <uint16_t>hton.unpack_int16(frb_read(buf, 2))
return (block, offset)

View File

@@ -0,0 +1,27 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef uuid_encode(CodecContext settings, WriteBuffer wbuf, obj):
cdef:
char buf[16]
if type(obj) is pg_UUID:
wbuf.write_int32(<int32_t>16)
wbuf.write_cstr((<UUID>obj)._data, 16)
elif cpython.PyUnicode_Check(obj):
pg_uuid_bytes_from_str(obj, buf)
wbuf.write_int32(<int32_t>16)
wbuf.write_cstr(buf, 16)
else:
bytea_encode(settings, wbuf, obj.bytes)
cdef uuid_decode(CodecContext settings, FRBuffer *buf):
if buf.len != 16:
raise TypeError(
f'cannot decode UUID, expected 16 bytes, got {buf.len}')
return pg_uuid_from_buf(frb_read_all(buf))

View File

@@ -0,0 +1,9 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
DEF _MAXINT32 = 2**31 - 1
DEF _NUMERIC_DECODER_SMALLBUF_SIZE = 256

View File

@@ -0,0 +1,23 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
from cpython cimport Py_buffer
cdef extern from "Python.h":
int PyUnicode_1BYTE_KIND
int PyByteArray_CheckExact(object)
int PyByteArray_Resize(object, ssize_t) except -1
object PyByteArray_FromStringAndSize(const char *, ssize_t)
char* PyByteArray_AsString(object)
object PyUnicode_FromString(const char *u)
const char* PyUnicode_AsUTF8AndSize(
object unicode, ssize_t *size) except NULL
object PyUnicode_FromKindAndData(
int kind, const void *buffer, Py_ssize_t size)

View File

@@ -0,0 +1,10 @@
# Copyright (C) 2016-present the asyncpg authors and contributors
# <see AUTHORS file>
#
# This module is part of asyncpg and is released under
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0
cdef extern from "debug.h":
cdef int PG_DEBUG

Some files were not shown because too many files have changed in this diff Show More