Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions src/google/adk/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing_extensions import override

from .base_session_service import BaseSessionService
from .in_memory_session_service import InMemorySessionService
from .session import Session
from .state import State
from .vertex_ai_session_service import VertexAiSessionService

try:
from .database_session_service import DatabaseSessionService
except ImportError:
# This handles the case where optional dependencies (like sqlalchemy)
# are not installed. A placeholder class ensures the symbol is always
# available for documentation tools and static analysis.
class DatabaseSessionService(BaseSessionService):
"""Placeholder for DatabaseSessionService when dependencies are not installed."""

_ERROR_MESSAGE = (
'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is'
' installed correctly.'
)

def __init__(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

@override
async def create_session(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

@override
async def get_session(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

@override
async def list_sessions(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

@override
async def delete_session(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)

@override
async def append_event(self, *args, **kwargs):
raise ImportError(self._ERROR_MESSAGE)


__all__ = [
'BaseSessionService',
'DatabaseSessionService',
Expand All @@ -25,17 +66,3 @@
'State',
'VertexAiSessionService',
]


def __getattr__(name: str):
if name == 'DatabaseSessionService':
try:
from .database_session_service import DatabaseSessionService

return DatabaseSessionService
except ImportError as e:
raise ImportError(
'DatabaseSessionService requires sqlalchemy>=2.0, please ensure it is'
' installed correctly.'
) from e
raise AttributeError(f'module {__name__!r} has no attribute {name!r}')