Initial commit: 首次建仓,建立目录结构

This commit is contained in:
FXY
2026-06-11 23:49:54 +08:00
commit 4038a476b5
9396 changed files with 2372905 additions and 0 deletions

View File

@ -0,0 +1,645 @@
from __future__ import annotations
from concurrent.futures import ThreadPoolExecutor
from decimal import Decimal
import operator
import os
from sys import byteorder
import threading
from typing import (
TYPE_CHECKING,
ContextManager,
)
import numpy as np
from pandas._config import using_string_dtype
from pandas._config.localization import (
can_set_locale,
get_locales,
set_locale,
)
from pandas.compat import HAS_PYARROW
import pandas as pd
from pandas import (
ArrowDtype,
DataFrame,
Index,
MultiIndex,
RangeIndex,
Series,
)
from pandas._testing._io import (
round_trip_pathlib,
round_trip_pickle,
write_to_compressed,
)
from pandas._testing._warnings import (
assert_produces_warning,
maybe_produces_warning,
)
from pandas._testing.asserters import (
assert_almost_equal,
assert_attr_equal,
assert_categorical_equal,
assert_class_equal,
assert_contains_all,
assert_copy,
assert_datetime_array_equal,
assert_dict_equal,
assert_equal,
assert_extension_array_equal,
assert_frame_equal,
assert_index_equal,
assert_indexing_slices_equivalent,
assert_interval_array_equal,
assert_is_sorted,
assert_metadata_equivalent,
assert_numpy_array_equal,
assert_period_array_equal,
assert_series_equal,
assert_sp_array_equal,
assert_timedelta_array_equal,
raise_assert_detail,
)
from pandas._testing.compat import (
get_dtype,
get_obj,
)
from pandas._testing.contexts import (
decompress_file,
raises_chained_assignment_error,
set_timezone,
with_csv_dialect,
)
from pandas.core.arrays import (
ArrowExtensionArray,
BaseMaskedArray,
NumpyExtensionArray,
)
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
from pandas.core.construction import extract_array
if TYPE_CHECKING:
from collections.abc import Callable
from pandas._typing import (
Dtype,
NpDtype,
)
UNSIGNED_INT_NUMPY_DTYPES: list[NpDtype] = ["uint8", "uint16", "uint32", "uint64"]
UNSIGNED_INT_EA_DTYPES: list[Dtype] = ["UInt8", "UInt16", "UInt32", "UInt64"]
SIGNED_INT_NUMPY_DTYPES: list[NpDtype] = [int, "int8", "int16", "int32", "int64"]
SIGNED_INT_EA_DTYPES: list[Dtype] = ["Int8", "Int16", "Int32", "Int64"]
ALL_INT_NUMPY_DTYPES = UNSIGNED_INT_NUMPY_DTYPES + SIGNED_INT_NUMPY_DTYPES
ALL_INT_EA_DTYPES = UNSIGNED_INT_EA_DTYPES + SIGNED_INT_EA_DTYPES
ALL_INT_DTYPES: list[Dtype] = [*ALL_INT_NUMPY_DTYPES, *ALL_INT_EA_DTYPES]
FLOAT_NUMPY_DTYPES: list[NpDtype] = [float, "float32", "float64"]
FLOAT_EA_DTYPES: list[Dtype] = ["Float32", "Float64"]
ALL_FLOAT_DTYPES: list[Dtype] = [*FLOAT_NUMPY_DTYPES, *FLOAT_EA_DTYPES]
COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
if using_string_dtype():
STRING_DTYPES: list[Dtype] = ["U"]
else:
STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef]
COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]
DATETIME64_DTYPES: list[Dtype] = ["datetime64[ns]", "M8[ns]"]
TIMEDELTA64_DTYPES: list[Dtype] = ["timedelta64[ns]", "m8[ns]"]
BOOL_DTYPES: list[Dtype] = [bool, "bool"]
BYTES_DTYPES: list[Dtype] = [bytes, "bytes"]
OBJECT_DTYPES: list[Dtype] = [object, "object"]
ALL_REAL_NUMPY_DTYPES = FLOAT_NUMPY_DTYPES + ALL_INT_NUMPY_DTYPES
ALL_REAL_EXTENSION_DTYPES = FLOAT_EA_DTYPES + ALL_INT_EA_DTYPES
ALL_REAL_DTYPES: list[Dtype] = [*ALL_REAL_NUMPY_DTYPES, *ALL_REAL_EXTENSION_DTYPES]
ALL_NUMERIC_DTYPES: list[Dtype] = [*ALL_REAL_DTYPES, *COMPLEX_DTYPES]
ALL_NUMPY_DTYPES = (
ALL_REAL_NUMPY_DTYPES
+ COMPLEX_DTYPES
+ STRING_DTYPES
+ DATETIME64_DTYPES
+ TIMEDELTA64_DTYPES
+ BOOL_DTYPES
+ OBJECT_DTYPES
+ BYTES_DTYPES
)
NARROW_NP_DTYPES = [
np.float16,
np.float32,
np.int8,
np.int16,
np.int32,
np.uint8,
np.uint16,
np.uint32,
]
PYTHON_DATA_TYPES = [
str,
int,
float,
complex,
list,
tuple,
range,
dict,
set,
frozenset,
bool,
bytes,
bytearray,
memoryview,
]
ENDIAN = {"little": "<", "big": ">"}[byteorder]
NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
NP_NAT_OBJECTS = [
cls("NaT", unit)
for cls in [np.datetime64, np.timedelta64]
for unit in [
"Y",
"M",
"W",
"D",
"h",
"m",
"s",
"ms",
"us",
"ns",
"ps",
"fs",
"as",
]
]
if HAS_PYARROW:
import pyarrow as pa
UNSIGNED_INT_PYARROW_DTYPES = [pa.uint8(), pa.uint16(), pa.uint32(), pa.uint64()]
SIGNED_INT_PYARROW_DTYPES = [pa.int8(), pa.int16(), pa.int32(), pa.int64()]
ALL_INT_PYARROW_DTYPES = UNSIGNED_INT_PYARROW_DTYPES + SIGNED_INT_PYARROW_DTYPES
ALL_INT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in ALL_INT_PYARROW_DTYPES
]
# pa.float16 doesn't seem supported
# https://github.com/apache/arrow/blob/master/python/pyarrow/src/arrow/python/helpers.cc#L86
FLOAT_PYARROW_DTYPES = [pa.float32(), pa.float64()]
FLOAT_PYARROW_DTYPES_STR_REPR = [
str(ArrowDtype(typ)) for typ in FLOAT_PYARROW_DTYPES
]
DECIMAL_PYARROW_DTYPES = [pa.decimal128(7, 3)]
STRING_PYARROW_DTYPES = [pa.string()]
BINARY_PYARROW_DTYPES = [pa.binary()]
TIME_PYARROW_DTYPES = [
pa.time32("s"),
pa.time32("ms"),
pa.time64("us"),
pa.time64("ns"),
]
DATE_PYARROW_DTYPES = [pa.date32(), pa.date64()]
DATETIME_PYARROW_DTYPES = [
pa.timestamp(unit=unit, tz=tz)
for unit in ["s", "ms", "us", "ns"]
for tz in [None, "UTC", "US/Pacific", "US/Eastern"]
]
TIMEDELTA_PYARROW_DTYPES = [pa.duration(unit) for unit in ["s", "ms", "us", "ns"]]
BOOL_PYARROW_DTYPES = [pa.bool_()]
# TODO: Add container like pyarrow types:
# https://arrow.apache.org/docs/python/api/datatypes.html#factory-functions
ALL_PYARROW_DTYPES = (
ALL_INT_PYARROW_DTYPES
+ FLOAT_PYARROW_DTYPES
+ DECIMAL_PYARROW_DTYPES
+ STRING_PYARROW_DTYPES
+ BINARY_PYARROW_DTYPES
+ TIME_PYARROW_DTYPES
+ DATE_PYARROW_DTYPES
+ DATETIME_PYARROW_DTYPES
+ TIMEDELTA_PYARROW_DTYPES
+ BOOL_PYARROW_DTYPES
)
ALL_REAL_PYARROW_DTYPES_STR_REPR = (
ALL_INT_PYARROW_DTYPES_STR_REPR + FLOAT_PYARROW_DTYPES_STR_REPR
)
else:
FLOAT_PYARROW_DTYPES_STR_REPR = []
ALL_INT_PYARROW_DTYPES_STR_REPR = []
ALL_PYARROW_DTYPES = []
ALL_REAL_PYARROW_DTYPES_STR_REPR = []
ALL_REAL_NULLABLE_DTYPES = (
FLOAT_NUMPY_DTYPES + ALL_REAL_EXTENSION_DTYPES + ALL_REAL_PYARROW_DTYPES_STR_REPR
)
arithmetic_dunder_methods = [
"__add__",
"__radd__",
"__sub__",
"__rsub__",
"__mul__",
"__rmul__",
"__floordiv__",
"__rfloordiv__",
"__truediv__",
"__rtruediv__",
"__pow__",
"__rpow__",
"__mod__",
"__rmod__",
]
comparison_dunder_methods = ["__eq__", "__ne__", "__le__", "__lt__", "__ge__", "__gt__"]
# -----------------------------------------------------------------------------
# Comparators
def box_expected(expected, box_cls, transpose: bool = True):
"""
Helper function to wrap the expected output of a test in a given box_class.
Parameters
----------
expected : np.ndarray, Index, Series
box_cls : {Index, Series, DataFrame}
Returns
-------
subclass of box_cls
"""
if box_cls is pd.array:
if isinstance(expected, RangeIndex):
# pd.array would return an IntegerArray
expected = NumpyExtensionArray(np.asarray(expected._values))
else:
expected = pd.array(expected, copy=False)
elif box_cls is Index:
expected = Index(expected, copy=False)
elif box_cls is Series:
expected = Series(expected)
elif box_cls is DataFrame:
expected = Series(expected).to_frame()
if transpose:
# for vector operations, we need a DataFrame to be a single-row,
# not a single-column, in order to operate against non-DataFrame
# vectors of the same length. But convert to two rows to avoid
# single-row special cases in datetime arithmetic
expected = expected.T
expected = pd.concat([expected] * 2, ignore_index=True)
elif box_cls is np.ndarray or box_cls is np.array:
expected = np.array(expected)
elif box_cls is to_array:
expected = to_array(expected)
else:
raise NotImplementedError(box_cls)
return expected
def to_array(obj):
"""
Similar to pd.array, but does not cast numpy dtypes to nullable dtypes.
"""
# temporary implementation until we get pd.array in place
dtype = getattr(obj, "dtype", None)
if dtype is None:
return np.asarray(obj)
return extract_array(obj, extract_numpy=True)
class SubclassedSeries(Series):
_metadata = ["testattr", "name"]
@property
def _constructor(self):
# For testing, those properties return a generic callable, and not
# the actual class. In this case that is equivalent, but it is to
# ensure we don't rely on the property returning a class
# See https://github.com/pandas-dev/pandas/pull/46018 and
# https://github.com/pandas-dev/pandas/issues/32638 and linked issues
return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
@property
def _constructor_expanddim(self):
return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
class SubclassedDataFrame(DataFrame):
_metadata = ["testattr"]
@property
def _constructor(self):
return lambda *args, **kwargs: SubclassedDataFrame(*args, **kwargs)
# error: Cannot override writeable attribute with read-only property
@property
def _constructor_sliced(self): # type: ignore[override]
return lambda *args, **kwargs: SubclassedSeries(*args, **kwargs)
def convert_rows_list_to_csv_str(rows_list: list[str]) -> str:
"""
Convert list of CSV rows to single CSV-formatted string for current OS.
This method is used for creating expected value of to_csv() method.
Parameters
----------
rows_list : List[str]
Each element represents the row of csv.
Returns
-------
str
Expected output of to_csv() in current OS.
"""
sep = os.linesep
return sep.join(rows_list) + sep
def external_error_raised(expected_exception: type[Exception]) -> ContextManager:
"""
Helper function to mark pytest.raises that have an external error message.
Parameters
----------
expected_exception : Exception
Expected error to raise.
Returns
-------
Callable
Regular `pytest.raises` function with `match` equal to `None`.
"""
import pytest
return pytest.raises(expected_exception, match=None)
def get_cython_table_params(ndframe, func_names_and_expected):
"""
Combine frame, functions from com._cython_table
keys and expected result.
Parameters
----------
ndframe : DataFrame or Series
func_names_and_expected : Sequence of two items
The first item is a name of an NDFrame method ('sum', 'prod') etc.
The second item is the expected return value.
Returns
-------
list
List of three items (DataFrame, function, expected result)
"""
results = []
for func_name, expected in func_names_and_expected:
results.append((ndframe, func_name, expected))
return results
def get_op_from_name(op_name: str) -> Callable:
"""
The operator function for a given op name.
Parameters
----------
op_name : str
The op name, in form of "add" or "__add__".
Returns
-------
function
A function performing the operation.
"""
short_opname = op_name.strip("_")
try:
op = getattr(operator, short_opname)
except AttributeError:
# Assume it is the reverse operator
rop = getattr(operator, short_opname[1:])
op = lambda x, y: rop(y, x)
return op
# -----------------------------------------------------------------------------
# Indexing test helpers
def getitem(x):
return x
def setitem(x):
return x
def loc(x):
return x.loc
def iloc(x):
return x.iloc
def at(x):
return x.at
def iat(x):
return x.iat
# -----------------------------------------------------------------------------
_UNITS = ["s", "ms", "us", "ns"]
def get_finest_unit(left: str, right: str) -> str:
"""
Find the higher of two datetime64 units.
"""
if _UNITS.index(left) >= _UNITS.index(right):
return left
return right
def shares_memory(left, right) -> bool:
"""
Pandas-compat for np.shares_memory.
"""
if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
return np.shares_memory(left, right)
elif isinstance(left, np.ndarray):
# Call with reversed args to get to unpacking logic below.
return shares_memory(right, left)
if isinstance(left, RangeIndex):
return False
if isinstance(left, MultiIndex):
return shares_memory(left._codes, right)
if isinstance(left, (Index, Series)):
if isinstance(right, (Index, Series)):
return shares_memory(left._values, right._values)
return shares_memory(left._values, right)
if isinstance(left, NDArrayBackedExtensionArray):
return shares_memory(left._ndarray, right)
if isinstance(left, pd.core.arrays.SparseArray):
return shares_memory(left.sp_values, right)
if isinstance(left, pd.core.arrays.IntervalArray):
return shares_memory(left._left, right) or shares_memory(left._right, right)
if isinstance(left, ArrowExtensionArray):
if isinstance(right, ArrowExtensionArray):
# https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
left_pa_data = left._pa_array
right_pa_data = right._pa_array
left_buf1 = left_pa_data.chunk(0).buffers()[1]
right_buf1 = right_pa_data.chunk(0).buffers()[1]
return left_buf1.address == right_buf1.address
else:
# if we have one one ArrowExtensionArray and one other array, assume
# they can only share memory if they share the same numpy buffer
return np.shares_memory(left, right)
if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray):
# By convention, we'll say these share memory if they share *either*
# the _data or the _mask
return np.shares_memory(left._data, right._data) or np.shares_memory(
left._mask, right._mask
)
if isinstance(left, DataFrame) and len(left._mgr.blocks) == 1:
arr = left._mgr.blocks[0].values
return shares_memory(arr, right)
raise NotImplementedError(type(left), type(right))
def run_multithreaded(closure, max_workers, arguments=None, pass_barrier=False):
with ThreadPoolExecutor(max_workers=max_workers) as tpe:
if arguments is None:
arguments = []
else:
arguments = list(arguments)
if pass_barrier:
barrier = threading.Barrier(max_workers)
arguments.append(barrier)
try:
futures = []
for _ in range(max_workers):
futures.append(tpe.submit(closure, *arguments)) # noqa: PERF401
except RuntimeError as e:
import pytest
pytest.skip(
f"Spawning {max_workers} threads failed with "
f"error {e!r} (likely due to resource limits on the "
"system running the tests)"
)
finally:
if len(futures) < max_workers and pass_barrier:
barrier.abort()
for f in futures:
f.result()
__all__ = [
"ALL_INT_EA_DTYPES",
"ALL_INT_NUMPY_DTYPES",
"ALL_NUMPY_DTYPES",
"ALL_REAL_NUMPY_DTYPES",
"BOOL_DTYPES",
"BYTES_DTYPES",
"COMPLEX_DTYPES",
"DATETIME64_DTYPES",
"ENDIAN",
"FLOAT_EA_DTYPES",
"FLOAT_NUMPY_DTYPES",
"NARROW_NP_DTYPES",
"NP_NAT_OBJECTS",
"NULL_OBJECTS",
"OBJECT_DTYPES",
"SIGNED_INT_EA_DTYPES",
"SIGNED_INT_NUMPY_DTYPES",
"STRING_DTYPES",
"TIMEDELTA64_DTYPES",
"UNSIGNED_INT_EA_DTYPES",
"UNSIGNED_INT_NUMPY_DTYPES",
"SubclassedDataFrame",
"SubclassedSeries",
"assert_almost_equal",
"assert_attr_equal",
"assert_categorical_equal",
"assert_class_equal",
"assert_contains_all",
"assert_copy",
"assert_datetime_array_equal",
"assert_dict_equal",
"assert_equal",
"assert_extension_array_equal",
"assert_frame_equal",
"assert_index_equal",
"assert_indexing_slices_equivalent",
"assert_interval_array_equal",
"assert_is_sorted",
"assert_metadata_equivalent",
"assert_numpy_array_equal",
"assert_period_array_equal",
"assert_produces_warning",
"assert_series_equal",
"assert_sp_array_equal",
"assert_timedelta_array_equal",
"at",
"box_expected",
"can_set_locale",
"convert_rows_list_to_csv_str",
"decompress_file",
"external_error_raised",
"get_cython_table_params",
"get_dtype",
"get_finest_unit",
"get_locales",
"get_obj",
"get_op_from_name",
"getitem",
"iat",
"iloc",
"loc",
"maybe_produces_warning",
"raise_assert_detail",
"raises_chained_assignment_error",
"round_trip_pathlib",
"round_trip_pickle",
"run_multithreaded",
"set_locale",
"set_timezone",
"setitem",
"shares_memory",
"to_array",
"with_csv_dialect",
"write_to_compressed",
]

