Initial commit: 首次建仓,建立目录结构

This commit is contained in:
FXY
2026-06-11 23:49:54 +08:00
commit 4038a476b5
9396 changed files with 2372905 additions and 0 deletions

View File

@ -0,0 +1 @@
9ef9f04e439a30021acccf8e501e87eba2e9b0d35e506de71716d248b4284043 /home/runner/work/aiohttp/aiohttp/aiohttp/_cparser.pxd

View File

@ -0,0 +1 @@
d067f01423cddb3c442933b5fcc039b18ab651fcec1bc91c577693aafc25cf78 /home/runner/work/aiohttp/aiohttp/aiohttp/_find_header.pxd

View File

@ -0,0 +1 @@
98d2d47a729345990d9575dcf68c550de5d5277d60b2b2385d1ebd8eb0d6c4fc /home/runner/work/aiohttp/aiohttp/aiohttp/_http_parser.pyx

View File

@ -0,0 +1 @@
16686a44bbfeab2bcbd8126f570935dbc545ab060c6f56f5bc72cf0adddac170 /home/runner/work/aiohttp/aiohttp/aiohttp/_http_writer.pyx

View File

@ -0,0 +1 @@
a46ad6c3a2faf8d26a2c6afc1a2210ce379a23f2799fce7b26a01f6ce5a40642 /home/runner/work/aiohttp/aiohttp/aiohttp/hdrs.py

View File

@ -0,0 +1,279 @@
__version__ = "3.14.1"
from typing import TYPE_CHECKING
from . import hdrs as hdrs
from .client import (
BaseConnector,
ClientConnectionError,
ClientConnectionResetError,
ClientConnectorCertificateError,
ClientConnectorDNSError,
ClientConnectorError,
ClientConnectorSSLError,
ClientError,
ClientHttpProxyError,
ClientOSError,
ClientPayloadError,
ClientProxyConnectionError,
ClientRequest,
ClientResponse,
ClientResponseError,
ClientSession,
ClientSSLError,
ClientTimeout,
ClientWebSocketResponse,
ClientWSTimeout,
ConnectionTimeoutError,
ContentTypeError,
Fingerprint,
InvalidURL,
InvalidUrlClientError,
InvalidUrlRedirectClientError,
NamedPipeConnector,
NonHttpUrlClientError,
NonHttpUrlRedirectClientError,
RedirectClientError,
RequestInfo,
ServerConnectionError,
ServerDisconnectedError,
ServerFingerprintMismatch,
ServerTimeoutError,
SocketTimeoutError,
TCPConnector,
TooManyRedirects,
UnixConnector,
WSMessageTypeError,
WSServerHandshakeError,
request,
)
from .client_middleware_digest_auth import DigestAuthMiddleware
from .client_middlewares import ClientHandlerType, ClientMiddlewareType
from .compression_utils import set_zlib_backend
from .connector import (
AddrInfoType as AddrInfoType,
SocketFactoryType as SocketFactoryType,
)
from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar
from .formdata import FormData as FormData
from .helpers import BasicAuth, ChainMapProxy, ETag, encode_basic_auth
from .http import (
HttpVersion as HttpVersion,
HttpVersion10 as HttpVersion10,
HttpVersion11 as HttpVersion11,
WebSocketError as WebSocketError,
WSCloseCode as WSCloseCode,
WSMessage as WSMessage,
WSMsgType as WSMsgType,
)
from .multipart import (
BadContentDispositionHeader as BadContentDispositionHeader,
BadContentDispositionParam as BadContentDispositionParam,
BodyPartReader as BodyPartReader,
MultipartReader as MultipartReader,
MultipartWriter as MultipartWriter,
content_disposition_filename as content_disposition_filename,
parse_content_disposition as parse_content_disposition,
)
from .payload import (
PAYLOAD_REGISTRY as PAYLOAD_REGISTRY,
AsyncIterablePayload as AsyncIterablePayload,
BufferedReaderPayload as BufferedReaderPayload,
BytesIOPayload as BytesIOPayload,
BytesPayload as BytesPayload,
IOBasePayload as IOBasePayload,
JsonPayload as JsonPayload,
Payload as Payload,
StringIOPayload as StringIOPayload,
StringPayload as StringPayload,
TextIOPayload as TextIOPayload,
get_payload as get_payload,
payload_type as payload_type,
)
from .payload_streamer import streamer as streamer
from .resolver import (
AsyncResolver as AsyncResolver,
DefaultResolver as DefaultResolver,
ThreadedResolver as ThreadedResolver,
)
from .streams import (
EMPTY_PAYLOAD as EMPTY_PAYLOAD,
DataQueue as DataQueue,
EofStream as EofStream,
FlowControlDataQueue as FlowControlDataQueue,
StreamReader as StreamReader,
)
from .tracing import (
TraceConfig as TraceConfig,
TraceConnectionCreateEndParams as TraceConnectionCreateEndParams,
TraceConnectionCreateStartParams as TraceConnectionCreateStartParams,
TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams,
TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams,
TraceConnectionReuseconnParams as TraceConnectionReuseconnParams,
TraceDnsCacheHitParams as TraceDnsCacheHitParams,
TraceDnsCacheMissParams as TraceDnsCacheMissParams,
TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams,
TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams,
TraceRequestChunkSentParams as TraceRequestChunkSentParams,
TraceRequestEndParams as TraceRequestEndParams,
TraceRequestExceptionParams as TraceRequestExceptionParams,
TraceRequestHeadersSentParams as TraceRequestHeadersSentParams,
TraceRequestRedirectParams as TraceRequestRedirectParams,
TraceRequestStartParams as TraceRequestStartParams,
TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams,
)
if TYPE_CHECKING:
# At runtime these are lazy-loaded at the bottom of the file.
from .worker import (
GunicornUVLoopWebWorker as GunicornUVLoopWebWorker,
GunicornWebWorker as GunicornWebWorker,
)
__all__: tuple[str, ...] = (
"hdrs",
# client
"AddrInfoType",
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientConnectorCertificateError",
"ClientConnectorDNSError",
"ClientConnectorError",
"ClientConnectorSSLError",
"ClientError",
"ClientHttpProxyError",
"ClientOSError",
"ClientPayloadError",
"ClientProxyConnectionError",
"ClientResponse",
"ClientRequest",
"ClientResponseError",
"ClientSSLError",
"ClientSession",
"ClientTimeout",
"ClientWebSocketResponse",
"ClientWSTimeout",
"ConnectionTimeoutError",
"ContentTypeError",
"Fingerprint",
"FlowControlDataQueue",
"InvalidURL",
"InvalidUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlClientError",
"NonHttpUrlRedirectClientError",
"RedirectClientError",
"RequestInfo",
"ServerConnectionError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
"SocketFactoryType",
"SocketTimeoutError",
"TCPConnector",
"TooManyRedirects",
"UnixConnector",
"NamedPipeConnector",
"WSServerHandshakeError",
"request",
# client_middleware
"ClientMiddlewareType",
"ClientHandlerType",
# cookiejar
"CookieJar",
"DummyCookieJar",
# formdata
"FormData",
# helpers
"BasicAuth",
"ChainMapProxy",
"DigestAuthMiddleware",
"ETag",
"encode_basic_auth",
"set_zlib_backend",
# http
"HttpVersion",
"HttpVersion10",
"HttpVersion11",
"WSMsgType",
"WSCloseCode",
"WSMessage",
"WebSocketError",
# multipart
"BadContentDispositionHeader",
"BadContentDispositionParam",
"BodyPartReader",
"MultipartReader",
"MultipartWriter",
"content_disposition_filename",
"parse_content_disposition",
# payload
"AsyncIterablePayload",
"BufferedReaderPayload",
"BytesIOPayload",
"BytesPayload",
"IOBasePayload",
"JsonPayload",
"PAYLOAD_REGISTRY",
"Payload",
"StringIOPayload",
"StringPayload",
"TextIOPayload",
"get_payload",
"payload_type",
# payload_streamer
"streamer",
# resolver
"AsyncResolver",
"DefaultResolver",
"ThreadedResolver",
# streams
"DataQueue",
"EMPTY_PAYLOAD",
"EofStream",
"StreamReader",
# tracing
"TraceConfig",
"TraceConnectionCreateEndParams",
"TraceConnectionCreateStartParams",
"TraceConnectionQueuedEndParams",
"TraceConnectionQueuedStartParams",
"TraceConnectionReuseconnParams",
"TraceDnsCacheHitParams",
"TraceDnsCacheMissParams",
"TraceDnsResolveHostEndParams",
"TraceDnsResolveHostStartParams",
"TraceRequestChunkSentParams",
"TraceRequestEndParams",
"TraceRequestExceptionParams",
"TraceRequestHeadersSentParams",
"TraceRequestRedirectParams",
"TraceRequestStartParams",
"TraceResponseChunkReceivedParams",
# workers (imported lazily with __getattr__)
"GunicornUVLoopWebWorker",
"GunicornWebWorker",
"WSMessageTypeError",
)
def __dir__() -> tuple[str, ...]:
return __all__ + ("__doc__",)
def __getattr__(name: str) -> object:
global GunicornUVLoopWebWorker, GunicornWebWorker
# Importing gunicorn takes a long time (>100ms), so only import if actually needed.
if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"):
try:
from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw
except ImportError:
return None
GunicornUVLoopWebWorker = guv # type: ignore[misc]
GunicornWebWorker = gw # type: ignore[misc]
return guv if name == "GunicornUVLoopWebWorker" else gw
raise AttributeError(f"module {__name__} has no attribute {name}")

View File

@ -0,0 +1,361 @@
"""
Internal cookie handling helpers.
This module contains internal utilities for cookie parsing and manipulation.
These are not part of the public API and may change without notice.
"""
import re
from collections.abc import Sequence
from http.cookies import CookieError, Morsel
from typing import cast
from .log import internal_logger
__all__ = (
"parse_set_cookie_headers",
"parse_cookie_header",
"preserve_morsel_with_coded_value",
)
# Cookie parsing constants
# Allow more characters in cookie names to handle real-world cookies
# that don't strictly follow RFC standards (fixes #2683)
# RFC 6265 defines cookie-name token as per RFC 2616 Section 2.2,
# but many servers send cookies with characters like {} [] () etc.
# This makes the cookie parser more tolerant of real-world cookies
# while still providing some validation to catch obviously malformed names.
_COOKIE_NAME_RE = re.compile(r"^[!#$%&\'()*+\-./0-9:<=>?@A-Z\[\]^_`a-z{|}~]+$")
_COOKIE_KNOWN_ATTRS = frozenset( # AKA Morsel._reserved
(
"path",
"domain",
"max-age",
"expires",
"secure",
"httponly",
"samesite",
"partitioned",
"version",
"comment",
)
)
_COOKIE_BOOL_ATTRS = frozenset( # AKA Morsel._flags
("secure", "httponly", "partitioned")
)
# SimpleCookie's pattern for parsing cookies with relaxed validation
# Based on http.cookies pattern but extended to allow more characters in cookie names
# to handle real-world cookies (fixes #2683)
_COOKIE_PATTERN = re.compile(
r"""
\s* # Optional whitespace at start of cookie
(?P<key> # Start of group 'key'
# aiohttp has extended to include [] for compatibility with real-world cookies
[\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\[\]]+ # Any word of at least one letter
) # End of group 'key'
( # Optional group: there may not be a value.
\s*=\s* # Equal Sign
(?P<val> # Start of group 'val'
"(?:[^\\"]|\\.)*" # Any double-quoted string (properly closed)
| # or
"[^";]* # Unmatched opening quote (differs from SimpleCookie - issue #7993)
| # or
# Special case for "expires" attr - RFC 822, RFC 850, RFC 1036, RFC 1123
(\w{3,6}day|\w{3}),\s # Day of the week or abbreviated day (with comma)
[\w\d\s-]{9,11}\s[\d:]{8}\s # Date and time in specific format
(GMT|[+-]\d{4}) # Timezone: GMT or RFC 2822 offset like -0000, +0100
# NOTE: RFC 2822 timezone support is an aiohttp extension
# for issue #4493 - SimpleCookie does NOT support this
| # or
# ANSI C asctime() format: "Wed Jun 9 10:18:14 2021"
# NOTE: This is an aiohttp extension for issue #4327 - SimpleCookie does NOT support this format
\w{3}\s+\w{3}\s+[\s\d]\d\s+\d{2}:\d{2}:\d{2}\s+\d{4}
| # or
[\w\d!#%&'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=\[\]]* # Any word or empty string
) # End of group 'val'
)? # End of optional value group
\s* # Any number of spaces.
(\s+|;|$) # Ending either at space, semicolon, or EOS.
""",
re.VERBOSE | re.ASCII,
)
def preserve_morsel_with_coded_value(cookie: Morsel[str]) -> Morsel[str]:
"""
Preserve a Morsel's coded_value exactly as received from the server.
This function ensures that cookie encoding is preserved exactly as sent by
the server, which is critical for compatibility with old servers that have
strict requirements about cookie formats.
This addresses the issue described in https://github.com/aio-libs/aiohttp/pull/1453
where Python's SimpleCookie would re-encode cookies, breaking authentication
with certain servers.
Args:
cookie: A Morsel object from SimpleCookie
Returns:
A Morsel object with preserved coded_value
"""
mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel()))
# We use __setstate__ instead of the public set() API because it allows us to
# bypass validation and set already validated state. This is more stable than
# setting protected attributes directly and unlikely to change since it would
# break pickling.
try:
mrsl_val.__setstate__( # type: ignore[attr-defined]
{
"key": cookie.key,
"value": cookie.value,
"coded_value": cookie.coded_value,
}
)
except CookieError:
return cookie
return mrsl_val
_unquote_sub = re.compile(r"\\(?:([0-3][0-7][0-7])|(.))").sub
def _unquote_replace(m: re.Match[str]) -> str:
"""
Replace function for _unquote_sub regex substitution.
Handles escaped characters in cookie values:
- Octal sequences are converted to their character representation
- Other escaped characters are unescaped by removing the backslash
"""
if m[1]:
return chr(int(m[1], 8))
return m[2]
def _unquote(value: str) -> str:
"""
Unquote a cookie value.
Vendored from http.cookies._unquote to ensure compatibility.
Note: The original implementation checked for None, but we've removed
that check since all callers already ensure the value is not None.
"""
# If there aren't any doublequotes,
# then there can't be any special characters. See RFC 2109.
if len(value) < 2:
return value
if value[0] != '"' or value[-1] != '"':
return value
# We have to assume that we must decode this string.
# Down to work.
# Remove the "s
value = value[1:-1]
# Check for special sequences. Examples:
# \012 --> \n
# \" --> "
#
return _unquote_sub(_unquote_replace, value)
def parse_cookie_header(header: str) -> list[tuple[str, Morsel[str]]]:
"""
Parse a Cookie header according to RFC 6265 Section 5.4.
Cookie headers contain only name-value pairs separated by semicolons.
There are no attributes in Cookie headers - even names that match
attribute names (like 'path' or 'secure') should be treated as cookies.
This parser uses the same regex-based approach as parse_set_cookie_headers
to properly handle quoted values that may contain semicolons. When the
regex fails to match a malformed cookie, it falls back to simple parsing
to ensure subsequent cookies are not lost
https://github.com/aio-libs/aiohttp/issues/11632
Args:
header: The Cookie header value to parse
Returns:
List of (name, Morsel) tuples for compatibility with SimpleCookie.update()
"""
if not header:
return []
cookies: list[tuple[str, Morsel[str]]] = []
morsel: Morsel[str]
i = 0
n = len(header)
invalid_names = []
while i < n:
# Use the same pattern as parse_set_cookie_headers to find cookies
match = _COOKIE_PATTERN.match(header, i)
if not match:
# Fallback for malformed cookies https://github.com/aio-libs/aiohttp/issues/11632
# Find next semicolon to skip or attempt simple key=value parsing
next_semi = header.find(";", i)
eq_pos = header.find("=", i)
# Try to extract key=value if '=' comes before ';'
if eq_pos != -1 and (next_semi == -1 or eq_pos < next_semi):
end_pos = next_semi if next_semi != -1 else n
key = header[i:eq_pos].strip()
value = header[eq_pos + 1 : end_pos].strip()
# Validate the name (same as regex path)
if not _COOKIE_NAME_RE.match(key):
invalid_names.append(key)
else:
morsel = Morsel()
try:
morsel.__setstate__( # type: ignore[attr-defined]
{
"key": key,
"value": _unquote(value),
"coded_value": value,
}
)
except CookieError:
pass
else:
cookies.append((key, morsel))
# Move to next cookie or end
i = next_semi + 1 if next_semi != -1 else n
continue
key = match.group("key")
value = match.group("val") or ""
i = match.end(0)
# Validate the name
if not key or not _COOKIE_NAME_RE.match(key):
invalid_names.append(key)
continue
# Create new morsel
morsel = Morsel()
# Preserve the original value as coded_value (with quotes if present)
# We use __setstate__ instead of the public set() API because it allows us to
# bypass validation and set already validated state. This is more stable than
# setting protected attributes directly and unlikely to change since it would
# break pickling.
try:
morsel.__setstate__( # type: ignore[attr-defined]
{"key": key, "value": _unquote(value), "coded_value": value}
)
except CookieError:
continue
cookies.append((key, morsel))
if invalid_names:
internal_logger.debug(
"Cannot load cookie. Illegal cookie names: %r", invalid_names
)
return cookies
def parse_set_cookie_headers(headers: Sequence[str]) -> list[tuple[str, Morsel[str]]]:
"""
Parse cookie headers using a vendored version of SimpleCookie parsing.
This implementation is based on SimpleCookie.__parse_string to ensure
compatibility with how SimpleCookie parses cookies, including handling
of malformed cookies with missing semicolons.
This function is used for both Cookie and Set-Cookie headers in order to be
forgiving. Ideally we would have followed RFC 6265 Section 5.2 (for Cookie
headers) and RFC 6265 Section 4.2.1 (for Set-Cookie headers), but the
real world data makes it impossible since we need to be a bit more forgiving.
NOTE: This implementation differs from SimpleCookie in handling unmatched quotes.
SimpleCookie will stop parsing when it encounters a cookie value with an unmatched
quote (e.g., 'cookie="value'), causing subsequent cookies to be silently dropped.
This implementation handles unmatched quotes more gracefully to prevent cookie loss.
See https://github.com/aio-libs/aiohttp/issues/7993
"""
parsed_cookies: list[tuple[str, Morsel[str]]] = []
for header in headers:
if not header:
continue
# Parse cookie string using SimpleCookie's algorithm
i = 0
n = len(header)
current_morsel: Morsel[str] | None = None
morsel_seen = False
while 0 <= i < n:
# Start looking for a cookie
match = _COOKIE_PATTERN.match(header, i)
if not match:
# No more cookies
break
key, value = match.group("key"), match.group("val")
i = match.end(0)
lower_key = key.lower()
if key[0] == "$":
if not morsel_seen:
# We ignore attributes which pertain to the cookie
# mechanism as a whole, such as "$Version".
continue
# Process as attribute
if current_morsel is not None:
attr_lower_key = lower_key[1:]
if attr_lower_key in _COOKIE_KNOWN_ATTRS:
current_morsel[attr_lower_key] = value or ""
elif lower_key in _COOKIE_KNOWN_ATTRS:
if not morsel_seen:
# Invalid cookie string - attribute before cookie
break
if lower_key in _COOKIE_BOOL_ATTRS:
# Boolean attribute with any value should be True
if current_morsel is not None and current_morsel.isReservedKey(key):
current_morsel[lower_key] = True
elif value is None:
# Invalid cookie string - non-boolean attribute without value
break
elif current_morsel is not None:
# Regular attribute with value
current_morsel[lower_key] = _unquote(value)
elif value is not None:
# This is a cookie name=value pair
# Validate the name
if key in _COOKIE_KNOWN_ATTRS or not _COOKIE_NAME_RE.match(key):
internal_logger.warning(
"Can not load cookies: Illegal cookie name %r", key
)
current_morsel = None
else:
# Create new morsel
current_morsel = Morsel()
# Preserve the original value as coded_value (with quotes if present)
try:
current_morsel.__setstate__( # type: ignore[attr-defined]
{
"key": key,
"value": _unquote(value),
"coded_value": value,
}
)
except CookieError:
current_morsel = None
else:
parsed_cookies.append((key, current_morsel))
morsel_seen = True
else:
# Invalid cookie string - no value for non-attribute
break
return parsed_cookies

View File

@ -0,0 +1,159 @@
from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t
cdef extern from "llhttp.h":
struct llhttp__internal_s:
int32_t _index
void* _span_pos0
void* _span_cb0
int32_t error
const char* reason
const char* error_pos
void* data
void* _current
uint64_t content_length
uint8_t type
uint8_t method
uint8_t http_major
uint8_t http_minor
uint8_t header_state
uint8_t lenient_flags
uint8_t upgrade
uint8_t finish
uint16_t flags
uint16_t status_code
void* settings
ctypedef llhttp__internal_s llhttp__internal_t
ctypedef llhttp__internal_t llhttp_t
ctypedef int (*llhttp_data_cb)(llhttp_t*, const char *at, size_t length) except -1
ctypedef int (*llhttp_cb)(llhttp_t*) except -1
struct llhttp_settings_s:
llhttp_cb on_message_begin
llhttp_data_cb on_url
llhttp_data_cb on_status
llhttp_data_cb on_header_field
llhttp_data_cb on_header_value
llhttp_cb on_headers_complete
llhttp_data_cb on_body
llhttp_cb on_message_complete
llhttp_cb on_chunk_header
llhttp_cb on_chunk_complete
llhttp_cb on_url_complete
llhttp_cb on_status_complete
llhttp_cb on_header_field_complete
llhttp_cb on_header_value_complete
ctypedef llhttp_settings_s llhttp_settings_t
enum llhttp_errno:
HPE_OK,
HPE_INTERNAL,
HPE_STRICT,
HPE_LF_EXPECTED,
HPE_UNEXPECTED_CONTENT_LENGTH,
HPE_CLOSED_CONNECTION,
HPE_INVALID_METHOD,
HPE_INVALID_URL,
HPE_INVALID_CONSTANT,
HPE_INVALID_VERSION,
HPE_INVALID_HEADER_TOKEN,
HPE_INVALID_CONTENT_LENGTH,
HPE_INVALID_CHUNK_SIZE,
HPE_INVALID_STATUS,
HPE_INVALID_EOF_STATE,
HPE_INVALID_TRANSFER_ENCODING,
HPE_CB_MESSAGE_BEGIN,
HPE_CB_HEADERS_COMPLETE,
HPE_CB_MESSAGE_COMPLETE,
HPE_CB_CHUNK_HEADER,
HPE_CB_CHUNK_COMPLETE,
HPE_PAUSED,
HPE_PAUSED_UPGRADE,
HPE_USER
ctypedef llhttp_errno llhttp_errno_t
enum llhttp_flags:
F_CHUNKED,
F_CONTENT_LENGTH
enum llhttp_type:
HTTP_REQUEST,
HTTP_RESPONSE,
HTTP_BOTH
enum llhttp_method:
HTTP_DELETE,
HTTP_GET,
HTTP_HEAD,
HTTP_POST,
HTTP_PUT,
HTTP_CONNECT,
HTTP_OPTIONS,
HTTP_TRACE,
HTTP_COPY,
HTTP_LOCK,
HTTP_MKCOL,
HTTP_MOVE,
HTTP_PROPFIND,
HTTP_PROPPATCH,
HTTP_SEARCH,
HTTP_UNLOCK,
HTTP_BIND,
HTTP_REBIND,
HTTP_UNBIND,
HTTP_ACL,
HTTP_REPORT,
HTTP_MKACTIVITY,
HTTP_CHECKOUT,
HTTP_MERGE,
HTTP_MSEARCH,
HTTP_NOTIFY,
HTTP_SUBSCRIBE,
HTTP_UNSUBSCRIBE,
HTTP_PATCH,
HTTP_PURGE,
HTTP_MKCALENDAR,
HTTP_LINK,
HTTP_UNLINK,
HTTP_SOURCE,
HTTP_PRI,
HTTP_DESCRIBE,
HTTP_ANNOUNCE,
HTTP_SETUP,
HTTP_PLAY,
HTTP_PAUSE,
HTTP_TEARDOWN,
HTTP_GET_PARAMETER,
HTTP_SET_PARAMETER,
HTTP_REDIRECT,
HTTP_RECORD,
HTTP_FLUSH
ctypedef llhttp_method llhttp_method_t;
void llhttp_settings_init(llhttp_settings_t* settings)
void llhttp_init(llhttp_t* parser, llhttp_type type,
const llhttp_settings_t* settings)
llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len)
int llhttp_should_keep_alive(const llhttp_t* parser)
void llhttp_resume(llhttp_t* parser)
void llhttp_resume_after_upgrade(llhttp_t* parser)
llhttp_errno_t llhttp_get_errno(const llhttp_t* parser)
const char* llhttp_get_error_reason(const llhttp_t* parser)
const char* llhttp_get_error_pos(const llhttp_t* parser)
const char* llhttp_method_name(llhttp_method_t method)
void llhttp_set_lenient_headers(llhttp_t* parser, int enabled)
void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled)
void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled)

