aboutsummaryrefslogtreecommitdiff
path: root/iTexSnip/Utils/RobertaTokenizerFast.swift
diff options
context:
space:
mode:
authorNavan Chauhan <navanchauhan@gmail.com>2024-10-21 23:54:10 -0600
committerNavan Chauhan <navanchauhan@gmail.com>2024-10-21 23:54:10 -0600
commit05165cc8d98ef5ffa8ee3a8ba9bf1ad5e0b5a9ab (patch)
tree7baea43c47d6c6fd00f87de3bb870df7966460ae /iTexSnip/Utils/RobertaTokenizerFast.swift
parent126c5a27ee98146c349303ecc7c77f6413cfe5fe (diff)
swift-format
Diffstat (limited to 'iTexSnip/Utils/RobertaTokenizerFast.swift')
-rw-r--r--iTexSnip/Utils/RobertaTokenizerFast.swift131
1 files changed, 69 insertions, 62 deletions
diff --git a/iTexSnip/Utils/RobertaTokenizerFast.swift b/iTexSnip/Utils/RobertaTokenizerFast.swift
index d61e8c7..888cc20 100644
--- a/iTexSnip/Utils/RobertaTokenizerFast.swift
+++ b/iTexSnip/Utils/RobertaTokenizerFast.swift
@@ -8,81 +8,88 @@
import Foundation
class RobertaTokenizerFast {
- var vocab: [String: Int] = [:]
- var idToToken: [Int: String] = [:]
- var specialTokens: [String] = []
- var unkTokenId: Int?
+ var vocab: [String: Int] = [:]
+ var idToToken: [Int: String] = [:]
+ var specialTokens: [String] = []
+ var unkTokenId: Int?
- init(vocabFile: String, tokenizerFile: String) {
- if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"),
- let vocabData = try? Data(contentsOf: vocabURL),
- let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) as? [String: Int] {
- self.vocab = vocabDict
- }
+ init(vocabFile: String, tokenizerFile: String) {
+ if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"),
+ let vocabData = try? Data(contentsOf: vocabURL),
+ let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: [])
+ as? [String: Int]
+ {
+ self.vocab = vocabDict
+ }
- if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"),
- let tokenizerData = try? Data(contentsOf: tokenizerURL),
- let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) as? [String: Any] {
- self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? []
- }
+ if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"),
+ let tokenizerData = try? Data(contentsOf: tokenizerURL),
+ let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: [])
+ as? [String: Any]
+ {
+ self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? []
+ }
- self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key }
+ self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key }
- self.unkTokenId = vocab["<unk>"]
- }
+ self.unkTokenId = vocab["<unk>"]
+ }
- func encode(text: String) -> [Int] {
- let tokens = tokenize(text)
- return tokens.map { vocab[$0] ?? unkTokenId! }
- }
+ func encode(text: String) -> [Int] {
+ let tokens = tokenize(text)
+ return tokens.map { vocab[$0] ?? unkTokenId! }
+ }
- func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String {
- let tokens = tokenIds.compactMap { idToToken[$0] }
- let filteredTokens = skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "</s>" } : tokens
- return convertTokensToString(filteredTokens)
- }
+ func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String {
+ let tokens = tokenIds.compactMap { idToToken[$0] }
+ let filteredTokens =
+ skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "</s>" } : tokens
+ return convertTokensToString(filteredTokens)
+ }
- private func tokenize(_ text: String) -> [String] {
- let cleanedText = cleanText(text)
- let words = cleanedText.split(separator: " ").map { String($0) }
-
- var tokens: [String] = []
- for word in words {
- tokens.append(contentsOf: bpeEncode(word))
- }
- return tokens
+ private func tokenize(_ text: String) -> [String] {
+ let cleanedText = cleanText(text)
+ let words = cleanedText.split(separator: " ").map { String($0) }
+
+ var tokens: [String] = []
+ for word in words {
+ tokens.append(contentsOf: bpeEncode(word))
}
+ return tokens
+ }
- private func bpeEncode(_ word: String) -> [String] {
- if vocab.keys.contains(word) {
- return [word]
- }
+ private func bpeEncode(_ word: String) -> [String] {
+ if vocab.keys.contains(word) {
+ return [word]
+ }
- let chars = Array(word)
- var tokens: [String] = []
- var i = 0
+ let chars = Array(word)
+ var tokens: [String] = []
+ var i = 0
- while i < chars.count {
- if i < chars.count - 1 {
- let pair = String(chars[i]) + String(chars[i + 1])
- if vocab.keys.contains(pair) {
- tokens.append(pair)
- i += 2
- continue
- }
- }
- tokens.append(String(chars[i]))
- i += 1
+ while i < chars.count {
+ if i < chars.count - 1 {
+ let pair = String(chars[i]) + String(chars[i + 1])
+ if vocab.keys.contains(pair) {
+ tokens.append(pair)
+ i += 2
+ continue
}
- return tokens
+ }
+ tokens.append(String(chars[i]))
+ i += 1
}
+ return tokens
+ }
- private func cleanText(_ text: String) -> String {
- return text.trimmingCharacters(in: .whitespacesAndNewlines)
- }
+ private func cleanText(_ text: String) -> String {
+ return text.trimmingCharacters(in: .whitespacesAndNewlines)
+ }
- private func convertTokensToString(_ tokens: [String]) -> String {
- let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ")
- return text.replacingOccurrences(of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil).trimmingCharacters(in: .whitespaces)
- }
+ private func convertTokensToString(_ tokens: [String]) -> String {
+ let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ")
+ return text.replacingOccurrences(
+ of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil
+ ).trimmingCharacters(in: .whitespaces)
+ }
}