diff --git a/dissect/database/sqlite3/sqlite3.py b/dissect/database/sqlite3/sqlite3.py index 25b32e6..3af3e97 100644 --- a/dissect/database/sqlite3/sqlite3.py +++ b/dissect/database/sqlite3/sqlite3.py @@ -19,7 +19,9 @@ if TYPE_CHECKING: from collections.abc import Iterator + from types import TracebackType + from typing_extensions import Self ENCODING = { 1: "utf-8", @@ -78,13 +80,11 @@ def __init__( wal: WAL | Path | BinaryIO | None = None, checkpoint: Checkpoint | int | None = None, ): - # Use the provided file handle or try to open the file path. - if hasattr(fh, "read"): - name = getattr(fh, "name", None) - path = Path(name) if name else None - else: + if isinstance(fh, Path): path = fh fh = path.open("rb") + else: + path = None self.fh = fh self.path = path @@ -105,12 +105,21 @@ def __init__( raise InvalidDatabase("Usable page size is too small") if wal: - self.wal = WAL(wal) if not isinstance(wal, WAL) else wal - elif path: + self.wal = wal if isinstance(wal, WAL) else WAL(wal) + else: # Check for WAL sidecar next to the DB. - wal_path = path.with_name(f"{path.name}-wal") - if wal_path.exists(): - self.wal = WAL(wal_path) + # If we have a path, we can deduce the WAL path. + # If we don't have a path, we can try to get it from the file handle. + if path is None: + # By deducing the path at this point and not earlier, we can keep the original passed + # path to indicate if we should close the file handle later on. + name = getattr(fh, "name", None) + path = Path(name) if name else None + + if path is not None: + wal_path = path.with_name(f"{path.name}-wal") + if wal_path.exists() and wal_path.stat().st_size > 0: + self.wal = WAL(wal_path) # If a checkpoint index was provided, resolve it to a Checkpoint object. if self.wal and isinstance(checkpoint, int): @@ -122,6 +131,23 @@ def __init__( self.page = lru_cache(256)(self.page) + def __enter__(self) -> Self: + """Return ``self`` upon entering the runtime context.""" + return self + + def __exit__(self, _: type[BaseException] | None, __: BaseException | None, ___: TracebackType | None) -> bool: + self.close() + return False + + def close(self) -> None: + """Close the database and WAL.""" + # Only close DB handle if we opened it using a path + if self.path is not None: + self.fh.close() + + if self.wal is not None: + self.wal.close() + def checkpoints(self) -> Iterator[SQLite3]: """Yield instances of the database at all available checkpoints in the WAL file, if applicable.""" if not self.wal: diff --git a/dissect/database/sqlite3/wal.py b/dissect/database/sqlite3/wal.py index 309f156..7d4ec76 100644 --- a/dissect/database/sqlite3/wal.py +++ b/dissect/database/sqlite3/wal.py @@ -23,16 +23,13 @@ class WAL: - def __init__(self, fh: WAL | Path | BinaryIO): + def __init__(self, fh: Path | BinaryIO): # Use the provided WAL file handle or try to open a sidecar WAL file. - if hasattr(fh, "read"): - name = getattr(fh, "name", None) - path = Path(name) if name else None - else: - if not isinstance(fh, Path): - fh = Path(fh) + if isinstance(fh, Path): path = fh fh = path.open("rb") + else: + path = None self.fh = fh self.path = path @@ -45,6 +42,12 @@ def __init__(self, fh: WAL | Path | BinaryIO): self.frame = lru_cache(1024)(self.frame) + def close(self) -> None: + """Close the WAL.""" + # Only close WAL handle if we opened it using a path + if self.path is not None: + self.fh.close() + def frame(self, frame_idx: int) -> Frame: frame_size = len(c_sqlite3.wal_frame) + self.header.page_size offset = len(c_sqlite3.wal_header) + frame_idx * frame_size diff --git a/tests/sqlite3/test_sqlite3.py b/tests/sqlite3/test_sqlite3.py index a0ce473..23cecbb 100644 --- a/tests/sqlite3/test_sqlite3.py +++ b/tests/sqlite3/test_sqlite3.py @@ -12,12 +12,19 @@ @pytest.mark.parametrize( - ("db_as_path"), - [pytest.param(True, id="db_as_path"), pytest.param(False, id="db_as_fh")], + ("open_as_path"), + [pytest.param(True, id="as_path"), pytest.param(False, id="as_fh")], ) -def test_sqlite(sqlite_db: Path, db_as_path: bool) -> None: - db = sqlite3.SQLite3(sqlite_db) if db_as_path else sqlite3.SQLite3(sqlite_db.open("rb")) +def test_sqlite(sqlite_db: Path, open_as_path: bool) -> None: + db = sqlite3.SQLite3(sqlite_db if open_as_path else sqlite_db.open("rb")) + _assert_sqlite_db(db) + db.close() + with sqlite3.SQLite3(sqlite_db if open_as_path else sqlite_db.open("rb")) as db: + _assert_sqlite_db(db) + + +def _assert_sqlite_db(db: sqlite3.SQLite3) -> None: assert db.header.magic == sqlite3.SQLITE3_HEADER_MAGIC tables = list(db.tables()) @@ -67,6 +74,8 @@ def test_sqlite(sqlite_db: Path, db_as_path: bool) -> None: assert table.row(0).__dict__ == rows[0].__dict__ assert list(rows[0]) == [("id", 1), ("name", "testing"), ("value", 1337)] + db.close() + @pytest.mark.parametrize( ("input", "encoding", "expected_output"), diff --git a/tests/sqlite3/test_wal.py b/tests/sqlite3/test_wal.py index cc01925..6d477fe 100644 --- a/tests/sqlite3/test_wal.py +++ b/tests/sqlite3/test_wal.py @@ -26,6 +26,8 @@ def test_sqlite_wal(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_ ) _assert_checkpoint_1(db) + db.close() + db = sqlite3.SQLite3( sqlite_db if db_as_path else sqlite_db.open("rb"), sqlite_wal if wal_as_path else sqlite_wal.open("rb"), @@ -33,6 +35,8 @@ def test_sqlite_wal(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_ ) _assert_checkpoint_2(db) + db.close() + db = sqlite3.SQLite3( sqlite_db if db_as_path else sqlite_db.open("rb"), sqlite_wal if wal_as_path else sqlite_wal.open("rb"), @@ -40,6 +44,8 @@ def test_sqlite_wal(sqlite_db: Path, sqlite_wal: Path, db_as_path: bool, wal_as_ ) _assert_checkpoint_3(db) + db.close() + def _assert_checkpoint_1(s: sqlite3.SQLite3) -> None: # After the first checkpoint the "after checkpoint" entries are present