From de7c5990bb0b9baa042602ff766d085a85378b4f Mon Sep 17 00:00:00 2001 From: alikia2x Date: Sun, 13 Oct 2024 18:33:49 +0800 Subject: [PATCH] add: energy score into inferencing --- components/onesearch/onesearch.tsx | 16 ++++++++++------ lib/nlp/energyScore.ts | 13 +++++++++++++ 2 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 lib/nlp/energyScore.ts diff --git a/components/onesearch/onesearch.tsx b/components/onesearch/onesearch.tsx index 14b09be..66e8b9a 100644 --- a/components/onesearch/onesearch.tsx +++ b/components/onesearch/onesearch.tsx @@ -21,6 +21,8 @@ import tokenize from "lib/nlp/tokenize/tokenizer"; import { getEmbedding, getEmbeddingLayer } from "lib/nlp/getEmbedding"; import { loadVocab } from "lib/nlp/tokenize/loadVocab"; import BPETokenizer from "lib/nlp/tokenize/BPEtokenizer"; +import energyScore from "lib/nlp/energyScore"; +import bytesToUnicode from "lib/nlp/tokenize/bytesToUnicode"; interface EmbeddingLayer { [key: number]: Float32Array; @@ -130,16 +132,12 @@ export default function OneSearch() { } async function getNLUResult(query: string) { - const start = new Date().getTime(); if (embeddingLayer === null || NLUsession === null || tokenizer == null) return; - const tokenIds = await tokenize(query, tokenizer); - console.log(new Date().getTime() - start, "ms"); + const tokenIds = await tokenize(bytesToUnicode(query), tokenizer); const embeddings = getEmbedding(tokenIds, embeddingLayer, 64); const inputTensor = new ort.Tensor("float32", embeddings, [1, 64, 96]); const feeds = { input: inputTensor }; - console.log(new Date().getTime() - start, "ms"); const results = await NLUsession.run(feeds); - console.log(new Date().getTime() - start, "ms"); return results; } @@ -171,7 +169,13 @@ export default function OneSearch() { (async function () { const result = await getNLUResult(query); - console.log(result); + if (result === undefined) return; + const rawData = result.output.data; + const data: number[] = []; + for (let i=0;i sum + Math.exp(val - maxVal), 0); + return Math.log(sumExp) + maxVal; +} + +function minusEnergyScore(logits: number[]): number { + return logsumexp(logits); +} + +const energyScore = minusEnergyScore; + +export default energyScore;