diff options
author | Navan Chauhan <navanchauhan@gmail.com> | 2024-10-21 23:54:10 -0600 |
---|---|---|
committer | Navan Chauhan <navanchauhan@gmail.com> | 2024-10-21 23:54:10 -0600 |
commit | 05165cc8d98ef5ffa8ee3a8ba9bf1ad5e0b5a9ab (patch) | |
tree | 7baea43c47d6c6fd00f87de3bb870df7966460ae /iTexSnip/Utils/RobertaTokenizerFast.swift | |
parent | 126c5a27ee98146c349303ecc7c77f6413cfe5fe (diff) |
swift-format
Diffstat (limited to 'iTexSnip/Utils/RobertaTokenizerFast.swift')
-rw-r--r-- | iTexSnip/Utils/RobertaTokenizerFast.swift | 131 |
1 files changed, 69 insertions, 62 deletions
diff --git a/iTexSnip/Utils/RobertaTokenizerFast.swift b/iTexSnip/Utils/RobertaTokenizerFast.swift index d61e8c7..888cc20 100644 --- a/iTexSnip/Utils/RobertaTokenizerFast.swift +++ b/iTexSnip/Utils/RobertaTokenizerFast.swift @@ -8,81 +8,88 @@ import Foundation class RobertaTokenizerFast { - var vocab: [String: Int] = [:] - var idToToken: [Int: String] = [:] - var specialTokens: [String] = [] - var unkTokenId: Int? + var vocab: [String: Int] = [:] + var idToToken: [Int: String] = [:] + var specialTokens: [String] = [] + var unkTokenId: Int? - init(vocabFile: String, tokenizerFile: String) { - if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"), - let vocabData = try? Data(contentsOf: vocabURL), - let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) as? [String: Int] { - self.vocab = vocabDict - } + init(vocabFile: String, tokenizerFile: String) { + if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"), + let vocabData = try? Data(contentsOf: vocabURL), + let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) + as? [String: Int] + { + self.vocab = vocabDict + } - if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"), - let tokenizerData = try? Data(contentsOf: tokenizerURL), - let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) as? [String: Any] { - self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? [] - } + if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"), + let tokenizerData = try? Data(contentsOf: tokenizerURL), + let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) + as? [String: Any] + { + self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? [] + } - self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key } + self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key } - self.unkTokenId = vocab["<unk>"] - } + self.unkTokenId = vocab["<unk>"] + } - func encode(text: String) -> [Int] { - let tokens = tokenize(text) - return tokens.map { vocab[$0] ?? unkTokenId! } - } + func encode(text: String) -> [Int] { + let tokens = tokenize(text) + return tokens.map { vocab[$0] ?? unkTokenId! } + } - func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String { - let tokens = tokenIds.compactMap { idToToken[$0] } - let filteredTokens = skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "</s>" } : tokens - return convertTokensToString(filteredTokens) - } + func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String { + let tokens = tokenIds.compactMap { idToToken[$0] } + let filteredTokens = + skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "</s>" } : tokens + return convertTokensToString(filteredTokens) + } - private func tokenize(_ text: String) -> [String] { - let cleanedText = cleanText(text) - let words = cleanedText.split(separator: " ").map { String($0) } - - var tokens: [String] = [] - for word in words { - tokens.append(contentsOf: bpeEncode(word)) - } - return tokens + private func tokenize(_ text: String) -> [String] { + let cleanedText = cleanText(text) + let words = cleanedText.split(separator: " ").map { String($0) } + + var tokens: [String] = [] + for word in words { + tokens.append(contentsOf: bpeEncode(word)) } + return tokens + } - private func bpeEncode(_ word: String) -> [String] { - if vocab.keys.contains(word) { - return [word] - } + private func bpeEncode(_ word: String) -> [String] { + if vocab.keys.contains(word) { + return [word] + } - let chars = Array(word) - var tokens: [String] = [] - var i = 0 + let chars = Array(word) + var tokens: [String] = [] + var i = 0 - while i < chars.count { - if i < chars.count - 1 { - let pair = String(chars[i]) + String(chars[i + 1]) - if vocab.keys.contains(pair) { - tokens.append(pair) - i += 2 - continue - } - } - tokens.append(String(chars[i])) - i += 1 + while i < chars.count { + if i < chars.count - 1 { + let pair = String(chars[i]) + String(chars[i + 1]) + if vocab.keys.contains(pair) { + tokens.append(pair) + i += 2 + continue } - return tokens + } + tokens.append(String(chars[i])) + i += 1 } + return tokens + } - private func cleanText(_ text: String) -> String { - return text.trimmingCharacters(in: .whitespacesAndNewlines) - } + private func cleanText(_ text: String) -> String { + return text.trimmingCharacters(in: .whitespacesAndNewlines) + } - private func convertTokensToString(_ tokens: [String]) -> String { - let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ") - return text.replacingOccurrences(of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil).trimmingCharacters(in: .whitespaces) - } + private func convertTokensToString(_ tokens: [String]) -> String { + let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ") + return text.replacingOccurrences( + of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil + ).trimmingCharacters(in: .whitespaces) + } } |