View File

@ -0,0 +1,2 @@
cdef extern from "_find_header.h":
int find_header(char *, int)

View File

@ -0,0 +1,83 @@
# The file is autogenerated from aiohttp/hdrs.py
# Run ./tools/gen.py to update it after the origin changing.
from . import hdrs
cdef tuple headers = (
hdrs.ACCEPT,
hdrs.ACCEPT_CHARSET,
hdrs.ACCEPT_ENCODING,
hdrs.ACCEPT_LANGUAGE,
hdrs.ACCEPT_RANGES,
hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
hdrs.ACCESS_CONTROL_ALLOW_METHODS,
hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
hdrs.ACCESS_CONTROL_MAX_AGE,
hdrs.ACCESS_CONTROL_REQUEST_HEADERS,
hdrs.ACCESS_CONTROL_REQUEST_METHOD,
hdrs.AGE,
hdrs.ALLOW,
hdrs.AUTHORIZATION,
hdrs.CACHE_CONTROL,
hdrs.CONNECTION,
hdrs.CONTENT_DISPOSITION,
hdrs.CONTENT_ENCODING,
hdrs.CONTENT_LANGUAGE,
hdrs.CONTENT_LENGTH,
hdrs.CONTENT_LOCATION,
hdrs.CONTENT_MD5,
hdrs.CONTENT_RANGE,
hdrs.CONTENT_TRANSFER_ENCODING,
hdrs.CONTENT_TYPE,
hdrs.COOKIE,
hdrs.DATE,
hdrs.DESTINATION,
hdrs.DIGEST,
hdrs.ETAG,
hdrs.EXPECT,
hdrs.EXPIRES,
hdrs.FORWARDED,
hdrs.FROM,
hdrs.HOST,
hdrs.IF_MATCH,
hdrs.IF_MODIFIED_SINCE,
hdrs.IF_NONE_MATCH,
hdrs.IF_RANGE,
hdrs.IF_UNMODIFIED_SINCE,
hdrs.KEEP_ALIVE,
hdrs.LAST_EVENT_ID,
hdrs.LAST_MODIFIED,
hdrs.LINK,
hdrs.LOCATION,
hdrs.MAX_FORWARDS,
hdrs.ORIGIN,
hdrs.PRAGMA,
hdrs.PROXY_AUTHENTICATE,
hdrs.PROXY_AUTHORIZATION,
hdrs.RANGE,
hdrs.REFERER,
hdrs.RETRY_AFTER,
hdrs.SEC_WEBSOCKET_ACCEPT,
hdrs.SEC_WEBSOCKET_EXTENSIONS,
hdrs.SEC_WEBSOCKET_KEY,
hdrs.SEC_WEBSOCKET_KEY1,
hdrs.SEC_WEBSOCKET_PROTOCOL,
hdrs.SEC_WEBSOCKET_VERSION,
hdrs.SERVER,
hdrs.SET_COOKIE,
hdrs.TE,
hdrs.TRAILER,
hdrs.TRANSFER_ENCODING,
hdrs.URI,
hdrs.UPGRADE,
hdrs.USER_AGENT,
hdrs.VARY,
hdrs.VIA,
hdrs.WWW_AUTHENTICATE,
hdrs.WANT_DIGEST,
hdrs.WARNING,
hdrs.X_FORWARDED_FOR,
hdrs.X_FORWARDED_HOST,
hdrs.X_FORWARDED_PROTO,
)

View File