View File

@ -0,0 +1,89 @@
"""
Hypothesis data generator helpers.
"""
from datetime import datetime
from hypothesis import strategies as st
from hypothesis.extra.dateutil import timezones as dateutil_timezones
from pandas.compat import is_platform_windows
import pandas as pd
from pandas.tseries.offsets import (
BMonthBegin,
BMonthEnd,
BQuarterBegin,
BQuarterEnd,
BYearBegin,
BYearEnd,
MonthBegin,
MonthEnd,
QuarterBegin,
QuarterEnd,
YearBegin,
YearEnd,
)
OPTIONAL_INTS = st.lists(st.one_of(st.integers(), st.none()), max_size=10, min_size=3)
OPTIONAL_FLOATS = st.lists(st.one_of(st.floats(), st.none()), max_size=10, min_size=3)
OPTIONAL_TEXT = st.lists(st.one_of(st.none(), st.text()), max_size=10, min_size=3)
OPTIONAL_DICTS = st.lists(
st.one_of(st.none(), st.dictionaries(st.text(), st.integers())),
max_size=10,
min_size=3,
)
OPTIONAL_LISTS = st.lists(
st.one_of(st.none(), st.lists(st.text(), max_size=10, min_size=3)),
max_size=10,
min_size=3,
)
OPTIONAL_ONE_OF_ALL = st.one_of(
OPTIONAL_DICTS, OPTIONAL_FLOATS, OPTIONAL_INTS, OPTIONAL_LISTS, OPTIONAL_TEXT
)
if is_platform_windows():
DATETIME_NO_TZ = st.datetimes(min_value=datetime(1900, 1, 1))
else:
DATETIME_NO_TZ = st.datetimes()
DATETIME_JAN_1_1900_OPTIONAL_TZ = st.datetimes(
min_value=pd.Timestamp(1900, 1, 1).to_pydatetime(), # pyright: ignore[reportArgumentType]
max_value=pd.Timestamp(1900, 1, 1).to_pydatetime(), # pyright: ignore[reportArgumentType]
timezones=st.one_of(st.none(), dateutil_timezones(), st.timezones()),
)
DATETIME_IN_PD_TIMESTAMP_RANGE_NO_TZ = st.datetimes(
min_value=pd.Timestamp.min.to_pydatetime(warn=False),
max_value=pd.Timestamp.max.to_pydatetime(warn=False),
)
INT_NEG_999_TO_POS_999 = st.integers(-999, 999)
# The strategy for each type is registered in conftest.py, as they don't carry
# enough runtime information (e.g. type hints) to infer how to build them.
YQM_OFFSET = st.one_of(
*map(
st.from_type,
[
MonthBegin,
MonthEnd,
BMonthBegin,
BMonthEnd,
QuarterBegin,
QuarterEnd,
BQuarterBegin,
BQuarterEnd,
YearBegin,
YearEnd,
BYearBegin,
BYearEnd,
],
)
)

