diff options
Diffstat (limited to 'iTexSnip/Utils')
-rw-r--r-- | iTexSnip/Utils/KatexUtils.swift | 131 | ||||
-rw-r--r-- | iTexSnip/Utils/TexTellerModel.swift | 165 |
2 files changed, 296 insertions, 0 deletions
diff --git a/iTexSnip/Utils/KatexUtils.swift b/iTexSnip/Utils/KatexUtils.swift new file mode 100644 index 0000000..d339a71 --- /dev/null +++ b/iTexSnip/Utils/KatexUtils.swift @@ -0,0 +1,131 @@ +// +// KatexUtils.swift +// iTexSnip +// +// Created by Navan Chauhan on 10/13/24. +// + +import Foundation + +func change(_ inputStr: String, oldInst: String, newInst: String, oldSurrL: Character, oldSurrR: Character, newSurrL: String, newSurrR: String) -> String { + var result = "" + var i = 0 + let n = inputStr.count + let inputArray = Array(inputStr) // Convert string to array of characters for easier access + + while i < n { + // Get the range for the substring equivalent to oldInst + if i + oldInst.count <= n && inputStr[inputStr.index(inputStr.startIndex, offsetBy: i)..<inputStr.index(inputStr.startIndex, offsetBy: i + oldInst.count)] == oldInst { + // Check if the old_inst is followed by old_surr_l + let start = i + oldInst.count + if start < n && inputArray[start] == oldSurrL { + var count = 1 + var j = start + 1 + var escaped = false + + while j < n && count > 0 { + if inputArray[j] == "\\" && !escaped { + escaped = true + j += 1 + continue + } + + if inputArray[j] == oldSurrR && !escaped { + count -= 1 + if count == 0 { + break + } + } else if inputArray[j] == oldSurrL && !escaped { + count += 1 + } + + escaped = false + j += 1 + } + + if count == 0 { + let innerContent = String(inputArray[(start + 1)..<j]) + result += newInst + newSurrL + innerContent + newSurrR + i = j + 1 + continue + } else { + result += newInst + newSurrL + i = start + 1 + continue + } + } + } + result.append(inputArray[i]) + i += 1 + } + + if oldInst != newInst && result.contains(oldInst + String(oldSurrL)) { + return change(result, oldInst: oldInst, newInst: newInst, oldSurrL: oldSurrL, oldSurrR: oldSurrR, newSurrL: newSurrL, newSurrR: newSurrR) + } + + return result +} + + +func findSubstringPositions(_ string: String, substring: String) -> [Int] { + var positions: [Int] = [] + var searchRange = string.startIndex..<string.endIndex + + while let range = string.range(of: substring, options: [], range: searchRange) { + let position = string.distance(from: string.startIndex, to: range.lowerBound) + positions.append(position) + searchRange = range.upperBound..<string.endIndex + } + + return positions +} + +func rmDollarSurr(content: String) -> String { + let pattern = try! NSRegularExpression(pattern: "\\\\[a-zA-Z]+\\$.*?\\$|\\$.*?\\$", options: []) + var newContent = content + let matches = pattern.matches(in: content, options: [], range: NSRange(content.startIndex..<content.endIndex, in: content)) + + for match in matches.reversed() { + let matchedString = (content as NSString).substring(with: match.range) + if !matchedString.starts(with: "\\") { + let strippedMatch = matchedString.replacingOccurrences(of: "$", with: "") + newContent = newContent.replacingOccurrences(of: matchedString, with: " \(strippedMatch) ") + } + } + + return newContent +} + +func changeAll(inputStr: String, oldInst: String, newInst: String, oldSurrL: Character, oldSurrR: Character, newSurrL: String, newSurrR: String) -> String { + let positions = findSubstringPositions(inputStr, substring: oldInst + String(oldSurrL)) + var result = inputStr + + for pos in positions.reversed() { + let startIndex = result.index(result.startIndex, offsetBy: pos) + let substring = String(result[startIndex..<result.endIndex]) + let changedSubstring = change(substring, oldInst: oldInst, newInst: newInst, oldSurrL: oldSurrL, oldSurrR: oldSurrR, newSurrL: newSurrL, newSurrR: newSurrR) + result.replaceSubrange(startIndex..<result.endIndex, with: changedSubstring) + } + + return result +} + +func toKatex(formula: String) -> String { + var res = formula + // Remove mbox surrounding + res = changeAll(inputStr: res, oldInst: "\\mbox ", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", newSurrR: "") + res = changeAll(inputStr: res, oldInst: "\\mbox", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", newSurrR: "") + + // Additional processing similar to the Python version... + res = res.replacingOccurrences(of: "\\[", with: "") + res = res.replacingOccurrences(of: "\\]", with: "") + res = res.replacingOccurrences(of: "\\\\[?.!,\'\"](?:\\s|$)", with: "", options: .regularExpression) + + // Merge consecutive `text` + res = rmDollarSurr(content: res) + + // Remove extra spaces + res = res.replacingOccurrences(of: " +", with: " ", options: .regularExpression) + + return res.trimmingCharacters(in: .whitespacesAndNewlines) +} diff --git a/iTexSnip/Utils/TexTellerModel.swift b/iTexSnip/Utils/TexTellerModel.swift new file mode 100644 index 0000000..fb3bcd9 --- /dev/null +++ b/iTexSnip/Utils/TexTellerModel.swift @@ -0,0 +1,165 @@ +// +// TexTellerModel.swift +// iTexSnip +// +// Created by Navan Chauhan on 10/20/24. +// + +import OnnxRuntimeBindings +import AppKit + +public struct TexTellerModel { + public let encoderSession: ORTSession + public let decoderSession: ORTSession + private let tokenizer: RobertaTokenizerFast + + public init() throws { + guard let encoderModelPath = Bundle.main.path(forResource: "encoder_model", ofType: "onnx") else { + print("Encoder model not found...") + throw ModelError.encoderModelNotFound + } + guard let decoderModelPath = Bundle.main.path(forResource: "decoder_model", ofType: "onnx") else { + print("Decoder model not found...") + throw ModelError.decoderModelNotFound + } + let env = try ORTEnv(loggingLevel: .warning) + let coreMLOptions = ORTCoreMLExecutionProviderOptions() + coreMLOptions.enableOnSubgraphs = true + coreMLOptions.createMLProgram = false + let options = try ORTSessionOptions() +// try options.appendCoreMLExecutionProvider(with: coreMLOptions) + encoderSession = try ORTSession(env: env, modelPath: encoderModelPath, sessionOptions: options) + decoderSession = try ORTSession(env: env, modelPath: decoderModelPath, sessionOptions: options) + + self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer") + } + + public static func asyncInit() async throws -> TexTellerModel { + return try await withCheckedThrowingContinuation { continuation in + DispatchQueue.global(qos: .userInitiated).async { + do { + let model = try TexTellerModel() + continuation.resume(returning: model) + } catch { + continuation.resume(throwing: error) + } + } + } + } + + public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) throws -> String { + let transformedImage = inferenceTransform(images: [image]) + if let firstTransformedImage = transformedImage.first { + let pixelValues = ciImageToFloatArray(firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE)) + if debug { + print("First few pixel inputs: \(pixelValues.prefix(10))") + } + let inputTensor = try ORTValue( + tensorData: NSMutableData( + data: Data(bytes: pixelValues, count: pixelValues.count * MemoryLayout<Float>.stride) + ), + elementType: .float, + shape: [ + 1, 1, NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE) + ] + ) + let encoderInput: [String: ORTValue] = [ + "pixel_values": inputTensor + ] + let encoderOutputNames = try self.encoderSession.outputNames() + let encoderOutputs: [String: ORTValue] = try self.encoderSession.run( + withInputs: encoderInput, + outputNames: Set(encoderOutputNames), + runOptions: nil + ) + + if (debug) { + print("Encoder output: \(encoderOutputs)") + } + + var decodedTokenIds: [Int] = [] + let startTokenId = 0 // TODO: Move to tokenizer directly? + let endTokenId = 2 + let maxDecoderLength: Int = 300 + var decoderInputIds: [Int] = [startTokenId] + let vocabSize = 15000 + + if (debug) { + let encoderHiddenStatesData = try encoderOutputs["last_hidden_state"]!.tensorData() as Data + let encoderHiddenStatesArray = encoderHiddenStatesData.withUnsafeBytes { + Array(UnsafeBufferPointer<Float>( + start: $0.baseAddress!.assumingMemoryBound(to: Float.self), + count: encoderHiddenStatesData.count / MemoryLayout<Float>.stride + )) + } + + print("First few values of encoder hidden states: \(encoderHiddenStatesArray.prefix(10))") + } + + let decoderOutputNames = try self.decoderSession.outputNames() + + for step in 0..<maxDecoderLength { + if (debug) { + print("Step \(step)") + } + + let decoderInputIdsTensor = try ORTValue( + tensorData: NSMutableData(data: Data(bytes: decoderInputIds, count: decoderInputIds.count * MemoryLayout<Int64>.stride)), + elementType: .int64, + shape: [1, NSNumber(value: decoderInputIds.count)] + ) + let decoderInputs: [String: ORTValue] = [ + "input_ids": decoderInputIdsTensor, + "encoder_hidden_states": encoderOutputs["last_hidden_state"]! + ] + let decoderOutputs: [String: ORTValue] = try self.decoderSession.run(withInputs: decoderInputs, outputNames: Set(decoderOutputNames), runOptions: nil) + let logitsTensor = decoderOutputs["logits"]! + let logitsData = try logitsTensor.tensorData() as Data + let logits = logitsData.withUnsafeBytes { + Array(UnsafeBufferPointer<Float>( + start: $0.baseAddress!.assumingMemoryBound(to: Float.self), + count: logitsData.count / MemoryLayout<Float>.stride + )) + } + let sequenceLength = decoderInputIds.count + let startIndex = (sequenceLength - 1) * vocabSize + let endIndex = startIndex + vocabSize + let lastTokenLogits = Array(logits[startIndex..<endIndex]) + let nextTokenId = lastTokenLogits.enumerated().max(by: { $0.element < $1.element})?.offset ?? 9 // TODO: Should I track if this fails + if (debug) { + print("Next token id: \(nextTokenId)") + } + if nextTokenId == endTokenId { + break + } + decodedTokenIds.append(nextTokenId) + decoderInputIds.append(nextTokenId) + } + + if rawString { + return tokenizer.decode(tokenIds: decodedTokenIds) + } + + return toKatex(formula: tokenizer.decode(tokenIds: decodedTokenIds)) + + + + + } + throw ModelError.imageError + } + + public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) async throws -> String { + return try await withCheckedThrowingContinuation { continuation in + DispatchQueue.global(qos: .userInitiated).async { + do { + let result = try self.texIt(image, rawString: rawString, debug: debug) + continuation.resume(returning: result) + } catch { + continuation.resume(throwing: error) + } + } + } + } + +} |