@ -0,0 +1,965 @@
# Based on https://github.com/MagicStack/httptools
#
from cpython cimport (
Py_buffer,
PyBUF_SIMPLE,
PyBuffer_Release,
PyBytes_AsString,
PyBytes_AsStringAndSize,
PyObject_GetBuffer,
)
from cpython.mem cimport PyMem_Free, PyMem_Malloc
from libc.limits cimport ULLONG_MAX
from libc.string cimport memcpy
from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy
from yarl import URL as _URL
from aiohttp import hdrs
from aiohttp.helpers import DEBUG, set_exception
from .http_exceptions import (
BadHttpMessage,
BadHttpMethod,
BadStatusLine,
ContentLengthError,
InvalidHeader,
InvalidURLError,
LineTooLong,
PayloadEncodingError,
TransferEncodingError,
)
from .http_parser import DeflateBuffer as _DeflateBuffer
from .http_writer import (
HttpVersion as _HttpVersion,
HttpVersion10 as _HttpVersion10,
HttpVersion11 as _HttpVersion11,
)
from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader
cimport cython
from aiohttp cimport _cparser as cparser
include "_headers.pxi"
from aiohttp cimport _find_header
cdef frozenset ALLOWED_UPGRADES = frozenset({"websocket"})
DEF DEFAULT_FREELIST_SIZE = 250
cdef extern from "Python.h":
int PyByteArray_Resize(object, Py_ssize_t) except -1
Py_ssize_t PyByteArray_Size(object) except -1
char* PyByteArray_AsString(object)
__all__ = ('HttpRequestParser', 'HttpResponseParser',
'RawRequestMessage', 'RawResponseMessage')
cdef object URL = _URL
cdef object URL_build = URL.build
cdef object CIMultiDict = _CIMultiDict
cdef object CIMultiDictProxy = _CIMultiDictProxy
cdef object HttpVersion = _HttpVersion
cdef object HttpVersion10 = _HttpVersion10
cdef object HttpVersion11 = _HttpVersion11
cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1
cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING
cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD
cdef object StreamReader = _StreamReader
cdef object DeflateBuffer = _DeflateBuffer
cdef tuple EMPTY_FEED_DATA_RESULT = ((), False, b"")
# RFC 9110 singleton headers — duplicates are rejected in strict mode.
# In lax mode (response parser default), the check is skipped entirely
# since real-world servers (e.g. Google APIs, Werkzeug) commonly send
# duplicate headers like Content-Type or Server.
cdef frozenset SINGLETON_HEADERS = frozenset({
hdrs.CONTENT_LENGTH,
hdrs.CONTENT_LOCATION,
hdrs.CONTENT_RANGE,
hdrs.CONTENT_TYPE,
hdrs.ETAG,
hdrs.HOST,
hdrs.MAX_FORWARDS,
hdrs.SERVER,
hdrs.TRANSFER_ENCODING,
hdrs.USER_AGENT,
})
cdef inline object extend(object buf, const char* at, size_t length):
cdef Py_ssize_t s
cdef char* ptr
s = PyByteArray_Size(buf)
PyByteArray_Resize(buf, s + length)
ptr = PyByteArray_AsString(buf)
memcpy(ptr + s, at, length)
DEF METHODS_COUNT = 46;
cdef list _http_method = []
for i in range(METHODS_COUNT):
_http_method.append(
cparser.llhttp_method_name(<cparser.llhttp_method_t> i).decode('ascii'))
cdef inline str http_method_str(int i):
if i < METHODS_COUNT:
return <str>_http_method[i]
else:
return "<unknown>"
cdef inline object find_header(bytes raw_header):
cdef Py_ssize_t size
cdef char *buf
cdef int idx
PyBytes_AsStringAndSize(raw_header, &buf, &size)
idx = _find_header.find_header(buf, size)
if idx == -1:
return raw_header.decode('utf-8', 'surrogateescape')
return headers[idx]
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawRequestMessage:
cdef readonly str method
cdef readonly str path
cdef readonly object version # HttpVersion
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
cdef readonly object url # yarl.URL
def __init__(self, method, path, version, headers, raw_headers,
should_close, compression, upgrade, chunked, url):
self.method = method
self.path = path
self.version = version
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
self.url = url
def __repr__(self):
info = []
info.append(("method", self.method))
info.append(("path", self.path))
info.append(("version", self.version))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
info.append(("url", self.url))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawRequestMessage(' + sinfo + ')>'
def _replace(self, **dct):
cdef RawRequestMessage ret
ret = _new_request_message(self.method,
self.path,
self.version,
self.headers,
self.raw_headers,
self.should_close,
self.compression,
self.upgrade,
self.chunked,
self.url)
if "method" in dct:
ret.method = dct["method"]
if "path" in dct:
ret.path = dct["path"]
if "version" in dct:
ret.version = dct["version"]
if "headers" in dct:
ret.headers = dct["headers"]
if "raw_headers" in dct:
ret.raw_headers = dct["raw_headers"]
if "should_close" in dct:
ret.should_close = dct["should_close"]
if "compression" in dct:
ret.compression = dct["compression"]
if "upgrade" in dct:
ret.upgrade = dct["upgrade"]
if "chunked" in dct:
ret.chunked = dct["chunked"]
if "url" in dct:
ret.url = dct["url"]
return ret
cdef _new_request_message(str method,
str path,
object version,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked,
object url):
cdef RawRequestMessage ret
ret = RawRequestMessage.__new__(RawRequestMessage)
ret.method = method
ret.path = path
ret.version = version
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
ret.url = url
return ret
@cython.freelist(DEFAULT_FREELIST_SIZE)
cdef class RawResponseMessage:
cdef readonly object version # HttpVersion
cdef readonly int code
cdef readonly str reason
cdef readonly object headers # CIMultiDict
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
cdef readonly object upgrade
cdef readonly object chunked
def __init__(self, version, code, reason, headers, raw_headers,
should_close, compression, upgrade, chunked):
self.version = version
self.code = code
self.reason = reason
self.headers = headers
self.raw_headers = raw_headers
self.should_close = should_close
self.compression = compression
self.upgrade = upgrade
self.chunked = chunked
def __repr__(self):
info = []
info.append(("version", self.version))
info.append(("code", self.code))
info.append(("reason", self.reason))
info.append(("headers", self.headers))
info.append(("raw_headers", self.raw_headers))
info.append(("should_close", self.should_close))
info.append(("compression", self.compression))
info.append(("upgrade", self.upgrade))
info.append(("chunked", self.chunked))
sinfo = ', '.join(name + '=' + repr(val) for name, val in info)
return '<RawResponseMessage(' + sinfo + ')>'
cdef _new_response_message(object version,
int code,
str reason,
object headers,
object raw_headers,
bint should_close,
object compression,
bint upgrade,
bint chunked):
cdef RawResponseMessage ret
ret = RawResponseMessage.__new__(RawResponseMessage)
ret.version = version
ret.code = code
ret.reason = reason
ret.headers = headers
ret.raw_headers = raw_headers
ret.should_close = should_close
ret.compression = compression
ret.upgrade = upgrade
ret.chunked = chunked
return ret
@cython.internal
cdef class HttpParser:
cdef:
cparser.llhttp_t* _cparser
cparser.llhttp_settings_t* _csettings
bytes _raw_name
object _name
bytes _raw_value
bint _has_value
int _header_name_size
readonly object protocol
object _loop
object _timer
size_t _max_line_size
size_t _max_field_size
size_t _max_headers
bint _response_with_body
bint _read_until_eof
bint _lax
bytes _tail
bint _started
object _url
bytearray _buf
str _path
str _reason
list _headers
set _seen_singletons
list _raw_headers
bint _upgraded
list _messages
bint _more_data_available
bint _paused
Py_ssize_t _msg_in_flight
Py_ssize_t _max_msg_queue_size
bint _eof_pending
object _payload
unsigned long long _content_length_expected
bint _payload_error
object _payload_exception
object _last_error
bint _auto_decompress
int _limit
str _content_encoding
Py_buffer py_buf
def __cinit__(self):
self._cparser = <cparser.llhttp_t*> \
PyMem_Malloc(sizeof(cparser.llhttp_t))
if self._cparser is NULL:
raise MemoryError()
self._csettings = <cparser.llhttp_settings_t*> \
PyMem_Malloc(sizeof(cparser.llhttp_settings_t))
if self._csettings is NULL:
raise MemoryError()
def __dealloc__(self):
PyMem_Free(self._cparser)
PyMem_Free(self._csettings)
cdef _init(
self, cparser.llhttp_type mode,
object protocol, object loop, int limit,
object timer=None,
size_t max_line_size=8190, size_t max_headers=128,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True,
Py_ssize_t max_msg_queue_size=0,
):
cparser.llhttp_settings_init(self._csettings)
cparser.llhttp_init(self._cparser, mode, self._csettings)
self._cparser.data = <void*>self
self._cparser.content_length = 0
self._content_length_expected = 0
self.protocol = protocol
self._loop = loop
self._timer = timer
self._buf = bytearray()
self._more_data_available = False
self._paused = False
self._msg_in_flight = 0
self._max_msg_queue_size = max_msg_queue_size
self._eof_pending = False
self._payload = None
self._payload_error = 0
self._payload_exception = payload_exception
self._messages = []
self._raw_name = b""
self._raw_value = b""
self._tail = b""
self._has_value = False
self._header_name_size = 0
self._max_line_size = max_line_size
self._max_headers = max_headers
self._max_field_size = max_field_size
self._response_with_body = response_with_body
self._read_until_eof = read_until_eof
self._upgraded = False
self._auto_decompress = auto_decompress
self._content_encoding = None
self._lax = False
self._seen_singletons = set()
self._csettings.on_url = cb_on_url
self._csettings.on_status = cb_on_status
self._csettings.on_header_field = cb_on_header_field
self._csettings.on_header_value = cb_on_header_value
self._csettings.on_headers_complete = cb_on_headers_complete
self._csettings.on_body = cb_on_body
self._csettings.on_message_begin = cb_on_message_begin
self._csettings.on_message_complete = cb_on_message_complete
self._csettings.on_chunk_header = cb_on_chunk_header
self._csettings.on_chunk_complete = cb_on_chunk_complete
self._last_error = None
self._limit = limit
cdef _process_header(self):
cdef str value
if self._raw_name != b"":
name = find_header(self._raw_name)
value = self._raw_value.decode('utf-8', 'surrogateescape')
# reject null bytes in header values - matches the Python parser
# check at http_parser.py. llhttp in lenient mode doesn't reject
# these itself, so we need to catch them here.
# ref: RFC 9110 section 5.5 (CTL chars forbidden in field values)
if "\x00" in value:
raise InvalidHeader(self._raw_value)
if not self._lax and name in SINGLETON_HEADERS:
if name in self._seen_singletons:
raise BadHttpMessage(f"Duplicate '{name}' header found.")
self._seen_singletons.add(name)
self._headers.append((name, value))
if len(self._headers) > self._max_headers:
raise BadHttpMessage("Too many headers received")
if name is CONTENT_ENCODING:
self._content_encoding = value
self._has_value = False
self._header_name_size = 0
self._raw_headers.append((self._raw_name, self._raw_value))
self._raw_name = b""
self._raw_value = b""
cdef _on_header_field(self, char* at, size_t length):
if self._has_value:
self._process_header()
if self._raw_name == b"":
self._raw_name = at[:length]
else:
self._raw_name += at[:length]
cdef _on_header_value(self, char* at, size_t length):
if self._raw_value == b"":
self._raw_value = at[:length]
else:
self._raw_value += at[:length]
self._has_value = True
cdef _on_headers_complete(self):
cdef str h_upg
cdef str enc
self._process_header()
http_version = self.http_version()
should_close = not cparser.llhttp_should_keep_alive(self._cparser)
upgrade = self._cparser.upgrade
chunked = self._cparser.flags & cparser.F_CHUNKED
raw_headers = tuple(self._raw_headers)
headers = CIMultiDictProxy(CIMultiDict(self._headers))
if self._cparser.type == cparser.HTTP_REQUEST:
if http_version == HttpVersion11 and hdrs.HOST not in headers:
raise BadHttpMessage("Missing 'Host' header in request.")
h_upg = headers.get("upgrade", "")
if (upgrade and h_upg.isascii() and h_upg.lower() in ALLOWED_UPGRADES) or self._cparser.method == cparser.HTTP_CONNECT:
self._upgraded = True
else:
if upgrade and self._cparser.status_code == 101:
self._upgraded = True
# do not support old websocket spec
if SEC_WEBSOCKET_KEY1 in headers:
raise InvalidHeader(SEC_WEBSOCKET_KEY1)
encoding = None
enc = self._content_encoding
if enc is not None:
self._content_encoding = None
if enc.isascii() and enc.lower() in {"gzip", "deflate", "br", "zstd"}:
encoding = enc
if self._cparser.type == cparser.HTTP_REQUEST:
method = http_method_str(self._cparser.method)
msg = _new_request_message(
method, self._path,
http_version, headers, raw_headers,
should_close, encoding, upgrade, chunked, self._url)
else:
msg = _new_response_message(
http_version, self._cparser.status_code, self._reason,
headers, raw_headers, should_close, encoding,
upgrade, chunked)
if (
self._response_with_body
and (
ULLONG_MAX > self._cparser.content_length > 0 or chunked or
self._cparser.method == cparser.HTTP_CONNECT or
(self._cparser.status_code >= 199 and
self._cparser.content_length == 0 and
self._read_until_eof)
)
):
payload = StreamReader(
self.protocol, timer=self._timer, loop=self._loop,
limit=self._limit)
else:
payload = EMPTY_PAYLOAD
self._payload = payload
self._content_length_expected = self._cparser.content_length
if encoding is not None and self._auto_decompress:
self._payload = DeflateBuffer(payload, encoding, max_decompress_size=self._limit)
self._messages.append((msg, payload))
cdef _on_message_complete(self):
self._payload.feed_eof()
self._payload = None
cdef _on_chunk_header(self):
self._payload.begin_http_chunk_receiving()
cdef _on_chunk_complete(self):
self._payload.end_http_chunk_receiving()
cdef object _on_status_complete(self):
pass
cdef inline http_version(self):
cdef cparser.llhttp_t* parser = self._cparser
if parser.http_major == 1:
if parser.http_minor == 0:
return HttpVersion10
elif parser.http_minor == 1:
return HttpVersion11
return HttpVersion(parser.http_major, parser.http_minor)
### Public API ###
def pause_reading(self):
assert self._payload is not None
self._paused = True
def message_consumed(self):
# Protocol drained a queued message; free a slot for parsing.
if self._msg_in_flight > 0:
self._msg_in_flight -= 1
def feed_eof(self):
cdef bytes desc
if self._payload is not None:
if self._cparser.flags & cparser.F_CHUNKED:
raise TransferEncodingError(
"Not enough data to satisfy transfer length header.")
elif self._cparser.flags & cparser.F_CONTENT_LENGTH:
received = self._content_length_expected - self._cparser.content_length
raise ContentLengthError(
f"Not enough data to satisfy content length header "
f"(received {received} of {self._content_length_expected} bytes).")
elif cparser.llhttp_get_errno(self._cparser) != cparser.HPE_OK:
desc = cparser.llhttp_get_error_reason(self._cparser)
raise PayloadEncodingError(desc.decode('latin-1'))
else:
self._eof_pending = True
while self._more_data_available:
if self._paused:
self._paused = False
return # Will resume via feed_data(b"") later
self._more_data_available = self._payload.feed_data(b"", 0)
self._payload.feed_eof()
self._payload = None
self._more_data_available = False
self._eof_pending = False
elif self._started:
self._on_headers_complete()
if self._messages:
return self._messages[-1][0]
def feed_data(self, incoming_data):
cdef:
size_t data_len
size_t nb
char* base
cdef cparser.llhttp_errno_t errno
cdef bytes data
# Proactor loop sends bytearray.
# Ensure cython sees `data` as bytes
if type(incoming_data) is not bytes:
data = bytes(incoming_data)
else:
data = incoming_data
if self._tail:
data, self._tail = self._tail + data, b""
if self._more_data_available:
result = cb_on_body(self._cparser, b"", 0)
if result is cparser.HPE_PAUSED:
self._tail = data
return EMPTY_FEED_DATA_RESULT
if self._eof_pending:
self._payload.feed_eof()
self._payload = None
self._eof_pending = False
# We can't have new messages here, otherwise we wouldn't have
# received EOF.
return EMPTY_FEED_DATA_RESULT
PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE)
# Cache buffer pointer before PyBuffer_Release to avoid use-after-release.
base = <char*>self.py_buf.buf
data_len = <size_t>self.py_buf.len
errno = cparser.llhttp_execute(
self._cparser,
base,
data_len)
if errno is cparser.HPE_PAUSED_UPGRADE:
cparser.llhttp_resume_after_upgrade(self._cparser)
nb = cparser.llhttp_get_error_pos(self._cparser) - base
elif errno is cparser.HPE_PAUSED:
cparser.llhttp_resume(self._cparser)
pos = cparser.llhttp_get_error_pos(self._cparser) - base
self._tail = data[pos:]
PyBuffer_Release(&self.py_buf)
if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED, cparser.HPE_PAUSED_UPGRADE):
if self._payload_error == 0:
if self._last_error is not None:
ex = self._last_error
self._last_error = None
else:
after = cparser.llhttp_get_error_pos(self._cparser)
before = data[:after - base]
after_b = after.split(b"\r\n", 1)[0]
before = before.rsplit(b"\r\n", 1)[-1]
data = before + after_b
pointer = " " * (len(repr(before))-1) + "^"
ex = parser_error_from_errno(self._cparser, data, pointer)
self._payload = None
raise ex
if self._messages:
messages = self._messages
self._messages = []
else:
messages = ()
if self._upgraded:
return messages, True, data[nb:]
if not messages: # Shortcut to reduce Python overhead
return EMPTY_FEED_DATA_RESULT
return messages, False, b""
def set_upgraded(self, val):
self._upgraded = val
cdef class HttpRequestParser(HttpParser):
def __init__(
self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=128,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True, Py_ssize_t max_msg_queue_size=0,
):
self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress, max_msg_queue_size)
cdef object _on_status_complete(self):
cdef int idx1, idx2
if not self._buf:
return
self._path = self._buf.decode('utf-8', 'surrogateescape')
try:
idx3 = len(self._path)
if self._cparser.method == cparser.HTTP_CONNECT:
# authority-form,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3
self._url = URL.build(authority=self._path, encoded=True)
elif idx3 > 1 and self._path[0] == '/':
# origin-form,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1
idx1 = self._path.find("?")
if idx1 == -1:
query = ""
idx2 = self._path.find("#")
if idx2 == -1:
path = self._path
fragment = ""
else:
path = self._path[0: idx2]
fragment = self._path[idx2+1:]
else:
path = self._path[0:idx1]
idx1 += 1
idx2 = self._path.find("#", idx1+1)
if idx2 == -1:
query = self._path[idx1:]
fragment = ""
else:
query = self._path[idx1: idx2]
fragment = self._path[idx2+1:]
self._url = URL.build(
path=path,
query_string=query,
fragment=fragment,
encoded=True,
)
else:
# absolute-form for proxy maybe,
# https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2
self._url = URL(self._path, encoded=True)
finally:
PyByteArray_Resize(self._buf, 0)
cdef class HttpResponseParser(HttpParser):
def __init__(
self, protocol, loop, int limit, timer=None,
size_t max_line_size=8190, size_t max_headers=128,
size_t max_field_size=8190, payload_exception=None,
bint response_with_body=True, bint read_until_eof=False,
bint auto_decompress=True
):
self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer,
max_line_size, max_headers, max_field_size,
payload_exception, response_with_body, read_until_eof,
auto_decompress)
# Use strict parsing on dev mode, so users are warned about broken servers.
if not DEBUG:
cparser.llhttp_set_lenient_headers(self._cparser, 1)
cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1)
cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1)
self._lax = True
cdef object _on_status_complete(self):
if self._buf:
self._reason = self._buf.decode('utf-8', 'surrogateescape')
PyByteArray_Resize(self._buf, 0)
else:
self._reason = self._reason or ''
cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
pyparser._started = True
pyparser._headers = []
pyparser._seen_singletons = set()
pyparser._raw_headers = []
PyByteArray_Resize(pyparser._buf, 0)
pyparser._path = None
pyparser._reason = None
return 0
cdef int cb_on_url(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
if len(pyparser._buf) + length > pyparser._max_line_size:
status = pyparser._buf + at[:length]
raise LineTooLong(status[:100] + b"...", pyparser._max_line_size)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_status(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
if len(pyparser._buf) + length > pyparser._max_line_size:
reason = pyparser._buf + at[:length]
raise LineTooLong(reason[:100] + b"...", pyparser._max_line_size)
extend(pyparser._buf, at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_field(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
pyparser._on_status_complete()
size = len(pyparser._raw_name) + length
if size > pyparser._max_field_size:
name = pyparser._raw_name + at[:length]
raise LineTooLong(name[:100] + b"...", pyparser._max_field_size)
pyparser._header_name_size = size
pyparser._on_header_field(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_header_value(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef Py_ssize_t size
try:
size = len(pyparser._raw_value) + length
if pyparser._header_name_size + size > pyparser._max_field_size:
value = pyparser._raw_value + at[:length]
raise LineTooLong(value[:100] + b"...", pyparser._max_field_size)
pyparser._on_header_value(at, length)
except BaseException as ex:
pyparser._last_error = ex
return -1
else:
return 0
cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_status_complete()
pyparser._on_headers_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
if pyparser._upgraded or pyparser._cparser.method == cparser.HTTP_CONNECT:
return 2
if not pyparser._response_with_body:
return 1
return 0
cdef int cb_on_body(cparser.llhttp_t* parser,
const char *at, size_t length) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
cdef bytes body = at[:length]
while body or pyparser._more_data_available:
try:
pyparser._more_data_available = pyparser._payload.feed_data(body, length)
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if pyparser._payload_exception is not None:
reraised_exc = pyparser._payload_exception(str(underlying_exc))
set_exception(pyparser._payload, reraised_exc, underlying_exc)
pyparser._payload_error = 1
pyparser._paused = False
return -1
body = b""
length = 0
if pyparser._paused:
pyparser._paused = False
return cparser.HPE_PAUSED
pyparser._paused = False
return 0
cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._started = False
pyparser._on_message_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
if pyparser._max_msg_queue_size:
pyparser._msg_in_flight += 1
if pyparser._msg_in_flight >= pyparser._max_msg_queue_size:
# Queue full: pause llhttp between messages. feed_data() buffers
# the remainder as tail; resumes once the queue drains.
return cparser.HPE_PAUSED
return 0
cdef int cb_on_chunk_header(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_header()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data
try:
pyparser._on_chunk_complete()
except BaseException as exc:
pyparser._last_error = exc
return -1
else:
return 0
cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer):
cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser)
cdef bytes desc = cparser.llhttp_get_error_reason(parser)
err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer)
if errno in {cparser.HPE_CB_MESSAGE_BEGIN,
cparser.HPE_CB_HEADERS_COMPLETE,
cparser.HPE_CB_MESSAGE_COMPLETE,
cparser.HPE_CB_CHUNK_HEADER,
cparser.HPE_CB_CHUNK_COMPLETE,
cparser.HPE_INVALID_HEADER_TOKEN,
cparser.HPE_INVALID_CONTENT_LENGTH,
cparser.HPE_INVALID_CHUNK_SIZE,
cparser.HPE_INVALID_EOF_STATE,
cparser.HPE_INVALID_TRANSFER_ENCODING}:
return BadHttpMessage(err_msg)
elif errno == cparser.HPE_INVALID_METHOD:
if data.startswith(b"\x16\x03"):
return BadHttpMethod(error="Received HTTPS traffic on an HTTP port")
return BadHttpMethod(error=err_msg)
elif errno in {cparser.HPE_INVALID_STATUS,
cparser.HPE_INVALID_VERSION,
cparser.HPE_INVALID_CONSTANT}:
return BadStatusLine(error=f"Bad status line:\n {err_msg}")
elif errno == cparser.HPE_INVALID_URL:
return InvalidURLError(err_msg)
return BadHttpMessage(err_msg)

View File

@ -0,0 +1,164 @@
from cpython.bytes cimport PyBytes_FromStringAndSize
from cpython.exc cimport PyErr_NoMemory
from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc
from cpython.object cimport PyObject_Str
from libc.stdint cimport uint8_t, uint64_t
from libc.string cimport memcpy
from multidict import istr
DEF BUF_SIZE = 16 * 1024 # 16KiB
cdef object _istr = istr
# ----------------- writer ---------------------------
cdef struct Writer:
char *buf
Py_ssize_t size
Py_ssize_t pos
bint heap_allocated
cdef inline void _init_writer(Writer* writer, char *buf):
writer.buf = buf
writer.size = BUF_SIZE
writer.pos = 0
writer.heap_allocated = 0
cdef inline void _release_writer(Writer* writer):
if writer.heap_allocated:
PyMem_Free(writer.buf)
cdef inline int _write_byte(Writer* writer, uint8_t ch):
cdef char * buf
cdef Py_ssize_t size
if writer.pos == writer.size:
# reallocate
size = writer.size + BUF_SIZE
if not writer.heap_allocated:
buf = <char*>PyMem_Malloc(size)
if buf == NULL:
PyErr_NoMemory()
return -1
memcpy(buf, writer.buf, writer.size)
else:
buf = <char*>PyMem_Realloc(writer.buf, size)
if buf == NULL:
PyErr_NoMemory()
return -1
writer.buf = buf
writer.size = size
writer.heap_allocated = 1
writer.buf[writer.pos] = <char>ch
writer.pos += 1
return 0
cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol):
cdef uint64_t utf = <uint64_t> symbol
if utf < 0x80:
return _write_byte(writer, <uint8_t>utf)
elif utf < 0x800:
if _write_byte(writer, <uint8_t>(0xc0 | (utf >> 6))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif 0xD800 <= utf <= 0xDFFF:
# surogate pair, ignored
return 0
elif utf < 0x10000:
if _write_byte(writer, <uint8_t>(0xe0 | (utf >> 12))) < 0:
return -1
if _write_byte(writer, <uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
elif utf > 0x10FFFF:
# symbol is too large
return 0
else:
if _write_byte(writer, <uint8_t>(0xf0 | (utf >> 18))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 12) & 0x3f))) < 0:
return -1
if _write_byte(writer,
<uint8_t>(0x80 | ((utf >> 6) & 0x3f))) < 0:
return -1
return _write_byte(writer, <uint8_t>(0x80 | (utf & 0x3f)))
cdef inline int _write_str(Writer* writer, str s):
cdef Py_UCS4 ch
for ch in s:
if _write_utf8(writer, ch) < 0:
return -1
cdef inline int _write_str_raise_on_nlcr(Writer* writer, object s):
cdef Py_UCS4 ch
cdef str out_str
if type(s) is str:
out_str = <str>s
elif type(s) is _istr:
out_str = PyObject_Str(s)
elif not isinstance(s, str):
raise TypeError("Cannot serialize non-str key {!r}".format(s))
else:
out_str = str(s)
for ch in out_str:
# https://www.rfc-editor.org/info/rfc9110/#section-5.5-5
# https://www.rfc-editor.org/info/rfc9112/#section-4-3
if (ch < 0x20 and ch != 0x09) or ch == 0x7F:
raise ValueError(
"Forbidden control character detected in headers. "
"Potential header injection attack."
)
if _write_utf8(writer, ch) < 0:
return -1
# --------------- _serialize_headers ----------------------
def _serialize_headers(str status_line, headers):
cdef Writer writer
cdef object key
cdef object val
cdef char buf[BUF_SIZE]
_init_writer(&writer, buf)
try:
if _write_str_raise_on_nlcr(&writer, status_line) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
for key, val in headers.items():
if _write_str_raise_on_nlcr(&writer, key) < 0:
raise
if _write_byte(&writer, b':') < 0:
raise
if _write_byte(&writer, b' ') < 0:
raise
if _write_str_raise_on_nlcr(&writer, val) < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
if _write_byte(&writer, b'\r') < 0:
raise
if _write_byte(&writer, b'\n') < 0:
raise
return PyBytes_FromStringAndSize(writer.buf, writer.pos)
finally:
_release_writer(&writer)

View File

@ -0,0 +1 @@
b01999d409b29bd916e067bc963d5f2d9ee63cfc9ae0bccb769910131417bf93 /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pxd

View File

@ -0,0 +1 @@
0478ceb55d0ed30ef1a7da742cd003449bc69a07cf9fdb06789bd2b347cbfffe /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pyx

View File

@ -0,0 +1 @@
97e3831a92693b1e05c69b02b644722139a646f065468f26bfceea36079065ba /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/reader_c.pxd

View File

@ -0,0 +1 @@
"""WebSocket protocol versions 13 and 8."""

View File

@ -0,0 +1,148 @@
"""Helpers for WebSocket protocol versions 13 and 8."""
import functools
import re
from re import Pattern
from struct import Struct
from typing import TYPE_CHECKING, Final
from ..helpers import NO_EXTENSIONS
from .models import WSHandshakeError
UNPACK_LEN3 = Struct("!Q").unpack_from
UNPACK_CLOSE_CODE = Struct("!H").unpack
PACK_LEN1 = Struct("!BB").pack
PACK_LEN2 = Struct("!BBH").pack
PACK_LEN3 = Struct("!BBQ").pack
PACK_CLOSE_CODE = Struct("!H").pack
PACK_RANDBITS = Struct("!L").pack
MSG_SIZE: Final[int] = 2**14
MASK_LEN: Final[int] = 4
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
# Used by _websocket_mask_python
@functools.lru_cache
def _xor_table() -> list[bytes]:
return [bytes(a ^ b for a in range(256)) for b in range(256)]
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
"""Websocket masking function.
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
object of any length. The contents of `data` are masked with `mask`,
as specified in section 5.3 of RFC 6455.
Note that this function mutates the `data` argument.
This pure-python implementation may be replaced by an optimized
version when available.
"""
assert isinstance(data, bytearray), data
assert len(mask) == 4, mask
if data:
_XOR_TABLE = _xor_table()
a, b, c, d = (_XOR_TABLE[n] for n in mask)
data[::4] = data[::4].translate(a)
data[1::4] = data[1::4].translate(b)
data[2::4] = data[2::4].translate(c)
data[3::4] = data[3::4].translate(d)
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
websocket_mask = _websocket_mask_python
else:
try:
from .mask import _websocket_mask_cython # type: ignore[import-not-found]
websocket_mask = _websocket_mask_cython
except ImportError: # pragma: no cover
websocket_mask = _websocket_mask_python
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
r"^(?:;\s*(?:"
r"(server_no_context_takeover)|"
r"(client_no_context_takeover)|"
r"(server_max_window_bits(?:=(\d+))?)|"
r"(client_max_window_bits(?:=(\d+))?)))*$"
)
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
def ws_ext_parse(extstr: str | None, isserver: bool = False) -> tuple[int, bool]:
if not extstr:
return 0, False
compress = 0
notakeover = False
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
defext = ext.group(1)
# Return compress = 15 when get `permessage-deflate`
if not defext:
compress = 15
break
match = _WS_EXT_RE.match(defext)
if match:
compress = 15
if isserver:
# Server never fail to detect compress handshake.
# Server does not need to send max wbit to client
if match.group(4):
compress = int(match.group(4))
# Group3 must match if group4 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# CONTINUE to next extension
if compress > 15 or compress < 9:
compress = 0
continue
if match.group(1):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
else:
if match.group(6):
compress = int(match.group(6))
# Group5 must match if group6 matches
# Compress wbit 8 does not support in zlib
# If compress level not support,
# FAIL the parse progress
if compress > 15 or compress < 9:
raise WSHandshakeError("Invalid window size")
if match.group(2):
notakeover = True
# Ignore regex group 5 & 6 for client_max_window_bits
break
# Return Fail if client side and not match
elif not isserver:
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
return compress, notakeover
def ws_ext_gen(
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
) -> str:
# client_notakeover=False not used for server
# compress wbit 8 does not support in zlib
if compress < 9 or compress > 15:
raise ValueError(
"Compress wbits must between 9 and 15, zlib does not support wbits=8"
)
enabledext = ["permessage-deflate"]
if not isserver:
enabledext.append("client_max_window_bits")
if compress < 15:
enabledext.append("server_max_window_bits=" + str(compress))
if server_notakeover:
enabledext.append("server_no_context_takeover")
# if client_notakeover:
# enabledext.append('client_no_context_takeover')
return "; ".join(enabledext)

View File

@ -0,0 +1,3 @@
"""Cython declarations for websocket masking."""
cpdef void _websocket_mask_cython(bytes mask, bytearray data)

View File

@ -0,0 +1,48 @@
from cpython cimport PyBytes_AsString
#from cpython cimport PyByteArray_AsString # cython still not exports that
cdef extern from "Python.h":
char* PyByteArray_AsString(bytearray ba) except NULL
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
cpdef void _websocket_mask_cython(bytes mask, bytearray data):
"""Note, this function mutates its `data` argument
"""
cdef:
Py_ssize_t data_len, i
# bit operations on signed integers are implementation-specific
unsigned char * in_buf
const unsigned char * mask_buf
uint32_t uint32_msk
uint64_t uint64_msk
assert len(mask) == 4
data_len = len(data)
in_buf = <unsigned char*>PyByteArray_AsString(data)
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
uint32_msk = (<uint32_t*>mask_buf)[0]
# TODO: align in_data ptr to achieve even faster speeds
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
if sizeof(size_t) >= 8:
uint64_msk = uint32_msk
uint64_msk = (uint64_msk << 32) | uint32_msk
while data_len >= 8:
(<uint64_t*>in_buf)[0] ^= uint64_msk
in_buf += 8
data_len -= 8
while data_len >= 4:
(<uint32_t*>in_buf)[0] ^= uint32_msk
in_buf += 4
data_len -= 4
for i in range(0, data_len):
in_buf[i] ^= mask_buf[i]

View File

@ -0,0 +1,107 @@
"""Models for WebSocket protocol versions 13 and 8."""
import json
from collections.abc import Callable
from enum import IntEnum
from typing import Any, Final, NamedTuple, cast
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
class WSCloseCode(IntEnum):
OK = 1000
GOING_AWAY = 1001
PROTOCOL_ERROR = 1002
UNSUPPORTED_DATA = 1003
ABNORMAL_CLOSURE = 1006
INVALID_TEXT = 1007
POLICY_VIOLATION = 1008
MESSAGE_TOO_BIG = 1009
MANDATORY_EXTENSION = 1010
INTERNAL_ERROR = 1011
SERVICE_RESTART = 1012
TRY_AGAIN_LATER = 1013
BAD_GATEWAY = 1014
class WSMsgType(IntEnum):
# websocket spec types
CONTINUATION = 0x0
TEXT = 0x1
BINARY = 0x2
PING = 0x9
PONG = 0xA
CLOSE = 0x8
# aiohttp specific types
CLOSING = 0x100
CLOSED = 0x101
ERROR = 0x102
text = TEXT
binary = BINARY
ping = PING
pong = PONG
close = CLOSE
closing = CLOSING
closed = CLOSED
error = ERROR
class WSMessage(NamedTuple):
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
data: Any
extra: str | None
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data.
.. versionadded:: 0.22
"""
return loads(self.data)
class WSMessageTextBytes(NamedTuple):
"""WebSocket TEXT message with raw bytes (no UTF-8 decoding)."""
type: WSMsgType
# To type correctly, this would need some kind of tagged union for each type.
# In 4.0, we use a union of message types to properly type data, but in 3.x
# we keep it as Any to avoid a breaking change.
data: Any
extra: str | None
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
"""Return parsed JSON data."""
return loads(self.data)
# Type aliases for message types based on decode_text setting
# When decode_text=True, TEXT messages have str data (WSMessage)
# When decode_text=False, TEXT messages have bytes data (WSMessageTextBytes)
WSMessageDecodeText = WSMessage
WSMessageNoDecodeText = WSMessage | WSMessageTextBytes
# Constructing the tuple directly to avoid the overhead of
# the lambda and arg processing since NamedTuples are constructed
# with a run time built lambda
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
class WebSocketError(Exception):
"""WebSocket protocol parser error."""
def __init__(self, code: int, message: str) -> None:
self.code = code
super().__init__(code, message)
def __str__(self) -> str:
return cast(str, self.args[1])
class WSHandshakeError(Exception):
"""WebSocket protocol handshake error."""

View File

@ -0,0 +1,31 @@
"""Reader for WebSocket protocol versions 13 and 8."""
from typing import TYPE_CHECKING
from ..helpers import NO_EXTENSIONS
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)
WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython
else:
try:
from .reader_c import ( # type: ignore[import-not-found]
WebSocketDataQueue as WebSocketDataQueueCython,
WebSocketReader as WebSocketReaderCython,
)
WebSocketReader = WebSocketReaderCython
WebSocketDataQueue = WebSocketDataQueueCython
except ImportError: # pragma: no cover
from .reader_py import (
WebSocketDataQueue as WebSocketDataQueuePython,
WebSocketReader as WebSocketReaderPython,
)
WebSocketReader = WebSocketReaderPython
WebSocketDataQueue = WebSocketDataQueuePython

View File

@ -0,0 +1,112 @@
import cython
from .mask cimport _websocket_mask_cython as websocket_mask
cdef unsigned int READ_HEADER
cdef unsigned int READ_PAYLOAD_LENGTH
cdef unsigned int READ_PAYLOAD_MASK
cdef unsigned int READ_PAYLOAD
cdef int OP_CODE_NOT_SET
cdef int OP_CODE_CONTINUATION
cdef int OP_CODE_TEXT
cdef int OP_CODE_BINARY
cdef int OP_CODE_CLOSE
cdef int OP_CODE_PING
cdef int OP_CODE_PONG
cdef int COMPRESSED_NOT_SET
cdef int COMPRESSED_FALSE
cdef int COMPRESSED_TRUE
cdef object UNPACK_LEN3
cdef object UNPACK_CLOSE_CODE
cdef object TUPLE_NEW
cdef object WSMsgType
cdef object WSMessage
cdef object WSMessageTextBytes
cdef object WS_MSG_TYPE_TEXT
cdef object WS_MSG_TYPE_BINARY
cdef set ALLOWED_CLOSE_CODES
cdef set MESSAGE_TYPES_WITH_CONTENT
cdef tuple EMPTY_FRAME
cdef tuple EMPTY_FRAME_ERROR
cdef class WebSocketDataQueue:
cdef unsigned int _size
cdef public object _protocol
cdef unsigned int _limit
cdef object _loop
cdef bint _eof
cdef object _waiter
cdef object _exception
cdef public object _buffer
cdef object _get_buffer
cdef object _put_buffer
cdef void _release_waiter(self)
cpdef void feed_data(self, object data, unsigned int size)
@cython.locals(size="unsigned int")
cdef _read_from_buffer(self)
cdef class WebSocketReader:
cdef WebSocketDataQueue queue
cdef unsigned int _max_msg_size
cdef bint _decode_text
cdef Exception _exc
cdef bytearray _partial
cdef unsigned int _state
cdef int _opcode
cdef bint _frame_fin
cdef int _frame_opcode
cdef list _payload_fragments
cdef Py_ssize_t _frame_payload_len
cdef bytes _tail
cdef bint _has_mask
cdef bytes _frame_mask
cdef Py_ssize_t _payload_bytes_to_read
cdef unsigned int _payload_len_flag
cdef int _compressed
cdef object _decompressobj
cdef bint _compress
cpdef tuple feed_data(self, object data)
@cython.locals(
is_continuation=bint,
fin=bint,
has_partial=bint,
payload_merged=bytes,
)
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
@cython.locals(
start_pos=Py_ssize_t,
data_len=Py_ssize_t,
length=Py_ssize_t,
chunk_size=Py_ssize_t,
chunk_len=Py_ssize_t,
data_len=Py_ssize_t,
data_cstr="const unsigned char *",
first_byte="unsigned char",
second_byte="unsigned char",
f_start_pos=Py_ssize_t,
f_end_pos=Py_ssize_t,
has_mask=bint,
fin=bint,
had_fragments=Py_ssize_t,
payload_bytearray=bytearray,
)
cpdef void _feed_data(self, bytes data) except *

View File

@ -0,0 +1,509 @@
"""Reader for WebSocket protocol versions 13 and 8."""
import asyncio
import builtins
from collections import deque
from typing import Final
from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
WebSocketError,
WSCloseCode,
WSMessage,
WSMessageTextBytes,
WSMsgType,
)
ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode}
# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value
EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")
COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1
TUPLE_NEW = tuple.__new__
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: asyncio.Future[None] | None = None
self._exception: BaseException | None = None
self._buffer: deque[tuple[WSMessage | WSMessageTextBytes, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append
def is_eof(self) -> bool:
return self._eof
def exception(self) -> BaseException | None:
return self._exception
def set_exception(
self,
exc: BaseException,
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)
def feed_eof(self) -> None:
self._eof = True
self._release_waiter()
self._exception = None # Break cyclic references
def feed_data(
self, data: "WSMessage | WSMessageTextBytes", size: "cython_int"
) -> None:
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> WSMessage | WSMessageTextBytes:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()
def _read_from_buffer(self) -> WSMessage | WSMessageTextBytes:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
class WebSocketReader:
def __init__(
self,
queue: WebSocketDataQueue,
max_msg_size: int,
compress: bool = True,
decode_text: bool = True,
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._decode_text = decode_text
self._exc: Exception | None = None
self._partial = bytearray()
self._state = READ_HEADER
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: bytes | None = None
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: ZLibDecompressor | None = None
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
# data can be bytearray on Windows because proactor event loop uses bytearray
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
# coerce data to bytes if it is not
def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]:
if type(data) is not bytes:
data = bytes(data)
if self._exc is not None:
return True, data
try:
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return EMPTY_FRAME_ERROR
return EMPTY_FRAME
def _handle_frame(
self,
fin: bool,
opcode: int | cython_int, # Union intended: Cython pxd uses C int
payload: bytes | bytearray,
compressed: int | cython_int, # Union intended: Cython pxd uses C int
) -> None:
msg: WSMessage
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
# Validate continuation frames before processing
if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
# load text/binary
if not fin:
# got partial frame payload
if opcode != OP_CODE_CONTINUATION:
self._opcode = opcode
self._partial += payload
return
has_partial = bool(self._partial)
if opcode == OP_CODE_CONTINUATION:
opcode = self._opcode
self._opcode = OP_CODE_NOT_SET
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
f"to be zero, got {opcode!r}",
)
assembled_payload: bytes | bytearray
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
# XXX: It's possible that the zlib backend (isal is known to
# do this, maybe others too?) will return max_length bytes,
# but internally buffer more data such that the payload is
# >max_length, so we return one extra byte and if we're able
# to do that, then the message is too big.
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING,
(
self._max_msg_size + 1
if self._max_msg_size
else self._max_msg_size
),
)
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message exceeds size limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)
if opcode == OP_CODE_TEXT:
if self._decode_text:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
# Return raw bytes for TEXT messages when decode_text=False
self.queue.feed_data(
TUPLE_NEW(
WSMessageTextBytes, (WS_MSG_TYPE_TEXT, payload_merged, "")
),
len(payload_merged),
)
else:
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {close_code}",
)
try:
close_message = payload[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
def _feed_data(self, data: bytes) -> None:
"""Return the next frame from the socket."""
if self._tail:
data, self._tail = self._tail + data, b""
start_pos: int = 0
data_len = len(data)
data_cstr = data
while True:
# read header
if self._state == READ_HEADER:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode not in {
OP_CODE_CONTINUATION,
OP_CODE_TEXT,
OP_CODE_BINARY,
OP_CODE_CLOSE,
OP_CODE_PING,
OP_CODE_PONG,
}:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Unexpected opcode={opcode!r}",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH
# read payload length
if self._state == READ_PAYLOAD_LENGTH:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_bytes_to_read = len_flag
# Reject oversized data frames before buffering any payload
# bytes. Control frames are capped at 125 bytes (checked in
# READ_HEADER) so only text/binary/continuation need this.
if self._max_msg_size and self._frame_opcode in {
OP_CODE_TEXT,
OP_CODE_BINARY,
OP_CODE_CONTINUATION,
}:
projected_size = self._payload_bytes_to_read + len(self._partial)
if projected_size >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {projected_size} "
f"exceeds limit {self._max_msg_size}",
)
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD
if self._state == READ_PAYLOAD:
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0
had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos
if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break
payload: bytes | bytearray
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray(b"".join(self._payload_fragments))
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]
self._handle_frame(
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload_len = 0
self._state = READ_HEADER
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

View File

@ -0,0 +1,509 @@
"""Reader for WebSocket protocol versions 13 and 8."""
import asyncio
import builtins
from collections import deque
from typing import Final
from ..base_protocol import BaseProtocol
from ..compression_utils import ZLibDecompressor
from ..helpers import _EXC_SENTINEL, set_exception
from ..streams import EofStream
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
from .models import (
WS_DEFLATE_TRAILING,
WebSocketError,
WSCloseCode,
WSMessage,
WSMessageTextBytes,
WSMsgType,
)
ALLOWED_CLOSE_CODES: Final[set[int]] = {int(i) for i in WSCloseCode}
# States for the reader, used to parse the WebSocket frame
# integer values are used so they can be cythonized
READ_HEADER = 1
READ_PAYLOAD_LENGTH = 2
READ_PAYLOAD_MASK = 3
READ_PAYLOAD = 4
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
# WSMsgType values unpacked so they can by cythonized to ints
OP_CODE_NOT_SET = -1
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
OP_CODE_TEXT = WSMsgType.TEXT.value
OP_CODE_BINARY = WSMsgType.BINARY.value
OP_CODE_CLOSE = WSMsgType.CLOSE.value
OP_CODE_PING = WSMsgType.PING.value
OP_CODE_PONG = WSMsgType.PONG.value
EMPTY_FRAME_ERROR = (True, b"")
EMPTY_FRAME = (False, b"")
COMPRESSED_NOT_SET = -1
COMPRESSED_FALSE = 0
COMPRESSED_TRUE = 1
TUPLE_NEW = tuple.__new__
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
class WebSocketDataQueue:
"""WebSocketDataQueue resumes and pauses an underlying stream.
It is a destination for WebSocket data.
"""
def __init__(
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
) -> None:
self._size = 0
self._protocol = protocol
self._limit = limit * 2
self._loop = loop
self._eof = False
self._waiter: asyncio.Future[None] | None = None
self._exception: BaseException | None = None
self._buffer: deque[tuple[WSMessage | WSMessageTextBytes, int]] = deque()
self._get_buffer = self._buffer.popleft
self._put_buffer = self._buffer.append
def is_eof(self) -> bool:
return self._eof
def exception(self) -> BaseException | None:
return self._exception
def set_exception(
self,
exc: BaseException,
exc_cause: builtins.BaseException = _EXC_SENTINEL,
) -> None:
self._eof = True
self._exception = exc
if (waiter := self._waiter) is not None:
self._waiter = None
set_exception(waiter, exc, exc_cause)
def _release_waiter(self) -> None:
if (waiter := self._waiter) is None:
return
self._waiter = None
if not waiter.done():
waiter.set_result(None)
def feed_eof(self) -> None:
self._eof = True
self._release_waiter()
self._exception = None # Break cyclic references
def feed_data(
self, data: "WSMessage | WSMessageTextBytes", size: "cython_int"
) -> None:
self._size += size
self._put_buffer((data, size))
self._release_waiter()
if self._size > self._limit and not self._protocol._reading_paused:
self._protocol.pause_reading()
async def read(self) -> WSMessage | WSMessageTextBytes:
if not self._buffer and not self._eof:
assert not self._waiter
self._waiter = self._loop.create_future()
try:
await self._waiter
except (asyncio.CancelledError, asyncio.TimeoutError):
self._waiter = None
raise
return self._read_from_buffer()
def _read_from_buffer(self) -> WSMessage | WSMessageTextBytes:
if self._buffer:
data, size = self._get_buffer()
self._size -= size
if self._size < self._limit and self._protocol._reading_paused:
self._protocol.resume_reading()
return data
if self._exception is not None:
raise self._exception
raise EofStream
class WebSocketReader:
def __init__(
self,
queue: WebSocketDataQueue,
max_msg_size: int,
compress: bool = True,
decode_text: bool = True,
) -> None:
self.queue = queue
self._max_msg_size = max_msg_size
self._decode_text = decode_text
self._exc: Exception | None = None
self._partial = bytearray()
self._state = READ_HEADER
self._opcode: int = OP_CODE_NOT_SET
self._frame_fin = False
self._frame_opcode: int = OP_CODE_NOT_SET
self._payload_fragments: list[bytes] = []
self._frame_payload_len = 0
self._tail: bytes = b""
self._has_mask = False
self._frame_mask: bytes | None = None
self._payload_bytes_to_read = 0
self._payload_len_flag = 0
self._compressed: int = COMPRESSED_NOT_SET
self._decompressobj: ZLibDecompressor | None = None
self._compress = compress
def feed_eof(self) -> None:
self.queue.feed_eof()
# data can be bytearray on Windows because proactor event loop uses bytearray
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
# coerce data to bytes if it is not
def feed_data(self, data: bytes | bytearray | memoryview) -> tuple[bool, bytes]:
if type(data) is not bytes:
data = bytes(data)
if self._exc is not None:
return True, data
try:
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return EMPTY_FRAME_ERROR
return EMPTY_FRAME
def _handle_frame(
self,
fin: bool,
opcode: int | cython_int, # Union intended: Cython pxd uses C int
payload: bytes | bytearray,
compressed: int | cython_int, # Union intended: Cython pxd uses C int
) -> None:
msg: WSMessage
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
# Validate continuation frames before processing
if opcode == OP_CODE_CONTINUATION and self._opcode == OP_CODE_NOT_SET:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
# load text/binary
if not fin:
# got partial frame payload
if opcode != OP_CODE_CONTINUATION:
self._opcode = opcode
self._partial += payload
return
has_partial = bool(self._partial)
if opcode == OP_CODE_CONTINUATION:
opcode = self._opcode
self._opcode = OP_CODE_NOT_SET
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
f"to be zero, got {opcode!r}",
)
assembled_payload: bytes | bytearray
if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload
# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
# XXX: It's possible that the zlib backend (isal is known to
# do this, maybe others too?) will return max_length bytes,
# but internally buffer more data such that the payload is
# >max_length, so we return one extra byte and if we're able
# to do that, then the message is too big.
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + WS_DEFLATE_TRAILING,
(
self._max_msg_size + 1
if self._max_msg_size
else self._max_msg_size
),
)
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Decompressed message exceeds size limit {self._max_msg_size}",
)
elif type(assembled_payload) is bytes:
payload_merged = assembled_payload
else:
payload_merged = bytes(assembled_payload)
if opcode == OP_CODE_TEXT:
if self._decode_text:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
# XXX: The Text and Binary messages here can be a performance
# bottleneck, so we use tuple.__new__ to improve performance.
# This is not type safe, but many tests should fail in
# test_client_ws_functional.py if this is wrong.
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
len(payload_merged),
)
else:
# Return raw bytes for TEXT messages when decode_text=False
self.queue.feed_data(
TUPLE_NEW(
WSMessageTextBytes, (WS_MSG_TYPE_TEXT, payload_merged, "")
),
len(payload_merged),
)
else:
self.queue.feed_data(
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
len(payload_merged),
)
elif opcode == OP_CODE_CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close code: {close_code}",
)
try:
close_message = payload[2:].decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
elif payload:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Invalid close frame: {fin} {opcode} {payload!r}",
)
else:
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
self.queue.feed_data(msg, 0)
elif opcode == OP_CODE_PING:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
self.queue.feed_data(msg, len(payload))
elif opcode == OP_CODE_PONG:
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
self.queue.feed_data(msg, len(payload))
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
def _feed_data(self, data: bytes) -> None:
"""Return the next frame from the socket."""
if self._tail:
data, self._tail = self._tail + data, b""
start_pos: int = 0
data_len = len(data)
data_cstr = data
while True:
# read header
if self._state == READ_HEADER:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0xF
# frame-fin = %x0 ; more frames of this message follow
# / %x1 ; final frame of this message
# frame-rsv1 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv2 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
# frame-rsv3 = %x0 ;
# 1 bit, MUST be 0 unless negotiated otherwise
#
# Remove rsv1 from this test for deflate development
if rsv2 or rsv3 or (rsv1 and not self._compress):
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
if opcode not in {
OP_CODE_CONTINUATION,
OP_CODE_TEXT,
OP_CODE_BINARY,
OP_CODE_CLOSE,
OP_CODE_PING,
OP_CODE_PONG,
}:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
f"Unexpected opcode={opcode!r}",
)
if opcode > 0x7 and fin == 0:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received fragmented control frame",
)
has_mask = (second_byte >> 7) & 1
length = second_byte & 0x7F
# Control frames MUST have a payload
# length of 125 bytes or less
if opcode > 0x7 and length > 125:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Control frame payload cannot be larger than 125 bytes",
)
# Set compress status if last package is FIN
# OR set compress status if this is first fragment
# Raise error if not first fragment with rsv1 = 0x1
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
elif rsv1:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Received frame with non-zero reserved bits",
)
self._frame_fin = bool(fin)
self._frame_opcode = opcode
self._has_mask = bool(has_mask)
self._payload_len_flag = length
self._state = READ_PAYLOAD_LENGTH
# read payload length
if self._state == READ_PAYLOAD_LENGTH:
len_flag = self._payload_len_flag
if len_flag == 126:
if data_len - start_pos < 2:
break
first_byte = data_cstr[start_pos]
second_byte = data_cstr[start_pos + 1]
start_pos += 2
self._payload_bytes_to_read = first_byte << 8 | second_byte
elif len_flag > 126:
if data_len - start_pos < 8:
break
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
start_pos += 8
else:
self._payload_bytes_to_read = len_flag
# Reject oversized data frames before buffering any payload
# bytes. Control frames are capped at 125 bytes (checked in
# READ_HEADER) so only text/binary/continuation need this.
if self._max_msg_size and self._frame_opcode in {
OP_CODE_TEXT,
OP_CODE_BINARY,
OP_CODE_CONTINUATION,
}:
projected_size = self._payload_bytes_to_read + len(self._partial)
if projected_size >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
f"Message size {projected_size} "
f"exceeds limit {self._max_msg_size}",
)
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
# read payload mask
if self._state == READ_PAYLOAD_MASK:
if data_len - start_pos < 4:
break
self._frame_mask = data_cstr[start_pos : start_pos + 4]
start_pos += 4
self._state = READ_PAYLOAD
if self._state == READ_PAYLOAD:
chunk_len = data_len - start_pos
if self._payload_bytes_to_read >= chunk_len:
f_end_pos = data_len
self._payload_bytes_to_read -= chunk_len
else:
f_end_pos = start_pos + self._payload_bytes_to_read
self._payload_bytes_to_read = 0
had_fragments = self._frame_payload_len
self._frame_payload_len += f_end_pos - start_pos
f_start_pos = start_pos
start_pos = f_end_pos
if self._payload_bytes_to_read != 0:
# If we don't have a complete frame, we need to save the
# data for the next call to feed_data.
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
break
payload: bytes | bytearray
if had_fragments:
# We have to join the payload fragments get the payload
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
if self._has_mask:
assert self._frame_mask is not None
payload_bytearray = bytearray(b"".join(self._payload_fragments))
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = b"".join(self._payload_fragments)
self._payload_fragments.clear()
elif self._has_mask:
assert self._frame_mask is not None
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
if type(payload_bytearray) is not bytearray: # pragma: no branch
# Cython will do the conversion for us
# but we need to do it for Python and we
# will always get here in Python
payload_bytearray = bytearray(payload_bytearray)
websocket_mask(self._frame_mask, payload_bytearray)
payload = payload_bytearray
else:
payload = data_cstr[f_start_pos:f_end_pos]
self._handle_frame(
self._frame_fin, self._frame_opcode, payload, self._compressed
)
self._frame_payload_len = 0
self._state = READ_HEADER
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

View File

@ -0,0 +1,261 @@
"""WebSocket protocol versions 13 and 8."""
import asyncio
import random
import sys
from functools import partial
from typing import Final, Optional, Set
from ..base_protocol import BaseProtocol
from ..client_exceptions import ClientConnectionResetError
from ..compression_utils import ZLibBackend, ZLibCompressor
from ..helpers import DEFAULT_CHUNK_SIZE
from .helpers import (
MASK_LEN,
MSG_SIZE,
PACK_CLOSE_CODE,
PACK_LEN1,
PACK_LEN2,
PACK_LEN3,
PACK_RANDBITS,
websocket_mask,
)
from .models import WS_DEFLATE_TRAILING, WSMsgType
# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
# Control frames (ping, pong, close) are never compressed
WS_CONTROL_FRAME_OPCODE: Final[int] = 8
# For websockets, keeping latency low is extremely important as implementations
# generally expect to be able to send and receive messages quickly. We use a
# larger chunk size to reduce the number of executor calls and avoid task
# creation overhead, since both are significant sources of latency when chunks
# are small. A size of 16KiB was chosen as a balance between avoiding task
# overhead and not blocking the event loop too long with synchronous compression.
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
class WebSocketWriter:
"""WebSocket writer.
The writer is responsible for sending messages to the client. It is
created by the protocol when a connection is established. The writer
should avoid implementing any application logic and should only be
concerned with the low-level details of the WebSocket protocol.
"""
def __init__(
self,
protocol: BaseProtocol,
transport: asyncio.Transport,
*,
use_mask: bool = False,
limit: int = DEFAULT_CHUNK_SIZE,
random: random.Random = random.Random(),
compress: int = 0,
notakeover: bool = False,
) -> None:
"""Initialize a WebSocket writer."""
self.protocol = protocol
self.transport = transport
self.use_mask = use_mask
self.get_random_bits = partial(random.getrandbits, 32)
self.compress = compress
self.notakeover = notakeover
self._closing = False
self._limit = limit
self._output_size = 0
self._compressobj: Optional[ZLibCompressor] = None
self._send_lock = asyncio.Lock()
self._background_tasks: Set[asyncio.Task[None]] = set()
async def send_frame(
self, message: bytes, opcode: int, compress: int | None = None
) -> None:
"""Send a frame over the websocket with message as its payload."""
if self._closing and not (opcode & WSMsgType.CLOSE):
raise ClientConnectionResetError("Cannot write to closing transport")
if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
# Non-compressed frames don't need lock or shield
self._write_websocket_frame(message, opcode, 0)
elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
# Small compressed payloads - compress synchronously in event loop
# We need the lock even though sync compression has no await points.
# This prevents small frames from interleaving with large frames that
# compress in the executor, avoiding compressor state corruption.
async with self._send_lock:
self._send_compressed_frame_sync(message, opcode, compress)
else:
# Large compressed frames need shield to prevent corruption
# For large compressed frames, the entire compress+send
# operation must be atomic. If cancelled after compression but
# before send, the compressor state would be advanced but data
# not sent, corrupting subsequent frames.
# Create a task to shield from cancellation
# The lock is acquired inside the shielded task so the entire
# operation (lock + compress + send) completes atomically.
# Use eager_start on Python 3.12+ to avoid scheduling overhead
loop = asyncio.get_running_loop()
coro = self._send_compressed_frame_async_locked(message, opcode, compress)
if sys.version_info >= (3, 12):
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
send_task = loop.create_task(coro)
# Keep a strong reference to prevent garbage collection
self._background_tasks.add(send_task)
send_task.add_done_callback(self._background_tasks.discard)
await asyncio.shield(send_task)
# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.
# Once we have written output_size up to the limit, we call the
# drain helper which waits for the transport to be ready to accept
# more data. This is a flow control mechanism to prevent the buffer
# from growing too large. The drain helper will return right away
# if the writer is not paused.
if self._output_size > self._limit:
self._output_size = 0
if self.protocol._paused:
await self.protocol._drain_helper()
def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
"""
Write a websocket frame to the transport.
This method handles frame header construction, masking, and writing to transport.
It does not handle compression or flow control - those are the responsibility
of the caller.
"""
msg_length = len(message)
use_mask = self.use_mask
mask_bit = 0x80 if use_mask else 0
# Depending on the message length, the header is assembled differently.
# The first byte is reserved for the opcode and the RSV bits.
first_byte = 0x80 | rsv | opcode
if msg_length < 126:
header = PACK_LEN1(first_byte, msg_length | mask_bit)
header_len = 2
elif msg_length < 65536:
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
header_len = 4
else:
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
header_len = 10
if self.transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
# If we are using a mask, we need to generate it randomly
# and apply it to the message before sending it. A mask is
# a 32-bit value that is applied to the message using a
# bitwise XOR operation. It is used to prevent certain types
# of attacks on the websocket protocol. The mask is only used
# when aiohttp is acting as a client. Servers do not use a mask.
if use_mask:
mask = PACK_RANDBITS(self.get_random_bits())
message_arr = bytearray(message)
websocket_mask(mask, message_arr)
self.transport.write(header + mask + message_arr)
self._output_size += MASK_LEN
elif msg_length > MSG_SIZE:
self.transport.write(header)
self.transport.write(message)
else:
self.transport.write(header + message)
self._output_size += header_len + msg_length
def _get_compressor(self, compress: int | None) -> ZLibCompressor:
"""Get or create a compressor object for the given compression level."""
if compress:
# Do not set self._compress if compressing is for this frame
return ZLibCompressor(
level=ZLibBackend.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
if not self._compressobj:
self._compressobj = ZLibCompressor(
level=ZLibBackend.Z_BEST_SPEED,
wbits=-self.compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
return self._compressobj
def _send_compressed_frame_sync(
self, message: bytes, opcode: int, compress: int | None
) -> None:
"""
Synchronous send for small compressed frames.
This is used for small compressed payloads that compress synchronously in the event loop.
Since there are no await points, this is inherently cancellation-safe.
"""
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
compressobj = self._get_compressor(compress)
# (0x40) RSV1 is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
self._write_websocket_frame(
(
compressobj.compress_sync(message)
+ compressobj.flush(
ZLibBackend.Z_FULL_FLUSH
if self.notakeover
else ZLibBackend.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING),
opcode,
0x40,
)
async def _send_compressed_frame_async_locked(
self, message: bytes, opcode: int, compress: int | None
) -> None:
"""
Async send for large compressed frames with lock.
Acquires the lock and compresses large payloads asynchronously in
the executor. The lock is held for the entire operation to ensure
the compressor state is not corrupted by concurrent sends.
MUST be run shielded from cancellation. If cancelled after
compression but before sending, the compressor state would be
advanced but data not sent, corrupting subsequent frames.
"""
async with self._send_lock:
# RSV are the reserved bits in the frame header. They are used to
# indicate that the frame is using an extension.
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
compressobj = self._get_compressor(compress)
# (0x40) RSV1 is set for compressed frames
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
self._write_websocket_frame(
(
await compressobj.compress(message)
+ compressobj.flush(
ZLibBackend.Z_FULL_FLUSH
if self.notakeover
else ZLibBackend.Z_SYNC_FLUSH
)
).removesuffix(WS_DEFLATE_TRAILING),
opcode,
0x40,
)
async def close(self, code: int = 1000, message: bytes | str = b"") -> None:
"""Close the websocket, sending the specified code and message."""
if isinstance(message, str):
message = message.encode("utf-8")
try:
await self.send_frame(
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
)
finally:
self._closing = True

View File

@ -0,0 +1,270 @@
import asyncio
import logging
import socket
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable, Generator, Iterable, Sequence, Sized
from http.cookies import BaseCookie, Morsel, SimpleCookie
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, TypedDict
from multidict import CIMultiDict
from yarl import URL
from ._cookie_helpers import parse_set_cookie_headers
from .typedefs import LooseCookies
if TYPE_CHECKING:
from .web_app import Application
from .web_exceptions import HTTPException
from .web_request import BaseRequest, Request
from .web_response import StreamResponse
else:
BaseRequest = Request = Application = StreamResponse = Any
HTTPException = Any
class AbstractRouter(ABC):
def __init__(self) -> None:
self._frozen = False
def post_init(self, app: Application) -> None:
"""Post init stage.
Not an abstract method for sake of backward compatibility,
but if the router wants to be aware of the application
it can override this.
"""
@property
def frozen(self) -> bool:
return self._frozen
def freeze(self) -> None:
"""Freeze router."""
self._frozen = True
@abstractmethod
async def resolve(self, request: Request) -> "AbstractMatchInfo":
"""Return MATCH_INFO for given request"""
class AbstractMatchInfo(ABC):
__slots__ = ()
@property # pragma: no branch
@abstractmethod
def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]:
"""Execute matched request handler"""
@property
@abstractmethod
def expect_handler(
self,
) -> Callable[[Request], Awaitable[StreamResponse | None]]:
"""Expect handler for 100-continue processing"""
@property # pragma: no branch
@abstractmethod
def http_exception(self) -> HTTPException | None:
"""HTTPException instance raised on router's resolving, or None"""
@abstractmethod # pragma: no branch
def get_info(self) -> dict[str, Any]:
"""Return a dict with additional info useful for introspection"""
@property # pragma: no branch
@abstractmethod
def apps(self) -> tuple[Application, ...]:
"""Stack of nested applications.
Top level application is left-most element.
"""
@abstractmethod
def add_app(self, app: Application) -> None:
"""Add application to the nested apps stack."""
@abstractmethod
def freeze(self) -> None:
"""Freeze the match info.
The method is called after route resolution.
After the call .add_app() is forbidden.
"""
class AbstractView(ABC):
"""Abstract class based view."""
def __init__(self, request: Request) -> None:
self._request = request
@property
def request(self) -> Request:
"""Request instance."""
return self._request
@abstractmethod
def __await__(self) -> Generator[None, None, StreamResponse]:
"""Execute the view handler."""
class ResolveResult(TypedDict):
"""Resolve result.
This is the result returned from an AbstractResolver's
resolve method.
:param hostname: The hostname that was provided.
:param host: The IP address that was resolved.
:param port: The port that was resolved.
:param family: The address family that was resolved.
:param proto: The protocol that was resolved.
:param flags: The flags that were resolved.
"""
hostname: str
host: str
port: int
family: int
proto: int
flags: int
class AbstractResolver(ABC):
"""Abstract DNS resolver."""
@abstractmethod
async def resolve(
self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
) -> list[ResolveResult]:
"""Return IP address for given hostname"""
@abstractmethod
async def close(self) -> None:
"""Release resolver"""
if TYPE_CHECKING:
IterableBase = Iterable[Morsel[str]]
else:
IterableBase = Iterable
ClearCookiePredicate = Callable[["Morsel[str]"], bool]
class AbstractCookieJar(Sized, IterableBase):
"""Abstract Cookie Jar."""
def __init__(self, *, loop: asyncio.AbstractEventLoop | None = None) -> None:
self._loop = loop or asyncio.get_running_loop()
@property
@abstractmethod
def unsafe(self) -> bool:
"""Return True if cookies can be used with IP addresses."""
@property
@abstractmethod
def quote_cookie(self) -> bool:
"""Return True if cookies should be quoted."""
@property
@abstractmethod
def cookies(self) -> MappingProxyType[tuple[str, str], SimpleCookie]:
"""Return the cookies stored in this jar."""
@property
@abstractmethod
def host_only_cookies(self) -> frozenset[tuple[str, str]]:
"""Return the host-only cookies stored in this jar."""
@abstractmethod
def clear(self, predicate: ClearCookiePredicate | None = None) -> None:
"""Clear all cookies if no predicate is passed."""
@abstractmethod
def clear_domain(self, domain: str) -> None:
"""Clear all cookies for domain and all subdomains."""
@abstractmethod
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
def update_cookies_from_headers(
self, headers: Sequence[str], response_url: URL
) -> None:
"""Update cookies from raw Set-Cookie headers."""
if headers and (cookies_to_update := parse_set_cookie_headers(headers)):
self.update_cookies(cookies_to_update, response_url)
@abstractmethod
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
"""Return the jar's cookies filtered by their attributes."""
class AbstractStreamWriter(ABC):
"""Abstract stream writer."""
buffer_size: int = 0
output_size: int = 0
length: int | None = 0
@abstractmethod
async def write(self, chunk: bytes | bytearray | memoryview) -> None:
"""Write chunk into stream."""
@abstractmethod
async def write_eof(self, chunk: bytes = b"") -> None:
"""Write last chunk."""
@abstractmethod
async def drain(self) -> None:
"""Flush the write buffer."""
@abstractmethod
def enable_compression(
self, encoding: str = "deflate", strategy: int | None = None
) -> None:
"""Enable HTTP body compression"""
@abstractmethod
def enable_chunking(self) -> None:
"""Enable HTTP chunked mode"""
@abstractmethod
async def write_headers(
self, status_line: str, headers: "CIMultiDict[str]"
) -> None:
"""Write HTTP headers"""
def send_headers(self) -> None:
"""Force sending buffered headers if not already sent.
Required only if write_headers() buffers headers instead of sending immediately.
For backwards compatibility, this method does nothing by default.
"""
class AbstractAccessLogger(ABC):
"""Abstract writer to access log."""
__slots__ = ("logger", "log_format")
def __init__(self, logger: logging.Logger, log_format: str) -> None:
self.logger = logger
self.log_format = log_format
@abstractmethod
def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None:
"""Emit log to logger."""
@property
def enabled(self) -> bool:
"""Check if logger is enabled."""
return True

View File

@ -0,0 +1,140 @@
import asyncio
from typing import TYPE_CHECKING, Any, cast
from .client_exceptions import ClientConnectionResetError
from .helpers import set_exception
from .tcp_helpers import tcp_nodelay
if TYPE_CHECKING:
from .http_parser import HttpParser
# Raised by transport.pause_reading()/resume_reading() when the transport
# does not support flow control; safe to ignore.
# NOTE: Catch these with a plain try/except/pass, never contextlib.suppress():
# pause/resume run on the hot read path and suppress() is ~6x slower than
# try/except here (it builds a context manager and unpacks this tuple per call).
PAUSE_RESUME_READING_ERRORS = (AttributeError, NotImplementedError, RuntimeError)
class BaseProtocol(asyncio.Protocol):
__slots__ = (
"_loop",
"_paused",
"_parser",
"_drain_waiter",
"_connection_lost",
"_reading_paused",
"_upgraded",
"transport",
)
def __init__(
self, loop: asyncio.AbstractEventLoop, parser: "HttpParser[Any] | None" = None
) -> None:
self._loop: asyncio.AbstractEventLoop = loop
self._paused = False
self._drain_waiter: asyncio.Future[None] | None = None
self._reading_paused = False
self._parser = parser
self._upgraded = False
self.transport: asyncio.Transport | None = None
@property
def connected(self) -> bool:
"""Return True if the connection is open."""
return self.transport is not None
@property
def writing_paused(self) -> bool:
return self._paused
def pause_writing(self) -> None:
assert not self._paused
self._paused = True
def resume_writing(self) -> None:
assert self._paused
self._paused = False
waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)
def pause_reading(self) -> None:
self._reading_paused = True
# Parser shouldn't be paused on websockets.
if not self._upgraded:
assert self._parser is not None
self._parser.pause_reading()
if self.transport is not None:
try:
self.transport.pause_reading()
except PAUSE_RESUME_READING_ERRORS:
# Transport lacks flow control; nothing to pause. Intentionally
# ignored (see PAUSE_RESUME_READING_ERRORS; do not use suppress).
pass
def _reading_paused_for_msg_queue(self) -> bool:
"""Keep the transport paused for protocol-specific reasons (overridden)."""
return False
def resume_reading(self, resume_parser: bool = True) -> None:
self._reading_paused = False
# This will resume parsing any unprocessed data from the last pause.
if not self._upgraded and resume_parser:
self.data_received(b"")
# Reading may have been paused again in the above call if there was a lot of
# compressed data still pending.
if (
not self._reading_paused
and not self._reading_paused_for_msg_queue()
and self.transport is not None
):
try:
self.transport.resume_reading()
except PAUSE_RESUME_READING_ERRORS:
# Transport lacks flow control; nothing to resume. Intentionally
# ignored (see PAUSE_RESUME_READING_ERRORS; do not use suppress).
pass
self._reading_paused = False
def connection_made(self, transport: asyncio.BaseTransport) -> None:
tr = cast(asyncio.Transport, transport)
tcp_nodelay(tr, True)
self.transport = tr
def connection_lost(self, exc: BaseException | None) -> None:
# Wake up the writer if currently paused.
self.transport = None
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
set_exception(
waiter,
ConnectionError("Connection lost"),
exc,
)
async def _drain_helper(self) -> None:
if self.transport is None:
raise ClientConnectionResetError("Connection lost")
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
waiter = self._loop.create_future()
self._drain_waiter = waiter
await waiter

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,426 @@
"""HTTP related errors."""
import asyncio
import warnings
from typing import TYPE_CHECKING, Union
from multidict import MultiMapping
from .typedefs import StrOrURL
if TYPE_CHECKING:
import ssl
SSLContext = ssl.SSLContext
else:
try:
import ssl
SSLContext = ssl.SSLContext
except ImportError: # pragma: no cover
ssl = SSLContext = None # type: ignore[assignment]
if TYPE_CHECKING:
from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo
from .http_parser import RawResponseMessage
else:
RequestInfo = ClientResponse = ConnectionKey = RawResponseMessage = None
__all__ = (
"ClientError",
"ClientConnectionError",
"ClientConnectionResetError",
"ClientOSError",
"ClientConnectorError",
"ClientProxyConnectionError",
"ClientSSLError",
"ClientConnectorDNSError",
"ClientConnectorSSLError",
"ClientConnectorCertificateError",
"ConnectionTimeoutError",
"SocketTimeoutError",
"ServerConnectionError",
"ServerTimeoutError",
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ClientResponseError",
"ClientHttpProxyError",
"WSServerHandshakeError",
"ContentTypeError",
"ClientPayloadError",
"InvalidURL",
"InvalidUrlClientError",
"RedirectClientError",
"NonHttpUrlClientError",
"InvalidUrlRedirectClientError",
"NonHttpUrlRedirectClientError",
"WSMessageTypeError",
)
class ClientError(Exception):
"""Base class for client connection errors."""
class ClientResponseError(ClientError):
"""Base class for exceptions that occur after getting a response.
request_info: An instance of RequestInfo.
history: A sequence of responses, if redirects occurred.
status: HTTP status code.
message: Error message.
headers: Response headers.
"""
def __init__(
self,
request_info: RequestInfo,
history: tuple[ClientResponse, ...],
*,
code: int | None = None,
status: int | None = None,
message: str = "",
headers: MultiMapping[str] | None = None,
) -> None:
self.request_info = request_info
if code is not None:
if status is not None:
raise ValueError(
"Both code and status arguments are provided; "
"code is deprecated, use status instead"
)
warnings.warn(
"code argument is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
if status is not None:
self.status = status
elif code is not None:
self.status = code
else:
self.status = 0
self.message = message
self.headers = headers
self.history = history
self.args = (request_info, history)
def __str__(self) -> str:
return f"{self.status}, message={self.message!r}, url={str(self.request_info.real_url)!r}"
def __repr__(self) -> str:
args = f"{self.request_info!r}, {self.history!r}"
if self.status != 0:
args += f", status={self.status!r}"
if self.message != "":
args += f", message={self.message!r}"
if self.headers is not None:
args += f", headers={self.headers!r}"
return f"{type(self).__name__}({args})"
@property
def code(self) -> int:
warnings.warn(
"code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
return self.status
@code.setter
def code(self, value: int) -> None:
warnings.warn(
"code property is deprecated, use status instead",
DeprecationWarning,
stacklevel=2,
)
self.status = value
class ContentTypeError(ClientResponseError):
"""ContentType found is not valid."""
class WSServerHandshakeError(ClientResponseError):
"""websocket server handshake error."""
class ClientHttpProxyError(ClientResponseError):
"""HTTP proxy error.
Raised in :class:`aiohttp.connector.TCPConnector` if
proxy responds with status other than ``200 OK``
on ``CONNECT`` request.
"""
class TooManyRedirects(ClientResponseError):
"""Client was redirected too many times."""
class ClientConnectionError(ClientError):
"""Base class for client socket errors."""
class ClientConnectionResetError(ClientConnectionError, ConnectionResetError):
"""ConnectionResetError"""
class ClientOSError(ClientConnectionError, OSError):
"""OSError error."""
class ClientConnectorError(ClientOSError):
"""Client connector error.
Raised in :class:`aiohttp.connector.TCPConnector` if
a connection can not be established.
"""
def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None:
self._conn_key = connection_key
self._os_error = os_error
super().__init__(os_error.errno, os_error.strerror)
self.args = (connection_key, os_error)
@property
def os_error(self) -> OSError:
return self._os_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> int | None:
return self._conn_key.port
@property
def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]:
return self._conn_key.ssl
def __str__(self) -> str:
return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format(
self, "default" if self.ssl is True else self.ssl, self.strerror
)
# OSError.__reduce__ does too much black magick
__reduce__ = BaseException.__reduce__
class ClientConnectorDNSError(ClientConnectorError):
"""DNS resolution failed during client connection.
Raised in :class:`aiohttp.connector.TCPConnector` if
DNS resolution fails.
"""
class ClientProxyConnectionError(ClientConnectorError):
"""Proxy connection error.
Raised in :class:`aiohttp.connector.TCPConnector` if
connection to proxy can not be established.
"""
class UnixClientConnectorError(ClientConnectorError):
"""Unix connector error.
Raised in :py:class:`aiohttp.connector.UnixConnector`
if connection to unix socket can not be established.
"""
def __init__(
self, path: str, connection_key: ConnectionKey, os_error: OSError
) -> None:
self._path = path
super().__init__(connection_key, os_error)
@property
def path(self) -> str:
return self._path
def __str__(self) -> str:
return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format(
self, "default" if self.ssl is True else self.ssl, self.strerror
)
class ServerConnectionError(ClientConnectionError):
"""Server connection errors."""
class ServerDisconnectedError(ServerConnectionError):
"""Server disconnected."""
def __init__(self, message: RawResponseMessage | str | None = None) -> None:
if message is None:
message = "Server disconnected"
self.args = (message,)
self.message = message
class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError):
"""Server timeout error."""
class ConnectionTimeoutError(ServerTimeoutError):
"""Connection timeout error."""
class SocketTimeoutError(ServerTimeoutError):
"""Socket timeout error."""
class ServerFingerprintMismatch(ServerConnectionError):
"""SSL certificate does not match expected fingerprint."""
def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None:
self.expected = expected
self.got = got
self.host = host
self.port = port
self.args = (expected, got, host, port)
def __repr__(self) -> str:
return f"<{self.__class__.__name__} expected={self.expected!r} got={self.got!r} host={self.host!r} port={self.port!r}>"
class ClientPayloadError(ClientError):
"""Response payload error."""
class InvalidURL(ClientError, ValueError):
"""Invalid URL.
URL used for fetching is malformed, e.g. it doesn't contains host
part.
"""
# Derive from ValueError for backward compatibility
def __init__(self, url: StrOrURL, description: str | None = None) -> None:
# The type of url is not yarl.URL because the exception can be raised
# on URL(url) call
self._url = url
self._description = description
if description:
super().__init__(url, description)
else:
super().__init__(url)
@property
def url(self) -> StrOrURL:
return self._url
@property
def description(self) -> "str | None":
return self._description
def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self}>"
def __str__(self) -> str:
if self._description:
return f"{self._url} - {self._description}"
return str(self._url)
class InvalidUrlClientError(InvalidURL):
"""Invalid URL client error."""
class RedirectClientError(ClientError):
"""Client redirect error."""
class NonHttpUrlClientError(ClientError):
"""Non http URL client error."""
class InvalidUrlRedirectClientError(InvalidUrlClientError, RedirectClientError):
"""Invalid URL redirect client error."""
class NonHttpUrlRedirectClientError(NonHttpUrlClientError, RedirectClientError):
"""Non http URL redirect client error."""
class ClientSSLError(ClientConnectorError):
"""Base error for ssl.*Errors."""
if ssl is not None:
cert_errors = (ssl.CertificateError,)
cert_errors_bases = (
ClientSSLError,
ssl.CertificateError,
)
ssl_errors = (ssl.SSLError,)
ssl_error_bases = (ClientSSLError, ssl.SSLError)
else: # pragma: no cover
cert_errors = tuple()
cert_errors_bases = (
ClientSSLError,
ValueError,
)
ssl_errors = tuple()
ssl_error_bases = (ClientSSLError,)
class ClientConnectorSSLError(*ssl_error_bases): # type: ignore[misc]
"""Response ssl error."""
class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore[misc]
"""Response certificate error."""
_conn_key: ConnectionKey
def __init__(
# TODO: If we require ssl in future, this can become ssl.CertificateError
self,
connection_key: ConnectionKey,
certificate_error: Exception,
) -> None:
if isinstance(certificate_error, cert_errors + (OSError,)):
# ssl.CertificateError has errno and strerror, so we should be fine
os_error = certificate_error
else:
os_error = OSError()
super().__init__(connection_key, os_error)
self._certificate_error = certificate_error
self.args = (connection_key, certificate_error)
@property
def certificate_error(self) -> Exception:
return self._certificate_error
@property
def host(self) -> str:
return self._conn_key.host
@property
def port(self) -> int | None:
return self._conn_key.port
@property
def ssl(self) -> bool:
return self._conn_key.is_ssl
def __str__(self) -> str:
return (
f"Cannot connect to host {self.host}:{self.port} ssl:{self.ssl} "
f"[{self.certificate_error.__class__.__name__}: "
f"{self.certificate_error.args}]"
)
class WSMessageTypeError(TypeError):
"""WebSocket message type is not valid."""