View File

@ -0,0 +1,129 @@
from __future__ import annotations
import gzip
import io
import tarfile
from typing import (
TYPE_CHECKING,
Any,
)
import zipfile
from pandas.compat._optional import import_optional_dependency
import pandas as pd
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from pandas import (
DataFrame,
Series,
)
# ------------------------------------------------------------------
# File-IO
def round_trip_pickle(obj: Any, tmp_path: Path) -> DataFrame | Series:
"""
Pickle an object and then read it again.
Parameters
----------
obj : any object
The object to pickle and then re-read.
path : str, path object or file-like object, default None
The path where the pickled object is written and then read.
Returns
-------
pandas object
The original object that was pickled and then re-read.
"""
pd.to_pickle(obj, tmp_path)
return pd.read_pickle(tmp_path)
def round_trip_pathlib(writer, reader, tmp_path: Path):
"""
Write an object to file specified by a pathlib.Path and read it back
Parameters
----------
writer : callable bound to pandas object
IO writing function (e.g. DataFrame.to_csv )
reader : callable
IO reading function (e.g. pd.read_csv )
path : str, default None
The path where the object is written and then read.
Returns
-------
pandas object
The original object that was serialized and then re-read.
"""
writer(tmp_path)
obj = reader(tmp_path)
return obj
def write_to_compressed(compression, path: str, data, dest: str = "test") -> None:
"""
Write data to a compressed file.
Parameters
----------
compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd'}
The compression type to use.
path : str
The file path to write the data.
data : str
The data to write.
dest : str, default "test"
The destination file (for ZIP only)
Raises
------
ValueError : An invalid compression value was passed in.
"""
args: tuple[Any, ...] = (data,)
mode = "wb"
method = "write"
compress_method: Callable
if compression == "zip":
compress_method = zipfile.ZipFile
mode = "w"
args = (dest, data)
method = "writestr"
elif compression == "tar":
compress_method = tarfile.TarFile
mode = "w"
file = tarfile.TarInfo(name=dest)
bytes = io.BytesIO(data)
file.size = len(data)
args = (file, bytes)
method = "addfile"
elif compression == "gzip":
compress_method = gzip.GzipFile
elif compression == "bz2":
import bz2
compress_method = bz2.BZ2File
elif compression == "zstd":
compress_method = import_optional_dependency("zstandard").open
elif compression == "xz":
import lzma
compress_method = lzma.LZMAFile
else:
raise ValueError(f"Unrecognized compression type: {compression}")
# error: No overload variant of "ZipFile" matches argument types "str", "str"
# error: No overload variant of "BZ2File" matches argument types "str", "str"
# error: Argument "mode" to "TarFile" has incompatible type "str";
# expected "Literal['r', 'a', 'w', 'x']
with compress_method(path, mode=mode) as f: # type: ignore[call-overload, arg-type]
getattr(f, method)(*args)

