1212import typing as typing_module
1313from collections import OrderedDict
1414from collections .abc import Callable
15- from functools import partial
1615from inspect import Parameter , _empty , isclass , signature
1716from json import dumps , load
1817from pathlib import Path
@@ -1096,15 +1095,15 @@ def format_params(
10961095 for meta in param .annotation .__metadata__
10971096 )
10981097 model = param .annotation .__args__ [0 ]
1099- is_pydantic_model = hasattr (model , "model_fields" ) or hasattr (
1098+ is_pydantic_model = hasattr (type ( model ) , "model_fields" ) or hasattr (
11001099 model , "__pydantic_fields__"
11011100 )
11021101 is_get_request = not MethodDefinition .is_data_processing_function (path )
11031102
11041103 if is_pydantic_model and is_get_request and not has_depends :
11051104 # Unpack the model fields as query parameters
11061105 fields = getattr (
1107- model ,
1106+ type ( model ) ,
11081107 "model_fields" ,
11091108 getattr (model , "__pydantic_fields__" , {}),
11101109 )
@@ -1747,7 +1746,7 @@ def build_command_method_body(
17471746 elif (
17481747 isinstance (param .annotation , _AnnotatedAlias )
17491748 and (
1750- hasattr (param .annotation .__args__ [0 ], "model_fields" )
1749+ hasattr (type ( param .annotation .__args__ [0 ]) , "model_fields" )
17511750 or hasattr (param .annotation .__args__ [0 ], "__pydantic_fields__" )
17521751 )
17531752 and not MethodDefinition .is_data_processing_function (path )
@@ -1759,7 +1758,7 @@ def build_command_method_body(
17591758 if not has_depends :
17601759 model = param .annotation .__args__ [0 ]
17611760 fields = getattr (
1762- model ,
1761+ type ( model ) ,
17631762 "model_fields" ,
17641763 getattr (model , "__pydantic_fields__" , {}),
17651764 )
@@ -2668,23 +2667,26 @@ def generate( # pylint: disable=too-many-positional-arguments # noqa: PLR0912
26682667 param_types .update ({k : v .type for k , v in kwarg_params .items ()})
26692668 # Format the annotation to hide the metadata, tags, etc.
26702669 annotation = func .__annotations__ .get ("return" )
2670+ model_fields = getattr (annotation , "model_fields" , {})
26712671 results_type = (
26722672 cls ._get_repr (
26732673 cls ._get_generic_types (
2674- annotation . model_fields ["results" ].annotation , # type: ignore[union-attr,arg-type]
2674+ model_fields ["results" ].annotation , # type: ignore[union-attr,arg-type]
26752675 [],
26762676 ),
26772677 model_name ,
26782678 )
2679- if isclass (annotation ) and issubclass (annotation , OBBject ) # type: ignore[arg-type]
2679+ if isclass (annotation )
2680+ and issubclass (annotation , OBBject ) # type: ignore[arg-type]
2681+ and "results" in model_fields
26802682 else model_name
26812683 )
26822684 doc = cls .generate_model_docstring (
26832685 model_name = model_name ,
26842686 summary = func .__doc__ or "" ,
26852687 explicit_params = explicit_params ,
26862688 kwarg_params = kwarg_params ,
2687- returns = return_schema . model_fields ,
2689+ returns = getattr ( return_schema , " model_fields" , {}) ,
26882690 results_type = results_type ,
26892691 sections = sections ,
26902692 )
@@ -2796,8 +2798,10 @@ def generate( # pylint: disable=too-many-positional-arguments # noqa: PLR0912
27962798
27972799 if not is_primitive :
27982800 try :
2799- if hasattr (return_annotation , "model_fields" ):
2800- fields = return_annotation .model_fields
2801+ if hasattr (type (return_annotation ), "model_fields" ):
2802+ fields = getattr (
2803+ type (return_annotation ), "model_fields" , {}
2804+ )
28012805
28022806 for field_name , field in fields .items ():
28032807 field_type = cls .get_field_type (
@@ -2900,15 +2904,18 @@ def _get_generic_types(cls, type_: type, items: list) -> list[str]:
29002904 """
29012905 if hasattr (type_ , "__args__" ):
29022906 origin = get_origin (type_ )
2903- # pylint: disable=unidiomatic-typecheck
2904- if (
2905- type (origin ) is type
2907+ if origin is Union or origin is UnionType :
2908+ for arg in type_ .__args__ :
2909+ cls ._get_generic_types (arg , items )
2910+ elif (
2911+ isinstance (origin , type )
29062912 and origin is not Annotated
2907- and (name := getattr (type_ , "_name" , getattr (type_ , "__name__" , None )))
2913+ and (name := getattr (type_ , "_name" , getattr (origin , "__name__" , None )))
29082914 ):
29092915 items .append (name )
2910- func = partial (cls ._get_generic_types , items = items )
2911- set ().union (* map (func , type_ .__args__ ), items ) # type: ignore
2916+ for arg in type_ .__args__ :
2917+ cls ._get_generic_types (arg , items )
2918+
29122919 return items
29132920
29142921 @staticmethod
@@ -4004,10 +4011,12 @@ def get_paths( # noqa: PLR0912
40044011 try :
40054012 module = sys .modules [route_func .__module__ ]
40064013 model_class = getattr (module , extracted_model_name , None )
4007- if model_class and hasattr (model_class , "model_fields" ):
4014+ if model_class and hasattr (type ( model_class ) , "model_fields" ):
40084015 # Set data to the fields
40094016 reference [path ]["data" ]["standard" ] = []
4010- for field_name , field in model_class .model_fields .items ():
4017+ for field_name , field in getattr (
4018+ type (model_class ), "model_fields" , {}
4019+ ).items ():
40114020 field_type = DocstringGenerator .get_field_type (
40124021 field .annotation , field .is_required (), "website"
40134022 )
0 commit comments