View File

@ -0,0 +1,494 @@
"""
Digest authentication middleware for aiohttp client.
This middleware implements HTTP Digest Authentication according to RFC 7616,
providing a more secure alternative to Basic Authentication. It supports all
standard hash algorithms including MD5, SHA, SHA-256, SHA-512 and their session
variants, as well as both 'auth' and 'auth-int' quality of protection (qop) options.
"""
import hashlib
import os
import re
import sys
import time
from collections.abc import Callable
from typing import Final, Literal, TypedDict
from yarl import URL
from . import hdrs
from .client_exceptions import ClientError
from .client_middlewares import ClientHandlerType
from .client_reqrep import ClientRequest, ClientResponse
from .payload import Payload
class DigestAuthChallenge(TypedDict, total=False):
realm: str
nonce: str
qop: str
algorithm: str
opaque: str
domain: str
stale: str
DigestFunctions: dict[str, Callable[[bytes], "hashlib._Hash"]] = {
"MD5": hashlib.md5,
"MD5-SESS": hashlib.md5,
"SHA": hashlib.sha1,
"SHA-SESS": hashlib.sha1,
"SHA256": hashlib.sha256,
"SHA256-SESS": hashlib.sha256,
"SHA-256": hashlib.sha256,
"SHA-256-SESS": hashlib.sha256,
"SHA512": hashlib.sha512,
"SHA512-SESS": hashlib.sha512,
"SHA-512": hashlib.sha512,
"SHA-512-SESS": hashlib.sha512,
}
# Compile the regex pattern once at module level for performance
_HEADER_PAIRS_PATTERN = re.compile(
r'(?:^|\s|,\s*)(\w+)\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))'
if sys.version_info < (3, 11)
else r'(?:^|\s|,\s*)((?>\w+))\s*=\s*(?:"((?:[^"\\]|\\.)*)"|([^\s,]+))'
# +------------|--------|--|-|-|--|----|------|----|--||-----|-> Match valid start/sep
# +--------|--|-|-|--|----|------|----|--||-----|-> alphanumeric key (atomic
# | | | | | | | | || | group reduces backtracking)
# +--|-|-|--|----|------|----|--||-----|-> maybe whitespace
# | | | | | | | || |
# +-|-|--|----|------|----|--||-----|-> = (delimiter)
# +-|--|----|------|----|--||-----|-> maybe whitespace
# | | | | | || |
# +--|----|------|----|--||-----|-> group quoted or unquoted
# | | | | || |
# +----|------|----|--||-----|-> if quoted...
# +------|----|--||-----|-> anything but " or \
# +----|--||-----|-> escaped characters allowed
# +--||-----|-> or can be empty string
# || |
# +|-----|-> if unquoted...
# +-----|-> anything but , or <space>
# +-> at least one char req'd
)
# RFC 7616: Challenge parameters to extract
CHALLENGE_FIELDS: Final[
tuple[
Literal["realm", "nonce", "qop", "algorithm", "opaque", "domain", "stale"], ...
]
] = (
"realm",
"nonce",
"qop",
"algorithm",
"opaque",
"domain",
"stale",
)
# Supported digest authentication algorithms
# Use a tuple of sorted keys for predictable documentation and error messages
SUPPORTED_ALGORITHMS: Final[tuple[str, ...]] = tuple(sorted(DigestFunctions.keys()))
# RFC 7616: Fields that require quoting in the Digest auth header
# These fields must be enclosed in double quotes in the Authorization header.
# Algorithm, qop, and nc are never quoted per RFC specifications.
# This frozen set is used by the template-based header construction to
# automatically determine which fields need quotes.
QUOTED_AUTH_FIELDS: Final[frozenset[str]] = frozenset(
{"username", "realm", "nonce", "uri", "response", "opaque", "cnonce"}
)
def escape_quotes(value: str) -> str:
"""Escape double quotes for HTTP header values."""
return value.replace('"', '\\"')
def unescape_quotes(value: str) -> str:
"""Unescape double quotes in HTTP header values."""
return value.replace('\\"', '"')
def parse_header_pairs(header: str) -> dict[str, str]:
"""
Parse key-value pairs from WWW-Authenticate or similar HTTP headers.
This function handles the complex format of WWW-Authenticate header values,
supporting both quoted and unquoted values, proper handling of commas in
quoted values, and whitespace variations per RFC 7616.
Examples of supported formats:
- key1="value1", key2=value2
- key1 = "value1" , key2="value, with, commas"
- key1=value1,key2="value2"
- realm="example.com", nonce="12345", qop="auth"
Args:
header: The header value string to parse
Returns:
Dictionary mapping parameter names to their values
"""
return {
stripped_key: unescape_quotes(quoted_val) if quoted_val else unquoted_val
for key, quoted_val, unquoted_val in _HEADER_PAIRS_PATTERN.findall(header)
if (stripped_key := key.strip())
}
class DigestAuthMiddleware:
"""
HTTP digest authentication middleware for aiohttp client.
This middleware intercepts 401 Unauthorized responses containing a Digest
authentication challenge, calculates the appropriate digest credentials,
and automatically retries the request with the proper Authorization header.
Features:
- Handles all aspects of Digest authentication handshake automatically
- Supports all standard hash algorithms:
- MD5, MD5-SESS
- SHA, SHA-SESS
- SHA256, SHA256-SESS, SHA-256, SHA-256-SESS
- SHA512, SHA512-SESS, SHA-512, SHA-512-SESS
- Supports 'auth' and 'auth-int' quality of protection modes
- Properly handles quoted strings and parameter parsing
- Includes replay attack protection with client nonce count tracking
- Supports preemptive authentication per RFC 7616 Section 3.6
Origin scoping:
The credentials are scoped to the origin of the first request the
middleware handles. A request to a different origin is passed through
untouched, so it never receives a digest response computed from those
credentials, unless that origin falls within a protection space the
anchor origin advertised through the RFC 7616 ``domain`` directive. Make
the first request through the middleware against the intended origin, as
the anchor is pinned to it and not reset for the life of the instance.
Standards compliance:
- RFC 7616: HTTP Digest Access Authentication (primary reference)
- RFC 2617: HTTP Authentication (deprecated by RFC 7616)
- RFC 1945: Section 11.1 (username restrictions)
Implementation notes:
The core digest calculation is inspired by the implementation in
https://github.com/requests/requests/blob/v2.18.4/requests/auth.py
with added support for modern digest auth features and error handling.
"""
def __init__(
self,
login: str,
password: str,
preemptive: bool = True,
) -> None:
if login is None:
raise ValueError("None is not allowed as login value")
if password is None:
raise ValueError("None is not allowed as password value")
if ":" in login:
raise ValueError('A ":" is not allowed in username (RFC 1945#section-11.1)')
self._login_str: Final[str] = login
self._login_bytes: Final[bytes] = login.encode("utf-8")
self._password_bytes: Final[bytes] = password.encode("utf-8")
self._last_nonce_bytes = b""
self._nonce_count = 0
self._challenge: DigestAuthChallenge = {}
self._preemptive: bool = preemptive
# Set of URLs defining the protection space
self._protection_space: list[str] = []
# Origin the credentials are scoped to; set on the first request.
self._origin: URL | None = None
async def _encode(self, method: str, url: URL, body: Payload | Literal[b""]) -> str:
"""
Build digest authorization header for the current challenge.
Args:
method: The HTTP method (GET, POST, etc.)
url: The request URL
body: The request body (used for qop=auth-int)
Returns:
A fully formatted Digest authorization header string
Raises:
ClientError: If the challenge is missing required parameters or
contains unsupported values
"""
challenge = self._challenge
if "realm" not in challenge:
raise ClientError(
"Malformed Digest auth challenge: Missing 'realm' parameter"
)
if "nonce" not in challenge:
raise ClientError(
"Malformed Digest auth challenge: Missing 'nonce' parameter"
)
# Empty realm values are allowed per RFC 7616 (SHOULD, not MUST, contain host name)
realm = challenge["realm"]
nonce = challenge["nonce"]
# Empty nonce values are not allowed as they are security-critical for replay protection
if not nonce:
raise ClientError(
"Security issue: Digest auth challenge contains empty 'nonce' value"
)
qop_raw = challenge.get("qop", "")
# Preserve original algorithm case for response while using uppercase for processing
algorithm_original = challenge.get("algorithm", "MD5")
algorithm = algorithm_original.upper()
opaque = challenge.get("opaque", "")
# Convert string values to bytes once
nonce_bytes = nonce.encode("utf-8")
realm_bytes = realm.encode("utf-8")
# Use the encoded request-target (raw_path_qs) since that is what is
# transmitted on the wire and what the server signs against. Using the
# decoded form would cause digest verification to fail when the path
# or query string contains percent-encoded reserved characters.
path = URL(url).raw_path_qs
# Process QoP
qop = ""
qop_bytes = b""
if qop_raw:
valid_qops = {"auth", "auth-int"}.intersection(
{q.strip() for q in qop_raw.split(",") if q.strip()}
)
if not valid_qops:
raise ClientError(
f"Digest auth error: Unsupported Quality of Protection (qop) value(s): {qop_raw}"
)
qop = "auth-int" if "auth-int" in valid_qops else "auth"
qop_bytes = qop.encode("utf-8")
if algorithm not in DigestFunctions:
raise ClientError(
f"Digest auth error: Unsupported hash algorithm: {algorithm}. "
f"Supported algorithms: {', '.join(SUPPORTED_ALGORITHMS)}"
)
hash_fn: Final = DigestFunctions[algorithm]
def H(x: bytes) -> bytes:
"""RFC 7616 Section 3: Hash function H(data) = hex(hash(data))."""
return hash_fn(x).hexdigest().encode()
def KD(s: bytes, d: bytes) -> bytes:
"""RFC 7616 Section 3: KD(secret, data) = H(concat(secret, ":", data))."""
return H(b":".join((s, d)))
# Calculate A1 and A2
A1 = b":".join((self._login_bytes, realm_bytes, self._password_bytes))
A2 = f"{method.upper()}:{path}".encode()
if qop == "auth-int":
if isinstance(body, Payload): # will always be empty bytes unless Payload
entity_bytes = await body.as_bytes() # Get bytes from Payload
else:
entity_bytes = body
entity_hash = H(entity_bytes)
A2 = b":".join((A2, entity_hash))
HA1 = H(A1)
HA2 = H(A2)
# Nonce count handling
if nonce_bytes == self._last_nonce_bytes:
self._nonce_count += 1
else:
self._nonce_count = 1
self._last_nonce_bytes = nonce_bytes
ncvalue = f"{self._nonce_count:08x}"
ncvalue_bytes = ncvalue.encode("utf-8")
# Generate client nonce
cnonce = hashlib.sha1(
b"".join(
[
str(self._nonce_count).encode("utf-8"),
nonce_bytes,
time.ctime().encode("utf-8"),
os.urandom(8),
]
)
).hexdigest()[:16]
cnonce_bytes = cnonce.encode("utf-8")
# Special handling for session-based algorithms
if algorithm.upper().endswith("-SESS"):
HA1 = H(b":".join((HA1, nonce_bytes, cnonce_bytes)))
# Calculate the response digest
if qop:
noncebit = b":".join(
(nonce_bytes, ncvalue_bytes, cnonce_bytes, qop_bytes, HA2)
)
response_digest = KD(HA1, noncebit)
else:
response_digest = KD(HA1, b":".join((nonce_bytes, HA2)))
# Define a dict mapping of header fields to their values
# Group fields into always-present, optional, and qop-dependent
header_fields = {
# Always present fields
"username": escape_quotes(self._login_str),
"realm": escape_quotes(realm),
"nonce": escape_quotes(nonce),
"uri": path,
"response": response_digest.decode(),
"algorithm": algorithm_original,
}
# Optional fields
if opaque:
header_fields["opaque"] = escape_quotes(opaque)
# QoP-dependent fields
if qop:
header_fields["qop"] = qop
header_fields["nc"] = ncvalue
header_fields["cnonce"] = cnonce
# Build header using templates for each field type
pairs: list[str] = []
for field, value in header_fields.items():
if field in QUOTED_AUTH_FIELDS:
pairs.append(f'{field}="{value}"')
else:
pairs.append(f"{field}={value}")
return f"Digest {', '.join(pairs)}"
def _in_protection_space(self, url: URL) -> bool:
"""
Check if the given URL is within the current protection space.
According to RFC 7616, a URI is in the protection space if any URI
in the protection space is a prefix of it (after both have been made absolute).
"""
request_str = str(url)
for space_str in self._protection_space:
# Check if request starts with space URL
if not request_str.startswith(space_str):
continue
# Exact match or space ends with / (proper directory prefix)
if len(request_str) == len(space_str) or space_str[-1] == "/":
return True
# Check next char is / to ensure proper path boundary
if request_str[len(space_str)] == "/":
return True
return False
def _authenticate(self, response: ClientResponse) -> bool:
"""
Takes the given response and tries digest-auth, if needed.
Returns true if the original request must be resent.
"""
if response.status != 401:
return False
auth_header = response.headers.get("www-authenticate", "")
if not auth_header:
return False # No authentication header present
method, sep, headers = auth_header.partition(" ")
if not sep:
# No space found in www-authenticate header
return False # Malformed auth header, missing scheme separator
if method.lower() != "digest":
# Not a digest auth challenge (could be Basic, Bearer, etc.)
return False
if not headers:
# We have a digest scheme but no parameters
return False # Malformed digest header, missing parameters
# We have a digest auth header with content
if not (header_pairs := parse_header_pairs(headers)):
# Failed to parse any key-value pairs
return False # Malformed digest header, no valid parameters
# Extract challenge parameters
self._challenge = {}
for field in CHALLENGE_FIELDS:
if (value := header_pairs.get(field)) is not None:
self._challenge[field] = value
# Update protection space based on domain parameter or default to origin
origin = response.url.origin()
if domain := self._challenge.get("domain"):
# Parse space-separated list of URIs
self._protection_space = []
for uri in domain.split():
# Remove quotes if present
uri = uri.strip('"')
if uri.startswith("/"):
# Path-absolute, relative to origin
self._protection_space.append(str(origin.join(URL(uri))))
else:
# Absolute URI
self._protection_space.append(str(URL(uri)))
else:
# No domain specified, protection space is entire origin
self._protection_space = [str(origin)]
# Return True only if we found at least one challenge parameter
return bool(self._challenge)
async def __call__(
self, request: ClientRequest, handler: ClientHandlerType
) -> ClientResponse:
"""Run the digest auth middleware."""
# Credentials are scoped to the first request's origin. Other origins
# pass through untouched unless a challenge from the anchor origin
# advertised them via RFC 7616 domain; mirrors aiohttp stripping
# Authorization on cross-origin redirects.
origin = request.url.origin()
if self._origin is None:
self._origin = origin
elif origin != self._origin and not self._in_protection_space(request.url):
return await handler(request)
response = None
for retry_count in range(2):
# Apply authorization header if:
# 1. This is a retry after 401 (retry_count > 0), OR
# 2. Preemptive auth is enabled AND we have a challenge AND the URL is in protection space
if retry_count > 0 or (
self._preemptive
and self._challenge
and self._in_protection_space(request.url)
):
request.headers[hdrs.AUTHORIZATION] = await self._encode(
request.method, request.url, request.body
)
# Send the request
response = await handler(request)
# Check if we need to authenticate
if not self._authenticate(response):
break
# At this point, response is guaranteed to be defined
assert response is not None
return response

