from __future__ import annotations import re from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, MutableMapping, Sequence, ValuesView from shlex import shlex from typing import Any, BinaryIO, NamedTuple, TypeVar, cast from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit from starlette.concurrency import run_in_threadpool from starlette.types import Scope class Address(NamedTuple): host: str port: int _KeyType = TypeVar("_KeyType") # Mapping keys are invariant but their values are covariant since # you can only read them # that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()` _CovariantValueType = TypeVar("_CovariantValueType", covariant=True) # Rejects Host header chars (/, ?, #, @, ...) that would let urlsplit produce a path differing from scope["path"]. _HOST_RE = re.compile(r"^([a-z0-9.-]+|\[[a-f0-9]*:[a-f0-9.:]+\])(?::[0-9]+)?$", re.IGNORECASE) class URL: def __init__( self, url: str = "", scope: Scope | None = None, **components: Any, ) -> None: if scope is not None: assert not url, 'Cannot set both "url" and "scope".' assert not components, 'Cannot set both "scope" and "**components".' scheme = scope.get("scheme", "http") server = scope.get("server", None) path = scope["path"] query_string = scope.get("query_string", b"") host_header = None for key, value in scope["headers"]: if key == b"host": host_header = value.decode("latin-1") break if host_header is not None and _HOST_RE.fullmatch(host_header): url = f"{scheme}://{host_header}{path}" elif server is None: url = path else: host, port = server default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] if port == default_port: url = f"{scheme}://{host}{path}" else: url = f"{scheme}://{host}:{port}{path}" if query_string: url += "?" + query_string.decode() elif components: assert not url, 'Cannot set both "url" and "**components".' url = URL("").replace(**components).components.geturl() self._url = url @property def components(self) -> SplitResult: if not hasattr(self, "_components"): self._components = urlsplit(self._url) return self._components @property def scheme(self) -> str: return self.components.scheme @property def netloc(self) -> str: return self.components.netloc @property def path(self) -> str: return self.components.path @property def query(self) -> str: return self.components.query @property def fragment(self) -> str: return self.components.fragment @property def username(self) -> None | str: return self.components.username @property def password(self) -> None | str: return self.components.password @property def hostname(self) -> None | str: return self.components.hostname @property def port(self) -> int | None: return self.components.port @property def is_secure(self) -> bool: return self.scheme in ("https", "wss") def replace(self, **kwargs: Any) -> URL: if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs: hostname = kwargs.pop("hostname", None) port = kwargs.pop("port", self.port) username = kwargs.pop("username", self.username) password = kwargs.pop("password", self.password) if hostname is None: netloc = self.netloc _, _, hostname = netloc.rpartition("@") if hostname[-1] != "]": hostname = hostname.rsplit(":", 1)[0] netloc = hostname if port is not None: netloc += f":{port}" if username is not None: userpass = username if password is not None: userpass += f":{password}" netloc = f"{userpass}@{netloc}" kwargs["netloc"] = netloc components = self.components._replace(**kwargs) return self.__class__(components.geturl()) def include_query_params(self, **kwargs: Any) -> URL: params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) params.update({str(key): str(value) for key, value in kwargs.items()}) query = urlencode(params.multi_items()) return self.replace(query=query) def replace_query_params(self, **kwargs: Any) -> URL: query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) return self.replace(query=query) def remove_query_params(self, keys: str | Sequence[str]) -> URL: if isinstance(keys, str): keys = [keys] params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) for key in keys: params.pop(key, None) query = urlencode(params.multi_items()) return self.replace(query=query) def __eq__(self, other: Any) -> bool: return str(self) == str(other) def __str__(self) -> str: return self._url def __repr__(self) -> str: url = str(self) if self.password: url = str(self.replace(password="********")) return f"{self.__class__.__name__}({repr(url)})" class URLPath(str): """ A URL path string that may also hold an associated protocol and/or host. Used by the routing to return `url_path_for` matches. """ def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: assert protocol in ("http", "websocket", "") return str.__new__(cls, path) def __init__(self, path: str, protocol: str = "", host: str = "") -> None: self.protocol = protocol self.host = host def make_absolute_url(self, base_url: str | URL) -> URL: if isinstance(base_url, str): base_url = URL(base_url) if self.protocol: scheme = { "http": {True: "https", False: "http"}, "websocket": {True: "wss", False: "ws"}, }[self.protocol][base_url.is_secure] else: scheme = base_url.scheme netloc = self.host or base_url.netloc path = base_url.path.rstrip("/") + str(self) return URL(scheme=scheme, netloc=netloc, path=path) class Secret: """ Holds a string value that should not be revealed in tracebacks etc. You should cast the value to `str` at the point it is required. """ def __init__(self, value: str): self._value = value def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}('**********')" def __str__(self) -> str: return self._value def __bool__(self) -> bool: return bool(self._value) class CommaSeparatedStrings(Sequence[str]): def __init__(self, value: str | Sequence[str]): if isinstance(value, str): splitter = shlex(value, posix=True) splitter.whitespace = "," splitter.whitespace_split = True self._items = [item.strip() for item in splitter] else: self._items = list(value) def __len__(self) -> int: return len(self._items) def __getitem__(self, index: int | slice) -> Any: return self._items[index] def __iter__(self) -> Iterator[str]: return iter(self._items) def __repr__(self) -> str: class_name = self.__class__.__name__ items = [item for item in self] return f"{class_name}({items!r})" def __str__(self) -> str: return ", ".join(repr(item) for item in self) class ImmutableMultiDict(Mapping[_KeyType, _CovariantValueType]): _dict: dict[_KeyType, _CovariantValueType] def __init__( self, *args: ImmutableMultiDict[_KeyType, _CovariantValueType] | Mapping[_KeyType, _CovariantValueType] | Iterable[tuple[_KeyType, _CovariantValueType]], **kwargs: Any, ) -> None: assert len(args) < 2, "Too many arguments." value: Any = args[0] if args else [] if kwargs: value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() if not value: _items: list[tuple[Any, Any]] = [] elif hasattr(value, "multi_items"): value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value) _items = list(value.multi_items()) elif hasattr(value, "items"): value = cast(Mapping[_KeyType, _CovariantValueType], value) _items = list(value.items()) else: value = cast("list[tuple[Any, Any]]", value) _items = list(value) self._dict = {k: v for k, v in _items} self._list = _items def getlist(self, key: Any) -> list[_CovariantValueType]: return [item_value for item_key, item_value in self._list if item_key == key] def keys(self) -> KeysView[_KeyType]: return self._dict.keys() def values(self) -> ValuesView[_CovariantValueType]: return self._dict.values() def items(self) -> ItemsView[_KeyType, _CovariantValueType]: return self._dict.items() def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]: return list(self._list) def __getitem__(self, key: _KeyType) -> _CovariantValueType: return self._dict[key] def __contains__(self, key: Any) -> bool: return key in self._dict def __iter__(self) -> Iterator[_KeyType]: return iter(self.keys()) def __len__(self) -> int: return len(self._dict) def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False return sorted(self._list) == sorted(other._list) def __repr__(self) -> str: class_name = self.__class__.__name__ items = self.multi_items() return f"{class_name}({items!r})" class MultiDict(ImmutableMultiDict[Any, Any]): def __setitem__(self, key: Any, value: Any) -> None: self.setlist(key, [value]) def __delitem__(self, key: Any) -> None: self._list = [(k, v) for k, v in self._list if k != key] del self._dict[key] def pop(self, key: Any, default: Any = None) -> Any: self._list = [(k, v) for k, v in self._list if k != key] return self._dict.pop(key, default) def popitem(self) -> tuple[Any, Any]: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value def poplist(self, key: Any) -> list[Any]: values = [v for k, v in self._list if k == key] self.pop(key) return values def clear(self) -> None: self._dict.clear() self._list.clear() def setdefault(self, key: Any, default: Any = None) -> Any: if key not in self: self._dict[key] = default self._list.append((key, default)) return self[key] def setlist(self, key: Any, values: list[Any]) -> None: if not values: self.pop(key, None) else: existing_items = [(k, v) for (k, v) in self._list if k != key] self._list = existing_items + [(key, value) for value in values] self._dict[key] = values[-1] def append(self, key: Any, value: Any) -> None: self._list.append((key, value)) self._dict[key] = value def update( self, *args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]], **kwargs: Any, ) -> None: value = MultiDict(*args, **kwargs) existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] self._list = existing_items + value.multi_items() self._dict.update(value) class QueryParams(ImmutableMultiDict[str, str]): """ An immutable multidict. """ def __init__( self, *args: ImmutableMultiDict[Any, Any] | Mapping[Any, Any] | list[tuple[Any, Any]] | str | bytes, **kwargs: Any, ) -> None: assert len(args) < 2, "Too many arguments." value = args[0] if args else [] if isinstance(value, str): super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) elif isinstance(value, bytes): super().__init__(parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs) else: super().__init__(*args, **kwargs) # type: ignore[arg-type] self._list = [(str(k), str(v)) for k, v in self._list] self._dict = {str(k): str(v) for k, v in self._dict.items()} def __str__(self) -> str: return urlencode(self._list) def __repr__(self) -> str: class_name = self.__class__.__name__ query_string = str(self) return f"{class_name}({query_string!r})" class UploadFile: """ An uploaded file included as part of the request data. """ def __init__( self, file: BinaryIO, *, size: int | None = None, filename: str | None = None, headers: Headers | None = None, ) -> None: self.filename = filename self.file = file self.size = size self.headers = headers or Headers() # Capture max size from SpooledTemporaryFile if one is provided. This slightly speeds up future checks. # Note 0 means unlimited mirroring SpooledTemporaryFile's __init__ self._max_mem_size = getattr(self.file, "_max_size", 0) @property def content_type(self) -> str | None: return self.headers.get("content-type", None) @property def _in_memory(self) -> bool: # check for SpooledTemporaryFile._rolled rolled_to_disk = getattr(self.file, "_rolled", True) return not rolled_to_disk def _will_roll(self, size_to_add: int) -> bool: # If we're not in_memory then we will always roll if not self._in_memory: return True # Check for SpooledTemporaryFile._max_size future_size = self.file.tell() + size_to_add return bool(future_size > self._max_mem_size) if self._max_mem_size else False async def write(self, data: bytes) -> None: new_data_len = len(data) if self.size is not None: self.size += new_data_len if self._will_roll(new_data_len): await run_in_threadpool(self.file.write, data) else: self.file.write(data) async def read(self, size: int = -1) -> bytes: if self._in_memory: return self.file.read(size) return await run_in_threadpool(self.file.read, size) async def seek(self, offset: int) -> None: if self._in_memory: self.file.seek(offset) else: await run_in_threadpool(self.file.seek, offset) async def close(self) -> None: if self._in_memory: self.file.close() else: await run_in_threadpool(self.file.close) def __repr__(self) -> str: return f"{self.__class__.__name__}(filename={self.filename!r}, size={self.size!r}, headers={self.headers!r})" class FormData(ImmutableMultiDict[str, UploadFile | str]): """ An immutable multidict, containing both file uploads and text input. """ def __init__( self, *args: FormData | Mapping[str, str | UploadFile] | list[tuple[str, str | UploadFile]], **kwargs: str | UploadFile, ) -> None: super().__init__(*args, **kwargs) async def close(self) -> None: for key, value in self.multi_items(): if isinstance(value, UploadFile): await value.close() class Headers(Mapping[str, str]): """ An immutable, case-insensitive multidict. """ def __init__( self, headers: Mapping[str, str] | None = None, raw: list[tuple[bytes, bytes]] | None = None, scope: MutableMapping[str, Any] | None = None, ) -> None: self._list: list[tuple[bytes, bytes]] = [] if headers is not None: assert raw is None, 'Cannot set both "headers" and "raw".' assert scope is None, 'Cannot set both "headers" and "scope".' self._list = [(key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items()] elif raw is not None: assert scope is None, 'Cannot set both "raw" and "scope".' self._list = raw elif scope is not None: # scope["headers"] isn't necessarily a list # it might be a tuple or other iterable self._list = scope["headers"] = list(scope["headers"]) @property def raw(self) -> list[tuple[bytes, bytes]]: return list(self._list) def keys(self) -> list[str]: # type: ignore[override] return [key.decode("latin-1") for key, value in self._list] def values(self) -> list[str]: # type: ignore[override] return [value.decode("latin-1") for key, value in self._list] def items(self) -> list[tuple[str, str]]: # type: ignore[override] return [(key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list] def getlist(self, key: str) -> list[str]: get_header_key = key.lower().encode("latin-1") return [item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key] def mutablecopy(self) -> MutableHeaders: return MutableHeaders(raw=self._list[:]) def __getitem__(self, key: str) -> str: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return header_value.decode("latin-1") raise KeyError(key) def __contains__(self, key: Any) -> bool: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return True return False def __iter__(self) -> Iterator[Any]: return iter(self.keys()) def __len__(self) -> int: return len(self._list) def __eq__(self, other: Any) -> bool: if not isinstance(other, Headers): return False return sorted(self._list) == sorted(other._list) def __repr__(self) -> str: class_name = self.__class__.__name__ as_dict = dict(self.items()) if len(as_dict) == len(self): return f"{class_name}({as_dict!r})" return f"{class_name}(raw={self.raw!r})" class MutableHeaders(Headers): def __setitem__(self, key: str, value: str) -> None: """ Set the header `key` to `value`, removing any duplicate entries. Retains insertion order. """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") found_indexes: list[int] = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: found_indexes.append(idx) for idx in reversed(found_indexes[1:]): del self._list[idx] if found_indexes: idx = found_indexes[0] self._list[idx] = (set_key, set_value) else: self._list.append((set_key, set_value)) def __delitem__(self, key: str) -> None: """ Remove the header `key`. """ del_key = key.lower().encode("latin-1") pop_indexes: list[int] = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == del_key: pop_indexes.append(idx) for idx in reversed(pop_indexes): del self._list[idx] def __ior__(self, other: Mapping[str, str]) -> MutableHeaders: if not isinstance(other, Mapping): raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") self.update(other) return self def __or__(self, other: Mapping[str, str]) -> MutableHeaders: if not isinstance(other, Mapping): raise TypeError(f"Expected a mapping but got {other.__class__.__name__}") new = self.mutablecopy() new.update(other) return new @property def raw(self) -> list[tuple[bytes, bytes]]: return self._list def setdefault(self, key: str, value: str) -> str: """ If the header `key` does not exist, then set it to `value`. Returns the header value. """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: return item_value.decode("latin-1") self._list.append((set_key, set_value)) return value def update(self, other: Mapping[str, str]) -> None: for key, val in other.items(): self[key] = val def append(self, key: str, value: str) -> None: """ Append a header, preserving any duplicate entries. """ append_key = key.lower().encode("latin-1") append_value = value.encode("latin-1") self._list.append((append_key, append_value)) def add_vary_header(self, vary: str) -> None: existing = self.get("vary") if existing is not None: vary = ", ".join([existing, vary]) self["vary"] = vary class State: """ An object that can be used to store arbitrary state. Used for `request.state` and `app.state`. """ _state: dict[str, Any] def __init__(self, state: dict[str, Any] | None = None): if state is None: state = {} super().__setattr__("_state", state) def __setattr__(self, key: Any, value: Any) -> None: self._state[key] = value def __getattr__(self, key: Any) -> Any: try: return self._state[key] except KeyError: message = "'{}' object has no attribute '{}'" raise AttributeError(message.format(self.__class__.__name__, key)) def __delattr__(self, key: Any) -> None: del self._state[key] def __getitem__(self, key: str) -> Any: return self._state[key] def __setitem__(self, key: str, value: Any) -> None: self._state[key] = value def __delitem__(self, key: str) -> None: del self._state[key] def __iter__(self) -> Iterator[str]: return iter(self._state) def __len__(self) -> int: return len(self._state)