diff options
author | Navan Chauhan <navanchauhan@gmail.com> | 2024-10-21 23:51:06 -0600 |
---|---|---|
committer | Navan Chauhan <navanchauhan@gmail.com> | 2024-10-21 23:51:06 -0600 |
commit | 05cf4dd46aebfaa7812989c88eb91d303a43d3e7 (patch) | |
tree | 6cf0542910c72f8135345facbcf39a8eb409d97e /iTexSnip/Utils | |
parent | c343140bcf9862b7b4b0d465b67e51eb42f45008 (diff) |
bruh
Diffstat (limited to 'iTexSnip/Utils')
-rw-r--r-- | iTexSnip/Utils/ImageUtils.swift | 176 | ||||
-rw-r--r-- | iTexSnip/Utils/RobertaTokenizerFast.swift | 88 | ||||
-rw-r--r-- | iTexSnip/Utils/TexTellerModel.swift | 6 |
3 files changed, 270 insertions, 0 deletions
diff --git a/iTexSnip/Utils/ImageUtils.swift b/iTexSnip/Utils/ImageUtils.swift new file mode 100644 index 0000000..e59c4e5 --- /dev/null +++ b/iTexSnip/Utils/ImageUtils.swift @@ -0,0 +1,176 @@ +// +// ImageUtils.swift +// iTexSnip +// +// Created by Navan Chauhan on 10/13/24. +// + +import Foundation +import CoreImage +import AppKit + +let IMAGE_MEAN: CGFloat = 0.9545467 +let IMAGE_STD: CGFloat = 0.15394445 +let FIXED_IMG_SIZE: CGFloat = 448 +let IMG_CHANNELS: Int = 1 +let MIN_HEIGHT: CGFloat = 12 +let MIN_WIDTH: CGFloat = 30 + +// Load image from URL +func loadImage(from urlString: String) -> NSImage? { + guard let url = URL(string: urlString), let imageData = try? Data(contentsOf: url) else { + return nil + } + return NSImage(data: imageData) +} + +// Helper to convert NSImage to CIImage +func nsImageToCIImage(_ image: NSImage) -> CIImage? { + guard let data = image.tiffRepresentation, + let bitmapImage = NSBitmapImageRep(data: data), + let cgImage = bitmapImage.cgImage else { + return nil + } + return CIImage(cgImage: cgImage) +} + +func trimWhiteBorder(image: CIImage) -> CIImage? { + let context = CIContext() + + // Render the CIImage to a CGImage for pixel analysis + guard let cgImage = context.createCGImage(image, from: image.extent) else { + return nil + } + + // Access the pixel data + let width = cgImage.width + let height = cgImage.height + let colorSpace = CGColorSpaceCreateDeviceRGB() + let bytesPerPixel = 4 + let bytesPerRow = bytesPerPixel * width + let bitmapInfo = CGImageAlphaInfo.premultipliedLast.rawValue + var pixelData = [UInt8](repeating: 0, count: height * bytesPerRow) + + guard let contextRef = CGContext( + data: &pixelData, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: bytesPerRow, + space: colorSpace, + bitmapInfo: bitmapInfo + ) else { + return nil + } + + contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))) + + // Define the white color in RGBA + let whitePixel: [UInt8] = [255, 255, 255, 255] + + var minX = width + var minY = height + var maxX: Int = 0 + var maxY: Int = 0 + + // Scan the pixels to find the bounding box of non-white content + for y in 0..<height { + for x in 0..<width { + let pixelIndex = (y * bytesPerRow) + (x * bytesPerPixel) + let pixel = Array(pixelData[pixelIndex..<(pixelIndex + 4)]) + + if pixel != whitePixel { + if x < minX { minX = x } + if x > maxX { maxX = x } + if y < minY { minY = y } + if y > maxY { maxY = y } + } + } + } + + // If no non-white content was found, return the original image + if minX == width || minY == height || maxX == 0 || maxY == 0 { + return image + } + + // Compute the bounding box and crop the image + let croppedRect = CGRect(x: CGFloat(minX), y: CGFloat(minY), width: CGFloat(maxX - minX), height: CGFloat(maxY - minY)) + return image.cropped(to: croppedRect) +} +// Padding image with white border +func addWhiteBorder(to image: CIImage, maxSize: CGFloat) -> CIImage { + let randomPadding = (0..<4).map { _ in CGFloat(arc4random_uniform(UInt32(maxSize))) } + var xPadding = randomPadding[0] + randomPadding[2] + var yPadding = randomPadding[1] + randomPadding[3] + + // Ensure the minimum width and height + if xPadding + image.extent.width < MIN_WIDTH { + let compensateWidth = (MIN_WIDTH - (xPadding + image.extent.width)) * 0.5 + 1 + xPadding += compensateWidth + } + if yPadding + image.extent.height < MIN_HEIGHT { + let compensateHeight = (MIN_HEIGHT - (yPadding + image.extent.height)) * 0.5 + 1 + yPadding += compensateHeight + } + + // Adding padding with a constant white color + let padFilter = CIFilter(name: "CICrop")! + let paddedRect = CGRect(x: image.extent.origin.x - randomPadding[0], + y: image.extent.origin.y - randomPadding[1], + width: image.extent.width + xPadding, + height: image.extent.height + yPadding) + padFilter.setValue(image, forKey: kCIInputImageKey) + padFilter.setValue(CIVector(cgRect: paddedRect), forKey: "inputRectangle") + + return padFilter.outputImage ?? image +} + +// Padding images to a required size +func padding(images: [CIImage], requiredSize: CGFloat) -> [CIImage] { + return images.map { image in + let widthPadding = requiredSize - image.extent.width + let heightPadding = requiredSize - image.extent.height + return addWhiteBorder(to: image, maxSize: max(widthPadding, heightPadding)) + } +} + +// Transform pipeline to apply resize, normalize, etc. +func inferenceTransform(images: [NSImage]) -> [CIImage] { + let ciImages = images.compactMap { nsImageToCIImage($0) } + + let trimmedImages = ciImages.compactMap { trimWhiteBorder(image: $0) } + let paddedImages = padding(images: trimmedImages, requiredSize: FIXED_IMG_SIZE) + + return paddedImages +} + +func ciImageToFloatArray(_ image: CIImage, size: CGSize) -> [Float] { + // Render the CIImage to a bitmap context + let context = CIContext() + guard let cgImage = context.createCGImage(image, from: image.extent) else { + return [] + } + + let width = Int(size.width) + let height = Int(size.height) + var pixelData = [UInt8](repeating: 0, count: width * height) // Use UInt8 for grayscale + + // Create bitmap context for rendering + let colorSpace = CGColorSpaceCreateDeviceGray() + guard let contextRef = CGContext( + data: &pixelData, + width: width, + height: height, + bitsPerComponent: 8, + bytesPerRow: width, + space: colorSpace, + bitmapInfo: CGImageAlphaInfo.none.rawValue + ) else { + return [] + } + + contextRef.draw(cgImage, in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height))) + + // Normalize pixel values to [0, 1] + return pixelData.map { Float($0) / 255.0 } +} diff --git a/iTexSnip/Utils/RobertaTokenizerFast.swift b/iTexSnip/Utils/RobertaTokenizerFast.swift new file mode 100644 index 0000000..d61e8c7 --- /dev/null +++ b/iTexSnip/Utils/RobertaTokenizerFast.swift @@ -0,0 +1,88 @@ +// +// RobertaTokenizerFast.swift +// iTexSnip +// +// Created by Navan Chauhan on 10/13/24. +// + +import Foundation + +class RobertaTokenizerFast { + var vocab: [String: Int] = [:] + var idToToken: [Int: String] = [:] + var specialTokens: [String] = [] + var unkTokenId: Int? + + init(vocabFile: String, tokenizerFile: String) { + if let vocabURL = Bundle.main.url(forResource: vocabFile, withExtension: "json"), + let vocabData = try? Data(contentsOf: vocabURL), + let vocabDict = try? JSONSerialization.jsonObject(with: vocabData, options: []) as? [String: Int] { + self.vocab = vocabDict + } + + if let tokenizerURL = Bundle.main.url(forResource: tokenizerFile, withExtension: "json"), + let tokenizerData = try? Data(contentsOf: tokenizerURL), + let tokenizerConfig = try? JSONSerialization.jsonObject(with: tokenizerData, options: []) as? [String: Any] { + self.specialTokens = tokenizerConfig["added_tokens"] as? [String] ?? [] + } + + self.idToToken = vocab.reduce(into: [Int: String]()) { $0[$1.value] = $1.key } + + self.unkTokenId = vocab["<unk>"] + } + + func encode(text: String) -> [Int] { + let tokens = tokenize(text) + return tokens.map { vocab[$0] ?? unkTokenId! } + } + + func decode(tokenIds: [Int], skipSpecialTokens: Bool = true) -> String { + let tokens = tokenIds.compactMap { idToToken[$0] } + let filteredTokens = skipSpecialTokens ? tokens.filter { !specialTokens.contains($0) && $0 != "</s>" } : tokens + return convertTokensToString(filteredTokens) + } + + private func tokenize(_ text: String) -> [String] { + let cleanedText = cleanText(text) + let words = cleanedText.split(separator: " ").map { String($0) } + + var tokens: [String] = [] + for word in words { + tokens.append(contentsOf: bpeEncode(word)) + } + return tokens + } + + private func bpeEncode(_ word: String) -> [String] { + if vocab.keys.contains(word) { + return [word] + } + + let chars = Array(word) + var tokens: [String] = [] + var i = 0 + + while i < chars.count { + if i < chars.count - 1 { + let pair = String(chars[i]) + String(chars[i + 1]) + if vocab.keys.contains(pair) { + tokens.append(pair) + i += 2 + continue + } + } + tokens.append(String(chars[i])) + i += 1 + } + return tokens + } + + private func cleanText(_ text: String) -> String { + return text.trimmingCharacters(in: .whitespacesAndNewlines) + } + + private func convertTokensToString(_ tokens: [String]) -> String { + let text = tokens.joined().replacingOccurrences(of: "Ġ", with: " ") + return text.replacingOccurrences(of: "\\s([?.!,\'\"](?:\\s|$))", with: "$1", options: .regularExpression, range: nil).trimmingCharacters(in: .whitespaces) + } +} diff --git a/iTexSnip/Utils/TexTellerModel.swift b/iTexSnip/Utils/TexTellerModel.swift index fb3bcd9..2f71919 100644 --- a/iTexSnip/Utils/TexTellerModel.swift +++ b/iTexSnip/Utils/TexTellerModel.swift @@ -8,6 +8,12 @@ import OnnxRuntimeBindings import AppKit +public enum ModelError: Error { + case encoderModelNotFound + case decoderModelNotFound + case imageError +} + public struct TexTellerModel { public let encoderSession: ORTSession public let decoderSession: ORTSession |