From a6151b21b77c9f48150e281d01b02df0eda9bb2f Mon Sep 17 00:00:00 2001 From: Koen van Eijk Date: Sun, 9 Jun 2024 00:59:25 +0200 Subject: [PATCH] =?UTF-8?q?Cleaning=20up=20some=20code=20=F0=9F=A7=B9and?= =?UTF-8?q?=20refactoring=20=F0=9F=94=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- openrecall/app.py | 462 ++++++--------------------------------- openrecall/config.py | 30 +++ openrecall/database.py | 62 ++++++ openrecall/nlp.py | 16 ++ openrecall/ocr.py | 19 ++ openrecall/screenshot.py | 73 +++++++ openrecall/utils.py | 87 ++++++++ setup.py | 14 +- 8 files changed, 367 insertions(+), 396 deletions(-) create mode 100644 openrecall/config.py create mode 100644 openrecall/database.py create mode 100644 openrecall/nlp.py create mode 100644 openrecall/ocr.py create mode 100644 openrecall/screenshot.py create mode 100644 openrecall/utils.py diff --git a/openrecall/app.py b/openrecall/app.py index 70cf7a9..defd1e7 100644 --- a/openrecall/app.py +++ b/openrecall/app.py @@ -1,340 +1,30 @@ -import os -import sqlite3 -import sys -import threading -import time +from threading import Thread -import mss import numpy as np -from doctr.models import ocr_predictor from flask import Flask, render_template_string, request, send_from_directory -from PIL import Image -from sentence_transformers import SentenceTransformer +from jinja2 import BaseLoader - -def get_appdata_folder(app_name="openrecall"): - """ - Get the path to the application data folder. - - Args: - app_name (str): The name of the application. - - Returns: - str: The path to the application data folder. - """ - if sys.platform == "win32": - appdata = os.getenv("APPDATA") - if not appdata: - raise EnvironmentError("APPDATA environment variable is not set.") - path = os.path.join(appdata, app_name) - elif sys.platform == "darwin": - home = os.path.expanduser("~") - path = os.path.join(home, "Library", "Application Support", app_name) - else: # Linux and other Unix-like systems - home = os.path.expanduser("~") - path = os.path.join(home, ".local", "share", app_name) - - if not os.path.exists(path): - os.makedirs(path) - - return path - - -appdata_folder = get_appdata_folder() - -print(f"All data is stored in: {appdata_folder}") - -db_path = os.path.join(appdata_folder, "recall.db") - -screenshots_path = os.path.join(appdata_folder, "screenshots") - -# ensure the screenshots folder exists -if not os.path.exists(screenshots_path): - try: - os.makedirs(screenshots_path) - except: - pass - - -def get_active_app_name_osx(): - """Returns the name of the active application.""" - from AppKit import NSWorkspace - - active_app = NSWorkspace.sharedWorkspace().activeApplication() - return active_app["NSApplicationName"] - - -def get_active_window_title_osx(): - """Returns the title of the active window.""" - from Quartz import ( - CGWindowListCopyWindowInfo, - kCGNullWindowID, - kCGWindowListOptionOnScreenOnly, - ) - - app_name = get_active_app_name_osx() - windows = CGWindowListCopyWindowInfo( - kCGWindowListOptionOnScreenOnly, kCGNullWindowID - ) - - for window in windows: - if window["kCGWindowOwnerName"] == app_name: - return window.get("kCGWindowName", "Unknown") - - return None - - -def get_active_app_name_windows(): - """returns the app's name .exe""" - import psutil - import win32gui - import win32process - - # Get the handle of the foreground window - hwnd = win32gui.GetForegroundWindow() - - # Get the thread process ID of the foreground window - _, pid = win32process.GetWindowThreadProcessId(hwnd) - - # Get the process name using psutil - exe = psutil.Process(pid).name() - return exe - - -def get_active_window_title_windows(): - """Returns the title of the active window.""" - import win32gui - - hwnd = win32gui.GetForegroundWindow() - window_title = win32gui.GetWindowText(hwnd) - return window_title - - -def get_active_app_name(): - if sys.platform == "win32": - return get_active_app_name_windows() - elif sys.platform == "darwin": - return get_active_app_name_osx() - else: - raise NotImplementedError("This platform is not supported") - - -def get_active_window_title(): - if sys.platform == "win32": - return get_active_window_title_windows() - elif sys.platform == "darwin": - return get_active_window_title_osx() - else: - raise NotImplementedError("This platform is not supported") - - -def create_db(): - # create table if not exists for entries, with columns id, text, datetime, and embedding (blob) - conn = sqlite3.connect(db_path) - c = conn.cursor() - c.execute( - """CREATE TABLE IF NOT EXISTS entries - (id INTEGER PRIMARY KEY AUTOINCREMENT, app TEXT, title TEXT, text TEXT, timestamp INTEGER, embedding BLOB)""" - ) - conn.commit() - conn.close() - - -def get_embedding(text): - # Initialize the model - model = SentenceTransformer("all-MiniLM-L6-v2") - - # Split text into sentences - sentences = text.split("\n") - - # Get sentence embeddings - sentence_embeddings = model.encode(sentences) - - # Aggregate embeddings (mean pooling in this example) - mean = np.mean(sentence_embeddings, axis=0) - # convert to float64 - mean = mean.astype(np.float64) - return mean - - -ocr = ocr_predictor( - pretrained=True, - det_arch="db_mobilenet_v3_large", - reco_arch="crnn_mobilenet_v3_large", +from openrecall.config import screenshots_path, appdata_folder +from openrecall.database import create_db, get_all_entries, get_timestamps +from openrecall.screenshot import record_screenshots_thread +from openrecall.utils import ( + human_readable_time, + timestamp_to_human_readable, ) - - -def take_screenshot(monitor=1): - """ - Take a screenshot of the specified monitor. - - Args: - monitor (int): The index of the monitor to capture the screenshot from. - - Returns: - numpy.ndarray: The screenshot image as a numpy array. - """ - with mss.mss() as sct: - monitor_ = sct.monitors[monitor] - screenshot = np.array(sct.grab(monitor_)) - screenshot = screenshot[:, :, [2, 1, 0]] - return screenshot - -def record_screenshot_thread(): - """ - Thread function to continuously record screenshots and process them. - - This function takes screenshots at regular intervals and compares them with the previous screenshot. - If the new screenshot is different enough from the previous one, it saves the screenshot, performs OCR on it, - extracts the text, computes the embedding, and stores the entry in the database. - - Returns: - None - """ - last_screenshot = take_screenshot() - - while True: - screenshot = take_screenshot() - - if not is_similar(screenshot, last_screenshot): - last_screenshot = screenshot - image = Image.fromarray(screenshot) - timestamp = int(time.time()) - image.save( - os.path.join(screenshots_path, f"{timestamp}.webp"), - format="webp", - lossless=True, - ) - result = ocr([screenshot]) - text = "" - - for page in result.pages: - for block in page.blocks: - for line in block.lines: - for word in line.words: - text += word.value + " " - text += "\n" - text += "\n" - - embedding = get_embedding(text) - active_app_name = get_active_app_name() - active_window_title = get_active_window_title() - - # connect to db - conn = sqlite3.connect(db_path) - c = conn.cursor() - - # Insert the entry into the database - embedding_bytes = embedding.tobytes() - c.execute( - "INSERT INTO entries (text, timestamp, embedding, app, title) VALUES (?, ?, ?, ?, ?)", - ( - text, - timestamp, - embedding_bytes, - active_app_name, - active_window_title, - ), - ) - - # Commit the transaction - conn.commit() - conn.close() - - time.sleep(3) - - -def mean_structured_similarity_index(img1, img2, L=255): - """Compute the mean Structural Similarity Index between two images.""" - K1, K2 = 0.01, 0.03 - C1, C2 = (K1 * L) ** 2, (K2 * L) ** 2 - - # Convert images to grayscale - def rgb2gray(img): - return 0.2989 * img[..., 0] + 0.5870 * img[..., 1] + 0.1140 * img[..., 2] - - img1_gray = rgb2gray(img1) - img2_gray = rgb2gray(img2) - - # Means - mu1 = np.mean(img1_gray) - mu2 = np.mean(img2_gray) - - # Variances and covariances - sigma1_sq = np.var(img1_gray) - sigma2_sq = np.var(img2_gray) - sigma12 = np.mean((img1_gray - mu1) * (img2_gray - mu2)) - - # SSIM computation - ssim_index = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ( - (mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2) - ) - - return ssim_index - - -def is_similar(img1, img2, similarity_threshold=0.9): - """Check if two images are similar based on a given similarity threshold.""" - similarity = mean_structured_similarity_index(img1, img2) - return similarity >= similarity_threshold - - -def cosine_similarity(a, b): - return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) - +from openrecall.nlp import get_embedding, cosine_similarity app = Flask(__name__) - -def human_readable_time(timestamp): - import datetime - - now = datetime.datetime.now() - dt_object = datetime.datetime.fromtimestamp(timestamp) - - diff = now - dt_object - - if diff.days > 0: - return f"{diff.days} days ago" - elif diff.seconds < 60: - return f"{diff.seconds} seconds ago" - elif diff.seconds < 3600: - return f"{diff.seconds // 60} minutes ago" - else: - return f"{diff.seconds // 3600} hours ago" - - -def timestamp_to_human_readable(timestamp): - import datetime - try: - dt_object = datetime.datetime.fromtimestamp(timestamp) - return dt_object.strftime("%Y-%m-%d %H:%M:%S") - except: - return "" - - app.jinja_env.filters["human_readable_time"] = human_readable_time app.jinja_env.filters["timestamp_to_human_readable"] = timestamp_to_human_readable - -@app.route("/") -def timeline(): - # connect to db - conn = sqlite3.connect(db_path) - c = conn.cursor() - results = c.execute( - "SELECT timestamp FROM entries ORDER BY timestamp DESC LIMIT 1000" - ).fetchall() - timestamps = [result[0] for result in results] - conn.close() - return render_template_string( - """ +base_template = """ - OpenRecall - Timeline + OpenRecall @@ -373,6 +63,38 @@ def timeline(): +{% block content %} + +{% endblock %} + + + + + + + + +""" + + +class StringLoader(BaseLoader): + def get_source(self, environment, template): + if template == "base_template": + return base_template, None, lambda: True + return None, None, None + + +app.jinja_env.loader = StringLoader() + + +@app.route("/") +def timeline(): + # connect to db + timestamps = get_timestamps() + return render_template_string( + """ +{% extends "base_template" %} +{% block content %} {% if timestamps|length > 0 %}
@@ -383,18 +105,6 @@ def timeline(): Image for timestamp
-{% else %} -
- -
- -{% endif %} - - - - - - - """, +{% else %} +
+ +
+{% endif %} +{% endblock %} +""", timestamps=timestamps, ) @@ -423,72 +139,34 @@ def timeline(): @app.route("/search") def search(): q = request.args.get("q") - - # load embeddings from db to numpy array - conn = sqlite3.connect(db_path) - c = conn.cursor() - - # Get all entries - results = c.execute("SELECT * FROM entries").fetchall() - embeddings = [] - - for result in results: - embeddings.append(np.frombuffer(result[5], dtype=np.float64)) - - embeddings = np.array(embeddings) - - # Get the embedding of the query + entries = get_all_entries() + embeddings = [ + np.frombuffer(entry["embedding"], dtype=np.float64) for entry in entries + ] query_embedding = get_embedding(q) - - # Compute the cosine similarity between the query and all entries - similarities = [] - - for embedding in embeddings: - similarities.append(cosine_similarity(query_embedding, embedding)) - - # Sort the entries by similarity + similarities = [cosine_similarity(query_embedding, emb) for emb in embeddings] indices = np.argsort(similarities)[::-1] - - entries = [] - - for i in indices: - result = results[i] - entries.append( - { - "text": result[3], - "timestamp": result[4], - "image_path": f"/static/{result[4]}.webp", - } - ) + sorted_entries = [entries[i] for i in indices] return render_template_string( """ - - - Search Results - - - - +{% extends "base_template" %} +{% block content %}
-

