import sqlite3 import trafilatura import hashlib import re import os from dotenv import load_dotenv from trafilatura.readability_lxml import is_probably_readerable from concurrent.futures import ThreadPoolExecutor, as_completed load_dotenv() # 常量定义 MAX_FETCH_LIMIT = int(os.getenv("FETCH_LIMIT")) # 每次运行时获取的最大任务数量 # 数据库连接 def connect_db(db_path): return sqlite3.connect(db_path) # 创建fetch_list表 def create_fetch_list_table(conn): cursor = conn.cursor() cursor.execute(""" CREATE TABLE IF NOT EXISTS fetch_list ( url TEXT PRIMARY KEY, fetched_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) conn.commit() # 获取未爬取的URL列表 def get_unfetched_urls(conn, limit): cursor = conn.cursor() cursor.execute(""" SELECT url FROM url_list WHERE url NOT IN (SELECT url FROM fetch_list) LIMIT ? """, (limit,)) return [row[0] for row in cursor.fetchall()] # 下载并提取网页内容 def fetch_and_extract_content(conn, url): downloaded = trafilatura.fetch_url(url) if not downloaded: return None html_string = downloaded if not is_probably_readerable(html_string) and os.getenv("FETCH_IGNORE_CHECK").capitalize() == "TRUE": print(f"URL {url} is not readable.") record_fetched_url(conn, url) return None content = trafilatura.extract(html_string, output_format="txt", url=url, favor_precision=True) print(f"Successfully extracted text for URL: {url}") return content # 计算URL的MD5 def md5_hash(url): return hashlib.md5(url.encode()).hexdigest() # 分段规则 def split_content(content): sentences = re.split(r'[。!?;.!?;]', content) segments = [] current_segment = [] current_length = 0 for sentence in sentences: sentence_length = len(sentence) if (len(current_segment) >= 12 or current_length + sentence_length > 1800): segments.append(''.join(current_segment)) current_segment = [] current_length = 0 current_segment.append(sentence) current_length += sentence_length if current_segment: segments.append(''.join(current_segment)) return segments # 保存分段内容到文件 def save_segments(url, segments, path): url_hash = md5_hash(url) for idx, segment in enumerate(segments): save_path = os.path.join(path, f"{url_hash}_{idx}.txt") with open(save_path, "w", encoding="utf-8") as f: f.write(segment) # 记录已爬取的URL def record_fetched_url(conn, url): cursor = conn.cursor() cursor.execute(""" INSERT INTO fetch_list (url, fetched_time) VALUES (?, CURRENT_TIMESTAMP) """, (url,)) conn.commit() # 处理单个URL的任务 def process_url(url, db_path, save_path): import time, random cooldown_base = float(os.getenv("FETCH_COOLDOWN")) time.sleep(random.random() * cooldown_base) conn = connect_db(db_path) content = fetch_and_extract_content(conn, url) if content: segments = split_content(content) save_segments(url, segments, save_path) record_fetched_url(conn, url) conn.close() time.sleep(random.random() * cooldown_base) # 主函数 def main(): db_path = "crawler.db" save_path = "./source" conn = connect_db(db_path) # 创建fetch_list表 create_fetch_list_table(conn) unfetched_urls = get_unfetched_urls(conn, MAX_FETCH_LIMIT) conn.close() with ThreadPoolExecutor(max_workers=int(os.getenv("FETCH_THREADS"))) as executor: futures = [executor.submit(process_url, url, db_path, save_path) for url in unfetched_urls] for future in as_completed(futures): try: future.result() except Exception as e: print(f"An error occurred: {e}") if __name__ == "__main__": main()