View File

@ -0,0 +1,55 @@
"""Client middleware support."""
from collections.abc import Awaitable, Callable, Sequence
from .client_reqrep import ClientRequest, ClientResponse
__all__ = ("ClientMiddlewareType", "ClientHandlerType", "build_client_middlewares")
# Type alias for client request handlers - functions that process requests and return responses
ClientHandlerType = Callable[[ClientRequest], Awaitable[ClientResponse]]
# Type for client middleware - similar to server but uses ClientRequest/ClientResponse
ClientMiddlewareType = Callable[
[ClientRequest, ClientHandlerType], Awaitable[ClientResponse]
]
def build_client_middlewares(
handler: ClientHandlerType,
middlewares: Sequence[ClientMiddlewareType],
) -> ClientHandlerType:
"""
Apply middlewares to request handler.
The middlewares are applied in reverse order, so the first middleware
in the list wraps all subsequent middlewares and the handler.
This implementation avoids using partial/update_wrapper to minimize overhead
and doesn't cache to avoid holding references to stateful middleware.
"""
# Optimize for single middleware case
if len(middlewares) == 1:
middleware = middlewares[0]
async def single_middleware_handler(req: ClientRequest) -> ClientResponse:
return await middleware(req, handler)
return single_middleware_handler
# Build the chain for multiple middlewares
current_handler = handler
for middleware in reversed(middlewares):
# Create a new closure that captures the current state
def make_wrapper(
mw: ClientMiddlewareType, next_h: ClientHandlerType
) -> ClientHandlerType:
async def wrapped(req: ClientRequest) -> ClientResponse:
return await mw(req, next_h)
return wrapped
current_handler = make_wrapper(middleware, current_handler)
return current_handler

