aboutsummaryrefslogtreecommitdiff
path: root/iTexSnip/Utils/TexTellerModel.swift
diff options
context:
space:
mode:
Diffstat (limited to 'iTexSnip/Utils/TexTellerModel.swift')
-rw-r--r--iTexSnip/Utils/TexTellerModel.swift311
1 files changed, 160 insertions, 151 deletions
diff --git a/iTexSnip/Utils/TexTellerModel.swift b/iTexSnip/Utils/TexTellerModel.swift
index 2f71919..0104bfb 100644
--- a/iTexSnip/Utils/TexTellerModel.swift
+++ b/iTexSnip/Utils/TexTellerModel.swift
@@ -5,167 +5,176 @@
// Created by Navan Chauhan on 10/20/24.
//
-import OnnxRuntimeBindings
import AppKit
+import OnnxRuntimeBindings
public enum ModelError: Error {
- case encoderModelNotFound
- case decoderModelNotFound
- case imageError
+ case encoderModelNotFound
+ case decoderModelNotFound
+ case imageError
}
public struct TexTellerModel {
- public let encoderSession: ORTSession
- public let decoderSession: ORTSession
- private let tokenizer: RobertaTokenizerFast
-
- public init() throws {
- guard let encoderModelPath = Bundle.main.path(forResource: "encoder_model", ofType: "onnx") else {
- print("Encoder model not found...")
- throw ModelError.encoderModelNotFound
- }
- guard let decoderModelPath = Bundle.main.path(forResource: "decoder_model", ofType: "onnx") else {
- print("Decoder model not found...")
- throw ModelError.decoderModelNotFound
- }
- let env = try ORTEnv(loggingLevel: .warning)
- let coreMLOptions = ORTCoreMLExecutionProviderOptions()
- coreMLOptions.enableOnSubgraphs = true
- coreMLOptions.createMLProgram = false
- let options = try ORTSessionOptions()
-// try options.appendCoreMLExecutionProvider(with: coreMLOptions)
- encoderSession = try ORTSession(env: env, modelPath: encoderModelPath, sessionOptions: options)
- decoderSession = try ORTSession(env: env, modelPath: decoderModelPath, sessionOptions: options)
-
- self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer")
+ public let encoderSession: ORTSession
+ public let decoderSession: ORTSession
+ private let tokenizer: RobertaTokenizerFast
+
+ public init() throws {
+ guard let encoderModelPath = Bundle.main.path(forResource: "encoder_model", ofType: "onnx")
+ else {
+ print("Encoder model not found...")
+ throw ModelError.encoderModelNotFound
+ }
+ guard let decoderModelPath = Bundle.main.path(forResource: "decoder_model", ofType: "onnx")
+ else {
+ print("Decoder model not found...")
+ throw ModelError.decoderModelNotFound
}
-
- public static func asyncInit() async throws -> TexTellerModel {
- return try await withCheckedThrowingContinuation { continuation in
- DispatchQueue.global(qos: .userInitiated).async {
- do {
- let model = try TexTellerModel()
- continuation.resume(returning: model)
- } catch {
- continuation.resume(throwing: error)
- }
- }
+ let env = try ORTEnv(loggingLevel: .warning)
+ let coreMLOptions = ORTCoreMLExecutionProviderOptions()
+ coreMLOptions.enableOnSubgraphs = true
+ coreMLOptions.createMLProgram = false
+ let options = try ORTSessionOptions()
+ // try options.appendCoreMLExecutionProvider(with: coreMLOptions)
+ encoderSession = try ORTSession(env: env, modelPath: encoderModelPath, sessionOptions: options)
+ decoderSession = try ORTSession(env: env, modelPath: decoderModelPath, sessionOptions: options)
+
+ self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer")
+ }
+
+ public static func asyncInit() async throws -> TexTellerModel {
+ return try await withCheckedThrowingContinuation { continuation in
+ DispatchQueue.global(qos: .userInitiated).async {
+ do {
+ let model = try TexTellerModel()
+ continuation.resume(returning: model)
+ } catch {
+ continuation.resume(throwing: error)
}
+ }
}
-
- public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) throws -> String {
- let transformedImage = inferenceTransform(images: [image])
- if let firstTransformedImage = transformedImage.first {
- let pixelValues = ciImageToFloatArray(firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE))
- if debug {
- print("First few pixel inputs: \(pixelValues.prefix(10))")
- }
- let inputTensor = try ORTValue(
- tensorData: NSMutableData(
- data: Data(bytes: pixelValues, count: pixelValues.count * MemoryLayout<Float>.stride)
- ),
- elementType: .float,
- shape: [
- 1, 1, NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE)
- ]
- )
- let encoderInput: [String: ORTValue] = [
- "pixel_values": inputTensor
- ]
- let encoderOutputNames = try self.encoderSession.outputNames()
- let encoderOutputs: [String: ORTValue] = try self.encoderSession.run(
- withInputs: encoderInput,
- outputNames: Set(encoderOutputNames),
- runOptions: nil
- )
-
- if (debug) {
- print("Encoder output: \(encoderOutputs)")
- }
-
- var decodedTokenIds: [Int] = []
- let startTokenId = 0 // TODO: Move to tokenizer directly?
- let endTokenId = 2
- let maxDecoderLength: Int = 300
- var decoderInputIds: [Int] = [startTokenId]
- let vocabSize = 15000
-
- if (debug) {
- let encoderHiddenStatesData = try encoderOutputs["last_hidden_state"]!.tensorData() as Data
- let encoderHiddenStatesArray = encoderHiddenStatesData.withUnsafeBytes {
- Array(UnsafeBufferPointer<Float>(
- start: $0.baseAddress!.assumingMemoryBound(to: Float.self),
- count: encoderHiddenStatesData.count / MemoryLayout<Float>.stride
- ))
- }
-
- print("First few values of encoder hidden states: \(encoderHiddenStatesArray.prefix(10))")
- }
-
- let decoderOutputNames = try self.decoderSession.outputNames()
-
- for step in 0..<maxDecoderLength {
- if (debug) {
- print("Step \(step)")
- }
-
- let decoderInputIdsTensor = try ORTValue(
- tensorData: NSMutableData(data: Data(bytes: decoderInputIds, count: decoderInputIds.count * MemoryLayout<Int64>.stride)),
- elementType: .int64,
- shape: [1, NSNumber(value: decoderInputIds.count)]
- )
- let decoderInputs: [String: ORTValue] = [
- "input_ids": decoderInputIdsTensor,
- "encoder_hidden_states": encoderOutputs["last_hidden_state"]!
- ]
- let decoderOutputs: [String: ORTValue] = try self.decoderSession.run(withInputs: decoderInputs, outputNames: Set(decoderOutputNames), runOptions: nil)
- let logitsTensor = decoderOutputs["logits"]!
- let logitsData = try logitsTensor.tensorData() as Data
- let logits = logitsData.withUnsafeBytes {
- Array(UnsafeBufferPointer<Float>(
- start: $0.baseAddress!.assumingMemoryBound(to: Float.self),
- count: logitsData.count / MemoryLayout<Float>.stride
- ))
- }
- let sequenceLength = decoderInputIds.count
- let startIndex = (sequenceLength - 1) * vocabSize
- let endIndex = startIndex + vocabSize
- let lastTokenLogits = Array(logits[startIndex..<endIndex])
- let nextTokenId = lastTokenLogits.enumerated().max(by: { $0.element < $1.element})?.offset ?? 9 // TODO: Should I track if this fails
- if (debug) {
- print("Next token id: \(nextTokenId)")
- }
- if nextTokenId == endTokenId {
- break
- }
- decodedTokenIds.append(nextTokenId)
- decoderInputIds.append(nextTokenId)
- }
-
- if rawString {
- return tokenizer.decode(tokenIds: decodedTokenIds)
- }
-
- return toKatex(formula: tokenizer.decode(tokenIds: decodedTokenIds))
-
-
-
-
+ }
+
+ public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) throws -> String
+ {
+ let transformedImage = inferenceTransform(images: [image])
+ if let firstTransformedImage = transformedImage.first {
+ let pixelValues = ciImageToFloatArray(
+ firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE))
+ if debug {
+ print("First few pixel inputs: \(pixelValues.prefix(10))")
+ }
+ let inputTensor = try ORTValue(
+ tensorData: NSMutableData(
+ data: Data(bytes: pixelValues, count: pixelValues.count * MemoryLayout<Float>.stride)
+ ),
+ elementType: .float,
+ shape: [
+ 1, 1, NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE),
+ ]
+ )
+ let encoderInput: [String: ORTValue] = [
+ "pixel_values": inputTensor
+ ]
+ let encoderOutputNames = try self.encoderSession.outputNames()
+ let encoderOutputs: [String: ORTValue] = try self.encoderSession.run(
+ withInputs: encoderInput,
+ outputNames: Set(encoderOutputNames),
+ runOptions: nil
+ )
+
+ if debug {
+ print("Encoder output: \(encoderOutputs)")
+ }
+
+ var decodedTokenIds: [Int] = []
+ let startTokenId = 0 // TODO: Move to tokenizer directly?
+ let endTokenId = 2
+ let maxDecoderLength: Int = 300
+ var decoderInputIds: [Int] = [startTokenId]
+ let vocabSize = 15000
+
+ if debug {
+ let encoderHiddenStatesData = try encoderOutputs["last_hidden_state"]!.tensorData() as Data
+ let encoderHiddenStatesArray = encoderHiddenStatesData.withUnsafeBytes {
+ Array(
+ UnsafeBufferPointer<Float>(
+ start: $0.baseAddress!.assumingMemoryBound(to: Float.self),
+ count: encoderHiddenStatesData.count / MemoryLayout<Float>.stride
+ ))
+ }
+
+ print("First few values of encoder hidden states: \(encoderHiddenStatesArray.prefix(10))")
+ }
+
+ let decoderOutputNames = try self.decoderSession.outputNames()
+
+ for step in 0..<maxDecoderLength {
+ if debug {
+ print("Step \(step)")
+ }
+
+ let decoderInputIdsTensor = try ORTValue(
+ tensorData: NSMutableData(
+ data: Data(
+ bytes: decoderInputIds, count: decoderInputIds.count * MemoryLayout<Int64>.stride)),
+ elementType: .int64,
+ shape: [1, NSNumber(value: decoderInputIds.count)]
+ )
+ let decoderInputs: [String: ORTValue] = [
+ "input_ids": decoderInputIdsTensor,
+ "encoder_hidden_states": encoderOutputs["last_hidden_state"]!,
+ ]
+ let decoderOutputs: [String: ORTValue] = try self.decoderSession.run(
+ withInputs: decoderInputs, outputNames: Set(decoderOutputNames), runOptions: nil)
+ let logitsTensor = decoderOutputs["logits"]!
+ let logitsData = try logitsTensor.tensorData() as Data
+ let logits = logitsData.withUnsafeBytes {
+ Array(
+ UnsafeBufferPointer<Float>(
+ start: $0.baseAddress!.assumingMemoryBound(to: Float.self),
+ count: logitsData.count / MemoryLayout<Float>.stride
+ ))
+ }
+ let sequenceLength = decoderInputIds.count
+ let startIndex = (sequenceLength - 1) * vocabSize
+ let endIndex = startIndex + vocabSize
+ let lastTokenLogits = Array(logits[startIndex..<endIndex])
+ let nextTokenId =
+ lastTokenLogits.enumerated().max(by: { $0.element < $1.element })?.offset ?? 9 // TODO: Should I track if this fails
+ if debug {
+ print("Next token id: \(nextTokenId)")
+ }
+ if nextTokenId == endTokenId {
+ break
}
- throw ModelError.imageError
+ decodedTokenIds.append(nextTokenId)
+ decoderInputIds.append(nextTokenId)
+ }
+
+ if rawString {
+ return tokenizer.decode(tokenIds: decodedTokenIds)
+ }
+
+ return toKatex(formula: tokenizer.decode(tokenIds: decodedTokenIds))
+
}
-
- public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) async throws -> String {
- return try await withCheckedThrowingContinuation { continuation in
- DispatchQueue.global(qos: .userInitiated).async {
- do {
- let result = try self.texIt(image, rawString: rawString, debug: debug)
- continuation.resume(returning: result)
- } catch {
- continuation.resume(throwing: error)
- }
- }
- }
+ throw ModelError.imageError
+ }
+
+ public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) async throws
+ -> String
+ {
+ return try await withCheckedThrowingContinuation { continuation in
+ DispatchQueue.global(qos: .userInitiated).async {
+ do {
+ let result = try self.texIt(image, rawString: rawString, debug: debug)
+ continuation.resume(returning: result)
+ } catch {
+ continuation.resume(throwing: error)
+ }
+ }
}
-
+ }
+
}