Skip to content

Commit 48fb671

Browse files
authored
[BugFix] Support PEP 649 & PEP 749 (Enable Python 3.14) (#7343)
* enable future support for python 3.14 * removing calling type on standard and extra
1 parent 06041a8 commit 48fb671

File tree

3 files changed

+66
-34
lines changed

3 files changed

+66
-34
lines changed

openbb_platform/core/openbb_core/app/provider_interface.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,24 @@ class CompanyNews(ProviderChoices):
552552

553553
return result
554554

555+
@staticmethod
556+
def _fields_to_pydantic(
557+
fields: list[TupleFieldType],
558+
) -> dict[str, tuple[type | None, Any]]:
559+
"""Convert dataclass fields to pydantic fields.
560+
561+
Parameters
562+
----------
563+
fields : list[TupleFieldType]
564+
List of (name, annotation, default) tuples.
565+
566+
Returns
567+
-------
568+
dict[str, tuple[type | None, Any]]
569+
Dictionary mapping field names to (annotation, default) tuples.
570+
"""
571+
return {name: (annotation, default) for name, annotation, default in fields}
572+
555573
def _generate_data_dc(
556574
self, map_: MapType
557575
) -> dict[str, dict[str, StandardData | ExtraData]]:
@@ -577,15 +595,15 @@ class EquityHistoricalData(StandardData):
577595
extra: dict
578596
standard, extra = self._extract_data(providers)
579597
result[model_name] = {
580-
"standard": make_dataclass(
581-
cls_name=model_name,
582-
fields=list(standard.values()), # type: ignore[arg-type]
583-
bases=(StandardData,),
598+
"standard": create_model( # type: ignore
599+
model_name,
600+
__base__=StandardData,
601+
**self._fields_to_pydantic(list(standard.values())), # type: ignore
584602
),
585-
"extra": make_dataclass(
586-
cls_name=model_name,
587-
fields=list(extra.values()), # type: ignore[arg-type]
588-
bases=(ExtraData,),
603+
"extra": create_model(
604+
model_name,
605+
__base__=ExtraData,
606+
**self._fields_to_pydantic(list(extra.values())), # type: ignore
589607
),
590608
}
591609

@@ -601,8 +619,9 @@ def _generate_return_schema(
601619
standard = dataclasses["standard"]
602620
extra = dataclasses["extra"]
603621

604-
fields = standard.model_fields.copy()
605-
fields.update(extra.model_fields)
622+
fields = getattr(standard, "model_fields", {}).copy()
623+
extra_fields = getattr(extra, "model_fields", {}).copy()
624+
fields.update(extra_fields)
606625

607626
fields_dict: dict[str, tuple[Any, Any]] = {}
608627

openbb_platform/core/openbb_core/app/static/package_builder.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import typing as typing_module
1313
from collections import OrderedDict
1414
from collections.abc import Callable
15-
from functools import partial
1615
from inspect import Parameter, _empty, isclass, signature
1716
from json import dumps, load
1817
from pathlib import Path
@@ -1096,15 +1095,15 @@ def format_params(
10961095
for meta in param.annotation.__metadata__
10971096
)
10981097
model = param.annotation.__args__[0]
1099-
is_pydantic_model = hasattr(model, "model_fields") or hasattr(
1098+
is_pydantic_model = hasattr(type(model), "model_fields") or hasattr(
11001099
model, "__pydantic_fields__"
11011100
)
11021101
is_get_request = not MethodDefinition.is_data_processing_function(path)
11031102

11041103
if is_pydantic_model and is_get_request and not has_depends:
11051104
# Unpack the model fields as query parameters
11061105
fields = getattr(
1107-
model,
1106+
type(model),
11081107
"model_fields",
11091108
getattr(model, "__pydantic_fields__", {}),
11101109
)
@@ -1747,7 +1746,7 @@ def build_command_method_body(
17471746
elif (
17481747
isinstance(param.annotation, _AnnotatedAlias)
17491748
and (
1750-
hasattr(param.annotation.__args__[0], "model_fields")
1749+
hasattr(type(param.annotation.__args__[0]), "model_fields")
17511750
or hasattr(param.annotation.__args__[0], "__pydantic_fields__")
17521751
)
17531752
and not MethodDefinition.is_data_processing_function(path)
@@ -1759,7 +1758,7 @@ def build_command_method_body(
17591758
if not has_depends:
17601759
model = param.annotation.__args__[0]
17611760
fields = getattr(
1762-
model,
1761+
type(model),
17631762
"model_fields",
17641763
getattr(model, "__pydantic_fields__", {}),
17651764
)
@@ -2668,23 +2667,26 @@ def generate( # pylint: disable=too-many-positional-arguments # noqa: PLR0912
26682667
param_types.update({k: v.type for k, v in kwarg_params.items()})
26692668
# Format the annotation to hide the metadata, tags, etc.
26702669
annotation = func.__annotations__.get("return")
2670+
model_fields = getattr(annotation, "model_fields", {})
26712671
results_type = (
26722672
cls._get_repr(
26732673
cls._get_generic_types(
2674-
annotation.model_fields["results"].annotation, # type: ignore[union-attr,arg-type]
2674+
model_fields["results"].annotation, # type: ignore[union-attr,arg-type]
26752675
[],
26762676
),
26772677
model_name,
26782678
)
2679-
if isclass(annotation) and issubclass(annotation, OBBject) # type: ignore[arg-type]
2679+
if isclass(annotation)
2680+
and issubclass(annotation, OBBject) # type: ignore[arg-type]
2681+
and "results" in model_fields
26802682
else model_name
26812683
)
26822684
doc = cls.generate_model_docstring(
26832685
model_name=model_name,
26842686
summary=func.__doc__ or "",
26852687
explicit_params=explicit_params,
26862688
kwarg_params=kwarg_params,
2687-
returns=return_schema.model_fields,
2689+
returns=getattr(return_schema, "model_fields", {}),
26882690
results_type=results_type,
26892691
sections=sections,
26902692
)
@@ -2796,8 +2798,10 @@ def generate( # pylint: disable=too-many-positional-arguments # noqa: PLR0912
27962798

27972799
if not is_primitive:
27982800
try:
2799-
if hasattr(return_annotation, "model_fields"):
2800-
fields = return_annotation.model_fields
2801+
if hasattr(type(return_annotation), "model_fields"):
2802+
fields = getattr(
2803+
type(return_annotation), "model_fields", {}
2804+
)
28012805

28022806
for field_name, field in fields.items():
28032807
field_type = cls.get_field_type(
@@ -2900,15 +2904,18 @@ def _get_generic_types(cls, type_: type, items: list) -> list[str]:
29002904
"""
29012905
if hasattr(type_, "__args__"):
29022906
origin = get_origin(type_)
2903-
# pylint: disable=unidiomatic-typecheck
2904-
if (
2905-
type(origin) is type
2907+
if origin is Union or origin is UnionType:
2908+
for arg in type_.__args__:
2909+
cls._get_generic_types(arg, items)
2910+
elif (
2911+
isinstance(origin, type)
29062912
and origin is not Annotated
2907-
and (name := getattr(type_, "_name", getattr(type_, "__name__", None)))
2913+
and (name := getattr(type_, "_name", getattr(origin, "__name__", None)))
29082914
):
29092915
items.append(name)
2910-
func = partial(cls._get_generic_types, items=items)
2911-
set().union(*map(func, type_.__args__), items) # type: ignore
2916+
for arg in type_.__args__:
2917+
cls._get_generic_types(arg, items)
2918+
29122919
return items
29132920

29142921
@staticmethod
@@ -4004,10 +4011,12 @@ def get_paths( # noqa: PLR0912
40044011
try:
40054012
module = sys.modules[route_func.__module__]
40064013
model_class = getattr(module, extracted_model_name, None)
4007-
if model_class and hasattr(model_class, "model_fields"):
4014+
if model_class and hasattr(type(model_class), "model_fields"):
40084015
# Set data to the fields
40094016
reference[path]["data"]["standard"] = []
4010-
for field_name, field in model_class.model_fields.items():
4017+
for field_name, field in getattr(
4018+
type(model_class), "model_fields", {}
4019+
).items():
40114020
field_type = DocstringGenerator.get_field_type(
40124021
field.annotation, field.is_required(), "website"
40134022
)

openbb_platform/obbject_extensions/charting/tests/test_charting.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,20 @@ def test_get_chart_function(obbject):
136136

137137

138138
@patch("openbb_charting.charting.Charting._get_chart_function")
139-
@patch("openbb_charting.charting.Chart")
140-
def test_show(_, mock_get_chart_function, obbject):
139+
def test_show(mock_get_chart_function, obbject):
141140
"""Test show method."""
142141
# Arrange
143142
mock_function = MagicMock()
144143
mock_get_chart_function.return_value = mock_function
145144
mock_fig = MagicMock()
145+
mock_fig.show = MagicMock(
146+
return_value=MagicMock(to_plotly_json=MagicMock(return_value={}))
147+
)
146148
mock_function.return_value = (mock_fig, {"content": "mock_content"})
147149
obj = Charting(obbject)
148150

149151
# Act
150-
obj.show()
152+
obj.show(render=False)
151153

152154
# Assert
153155
mock_get_chart_function.assert_called_once()
@@ -156,19 +158,21 @@ def test_show(_, mock_get_chart_function, obbject):
156158

157159
@patch("openbb_charting.charting.Charting._prepare_data_as_df")
158160
@patch("openbb_charting.charting.Charting._get_chart_function")
159-
@patch("openbb_charting.charting.Chart")
160-
def test_to_chart(_, mock_get_chart_function, mock_prepare_data_as_df, obbject):
161+
def test_to_chart(mock_get_chart_function, mock_prepare_data_as_df, obbject):
161162
"""Test to_chart method."""
162163
# Arrange
163164
mock_prepare_data_as_df.return_value = (mock_dataframe, True)
164165
mock_function = MagicMock()
165166
mock_get_chart_function.return_value = mock_function
166167
mock_fig = MagicMock()
168+
mock_fig.show = MagicMock(
169+
return_value=MagicMock(to_plotly_json=MagicMock(return_value={}))
170+
)
167171
mock_function.return_value = (mock_fig, {"content": "mock_content"})
168172
obj = Charting(obbject)
169173

170174
# Act
171-
obj.to_chart()
175+
obj.to_chart(render=False)
172176

173177
# Assert
174178
mock_get_chart_function.assert_called_once()

0 commit comments

Comments
 (0)