View File

@ -0,0 +1,370 @@
import asyncio
from contextlib import suppress
from typing import Any, Callable
from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientConnectionError,
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
SocketTimeoutError,
)
from .helpers import (
_EXC_SENTINEL,
DEFAULT_CHUNK_SIZE,
EMPTY_BODY_STATUS_CODES,
BaseTimerContext,
set_exception,
set_result,
)
from .http import HttpResponseParser, RawResponseMessage
from .http_exceptions import HttpProcessingError
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader
class ResponseHandler(BaseProtocol, DataQueue[tuple[RawResponseMessage, StreamReader]]):
"""Helper class to adapt between Protocol and StreamReader."""
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
BaseProtocol.__init__(self, loop=loop, parser=None)
DataQueue.__init__(self, loop)
self._should_close = False
self._payload: StreamReader | None = None
self._skip_payload = False
self._payload_parser = None
self._data_received_cb: Callable[[], None] | None = None
self._timer = None
self._tail = b""
self._read_timeout: float | None = None
self._read_timeout_handle: asyncio.TimerHandle | None = None
self._timeout_ceil_threshold: float | None = 5
self._closed: None | asyncio.Future[None] = None
self._connection_lost_called = False
@property
def closed(self) -> None | asyncio.Future[None]:
"""Future that is set when the connection is closed.
This property returns a Future that will be completed when the connection
is closed. The Future is created lazily on first access to avoid creating
futures that will never be awaited.
Returns:
- A Future[None] if the connection is still open or was closed after
this property was accessed
- None if connection_lost() was already called before this property
was ever accessed (indicating no one is waiting for the closure)
"""
if self._closed is None and not self._connection_lost_called:
self._closed = self._loop.create_future()
return self._closed
@property
def upgraded(self) -> bool:
return self._upgraded
@property
def should_close(self) -> bool:
return bool(
self._should_close
or (self._payload is not None and not self._payload.is_eof())
or self._upgraded
or self._exception is not None
or self._payload_parser is not None
or self._buffer
or self._tail
)
def force_close(self) -> None:
self._should_close = True
def close(self) -> None:
self._exception = None # Break cyclic references
transport = self.transport
if transport is not None:
transport.close()
self.transport = None
self._payload = None
self._drop_timeout()
def abort(self) -> None:
self._exception = None # Break cyclic references
transport = self.transport
if transport is not None:
transport.abort()
self.transport = None
self._payload = None
self._drop_timeout()
def is_connected(self) -> bool:
return self.transport is not None and not self.transport.is_closing()
def connection_lost(self, exc: BaseException | None) -> None:
self._connection_lost_called = True
self._drop_timeout()
original_connection_error = exc
reraised_exc = original_connection_error
connection_closed_cleanly = original_connection_error is None
if self._closed is not None:
# If someone is waiting for the closed future,
# we should set it to None or an exception. If
# self._closed is None, it means that
# connection_lost() was called already
# or nobody is waiting for it.
if connection_closed_cleanly:
set_result(self._closed, None)
else:
assert original_connection_error is not None
set_exception(
self._closed,
ClientConnectionError(
f"Connection lost: {original_connection_error !s}",
),
original_connection_error,
)
if self._payload_parser is not None:
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()
uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception as underlying_exc:
if self._payload is not None:
client_payload_exc_msg = (
f"Response payload is not completed: {underlying_exc !r}"
)
if not connection_closed_cleanly:
client_payload_exc_msg = (
f"{client_payload_exc_msg !s}. "
f"{original_connection_error !r}"
)
set_exception(
self._payload,
ClientPayloadError(client_payload_exc_msg),
underlying_exc,
)
if not self.is_eof():
if isinstance(original_connection_error, OSError):
reraised_exc = ClientOSError(*original_connection_error.args)
if connection_closed_cleanly:
reraised_exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
underlying_non_eof_exc = (
_EXC_SENTINEL
if connection_closed_cleanly
else original_connection_error
)
assert underlying_non_eof_exc is not None
assert reraised_exc is not None
self.set_exception(reraised_exc, underlying_non_eof_exc)
self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False
super().connection_lost(reraised_exc)
def eof_received(self) -> None:
# should call parser.feed_eof() most likely
self._drop_timeout()
def pause_reading(self) -> None:
super().pause_reading()
self._drop_timeout()
def resume_reading(self, resume_parser: bool = True) -> None:
super().resume_reading(resume_parser)
self._reschedule_timeout()
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc, exc_cause)
def set_parser(
self,
parser: Any,
payload: Any,
data_received_cb: Callable[[], None] | None = None,
) -> None:
# TODO: actual types are:
# parser: WebSocketReader
# payload: WebSocketDataQueue
# but they are not generi enough
# Need an ABC for both types
self._payload = payload
self._payload_parser = parser
self._data_received_cb = data_received_cb
self._drop_timeout()
if self._tail:
data, self._tail = self._tail, b""
self.data_received(data)
def set_response_params(
self,
*,
timer: BaseTimerContext | None = None,
skip_payload: bool = False,
read_until_eof: bool = False,
auto_decompress: bool = True,
read_timeout: float | None = None,
read_bufsize: int = DEFAULT_CHUNK_SIZE,
timeout_ceil_threshold: float = 5,
max_line_size: int = 8190,
max_field_size: int = 8190,
max_headers: int = 128,
) -> None:
self._skip_payload = skip_payload
self._read_timeout = read_timeout
self._timeout_ceil_threshold = timeout_ceil_threshold
self._parser = HttpResponseParser(
self,
self._loop,
read_bufsize,
timer=timer,
payload_exception=ClientPayloadError,
response_with_body=not skip_payload,
read_until_eof=read_until_eof,
auto_decompress=auto_decompress,
max_line_size=max_line_size,
max_field_size=max_field_size,
max_headers=max_headers,
)
if self._tail:
data, self._tail = self._tail, b""
self.data_received(data)
def _drop_timeout(self) -> None:
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
self._read_timeout_handle = None
def _reschedule_timeout(self) -> None:
timeout = self._read_timeout
if self._read_timeout_handle is not None:
self._read_timeout_handle.cancel()
if timeout:
self._read_timeout_handle = self._loop.call_later(
timeout, self._on_read_timeout
)
else:
self._read_timeout_handle = None
def start_timeout(self) -> None:
self._reschedule_timeout()
@property
def read_timeout(self) -> float | None:
return self._read_timeout
@read_timeout.setter
def read_timeout(self, read_timeout: float | None) -> None:
self._read_timeout = read_timeout
def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
set_exception(self._payload, exc)
def data_received(self, data: bytes) -> None:
# If no data, then we are resuming decompression. We haven't received
# data from the socket, so we can avoid the reschedule overhead.
if data:
self._reschedule_timeout()
# custom payload parser - currently always WebSocketReader
if self._payload_parser is not None:
if self._data_received_cb is not None:
self._data_received_cb()
eof, tail = self._payload_parser.feed_data(data)
if eof:
self._payload = None
self._payload_parser = None
if tail:
self.data_received(tail)
return
if self._upgraded or self._parser is None:
# i.e. websocket connection, websocket parser is not set yet
self._tail += data
return
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
if not isinstance(underlying_exc, Exception):
raise
# should_close is True after the call
if isinstance(underlying_exc, HttpProcessingError):
exc = HttpProcessingError(
code=underlying_exc.code,
message=underlying_exc.message,
headers=underlying_exc.headers,
)
else:
exc = HttpProcessingError()
self.set_exception(exc, underlying_exc)
return
self._upgraded = upgraded
payload: StreamReader | None = None
for message, payload in messages:
if message.should_close:
self._should_close = True
self._payload = payload
if self._skip_payload or message.code in EMPTY_BODY_STATUS_CODES:
self.feed_data((message, EMPTY_PAYLOAD), 0)
else:
self.feed_data((message, payload), 0)
if payload is not None:
# new message(s) was processed
# register timeout handler unsubscribing
# either on end-of-stream or immediately for
# EMPTY_PAYLOAD
if payload is not EMPTY_PAYLOAD:
payload.on_eof(self._drop_timeout)
else:
self._drop_timeout()
if upgraded and tail:
self.data_received(tail)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,560 @@
"""WebSocket client for asyncio."""
import asyncio
import sys
from collections.abc import Callable
from types import TracebackType
from typing import Any, Generic, Literal, Optional, cast, overload
import attr
from ._websocket.reader import WebSocketDataQueue
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
from .client_reqrep import ClientResponse
from .helpers import calculate_timeout_when, set_result
from .http import (
WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE,
WebSocketError,
WSCloseCode,
WSMessage,
WSMessageDecodeText,
WSMessageNoDecodeText,
WSMsgType,
)
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
from .streams import EofStream
from .typedefs import (
DEFAULT_JSON_DECODER,
DEFAULT_JSON_ENCODER,
JSONBytesEncoder,
JSONDecoder,
JSONEncoder,
)
if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar
if sys.version_info >= (3, 11):
import asyncio as async_timeout
from typing import Self
else:
import async_timeout
from typing_extensions import Self
# TypeVar for whether text messages are decoded to str (True) or kept as bytes (False)
# Covariant because it only affects return types, not input types
_DecodeText = TypeVar("_DecodeText", bound=bool, covariant=True, default=Literal[True])
@attr.s(frozen=True, slots=True)
class ClientWSTimeout:
ws_receive = attr.ib(type=Optional[float], default=None)
ws_close = attr.ib(type=Optional[float], default=None)
DEFAULT_WS_CLIENT_TIMEOUT = ClientWSTimeout(ws_receive=None, ws_close=10.0)
class ClientWebSocketResponse(Generic[_DecodeText]):
def __init__(
self,
reader: WebSocketDataQueue,
writer: WebSocketWriter,
protocol: str | None,
response: ClientResponse,
timeout: ClientWSTimeout,
autoclose: bool,
autoping: bool,
loop: asyncio.AbstractEventLoop,
*,
heartbeat: float | None = None,
compress: int = 0,
client_notakeover: bool = False,
) -> None:
self._response = response
self._conn = response.connection
self._writer = writer
self._reader = reader
self._protocol = protocol
self._closed = False
self._closing = False
self._close_code: int | None = None
self._timeout = timeout
self._autoclose = autoclose
self._autoping = autoping
self._heartbeat = heartbeat
self._heartbeat_cb: asyncio.TimerHandle | None = None
self._heartbeat_when: float = 0.0
if heartbeat is not None:
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb: asyncio.TimerHandle | None = None
self._loop = loop
self._waiting: bool = False
self._close_wait: asyncio.Future[None] | None = None
self._exception: BaseException | None = None
self._compress = compress
self._client_notakeover = client_notakeover
self._ping_task: asyncio.Task[None] | None = None
self._need_heartbeat_reset = False
self._heartbeat_reset_handle: asyncio.Handle | None = None
self._reset_heartbeat()
def _cancel_heartbeat(self) -> None:
self._cancel_pong_response_cb()
if self._heartbeat_reset_handle is not None:
self._heartbeat_reset_handle.cancel()
self._heartbeat_reset_handle = None
self._need_heartbeat_reset = False
if self._heartbeat_cb is not None:
self._heartbeat_cb.cancel()
self._heartbeat_cb = None
if self._ping_task is not None:
self._ping_task.cancel()
self._ping_task = None
def _cancel_pong_response_cb(self) -> None:
if self._pong_response_cb is not None:
self._pong_response_cb.cancel()
self._pong_response_cb = None
def _on_data_received(self) -> None:
if self._heartbeat is None or self._need_heartbeat_reset:
return
loop = self._loop
assert loop is not None
# Coalesce multiple chunks received in the same loop tick into a single
# heartbeat reset. Resetting immediately per chunk increases timer churn.
self._need_heartbeat_reset = True
self._heartbeat_reset_handle = loop.call_soon(self._flush_heartbeat_reset)
def _flush_heartbeat_reset(self) -> None:
self._heartbeat_reset_handle = None
if not self._need_heartbeat_reset:
return
self._reset_heartbeat()
self._need_heartbeat_reset = False
def _reset_heartbeat(self) -> None:
if self._heartbeat is None:
return
self._cancel_pong_response_cb()
loop = self._loop
assert loop is not None
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
now = loop.time()
when = calculate_timeout_when(now, self._heartbeat, timeout_ceil_threshold)
self._heartbeat_when = when
if self._heartbeat_cb is None:
# We do not cancel the previous heartbeat_cb here because
# it generates a significant amount of TimerHandle churn
# which causes asyncio to rebuild the heap frequently.
# Instead _send_heartbeat() will reschedule the next
# heartbeat if it fires too early.
self._heartbeat_cb = loop.call_at(when, self._send_heartbeat)
def _send_heartbeat(self) -> None:
self._heartbeat_cb = None
# If heartbeat reset is pending (data is being received), skip sending
# the ping and let the reset callback handle rescheduling the heartbeat.
if self._need_heartbeat_reset:
return
loop = self._loop
now = loop.time()
if now < self._heartbeat_when:
# Heartbeat fired too early, reschedule
self._heartbeat_cb = loop.call_at(
self._heartbeat_when, self._send_heartbeat
)
return
conn = self._conn
timeout_ceil_threshold = (
conn._connector._timeout_ceil_threshold if conn is not None else 5
)
when = calculate_timeout_when(now, self._pong_heartbeat, timeout_ceil_threshold)
self._cancel_pong_response_cb()
self._pong_response_cb = loop.call_at(when, self._pong_not_received)
coro = self._writer.send_frame(b"", WSMsgType.PING)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to send the ping
# immediately to avoid having to schedule
# the task on the event loop.
ping_task = asyncio.Task(coro, loop=loop, eager_start=True)
else:
ping_task = loop.create_task(coro)
if not ping_task.done():
self._ping_task = ping_task
ping_task.add_done_callback(self._ping_task_done)
else:
self._ping_task_done(ping_task)
def _ping_task_done(self, task: "asyncio.Task[None]") -> None:
"""Callback for when the ping task completes."""
if not task.cancelled() and (exc := task.exception()):
self._handle_ping_pong_exception(exc)
self._ping_task = None
def _pong_not_received(self) -> None:
self._handle_ping_pong_exception(
ServerTimeoutError(f"No PONG received after {self._pong_heartbeat} seconds")
)
def _handle_ping_pong_exception(self, exc: BaseException) -> None:
"""Handle exceptions raised during ping/pong processing."""
if self._closed:
return
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
if self._waiting and not self._closing:
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
def _set_closed(self) -> None:
"""Set the connection to closed.
Cancel any heartbeat timers and set the closed flag.
"""
self._closed = True
self._cancel_heartbeat()
def _set_closing(self) -> None:
"""Set the connection to closing.
Cancel any heartbeat timers and set the closing flag.
"""
self._closing = True
self._cancel_heartbeat()
@property
def closed(self) -> bool:
return self._closed
@property
def close_code(self) -> int | None:
return self._close_code
@property
def protocol(self) -> str | None:
return self._protocol
@property
def compress(self) -> int:
return self._compress
@property
def client_notakeover(self) -> bool:
return self._client_notakeover
def get_extra_info(self, name: str, default: Any = None) -> Any:
"""extra info from connection transport"""
conn = self._response.connection
if conn is None:
return default
transport = conn.transport
if transport is None:
return default
return transport.get_extra_info(name, default)
def exception(self) -> BaseException | None:
return self._exception
async def ping(self, message: bytes = b"") -> None:
await self._writer.send_frame(message, WSMsgType.PING)
async def pong(self, message: bytes = b"") -> None:
await self._writer.send_frame(message, WSMsgType.PONG)
async def send_frame(
self, message: bytes, opcode: WSMsgType, compress: int | None = None
) -> None:
"""Send a frame over the websocket."""
await self._writer.send_frame(message, opcode, compress)
async def send_str(self, data: str, compress: int | None = None) -> None:
if not isinstance(data, str):
raise TypeError("data argument must be str (%r)" % type(data))
await self._writer.send_frame(
data.encode("utf-8"), WSMsgType.TEXT, compress=compress
)
async def send_bytes(self, data: bytes, compress: int | None = None) -> None:
if not isinstance(data, (bytes, bytearray, memoryview)):
raise TypeError("data argument must be byte-ish (%r)" % type(data))
await self._writer.send_frame(data, WSMsgType.BINARY, compress=compress)
async def send_json(
self,
data: Any,
compress: int | None = None,
*,
dumps: JSONEncoder = DEFAULT_JSON_ENCODER,
) -> None:
await self.send_str(dumps(data), compress=compress)
async def send_json_bytes(
self,
data: Any,
compress: int | None = None,
*,
dumps: JSONBytesEncoder,
) -> None:
"""Send JSON data using a bytes-returning encoder as a binary frame.
Use this when your JSON encoder (like orjson) returns bytes
instead of str, avoiding the encode/decode overhead.
"""
await self.send_bytes(dumps(data), compress=compress)
async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool:
# we need to break `receive()` cycle first,
# `close()` may be called from different task
if self._waiting and not self._closing:
assert self._loop is not None
self._close_wait = self._loop.create_future()
self._set_closing()
self._reader.feed_data(WS_CLOSING_MESSAGE, 0)
await self._close_wait
if self._closed:
return False
self._set_closed()
try:
await self._writer.close(code, message)
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if self._close_code:
self._response.close()
return True
while True:
try:
async with async_timeout.timeout(self._timeout.ws_close):
msg = await self._reader.read()
except asyncio.CancelledError:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._response.close()
raise
except Exception as exc:
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
self._exception = exc
self._response.close()
return True
if msg.type is WSMsgType.CLOSE:
self._close_code = msg.data
self._response.close()
return True
@overload
async def receive(
self: "ClientWebSocketResponse[Literal[True]]", timeout: float | None = None
) -> WSMessageDecodeText: ...
@overload
async def receive(
self: "ClientWebSocketResponse[Literal[False]]", timeout: float | None = None
) -> WSMessageNoDecodeText: ...
@overload
async def receive(
self: "ClientWebSocketResponse[_DecodeText]", timeout: float | None = None
) -> WSMessageDecodeText | WSMessageNoDecodeText: ...
async def receive(
self, timeout: float | None = None
) -> WSMessageDecodeText | WSMessageNoDecodeText:
receive_timeout = timeout or self._timeout.ws_receive
while True:
if self._waiting:
raise RuntimeError("Concurrent call to receive() is not allowed")
if self._closed:
return WS_CLOSED_MESSAGE
elif self._closing:
await self.close()
return WS_CLOSED_MESSAGE
try:
self._waiting = True
try:
if receive_timeout:
# Entering the context manager and creating
# Timeout() object can take almost 50% of the
# run time in this loop so we avoid it if
# there is no read timeout.
async with async_timeout.timeout(receive_timeout):
msg = await self._reader.read()
else:
msg = await self._reader.read()
finally:
self._waiting = False
if self._close_wait:
set_result(self._close_wait, None)
except (asyncio.CancelledError, asyncio.TimeoutError):
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
raise
except EofStream:
self._close_code = WSCloseCode.OK
await self.close()
return WSMessage(WSMsgType.CLOSED, None, None)
except ClientError:
# Likely ServerDisconnectedError when connection is lost
self._set_closed()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
return WS_CLOSED_MESSAGE
except WebSocketError as exc:
self._close_code = exc.code
await self.close(code=exc.code)
return WSMessage(WSMsgType.ERROR, exc, None)
except Exception as exc:
self._exception = exc
self._set_closing()
self._close_code = WSCloseCode.ABNORMAL_CLOSURE
await self.close()
return WSMessage(WSMsgType.ERROR, exc, None)
if msg.type not in _INTERNAL_RECEIVE_TYPES:
# If its not a close/closing/ping/pong message
# we can return it immediately
return msg
if msg.type is WSMsgType.CLOSE:
self._set_closing()
self._close_code = msg.data
if not self._closed and self._autoclose:
await self.close()
elif msg.type is WSMsgType.CLOSING:
self._set_closing()
elif msg.type is WSMsgType.PING and self._autoping:
await self.pong(msg.data)
continue
elif msg.type is WSMsgType.PONG and self._autoping:
continue
return msg
@overload
async def receive_str(
self: "ClientWebSocketResponse[Literal[True]]", *, timeout: float | None = None
) -> str: ...
@overload
async def receive_str(
self: "ClientWebSocketResponse[Literal[False]]", *, timeout: float | None = None
) -> bytes: ...
@overload
async def receive_str(
self: "ClientWebSocketResponse[_DecodeText]", *, timeout: float | None = None
) -> str | bytes: ...
async def receive_str(self, *, timeout: float | None = None) -> str | bytes:
"""Receive TEXT message.
Returns str when decode_text=True (default), bytes when decode_text=False.
"""
msg = await self.receive(timeout)
if msg.type is not WSMsgType.TEXT:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT"
)
return cast(str, msg.data)
async def receive_bytes(self, *, timeout: float | None = None) -> bytes:
msg = await self.receive(timeout)
if msg.type is not WSMsgType.BINARY:
raise WSMessageTypeError(
f"Received message {msg.type}:{msg.data!r} is not WSMsgType.BINARY"
)
return cast(bytes, msg.data)
@overload
async def receive_json(
self: "ClientWebSocketResponse[Literal[True]]",
*,
loads: JSONDecoder = ...,
timeout: float | None = None,
) -> Any: ...
@overload
async def receive_json(
self: "ClientWebSocketResponse[Literal[False]]",
*,
loads: Callable[[bytes], Any] = ...,
timeout: float | None = None,
) -> Any: ...
@overload
async def receive_json(
self: "ClientWebSocketResponse[_DecodeText]",
*,
loads: JSONDecoder | Callable[[bytes], Any] = ...,
timeout: float | None = None,
) -> Any: ...
async def receive_json(
self,
*,
loads: JSONDecoder | Callable[[bytes], Any] = DEFAULT_JSON_DECODER,
timeout: float | None = None,
) -> Any:
data = await self.receive_str(timeout=timeout)
return loads(data) # type: ignore[arg-type]
def __aiter__(self) -> Self:
return self
@overload
async def __anext__(
self: "ClientWebSocketResponse[Literal[True]]",
) -> WSMessageDecodeText: ...
@overload
async def __anext__(
self: "ClientWebSocketResponse[Literal[False]]",
) -> WSMessageNoDecodeText: ...
@overload
async def __anext__(
self: "ClientWebSocketResponse[_DecodeText]",
) -> WSMessageDecodeText | WSMessageNoDecodeText: ...
async def __anext__(self) -> WSMessageDecodeText | WSMessageNoDecodeText:
msg = await self.receive()
if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED):
raise StopAsyncIteration
return msg
async def __aenter__(self) -> Self:
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.close()

