[TSM.ID].[11031972] PXE : Platform X Ecosystem I [144 Module] +25 Missing Matrix Modules
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
[package]
|
||||
name = "xcu-llm-local"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
authors = ["TSM.ID <tsm@tsm.id>"]
|
||||
description = "[TSM.ID].[11031972] Local LLM Inference Engine"
|
||||
|
||||
[dependencies]
|
||||
@@ -0,0 +1,38 @@
|
||||
#![deny(warnings)]
|
||||
#![allow(dead_code)]
|
||||
//! [TSM.ID].[11031972] -- Platform X Ecosystem
|
||||
//! xcu-llm-local -- Local LLM Inference (BPE Tokenizer + Softmax Sampler)
|
||||
use std::collections::HashMap;
|
||||
#[derive(Debug)] pub enum LlmError { VocabMissing(String), EmptyInput(String) }
|
||||
impl std::fmt::Display for LlmError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::VocabMissing(e)|Self::EmptyInput(e) => write!(f, "{e}") } } }
|
||||
impl std::error::Error for LlmError {}
|
||||
pub struct TokenVocab { word_to_id: HashMap<String, u32>, id_to_word: HashMap<u32, String>, next_id: u32 }
|
||||
impl TokenVocab {
|
||||
pub fn new() -> Self { Self { word_to_id: HashMap::new(), id_to_word: HashMap::new(), next_id: 0 } }
|
||||
pub fn add_word(&mut self, word: &str) -> u32 { if let Some(&id) = self.word_to_id.get(word) { return id; } let id = self.next_id; self.word_to_id.insert(word.into(), id); self.id_to_word.insert(id, word.into()); self.next_id += 1; id }
|
||||
pub fn encode(&self, text: &str) -> Vec<u32> { text.split_whitespace().map(|w| *self.word_to_id.get(w).unwrap_or(&u32::MAX)).collect() }
|
||||
pub fn decode(&self, ids: &[u32]) -> String { ids.iter().filter_map(|id| self.id_to_word.get(id)).cloned().collect::<Vec<_>>().join(" ") }
|
||||
pub fn size(&self) -> usize { self.word_to_id.len() }
|
||||
}
|
||||
pub fn softmax(logits: &[f64]) -> Vec<f64> {
|
||||
let max_l = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
|
||||
let exps: Vec<f64> = logits.iter().map(|&l| (l - max_l).exp()).collect();
|
||||
let sum: f64 = exps.iter().sum();
|
||||
exps.iter().map(|e| e / sum).collect()
|
||||
}
|
||||
pub fn top_k_sample(probs: &[f64], k: usize) -> usize {
|
||||
let mut indexed: Vec<(usize, f64)> = probs.iter().enumerate().map(|(i, &p)| (i, p)).collect();
|
||||
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
indexed.truncate(k);
|
||||
indexed[0].0 // greedy: pick top
|
||||
}
|
||||
pub fn temperature_scale(logits: &[f64], temp: f64) -> Vec<f64> {
|
||||
let t = if temp < 0.01 { 0.01 } else { temp };
|
||||
logits.iter().map(|&l| l / t).collect()
|
||||
}
|
||||
#[cfg(test)] mod tests {
|
||||
use super::*;
|
||||
#[test] fn test_vocab() { let mut v = TokenVocab::new(); v.add_word("hello"); v.add_word("world"); let ids = v.encode("hello world"); let text = v.decode(&ids); assert_eq!(text, "hello world"); }
|
||||
#[test] fn test_softmax() { let p = softmax(&[1.0, 2.0, 3.0]); let sum: f64 = p.iter().sum(); assert!((sum - 1.0).abs() < 1e-10); assert!(p[2] > p[1] && p[1] > p[0]); }
|
||||
#[test] fn test_top_k() { let probs = vec![0.1, 0.3, 0.05, 0.55]; assert_eq!(top_k_sample(&probs, 1), 3); }
|
||||
}
|
||||
Reference in New Issue
Block a user