From 1151e5f0afe705f5e55f7688bb871c6c403a6b97 Mon Sep 17 00:00:00 2001 From: Koen van Eijk Date: Sun, 9 Jun 2024 01:06:21 +0200 Subject: [PATCH] Improved readability --- openrecall/database.py | 86 +++++++++++++++++------------------------- 1 file changed, 35 insertions(+), 51 deletions(-) diff --git a/openrecall/database.py b/openrecall/database.py index 2f30c8c..216320a 100644 --- a/openrecall/database.py +++ b/openrecall/database.py @@ -1,62 +1,46 @@ import sqlite3 +from collections import namedtuple +from typing import Any, List from openrecall.config import db_path - -def create_db(): - conn = sqlite3.connect(db_path) - c = conn.cursor() - c.execute( - """CREATE TABLE IF NOT EXISTS entries - (id INTEGER PRIMARY KEY AUTOINCREMENT, app TEXT, title TEXT, text TEXT, timestamp INTEGER, embedding BLOB)""" - ) - conn.commit() - conn.close() +Entry = namedtuple("Entry", ["id", "app", "title", "text", "timestamp", "embedding"]) -def get_all_entries(): - conn = sqlite3.connect(db_path) - c = conn.cursor() - results = c.execute("SELECT * FROM entries").fetchall() - conn.close() - entries = [] - for result in results: - entries.append( - { - "id": result[0], - "app": result[1], - "title": result[2], - "text": result[3], - "timestamp": result[4], - "embedding": result[5], - } +def create_db() -> None: + with sqlite3.connect(db_path) as conn: + c = conn.cursor() + c.execute( + """CREATE TABLE IF NOT EXISTS entries + (id INTEGER PRIMARY KEY AUTOINCREMENT, app TEXT, title TEXT, text TEXT, timestamp INTEGER, embedding BLOB)""" ) - return entries + conn.commit() -def get_timestamps(): - conn = sqlite3.connect(db_path) - c = conn.cursor() - results = c.execute( - "SELECT timestamp FROM entries ORDER BY timestamp DESC LIMIT 1000" - ).fetchall() - timestamps = [result[0] for result in results] - conn.close() - return timestamps +def get_all_entries() -> List[Entry]: + with sqlite3.connect(db_path) as conn: + c = conn.cursor() + results = c.execute("SELECT * FROM entries").fetchall() + return [Entry(*result) for result in results] -def insert_entry(text, timestamp, embedding, app, title): - conn = sqlite3.connect(db_path) - c = conn.cursor() + +def get_timestamps() -> List[int]: + with sqlite3.connect(db_path) as conn: + c = conn.cursor() + results = c.execute( + "SELECT timestamp FROM entries ORDER BY timestamp DESC LIMIT 1000" + ).fetchall() + return [result[0] for result in results] + + +def insert_entry( + text: str, timestamp: int, embedding: Any, app: str, title: str +) -> None: embedding_bytes = embedding.tobytes() - c.execute( - "INSERT INTO entries (text, timestamp, embedding, app, title) VALUES (?, ?, ?, ?, ?)", - ( - text, - timestamp, - embedding_bytes, - app, - title, - ), - ) - conn.commit() - conn.close() + with sqlite3.connect(db_path) as conn: + c = conn.cursor() + c.execute( + "INSERT INTO entries (text, timestamp, embedding, app, title) VALUES (?, ?, ?, ?, ?)", + (text, timestamp, embedding_bytes, app, title), + ) + conn.commit()