aboutsummaryrefslogtreecommitdiff
path: root/iTexSnip/Utils
diff options
context:
space:
mode:
Diffstat (limited to 'iTexSnip/Utils')
-rw-r--r--iTexSnip/Utils/KatexUtils.swift131
-rw-r--r--iTexSnip/Utils/TexTellerModel.swift165
2 files changed, 296 insertions, 0 deletions
diff --git a/iTexSnip/Utils/KatexUtils.swift b/iTexSnip/Utils/KatexUtils.swift
new file mode 100644
index 0000000..d339a71
--- /dev/null
+++ b/iTexSnip/Utils/KatexUtils.swift
@@ -0,0 +1,131 @@
+//
+// KatexUtils.swift
+// iTexSnip
+//
+// Created by Navan Chauhan on 10/13/24.
+//
+
+import Foundation
+
+func change(_ inputStr: String, oldInst: String, newInst: String, oldSurrL: Character, oldSurrR: Character, newSurrL: String, newSurrR: String) -> String {
+ var result = ""
+ var i = 0
+ let n = inputStr.count
+ let inputArray = Array(inputStr) // Convert string to array of characters for easier access
+
+ while i < n {
+ // Get the range for the substring equivalent to oldInst
+ if i + oldInst.count <= n && inputStr[inputStr.index(inputStr.startIndex, offsetBy: i)..<inputStr.index(inputStr.startIndex, offsetBy: i + oldInst.count)] == oldInst {
+ // Check if the old_inst is followed by old_surr_l
+ let start = i + oldInst.count
+ if start < n && inputArray[start] == oldSurrL {
+ var count = 1
+ var j = start + 1
+ var escaped = false
+
+ while j < n && count > 0 {
+ if inputArray[j] == "\\" && !escaped {
+ escaped = true
+ j += 1
+ continue
+ }
+
+ if inputArray[j] == oldSurrR && !escaped {
+ count -= 1
+ if count == 0 {
+ break
+ }
+ } else if inputArray[j] == oldSurrL && !escaped {
+ count += 1
+ }
+
+ escaped = false
+ j += 1
+ }
+
+ if count == 0 {
+ let innerContent = String(inputArray[(start + 1)..<j])
+ result += newInst + newSurrL + innerContent + newSurrR
+ i = j + 1
+ continue
+ } else {
+ result += newInst + newSurrL
+ i = start + 1
+ continue
+ }
+ }
+ }
+ result.append(inputArray[i])
+ i += 1
+ }
+
+ if oldInst != newInst && result.contains(oldInst + String(oldSurrL)) {
+ return change(result, oldInst: oldInst, newInst: newInst, oldSurrL: oldSurrL, oldSurrR: oldSurrR, newSurrL: newSurrL, newSurrR: newSurrR)
+ }
+
+ return result
+}
+
+
+func findSubstringPositions(_ string: String, substring: String) -> [Int] {
+ var positions: [Int] = []
+ var searchRange = string.startIndex..<string.endIndex
+
+ while let range = string.range(of: substring, options: [], range: searchRange) {
+ let position = string.distance(from: string.startIndex, to: range.lowerBound)
+ positions.append(position)
+ searchRange = range.upperBound..<string.endIndex
+ }
+
+ return positions
+}
+
+func rmDollarSurr(content: String) -> String {
+ let pattern = try! NSRegularExpression(pattern: "\\\\[a-zA-Z]+\\$.*?\\$|\\$.*?\\$", options: [])
+ var newContent = content
+ let matches = pattern.matches(in: content, options: [], range: NSRange(content.startIndex..<content.endIndex, in: content))
+
+ for match in matches.reversed() {
+ let matchedString = (content as NSString).substring(with: match.range)
+ if !matchedString.starts(with: "\\") {
+ let strippedMatch = matchedString.replacingOccurrences(of: "$", with: "")
+ newContent = newContent.replacingOccurrences(of: matchedString, with: " \(strippedMatch) ")
+ }
+ }
+
+ return newContent
+}
+
+func changeAll(inputStr: String, oldInst: String, newInst: String, oldSurrL: Character, oldSurrR: Character, newSurrL: String, newSurrR: String) -> String {
+ let positions = findSubstringPositions(inputStr, substring: oldInst + String(oldSurrL))
+ var result = inputStr
+
+ for pos in positions.reversed() {
+ let startIndex = result.index(result.startIndex, offsetBy: pos)
+ let substring = String(result[startIndex..<result.endIndex])
+ let changedSubstring = change(substring, oldInst: oldInst, newInst: newInst, oldSurrL: oldSurrL, oldSurrR: oldSurrR, newSurrL: newSurrL, newSurrR: newSurrR)
+ result.replaceSubrange(startIndex..<result.endIndex, with: changedSubstring)
+ }
+
+ return result
+}
+
+func toKatex(formula: String) -> String {
+ var res = formula
+ // Remove mbox surrounding
+ res = changeAll(inputStr: res, oldInst: "\\mbox ", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", newSurrR: "")
+ res = changeAll(inputStr: res, oldInst: "\\mbox", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", newSurrR: "")
+
+ // Additional processing similar to the Python version...
+ res = res.replacingOccurrences(of: "\\[", with: "")
+ res = res.replacingOccurrences(of: "\\]", with: "")
+ res = res.replacingOccurrences(of: "\\\\[?.!,\'\"](?:\\s|$)", with: "", options: .regularExpression)
+
+ // Merge consecutive `text`
+ res = rmDollarSurr(content: res)
+
+ // Remove extra spaces
+ res = res.replacingOccurrences(of: " +", with: " ", options: .regularExpression)
+
+ return res.trimmingCharacters(in: .whitespacesAndNewlines)
+}
diff --git a/iTexSnip/Utils/TexTellerModel.swift b/iTexSnip/Utils/TexTellerModel.swift
new file mode 100644
index 0000000..fb3bcd9
--- /dev/null
+++ b/iTexSnip/Utils/TexTellerModel.swift
@@ -0,0 +1,165 @@
+//
+// TexTellerModel.swift
+// iTexSnip
+//
+// Created by Navan Chauhan on 10/20/24.
+//
+
+import OnnxRuntimeBindings
+import AppKit
+
+public struct TexTellerModel {
+ public let encoderSession: ORTSession
+ public let decoderSession: ORTSession
+ private let tokenizer: RobertaTokenizerFast
+
+ public init() throws {
+ guard let encoderModelPath = Bundle.main.path(forResource: "encoder_model", ofType: "onnx") else {
+ print("Encoder model not found...")
+ throw ModelError.encoderModelNotFound
+ }
+ guard let decoderModelPath = Bundle.main.path(forResource: "decoder_model", ofType: "onnx") else {
+ print("Decoder model not found...")
+ throw ModelError.decoderModelNotFound
+ }
+ let env = try ORTEnv(loggingLevel: .warning)
+ let coreMLOptions = ORTCoreMLExecutionProviderOptions()
+ coreMLOptions.enableOnSubgraphs = true
+ coreMLOptions.createMLProgram = false
+ let options = try ORTSessionOptions()
+// try options.appendCoreMLExecutionProvider(with: coreMLOptions)
+ encoderSession = try ORTSession(env: env, modelPath: encoderModelPath, sessionOptions: options)
+ decoderSession = try ORTSession(env: env, modelPath: decoderModelPath, sessionOptions: options)
+
+ self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer")
+ }
+
+ public static func asyncInit() async throws -> TexTellerModel {
+ return try await withCheckedThrowingContinuation { continuation in
+ DispatchQueue.global(qos: .userInitiated).async {
+ do {
+ let model = try TexTellerModel()
+ continuation.resume(returning: model)
+ } catch {
+ continuation.resume(throwing: error)
+ }
+ }
+ }
+ }
+
+ public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) throws -> String {
+ let transformedImage = inferenceTransform(images: [image])
+ if let firstTransformedImage = transformedImage.first {
+ let pixelValues = ciImageToFloatArray(firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE))
+ if debug {
+ print("First few pixel inputs: \(pixelValues.prefix(10))")
+ }
+ let inputTensor = try ORTValue(
+ tensorData: NSMutableData(
+ data: Data(bytes: pixelValues, count: pixelValues.count * MemoryLayout<Float>.stride)
+ ),
+ elementType: .float,
+ shape: [
+ 1, 1, NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE)
+ ]
+ )
+ let encoderInput: [String: ORTValue] = [
+ "pixel_values": inputTensor
+ ]
+ let encoderOutputNames = try self.encoderSession.outputNames()
+ let encoderOutputs: [String: ORTValue] = try self.encoderSession.run(
+ withInputs: encoderInput,
+ outputNames: Set(encoderOutputNames),
+ runOptions: nil
+ )
+
+ if (debug) {
+ print("Encoder output: \(encoderOutputs)")
+ }
+
+ var decodedTokenIds: [Int] = []
+ let startTokenId = 0 // TODO: Move to tokenizer directly?
+ let endTokenId = 2
+ let maxDecoderLength: Int = 300
+ var decoderInputIds: [Int] = [startTokenId]
+ let vocabSize = 15000
+
+ if (debug) {
+ let encoderHiddenStatesData = try encoderOutputs["last_hidden_state"]!.tensorData() as Data
+ let encoderHiddenStatesArray = encoderHiddenStatesData.withUnsafeBytes {
+ Array(UnsafeBufferPointer<Float>(
+ start: $0.baseAddress!.assumingMemoryBound(to: Float.self),
+ count: encoderHiddenStatesData.count / MemoryLayout<Float>.stride
+ ))
+ }
+
+ print("First few values of encoder hidden states: \(encoderHiddenStatesArray.prefix(10))")
+ }
+
+ let decoderOutputNames = try self.decoderSession.outputNames()
+
+ for step in 0..<maxDecoderLength {
+ if (debug) {
+ print("Step \(step)")
+ }
+
+ let decoderInputIdsTensor = try ORTValue(
+ tensorData: NSMutableData(data: Data(bytes: decoderInputIds, count: decoderInputIds.count * MemoryLayout<Int64>.stride)),
+ elementType: .int64,
+ shape: [1, NSNumber(value: decoderInputIds.count)]
+ )
+ let decoderInputs: [String: ORTValue] = [
+ "input_ids": decoderInputIdsTensor,
+ "encoder_hidden_states": encoderOutputs["last_hidden_state"]!
+ ]
+ let decoderOutputs: [String: ORTValue] = try self.decoderSession.run(withInputs: decoderInputs, outputNames: Set(decoderOutputNames), runOptions: nil)
+ let logitsTensor = decoderOutputs["logits"]!
+ let logitsData = try logitsTensor.tensorData() as Data
+ let logits = logitsData.withUnsafeBytes {
+ Array(UnsafeBufferPointer<Float>(
+ start: $0.baseAddress!.assumingMemoryBound(to: Float.self),
+ count: logitsData.count / MemoryLayout<Float>.stride
+ ))
+ }
+ let sequenceLength = decoderInputIds.count
+ let startIndex = (sequenceLength - 1) * vocabSize
+ let endIndex = startIndex + vocabSize
+ let lastTokenLogits = Array(logits[startIndex..<endIndex])
+ let nextTokenId = lastTokenLogits.enumerated().max(by: { $0.element < $1.element})?.offset ?? 9 // TODO: Should I track if this fails
+ if (debug) {
+ print("Next token id: \(nextTokenId)")
+ }
+ if nextTokenId == endTokenId {
+ break
+ }
+ decodedTokenIds.append(nextTokenId)
+ decoderInputIds.append(nextTokenId)
+ }
+
+ if rawString {
+ return tokenizer.decode(tokenIds: decodedTokenIds)
+ }
+
+ return toKatex(formula: tokenizer.decode(tokenIds: decodedTokenIds))
+
+
+
+
+ }
+ throw ModelError.imageError
+ }
+
+ public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) async throws -> String {
+ return try await withCheckedThrowingContinuation { continuation in
+ DispatchQueue.global(qos: .userInitiated).async {
+ do {
+ let result = try self.texIt(image, rawString: rawString, debug: debug)
+ continuation.resume(returning: result)
+ } catch {
+ continuation.resume(throwing: error)
+ }
+ }
+ }
+ }
+
+}