From 05165cc8d98ef5ffa8ee3a8ba9bf1ad5e0b5a9ab Mon Sep 17 00:00:00 2001 From: Navan Chauhan Date: Mon, 21 Oct 2024 23:54:10 -0600 Subject: swift-format --- iTexSnip/Utils/ImageUtils.swift | 275 +++++++++++++------------- iTexSnip/Utils/KatexUtils.swift | 222 +++++++++++---------- iTexSnip/Utils/RobertaTokenizerFast.swift | 131 +++++++------ iTexSnip/Utils/TexTellerModel.swift | 311 +++++++++++++++--------------- 4 files changed, 491 insertions(+), 448 deletions(-) (limited to 'iTexSnip/Utils') diff --git a/iTexSnip/Utils/ImageUtils.swift b/iTexSnip/Utils/ImageUtils.swift index e59c4e5..73bab84 100644 --- a/iTexSnip/Utils/ImageUtils.swift +++ b/iTexSnip/Utils/ImageUtils.swift @@ -5,9 +5,9 @@ // Created by Navan Chauhan on 10/13/24. // -import Foundation -import CoreImage import AppKit +import CoreImage +import Foundation let IMAGE_MEAN: CGFloat = 0.9545467 let IMAGE_STD: CGFloat = 0.15394445 @@ -18,159 +18,166 @@ let MIN_WIDTH: CGFloat = 30 // Load image from URL func loadImage(from urlString: String) -> NSImage? { - guard let url = URL(string: urlString), let imageData = try? Data(contentsOf: url) else { - return nil - } - return NSImage(data: imageData) + guard let url = URL(string: urlString), let imageData = try? Data(contentsOf: url) else { + return nil + } + return NSImage(data: imageData) } // Helper to convert NSImage to CIImage func nsImageToCIImage(_ image: NSImage) -> CIImage? { - guard let data = image.tiffRepresentation, - let bitmapImage = NSBitmapImageRep(data: data), - let cgImage = bitmapImage.cgImage else { - return nil - } - return CIImage(cgImage: cgImage) + guard let data = image.tiffRepresentation, + let bitmapImage = NSBitmapImageRep(data: data), + let cgImage = bitmapImage.cgImage + else { + return nil + } + return CIImage(cgImage: cgImage) } func trimWhiteBorder(image: CIImage) -> CIImage? { - let context = CIContext() - - // Render the CIImage to a CGImage for pixel analysis - guard let cgImage = context.createCGImage(image, from: image.extent) else { - return nil - } - - // Access the pixel data - let width = cgImage.width - let height = cgImage.height - let colorSpace = CGColorSpaceCreateDeviceRGB() - let bytesPerPixel = 4 - let bytesPerRow = bytesPerPixel * width - let bitmapInfo = CGImageAlphaInfo.premultipliedLast.rawValue - var pixelData = [UInt8](repeating: 0, count: height * bytesPerRow) - - guard let contextRef = CGContext( - data: &pixelData, - width: width, - height: height, - bitsPerComponent: 8, - bytesPerRow: bytesPerRow, - space: colorSpace, - bitmapInfo: bitmapInfo - ) else { - return nil - } - - contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))) - - // Define the white color in RGBA - let whitePixel: [UInt8] = [255, 255, 255, 255] - - var minX = width - var minY = height - var maxX: Int = 0 - var maxY: Int = 0 - - // Scan the pixels to find the bounding box of non-white content - for y in 0.. maxX { maxX = x } - if y < minY { minY = y } - if y > maxY { maxY = y } - } - } - } - - // If no non-white content was found, return the original image - if minX == width || minY == height || maxX == 0 || maxY == 0 { - return image + let context = CIContext() + + // Render the CIImage to a CGImage for pixel analysis + guard let cgImage = context.createCGImage(image, from: image.extent) else { + return nil + } + + // Access the pixel data + let width = cgImage.width + let height = cgImage.height + let colorSpace = CGColorSpaceCreateDeviceRGB() + let bytesPerPixel = 4 + let bytesPerRow = bytesPerPixel * width + let bitmapInfo = CGImageAlphaInfo.premultipliedLast.rawValue + var pixelData = [UInt8](repeating: 0, count: height * bytesPerRow) + + guard + let contextRef = CGContext( + data: &pixelData, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: bytesPerRow, + space: colorSpace, + bitmapInfo: bitmapInfo + ) + else { + return nil + } + + contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))) + + // Define the white color in RGBA + let whitePixel: [UInt8] = [255, 255, 255, 255] + + var minX = width + var minY = height + var maxX: Int = 0 + var maxY: Int = 0 + + // Scan the pixels to find the bounding box of non-white content + for y in 0.. maxX { maxX = x } + if y < minY { minY = y } + if y > maxY { maxY = y } + } } + } + + // If no non-white content was found, return the original image + if minX == width || minY == height || maxX == 0 || maxY == 0 { + return image + } - // Compute the bounding box and crop the image - let croppedRect = CGRect(x: CGFloat(minX), y: CGFloat(minY), width: CGFloat(maxX - minX), height: CGFloat(maxY - minY)) - return image.cropped(to: croppedRect) + // Compute the bounding box and crop the image + let croppedRect = CGRect( + x: CGFloat(minX), y: CGFloat(minY), width: CGFloat(maxX - minX), height: CGFloat(maxY - minY)) + return image.cropped(to: croppedRect) } // Padding image with white border func addWhiteBorder(to image: CIImage, maxSize: CGFloat) -> CIImage { - let randomPadding = (0..<4).map { _ in CGFloat(arc4random_uniform(UInt32(maxSize))) } - var xPadding = randomPadding[0] + randomPadding[2] - var yPadding = randomPadding[1] + randomPadding[3] - - // Ensure the minimum width and height - if xPadding + image.extent.width < MIN_WIDTH { - let compensateWidth = (MIN_WIDTH - (xPadding + image.extent.width)) * 0.5 + 1 - xPadding += compensateWidth - } - if yPadding + image.extent.height < MIN_HEIGHT { - let compensateHeight = (MIN_HEIGHT - (yPadding + image.extent.height)) * 0.5 + 1 - yPadding += compensateHeight - } - - // Adding padding with a constant white color - let padFilter = CIFilter(name: "CICrop")! - let paddedRect = CGRect(x: image.extent.origin.x - randomPadding[0], - y: image.extent.origin.y - randomPadding[1], - width: image.extent.width + xPadding, - height: image.extent.height + yPadding) - padFilter.setValue(image, forKey: kCIInputImageKey) - padFilter.setValue(CIVector(cgRect: paddedRect), forKey: "inputRectangle") - - return padFilter.outputImage ?? image + let randomPadding = (0..<4).map { _ in CGFloat(arc4random_uniform(UInt32(maxSize))) } + var xPadding = randomPadding[0] + randomPadding[2] + var yPadding = randomPadding[1] + randomPadding[3] + + // Ensure the minimum width and height + if xPadding + image.extent.width < MIN_WIDTH { + let compensateWidth = (MIN_WIDTH - (xPadding + image.extent.width)) * 0.5 + 1 + xPadding += compensateWidth + } + if yPadding + image.extent.height < MIN_HEIGHT { + let compensateHeight = (MIN_HEIGHT - (yPadding + image.extent.height)) * 0.5 + 1 + yPadding += compensateHeight + } + + // Adding padding with a constant white color + let padFilter = CIFilter(name: "CICrop")! + let paddedRect = CGRect( + x: image.extent.origin.x - randomPadding[0], + y: image.extent.origin.y - randomPadding[1], + width: image.extent.width + xPadding, + height: image.extent.height + yPadding) + padFilter.setValue(image, forKey: kCIInputImageKey) + padFilter.setValue(CIVector(cgRect: paddedRect), forKey: "inputRectangle") + + return padFilter.outputImage ?? image } // Padding images to a required size func padding(images: [CIImage], requiredSize: CGFloat) -> [CIImage] { - return images.map { image in - let widthPadding = requiredSize - image.extent.width - let heightPadding = requiredSize - image.extent.height - return addWhiteBorder(to: image, maxSize: max(widthPadding, heightPadding)) - } + return images.map { image in + let widthPadding = requiredSize - image.extent.width + let heightPadding = requiredSize - image.extent.height + return addWhiteBorder(to: image, maxSize: max(widthPadding, heightPadding)) + } } // Transform pipeline to apply resize, normalize, etc. func inferenceTransform(images: [NSImage]) -> [CIImage] { - let ciImages = images.compactMap { nsImageToCIImage($0) } - - let trimmedImages = ciImages.compactMap { trimWhiteBorder(image: $0) } - let paddedImages = padding(images: trimmedImages, requiredSize: FIXED_IMG_SIZE) - - return paddedImages -} - -func ciImageToFloatArray(_ image: CIImage, size: CGSize) -> [Float] { - // Render the CIImage to a bitmap context - let context = CIContext() - guard let cgImage = context.createCGImage(image, from: image.extent) else { - return [] - } + let ciImages = images.compactMap { nsImageToCIImage($0) } - let width = Int(size.width) - let height = Int(size.height) - var pixelData = [UInt8](repeating: 0, count: width * height) // Use UInt8 for grayscale - - // Create bitmap context for rendering - let colorSpace = CGColorSpaceCreateDeviceGray() - guard let contextRef = CGContext( - data: &pixelData, - width: width, - height: height, - bitsPerComponent: 8, - bytesPerRow: width, - space: colorSpace, - bitmapInfo: CGImageAlphaInfo.none.rawValue - ) else { - return [] - } + let trimmedImages = ciImages.compactMap { trimWhiteBorder(image: $0) } + let paddedImages = padding(images: trimmedImages, requiredSize: FIXED_IMG_SIZE) - contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))) + return paddedImages +} - // Normalize pixel values to [0, 1] - return pixelData.map { Float($0) / 255.0 } +func ciImageToFloatArray(_ image: CIImage, size: CGSize) -> [Float] { + // Render the CIImage to a bitmap context + let context = CIContext() + guard let cgImage = context.createCGImage(image, from: image.extent) else { + return [] + } + + let width = Int(size.width) + let height = Int(size.height) + var pixelData = [UInt8](repeating: 0, count: width * height) // Use UInt8 for grayscale + + // Create bitmap context for rendering + let colorSpace = CGColorSpaceCreateDeviceGray() + guard + let contextRef = CGContext( + data: &pixelData, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: width, + space: colorSpace, + bitmapInfo: CGImageAlphaInfo.none.rawValue + ) + else { + return [] + } + + contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))) + + // Normalize pixel values to [0, 1] + return pixelData.map { Float($0) / 255.0 } } diff --git a/iTexSnip/Utils/KatexUtils.swift b/iTexSnip/Utils/KatexUtils.swift index d339a71..697e700 100644 --- a/iTexSnip/Utils/KatexUtils.swift +++ b/iTexSnip/Utils/KatexUtils.swift @@ -7,125 +7,145 @@ import Foundation -func change(_ inputStr: String, oldInst: String, newInst: String, oldSurrL: Character, oldSurrR: Character, newSurrL: String, newSurrR: String) -> String { - var result = "" - var i = 0 - let n = inputStr.count - let inputArray = Array(inputStr) // Convert string to array of characters for easier access - - while i < n { - // Get the range for the substring equivalent to oldInst - if i + oldInst.count <= n && inputStr[inputStr.index(inputStr.startIndex, offsetBy: i).. 0 { - if inputArray[j] == "\\" && !escaped { - escaped = true - j += 1 - continue - } - - if inputArray[j] == oldSurrR && !escaped { - count -= 1 - if count == 0 { - break - } - } else if inputArray[j] == oldSurrL && !escaped { - count += 1 - } - - escaped = false - j += 1 - } - - if count == 0 { - let innerContent = String(inputArray[(start + 1).. String { + var result = "" + var i = 0 + let n = inputStr.count + let inputArray = Array(inputStr) // Convert string to array of characters for easier access + + while i < n { + // Get the range for the substring equivalent to oldInst + if i + oldInst.count <= n + && inputStr[ + inputStr.index( + inputStr.startIndex, offsetBy: i).. 0 { + if inputArray[j] == "\\" && !escaped { + escaped = true + j += 1 + continue + } + + if inputArray[j] == oldSurrR && !escaped { + count -= 1 + if count == 0 { + break } + } else if inputArray[j] == oldSurrL && !escaped { + count += 1 + } + + escaped = false + j += 1 } - result.append(inputArray[i]) - i += 1 - } - if oldInst != newInst && result.contains(oldInst + String(oldSurrL)) { - return change(result, oldInst: oldInst, newInst: newInst, oldSurrL: oldSurrL, oldSurrR: oldSurrR, newSurrL: newSurrL, newSurrR: newSurrR) + if count == 0 { + let innerContent = String(inputArray[(start + 1).. [Int] { - var positions: [Int] = [] - var searchRange = string.startIndex.. String { - let pattern = try! NSRegularExpression(pattern: "\\\\[a-zA-Z]+\\$.*?\\$|\\$.*?\\$", options: []) - var newContent = content - let matches = pattern.matches(in: content, options: [], range: NSRange(content.startIndex.. String { - let positions = findSubstringPositions(inputStr, substring: oldInst + String(oldSurrL)) - var result = inputStr - - for pos in positions.reversed() { - let startIndex = result.index(result.startIndex, offsetBy: pos) - let substring = String(result[startIndex.. String { + let positions = findSubstringPositions(inputStr, substring: oldInst + String(oldSurrL)) + var result = inputStr + + for pos in positions.reversed() { + let startIndex = result.index(result.startIndex, offsetBy: pos) + let substring = String(result[startIndex.. String { - var res = formula - // Remove mbox surrounding - res = changeAll(inputStr: res, oldInst: "\\mbox ", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", newSurrR: "") - res = changeAll(inputStr: res, oldInst: "\\mbox", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", newSurrR: "") - - // Additional processing similar to the Python version... - res = res.replacingOccurrences(of: "\\[", with: "") - res = res.replacingOccurrences(of: "\\]", with: "") - res = res.replacingOccurrences(of: "\\\\[?.!,\'\"](?:\\s|$)", with: "", options: .regularExpression) - - // Merge consecutive `text` - res = rmDollarSurr(content: res) - - // Remove extra spaces - res = res.replacingOccurrences(of: " +", with: " ", options: .regularExpression) - - return res.trimmingCharacters(in: .whitespacesAndNewlines) + var res = formula + // Remove mbox surrounding + res = changeAll( + inputStr: res, oldInst: "\\mbox ", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", + newSurrR: "") + res = changeAll( + inputStr: res, oldInst: "\\mbox", newInst: " ", oldSurrL: "{", oldSurrR: "}", newSurrL: "", + newSurrR: "") + + // Additional processing similar to the Python version... + res = res.replacingOccurrences(of: "\\[", with: "") + res = res.replacingOccurrences(of: "\\]", with: "") + res = res.replacingOccurrences( + of: "\\\\[?.!,\'\"](?:\\s|$)", with: "", options: .regularExpression) + + // Merge consecutive `text` + res = rmDollarSurr(content: res) + + // Remove extra spaces + res = res.replacingOccurrences(of: " +", with: " ", options: .regularExpression) + + return res.trimmingCharacters(in: .whitespacesAndNewlines) } diff --git a/iTexSnip/Utils/RobertaTokenizerFast.swift b/iTexSnip/Utils/RobertaTokenizerFast.swift index d61e8c7..888cc20 100644 --- a/iTexSnip/Utils/RobertaTokenizerFast.swift +++ b/iTexSnip/Utils/RobertaTokenizerFast.swift @@ -8,81 +8,88 @@ import Foundation class RobertaTokenizerFast { - var vocab: [String: Int] = [:] - var idToToken: [Int: String] = [:] - var specialTokens: [String] = [] - var unkTokenId: Int? + var vocab: [String: Int] = [:] + var idToToken: [Int: String] = [:] + var specialTokens: [String] = [] + var unkTokenId: Int? - init(vocabFile: String, tokenizerFile: String) { - if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"), - let vocabData = try? Data(contentsOf: vocabURL), - let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) as? [String: Int] { - self.vocab = vocabDict - } + init(vocabFile: String, tokenizerFile: String) { + if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"), + let vocabData = try? Data(contentsOf: vocabURL), + let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) + as? [String: Int] + { + self.vocab = vocabDict + } - if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"), - let tokenizerData = try? Data(contentsOf: tokenizerURL), - let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) as? [String: Any] { - self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? [] - } + if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"), + let tokenizerData = try? Data(contentsOf: tokenizerURL), + let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) + as? [String: Any] + { + self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? [] + } - self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key } + self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key } - self.unkTokenId = vocab[""] - } + self.unkTokenId = vocab[""] + } - func encode(text: String) -> [Int] { - let tokens = tokenize(text) - return tokens.map { vocab[$0] ?? unkTokenId! } - } + func encode(text: String) -> [Int] { + let tokens = tokenize(text) + return tokens.map { vocab[$0] ?? unkTokenId! } + } - func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String { - let tokens = tokenIds.compactMap { idToToken[$0] } - let filteredTokens = skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "" } : tokens - return convertTokensToString(filteredTokens) - } + func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String { + let tokens = tokenIds.compactMap { idToToken[$0] } + let filteredTokens = + skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "" } : tokens + return convertTokensToString(filteredTokens) + } - private func tokenize(_ text: String) -> [String] { - let cleanedText = cleanText(text) - let words = cleanedText.split(separator: " ").map { String($0) } - - var tokens: [String] = [] - for word in words { - tokens.append(contentsOf: bpeEncode(word)) - } - return tokens + private func tokenize(_ text: String) -> [String] { + let cleanedText = cleanText(text) + let words = cleanedText.split(separator: " ").map { String($0) } + + var tokens: [String] = [] + for word in words { + tokens.append(contentsOf: bpeEncode(word)) } + return tokens + } - private func bpeEncode(_ word: String) -> [String] { - if vocab.keys.contains(word) { - return [word] - } + private func bpeEncode(_ word: String) -> [String] { + if vocab.keys.contains(word) { + return [word] + } - let chars = Array(word) - var tokens: [String] = [] - var i = 0 + let chars = Array(word) + var tokens: [String] = [] + var i = 0 - while i < chars.count { - if i < chars.count - 1 { - let pair = String(chars[i]) + String(chars[i + 1]) - if vocab.keys.contains(pair) { - tokens.append(pair) - i += 2 - continue - } - } - tokens.append(String(chars[i])) - i += 1 + while i < chars.count { + if i < chars.count - 1 { + let pair = String(chars[i]) + String(chars[i + 1]) + if vocab.keys.contains(pair) { + tokens.append(pair) + i += 2 + continue } - return tokens + } + tokens.append(String(chars[i])) + i += 1 } + return tokens + } - private func cleanText(_ text: String) -> String { - return text.trimmingCharacters(in: .whitespacesAndNewlines) - } + private func cleanText(_ text: String) -> String { + return text.trimmingCharacters(in: .whitespacesAndNewlines) + } - private func convertTokensToString(_ tokens: [String]) -> String { - let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ") - return text.replacingOccurrences(of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil).trimmingCharacters(in: .whitespaces) - } + private func convertTokensToString(_ tokens: [String]) -> String { + let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ") + return text.replacingOccurrences( + of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil + ).trimmingCharacters(in: .whitespaces) + } } diff --git a/iTexSnip/Utils/TexTellerModel.swift b/iTexSnip/Utils/TexTellerModel.swift index 2f71919..0104bfb 100644 --- a/iTexSnip/Utils/TexTellerModel.swift +++ b/iTexSnip/Utils/TexTellerModel.swift @@ -5,167 +5,176 @@ // Created by Navan Chauhan on 10/20/24. // -import OnnxRuntimeBindings import AppKit +import OnnxRuntimeBindings public enum ModelError: Error { - case encoderModelNotFound - case decoderModelNotFound - case imageError + case encoderModelNotFound + case decoderModelNotFound + case imageError } public struct TexTellerModel { - public let encoderSession: ORTSession - public let decoderSession: ORTSession - private let tokenizer: RobertaTokenizerFast - - public init() throws { - guard let encoderModelPath = Bundle.main.path(forResource: "encoder_model", ofType: "onnx") else { - print("Encoder model not found...") - throw ModelError.encoderModelNotFound - } - guard let decoderModelPath = Bundle.main.path(forResource: "decoder_model", ofType: "onnx") else { - print("Decoder model not found...") - throw ModelError.decoderModelNotFound - } - let env = try ORTEnv(loggingLevel: .warning) - let coreMLOptions = ORTCoreMLExecutionProviderOptions() - coreMLOptions.enableOnSubgraphs = true - coreMLOptions.createMLProgram = false - let options = try ORTSessionOptions() -// try options.appendCoreMLExecutionProvider(with: coreMLOptions) - encoderSession = try ORTSession(env: env, modelPath: encoderModelPath, sessionOptions: options) - decoderSession = try ORTSession(env: env, modelPath: decoderModelPath, sessionOptions: options) - - self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer") + public let encoderSession: ORTSession + public let decoderSession: ORTSession + private let tokenizer: RobertaTokenizerFast + + public init() throws { + guard let encoderModelPath = Bundle.main.path(forResource: "encoder_model", ofType: "onnx") + else { + print("Encoder model not found...") + throw ModelError.encoderModelNotFound + } + guard let decoderModelPath = Bundle.main.path(forResource: "decoder_model", ofType: "onnx") + else { + print("Decoder model not found...") + throw ModelError.decoderModelNotFound } - - public static func asyncInit() async throws -> TexTellerModel { - return try await withCheckedThrowingContinuation { continuation in - DispatchQueue.global(qos: .userInitiated).async { - do { - let model = try TexTellerModel() - continuation.resume(returning: model) - } catch { - continuation.resume(throwing: error) - } - } + let env = try ORTEnv(loggingLevel: .warning) + let coreMLOptions = ORTCoreMLExecutionProviderOptions() + coreMLOptions.enableOnSubgraphs = true + coreMLOptions.createMLProgram = false + let options = try ORTSessionOptions() + // try options.appendCoreMLExecutionProvider(with: coreMLOptions) + encoderSession = try ORTSession(env: env, modelPath: encoderModelPath, sessionOptions: options) + decoderSession = try ORTSession(env: env, modelPath: decoderModelPath, sessionOptions: options) + + self.tokenizer = RobertaTokenizerFast(vocabFile: "vocab", tokenizerFile: "tokenizer") + } + + public static func asyncInit() async throws -> TexTellerModel { + return try await withCheckedThrowingContinuation { continuation in + DispatchQueue.global(qos: .userInitiated).async { + do { + let model = try TexTellerModel() + continuation.resume(returning: model) + } catch { + continuation.resume(throwing: error) } + } } - - public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) throws -> String { - let transformedImage = inferenceTransform(images: [image]) - if let firstTransformedImage = transformedImage.first { - let pixelValues = ciImageToFloatArray(firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE)) - if debug { - print("First few pixel inputs: \(pixelValues.prefix(10))") - } - let inputTensor = try ORTValue( - tensorData: NSMutableData( - data: Data(bytes: pixelValues, count: pixelValues.count * MemoryLayout.stride) - ), - elementType: .float, - shape: [ - 1, 1, NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE) - ] - ) - let encoderInput: [String: ORTValue] = [ - "pixel_values": inputTensor - ] - let encoderOutputNames = try self.encoderSession.outputNames() - let encoderOutputs: [String: ORTValue] = try self.encoderSession.run( - withInputs: encoderInput, - outputNames: Set(encoderOutputNames), - runOptions: nil - ) - - if (debug) { - print("Encoder output: \(encoderOutputs)") - } - - var decodedTokenIds: [Int] = [] - let startTokenId = 0 // TODO: Move to tokenizer directly? - let endTokenId = 2 - let maxDecoderLength: Int = 300 - var decoderInputIds: [Int] = [startTokenId] - let vocabSize = 15000 - - if (debug) { - let encoderHiddenStatesData = try encoderOutputs["last_hidden_state"]!.tensorData() as Data - let encoderHiddenStatesArray = encoderHiddenStatesData.withUnsafeBytes { - Array(UnsafeBufferPointer( - start: $0.baseAddress!.assumingMemoryBound(to: Float.self), - count: encoderHiddenStatesData.count / MemoryLayout.stride - )) - } - - print("First few values of encoder hidden states: \(encoderHiddenStatesArray.prefix(10))") - } - - let decoderOutputNames = try self.decoderSession.outputNames() - - for step in 0...stride)), - elementType: .int64, - shape: [1, NSNumber(value: decoderInputIds.count)] - ) - let decoderInputs: [String: ORTValue] = [ - "input_ids": decoderInputIdsTensor, - "encoder_hidden_states": encoderOutputs["last_hidden_state"]! - ] - let decoderOutputs: [String: ORTValue] = try self.decoderSession.run(withInputs: decoderInputs, outputNames: Set(decoderOutputNames), runOptions: nil) - let logitsTensor = decoderOutputs["logits"]! - let logitsData = try logitsTensor.tensorData() as Data - let logits = logitsData.withUnsafeBytes { - Array(UnsafeBufferPointer( - start: $0.baseAddress!.assumingMemoryBound(to: Float.self), - count: logitsData.count / MemoryLayout.stride - )) - } - let sequenceLength = decoderInputIds.count - let startIndex = (sequenceLength - 1) * vocabSize - let endIndex = startIndex + vocabSize - let lastTokenLogits = Array(logits[startIndex.. String + { + let transformedImage = inferenceTransform(images: [image]) + if let firstTransformedImage = transformedImage.first { + let pixelValues = ciImageToFloatArray( + firstTransformedImage, size: CGSize(width: FIXED_IMG_SIZE, height: FIXED_IMG_SIZE)) + if debug { + print("First few pixel inputs: \(pixelValues.prefix(10))") + } + let inputTensor = try ORTValue( + tensorData: NSMutableData( + data: Data(bytes: pixelValues, count: pixelValues.count * MemoryLayout.stride) + ), + elementType: .float, + shape: [ + 1, 1, NSNumber(value: FIXED_IMG_SIZE), NSNumber(value: FIXED_IMG_SIZE), + ] + ) + let encoderInput: [String: ORTValue] = [ + "pixel_values": inputTensor + ] + let encoderOutputNames = try self.encoderSession.outputNames() + let encoderOutputs: [String: ORTValue] = try self.encoderSession.run( + withInputs: encoderInput, + outputNames: Set(encoderOutputNames), + runOptions: nil + ) + + if debug { + print("Encoder output: \(encoderOutputs)") + } + + var decodedTokenIds: [Int] = [] + let startTokenId = 0 // TODO: Move to tokenizer directly? + let endTokenId = 2 + let maxDecoderLength: Int = 300 + var decoderInputIds: [Int] = [startTokenId] + let vocabSize = 15000 + + if debug { + let encoderHiddenStatesData = try encoderOutputs["last_hidden_state"]!.tensorData() as Data + let encoderHiddenStatesArray = encoderHiddenStatesData.withUnsafeBytes { + Array( + UnsafeBufferPointer( + start: $0.baseAddress!.assumingMemoryBound(to: Float.self), + count: encoderHiddenStatesData.count / MemoryLayout.stride + )) + } + + print("First few values of encoder hidden states: \(encoderHiddenStatesArray.prefix(10))") + } + + let decoderOutputNames = try self.decoderSession.outputNames() + + for step in 0...stride)), + elementType: .int64, + shape: [1, NSNumber(value: decoderInputIds.count)] + ) + let decoderInputs: [String: ORTValue] = [ + "input_ids": decoderInputIdsTensor, + "encoder_hidden_states": encoderOutputs["last_hidden_state"]!, + ] + let decoderOutputs: [String: ORTValue] = try self.decoderSession.run( + withInputs: decoderInputs, outputNames: Set(decoderOutputNames), runOptions: nil) + let logitsTensor = decoderOutputs["logits"]! + let logitsData = try logitsTensor.tensorData() as Data + let logits = logitsData.withUnsafeBytes { + Array( + UnsafeBufferPointer( + start: $0.baseAddress!.assumingMemoryBound(to: Float.self), + count: logitsData.count / MemoryLayout.stride + )) + } + let sequenceLength = decoderInputIds.count + let startIndex = (sequenceLength - 1) * vocabSize + let endIndex = startIndex + vocabSize + let lastTokenLogits = Array(logits[startIndex.. String { - return try await withCheckedThrowingContinuation { continuation in - DispatchQueue.global(qos: .userInitiated).async { - do { - let result = try self.texIt(image, rawString: rawString, debug: debug) - continuation.resume(returning: result) - } catch { - continuation.resume(throwing: error) - } - } - } + throw ModelError.imageError + } + + public func texIt(_ image: NSImage, rawString: Bool = false, debug: Bool = false) async throws + -> String + { + return try await withCheckedThrowingContinuation { continuation in + DispatchQueue.global(qos: .userInitiated).async { + do { + let result = try self.texIt(image, rawString: rawString, debug: debug) + continuation.resume(returning: result) + } catch { + continuation.resume(throwing: error) + } + } } - + } + } -- cgit v1.2.3