View File

@ -0,0 +1,266 @@
from __future__ import annotations
from contextlib import (
AbstractContextManager,
contextmanager,
nullcontext,
)
import inspect
import re
import sys
from typing import (
TYPE_CHECKING,
Literal,
Union,
cast,
)
import warnings
if TYPE_CHECKING:
from collections.abc import (
Generator,
Sequence,
)
@contextmanager
def assert_produces_warning(
expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None = Warning,
filter_level: Literal[
"error", "ignore", "always", "default", "module", "once"
] = "always",
check_stacklevel: bool = True,
raise_on_extra_warnings: bool = True,
match: str | tuple[str | None, ...] | None = None,
must_find_all_warnings: bool = True,
) -> Generator[list[warnings.WarningMessage]]:
"""
Context manager for running code expected to either raise a specific warning,
multiple specific warnings, or not raise any warnings. Verifies that the code
raises the expected warning(s), and that it does not raise any other unexpected
warnings. It is basically a wrapper around ``warnings.catch_warnings``.
Parameters
----------
expected_warning : {Warning, False, tuple[Warning, ...], None}, default Warning
The type of Exception raised. ``exception.Warning`` is the base
class for all warnings. To raise multiple types of exceptions,
pass them as a tuple. To check that no warning is returned,
specify ``False`` or ``None``.
filter_level : str or None, default "always"
Specifies whether warnings are ignored, displayed, or turned
into errors.
Valid values are:
* "error" - turns matching warnings into exceptions
* "ignore" - discard the warning
* "always" - always emit a warning
* "default" - print the warning the first time it is generated
from each location
* "module" - print the warning the first time it is generated
from each module
* "once" - print the warning the first time it is generated
check_stacklevel : bool, default True
If True, displays the line that called the function containing
the warning to show were the function is called. Otherwise, the
line that implements the function is displayed.
raise_on_extra_warnings : bool, default True
Whether extra warnings not of the type `expected_warning` should
cause the test to fail.
match : {str, tuple[str, ...]}, optional
Match warning message. If it's a tuple, it has to be the size of
`expected_warning`. If additionally `must_find_all_warnings` is
True, each expected warning's message gets matched with a respective
match. Otherwise, multiple values get treated as an alternative.
must_find_all_warnings : bool, default True
If True and `expected_warning` is a tuple, each expected warning
type must get encountered. Otherwise, even one expected warning
results in success.
Examples
--------
>>> import warnings
>>> with assert_produces_warning():
... warnings.warn(UserWarning())
>>> with assert_produces_warning(False):
... warnings.warn(RuntimeWarning())
Traceback (most recent call last):
...
AssertionError: Caused unexpected warning(s): ['RuntimeWarning'].
>>> with assert_produces_warning(UserWarning):
... warnings.warn(RuntimeWarning())
Traceback (most recent call last):
...
AssertionError: Did not see expected warning of class 'UserWarning'.
..warn:: This is *not* thread-safe.
"""
__tracebackhide__ = True
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter(filter_level)
try:
yield w
finally:
if expected_warning:
if isinstance(expected_warning, tuple) and must_find_all_warnings:
match = (
match
if isinstance(match, tuple)
else (match,) * len(expected_warning)
)
for warning_type, warning_match in zip(
expected_warning, match, strict=True
):
_assert_caught_expected_warnings(
caught_warnings=w,
expected_warning=warning_type,
match=warning_match,
check_stacklevel=check_stacklevel,
)
else:
expected_warning = cast(
Union[type[Warning], tuple[type[Warning], ...]],
expected_warning,
)
match = (
"|".join(m for m in match if m)
if isinstance(match, tuple)
else match
)
_assert_caught_expected_warnings(
caught_warnings=w,
expected_warning=expected_warning,
match=match,
check_stacklevel=check_stacklevel,
)
if raise_on_extra_warnings:
_assert_caught_no_extra_warnings(
caught_warnings=w,
expected_warning=expected_warning,
)
def maybe_produces_warning(
warning: type[Warning], condition: bool, **kwargs
) -> AbstractContextManager:
"""
Return a context manager that possibly checks a warning based on the condition
"""
if condition:
return assert_produces_warning(warning, **kwargs)
else:
return nullcontext()
def _assert_caught_expected_warnings(
*,
caught_warnings: Sequence[warnings.WarningMessage],
expected_warning: type[Warning] | tuple[type[Warning], ...],
match: str | None,
check_stacklevel: bool,
) -> None:
"""Assert that there was the expected warning among the caught warnings."""
saw_warning = False
matched_message = False
unmatched_messages = []
warning_name = (
tuple(x.__name__ for x in expected_warning)
if isinstance(expected_warning, tuple)
else expected_warning.__name__
)
for actual_warning in caught_warnings:
if issubclass(actual_warning.category, expected_warning):
saw_warning = True
if check_stacklevel:
_assert_raised_with_correct_stacklevel(actual_warning)
if match is not None:
if re.search(match, str(actual_warning.message)):
matched_message = True
else:
unmatched_messages.append(actual_warning.message)
if not saw_warning:
raise AssertionError(f"Did not see expected warning of class {warning_name!r}")
if match and not matched_message:
raise AssertionError(
f"Did not see warning {warning_name!r} "
f"matching '{match}'. The emitted warning messages are "
f"{unmatched_messages}"
)
def _assert_caught_no_extra_warnings(
*,
caught_warnings: Sequence[warnings.WarningMessage],
expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
) -> None:
"""Assert that no extra warnings apart from the expected ones are caught."""
extra_warnings = []
for actual_warning in caught_warnings:
if _is_unexpected_warning(actual_warning, expected_warning):
# GH#38630 pytest.filterwarnings does not suppress these.
if actual_warning.category == ResourceWarning:
# GH 44732: Don't make the CI flaky by filtering SSL-related
# ResourceWarning from dependencies
if "unclosed <ssl.SSLSocket" in str(actual_warning.message):
continue
# GH 44844: Matplotlib leaves font files open during the entire process
# upon import. Don't make CI flaky if ResourceWarning raised
# due to these open files.
if any("matplotlib" in mod for mod in sys.modules):
continue
if actual_warning.category == EncodingWarning:
# EncodingWarnings are checked in the CI
# pyproject.toml errors on EncodingWarnings in pandas
# Ignore EncodingWarnings from other libraries
continue
extra_warnings.append(
(
actual_warning.category.__name__,
actual_warning.message,
actual_warning.filename,
actual_warning.lineno,
)
)
if extra_warnings:
raise AssertionError(f"Caused unexpected warning(s): {extra_warnings!r}")
def _is_unexpected_warning(
actual_warning: warnings.WarningMessage,
expected_warning: type[Warning] | bool | tuple[type[Warning], ...] | None,
) -> bool:
"""Check if the actual warning issued is unexpected."""
if actual_warning and not expected_warning:
return True
expected_warning = cast(type[Warning], expected_warning)
return bool(not issubclass(actual_warning.category, expected_warning))
def _assert_raised_with_correct_stacklevel(
actual_warning: warnings.WarningMessage,
) -> None:
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
for _ in range(4):
frame = frame.f_back # type: ignore[union-attr]
try:
caller_filename = inspect.getfile(frame) # type: ignore[arg-type]
finally:
# See note in
# https://docs.python.org/3/library/inspect.html#inspect.Traceback
del frame
msg = (
"Warning not set with correct stacklevel. "
f"File where warning is raised: {actual_warning.filename} != "
f"{caller_filename}. Warning message: {actual_warning.message}"
)
assert actual_warning.filename == caller_filename, msg

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,30 @@
"""
Helpers for sharing tests between DataFrame/Series
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from pandas import DataFrame
if TYPE_CHECKING:
from pandas._typing import DtypeObj
def get_dtype(obj) -> DtypeObj:
if isinstance(obj, DataFrame):
# Note: we are assuming only one column
return obj.dtypes.iat[0]
else:
return obj.dtype
def get_obj(df: DataFrame, klass):
"""
For sharing tests using frame_or_series, either return the DataFrame
unchanged or return it's first column as a Series.
"""
if klass is DataFrame:
return df
return df._ixs(0, axis=1)

View File

@ -0,0 +1,151 @@
from __future__ import annotations
from contextlib import contextmanager
import os
import sys
from typing import (
IO,
TYPE_CHECKING,
)
from pandas.compat import CHAINED_WARNING_DISABLED
from pandas.errors import ChainedAssignmentError
from pandas.io.common import get_handle
if TYPE_CHECKING:
from collections.abc import Generator
from pandas._typing import (
BaseBuffer,
CompressionOptions,
FilePath,
)
@contextmanager
def decompress_file(
path: FilePath | BaseBuffer, compression: CompressionOptions
) -> Generator[IO[bytes]]:
"""
Open a compressed file and return a file object.
Parameters
----------
path : str
The path where the file is read from.
compression : {'gzip', 'bz2', 'zip', 'xz', 'zstd', None}
Name of the decompression to use
Returns
-------
file object
"""
with get_handle(path, "rb", compression=compression, is_text=False) as handle:
yield handle.handle
@contextmanager
def set_timezone(tz: str) -> Generator[None]:
"""
Context manager for temporarily setting a timezone.
Parameters
----------
tz : str
A string representing a valid timezone.
Examples
--------
>>> from datetime import datetime
>>> from dateutil.tz import tzlocal
>>> tzlocal().tzname(datetime(2021, 1, 1)) # doctest: +SKIP
'IST'
>>> with set_timezone("US/Eastern"):
... tzlocal().tzname(datetime(2021, 1, 1))
'EST'
"""
import time
def setTZ(tz) -> None:
if hasattr(time, "tzset"):
if tz is None:
try:
del os.environ["TZ"]
except KeyError:
pass
else:
os.environ["TZ"] = tz
# Next line allows typing checks to pass on Windows
if sys.platform != "win32":
time.tzset()
orig_tz = os.environ.get("TZ")
setTZ(tz)
try:
yield
finally:
setTZ(orig_tz)
@contextmanager
def with_csv_dialect(name: str, **kwargs) -> Generator[None]:
"""
Context manager to temporarily register a CSV dialect for parsing CSV.
Parameters
----------
name : str
The name of the dialect.
kwargs : mapping
The parameters for the dialect.
Raises
------
ValueError : the name of the dialect conflicts with a builtin one.
See Also
--------
csv : Python's CSV library.
"""
import csv
_BUILTIN_DIALECTS = {"excel", "excel-tab", "unix"}
if name in _BUILTIN_DIALECTS:
raise ValueError("Cannot override builtin dialect.")
csv.register_dialect(name, **kwargs)
try:
yield
finally:
csv.unregister_dialect(name)
def raises_chained_assignment_error(extra_warnings=(), extra_match=()):
from pandas._testing import assert_produces_warning
if CHAINED_WARNING_DISABLED:
if not extra_warnings:
from contextlib import nullcontext
return nullcontext()
else:
return assert_produces_warning(
extra_warnings,
match=extra_match,
)
else:
warning = ChainedAssignmentError
match = (
"A value is being set on a copy of a DataFrame or Series "
"through chained assignment"
)
if extra_warnings:
warning = (warning, *extra_warnings) # type: ignore[assignment]
return assert_produces_warning(
warning,
match=(match, *extra_match),
)