View File

@ -0,0 +1,447 @@
import asyncio
import sys
import zlib
from abc import ABC, abstractmethod
from concurrent.futures import Executor
from typing import Any, Final, Protocol, TypedDict, cast
if sys.version_info >= (3, 12):
from collections.abc import Buffer
else:
from typing import Union
Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
try:
try:
import brotlicffi as brotli
except ImportError:
import brotli
HAS_BROTLI = True
except ImportError: # pragma: no cover
HAS_BROTLI = False
try:
if sys.version_info >= (3, 14):
from compression.zstd import ZstdDecompressor # noqa: I900
else: # TODO(PY314): Remove mentions of backports.zstd across codebase
from backports.zstd import ZstdDecompressor
HAS_ZSTD = True
except ImportError:
HAS_ZSTD = False
MAX_SYNC_CHUNK_SIZE = 4096
# Unlimited decompression constants - different libraries use different conventions
ZLIB_MAX_LENGTH_UNLIMITED = 0 # zlib uses 0 to mean unlimited
ZSTD_MAX_LENGTH_UNLIMITED = -1 # zstd uses -1 to mean unlimited
class ZLibCompressObjProtocol(Protocol):
def compress(self, data: Buffer) -> bytes: ...
def flush(self, mode: int = ..., /) -> bytes: ...
class ZLibDecompressObjProtocol(Protocol):
def decompress(self, data: Buffer, max_length: int = ...) -> bytes: ...
def flush(self, length: int = ..., /) -> bytes: ...
@property
def eof(self) -> bool: ...
@property
def unconsumed_tail(self) -> bytes: ...
@property
def unused_data(self) -> bytes: ...
class ZLibBackendProtocol(Protocol):
MAX_WBITS: int
Z_FULL_FLUSH: int
Z_SYNC_FLUSH: int
Z_BEST_SPEED: int
Z_FINISH: int
def compressobj(
self,
level: int = ...,
method: int = ...,
wbits: int = ...,
memLevel: int = ...,
strategy: int = ...,
zdict: Buffer | None = ...,
) -> ZLibCompressObjProtocol: ...
def decompressobj(
self, wbits: int = ..., zdict: Buffer = ...
) -> ZLibDecompressObjProtocol: ...
def compress(
self, data: Buffer, /, level: int = ..., wbits: int = ...
) -> bytes: ...
def decompress(
self, data: Buffer, /, wbits: int = ..., bufsize: int = ...
) -> bytes: ...
class CompressObjArgs(TypedDict, total=False):
wbits: int
strategy: int
level: int
class ZLibBackendWrapper:
def __init__(self, _zlib_backend: ZLibBackendProtocol):
self._zlib_backend: ZLibBackendProtocol = _zlib_backend
@property
def name(self) -> str:
return getattr(self._zlib_backend, "__name__", "undefined")
@property
def MAX_WBITS(self) -> int:
return self._zlib_backend.MAX_WBITS
@property
def Z_FULL_FLUSH(self) -> int:
return self._zlib_backend.Z_FULL_FLUSH
@property
def Z_SYNC_FLUSH(self) -> int:
return self._zlib_backend.Z_SYNC_FLUSH
@property
def Z_BEST_SPEED(self) -> int:
return self._zlib_backend.Z_BEST_SPEED
@property
def Z_FINISH(self) -> int:
return self._zlib_backend.Z_FINISH
def compressobj(self, *args: Any, **kwargs: Any) -> ZLibCompressObjProtocol:
return self._zlib_backend.compressobj(*args, **kwargs)
def decompressobj(self, *args: Any, **kwargs: Any) -> ZLibDecompressObjProtocol:
return self._zlib_backend.decompressobj(*args, **kwargs)
def compress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
return self._zlib_backend.compress(data, *args, **kwargs)
def decompress(self, data: Buffer, *args: Any, **kwargs: Any) -> bytes:
return self._zlib_backend.decompress(data, *args, **kwargs)
# Everything not explicitly listed in the Protocol we just pass through
def __getattr__(self, attrname: str) -> Any:
return getattr(self._zlib_backend, attrname)
ZLibBackend: ZLibBackendWrapper = ZLibBackendWrapper(zlib)
def set_zlib_backend(new_zlib_backend: ZLibBackendProtocol) -> None:
ZLibBackend._zlib_backend = new_zlib_backend
def encoding_to_mode(
encoding: str | None = None,
suppress_deflate_header: bool = False,
) -> int:
if encoding == "gzip":
return 16 + ZLibBackend.MAX_WBITS
return -ZLibBackend.MAX_WBITS if suppress_deflate_header else ZLibBackend.MAX_WBITS
class DecompressionBaseHandler(ABC):
def __init__(
self,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
):
"""Base class for decompression handlers."""
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size
@abstractmethod
def decompress_sync(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""
async def decompress(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.decompress_sync, data, max_length
)
return self.decompress_sync(data, max_length)
@property
@abstractmethod
def data_available(self) -> bool:
"""Return True if more output is available by passing b""."""
class ZLibCompressor:
def __init__(
self,
encoding: str | None = None,
suppress_deflate_header: bool = False,
level: int | None = None,
wbits: int | None = None,
strategy: int | None = None,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
):
self._executor = executor
self._max_sync_chunk_size = max_sync_chunk_size
self._mode = (
encoding_to_mode(encoding, suppress_deflate_header)
if wbits is None
else wbits
)
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
kwargs: CompressObjArgs = {}
kwargs["wbits"] = self._mode
if strategy is not None:
kwargs["strategy"] = strategy
if level is not None:
kwargs["level"] = level
self._compressor = self._zlib_backend.compressobj(**kwargs)
def compress_sync(self, data: Buffer) -> bytes:
return self._compressor.compress(data)
async def compress(self, data: Buffer) -> bytes:
"""Compress the data and returned the compressed bytes.
Note that flush() must be called after the last call to compress()
If the data size is large than the max_sync_chunk_size, the compression
will be done in the executor. Otherwise, the compression will be done
in the event loop.
**WARNING: This method is NOT cancellation-safe when used with flush().**
If this operation is cancelled, the compressor state may be corrupted.
The connection MUST be closed after cancellation to avoid data corruption
in subsequent compress operations.
For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
# For large payloads, offload compression to executor to avoid blocking event loop
should_use_executor = (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
)
if should_use_executor:
return await asyncio.get_running_loop().run_in_executor(
self._executor, self._compressor.compress, data
)
return self.compress_sync(data)
def flush(self, mode: int | None = None) -> bytes:
"""Flush the compressor synchronously.
**WARNING: This method is NOT cancellation-safe when called after compress().**
The flush() operation accesses shared compressor state. If compress() was
cancelled, calling flush() may result in corrupted data. The connection MUST
be closed after compress() cancellation.
For cancellation-safe compression (e.g., WebSocket), the caller MUST wrap
compress() + flush() + send operations in a shield and lock to ensure atomicity.
"""
return self._compressor.flush(
mode if mode is not None else self._zlib_backend.Z_FINISH
)
class ZLibDecompressor(DecompressionBaseHandler):
def __init__(
self,
encoding: str | None = None,
suppress_deflate_header: bool = False,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
):
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
self._mode = encoding_to_mode(encoding, suppress_deflate_header)
self._zlib_backend: Final = ZLibBackendWrapper(ZLibBackend._zlib_backend)
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
self._last_empty = False
self._pending_unused_data: bytes | None = None
def decompress_sync(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
if self._pending_unused_data is not None:
data = self._pending_unused_data + bytes(data)
self._pending_unused_data = None
result = self._decompressor.decompress(
self._decompressor.unconsumed_tail + data, max_length
)
# Only way to know that isal has no further data is checking we get no output
self._last_empty = result == b""
# Handle concatenated gzip/deflate streams (multi-member).
# After a member ends, unused_data holds the start of the next member.
# Create a fresh decompressor for each subsequent member.
while self._decompressor.eof and self._decompressor.unused_data:
unused = self._decompressor.unused_data
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
if max_length != ZLIB_MAX_LENGTH_UNLIMITED:
max_length -= len(result)
if max_length <= 0:
self._pending_unused_data = unused
break
chunk = self._decompressor.decompress(unused, max_length)
self._last_empty = chunk == b""
result += chunk
# Member ended exactly at chunk boundary — no unused_data, but the
# next feed_data() call would fail on the spent decompressor.
# Only reset for gzip; deflate's feed_eof() relies on eof=True to
# confirm the stream is complete.
if self._decompressor.eof and self._mode > self._zlib_backend.MAX_WBITS:
self._decompressor = self._zlib_backend.decompressobj(wbits=self._mode)
return result
def flush(self, length: int = 0) -> bytes:
return (
self._decompressor.flush(length)
if length > 0
else self._decompressor.flush()
)
@property
def data_available(self) -> bool:
return (
bool(self._decompressor.unconsumed_tail)
or not self._last_empty
or self._pending_unused_data is not None
)
@property
def eof(self) -> bool:
return self._decompressor.eof
class BrotliDecompressor(DecompressionBaseHandler):
# Supports both 'brotlipy' and 'Brotli' packages
# since they share an import name. The top branches
# are for 'brotlipy' and bottom branches for 'Brotli'
def __init__(
self,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
) -> None:
"""Decompress data using the Brotli library."""
if not HAS_BROTLI:
raise RuntimeError(
"The brotli decompression is not available. "
"Please install `Brotli` module"
)
self._obj = brotli.Decompressor()
self._last_empty = False
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
def decompress_sync(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
"""Decompress the given data."""
if hasattr(self._obj, "decompress"):
if max_length == ZLIB_MAX_LENGTH_UNLIMITED:
result = cast(bytes, self._obj.decompress(data))
else:
result = cast(bytes, self._obj.decompress(data, max_length))
else:
if max_length == ZLIB_MAX_LENGTH_UNLIMITED:
result = cast(bytes, self._obj.process(data))
else:
result = cast(bytes, self._obj.process(data, max_length))
# Only way to know that brotli has no further data is checking we get no output
self._last_empty = result == b""
return result
def flush(self) -> bytes:
"""Flush the decompressor."""
if hasattr(self._obj, "flush"):
return cast(bytes, self._obj.flush())
return b""
@property
def data_available(self) -> bool:
return not self._obj.is_finished() and not self._last_empty
class ZSTDDecompressor(DecompressionBaseHandler):
def __init__(
self,
executor: Executor | None = None,
max_sync_chunk_size: int | None = MAX_SYNC_CHUNK_SIZE,
) -> None:
if not HAS_ZSTD:
raise RuntimeError(
"The zstd decompression is not available. "
"Please install `backports.zstd` module"
)
self._obj = ZstdDecompressor()
self._pending_unused_data: bytes | None = None
super().__init__(executor=executor, max_sync_chunk_size=max_sync_chunk_size)
def decompress_sync(
self, data: Buffer, max_length: int = ZLIB_MAX_LENGTH_UNLIMITED
) -> bytes:
# zstd uses -1 for unlimited, while zlib uses 0 for unlimited
# Convert the zlib convention (0=unlimited) to zstd convention (-1=unlimited)
zstd_max_length = (
ZSTD_MAX_LENGTH_UNLIMITED
if max_length == ZLIB_MAX_LENGTH_UNLIMITED
else max_length
)
if self._pending_unused_data is not None:
data = self._pending_unused_data + data
self._pending_unused_data = None
result = self._obj.decompress(data, zstd_max_length)
# Handle multi-frame zstd streams.
# https://datatracker.ietf.org/doc/html/rfc8878#section-3.1.1
# ZstdDecompressor handles one frame only. When a frame ends,
# eof becomes True and any trailing data goes to unused_data.
# We create a fresh decompressor to continue with the next frame.
while self._obj.eof and self._obj.unused_data:
unused_data = self._obj.unused_data
self._obj = ZstdDecompressor()
if zstd_max_length != ZSTD_MAX_LENGTH_UNLIMITED:
zstd_max_length -= len(result)
if zstd_max_length <= 0:
self._pending_unused_data = unused_data
break
result += self._obj.decompress(unused_data, zstd_max_length)
# Frame ended exactly at chunk boundary — no unused_data, but the
# next feed_data() call would fail on the spent decompressor.
# Prepare a fresh one for the next chunk.
if self._obj.eof:
self._obj = ZstdDecompressor()
return result
def flush(self) -> bytes:
return b""
@property
def data_available(self) -> bool:
return (
not self._obj.needs_input and not self._obj.eof
) or self._pending_unused_data is not None

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,678 @@
import asyncio
import calendar
import contextlib
import datetime
import heapq
import itertools
import json
import os
import pathlib
import pickle
import re
import time
import warnings
from collections import defaultdict
from collections.abc import Iterable, Iterator, Mapping
from http.cookies import BaseCookie, Morsel, SimpleCookie
from types import MappingProxyType
from typing import Union
from yarl import URL
from ._cookie_helpers import preserve_morsel_with_coded_value
from .abc import AbstractCookieJar, ClearCookiePredicate
from .helpers import is_ip_address
from .typedefs import LooseCookies, PathLike, StrOrURL
__all__ = ("CookieJar", "DummyCookieJar")
CookieItem = Union[str, "Morsel[str]"]
# We cache these string methods here as their use is in performance critical code.
_FORMAT_PATH = "{}/{}".format
_FORMAT_DOMAIN_REVERSED = "{1}.{0}".format
# The minimum number of scheduled cookie expirations before we start cleaning up
# the expiration heap. This is a performance optimization to avoid cleaning up the
# heap too often when there are only a few scheduled expirations.
_MIN_SCHEDULED_COOKIE_EXPIRATION = 100
_SIMPLE_COOKIE = SimpleCookie()
# Not persisted; the absolute deadline is saved instead.
_RELATIVE_EXPIRY_ATTRS = frozenset(("max-age", "expires"))
class _RestrictedCookieUnpickler(pickle._Unpickler):
"""A restricted unpickler that only allows cookie-related types.
This prevents arbitrary code execution when loading pickled cookie data
from untrusted sources. Only types that are expected in a serialized
CookieJar are permitted.
Subclasses :class:`pickle._Unpickler` (the pure-Python implementation)
rather than :class:`pickle.Unpickler` because the accelerated unpickler
on some implementations (notably PyPy) does not dispatch through
:meth:`find_class` overrides.
See: https://docs.python.org/3/library/pickle.html#restricting-globals
"""
_ALLOWED_CLASSES: frozenset[tuple[str, str]] = frozenset(
{
# Core cookie types
("http.cookies", "SimpleCookie"),
("http.cookies", "Morsel"),
# Container types used by CookieJar._cookies
("collections", "defaultdict"),
# builtins that pickle uses for reconstruction
("builtins", "tuple"),
("builtins", "set"),
("builtins", "frozenset"),
("builtins", "dict"),
}
)
def find_class(self, module: str, name: str) -> type:
if (module, name) not in self._ALLOWED_CLASSES:
raise pickle.UnpicklingError(
f"Forbidden class: {module}.{name}. "
"CookieJar.load() only allows cookie-related types for security. "
"See https://docs.python.org/3/library/pickle.html#restricting-globals"
)
return super().find_class(module, name) # type: ignore[no-any-return]
class CookieJar(AbstractCookieJar):
"""Implements cookie storage adhering to RFC 6265."""
DATE_TOKENS_RE = re.compile(
r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*"
r"(?P<token>[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)"
)
DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})")
DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})")
DATE_MONTH_RE = re.compile(
"(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|(aug)|(sep)|(oct)|(nov)|(dec)",
re.I,
)
DATE_YEAR_RE = re.compile(r"(\d{2,4})")
# calendar.timegm() fails for timestamps after datetime.datetime.max
# Minus one as a loss of precision occurs when timestamp() is called.
MAX_TIME = (
int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1
)
try:
calendar.timegm(time.gmtime(MAX_TIME))
except OSError:
# Hit the maximum representable time on Windows
# https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64
MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1))
except OverflowError:
# #4515: datetime.max may not be representable on 32-bit platforms
MAX_TIME = 2**31 - 1
# Avoid minuses in the future, 3x faster
SUB_MAX_TIME = MAX_TIME - 1
def __init__(
self,
*,
unsafe: bool = False,
quote_cookie: bool = True,
treat_as_secure_origin: StrOrURL | list[StrOrURL] | None = None,
loop: asyncio.AbstractEventLoop | None = None,
) -> None:
super().__init__(loop=loop)
self._cookies: defaultdict[tuple[str, str], SimpleCookie] = defaultdict(
SimpleCookie
)
self._morsel_cache: defaultdict[tuple[str, str], dict[str, Morsel[str]]] = (
defaultdict(dict)
)
self._host_only_cookies: set[tuple[str, str]] = set()
self._unsafe = unsafe
self._quote_cookie = quote_cookie
if treat_as_secure_origin is None:
treat_as_secure_origin = []
elif isinstance(treat_as_secure_origin, URL):
treat_as_secure_origin = [treat_as_secure_origin.origin()]
elif isinstance(treat_as_secure_origin, str):
treat_as_secure_origin = [URL(treat_as_secure_origin).origin()]
else:
treat_as_secure_origin = [
URL(url).origin() if isinstance(url, str) else url.origin()
for url in treat_as_secure_origin
]
self._treat_as_secure_origin = treat_as_secure_origin
self._expire_heap: list[tuple[float, tuple[str, str, str]]] = []
self._expirations: dict[tuple[str, str, str], float] = {}
@property
def unsafe(self) -> bool:
return self._unsafe
@property
def quote_cookie(self) -> bool:
return self._quote_cookie
@property
def cookies(self) -> MappingProxyType[tuple[str, str], SimpleCookie]:
"""Return the cookies stored in this jar."""
return MappingProxyType(self._cookies)
@property
def host_only_cookies(self) -> frozenset[tuple[str, str]]:
"""Return the host-only cookies stored in this jar."""
return frozenset(self._host_only_cookies)
def save(self, file_path: PathLike) -> None:
"""Save cookies to a file using JSON format.
:param file_path: Path to file where cookies will be serialized,
:class:`str` or :class:`pathlib.Path` instance.
"""
file_path = pathlib.Path(file_path)
data: dict[str, dict[str, dict[str, str | bool | float]]] = {}
for (domain, path), cookie in self._cookies.items():
key = f"{domain}|{path}"
data[key] = {}
for name, morsel in cookie.items():
morsel_data: dict[str, str | bool | float] = {
"key": morsel.key,
"value": morsel.value,
"coded_value": morsel.coded_value,
}
# Skip relative expiry; the absolute deadline is saved below.
for attr in morsel._reserved: # type: ignore[attr-defined]
if attr in _RELATIVE_EXPIRY_ATTRS:
continue
attr_val = morsel[attr]
if attr_val:
morsel_data[attr] = attr_val
# Persist or it reloads as a domain cookie and leaks to subdomains.
if (domain, name) in self._host_only_cookies:
morsel_data["host_only"] = True
if (exp := self._expirations.get((domain, path, name))) is not None:
morsel_data["expires_timestamp"] = exp
data[key][name] = morsel_data
# Cookie persistence may include authentication/session tokens.
# Use 0o600 at creation time to avoid umask-dependent overexposure
# and enforce least-privilege access to sensitive credential data.
with open(
file_path,
mode="w",
encoding="utf-8",
opener=lambda path, flags: os.open(path, flags, 0o600),
) as f:
json.dump(data, f, indent=2)
def load(self, file_path: PathLike) -> None:
"""Load cookies from a file.
Tries to load JSON format first. Falls back to loading legacy
pickle format (using a restricted unpickler) for backward
compatibility with existing cookie files.
Replaces the current jar contents; loaded cookies pass through the
same acceptance rules as :meth:`update_cookies`.
:param file_path: Path to file from where cookies will be
imported, :class:`str` or :class:`pathlib.Path` instance.
"""
file_path = pathlib.Path(file_path)
# Try JSON format first
try:
with file_path.open(mode="r", encoding="utf-8") as f:
data = json.load(f)
self._load_json_data(data)
except (json.JSONDecodeError, UnicodeDecodeError, ValueError):
# Fall back to legacy pickle format with restricted unpickler
with file_path.open(mode="rb") as f:
self._cookies = _RestrictedCookieUnpickler(f).load()
def _load_json_data(
self, data: dict[str, dict[str, dict[str, str | bool | float]]]
) -> None:
"""Replace contents, routing cookies through update_cookies()."""
self.clear()
for compound_key, cookie_data in data.items():
domain, path = compound_key.split("|", 1)
for name, morsel_data in cookie_data.items():
morsel: Morsel[str] = Morsel()
# Use __setstate__ to bypass validation, same pattern
# used in _build_morsel and _cookie_helpers.
morsel.__setstate__( # type: ignore[attr-defined]
{
"key": morsel_data["key"],
"value": morsel_data["value"],
"coded_value": morsel_data["coded_value"],
}
)
# Restore morsel attributes
for attr in morsel._reserved: # type: ignore[attr-defined]
if attr in morsel_data and attr not in (
"key",
"value",
"coded_value",
):
morsel[attr] = morsel_data[attr]
# Drop the domain so update_cookies() re-marks it host-only.
if morsel_data.get("host_only"):
morsel["domain"] = ""
response_url = (
URL.build(scheme="https", host=domain) if domain else URL()
)
self.update_cookies({name: morsel}, response_url)
# Restore the absolute deadline; update_cookies() schedules none.
if (exp := morsel_data.get("expires_timestamp")) is not None:
self._expire_cookie(float(exp), domain, path, name)
self._do_expiration()
def clear(self, predicate: ClearCookiePredicate | None = None) -> None:
if predicate is None:
self._expire_heap.clear()
self._cookies.clear()
self._morsel_cache.clear()
self._host_only_cookies.clear()
self._expirations.clear()
return
now = time.time()
to_del = [
key
for (domain, path), cookie in self._cookies.items()
for name, morsel in cookie.items()
if (
(key := (domain, path, name)) in self._expirations
and self._expirations[key] <= now
)
or predicate(morsel)
]
if to_del:
self._delete_cookies(to_del)
def clear_domain(self, domain: str) -> None:
self.clear(lambda x: self._is_domain_match(domain, x["domain"]))
def __iter__(self) -> "Iterator[Morsel[str]]":
self._do_expiration()
for val in self._cookies.values():
yield from val.values()
def __len__(self) -> int:
"""Return number of cookies.
This function does not iterate self to avoid unnecessary expiration
checks.
"""
return sum(len(cookie.values()) for cookie in self._cookies.values())
def _do_expiration(self) -> None:
"""Remove expired cookies."""
if not (expire_heap_len := len(self._expire_heap)):
return
# If the expiration heap grows larger than the number expirations
# times two, we clean it up to avoid keeping expired entries in
# the heap and consuming memory. We guard this with a minimum
# threshold to avoid cleaning up the heap too often when there are
# only a few scheduled expirations.
if (
expire_heap_len > _MIN_SCHEDULED_COOKIE_EXPIRATION
and expire_heap_len > len(self._expirations) * 2
):
# Remove any expired entries from the expiration heap
# that do not match the expiration time in the expirations
# as it means the cookie has been re-added to the heap
# with a different expiration time.
self._expire_heap = [
entry
for entry in self._expire_heap
if self._expirations.get(entry[1]) == entry[0]
]
heapq.heapify(self._expire_heap)
now = time.time()
to_del: list[tuple[str, str, str]] = []
# Find any expired cookies and add them to the to-delete list
while self._expire_heap:
when, cookie_key = self._expire_heap[0]
if when > now:
break
heapq.heappop(self._expire_heap)
# Check if the cookie hasn't been re-added to the heap
# with a different expiration time as it will be removed
# later when it reaches the top of the heap and its
# expiration time is met.
if self._expirations.get(cookie_key) == when:
to_del.append(cookie_key)
if to_del:
self._delete_cookies(to_del)
def _delete_cookies(self, to_del: list[tuple[str, str, str]]) -> None:
for domain, path, name in to_del:
self._host_only_cookies.discard((domain, name))
self._cookies[(domain, path)].pop(name, None)
self._morsel_cache[(domain, path)].pop(name, None)
self._expirations.pop((domain, path, name), None)
def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None:
cookie_key = (domain, path, name)
if self._expirations.get(cookie_key) == when:
# Avoid adding duplicates to the heap
return
heapq.heappush(self._expire_heap, (when, cookie_key))
self._expirations[cookie_key] = when
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
"""Update cookies."""
hostname = response_url.raw_host
if not self._unsafe and is_ip_address(hostname):
# Don't accept cookies from IPs
return
if isinstance(cookies, Mapping):
cookies = cookies.items()
for name, cookie in cookies:
if not isinstance(cookie, Morsel):
tmp = SimpleCookie()
tmp[name] = cookie # type: ignore[assignment]
cookie = tmp[name]
domain = cookie["domain"]
# ignore domains with trailing dots
if domain and domain[-1] == ".":
domain = ""
del cookie["domain"]
if not domain and hostname is not None:
# Set the cookie's domain to the response hostname
# and set its host-only-flag
self._host_only_cookies.add((hostname, name))
domain = cookie["domain"] = hostname
if domain and domain[0] == ".":
# Remove leading dot
domain = domain[1:]
cookie["domain"] = domain
if hostname and not self._is_domain_match(domain, hostname):
# Setting cookies for different domains is not allowed
continue
path = cookie["path"]
if not path or path[0] != "/":
# Set the cookie's path to the response path
path = response_url.path
if not path.startswith("/"):
path = "/"
else:
# Cut everything from the last slash to the end
path = "/" + path[1 : path.rfind("/")]
cookie["path"] = path
path = path.rstrip("/")
if max_age := cookie["max-age"]:
try:
delta_seconds = int(max_age)
max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME)
self._expire_cookie(max_age_expiration, domain, path, name)
except ValueError:
cookie["max-age"] = ""
elif expires := cookie["expires"]:
if expire_time := self._parse_date(expires):
self._expire_cookie(expire_time, domain, path, name)
else:
cookie["expires"] = ""
key = (domain, path)
if self._cookies[key].get(name) != cookie:
# Don't blow away the cache if the same
# cookie gets set again
self._cookies[key][name] = cookie
self._morsel_cache[key].pop(name, None)
self._do_expiration()
def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]":
"""Returns this jar's cookies filtered by their attributes."""
# We always use BaseCookie now since all
# cookies set on on filtered are fully constructed
# Morsels, not just names and values.
filtered: BaseCookie[str] = BaseCookie()
if not self._cookies:
# Skip do_expiration() if there are no cookies.
return filtered
self._do_expiration()
if not self._cookies:
# Skip rest of function if no non-expired cookies.
return filtered
if type(request_url) is not URL:
warnings.warn(
"filter_cookies expects yarl.URL instances only,"
f"and will stop working in 4.x, got {type(request_url)}",
DeprecationWarning,
stacklevel=2,
)
request_url = URL(request_url)
hostname = request_url.raw_host or ""
is_not_secure = request_url.scheme not in ("https", "wss")
if is_not_secure and self._treat_as_secure_origin:
request_origin = URL()
with contextlib.suppress(ValueError):
request_origin = request_url.origin()
is_not_secure = request_origin not in self._treat_as_secure_origin
# Send shared cookie
key = ("", "")
for c in self._cookies[key].values():
# Check cache first
if c.key in self._morsel_cache[key]:
filtered[c.key] = self._morsel_cache[key][c.key]
continue
# Build and cache the morsel
mrsl_val = self._build_morsel(c)
self._morsel_cache[key][c.key] = mrsl_val
filtered[c.key] = mrsl_val
if is_ip_address(hostname):
if not self._unsafe:
return filtered
domains: Iterable[str] = (hostname,)
else:
# Get all the subdomains that might match a cookie (e.g. "foo.bar.com", "bar.com", "com")
domains = itertools.accumulate(
reversed(hostname.split(".")), _FORMAT_DOMAIN_REVERSED
)
# Get all the path prefixes that might match a cookie (e.g. "", "/foo", "/foo/bar")
paths = itertools.accumulate(request_url.path.split("/"), _FORMAT_PATH)
# Create every combination of (domain, path) pairs.
pairs = itertools.product(domains, paths)
path_len = len(request_url.path)
# Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4
for p in pairs:
if p not in self._cookies:
continue
for name, cookie in self._cookies[p].items():
domain = cookie["domain"]
if (domain, name) in self._host_only_cookies and domain != hostname:
continue
# Skip edge case when the cookie has a trailing slash but request doesn't.
if len(cookie["path"]) > path_len:
continue
if is_not_secure and cookie["secure"]:
continue
# We already built the Morsel so reuse it here
if name in self._morsel_cache[p]:
filtered[name] = self._morsel_cache[p][name]
continue
# Build and cache the morsel
mrsl_val = self._build_morsel(cookie)
self._morsel_cache[p][name] = mrsl_val
filtered[name] = mrsl_val
return filtered
def _build_morsel(self, cookie: Morsel[str]) -> Morsel[str]:
"""Build a morsel for sending, respecting quote_cookie setting."""
if self._quote_cookie and cookie.coded_value and cookie.coded_value[0] == '"':
return preserve_morsel_with_coded_value(cookie)
morsel: Morsel[str] = Morsel()
if self._quote_cookie:
value, coded_value = _SIMPLE_COOKIE.value_encode(cookie.value)
else:
coded_value = value = cookie.value
# We use __setstate__ instead of the public set() API because it allows us to
# bypass validation and set already validated state. This is more stable than
# setting protected attributes directly and unlikely to change since it would
# break pickling.
morsel.__setstate__({"key": cookie.key, "value": value, "coded_value": coded_value}) # type: ignore[attr-defined]
return morsel
@staticmethod
def _is_domain_match(domain: str, hostname: str) -> bool:
"""Implements domain matching adhering to RFC 6265."""
if hostname == domain:
return True
if not hostname.endswith(domain):
return False
non_matching = hostname[: -len(domain)]
if not non_matching.endswith("."):
return False
return not is_ip_address(hostname)
@classmethod
def _parse_date(cls, date_str: str) -> int | None:
"""Implements date string parsing adhering to RFC 6265."""
if not date_str:
return None
found_time = False
found_day = False
found_month = False
found_year = False
hour = minute = second = 0
day = 0
month = 0
year = 0
for token_match in cls.DATE_TOKENS_RE.finditer(date_str):
token = token_match.group("token")
if not found_time:
time_match = cls.DATE_HMS_TIME_RE.match(token)
if time_match:
found_time = True
hour, minute, second = (int(s) for s in time_match.groups())
continue
if not found_day:
day_match = cls.DATE_DAY_OF_MONTH_RE.match(token)
if day_match:
found_day = True
day = int(day_match.group())
continue
if not found_month:
month_match = cls.DATE_MONTH_RE.match(token)
if month_match:
found_month = True
assert month_match.lastindex is not None
month = month_match.lastindex
continue
if not found_year:
year_match = cls.DATE_YEAR_RE.match(token)
if year_match:
found_year = True
year = int(year_match.group())
if 70 <= year <= 99:
year += 1900
elif 0 <= year <= 69:
year += 2000
if False in (found_day, found_month, found_year, found_time):
return None
if not 1 <= day <= 31:
return None
if year < 1601 or hour > 23 or minute > 59 or second > 59:
return None
return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1))
class DummyCookieJar(AbstractCookieJar):
"""Implements a dummy cookie storage.
It can be used with the ClientSession when no cookie processing is needed.
"""
def __init__(self, *, loop: asyncio.AbstractEventLoop | None = None) -> None:
super().__init__(loop=loop)
def __iter__(self) -> "Iterator[Morsel[str]]":
while False:
yield None
def __len__(self) -> int:
return 0
@property
def unsafe(self) -> bool:
return False
@property
def quote_cookie(self) -> bool:
return True
@property
def cookies(self) -> MappingProxyType[tuple[str, str], SimpleCookie]:
"""Return an empty mapping."""
return MappingProxyType({})
@property
def host_only_cookies(self) -> frozenset[tuple[str, str]]:
"""Return an empty frozenset."""
return frozenset()
def clear(self, predicate: ClearCookiePredicate | None = None) -> None:
pass
def clear_domain(self, domain: str) -> None:
pass
def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None:
pass
def filter_cookies(self, request_url: URL) -> "BaseCookie[str]":
return SimpleCookie()

