add: energy score into inferencing
This commit is contained in:
parent
1acb0a7f11
commit
de7c5990bb
@ -21,6 +21,8 @@ import tokenize from "lib/nlp/tokenize/tokenizer";
|
|||||||
import { getEmbedding, getEmbeddingLayer } from "lib/nlp/getEmbedding";
|
import { getEmbedding, getEmbeddingLayer } from "lib/nlp/getEmbedding";
|
||||||
import { loadVocab } from "lib/nlp/tokenize/loadVocab";
|
import { loadVocab } from "lib/nlp/tokenize/loadVocab";
|
||||||
import BPETokenizer from "lib/nlp/tokenize/BPEtokenizer";
|
import BPETokenizer from "lib/nlp/tokenize/BPEtokenizer";
|
||||||
|
import energyScore from "lib/nlp/energyScore";
|
||||||
|
import bytesToUnicode from "lib/nlp/tokenize/bytesToUnicode";
|
||||||
|
|
||||||
interface EmbeddingLayer {
|
interface EmbeddingLayer {
|
||||||
[key: number]: Float32Array<ArrayBufferLike>;
|
[key: number]: Float32Array<ArrayBufferLike>;
|
||||||
@ -130,16 +132,12 @@ export default function OneSearch() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function getNLUResult(query: string) {
|
async function getNLUResult(query: string) {
|
||||||
const start = new Date().getTime();
|
|
||||||
if (embeddingLayer === null || NLUsession === null || tokenizer == null) return;
|
if (embeddingLayer === null || NLUsession === null || tokenizer == null) return;
|
||||||
const tokenIds = await tokenize(query, tokenizer);
|
const tokenIds = await tokenize(bytesToUnicode(query), tokenizer);
|
||||||
console.log(new Date().getTime() - start, "ms");
|
|
||||||
const embeddings = getEmbedding(tokenIds, embeddingLayer, 64);
|
const embeddings = getEmbedding(tokenIds, embeddingLayer, 64);
|
||||||
const inputTensor = new ort.Tensor("float32", embeddings, [1, 64, 96]);
|
const inputTensor = new ort.Tensor("float32", embeddings, [1, 64, 96]);
|
||||||
const feeds = { input: inputTensor };
|
const feeds = { input: inputTensor };
|
||||||
console.log(new Date().getTime() - start, "ms");
|
|
||||||
const results = await NLUsession.run(feeds);
|
const results = await NLUsession.run(feeds);
|
||||||
console.log(new Date().getTime() - start, "ms");
|
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,7 +169,13 @@ export default function OneSearch() {
|
|||||||
|
|
||||||
(async function () {
|
(async function () {
|
||||||
const result = await getNLUResult(query);
|
const result = await getNLUResult(query);
|
||||||
console.log(result);
|
if (result === undefined) return;
|
||||||
|
const rawData = result.output.data;
|
||||||
|
const data: number[] = [];
|
||||||
|
for (let i=0;i<rawData.length;i++){
|
||||||
|
data.push(rawData[i] as number);
|
||||||
|
}
|
||||||
|
console.log(data, energyScore(data));
|
||||||
})();
|
})();
|
||||||
}, [query, engineName]);
|
}, [query, engineName]);
|
||||||
|
|
||||||
|
13
lib/nlp/energyScore.ts
Normal file
13
lib/nlp/energyScore.ts
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
function logsumexp(arr: number[]): number {
|
||||||
|
const maxVal = Math.max(...arr);
|
||||||
|
const sumExp = arr.reduce((sum, val) => sum + Math.exp(val - maxVal), 0);
|
||||||
|
return Math.log(sumExp) + maxVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
function minusEnergyScore(logits: number[]): number {
|
||||||
|
return logsumexp(logits);
|
||||||
|
}
|
||||||
|
|
||||||
|
const energyScore = minusEnergyScore;
|
||||||
|
|
||||||
|
export default energyScore;
|
Loading…
Reference in New Issue
Block a user