Search Results

{% for entry in entries %} -
- - - - - - - - """, - entries=entries, +{% endblock %} +""", + entries=sorted_entries, ) @@ -516,8 +188,10 @@ def serve_image(filename): if __name__ == "__main__": create_db() + print(f"Appdata folder: {appdata_folder}") + # Start the thread to record screenshots - t = threading.Thread(target=record_screenshot_thread) + t = Thread(target=record_screenshots_thread) t.start() app.run(port=8082) diff --git a/openrecall/config.py b/openrecall/config.py new file mode 100644 index 0000000..ea246c0 --- /dev/null +++ b/openrecall/config.py @@ -0,0 +1,30 @@ +import os +import sys + + +def get_appdata_folder(app_name="openrecall"): + if sys.platform == "win32": + appdata = os.getenv("APPDATA") + if not appdata: + raise EnvironmentError("APPDATA environment variable is not set.") + path = os.path.join(appdata, app_name) + elif sys.platform == "darwin": + home = os.path.expanduser("~") + path = os.path.join(home, "Library", "Application Support", app_name) + else: + home = os.path.expanduser("~") + path = os.path.join(home, ".local", "share", app_name) + if not os.path.exists(path): + os.makedirs(path) + return path + + +appdata_folder = get_appdata_folder() +db_path = os.path.join(appdata_folder, "recall.db") +screenshots_path = os.path.join(appdata_folder, "screenshots") + +if not os.path.exists(screenshots_path): + try: + os.makedirs(screenshots_path) + except: + pass diff --git a/openrecall/database.py b/openrecall/database.py new file mode 100644 index 0000000..2f30c8c --- /dev/null +++ b/openrecall/database.py @@ -0,0 +1,62 @@ +import sqlite3 + +from openrecall.config import db_path + + +def create_db(): + conn = sqlite3.connect(db_path) + c = conn.cursor() + c.execute( + """CREATE TABLE IF NOT EXISTS entries + (id INTEGER PRIMARY KEY AUTOINCREMENT, app TEXT, title TEXT, text TEXT, timestamp INTEGER, embedding BLOB)""" + ) + conn.commit() + conn.close() + + +def get_all_entries(): + conn = sqlite3.connect(db_path) + c = conn.cursor() + results = c.execute("SELECT * FROM entries").fetchall() + conn.close() + entries = [] + for result in results: + entries.append( + { + "id": result[0], + "app": result[1], + "title": result[2], + "text": result[3], + "timestamp": result[4], + "embedding": result[5], + } + ) + return entries + + +def get_timestamps(): + conn = sqlite3.connect(db_path) + c = conn.cursor() + results = c.execute( + "SELECT timestamp FROM entries ORDER BY timestamp DESC LIMIT 1000" + ).fetchall() + timestamps = [result[0] for result in results] + conn.close() + return timestamps + +def insert_entry(text, timestamp, embedding, app, title): + conn = sqlite3.connect(db_path) + c = conn.cursor() + embedding_bytes = embedding.tobytes() + c.execute( + "INSERT INTO entries (text, timestamp, embedding, app, title) VALUES (?, ?, ?, ?, ?)", + ( + text, + timestamp, + embedding_bytes, + app, + title, + ), + ) + conn.commit() + conn.close() diff --git a/openrecall/nlp.py b/openrecall/nlp.py new file mode 100644 index 0000000..6f73098 --- /dev/null +++ b/openrecall/nlp.py @@ -0,0 +1,16 @@ +from sentence_transformers import SentenceTransformer +import numpy as np + + +def get_embedding(text): + model = SentenceTransformer("all-MiniLM-L6-v2") + sentences = text.split("\n") + sentence_embeddings = model.encode(sentences) + mean = np.mean(sentence_embeddings, axis=0) + mean = mean.astype(np.float64) + return mean + + +def cosine_similarity(a, b): + return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) + diff --git a/openrecall/ocr.py b/openrecall/ocr.py new file mode 100644 index 0000000..5975cb8 --- /dev/null +++ b/openrecall/ocr.py @@ -0,0 +1,19 @@ +from doctr.models import ocr_predictor + +ocr = ocr_predictor( + pretrained=True, + det_arch="db_mobilenet_v3_large", + reco_arch="crnn_mobilenet_v3_large", +) + +def extract_text_from_image(image): + result = ocr([image]) + text = "" + for page in result.pages: + for block in page.blocks: + for line in block.lines: + for word in line.words: + text += word.value + " " + text += "\n" + text += "\n" + return text diff --git a/openrecall/screenshot.py b/openrecall/screenshot.py new file mode 100644 index 0000000..19f0a81 --- /dev/null +++ b/openrecall/screenshot.py @@ -0,0 +1,73 @@ +import os +import time + +import mss +import numpy as np +from PIL import Image + +from openrecall.config import db_path, screenshots_path +from openrecall.ocr import extract_text_from_image +from openrecall.utils import ( + get_active_app_name, + get_active_window_title +) +from openrecall.nlp import get_embedding +from openrecall.database import insert_entry + +def mean_structured_similarity_index(img1, img2, L=255): + K1, K2 = 0.01, 0.03 + C1, C2 = (K1 * L) ** 2, (K2 * L) ** 2 + + def rgb2gray(img): + return 0.2989 * img[..., 0] + 0.5870 * img[..., 1] + 0.1140 * img[..., 2] + + img1_gray = rgb2gray(img1) + img2_gray = rgb2gray(img2) + mu1 = np.mean(img1_gray) + mu2 = np.mean(img2_gray) + sigma1_sq = np.var(img1_gray) + sigma2_sq = np.var(img2_gray) + sigma12 = np.mean((img1_gray - mu1) * (img2_gray - mu2)) + ssim_index = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2) + ) + return ssim_index + + +def is_similar(img1, img2, similarity_threshold=0.9): + similarity = mean_structured_similarity_index(img1, img2) + return similarity >= similarity_threshold + + +def take_screenshots(monitor=1): + screenshots = [] + with mss.mss() as sct: + for monitor in range(len(sct.monitors)): + monitor_ = sct.monitors[monitor] + screenshot = np.array(sct.grab(monitor_)) + screenshot = screenshot[:, :, [2, 1, 0]] + screenshots.append(screenshot) + return screenshots + + +def record_screenshots_thread(): + last_screenshots = take_screenshots() + while True: + screenshots = take_screenshots() + for i, screenshot in enumerate(screenshots): + last_screenshot = last_screenshots[i] + if not is_similar(screenshot, last_screenshot): + last_screenshots[i] = screenshot + image = Image.fromarray(screenshot) + timestamp = int(time.time()) + image.save( + os.path.join(screenshots_path, f"{timestamp}.webp"), + format="webp", + lossless=True, + ) + text = extract_text_from_image(screenshot) + embedding = get_embedding(text) + active_app_name = get_active_app_name() + active_window_title = get_active_window_title() + insert_entry(text, timestamp, embedding, active_app_name, active_window_title) + time.sleep(3) diff --git a/openrecall/utils.py b/openrecall/utils.py new file mode 100644 index 0000000..76cfb11 --- /dev/null +++ b/openrecall/utils.py @@ -0,0 +1,87 @@ +import sys + + +def human_readable_time(timestamp): + import datetime + + now = datetime.datetime.now() + dt_object = datetime.datetime.fromtimestamp(timestamp) + diff = now - dt_object + if diff.days > 0: + return f"{diff.days} days ago" + elif diff.seconds < 60: + return f"{diff.seconds} seconds ago" + elif diff.seconds < 3600: + return f"{diff.seconds // 60} minutes ago" + else: + return f"{diff.seconds // 3600} hours ago" + + +def timestamp_to_human_readable(timestamp): + import datetime + + try: + dt_object = datetime.datetime.fromtimestamp(timestamp) + return dt_object.strftime("%Y-%m-%d %H:%M:%S") + except: + return "" + + +def get_active_app_name_osx(): + from AppKit import NSWorkspace + + active_app = NSWorkspace.sharedWorkspace().activeApplication() + return active_app["NSApplicationName"] + + +def get_active_window_title_osx(): + from Quartz import ( + CGWindowListCopyWindowInfo, + kCGNullWindowID, + kCGWindowListOptionOnScreenOnly, + ) + + app_name = get_active_app_name_osx() + windows = CGWindowListCopyWindowInfo( + kCGWindowListOptionOnScreenOnly, kCGNullWindowID + ) + for window in windows: + if window["kCGWindowOwnerName"] == app_name: + return window.get("kCGWindowName", "Unknown") + return None + + +def get_active_app_name_windows(): + import psutil + import win32gui + import win32process + + hwnd = win32gui.GetForegroundWindow() + _, pid = win32process.GetWindowThreadProcessId(hwnd) + exe = psutil.Process(pid).name() + return exe + + +def get_active_window_title_windows(): + import win32gui + + hwnd = win32gui.GetForegroundWindow() + return win32gui.GetWindowText(hwnd) + + +def get_active_app_name(): + if sys.platform == "win32": + return get_active_app_name_windows() + elif sys.platform == "darwin": + return get_active_app_name_osx() + else: + raise NotImplementedError("This platform is not supported") + + +def get_active_window_title(): + if sys.platform == "win32": + return get_active_window_title_windows() + elif sys.platform == "darwin": + return get_active_window_title_osx() + else: + raise NotImplementedError("This platform is not supported") diff --git a/setup.py b/setup.py index c57134d..8ea3b44 100644 --- a/setup.py +++ b/setup.py @@ -17,14 +17,24 @@ install_requires = [ "torchvision==0.18.0", "shapely", "h5py", - "rapidfuzz" + "rapidfuzz", ] import subprocess import sys + def install_doctr(): - subprocess.run([sys.executable, "-m", "pip", "install", "git+https://github.com/koenvaneijk/doctr.git"]) + subprocess.run( + [ + sys.executable, + "-m", + "pip", + "install", + "git+https://github.com/koenvaneijk/doctr.git", + ] + ) + install_doctr()