aboutsummaryrefslogtreecommitdiff
path: root/iTexSnip/Utils
diff options
context:
space:
mode:
Diffstat (limited to 'iTexSnip/Utils')
-rw-r--r--iTexSnip/Utils/ImageUtils.swift176
-rw-r--r--iTexSnip/Utils/RobertaTokenizerFast.swift88
-rw-r--r--iTexSnip/Utils/TexTellerModel.swift6
3 files changed, 270 insertions, 0 deletions
diff --git a/iTexSnip/Utils/ImageUtils.swift b/iTexSnip/Utils/ImageUtils.swift
new file mode 100644
index 0000000..e59c4e5
--- /dev/null
+++ b/iTexSnip/Utils/ImageUtils.swift
@@ -0,0 +1,176 @@
+//
+// ImageUtils.swift
+// iTexSnip
+//
+// Created by Navan Chauhan on 10/13/24.
+//
+
+import Foundation
+import CoreImage
+import AppKit
+
+let IMAGE_MEAN: CGFloat = 0.9545467
+let IMAGE_STD: CGFloat = 0.15394445
+let FIXED_IMG_SIZE: CGFloat = 448
+let IMG_CHANNELS: Int = 1
+let MIN_HEIGHT: CGFloat = 12
+let MIN_WIDTH: CGFloat = 30
+
+// Load image from URL
+func loadImage(from urlString: String) -> NSImage? {
+ guard let url = URL(string: urlString), let imageData = try? Data(contentsOf: url) else {
+ return nil
+ }
+ return NSImage(data: imageData)
+}
+
+// Helper to convert NSImage to CIImage
+func nsImageToCIImage(_ image: NSImage) -> CIImage? {
+ guard let data = image.tiffRepresentation,
+ let bitmapImage = NSBitmapImageRep(data: data),
+ let cgImage = bitmapImage.cgImage else {
+ return nil
+ }
+ return CIImage(cgImage: cgImage)
+}
+
+func trimWhiteBorder(image: CIImage) -> CIImage? {
+ let context = CIContext()
+
+ // Render the CIImage to a CGImage for pixel analysis
+ guard let cgImage = context.createCGImage(image, from: image.extent) else {
+ return nil
+ }
+
+ // Access the pixel data
+ let width = cgImage.width
+ let height = cgImage.height
+ let colorSpace = CGColorSpaceCreateDeviceRGB()
+ let bytesPerPixel = 4
+ let bytesPerRow = bytesPerPixel * width
+ let bitmapInfo = CGImageAlphaInfo.premultipliedLast.rawValue
+ var pixelData = [UInt8](repeating: 0, count: height * bytesPerRow)
+
+ guard let contextRef = CGContext(
+ data: &pixelData,
+ width: width,
+ height: height,
+ bitsPerComponent: 8,
+ bytesPerRow: bytesPerRow,
+ space: colorSpace,
+ bitmapInfo: bitmapInfo
+ ) else {
+ return nil
+ }
+
+ contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height)))
+
+ // Define the white color in RGBA
+ let whitePixel: [UInt8] = [255, 255, 255, 255]
+
+ var minX = width
+ var minY = height
+ var maxX: Int = 0
+ var maxY: Int = 0
+
+ // Scan the pixels to find the bounding box of non-white content
+ for y in 0..<height {
+ for x in 0..<width {
+ let pixelIndex = (y * bytesPerRow) + (x * bytesPerPixel)
+ let pixel = Array(pixelData[pixelIndex..<(pixelIndex + 4)])
+
+ if pixel != whitePixel {
+ if x < minX { minX = x }
+ if x > maxX { maxX = x }
+ if y < minY { minY = y }
+ if y > maxY { maxY = y }
+ }
+ }
+ }
+
+ // If no non-white content was found, return the original image
+ if minX == width || minY == height || maxX == 0 || maxY == 0 {
+ return image
+ }
+
+ // Compute the bounding box and crop the image
+ let croppedRect = CGRect(x: CGFloat(minX), y: CGFloat(minY), width: CGFloat(maxX - minX), height: CGFloat(maxY - minY))
+ return image.cropped(to: croppedRect)
+}
+// Padding image with white border
+func addWhiteBorder(to image: CIImage, maxSize: CGFloat) -> CIImage {
+ let randomPadding = (0..<4).map { _ in CGFloat(arc4random_uniform(UInt32(maxSize))) }
+ var xPadding = randomPadding[0] + randomPadding[2]
+ var yPadding = randomPadding[1] + randomPadding[3]
+
+ // Ensure the minimum width and height
+ if xPadding + image.extent.width < MIN_WIDTH {
+ let compensateWidth = (MIN_WIDTH - (xPadding + image.extent.width)) * 0.5 + 1
+ xPadding += compensateWidth
+ }
+ if yPadding + image.extent.height < MIN_HEIGHT {
+ let compensateHeight = (MIN_HEIGHT - (yPadding + image.extent.height)) * 0.5 + 1
+ yPadding += compensateHeight
+ }
+
+ // Adding padding with a constant white color
+ let padFilter = CIFilter(name: "CICrop")!
+ let paddedRect = CGRect(x: image.extent.origin.x - randomPadding[0],
+ y: image.extent.origin.y - randomPadding[1],
+ width: image.extent.width + xPadding,
+ height: image.extent.height + yPadding)
+ padFilter.setValue(image, forKey: kCIInputImageKey)
+ padFilter.setValue(CIVector(cgRect: paddedRect), forKey: "inputRectangle")
+
+ return padFilter.outputImage ?? image
+}
+
+// Padding images to a required size
+func padding(images: [CIImage], requiredSize: CGFloat) -> [CIImage] {
+ return images.map { image in
+ let widthPadding = requiredSize - image.extent.width
+ let heightPadding = requiredSize - image.extent.height
+ return addWhiteBorder(to: image, maxSize: max(widthPadding, heightPadding))
+ }
+}
+
+// Transform pipeline to apply resize, normalize, etc.
+func inferenceTransform(images: [NSImage]) -> [CIImage] {
+ let ciImages = images.compactMap { nsImageToCIImage($0) }
+
+ let trimmedImages = ciImages.compactMap { trimWhiteBorder(image: $0) }
+ let paddedImages = padding(images: trimmedImages, requiredSize: FIXED_IMG_SIZE)
+
+ return paddedImages
+}
+
+func ciImageToFloatArray(_ image: CIImage, size: CGSize) -> [Float] {
+ // Render the CIImage to a bitmap context
+ let context = CIContext()
+ guard let cgImage = context.createCGImage(image, from: image.extent) else {
+ return []
+ }
+
+ let width = Int(size.width)
+ let height = Int(size.height)
+ var pixelData = [UInt8](repeating: 0, count: width * height) // Use UInt8 for grayscale
+
+ // Create bitmap context for rendering
+ let colorSpace = CGColorSpaceCreateDeviceGray()
+ guard let contextRef = CGContext(
+ data: &pixelData,
+ width: width,
+ height: height,
+ bitsPerComponent: 8,
+ bytesPerRow: width,
+ space: colorSpace,
+ bitmapInfo: CGImageAlphaInfo.none.rawValue
+ ) else {
+ return []
+ }
+
+ contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height)))
+
+ // Normalize pixel values to [0, 1]
+ return pixelData.map { Float($0) / 255.0 }
+}
diff --git a/iTexSnip/Utils/RobertaTokenizerFast.swift b/iTexSnip/Utils/RobertaTokenizerFast.swift
new file mode 100644
index 0000000..d61e8c7
--- /dev/null
+++ b/iTexSnip/Utils/RobertaTokenizerFast.swift
@@ -0,0 +1,88 @@
+//
+// RobertaTokenizerFast.swift
+// iTexSnip
+//
+// Created by Navan Chauhan on 10/13/24.
+//
+
+import Foundation
+
+class RobertaTokenizerFast {
+ var vocab: [String: Int] = [:]
+ var idToToken: [Int: String] = [:]
+ var specialTokens: [String] = []
+ var unkTokenId: Int?
+
+ init(vocabFile: String, tokenizerFile: String) {
+ if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"),
+ let vocabData = try? Data(contentsOf: vocabURL),
+ let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) as? [String: Int] {
+ self.vocab = vocabDict
+ }
+
+ if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"),
+ let tokenizerData = try? Data(contentsOf: tokenizerURL),
+ let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) as? [String: Any] {
+ self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? []
+ }
+
+ self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key }
+
+ self.unkTokenId = vocab["<unk>"]
+ }
+
+ func encode(text: String) -> [Int] {
+ let tokens = tokenize(text)
+ return tokens.map { vocab[$0] ?? unkTokenId! }
+ }
+
+ func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String {
+ let tokens = tokenIds.compactMap { idToToken[$0] }
+ let filteredTokens = skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "</s>" } : tokens
+ return convertTokensToString(filteredTokens)
+ }
+
+ private func tokenize(_ text: String) -> [String] {
+ let cleanedText = cleanText(text)
+ let words = cleanedText.split(separator: " ").map { String($0) }
+
+ var tokens: [String] = []
+ for word in words {
+ tokens.append(contentsOf: bpeEncode(word))
+ }
+ return tokens
+ }
+
+ private func bpeEncode(_ word: String) -> [String] {
+ if vocab.keys.contains(word) {
+ return [word]
+ }
+
+ let chars = Array(word)
+ var tokens: [String] = []
+ var i = 0
+
+ while i < chars.count {
+ if i < chars.count - 1 {
+ let pair = String(chars[i]) + String(chars[i + 1])
+ if vocab.keys.contains(pair) {
+ tokens.append(pair)
+ i += 2
+ continue
+ }
+ }
+ tokens.append(String(chars[i]))
+ i += 1
+ }
+ return tokens
+ }
+
+ private func cleanText(_ text: String) -> String {
+ return text.trimmingCharacters(in: .whitespacesAndNewlines)
+ }
+
+ private func convertTokensToString(_ tokens: [String]) -> String {
+ let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ")
+ return text.replacingOccurrences(of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil).trimmingCharacters(in: .whitespaces)
+ }
+}
diff --git a/iTexSnip/Utils/TexTellerModel.swift b/iTexSnip/Utils/TexTellerModel.swift
index fb3bcd9..2f71919 100644
--- a/iTexSnip/Utils/TexTellerModel.swift
+++ b/iTexSnip/Utils/TexTellerModel.swift
@@ -8,6 +8,12 @@
import OnnxRuntimeBindings
import AppKit
+public enum ModelError: Error {
+ case encoderModelNotFound
+ case decoderModelNotFound
+ case imageError
+}
+
public struct TexTellerModel {
public let encoderSession: ORTSession
public let decoderSession: ORTSession