View File

@ -0,0 +1,184 @@
import io
import warnings
from collections.abc import Iterable
from typing import Any
from urllib.parse import urlencode
from multidict import MultiDict, MultiDictProxy
from . import hdrs, multipart, payload
from .helpers import guess_filename
from .http_writer import _safe_header
from .payload import Payload
__all__ = ("FormData",)
class FormData:
"""Helper class for form body generation.
Supports multipart/form-data and application/x-www-form-urlencoded.
"""
def __init__(
self,
fields: Iterable[Any] = (),
quote_fields: bool = True,
charset: str | None = None,
*,
default_to_multipart: bool = False,
) -> None:
self._writer = multipart.MultipartWriter("form-data")
self._fields: list[Any] = []
self._is_multipart = default_to_multipart
self._quote_fields = quote_fields
self._charset = charset
if isinstance(fields, dict):
fields = list(fields.items())
elif not isinstance(fields, (list, tuple)):
fields = (fields,)
self.add_fields(*fields)
@property
def is_multipart(self) -> bool:
return self._is_multipart
def add_field(
self,
name: str,
value: Any,
*,
content_type: str | None = None,
filename: str | None = None,
content_transfer_encoding: str | None = None,
) -> None:
if isinstance(value, io.IOBase):
self._is_multipart = True
elif isinstance(value, (bytes, bytearray, memoryview)):
msg = (
"In v4, passing bytes will no longer create a file field. "
"Please explicitly use the filename parameter or pass a BytesIO object."
)
if filename is None and content_transfer_encoding is None:
warnings.warn(msg, DeprecationWarning)
filename = name
_safe_header(name)
type_options: MultiDict[str] = MultiDict({"name": name})
if filename is not None and not isinstance(filename, str):
raise TypeError("filename must be an instance of str. Got: %s" % filename)
if filename is None and isinstance(value, io.IOBase):
filename = guess_filename(value, name)
if filename is not None:
_safe_header(filename)
type_options["filename"] = filename
self._is_multipart = True
headers = {}
if content_type is not None:
if not isinstance(content_type, str):
raise TypeError(
"content_type must be an instance of str. Got: %s" % content_type
)
_safe_header(content_type)
headers[hdrs.CONTENT_TYPE] = content_type
self._is_multipart = True
if content_transfer_encoding is not None:
if not isinstance(content_transfer_encoding, str):
raise TypeError(
"content_transfer_encoding must be an instance"
" of str. Got: %s" % content_transfer_encoding
)
msg = (
"content_transfer_encoding is deprecated. "
"To maintain compatibility with v4 please pass a BytesPayload."
)
warnings.warn(msg, DeprecationWarning)
self._is_multipart = True
self._fields.append((type_options, headers, value))
def add_fields(self, *fields: Any) -> None:
to_add = list(fields)
while to_add:
rec = to_add.pop(0)
if isinstance(rec, io.IOBase):
k = guess_filename(rec, "unknown")
self.add_field(k, rec) # type: ignore[arg-type]
elif isinstance(rec, (MultiDictProxy, MultiDict)):
to_add.extend(rec.items())
elif isinstance(rec, (list, tuple)) and len(rec) == 2:
k, fp = rec
self.add_field(k, fp)
else:
raise TypeError(
"Only io.IOBase, multidict and (name, file) "
"pairs allowed, use .add_field() for passing "
f"more complex parameters, got {rec!r}"
)
def _gen_form_urlencoded(self) -> payload.BytesPayload:
# form data (x-www-form-urlencoded)
data = []
for type_options, _, value in self._fields:
data.append((type_options["name"], value))
charset = self._charset if self._charset is not None else "utf-8"
if charset == "utf-8":
content_type = "application/x-www-form-urlencoded"
else:
content_type = "application/x-www-form-urlencoded; charset=%s" % charset
return payload.BytesPayload(
urlencode(data, doseq=True, encoding=charset).encode(),
content_type=content_type,
)
def _gen_form_data(self) -> multipart.MultipartWriter:
"""Encode a list of fields using the multipart/form-data MIME format"""
for dispparams, headers, value in self._fields:
try:
if hdrs.CONTENT_TYPE in headers:
part = payload.get_payload(
value,
content_type=headers[hdrs.CONTENT_TYPE],
headers=headers,
encoding=self._charset,
)
else:
part = payload.get_payload(
value, headers=headers, encoding=self._charset
)
except Exception as exc:
raise TypeError(
"Can not serialize value type: %r\n "
"headers: %r\n value: %r" % (type(value), headers, value)
) from exc
if dispparams:
part.set_content_disposition(
"form-data", quote_fields=self._quote_fields, **dispparams
)
# FIXME cgi.FieldStorage doesn't likes body parts with
# Content-Length which were sent via chunked transfer encoding
assert part.headers is not None
part.headers.popall(hdrs.CONTENT_LENGTH, None)
self._writer.append_payload(part)
self._fields.clear()
return self._writer
def __call__(self) -> Payload:
if self._is_multipart:
return self._gen_form_data()
else:
return self._gen_form_urlencoded()

View File

@ -0,0 +1,121 @@
"""HTTP Headers constants."""
# After changing the file content call ./tools/gen.py
# to regenerate the headers parser
import itertools
from typing import Final
from multidict import istr
METH_ANY: Final[str] = "*"
METH_CONNECT: Final[str] = "CONNECT"
METH_HEAD: Final[str] = "HEAD"
METH_GET: Final[str] = "GET"
METH_DELETE: Final[str] = "DELETE"
METH_OPTIONS: Final[str] = "OPTIONS"
METH_PATCH: Final[str] = "PATCH"
METH_POST: Final[str] = "POST"
METH_PUT: Final[str] = "PUT"
METH_TRACE: Final[str] = "TRACE"
METH_ALL: Final[set[str]] = {
METH_CONNECT,
METH_HEAD,
METH_GET,
METH_DELETE,
METH_OPTIONS,
METH_PATCH,
METH_POST,
METH_PUT,
METH_TRACE,
}
ACCEPT: Final[istr] = istr("Accept")
ACCEPT_CHARSET: Final[istr] = istr("Accept-Charset")
ACCEPT_ENCODING: Final[istr] = istr("Accept-Encoding")
ACCEPT_LANGUAGE: Final[istr] = istr("Accept-Language")
ACCEPT_RANGES: Final[istr] = istr("Accept-Ranges")
ACCESS_CONTROL_MAX_AGE: Final[istr] = istr("Access-Control-Max-Age")
ACCESS_CONTROL_ALLOW_CREDENTIALS: Final[istr] = istr("Access-Control-Allow-Credentials")
ACCESS_CONTROL_ALLOW_HEADERS: Final[istr] = istr("Access-Control-Allow-Headers")
ACCESS_CONTROL_ALLOW_METHODS: Final[istr] = istr("Access-Control-Allow-Methods")
ACCESS_CONTROL_ALLOW_ORIGIN: Final[istr] = istr("Access-Control-Allow-Origin")
ACCESS_CONTROL_EXPOSE_HEADERS: Final[istr] = istr("Access-Control-Expose-Headers")
ACCESS_CONTROL_REQUEST_HEADERS: Final[istr] = istr("Access-Control-Request-Headers")
ACCESS_CONTROL_REQUEST_METHOD: Final[istr] = istr("Access-Control-Request-Method")
AGE: Final[istr] = istr("Age")
ALLOW: Final[istr] = istr("Allow")
AUTHORIZATION: Final[istr] = istr("Authorization")
CACHE_CONTROL: Final[istr] = istr("Cache-Control")
CONNECTION: Final[istr] = istr("Connection")
CONTENT_DISPOSITION: Final[istr] = istr("Content-Disposition")
CONTENT_ENCODING: Final[istr] = istr("Content-Encoding")
CONTENT_LANGUAGE: Final[istr] = istr("Content-Language")
CONTENT_LENGTH: Final[istr] = istr("Content-Length")
CONTENT_LOCATION: Final[istr] = istr("Content-Location")
CONTENT_MD5: Final[istr] = istr("Content-MD5")
CONTENT_RANGE: Final[istr] = istr("Content-Range")
CONTENT_TRANSFER_ENCODING: Final[istr] = istr("Content-Transfer-Encoding")
CONTENT_TYPE: Final[istr] = istr("Content-Type")
COOKIE: Final[istr] = istr("Cookie")
DATE: Final[istr] = istr("Date")
DESTINATION: Final[istr] = istr("Destination")
DIGEST: Final[istr] = istr("Digest")
ETAG: Final[istr] = istr("Etag")
EXPECT: Final[istr] = istr("Expect")
EXPIRES: Final[istr] = istr("Expires")
FORWARDED: Final[istr] = istr("Forwarded")
FROM: Final[istr] = istr("From")
HOST: Final[istr] = istr("Host")
IF_MATCH: Final[istr] = istr("If-Match")
IF_MODIFIED_SINCE: Final[istr] = istr("If-Modified-Since")
IF_NONE_MATCH: Final[istr] = istr("If-None-Match")
IF_RANGE: Final[istr] = istr("If-Range")
IF_UNMODIFIED_SINCE: Final[istr] = istr("If-Unmodified-Since")
KEEP_ALIVE: Final[istr] = istr("Keep-Alive")
LAST_EVENT_ID: Final[istr] = istr("Last-Event-ID")
LAST_MODIFIED: Final[istr] = istr("Last-Modified")
LINK: Final[istr] = istr("Link")
LOCATION: Final[istr] = istr("Location")
MAX_FORWARDS: Final[istr] = istr("Max-Forwards")
ORIGIN: Final[istr] = istr("Origin")
PRAGMA: Final[istr] = istr("Pragma")
PROXY_AUTHENTICATE: Final[istr] = istr("Proxy-Authenticate")
PROXY_AUTHORIZATION: Final[istr] = istr("Proxy-Authorization")
RANGE: Final[istr] = istr("Range")
REFERER: Final[istr] = istr("Referer")
RETRY_AFTER: Final[istr] = istr("Retry-After")
SEC_WEBSOCKET_ACCEPT: Final[istr] = istr("Sec-WebSocket-Accept")
SEC_WEBSOCKET_VERSION: Final[istr] = istr("Sec-WebSocket-Version")
SEC_WEBSOCKET_PROTOCOL: Final[istr] = istr("Sec-WebSocket-Protocol")
SEC_WEBSOCKET_EXTENSIONS: Final[istr] = istr("Sec-WebSocket-Extensions")
SEC_WEBSOCKET_KEY: Final[istr] = istr("Sec-WebSocket-Key")
SEC_WEBSOCKET_KEY1: Final[istr] = istr("Sec-WebSocket-Key1")
SERVER: Final[istr] = istr("Server")
SET_COOKIE: Final[istr] = istr("Set-Cookie")
TE: Final[istr] = istr("TE")
TRAILER: Final[istr] = istr("Trailer")
TRANSFER_ENCODING: Final[istr] = istr("Transfer-Encoding")
UPGRADE: Final[istr] = istr("Upgrade")
URI: Final[istr] = istr("URI")
USER_AGENT: Final[istr] = istr("User-Agent")
VARY: Final[istr] = istr("Vary")
VIA: Final[istr] = istr("Via")
WANT_DIGEST: Final[istr] = istr("Want-Digest")
WARNING: Final[istr] = istr("Warning")
WWW_AUTHENTICATE: Final[istr] = istr("WWW-Authenticate")
X_FORWARDED_FOR: Final[istr] = istr("X-Forwarded-For")
X_FORWARDED_HOST: Final[istr] = istr("X-Forwarded-Host")
X_FORWARDED_PROTO: Final[istr] = istr("X-Forwarded-Proto")
# These are the upper/lower case variants of the headers/methods
# Example: {'hOst', 'host', 'HoST', 'HOSt', 'hOsT', 'HosT', 'hoSt', ...}
METH_HEAD_ALL: Final = frozenset(
map("".join, itertools.product(*zip(METH_HEAD.upper(), METH_HEAD.lower())))
)
METH_CONNECT_ALL: Final = frozenset(
map("".join, itertools.product(*zip(METH_CONNECT.upper(), METH_CONNECT.lower())))
)
HOST_ALL: Final = frozenset(
map("".join, itertools.product(*zip(HOST.upper(), HOST.lower())))
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,78 @@
import sys
from collections.abc import Mapping
from http import HTTPStatus
from . import __version__
from .http_exceptions import HttpProcessingError as HttpProcessingError
from .http_parser import (
HeadersParser as HeadersParser,
HttpParser as HttpParser,
HttpRequestParser as HttpRequestParser,
HttpResponseParser as HttpResponseParser,
RawRequestMessage as RawRequestMessage,
RawResponseMessage as RawResponseMessage,
)
from .http_websocket import (
WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE,
WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE,
WS_KEY as WS_KEY,
WebSocketError as WebSocketError,
WebSocketReader as WebSocketReader,
WebSocketWriter as WebSocketWriter,
WSCloseCode as WSCloseCode,
WSMessage as WSMessage,
WSMessageDecodeText as WSMessageDecodeText,
WSMessageNoDecodeText as WSMessageNoDecodeText,
WSMessageTextBytes as WSMessageTextBytes,
WSMsgType as WSMsgType,
ws_ext_gen as ws_ext_gen,
ws_ext_parse as ws_ext_parse,
)
from .http_writer import (
HttpVersion as HttpVersion,
HttpVersion10 as HttpVersion10,
HttpVersion11 as HttpVersion11,
StreamWriter as StreamWriter,
)
__all__ = (
"HttpProcessingError",
"RESPONSES",
"SERVER_SOFTWARE",
# .http_writer
"StreamWriter",
"HttpVersion",
"HttpVersion10",
"HttpVersion11",
# .http_parser
"HeadersParser",
"HttpParser",
"HttpRequestParser",
"HttpResponseParser",
"RawRequestMessage",
"RawResponseMessage",
# .http_websocket
"WS_CLOSED_MESSAGE",
"WS_CLOSING_MESSAGE",
"WS_KEY",
"WebSocketReader",
"WebSocketWriter",
"ws_ext_gen",
"ws_ext_parse",
"WSMessage",
"WSMessageDecodeText",
"WSMessageNoDecodeText",
"WSMessageTextBytes",
"WebSocketError",
"WSMsgType",
"WSCloseCode",
)
SERVER_SOFTWARE: str = (
f"Python/{sys.version_info[0]}.{sys.version_info[1]} aiohttp/{__version__}"
)
RESPONSES: Mapping[int, tuple[str, str]] = {
v: (v.phrase, v.description) for v in HTTPStatus.__members__.values()
}

Some files were not shown because too many files have changed in this diff Show More