from __future__ import annotations import inspect from typing import TYPE_CHECKING, Any, Type, Union, Generic, TypeVar, Callable, cast from datetime import date, datetime from typing_extensions import ( Unpack, Literal, ClassVar, Protocol, Required, TypedDict, final, override, runtime_checkable, ) import pydantic import pydantic.generics from pydantic.fields import FieldInfo from ._types import ( Body, IncEx, Query, ModelT, Headers, Timeout, NotGiven, AnyMapping, HttpxRequestFiles, ) from ._utils import ( is_list, is_given, is_mapping, parse_date, parse_datetime, strip_not_given, ) from ._compat import PYDANTIC_V2, ConfigDict from ._compat import GenericModel as BaseGenericModel from ._compat import ( get_args, is_union, parse_obj, get_origin, is_literal_type, get_model_config, get_model_fields, field_get_default, ) from ._constants import RAW_RESPONSE_HEADER __all__ = ["BaseModel", "GenericModel"] _T = TypeVar("_T") @runtime_checkable class _ConfigProtocol(Protocol): allow_population_by_field_name: bool class BaseModel(pydantic.BaseModel): if PYDANTIC_V2: model_config: ClassVar[ConfigDict] = ConfigDict(extra="allow") else: @property @override def model_fields_set(self) -> set[str]: # a forwards-compat shim for pydantic v2 return self.__fields_set__ # type: ignore class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] extra: Any = pydantic.Extra.allow # type: ignore @override def __str__(self) -> str: # mypy complains about an invalid self arg return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc] # Override the 'construct' method in a way that supports recursive parsing without validation. # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. @classmethod @override def construct( cls: Type[ModelT], _fields_set: set[str] | None = None, **values: object, ) -> ModelT: m = cls.__new__(cls) fields_values: dict[str, object] = {} config = get_model_config(cls) populate_by_name = ( config.allow_population_by_field_name if isinstance(config, _ConfigProtocol) else config.get("populate_by_name") ) if _fields_set is None: _fields_set = set() model_fields = get_model_fields(cls) for name, field in model_fields.items(): key = field.alias if key is None or (key not in values and populate_by_name): key = name if key in values: fields_values[name] = _construct_field(value=values[key], field=field, key=key) _fields_set.add(name) else: fields_values[name] = field_get_default(field) _extra = {} for key, value in values.items(): if key not in model_fields: if PYDANTIC_V2: _extra[key] = value else: _fields_set.add(key) fields_values[key] = value object.__setattr__(m, "__dict__", fields_values) if PYDANTIC_V2: # these properties are copied from Pydantic's `model_construct()` method object.__setattr__(m, "__pydantic_private__", None) object.__setattr__(m, "__pydantic_extra__", _extra) object.__setattr__(m, "__pydantic_fields_set__", _fields_set) else: # init_private_attributes() does not exist in v2 m._init_private_attributes() # type: ignore # copied from Pydantic v1's `construct()` method object.__setattr__(m, "__fields_set__", _fields_set) return m if not TYPE_CHECKING: # type checkers incorrectly complain about this assignment # because the type signatures are technically different # although not in practice model_construct = construct if not PYDANTIC_V2: # we define aliases for some of the new pydantic v2 methods so # that we can just document these methods without having to specify # a specific pydantic version as some users may not know which # pydantic version they are currently using @override def model_dump( self, *, mode: Literal["json", "python"] | str = "python", include: IncEx = None, exclude: IncEx = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, warnings: bool = True, ) -> dict[str, Any]: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump Generate a dictionary representation of the model, optionally specifying which fields to include or exclude. Args: mode: The mode in which `to_python` should run. If mode is 'json', the dictionary will only contain JSON serializable types. If mode is 'python', the dictionary may contain any Python objects. include: A list of fields to include in the output. exclude: A list of fields to exclude from the output. by_alias: Whether to use the field's alias in the dictionary key if defined. exclude_unset: Whether to exclude fields that are unset or None from the output. exclude_defaults: Whether to exclude fields that are set to their default value from the output. exclude_none: Whether to exclude fields that have a value of `None` from the output. round_trip: Whether to enable serialization and deserialization round-trip support. warnings: Whether to log warnings when invalid fields are encountered. Returns: A dictionary representation of the model. """ if mode != "python": raise ValueError("mode is only supported in Pydantic v2") if round_trip != False: raise ValueError("round_trip is only supported in Pydantic v2") if warnings != True: raise ValueError("warnings is only supported in Pydantic v2") return super().dict( # pyright: ignore[reportDeprecated] include=include, exclude=exclude, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) @override def model_dump_json( self, *, indent: int | None = None, include: IncEx = None, exclude: IncEx = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, round_trip: bool = False, warnings: bool = True, ) -> str: """Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json Generates a JSON representation of the model using Pydantic's `to_json` method. Args: indent: Indentation to use in the JSON output. If None is passed, the output will be compact. include: Field(s) to include in the JSON output. Can take either a string or set of strings. exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings. by_alias: Whether to serialize using field aliases. exclude_unset: Whether to exclude fields that have not been explicitly set. exclude_defaults: Whether to exclude fields that have the default value. exclude_none: Whether to exclude fields that have a value of `None`. round_trip: Whether to use serialization/deserialization between JSON and class instance. warnings: Whether to show any warnings that occurred during serialization. Returns: A JSON string representation of the model. """ if round_trip != False: raise ValueError("round_trip is only supported in Pydantic v2") if warnings != True: raise ValueError("warnings is only supported in Pydantic v2") return super().json( # type: ignore[reportDeprecated] indent=indent, include=include, exclude=exclude, by_alias=by_alias, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) def _construct_field(value: object, field: FieldInfo, key: str) -> object: if value is None: return field_get_default(field) if PYDANTIC_V2: type_ = field.annotation else: type_ = cast(type, field.outer_type_) # type: ignore if type_ is None: raise RuntimeError(f"Unexpected field type is None for {key}") return construct_type(value=value, type_=type_) def is_basemodel(type_: type) -> bool: """Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`""" origin = get_origin(type_) or type_ if is_union(type_): for variant in get_args(type_): if is_basemodel(variant): return True return False return issubclass(origin, BaseModel) or issubclass(origin, GenericModel) def construct_type(*, value: object, type_: type) -> object: """Loose coercion to the expected type with construction of nested values. If the given value does not match the expected type then it is returned as-is. """ # we need to use the origin class for any types that are subscripted generics # e.g. Dict[str, object] origin = get_origin(type_) or type_ args = get_args(type_) if is_union(origin): try: return validate_type(type_=type_, value=value) except Exception: pass # if the data is not valid, use the first variant that doesn't fail while deserializing for variant in args: try: return construct_type(value=value, type_=variant) except Exception: continue raise RuntimeError(f"Could not convert data into a valid instance of {type_}") if origin == dict: if not is_mapping(value): return value _, items_type = get_args(type_) # Dict[_, items_type] return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)): if is_list(value): return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value] if is_mapping(value): if issubclass(type_, BaseModel): return type_.construct(**value) # type: ignore[arg-type] return cast(Any, type_).construct(**value) if origin == list: if not is_list(value): return value inner_type = args[0] # List[inner_type] return [construct_type(value=entry, type_=inner_type) for entry in value] if origin == float: if isinstance(value, int): coerced = float(value) if coerced != value: return value return coerced return value if type_ == datetime: try: return parse_datetime(value) # type: ignore except Exception: return value if type_ == date: try: return parse_date(value) # type: ignore except Exception: return value return value def validate_type(*, type_: type[_T], value: object) -> _T: """Strict validation that the given value matches the expected type""" if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel): return cast(_T, parse_obj(type_, value)) return cast(_T, _validate_non_model_type(type_=type_, value=value)) # our use of subclasssing here causes weirdness for type checkers, # so we just pretend that we don't subclass if TYPE_CHECKING: GenericModel = BaseModel else: class GenericModel(BaseGenericModel, BaseModel): pass if PYDANTIC_V2: from pydantic import TypeAdapter def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: return TypeAdapter(type_).validate_python(value) elif not TYPE_CHECKING: # TODO: condition is weird class RootModel(GenericModel, Generic[_T]): """Used as a placeholder to easily convert runtime types to a Pydantic format to provide validation. For example: ```py validated = RootModel[int](__root__='5').__root__ # validated: 5 ``` """ __root__: _T def _validate_non_model_type(*, type_: type[_T], value: object) -> _T: model = _create_pydantic_model(type_).validate(value) return cast(_T, model.__root__) def _create_pydantic_model(type_: _T) -> Type[RootModel[_T]]: return RootModel[type_] # type: ignore class FinalRequestOptionsInput(TypedDict, total=False): method: Required[str] url: Required[str] params: Query headers: Headers max_retries: int timeout: float | Timeout | None files: HttpxRequestFiles | None idempotency_key: str json_data: Body extra_json: AnyMapping @final class FinalRequestOptions(pydantic.BaseModel): method: str url: str params: Query = {} headers: Union[Headers, NotGiven] = NotGiven() max_retries: Union[int, NotGiven] = NotGiven() timeout: Union[float, Timeout, None, NotGiven] = NotGiven() files: Union[HttpxRequestFiles, None] = None idempotency_key: Union[str, None] = None post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven() # It should be noted that we cannot use `json` here as that would override # a BaseModel method in an incompatible fashion. json_data: Union[Body, None] = None extra_json: Union[AnyMapping, None] = None if PYDANTIC_V2: model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) else: class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated] arbitrary_types_allowed: bool = True def get_max_retries(self, max_retries: int) -> int: if isinstance(self.max_retries, NotGiven): return max_retries return self.max_retries def _strip_raw_response_header(self) -> None: if not is_given(self.headers): return if self.headers.get(RAW_RESPONSE_HEADER): self.headers = {**self.headers} self.headers.pop(RAW_RESPONSE_HEADER) # override the `construct` method so that we can run custom transformations. # this is necessary as we don't want to do any actual runtime type checking # (which means we can't use validators) but we do want to ensure that `NotGiven` # values are not present # # type ignore required because we're adding explicit types to `**values` @classmethod def construct( # type: ignore cls, _fields_set: set[str] | None = None, **values: Unpack[FinalRequestOptionsInput], ) -> FinalRequestOptions: kwargs: dict[str, Any] = { # we unconditionally call `strip_not_given` on any value # as it will just ignore any non-mapping types key: strip_not_given(value) for key, value in values.items() } if PYDANTIC_V2: return super().model_construct(_fields_set, **kwargs) return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated] if not TYPE_CHECKING: # type checkers incorrectly complain about this assignment model_construct = construct