aboutsummaryrefslogtreecommitdiff
path: root/iTexSnip/Utils/RobertaTokenizerFast.swift
diff options
context:
space:
mode:
authorNavan Chauhan <navanchauhan@gmail.com>2024-10-21 23:51:06 -0600
committerNavan Chauhan <navanchauhan@gmail.com>2024-10-21 23:51:06 -0600
commit05cf4dd46aebfaa7812989c88eb91d303a43d3e7 (patch)
tree6cf0542910c72f8135345facbcf39a8eb409d97e /iTexSnip/Utils/RobertaTokenizerFast.swift
parentc343140bcf9862b7b4b0d465b67e51eb42f45008 (diff)
bruh
Diffstat (limited to 'iTexSnip/Utils/RobertaTokenizerFast.swift')
-rw-r--r--iTexSnip/Utils/RobertaTokenizerFast.swift88
1 files changed, 88 insertions, 0 deletions
diff --git a/iTexSnip/Utils/RobertaTokenizerFast.swift b/iTexSnip/Utils/RobertaTokenizerFast.swift
new file mode 100644
index 0000000..d61e8c7
--- /dev/null
+++ b/iTexSnip/Utils/RobertaTokenizerFast.swift
@@ -0,0 +1,88 @@
+//
+// RobertaTokenizerFast.swift
+// iTexSnip
+//
+// Created by Navan Chauhan on 10/13/24.
+//
+
+import Foundation
+
+class RobertaTokenizerFast {
+ var vocab: [String: Int] = [:]
+ var idToToken: [Int: String] = [:]
+ var specialTokens: [String] = []
+ var unkTokenId: Int?
+
+ init(vocabFile: String, tokenizerFile: String) {
+ if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"),
+ let vocabData = try? Data(contentsOf: vocabURL),
+ let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) as? [String: Int] {
+ self.vocab = vocabDict
+ }
+
+ if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"),
+ let tokenizerData = try? Data(contentsOf: tokenizerURL),
+ let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) as? [String: Any] {
+ self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? []
+ }
+
+ self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key }
+
+ self.unkTokenId = vocab["<unk>"]
+ }
+
+ func encode(text: String) -> [Int] {
+ let tokens = tokenize(text)
+ return tokens.map { vocab[$0] ?? unkTokenId! }
+ }
+
+ func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String {
+ let tokens = tokenIds.compactMap { idToToken[$0] }
+ let filteredTokens = skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "</s>" } : tokens
+ return convertTokensToString(filteredTokens)
+ }
+
+ private func tokenize(_ text: String) -> [String] {
+ let cleanedText = cleanText(text)
+ let words = cleanedText.split(separator: " ").map { String($0) }
+
+ var tokens: [String] = []
+ for word in words {
+ tokens.append(contentsOf: bpeEncode(word))
+ }
+ return tokens
+ }
+
+ private func bpeEncode(_ word: String) -> [String] {
+ if vocab.keys.contains(word) {
+ return [word]
+ }
+
+ let chars = Array(word)
+ var tokens: [String] = []
+ var i = 0
+
+ while i < chars.count {
+ if i < chars.count - 1 {
+ let pair = String(chars[i]) + String(chars[i + 1])
+ if vocab.keys.contains(pair) {
+ tokens.append(pair)
+ i += 2
+ continue
+ }
+ }
+ tokens.append(String(chars[i]))
+ i += 1
+ }
+ return tokens
+ }
+
+ private func cleanText(_ text: String) -> String {
+ return text.trimmingCharacters(in: .whitespacesAndNewlines)
+ }
+
+ private func convertTokensToString(_ tokens: [String]) -> String {
+ let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ")
+ return text.replacingOccurrences(of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil).trimmingCharacters(in: .whitespaces)
+ }
+}