// // TexTellerModel.swift // iTexSnip // // Created by Navan Chauhan on 10/20/24. // import OnnxRuntimeBindings import AppKit public enum ModelError: Error { case encoderModelNotFound case decoderModelNotFound case imageError } public struct TexTellerModel { public let encoderSession: ORTSession public let decoderSession: ORTSession private let tokenizer: RobertaTokenizerFast public init() throws { guard let encoderModelPath = Bundle.main.path(forResource: "encoder_model", ofType: "onnx") else { print("Encoder model not found...") throw ModelError.encoderModelNotFound } guard let decoderModelPath = Bundle.main.path(forResource: "decoder_model", ofType: "onnx") else { print("Decoder model not found...") throw ModelError.decoderModelNotFound } let env = try ORTEnv(loggingLevel: .warning) let coreMLOptions = ORTCoreMLExecutionProviderOptions() coreMLOptions.enableOnSubgraphs = true coreMLOptions.createMLProgram = false let options = try ORTSessionOptions() // try options.appendCoreMLExecutionProvider(with: coreMLOptions) encoderSession = try ORTSession(env: env, modelPath: encoderModelPath, sessionOptions: options) decoderSession = try ORTSession(env: env, modelPath: decoderModelPath, sessionOptions: options) self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer") } public static func asyncInit() async throws -> TexTellerModel { return try await withCheckedThrowingContinuation { continuation in DispatchQueue.global(qos: .userInitiated).async { do { let model = try TexTellerModel() continuation.resume(returning: model) } catch { continuation.resume(throwing: error) } } } } public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) throws -> String { let transformedImage = inferenceTransform(images: [image]) if let firstTransformedImage = transformedImage.first { let pixelValues = ciImageToFloatArray(firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE)) if debug { print("First few pixel inputs: \(pixelValues.prefix(10))") } let inputTensor = try ORTValue( tensorData: NSMutableData( data: Data(bytes: pixelValues, count: pixelValues.count * MemoryLayout.stride) ), elementType: .float, shape: [ 1, 1, NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE) ] ) let encoderInput: [String: ORTValue] = [ "pixel_values": inputTensor ] let encoderOutputNames = try self.encoderSession.outputNames() let encoderOutputs: [String: ORTValue] = try self.encoderSession.run( withInputs: encoderInput, outputNames: Set(encoderOutputNames), runOptions: nil ) if (debug) { print("Encoder output: \(encoderOutputs)") } var decodedTokenIds: [Int] = [] let startTokenId = 0 // TODO: Move to tokenizer directly? let endTokenId = 2 let maxDecoderLength: Int = 300 var decoderInputIds: [Int] = [startTokenId] let vocabSize = 15000 if (debug) { let encoderHiddenStatesData = try encoderOutputs["last_hidden_state"]!.tensorData() as Data let encoderHiddenStatesArray = encoderHiddenStatesData.withUnsafeBytes { Array(UnsafeBufferPointer( start: $0.baseAddress!.assumingMemoryBound(to: Float.self), count: encoderHiddenStatesData.count / MemoryLayout.stride )) } print("First few values of encoder hidden states: \(encoderHiddenStatesArray.prefix(10))") } let decoderOutputNames = try self.decoderSession.outputNames() for step in 0...stride)), elementType: .int64, shape: [1, NSNumber(value: decoderInputIds.count)] ) let decoderInputs: [String: ORTValue] = [ "input_ids": decoderInputIdsTensor, "encoder_hidden_states": encoderOutputs["last_hidden_state"]! ] let decoderOutputs: [String: ORTValue] = try self.decoderSession.run(withInputs: decoderInputs, outputNames: Set(decoderOutputNames), runOptions: nil) let logitsTensor = decoderOutputs["logits"]! let logitsData = try logitsTensor.tensorData() as Data let logits = logitsData.withUnsafeBytes { Array(UnsafeBufferPointer( start: $0.baseAddress!.assumingMemoryBound(to: Float.self), count: logitsData.count / MemoryLayout.stride )) } let sequenceLength = decoderInputIds.count let startIndex = (sequenceLength - 1) * vocabSize let endIndex = startIndex + vocabSize let lastTokenLogits = Array(logits[startIndex.. String { return try await withCheckedThrowingContinuation { continuation in DispatchQueue.global(qos: .userInitiated).async { do { let result = try self.texIt(image, rawString: rawString, debug: debug) continuation.resume(returning: result) } catch { continuation.resume(throwing: error) } } } } }