@@ -160,6 +160,58 @@ def transform(self, X):
160160 Input data transformed to the metric space by :math:`XL^{\\ top}`
161161 """
162162
163+ class BilinearMixin (BaseMetricLearner , metaclass = ABCMeta ):
164+
165+ def score_pairs (self , pairs ):
166+ r"""
167+ Parameters
168+ ----------
169+ pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2)
170+ 3D Array of pairs to score, with each row corresponding to two points,
171+ for 2D array of indices of pairs if the metric learner uses a
172+ preprocessor.
173+
174+ Returns
175+ -------
176+ scores : `numpy.ndarray` of shape=(n_pairs,)
177+ The learned Mahalanobis distance for every pair.
178+ """
179+ check_is_fitted (self , ['preprocessor_' , 'components_' ])
180+ pairs = check_input (pairs , type_of_inputs = 'tuples' ,
181+ preprocessor = self .preprocessor_ ,
182+ estimator = self , tuple_size = 2 )
183+ return np .dot (np .dot (pairs [:, 1 , :], self .components_ ), pairs [:, 0 , :].T )
184+
185+ def get_metric (self ):
186+ check_is_fitted (self , 'components_' )
187+ components = self .components_ .copy ()
188+
189+ def metric_fun (u , v ):
190+ """This function computes the metric between u and v, according to the
191+ previously learned metric.
192+
193+ Parameters
194+ ----------
195+ u : array-like, shape=(n_features,)
196+ The first point involved in the distance computation.
197+
198+ v : array-like, shape=(n_features,)
199+ The second point involved in the distance computation.
200+
201+ Returns
202+ -------
203+ distance : float
204+ The distance between u and v according to the new metric.
205+ """
206+ u = validate_vector (u )
207+ v = validate_vector (v )
208+ return np .dot (np .dot (u , components ), v .T )
209+
210+ return metric_fun
211+
212+ def get_bilinear_matrix (self ):
213+ check_is_fitted (self , 'components_' )
214+ return self .components_
163215
164216class MahalanobisMixin (BaseMetricLearner , MetricTransformer ,
165217 metaclass = ABCMeta ):
0 commit comments