1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
/*
See LICENSE folder for this sample’s licensing information.
Abstract:
Provides helper types for the BERT model's outputs.
*/
import CoreML
extension Array where Element: Comparable {
/// Provides the indices of the largest elements.
///
/// - parameters:
/// - count: The number of indicies to return, at most.
/// - returns: An array of integers.
func indicesOfLargest(_ count: Int = 10) -> [Int] {
let count = Swift.min(count, self.count)
let sortedSelf = enumerated().sorted { (arg0, arg1) in arg0.element > arg1.element }
let topElements = sortedSelf[0..<count]
let topIndices = topElements.map { (tuple) in tuple.offset }
return topIndices
}
}
extension MLMultiArray {
/// Creates a copy of the multi-array's contents into a Doubles array.
///
/// - returns: An array of Doubles.
func doubleArray() -> [Double] {
// Bind the underlying `dataPointer` memory to make a native swift `Array<Double>`
let unsafeMutablePointer = dataPointer.bindMemory(to: Double.self, capacity: count)
let unsafeBufferPointer = UnsafeBufferPointer(start: unsafeMutablePointer, count: count)
return [Double](unsafeBufferPointer)
}
}
extension BERT {
/// Finds the indices of the best start logit and end logit given a prediction output and a range.
///
/// - parameters:
/// - prediction: A feature provider that supplies the output MLMultiArrays from a BERT model.
/// - range: A range of the output tokens to search.
/// - returns: Description.
/// - Tag: BestLogitIndices
func bestLogitsIndices(from prediction: BERTQAFP16Output, in range: Range<Int>) -> (start: Int, end: Int)? {
// Convert the logits MLMultiArrays to [Double].
let startLogits = prediction.startLogits.doubleArray()
let endLogits = prediction.endLogits.doubleArray()
// Isolate the logits for the document.
let startLogitsOfDoc = [Double](startLogits[range])
let endLogitsOfDoc = [Double](endLogits[range])
// Only keep the top 20 (out of the possible ~380) indices for faster searching.
let topStartIndices = startLogitsOfDoc.indicesOfLargest(20)
let topEndIndices = endLogitsOfDoc.indicesOfLargest(20)
// Search for the highest valued logit pairing.
let bestPair = findBestLogitPair(startLogits: startLogitsOfDoc,
bestStartIndices: topStartIndices,
endLogits: endLogitsOfDoc,
bestEndIndices: topEndIndices)
guard bestPair.start >= 0 && bestPair.end >= 0 else {
return nil
}
return bestPair
}
/// Searches the given indices for the highest valued start and end logits.
///
/// - parameters:
/// - startLogits: An array of all the start logits.
/// - bestStartIndices: An array of the best start logit indices.
/// - endLogits: An array of all the end logits.
/// - bestEndIndices: An array of the best end logit indices.
/// - returns: A tuple of the best start index and best end index.
func findBestLogitPair(startLogits: [Double],
bestStartIndices: [Int],
endLogits: [Double],
bestEndIndices: [Int]) -> (start: Int, end: Int) {
let logitsCount = startLogits.count
var bestSum = -Double.infinity
var bestStart = -1
var bestEnd = -1
for start in 0..<logitsCount where bestStartIndices.contains(start) {
for end in start..<logitsCount where bestEndIndices.contains(end) {
let logitSum = startLogits[start] + endLogits[end]
if logitSum > bestSum {
bestSum = logitSum
bestStart = start
bestEnd = end
}
}
}
return (bestStart, bestEnd)
}
}
|