diff --git a/.bumpversion.cfg b/.bumpversion.cfg
index 4512da1..ef646f8 100644
--- a/.bumpversion.cfg
+++ b/.bumpversion.cfg
@@ -1,5 +1,5 @@
[bumpversion]
-current_version = 0.0.74
+current_version = 0.0.75
commit = True
tag = True
diff --git a/README.md b/README.md
index e4d097b..825c4e0 100644
--- a/README.md
+++ b/README.md
@@ -13,15 +13,12 @@ You can install SqliteCloud Package using Python Package Index (PYPI):
$ pip install SqliteCloud
```
-- Follow the instructions reported here https://github.com/sqlitecloud/sdk/tree/master/C to build the driver.
-
-- Set SQLITECLOUD_DRIVER_PATH environment variable to the path of the driver file build.
-
## Usage
```python
-from sqlitecloud.client import SqliteCloudClient, SqliteCloudAccount
+from sqlitecloud.client import SqliteCloudClient
+from sqlitecloud.types import SqliteCloudAccount
```
### _Init a connection_
@@ -45,9 +42,8 @@ conn = client.open_connection()
### _Execute a query_
You can bind values to parametric queries: you can pass parameters as positional values in an array
```python
-result = client.exec_statement(
- "SELECT * FROM table_name WHERE id = ?",
- [1],
+result = client.exec_query(
+ "SELECT * FROM table_name WHERE id = 1"
conn=conn
)
```
diff --git a/samples.ipynb b/samples.ipynb
index e7809fb..17ce3da 100644
--- a/samples.ipynb
+++ b/samples.ipynb
@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -37,7 +37,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -75,7 +75,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -92,7 +92,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -124,7 +124,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
diff --git a/src/README-PYPI.md b/src/README-PYPI.md
index e4d097b..825c4e0 100644
--- a/src/README-PYPI.md
+++ b/src/README-PYPI.md
@@ -13,15 +13,12 @@ You can install SqliteCloud Package using Python Package Index (PYPI):
$ pip install SqliteCloud
```
-- Follow the instructions reported here https://github.com/sqlitecloud/sdk/tree/master/C to build the driver.
-
-- Set SQLITECLOUD_DRIVER_PATH environment variable to the path of the driver file build.
-
## Usage
```python
-from sqlitecloud.client import SqliteCloudClient, SqliteCloudAccount
+from sqlitecloud.client import SqliteCloudClient
+from sqlitecloud.types import SqliteCloudAccount
```
### _Init a connection_
@@ -45,9 +42,8 @@ conn = client.open_connection()
### _Execute a query_
You can bind values to parametric queries: you can pass parameters as positional values in an array
```python
-result = client.exec_statement(
- "SELECT * FROM table_name WHERE id = ?",
- [1],
+result = client.exec_query(
+ "SELECT * FROM table_name WHERE id = 1"
conn=conn
)
```
diff --git a/src/setup.py b/src/setup.py
index 72ad776..99522dd 100644
--- a/src/setup.py
+++ b/src/setup.py
@@ -18,24 +18,24 @@ def read_file(filename):
setup(
name='SqliteCloud',
- version='0.0.74',
- author='Sam Reghenzi & Matteo Fredi',
+ version='0.0.75',
+ author='sqlitecloud.io',
description='A Python package for working with SQLite databases in the cloud.',
long_description=read_file('README-PYPI.md'),
long_description_content_type='text/markdown',
url="https://github.com/sqlitecloud/python",
packages=find_packages(),
install_requires=[
- 'mypy == 1.6.1',
- 'mypy-extensions == 1.0.0',
- 'typing-extensions == 4.8.0',
- 'black == 23.7.0',
- 'python-dotenv == 1.0.0',
+ 'lz4 == 3.1.10',
],
classifiers=[
'Development Status :: 3 - Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
+ 'Programming Language :: Python :: 3.6',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Programming Language :: Python :: 3.12',
diff --git a/src/sqlitecloud/__init__.py b/src/sqlitecloud/__init__.py
index ae8e873..1cf6267 100644
--- a/src/sqlitecloud/__init__.py
+++ b/src/sqlitecloud/__init__.py
@@ -1 +1 @@
-VERSION = "0.0.74"
+VERSION = "0.1.0"
diff --git a/src/sqlitecloud/client.py b/src/sqlitecloud/client.py
index bad9894..647e6c5 100644
--- a/src/sqlitecloud/client.py
+++ b/src/sqlitecloud/client.py
@@ -25,15 +25,16 @@ def __init__(
self,
cloud_account: Optional[SqliteCloudAccount] = None,
connection_str: Optional[str] = None,
- # pub_subs: SQCloudPubSubCallback = [],
) -> None:
"""Initializes a new instance of the class with connection information.
Args:
- connection_str (str): The connection string for the database.
+ cloud_account (SqliteCloudAccount): The account information for the SQlite Cloud database.
+ connection_str (str): The connection string for the SQlite Cloud database.
+ Eg: sqlitecloud://user:pass@host.com:port/dbname?timeout=10&apikey=abcd123
"""
- self.driver = Driver()
+ self._driver = Driver()
self.config = SQCloudConfig()
@@ -53,7 +54,7 @@ def open_connection(self) -> SQCloudConnect:
Raises:
SQCloudException: If an error occurs while opening the connection.
"""
- connection = self.driver.connect(
+ connection = self._driver.connect(
self.config.account.hostname, self.config.account.port, self.config
)
@@ -61,10 +62,21 @@ def open_connection(self) -> SQCloudConnect:
def disconnect(self, conn: SQCloudConnect) -> None:
"""Close the connection to the database."""
- self.driver.disconnect(conn)
+ self._driver.disconnect(conn)
+
+ def is_connected(self, conn: SQCloudConnect) -> bool:
+ """Check if the connection is still open.
+
+ Args:
+ conn (SQCloudConnect): The connection to the database.
+
+ Returns:
+ bool: True if the connection is open, False otherwise.
+ """
+ return self._driver.is_connected(conn)
def exec_query(
- self, query: str, conn: SQCloudConnect = None
+ self, query: str, conn: SQCloudConnect
) -> SqliteCloudResultSet:
"""Executes a SQL query on the SQLite Cloud database.
@@ -73,15 +85,11 @@ def exec_query(
Returns:
SqliteCloudResultSet: The result set of the executed query.
- """
- provided_connection = conn is not None
- if not provided_connection:
- conn = self.open_connection()
-
- result = self.driver.execute(query, conn)
- if not provided_connection:
- self.disconnect(conn)
+ Raises:
+ SQCloudException: If an error occurs while executing the query.
+ """
+ result = self._driver.execute(query, conn)
return SqliteCloudResultSet(result)
@@ -92,7 +100,7 @@ def sendblob(self, blob: bytes, conn: SQCloudConnect) -> SqliteCloudResultSet:
blob (bytes): The blob to be sent to the database.
conn (SQCloudConnect): The connection to the database.
"""
- return self.driver.sendblob(blob, conn)
+ return self._driver.send_blob(blob, conn)
def _parse_connection_string(self, connection_string) -> SQCloudConfig:
# URL STRING FORMAT
diff --git a/src/sqlitecloud/download.py b/src/sqlitecloud/download.py
new file mode 100644
index 0000000..92bbf92
--- /dev/null
+++ b/src/sqlitecloud/download.py
@@ -0,0 +1,41 @@
+from io import BufferedWriter
+import logging
+
+from sqlitecloud.driver import Driver
+from sqlitecloud.types import SQCloudConnect
+
+
+def xCallback(
+ fd: BufferedWriter, data: bytes, blen: int, ntot: int, nprogress: int
+) -> None:
+ """
+ Callback function used for downloading data.
+ Data is passed to the callback to be written to the file and to
+ monitor the progress.
+
+ Args:
+ fd (BufferedWriter): The file descriptor to write the downloaded data to.
+ data (bytes): The data to be written.
+ blen (int): The length of the data.
+ ntot (int): The total length of the data being downloaded.
+ nprogress (int): The number of bytes already downloaded.
+ """
+ fd.write(data)
+
+ if blen == 0:
+ logging.log(logging.DEBUG, "DOWNLOAD COMPLETE")
+ else:
+ logging.log(logging.DEBUG, f"{(nprogress + blen) / ntot * 100:.2f}%")
+
+
+def download_db(connection: SQCloudConnect, dbname: str, filename: str) -> None:
+ """
+ Download a database from the server.
+
+ Raises:
+ SQCloudException: If an error occurs while downloading the database.
+ """
+ driver = Driver()
+
+ with open(filename, "wb") as fd:
+ driver.download_database(connection, dbname, fd, xCallback, False)
diff --git a/src/sqlitecloud/driver.py b/src/sqlitecloud/driver.py
index fd49e82..aeb7767 100644
--- a/src/sqlitecloud/driver.py
+++ b/src/sqlitecloud/driver.py
@@ -1,10 +1,16 @@
+from io import BufferedReader, BufferedWriter
+import logging
+import select
import ssl
-from typing import Optional, Union
+import threading
+from typing import Callable, Optional, Union
import lz4.block
-from sqlitecloud.resultset import SQCloudResult
+from sqlitecloud.resultset import SQCloudResult, SqliteCloudResultSet
from sqlitecloud.types import (
SQCLOUD_CMD,
+ SQCLOUD_DEFAULT,
SQCLOUD_INTERNAL_ERRCODE,
+ SQCLOUD_RESULT_TYPE,
SQCLOUD_ROWSET,
SQCloudConfig,
SQCloudConnect,
@@ -17,6 +23,8 @@
class Driver:
+ SQCLOUD_DEFAULT_UPLOAD_SIZE = 512 * 1024
+
def __init__(self) -> None:
# Used while parsing chunked rowset
self._rowset: SQCloudResult = None
@@ -25,7 +33,7 @@ def connect(
self, hostname: str, port: int, config: SQCloudConfig
) -> SQCloudConnect:
"""
- Connects to the SQLite Cloud server.
+ Connect to the SQLite Cloud server.
Args:
hostname (str): The hostname of the server.
@@ -36,10 +44,77 @@ def connect(
SQCloudConnect: The connection object.
Raises:
- SQCloudException: If an error occurs while initializing the socket.
+ SQCloudException: If an error occurs while connecting the socket.
+ """
+ sock = self._internal_connect(hostname, port, config)
+
+ connection = SQCloudConnect()
+ connection.config = config
+ connection.socket = sock
+
+ self._internal_config_apply(connection, config)
+
+ return connection
+
+ def disconnect(self, conn: SQCloudConnect, only_main_socket: bool = False) -> None:
+ """
+ Disconnect from the SQLite Cloud server.
+ """
+ try:
+ if conn.socket:
+ conn.socket.close()
+ if not only_main_socket and conn.pubsub_socket:
+ conn.pubsub_socket.close()
+ except Exception:
+ pass
+ finally:
+ conn.socket = None
+ if not only_main_socket:
+ conn.pubsub_socket = None
+
+ def execute(self, command: str, connection: SQCloudConnect) -> SQCloudResult:
+ """
+ Execute a query on the SQLite Cloud server.
+ """
+ return self._internal_run_command(connection, command)
+
+ def send_blob(self, blob: bytes, conn: SQCloudConnect) -> SQCloudResult:
+ """
+ Send a blob to the SQLite Cloud server.
+ """
+ try:
+ conn.isblob = True
+ return self._internal_run_command(conn, blob)
+ finally:
+ conn.isblob = False
+
+ def is_connected(
+ self, connection: SQCloudConnect, main_socket: bool = True
+ ) -> bool:
+ """
+ Check if the connection is still open.
+ """
+ sock = connection.socket if main_socket else connection.pubsub_socket
+
+ if not sock:
+ return False
+ try:
+ sock.sendall(b"")
+ except OSError:
+ return False
+
+ return True
+
+ def _internal_connect(
+ self, hostname: str, port: int, config: SQCloudConfig
+ ) -> socket:
+ """
+ Create a socket connection to the SQLite Cloud server.
"""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(config.connect_timeout)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
if not config.insecure:
context = ssl.create_default_context(cafile=config.root_certificate)
@@ -59,36 +134,243 @@ def connect(
errmsg = f"An error occurred while initializing the socket."
raise SQCloudException(errmsg) from e
- connection = SQCloudConnect()
- connection.socket = sock
- connection.config = config
+ return sock
- self._internal_config_apply(connection, config)
+ def _internal_reconnect(self, buffer: bytes) -> bool:
+ return True
- return connection
+ def _internal_setup_pubsub(self, connection: SQCloudConnect, buffer: bytes) -> bool:
+ """
+ Prepare the connection for PubSub.
+ Opens a new specific socket and starts the thread to listen for incoming messages.
+ """
+ if self.is_connected(connection, False):
+ return True
+
+ if connection.pubsub_callback is None:
+ raise SQCloudException(
+ "A callback function must be provided to setup the PubSub connection."
+ )
+
+ connection.pubsub_socket = self._internal_connect(
+ connection.config.account.hostname,
+ connection.config.account.port,
+ connection.config,
+ )
+
+ self._internal_run_command(connection, buffer, False)
+ thread = threading.Thread(
+ target=self._internal_pubsub_thread, args=(connection,)
+ )
+ # kill the thread when the main one is terminated
+ thread.daemon = True
+ thread.start()
+ connection.pubsub_thread = thread
+
+ return True
+
+ def _internal_pubsub_thread(self, connection: SQCloudConnect) -> None:
+ blen = 2048
+ buffer: bytes = b""
- def disconnect(self, conn: SQCloudConnect):
try:
- if conn.socket:
- conn.socket.close()
+ while True:
+ tread = 0
+
+ try:
+ if not connection.pubsub_socket:
+ logging.info("PubSub socket dismissed.")
+ break
+
+ # wait for the socket to be readable (no timeout)
+ ready_to_read, _, errors = select.select(
+ [connection.pubsub_socket], [], []
+ )
+ # eg, no data to read
+ if len(ready_to_read) == 0:
+ continue
+ # eg, if the socket is closed
+ if len(errors) > 0:
+ break
+
+ data = connection.pubsub_socket.recv(blen)
+ if not data:
+ logging.info("PubSub connection closed.")
+ break
+ except Exception as e:
+ logging.error(
+ f"An error occurred while reading data: {SQCLOUD_INTERNAL_ERRCODE.NETWORK.value} ({e})."
+ )
+ break
+
+ nread = len(data)
+ tread += nread
+ blen -= nread
+ buffer += data
+
+ sqcloud_number = self._internal_parse_number(buffer)
+ clen = sqcloud_number.value
+ if clen == 0:
+ continue
+
+ # check if read is complete
+ # clen is the lenght parsed in the buffer
+ # cstart is the index of the first space
+ cstart = sqcloud_number.cstart
+ if clen + cstart != tread:
+ continue
+
+ result = self._internal_parse_buffer(connection, buffer, tread)
+ if result.tag == SQCLOUD_RESULT_TYPE.RESULT_STRING:
+ result.tag = SQCLOUD_RESULT_TYPE.RESULT_JSON
+
+ connection.pubsub_callback(
+ connection, SqliteCloudResultSet(result), connection.pubsub_data
+ )
+ except Exception as e:
+ logging.error(f"An error occurred while parsing data: {e}.")
+
finally:
- conn.socket = None
+ connection.pubsub_callback(connection, None, connection.pubsub_data)
+
+ def upload_database(
+ self,
+ connection: SQCloudConnect,
+ dbname: str,
+ key: Optional[str],
+ is_file_transfer: bool,
+ snapshot_id: int,
+ is_internal_db: bool,
+ fd: BufferedReader,
+ dbsize: int,
+ xCallback: Callable[[BufferedReader, int, int, int], bytes],
+ ) -> None:
+ """
+ Uploads a database to the server.
- def execute(self, command: str, connection: SQCloudConnect) -> SQCloudResult:
- return self._internal_run_command(connection, command)
+ Args:
+ connection (SQCloudConnect): The connection object to the SQLite Cloud server.
+ dbname (str): The name of the database to upload.
+ key (Optional[str]): The encryption key for the database, if applicable.
+ is_file_transfer (bool): Indicates whether the database is being transferred as a file.
+ snapshot_id (int): The ID of the snapshot to upload.
+ is_internal_db (bool): Indicates whether the database is an internal database.
+ fd (BufferedReader): The file descriptor of the database file.
+ dbsize (int): The size of the database file.
+ xCallback (Callable[[BufferedReader, int, int, int], bytes]): The callback function to read the buffer.
- def sendblob(self, blob: bytes, conn: SQCloudConnect) -> SQCloudResult:
+ Raises:
+ SQCloudException: If an error occurs during the upload process.
+
+ """
+ keyarg = "KEY " if key else ""
+ keyvalue = key if key else ""
+
+ # prepare command to execute
+ command = ""
+ if is_file_transfer:
+ internalarg = "INTERNAL" if is_internal_db else ""
+ command = f"TRANSFER DATABASE '{dbname}' {keyarg}{keyvalue} SNAPSHOT {snapshot_id} {internalarg}"
+ else:
+ command = f"UPLOAD DATABASE '{dbname}' {keyarg}{keyvalue}"
+
+ # execute command on server side
+ result = self._internal_run_command(connection, command)
+ if not result.data[0]:
+ raise SQCloudException(
+ "An error occurred while initializing the upload of the database."
+ )
+
+ buffer: bytes = b""
+ blen = 0
+ nprogress = 0
try:
- conn.isblob = True
- return self._internal_run_command(conn, blob)
- finally:
- conn.isblob = False
+ while True:
+ # execute callback to read buffer
+ blen = SQCLOUD_DEFAULT.UPLOAD_SIZE.value
+ try:
+ buffer = xCallback(fd, blen, dbsize, nprogress)
+ blen = len(buffer)
+ except Exception as e:
+ raise SQCloudException(
+ "An error occurred while reading the file."
+ ) from e
+
+ try:
+ # send also the final confirmation blob of zero bytes
+ self.send_blob(buffer, connection)
+ except Exception as e:
+ raise SQCloudException(
+ "An error occurred while uploading the file."
+ ) from e
+
+ # update progress
+ nprogress += blen
+
+ if blen == 0:
+ # Upload completed
+ break
+ except Exception as e:
+ self._internal_run_command(connection, "UPLOAD ABORT")
+ raise e
+
+ def download_database(
+ self,
+ connection: SQCloudConnect,
+ dbname: str,
+ fd: BufferedWriter,
+ xCallback: Callable[[BufferedWriter, int, int, int], bytes],
+ if_exists: bool,
+ ) -> None:
+ """
+ Downloads a database from the SQLite Cloud service.
- def _internal_reconnect(self, buffer: bytes) -> bool:
- return True
+ Args:
+ connection (SQCloudConnect): The connection object used to communicate with the SQLite Cloud service.
+ dbname (str): The name of the database to download.
+ fd (BufferedWriter): The file descriptor to write the downloaded data to.
+ xCallback (Callable[[BufferedWriter, int, int, int], bytes]): A callback function to write downloaded data with the download progress information.
+ if_exists (bool): If True, the download won't rise an exception if database is missing.
- def _internal_setup_pubsub(self, buffer: bytes) -> bool:
- return True
+ Raises:
+ SQCloudException: If an error occurs while downloading the database.
+
+ """
+ exists_cmd = " IF EXISTS" if if_exists else ""
+ result = self._internal_run_command(
+ connection, f"DOWNLOAD DATABASE {dbname}{exists_cmd};"
+ )
+
+ if result.nrows == 0:
+ raise SQCloudException(
+ "An error occurred while initializing the download of the database."
+ )
+
+ # result is an ARRAY (database size, number of pages, raft_index)
+ download_info = result.data[0]
+ db_size = int(download_info[0])
+
+ # loop to download
+ progress_size = 0
+
+ try:
+ while progress_size < db_size:
+ result = self._internal_run_command(connection, "DOWNLOAD STEP")
+
+ # res is BLOB, decode it
+ data = result.data[0]
+ data_len = len(data)
+
+ # execute callback (with progress_size updated)
+ progress_size += data_len
+ xCallback(fd, data, data_len, db_size, progress_size)
+
+ # check exit condition
+ if data_len == 0:
+ break
+ except Exception as e:
+ self._internal_run_command(connection, "DOWNLOAD ABORT")
+ raise e
def _internal_config_apply(
self, connection: SQCloudConnect, config: SQCloudConfig
@@ -135,13 +417,19 @@ def _internal_config_apply(
self._internal_run_command(connection, buffer)
def _internal_run_command(
- self, connection: SQCloudConnect, command: Union[str, bytes]
- ) -> None:
- self._internal_socket_write(connection, command)
- return self._internal_socket_read(connection)
+ self,
+ connection: SQCloudConnect,
+ command: Union[str, bytes],
+ main_socket: bool = True,
+ ) -> SQCloudResult:
+ self._internal_socket_write(connection, command, main_socket)
+ return self._internal_socket_read(connection, main_socket)
def _internal_socket_write(
- self, connection: SQCloudConnect, command: Union[str, bytes]
+ self,
+ connection: SQCloudConnect,
+ command: Union[str, bytes],
+ main_socket: bool = True,
) -> None:
# compute header
delimit = "$" if connection.isblob else "+"
@@ -149,27 +437,31 @@ def _internal_socket_write(
buffer_len = len(buffer)
header = f"{delimit}{buffer_len} "
+ sock = connection.socket if main_socket else connection.pubsub_socket
+
# write header
try:
- connection.socket.sendall(header.encode())
+ sock.sendall(header.encode())
except Exception as exc:
raise SQCloudException(
"An error occurred while writing header data.",
- SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK,
+ SQCLOUD_INTERNAL_ERRCODE.NETWORK,
) from exc
# write buffer
if buffer_len == 0:
return
try:
- connection.socket.sendall(buffer)
+ sock.sendall(buffer)
except Exception as exc:
raise SQCloudException(
"An error occurred while writing data.",
- SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK,
+ SQCLOUD_INTERNAL_ERRCODE.NETWORK,
) from exc
- def _internal_socket_read(self, connection: SQCloudConnect) -> SQCloudResult:
+ def _internal_socket_read(
+ self, connection: SQCloudConnect, main_socket: bool = True
+ ) -> SQCloudResult:
"""
Read from the socket and parse the response.
@@ -182,15 +474,17 @@ def _internal_socket_read(self, connection: SQCloudConnect) -> SQCloudResult:
buffer_size = 8192
nread = 0
+ sock = connection.socket if main_socket else connection.pubsub_socket
+
while True:
try:
- data = connection.socket.recv(buffer_size)
+ data = sock.recv(buffer_size)
if not data:
raise SQCloudException("Incomplete response from server.")
except Exception as exc:
raise SQCloudException(
"An error occurred while reading data from the socket.",
- SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK,
+ SQCLOUD_INTERNAL_ERRCODE.NETWORK,
) from exc
# the expected data length to read
@@ -275,7 +569,7 @@ def _internal_parse_buffer(
# check OK value
if buffer == b"+2 OK":
- return SQCloudResult(True)
+ return SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_OK, True)
cmd = chr(buffer[0])
@@ -306,7 +600,13 @@ def _internal_parse_buffer(
len_ = sqlite_number.value
cstart = sqlite_number.cstart
if len_ == 0:
- return SQCloudResult("")
+ return SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_STRING, "")
+
+ tag = (
+ SQCLOUD_RESULT_TYPE.RESULT_JSON
+ if cmd == SQCLOUD_CMD.JSON.value
+ else SQCLOUD_RESULT_TYPE.RESULT_STRING
+ )
if cmd == SQCLOUD_CMD.ZEROSTRING.value:
len_ -= 1
@@ -315,14 +615,23 @@ def _internal_parse_buffer(
if cmd == SQCLOUD_CMD.COMMAND.value:
return self._internal_run_command(connection, clone)
elif cmd == SQCLOUD_CMD.PUBSUB.value:
- return SQCloudResult(self._internal_setup_pubsub(clone))
+ return SQCloudResult(
+ SQCLOUD_RESULT_TYPE.RESULT_OK,
+ self._internal_setup_pubsub(connection, clone),
+ )
elif cmd == SQCLOUD_CMD.RECONNECT.value:
- return SQCloudResult(self._internal_reconnect(clone))
+ return SQCloudResult(
+ SQCLOUD_RESULT_TYPE.RESULT_OK, self._internal_reconnect(clone)
+ )
elif cmd == SQCLOUD_CMD.ARRAY.value:
- return SQCloudResult(self._internal_parse_array(clone))
+ return SQCloudResult(
+ SQCLOUD_RESULT_TYPE.RESULT_ARRAY, self._internal_parse_array(clone)
+ )
+ elif cmd == SQCLOUD_CMD.BLOB.value:
+ tag = SQCLOUD_RESULT_TYPE.RESULT_BLOB
clone = clone.decode() if cmd != SQCLOUD_CMD.BLOB.value else clone
- return SQCloudResult(clone)
+ return SQCloudResult(tag, clone)
elif cmd == SQCLOUD_CMD.ERROR.value:
# -LEN ERRCODE:EXTCODE ERRMSG
@@ -376,25 +685,29 @@ def _internal_parse_buffer(
return rowset
elif cmd == SQCLOUD_CMD.NULL.value:
- return None
+ return SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_NONE, None)
elif cmd in [SQCLOUD_CMD.INT.value, SQCLOUD_CMD.FLOAT.value]:
sqcloud_value = self._internal_parse_value(buffer)
clone = sqcloud_value.value
+ tag = (
+ SQCLOUD_RESULT_TYPE.RESULT_INTEGER
+ if cmd == SQCLOUD_CMD.INT.value
+ else SQCLOUD_RESULT_TYPE.RESULT_FLOAT
+ )
+
if clone is None:
- return SQCloudResult(0)
+ return SQCloudResult(tag, 0)
if cmd == SQCLOUD_CMD.INT.value:
- return SQCloudResult(int(clone))
- return SQCloudResult(float(clone))
+ return SQCloudResult(tag, int(clone))
+ return SQCloudResult(tag, float(clone))
elif cmd == SQCLOUD_CMD.RAWJSON.value:
- # TODO: isn't implemented in C?
- return SQCloudResult(None)
+ return SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_NONE, None)
- # TODO: exception here?
- return SQCloudResult(None)
+ return SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_NONE, None)
def _internal_uncompress_data(self, buffer: bytes) -> Optional[bytes]:
"""
@@ -545,7 +858,7 @@ def _internal_parse_rowset(
# idx == 1 means first chunk for chunked rowset
first_chunk = (ischunk and idx == 1) or (not ischunk and idx == 0)
if first_chunk:
- rowset = SQCloudResult()
+ rowset = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_ROWSET)
rowset.nrows = nrows
rowset.ncols = ncols
rowset.version = version
diff --git a/src/sqlitecloud/pubsub.py b/src/sqlitecloud/pubsub.py
new file mode 100644
index 0000000..148a7c2
--- /dev/null
+++ b/src/sqlitecloud/pubsub.py
@@ -0,0 +1,70 @@
+import socket
+from sqlite3 import connect
+from typing import Callable, Optional
+from sqlitecloud.driver import Driver
+from sqlitecloud.resultset import SqliteCloudResultSet
+from sqlitecloud.types import SQCLOUD_PUBSUB_SUBJECT, SQCloudConnect
+
+
+class SqliteCloudPubSub:
+ def __init__(self) -> None:
+ self._driver = Driver()
+
+ def listen(
+ self,
+ connection: SQCloudConnect,
+ subject_type: SQCLOUD_PUBSUB_SUBJECT,
+ subject_name: str,
+ callback: Callable[
+ [SQCloudConnect, Optional[SqliteCloudResultSet], Optional[any]], None
+ ],
+ data: Optional[any] = None,
+ ) -> None:
+ subject = "TABLE " if subject_type.value == "TABLE" else ""
+
+ connection.pubsub_callback = callback
+ connection.pubsub_data = data
+
+ self._driver.execute(f"LISTEN {subject}{subject_name};", connection)
+
+ def unlisten(
+ self,
+ connection: SQCloudConnect,
+ subject_type: SQCLOUD_PUBSUB_SUBJECT,
+ subject_name: str,
+ ) -> None:
+ subject = "TABLE " if subject_type.value == "TABLE" else ""
+
+ self._driver.execute(f"UNLISTEN {subject}{subject_name};", connection)
+
+ connection.pubsub_callback = None
+ connection.pubsub_data = None
+
+ def create_channel(
+ self, connection: SQCloudConnect, name: str, if_not_exists: bool = False
+ ) -> None:
+ if if_not_exists:
+ self._driver.execute(f"CREATE CHANNEL {name} IF NOT EXISTS;", connection)
+ else:
+ self._driver.execute(f"CREATE CHANNEL {name};", connection)
+
+ def notify_channel(self, connection: SQCloudConnect, name: str, data: str) -> None:
+ self._driver.execute(f"NOTIFY {name} '{data}';", connection)
+
+ def set_pubsub_only(self, connection: SQCloudConnect) -> None:
+ """
+ Close the main socket, leaving only the pub/sub socket opened and ready
+ to receive incoming notifications from subscripted channels and tables.
+
+ Connection is no longer able to send commands.
+ """
+ self._driver.execute("PUBSUB ONLY;", connection)
+ self._driver.disconnect(connection, only_main_socket=True)
+
+ def is_connected(self, connection: SQCloudConnect) -> bool:
+ return self._driver.is_connected(connection, False)
+
+ def list_connections(self, connection: SQCloudConnect) -> SqliteCloudResultSet:
+ return SqliteCloudResultSet(
+ self._driver.execute("LIST PUBSUB CONNECTIONS;", connection)
+ )
diff --git a/src/sqlitecloud/resultset.py b/src/sqlitecloud/resultset.py
index 3d35d94..0220665 100644
--- a/src/sqlitecloud/resultset.py
+++ b/src/sqlitecloud/resultset.py
@@ -1,8 +1,11 @@
from typing import Any, Dict, List, Optional
+from sqlitecloud.types import SQCLOUD_RESULT_TYPE
+
class SQCloudResult:
- def __init__(self, result: Optional[any] = None) -> None:
+ def __init__(self, tag: SQCLOUD_RESULT_TYPE, result: Optional[any] = None) -> None:
+ self.tag: SQCLOUD_RESULT_TYPE = tag
self.nrows: int = 0
self.ncols: int = 0
self.version: int = 0
@@ -34,7 +37,7 @@ def __init__(self, result: SQCloudResult) -> None:
self._iter_row: int = 0
self._result: SQCloudResult = result
- def __getattr__(self, attr: str) -> Any:
+ def __getattr__(self, attr: str) -> Optional[Any]:
return getattr(self._result, attr)
def __iter__(self):
diff --git a/src/sqlitecloud/types.py b/src/sqlitecloud/types.py
index 118ad1e..6fb1aec 100644
--- a/src/sqlitecloud/types.py
+++ b/src/sqlitecloud/types.py
@@ -1,8 +1,17 @@
+from asyncio import AbstractEventLoop
from enum import Enum
-from typing import Optional
+from threading import Thread
+import types
+from typing import Callable, Optional
from enum import Enum
+class SQCLOUD_DEFAULT(Enum):
+ PORT = 8860
+ TIMEOUT = 12
+ UPLOAD_SIZE = 512 * 1024
+
+
class SQCLOUD_CMD(Enum):
STRING = "+"
ZEROSTRING = "!"
@@ -27,18 +36,48 @@ class SQCLOUD_ROWSET(Enum):
class SQCLOUD_INTERNAL_ERRCODE(Enum):
- INTERNAL_ERRCODE_NONE = 0
- INTERNAL_ERRCODE_NETWORK = 100005
+ """
+ Clients error codes.
+ """
+
+ NONE = 0
+ NETWORK = 100005
+
+
+class SQCLOUD_ERRCODE(Enum):
+ """
+ Error codes from Sqlite Cloud.
+ """
+ MEM = 10000
+ NOTFOUND = 10001
+ COMMAND = 10002
+ INTERNAL = 10003
+ AUTH = 10004
+ GENERIC = 10005
+ RAFT = 10006
+
+
+class SQCLOUD_RESULT_TYPE(Enum):
+ RESULT_OK = 0
+ RESULT_ERROR = 1
+ RESULT_STRING = 2
+ RESULT_INTEGER = 3
+ RESULT_FLOAT = 4
+ RESULT_ROWSET = 5
+ RESULT_ARRAY = 6
+ RESULT_NONE = 7
+ RESULT_JSON = 8
+ RESULT_BLOB = 9
+
+
+class SQCLOUD_PUBSUB_SUBJECT(Enum):
+ """
+ Subjects that can be subscribed to by PubSub.
+ """
-class SQCLOUD_CLOUD_ERRCODE(Enum):
- CLOUD_ERRCODE_MEM = 10000
- CLOUD_ERRCODE_NOTFOUND = 10001
- CLOUD_ERRCODE_COMMAND = 10002
- CLOUD_ERRCODE_INTERNAL = 10003
- CLOUD_ERRCODE_AUTH = 10004
- CLOUD_ERRCODE_GENERIC = 10005
- CLOUD_ERRCODE_RAFT = 10006
+ TABLE = "TABLE"
+ CHANNEL = "CHANNEL"
class SQCloudRowsetSignature:
@@ -62,7 +101,7 @@ def __init__(
password: Optional[str] = "",
hostname: Optional[str] = "",
dbname: Optional[str] = "",
- port: Optional[int] = 8860,
+ port: Optional[int] = SQCLOUD_DEFAULT.PORT.value,
apikey: Optional[str] = "",
) -> None:
# User name is required unless connectionstring is provided
@@ -90,6 +129,13 @@ def __init__(self):
self.config: SQCloudConfig
self.isblob: bool = False
+ self.pubsub_socket: any = None
+ self.pubsub_callback: Callable[
+ [SQCloudConnect, Optional[types.SqliteCloudResultSet], Optional[any]], None
+ ] = None
+ self.pubsub_data: any = None
+ self.pubsub_thread: AbstractEventLoop = None
+
class SQCloudConfig:
def __init__(self) -> None:
@@ -98,7 +144,7 @@ def __init__(self) -> None:
# Optional query timeout passed directly to TLS socket
self.timeout = 0
# Socket connection timeout
- self.connect_timeout = 20
+ self.connect_timeout = SQCLOUD_DEFAULT.TIMEOUT.value
# Enable compression
self.compression = False
@@ -131,9 +177,7 @@ def __init__(self) -> None:
class SQCloudException(Exception):
- def __init__(
- self, message: str, code: Optional[int] = -1, xerrcode: Optional[int] = 0
- ) -> None:
+ def __init__(self, message: str, code: int = -1, xerrcode: int = 0) -> None:
self.errmsg = str(message)
self.errcode = code
self.xerrcode = xerrcode
diff --git a/src/sqlitecloud/upload.py b/src/sqlitecloud/upload.py
new file mode 100644
index 0000000..9f90a7e
--- /dev/null
+++ b/src/sqlitecloud/upload.py
@@ -0,0 +1,65 @@
+from io import BufferedReader
+import os
+from typing import Optional
+from sqlitecloud.driver import Driver
+from sqlitecloud.types import SQCloudConnect
+import logging
+
+def xCallback(fd: BufferedReader, blen: int, ntot: int, nprogress: int) -> bytes:
+ """
+ Callback function used for uploading data.
+
+ Args:
+ fd (BufferedReader): The file descriptor to read data from.
+ blen (int): The length of the buffer to read.
+ ntot (int): The total number of bytes to be uploaded.
+ nprogress (int): The number of bytes already uploaded.
+
+ Returns:
+ bytes: The buffer containing the read data.
+ """
+ buffer = fd.read(blen)
+ nread = len(buffer)
+
+ if nread == 0:
+ logging.log(logging.DEBUG, "UPLOAD COMPLETE\n\n")
+ else:
+ logging.log(logging.DEBUG, f"{(nprogress + nread) / ntot * 100:.2f}%")
+
+ return buffer
+
+
+def upload_db(
+ connection: SQCloudConnect, dbname: str, key: Optional[str], filename: str
+) -> None:
+ """
+ Uploads a SQLite database to the SQLite Cloud node using the provided connection.
+
+ Args:
+ connection (SQCloudConnect): The connection object used to connect to the node.
+ dbname (str): The name of the database in SQLite Cloud.
+ key (Optional[str]): The encryption key for the database. If None, no encryption is used.
+ filename (str): The path to the SQLite database file to be uploaded.
+
+ Raises:
+ SQCloudException: If an error occurs while uploading the database.
+
+ """
+
+ # Create a driver object
+ driver = Driver()
+
+ with open(filename, 'rb') as fd:
+ dbsize = os.path.getsize(filename)
+
+ driver.upload_database(
+ connection,
+ dbname,
+ key,
+ False,
+ 0,
+ False,
+ fd,
+ dbsize,
+ xCallback,
+ )
diff --git a/src/tests/assets/test.db b/src/tests/assets/test.db
new file mode 100644
index 0000000..cf7714a
Binary files /dev/null and b/src/tests/assets/test.db differ
diff --git a/src/tests/conftest.py b/src/tests/conftest.py
index 434c04b..cf8b266 100644
--- a/src/tests/conftest.py
+++ b/src/tests/conftest.py
@@ -1,6 +1,29 @@
+import os
import pytest
from dotenv import load_dotenv
+from sqlitecloud.client import SqliteCloudClient
+from sqlitecloud.types import SQCloudConnect, SqliteCloudAccount
+
@pytest.fixture(autouse=True)
def load_env_vars():
load_dotenv(".env")
+
+@pytest.fixture()
+def sqlitecloud_connection():
+ account = SqliteCloudAccount()
+ account.username = os.getenv("SQLITE_USER")
+ account.password = os.getenv("SQLITE_PASSWORD")
+ account.dbname = os.getenv("SQLITE_DB")
+ account.hostname = os.getenv("SQLITE_HOST")
+ account.port = 8860
+
+ client = SqliteCloudClient(cloud_account=account)
+
+ connection = client.open_connection()
+ assert isinstance(connection, SQCloudConnect)
+ assert client.is_connected(connection)
+
+ yield (connection, client)
+
+ client.disconnect(connection)
\ No newline at end of file
diff --git a/src/tests/integration/test_client.py b/src/tests/integration/test_client.py
index b4f17bb..1ae2622 100644
--- a/src/tests/integration/test_client.py
+++ b/src/tests/integration/test_client.py
@@ -1,4 +1,5 @@
import json
+from multiprocessing import connection
import os
import sqlite3
import tempfile
@@ -7,8 +8,9 @@
import pytest
from sqlitecloud.client import SqliteCloudClient
from sqlitecloud.types import (
- SQCLOUD_CLOUD_ERRCODE,
+ SQCLOUD_ERRCODE,
SQCLOUD_INTERNAL_ERRCODE,
+ SQCLOUD_RESULT_TYPE,
SQCloudConnect,
SQCloudException,
SqliteCloudAccount,
@@ -22,24 +24,6 @@ class TestClient:
# Will except queries to be quicker than this
EXPECT_SPEED_MS = 6 * 1000
- @pytest.fixture()
- def sqlitecloud_connection(self):
- account = SqliteCloudAccount()
- account.username = os.getenv("SQLITE_USER")
- account.password = os.getenv("SQLITE_PASSWORD")
- account.dbname = os.getenv("SQLITE_DB")
- account.hostname = os.getenv("SQLITE_HOST")
- account.port = 8860
-
- client = SqliteCloudClient(cloud_account=account)
-
- connection = client.open_connection()
- assert isinstance(connection, SQCloudConnect)
-
- yield (connection, client)
-
- client.disconnect(connection)
-
def test_connection_with_credentials(self):
account = SqliteCloudAccount()
account.username = os.getenv("SQLITE_USER")
@@ -97,6 +81,39 @@ def test_connect_with_string_with_credentials(self):
client.disconnect(conn)
+ def test_is_connected(self):
+ account = SqliteCloudAccount()
+ account.username = os.getenv("SQLITE_API_KEY")
+ account.hostname = os.getenv("SQLITE_HOST")
+ account.port = 8860
+
+ client = SqliteCloudClient(cloud_account=account)
+
+ conn = client.open_connection()
+ assert client.is_connected(conn) == True
+
+ client.disconnect(conn)
+ assert client.is_connected(conn) == False
+
+ def test_disconnect(self):
+ account = SqliteCloudAccount()
+ account.username = os.getenv("SQLITE_API_KEY")
+ account.hostname = os.getenv("SQLITE_HOST")
+ account.port = 8860
+
+ client = SqliteCloudClient(cloud_account=account)
+
+ conn = client.open_connection()
+ assert client.is_connected(conn) == True
+
+ client.disconnect(conn)
+ assert client.is_connected(conn) == False
+ assert conn.socket is None
+ assert conn.pubsub_socket is None
+
+ # disconnecting a second time should not raise an exception
+ client.disconnect(conn)
+
def test_select(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
@@ -119,6 +136,7 @@ def test_rowset_data(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("SELECT AlbumId FROM albums LIMIT 2", connection)
+ assert SQCLOUD_RESULT_TYPE.RESULT_ROWSET == result.tag
assert 2 == result.nrows
assert 1 == result.ncols
assert 2 == result.version
@@ -178,33 +196,38 @@ def test_integer(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST INTEGER", connection)
+ assert SQCLOUD_RESULT_TYPE.RESULT_INTEGER == result.tag
assert 123456 == result.get_result()
def test_float(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST FLOAT", connection)
+ assert SQCLOUD_RESULT_TYPE.RESULT_FLOAT == result.tag
assert 3.1415926 == result.get_result()
def test_string(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST STRING", connection)
- assert "Hello World, this is a test string." == result.get_result()
+ assert SQCLOUD_RESULT_TYPE.RESULT_STRING == result.tag
+ assert result.get_result() == "Hello World, this is a test string."
def test_zero_string(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST ZERO_STRING", connection)
+ assert SQCLOUD_RESULT_TYPE.RESULT_STRING == result.tag
assert (
- "Hello World, this is a zero-terminated test string." == result.get_result()
+ result.get_result() == "Hello World, this is a zero-terminated test string."
)
def test_empty_string(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST STRING0", connection)
- assert "" == result.get_result()
+ assert SQCLOUD_RESULT_TYPE.RESULT_STRING == result.tag
+ assert result.get_result() == ""
def test_command(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
@@ -216,6 +239,7 @@ def test_json(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST JSON", connection)
+ assert SQCLOUD_RESULT_TYPE.RESULT_JSON == result.tag
assert {
"msg-from": {"class": "soldier", "name": "Wixilav"},
"msg-to": {"class": "supreme-commander", "name": "[Redacted]"},
@@ -232,13 +256,15 @@ def test_blob(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST BLOB", connection)
- assert 1000 == len(result.get_result())
+ assert SQCLOUD_RESULT_TYPE.RESULT_BLOB == result.tag
+ assert len(result.get_result()) == 1000
def test_blob0(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST BLOB0", connection)
- assert 0 == len(result.get_result())
+ assert SQCLOUD_RESULT_TYPE.RESULT_STRING == result.tag
+ assert len(result.get_result()) == 0
def test_error(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
@@ -246,8 +272,8 @@ def test_error(self, sqlitecloud_connection):
with pytest.raises(SQCloudException) as e:
client.exec_query("TEST ERROR", connection)
- assert 66666 == e.value.errcode
- assert "This is a test error message with a devil error code." == e.value.errmsg
+ assert e.value.errcode == 66666
+ assert e.value.errmsg == "This is a test error message with a devil error code."
def test_ext_error(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
@@ -255,11 +281,11 @@ def test_ext_error(self, sqlitecloud_connection):
with pytest.raises(SQCloudException) as e:
client.exec_query("TEST EXTERROR", connection)
- assert 66666 == e.value.errcode
- assert 333 == e.value.xerrcode
+ assert e.value.errcode == 66666
+ assert e.value.xerrcode == 333
assert (
- "This is a test error message with an extcode and a devil error code."
- == e.value.errmsg
+ e.value.errmsg
+ == "This is a test error message with an extcode and a devil error code."
)
def test_array(self, sqlitecloud_connection):
@@ -267,6 +293,8 @@ def test_array(self, sqlitecloud_connection):
result = client.exec_query("TEST ARRAY", connection)
result_array = result.get_result()
+
+ assert SQCLOUD_RESULT_TYPE.RESULT_ARRAY == result.tag
assert isinstance(result_array, list)
assert len(result_array) == 5
assert result_array[0] == "Hello World"
@@ -278,6 +306,7 @@ def test_rowset(self, sqlitecloud_connection):
connection, client = sqlitecloud_connection
result = client.exec_query("TEST ROWSET", connection)
+ assert SQCLOUD_RESULT_TYPE.RESULT_ROWSET == result.tag
assert result.nrows >= 30
assert result.ncols == 2
assert result.version in [1, 2]
@@ -293,12 +322,17 @@ def test_max_rows_option(self):
client = SqliteCloudClient(cloud_account=account)
client.config.maxrows = 1
- rowset = client.exec_query("TEST ROWSET_CHUNK")
+ connection = client.open_connection()
+
+ rowset = client.exec_query("TEST ROWSET_CHUNK", connection)
+
+ client.disconnect(connection)
# maxrows cannot be tested at this level.
# just expect everything is ok
assert rowset.nrows > 100
+
def test_max_rowset_option_to_fail_when_rowset_is_bigger(self):
account = SqliteCloudAccount()
account.hostname = os.getenv("SQLITE_HOST")
@@ -308,12 +342,17 @@ def test_max_rowset_option_to_fail_when_rowset_is_bigger(self):
client = SqliteCloudClient(cloud_account=account)
client.config.maxrowset = 1024
+ connection = client.open_connection()
+
with pytest.raises(SQCloudException) as e:
- client.exec_query("SELECT * FROM albums")
+ client.exec_query("SELECT * FROM albums", connection)
- assert SQCLOUD_CLOUD_ERRCODE.CLOUD_ERRCODE_INTERNAL.value == e.value.errcode
+ client.disconnect(connection)
+
+ assert SQCLOUD_ERRCODE.INTERNAL.value == e.value.errcode
assert "RowSet too big to be sent (limit set to 1024 bytes)." == e.value.errmsg
+
def test_max_rowset_option_to_succeed_when_rowset_is_lighter(self):
account = SqliteCloudAccount()
account.hostname = os.getenv("SQLITE_HOST")
@@ -323,7 +362,11 @@ def test_max_rowset_option_to_succeed_when_rowset_is_lighter(self):
client = SqliteCloudClient(cloud_account=account)
client.config.maxrowset = 1024
- rowset = client.exec_query("SELECT 'hello world'")
+ connection = client.open_connection()
+
+ rowset = client.exec_query("SELECT 'hello world'", connection)
+
+ client.disconnect(connection)
assert 1 == rowset.nrows
@@ -332,6 +375,7 @@ def test_chunked_rowset(self, sqlitecloud_connection):
rowset = client.exec_query("TEST ROWSET_CHUNK", connection)
+ assert SQCLOUD_RESULT_TYPE.RESULT_ROWSET == rowset.tag
assert 147 == rowset.nrows
assert 1 == rowset.ncols
assert 147 == len(rowset.data)
@@ -381,6 +425,8 @@ def test_query_timeout(self):
client = SqliteCloudClient(cloud_account=account)
client.config.timeout = 1 # 1 sec
+ connection = client.open_connection()
+
# this operation should take more than 1 sec
with pytest.raises(SQCloudException) as e:
# just a long running query
@@ -392,10 +438,13 @@ def test_query_timeout(self):
SELECT i FROM r
LIMIT 10000000
)
- SELECT i FROM r WHERE i = 1;"""
+ SELECT i FROM r WHERE i = 1;""",
+ connection
)
- assert e.value.errcode == SQCLOUD_INTERNAL_ERRCODE.INTERNAL_ERRCODE_NETWORK
+ client.disconnect(connection)
+
+ assert e.value.errcode == SQCLOUD_INTERNAL_ERRCODE.NETWORK
assert e.value.errmsg == "An error occurred while reading data from the socket."
def test_XXL_query(self, sqlitecloud_connection):
@@ -470,7 +519,11 @@ def test_select_database(self):
client = SqliteCloudClient(cloud_account=account)
- rowset = client.exec_query("USE DATABASE chinook.sqlite")
+ connection = client.open_connection()
+
+ rowset = client.exec_query("USE DATABASE chinook.sqlite", connection)
+
+ client.disconnect(connection)
assert rowset.get_result()
@@ -554,36 +607,6 @@ def test_stress_test_20x_batched_selects(self, sqlitecloud_connection):
query_ms < self.EXPECT_SPEED_MS
), f"{num_queries}x batched selects, {query_ms}ms per query"
- def test_download_database(self, sqlitecloud_connection):
- connection, client = sqlitecloud_connection
-
- rowset = client.exec_query(
- "DOWNLOAD DATABASE " + os.getenv("SQLITE_DB"), connection
- )
-
- result_array = rowset.get_result()
-
- db_size = int(result_array[0])
-
- tot_read = 0
- data: bytes = b""
- while tot_read < db_size:
- result = client.exec_query("DOWNLOAD STEP;", connection)
-
- data += result.get_result()
- tot_read += len(data)
-
- temp_file = tempfile.mkstemp(prefix="chinook")[1]
- with open(temp_file, "wb") as f:
- f.write(data)
-
- db = sqlite3.connect(temp_file)
- cursor = db.execute("SELECT * FROM albums")
- rowset = cursor.fetchall()
-
- assert cursor.description[0][0] == "AlbumId"
- assert cursor.description[1][0] == "Title"
-
def test_compression_single_column(self):
account = SqliteCloudAccount()
account.hostname = os.getenv("SQLITE_HOST")
@@ -593,13 +616,17 @@ def test_compression_single_column(self):
client = SqliteCloudClient(cloud_account=account)
client.config.compression = True
+ connection = client.open_connection()
+
# min compression size for rowset set by default to 20400 bytes
blob_size = 20 * 1024
# rowset = client.exec_query("SELECT * from albums inner join albums a2 on albums.AlbumId = a2.AlbumId")
rowset = client.exec_query(
- f"SELECT hex(randomblob({blob_size})) AS 'someColumnName'"
+ f"SELECT hex(randomblob({blob_size})) AS 'someColumnName'", connection
)
+ client.disconnect(connection)
+
assert rowset.nrows == 1
assert rowset.ncols == 1
assert rowset.get_name(0) == "someColumnName"
@@ -614,11 +641,15 @@ def test_compression_multiple_columns(self):
client = SqliteCloudClient(cloud_account=account)
client.config.compression = True
+ connection = client.open_connection()
+
# min compression size for rowset set by default to 20400 bytes
rowset = client.exec_query(
- "SELECT * from albums inner join albums a2 on albums.AlbumId = a2.AlbumId"
+ "SELECT * from albums inner join albums a2 on albums.AlbumId = a2.AlbumId", connection
)
+ client.disconnect(connection)
+
assert rowset.nrows > 0
assert rowset.ncols > 0
assert rowset.get_name(0) == "AlbumId"
diff --git a/src/tests/integration/test_download.py b/src/tests/integration/test_download.py
new file mode 100644
index 0000000..513eaaa
--- /dev/null
+++ b/src/tests/integration/test_download.py
@@ -0,0 +1,34 @@
+import os
+import sqlite3
+import tempfile
+
+import pytest
+
+from sqlitecloud import download
+from sqlitecloud.client import SqliteCloudClient
+from sqlitecloud.types import SQCLOUD_ERRCODE, SQCloudConnect, SQCloudException, SqliteCloudAccount
+
+
+class TestDownload:
+ def test_download_database(self, sqlitecloud_connection):
+ connection, _ = sqlitecloud_connection
+
+ temp_file = tempfile.mkstemp(prefix="chinook")[1]
+ download.download_db(connection, "chinook.sqlite", temp_file)
+
+ db = sqlite3.connect(temp_file)
+ cursor = db.execute("SELECT * FROM albums")
+
+ assert cursor.description[0][0] == "AlbumId"
+ assert cursor.description[1][0] == "Title"
+
+ def test_download_missing_database(self, sqlitecloud_connection):
+ connection, _ = sqlitecloud_connection
+
+ temp_file = tempfile.mkstemp(prefix="missing")[1]
+
+ with pytest.raises(SQCloudException) as e:
+ download.download_db(connection, "missing.sqlite", temp_file)
+
+ assert e.value.errcode == SQCLOUD_ERRCODE.COMMAND.value
+ assert e.value.errmsg == "Database missing.sqlite does not exist."
\ No newline at end of file
diff --git a/src/tests/integration/test_driver.py b/src/tests/integration/test_driver.py
index c631829..ac1376e 100644
--- a/src/tests/integration/test_driver.py
+++ b/src/tests/integration/test_driver.py
@@ -1,90 +1,15 @@
+import tempfile
from sqlitecloud.driver import Driver
-import pytest
-
class TestDriver:
- @pytest.fixture(
- params=[
- (":0 ", 0, 0, 3),
- (":123 ", 123, 0, 5),
- (",123.456 ", 1230456, 0, 9),
- ("-1:1234 ", 1, 1234, 8),
- ("-0:0 ", 0, 0, 5),
- ("-123:456 ", 123, 456, 9),
- ("-123: ", 123, 0, 6),
- ("-1234:5678 ", 1234, 5678, 11),
- ("-1234: ", 1234, 0, 7),
- ]
- )
- def number_data(self, request):
- return request.param
-
- def test_parse_number(self, number_data):
+ def test_download_missing_database_without_error_when_expected(self, sqlitecloud_connection):
driver = Driver()
- buffer, expected_value, expected_extcode, expected_cstart = number_data
- result = driver._internal_parse_number(buffer.encode())
-
- assert expected_value == result.value
- assert expected_extcode == result.extcode
- assert expected_cstart == result.cstart
-
- @pytest.fixture(
- params=[
- ("+5 Hello", "Hello", 5, 8),
- ("+11 Hello World", "Hello World", 11, 15),
- ("!6 Hello0", "Hello", 5, 9),
- ("+0 ", "", 0, 3),
- (":5678 ", "5678", 0, 6),
- (":0 ", "0", 0, 3),
- (",3.14 ", "3.14", 0, 6),
- (",0 ", "0", 0, 3),
- (",0.0 ", "0.0", 0, 5),
- ("_ ", None, 0, 2),
- ],
- ids=[
- "String",
- "String with space",
- "String zero-terminated",
- "Empty string",
- "Integer",
- "Integer zero",
- "Float",
- "Float zero",
- "Float 0.0",
- "Null",
- ],
- )
- def value_data(self, request):
- return request.param
-
- def test_parse_value(self, value_data):
- driver = Driver()
- buffer, expected_value, expected_len, expected_cellsize = value_data
-
- result = driver._internal_parse_value(buffer.encode())
+
+ connection, _ = sqlitecloud_connection
- assert expected_value == result.value
- assert expected_len == result.len
- assert expected_cellsize == result.cellsize
-
- def test_parse_array(self):
- driver = Driver()
- buffer = b"=5 +11 Hello World:123456 ,3.1415 _ $10 0123456789"
- expected_list = ["Hello World", "123456", "3.1415", None, "0123456789"]
-
- result = driver._internal_parse_array(buffer)
-
- assert expected_list == result
-
- def test_parse_rowset_signature(self):
- driver = Driver()
- buffer = b"*35 0:1 1 2 +2 42+7 'hello':42 +5 hello"
+ temp_file = tempfile.mkstemp(prefix="missing")[1]
- result = driver._internal_parse_rowset_signature(buffer)
+ if_exists = True
- assert 12 == result.start
- assert 35 == result.len
- assert 0 == result.idx
- assert 1 == result.version
- assert 1 == result.nrows
- assert 2 == result.ncols
\ No newline at end of file
+ with open(temp_file, "wb") as fd:
+ driver.download_database(connection, "missing.sqlite", fd, lambda x, y, z, k: None, if_exists=if_exists)
\ No newline at end of file
diff --git a/src/tests/integration/test_pubsub.py b/src/tests/integration/test_pubsub.py
new file mode 100644
index 0000000..122df55
--- /dev/null
+++ b/src/tests/integration/test_pubsub.py
@@ -0,0 +1,155 @@
+from time import sleep
+import time
+
+import pytest
+
+from sqlitecloud.pubsub import SqliteCloudPubSub
+from sqlitecloud.resultset import SqliteCloudResultSet
+from sqlitecloud.types import (
+ SQCLOUD_ERRCODE,
+ SQCLOUD_PUBSUB_SUBJECT,
+ SQCLOUD_RESULT_TYPE,
+ SQCloudException,
+)
+
+
+class TestPubSub:
+ def test_listen_channel_and_notify(self, sqlitecloud_connection):
+ connection, _ = sqlitecloud_connection
+
+ callback_called = False
+
+ def assert_callback(conn, result, data):
+ nonlocal callback_called
+
+ if isinstance(result, SqliteCloudResultSet):
+ assert result.tag == SQCLOUD_RESULT_TYPE.RESULT_JSON
+ assert data == ["somedata"]
+ callback_called = True
+
+ pubsub = SqliteCloudPubSub()
+ type = SQCLOUD_PUBSUB_SUBJECT.CHANNEL
+ channel = "channel" + str(int(time.time()))
+
+ pubsub.create_channel(connection, channel)
+ pubsub.listen(connection, type, channel, assert_callback, ["somedata"])
+
+ pubsub.notify_channel(connection, channel, "somedata2")
+
+ # wait for callback to be called
+ sleep(1)
+
+ assert callback_called
+
+ def test_unlisten_channel(self, sqlitecloud_connection):
+ connection, _ = sqlitecloud_connection
+
+ pubsub = SqliteCloudPubSub()
+ type = SQCLOUD_PUBSUB_SUBJECT.CHANNEL
+ channel_name = "channel" + str(int(time.time()))
+
+ pubsub.create_channel(connection, channel_name)
+ pubsub.listen(connection, type, channel_name, lambda conn, result, data: None)
+
+ result = pubsub.list_connections(connection)
+ assert channel_name in result.data
+
+ pubsub.unlisten(connection, type, channel_name)
+
+ result = pubsub.list_connections(connection)
+
+ assert channel_name not in result.data
+ assert connection.pubsub_callback is None
+ assert connection.pubsub_data is None
+
+ def test_create_channel_to_fail_if_exists(self, sqlitecloud_connection):
+ connection, _ = sqlitecloud_connection
+
+ pubsub = SqliteCloudPubSub()
+ channel_name = "channel" + str(int(time.time()))
+
+ pubsub.create_channel(connection, channel_name, if_not_exists=True)
+
+ with pytest.raises(SQCloudException) as e:
+ pubsub.create_channel(connection, channel_name, if_not_exists=False)
+
+ assert (
+ e.value.errmsg
+ == f"Cannot create channel {channel_name} because it already exists."
+ )
+ assert e.value.errcode == SQCLOUD_ERRCODE.GENERIC.value
+
+ def test_is_connected(self, sqlitecloud_connection):
+ connection, _ = sqlitecloud_connection
+
+ pubsub = SqliteCloudPubSub()
+ channel_name = "channel" + str(int(time.time()))
+
+ assert not pubsub.is_connected(connection)
+
+ pubsub.create_channel(connection, channel_name, if_not_exists=True)
+ pubsub.listen(connection, SQCLOUD_PUBSUB_SUBJECT.CHANNEL, channel_name, lambda conn, result, data: None)
+
+ assert pubsub.is_connected(connection)
+
+ def test_set_pubsub_only(self, sqlitecloud_connection):
+ connection, client = sqlitecloud_connection
+
+ callback_called = False
+
+ def assert_callback(conn, result, data):
+ nonlocal callback_called
+
+ if isinstance(result, SqliteCloudResultSet):
+ assert result.get_result() is not None
+ callback_called = True
+
+ pubsub = SqliteCloudPubSub()
+ type = SQCLOUD_PUBSUB_SUBJECT.CHANNEL
+ channel = "channel" + str(int(time.time()))
+
+ pubsub.create_channel(connection, channel, if_not_exists=True)
+ pubsub.listen(connection, type, channel, assert_callback)
+
+ pubsub.set_pubsub_only(connection)
+
+ assert not client.is_connected(connection)
+ assert pubsub.is_connected(connection)
+
+ connection2 = client.open_connection()
+ pubsub2 = SqliteCloudPubSub()
+ pubsub2.notify_channel(connection2, channel, "message-in-a-bottle")
+
+ # wait for callback to be called
+ sleep(2)
+
+ assert callback_called
+
+ client.disconnect(connection2)
+
+ def test_listen_table_for_update(self, sqlitecloud_connection):
+ connection, client = sqlitecloud_connection
+
+ callback_called = False
+
+ def assert_callback(conn, result, data):
+ nonlocal callback_called
+
+ if isinstance(result, SqliteCloudResultSet):
+ assert result.tag == SQCLOUD_RESULT_TYPE.RESULT_JSON
+ assert new_name in result.get_result()
+ assert data == ["somedata"]
+ callback_called = True
+
+ pubsub = SqliteCloudPubSub()
+ type = SQCLOUD_PUBSUB_SUBJECT.TABLE
+ new_name = "Rock"+ str(int(time.time()))
+
+ pubsub.listen(connection, type, "genres", assert_callback, ["somedata"])
+
+ client.exec_query(f"UPDATE genres SET Name = '{new_name}' WHERE GenreId = 1;", connection)
+
+ # wait for callback to be called
+ sleep(1)
+
+ assert callback_called
\ No newline at end of file
diff --git a/src/tests/integration/test_upload.py b/src/tests/integration/test_upload.py
new file mode 100644
index 0000000..eef127f
--- /dev/null
+++ b/src/tests/integration/test_upload.py
@@ -0,0 +1,30 @@
+import os
+import uuid
+import pytest
+from sqlitecloud.client import SqliteCloudClient
+from sqlitecloud.types import SQCloudConnect, SqliteCloudAccount
+from sqlitecloud.upload import upload_db
+
+
+class TestUpload:
+ def test_upload_db(self, sqlitecloud_connection):
+ connection, client = sqlitecloud_connection
+
+ dbname = f"testUploadDb{str(uuid.uuid4())}"
+ key = None
+ filename = os.path.join(os.path.dirname(__file__), "..", "assets", "test.db")
+
+ upload_db(connection, dbname, key, filename)
+
+ try:
+ rowset = client.exec_query(
+ f"USE DATABASE {dbname}; SELECT * FROM contacts", connection
+ )
+
+ assert rowset.nrows == 1
+ assert rowset.ncols == 5
+ assert rowset.get_value(0, 1) == "John"
+ assert rowset.get_name(4) == "phone"
+ finally:
+ # delete uploaded database
+ client.exec_query(f"UNUSE DATABASE; REMOVE DATABASE {dbname}", connection)
diff --git a/src/tests/unit/test_driver.py b/src/tests/unit/test_driver.py
new file mode 100644
index 0000000..750b529
--- /dev/null
+++ b/src/tests/unit/test_driver.py
@@ -0,0 +1,103 @@
+import pytest
+from sqlitecloud.driver import Driver
+
+
+class TestDriver:
+ @pytest.fixture(
+ params=[
+ (":0 ", 0, 0, 3),
+ (":123 ", 123, 0, 5),
+ (",123.456 ", 1230456, 0, 9),
+ ("-1:1234 ", 1, 1234, 8),
+ ("-0:0 ", 0, 0, 5),
+ ("-123:456 ", 123, 456, 9),
+ ("-123: ", 123, 0, 6),
+ ("-1234:5678 ", 1234, 5678, 11),
+ ("-1234: ", 1234, 0, 7),
+ ]
+ )
+ def number_data(self, request):
+ return request.param
+
+ def test_parse_number(self, number_data):
+ driver = Driver()
+ buffer, expected_value, expected_extcode, expected_cstart = number_data
+ result = driver._internal_parse_number(buffer.encode())
+
+ assert expected_value == result.value
+ assert expected_extcode == result.extcode
+ assert expected_cstart == result.cstart
+
+ @pytest.fixture(
+ params=[
+ ("+5 Hello", "Hello", 5, 8),
+ ("+11 Hello World", "Hello World", 11, 15),
+ ("!6 Hello0", "Hello", 5, 9),
+ ("+0 ", "", 0, 3),
+ (":5678 ", "5678", 0, 6),
+ (":0 ", "0", 0, 3),
+ (",3.14 ", "3.14", 0, 6),
+ (",0 ", "0", 0, 3),
+ (",0.0 ", "0.0", 0, 5),
+ ("_ ", None, 0, 2),
+ ],
+ ids=[
+ "String",
+ "String with space",
+ "String zero-terminated",
+ "Empty string",
+ "Integer",
+ "Integer zero",
+ "Float",
+ "Float zero",
+ "Float 0.0",
+ "Null",
+ ],
+ )
+ def value_data(self, request):
+ return request.param
+
+ def test_parse_value(self, value_data):
+ driver = Driver()
+ buffer, expected_value, expected_len, expected_cellsize = value_data
+
+ result = driver._internal_parse_value(buffer.encode())
+
+ assert expected_value == result.value
+ assert expected_len == result.len
+ assert expected_cellsize == result.cellsize
+
+ def test_parse_array(self):
+ driver = Driver()
+ buffer = b"=5 +11 Hello World:123456 ,3.1415 _ $10 0123456789"
+ expected_list = ["Hello World", "123456", "3.1415", None, "0123456789"]
+
+ result = driver._internal_parse_array(buffer)
+
+ assert expected_list == result
+
+ def test_parse_rowset_signature(self):
+ driver = Driver()
+ buffer = b"*35 0:1 1 2 +2 42+7 'hello':42 +5 hello"
+
+ result = driver._internal_parse_rowset_signature(buffer)
+
+ assert 12 == result.start
+ assert 35 == result.len
+ assert 0 == result.idx
+ assert 1 == result.version
+ assert 1 == result.nrows
+ assert 2 == result.ncols
+
+ def test_parse_rowset_signature(self):
+ driver = Driver()
+ buffer = b"*35 0:1 1 2 +2 42+7 'hello':42 +5 hello"
+
+ result = driver._internal_parse_rowset_signature(buffer)
+
+ assert 12 == result.start
+ assert 35 == result.len
+ assert 0 == result.idx
+ assert 1 == result.version
+ assert 1 == result.nrows
+ assert 2 == result.ncols
diff --git a/src/tests/unit/test_resultset.py b/src/tests/unit/test_resultset.py
index 18b36f6..b865e43 100644
--- a/src/tests/unit/test_resultset.py
+++ b/src/tests/unit/test_resultset.py
@@ -1,19 +1,19 @@
import pytest
from sqlitecloud.resultset import SQCloudResult, SqliteCloudResultSet
+from sqlitecloud.types import SQCLOUD_RESULT_TYPE
class TestSqCloudResult:
def test_init_data(self):
- result = SQCloudResult()
+ result = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_INTEGER)
result.init_data(42)
assert 1 == result.nrows
assert 1 == result.ncols
assert [42] == result.data
assert True is result.is_result
- # TODO
def test_init_data_with_array(self):
- result = SQCloudResult()
+ result = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_ARRAY)
result.init_data([42, 43, 44])
assert 1 == result.nrows
@@ -22,7 +22,8 @@ def test_init_data_with_array(self):
assert True is result.is_result
def test_init_as_dataset(self):
- result = SQCloudResult()
+ result = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_ROWSET)
+
assert False is result.is_result
assert 0 == result.nrows
assert 0 == result.ncols
@@ -31,7 +32,7 @@ def test_init_as_dataset(self):
class TestSqliteCloudResultSet:
def test_next(self):
- result = SQCloudResult(result=42)
+ result = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_INTEGER, result=42)
result_set = SqliteCloudResultSet(result)
assert {"result": 42} == next(result_set)
@@ -39,13 +40,13 @@ def test_next(self):
next(result_set)
def test_iter_result(self):
- result = SQCloudResult(result=42)
+ result = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_INTEGER, result=42)
result_set = SqliteCloudResultSet(result)
for row in result_set:
assert {"result": 42} == row
def test_iter_rowset(self):
- rowset = SQCloudResult()
+ rowset = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_ROWSET)
rowset.nrows = 2
rowset.ncols = 2
rowset.colname = ["name", "age"]
@@ -62,7 +63,7 @@ def test_iter_rowset(self):
assert {"name": "Doe", "age": 24} == out[1]
def test_get_value_with_rowset(self):
- rowset = SQCloudResult()
+ rowset = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_ROWSET)
rowset.nrows = 2
rowset.ncols = 2
rowset.colname = ["name", "age"]
@@ -75,13 +76,13 @@ def test_get_value_with_rowset(self):
assert None == result_set.get_value(2, 2)
def test_get_value_array(self):
- result = SQCloudResult(result=[1, 2, 3])
+ result = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_ARRAY, result=[1, 2, 3])
result_set = SqliteCloudResultSet(result)
assert [1, 2, 3] == result_set.get_value(0, 0)
def test_get_colname(self):
- result = SQCloudResult()
+ result = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_ROWSET)
result.ncols = 2
result.colname = ["name", "age"]
result_set = SqliteCloudResultSet(result)
@@ -91,7 +92,7 @@ def test_get_colname(self):
assert None == result_set.get_name(2)
def test_get_result_with_single_value(self):
- result = SQCloudResult(result=42)
+ result = SQCloudResult(SQCLOUD_RESULT_TYPE.RESULT_INTEGER, result=42)
result_set = SqliteCloudResultSet(result)
assert 42 == result_set.get_result()