Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions dissect/database/sqlite3/sqlite3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@

if TYPE_CHECKING:
from collections.abc import Iterator
from types import TracebackType

from typing_extensions import Self

ENCODING = {
1: "utf-8",
Expand Down Expand Up @@ -78,13 +80,11 @@ def __init__(
wal: WAL | Path | BinaryIO | None = None,
checkpoint: Checkpoint | int | None = None,
):
# Use the provided file handle or try to open the file path.
if hasattr(fh, "read"):
name = getattr(fh, "name", None)
path = Path(name) if name else None
else:
if isinstance(fh, Path):
path = fh
fh = path.open("rb")
else:
path = None

self.fh = fh
self.path = path
Expand All @@ -105,12 +105,21 @@ def __init__(
raise InvalidDatabase("Usable page size is too small")

if wal:
self.wal = WAL(wal) if not isinstance(wal, WAL) else wal
elif path:
self.wal = wal if isinstance(wal, WAL) else WAL(wal)
else:
# Check for WAL sidecar next to the DB.
wal_path = path.with_name(f"{path.name}-wal")
if wal_path.exists():
self.wal = WAL(wal_path)
# If we have a path, we can deduce the WAL path.
# If we don't have a path, we can try to get it from the file handle.
if path is None:
# By deducing the path at this point and not earlier, we can keep the original passed
# path to indicate if we should close the file handle later on.
name = getattr(fh, "name", None)
path = Path(name) if name else None

if path is not None:
wal_path = path.with_name(f"{path.name}-wal")
if wal_path.exists() and wal_path.stat().st_size > 0:
self.wal = WAL(wal_path)

# If a checkpoint index was provided, resolve it to a Checkpoint object.
if self.wal and isinstance(checkpoint, int):
Expand All @@ -122,6 +131,23 @@ def __init__(

self.page = lru_cache(256)(self.page)

def __enter__(self) -> Self:
"""Return ``self`` upon entering the runtime context."""
return self

def __exit__(self, _: type[BaseException] | None, __: BaseException | None, ___: TracebackType | None) -> bool:
self.close()
return False

def close(self) -> None:
"""Close the database and WAL."""
# Only close DB handle if we opened it using a path
if self.path is not None:
self.fh.close()

if self.wal is not None:
self.wal.close()

def checkpoints(self) -> Iterator[SQLite3]:
"""Yield instances of the database at all available checkpoints in the WAL file, if applicable."""
if not self.wal:
Expand Down
17 changes: 10 additions & 7 deletions dissect/database/sqlite3/wal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,13 @@


class WAL:
def __init__(self, fh: WAL | Path | BinaryIO):
def __init__(self, fh: Path | BinaryIO):
# Use the provided WAL file handle or try to open a sidecar WAL file.
if hasattr(fh, "read"):
name = getattr(fh, "name", None)
path = Path(name) if name else None
else:
if not isinstance(fh, Path):
fh = Path(fh)
if isinstance(fh, Path):
path = fh
fh = path.open("rb")
else:
path = None

self.fh = fh
self.path = path
Expand All @@ -45,6 +42,12 @@ def __init__(self, fh: WAL | Path | BinaryIO):

self.frame = lru_cache(1024)(self.frame)

def close(self) -> None:
"""Close the WAL."""
# Only close WAL handle if we opened it using a path
if self.path is not None:
self.fh.close()

def frame(self, frame_idx: int) -> Frame:
frame_size = len(c_sqlite3.wal_frame) + self.header.page_size
offset = len(c_sqlite3.wal_header) + frame_idx * frame_size
Expand Down
17 changes: 13 additions & 4 deletions tests/sqlite3/test_sqlite3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@


@pytest.mark.parametrize(
("db_as_path"),
[pytest.param(True, id="db_as_path"), pytest.param(False, id="db_as_fh")],
("open_as_path"),
[pytest.param(True, id="as_path"), pytest.param(False, id="as_fh")],
)
def test_sqlite(sqlite_db: Path, db_as_path: bool) -> None:
db = sqlite3.SQLite3(sqlite_db) if db_as_path else sqlite3.SQLite3(sqlite_db.open("rb"))
def test_sqlite(sqlite_db: Path, open_as_path: bool) -> None:
db = sqlite3.SQLite3(sqlite_db if open_as_path else sqlite_db.open("rb"))
_assert_sqlite_db(db)
db.close()

with sqlite3.SQLite3(sqlite_db if open_as_path else sqlite_db.open("rb")) as db:
_assert_sqlite_db(db)


def _assert_sqlite_db(db: sqlite3.SQLite3) -> None:
assert db.header.magic == sqlite3.SQLITE3_HEADER_MAGIC

tables = list(db.tables())
Expand Down Expand Up @@ -67,6 +74,8 @@ def test_sqlite(sqlite_db: Path, db_as_path: bool) -> None:
assert table.row(0).__dict__ == rows[0].__dict__
assert list(rows[0]) == [("id", 1), ("name", "testing"), ("value", 1337)]

db.close()


@pytest.mark.parametrize(
("input", "encoding", "expected_output"),
Expand Down
6 changes: 6 additions & 0 deletions tests/sqlite3/test_wal.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,26 @@ def test_sqlite_wal(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_
)
_assert_checkpoint_1(db)

db.close()

db = sqlite3.SQLite3(
sqlite_db if db_as_path else sqlite_db.open("rb"),
sqlite_wal if wal_as_path else sqlite_wal.open("rb"),
checkpoint=2,
)
_assert_checkpoint_2(db)

db.close()

db = sqlite3.SQLite3(
sqlite_db if db_as_path else sqlite_db.open("rb"),
sqlite_wal if wal_as_path else sqlite_wal.open("rb"),
checkpoint=3,
)
_assert_checkpoint_3(db)

db.close()


def _assert_checkpoint_1(s: sqlite3.SQLite3) -> None:
# After the first checkpoint the "after checkpoint" entries are present
Expand Down