@@ -284,15 +284,15 @@ struct smo_out {
284284 * @param[in] x data of size mxn where m is the number of observation
285285 * @param[in] y labels must be a vector of size pow2(m) with -1 or 1 the first m elements and 0 after
286286 * @param c parameter for C-SVM
287- * @param tol criteria for stopping condition, should be greater than eps
288- * @param eps threshold above which alpha needs to be to be used as a weight of a support vector
287+ * @param tol criteria for stopping condition
288+ * @param alpha_eps threshold above which alpha needs to be to be used as a weight of a support vector
289289 * @param max_nb_iter maximum number of iterations
290290 * @param kernel_cache
291291 * @return smo_out containing the support vectors svs, the alphas (multiplied by their respective labels) and
292292 * the offset rho
293293 */
294294template <class KernelCacheT , class T >
295- smo_out<T> smo (queue& q, matrix_t <T>& x, vector_t <T>& y, T c, T tol, T eps , SYCLIndexT max_nb_iter,
295+ smo_out<T> smo (queue& q, matrix_t <T>& x, vector_t <T>& y, T c, T tol, T alpha_eps , SYCLIndexT max_nb_iter,
296296 KernelCacheT kernel_cache) {
297297 auto m = access_ker_dim (x, 0 );
298298 assert_eq (y.kernel_range .get_global_linear_range (), to_pow2 (m));
@@ -309,8 +309,8 @@ smo_out<T> smo(queue& q, matrix_t<T>& x, vector_t<T>& y, T c, T tol, T eps, SYCL
309309 vector_t <T> vec_cond_greater (y.data_range , y.kernel_range );
310310 vector_t <T> vec_cond_less (y.data_range , y.kernel_range );
311311
312- auto cond_greater = [c, eps ](T y, T a) { return T ((y > 0 && a < c) || (y < 0 && a > eps )); };
313- auto cond_less = [c, eps ](T y, T a) { return T ((y > 0 && a > eps ) || (y < 0 && a < c)); };
312+ auto cond_greater = [c, alpha_eps ](T y, T a) { return T ((y > 0 && a < c) || (y < 0 && a > alpha_eps )); };
313+ auto cond_less = [c, alpha_eps ](T y, T a) { return T ((y > 0 && a > alpha_eps ) || (y < 0 && a < c)); };
314314
315315 // Compute initial cond
316316 vec_unary_op (q, y, vec_cond_greater, ml::functors::positive<T>());
@@ -325,6 +325,7 @@ smo_out<T> smo(queue& q, matrix_t<T>& x, vector_t<T>& y, T c, T tol, T eps, SYCL
325325 SYCLIndexT j;
326326 T diff;
327327 SYCLIndexT nb_iter = 0 ;
328+ T eps = 1E-8 ;
328329 while (nb_iter < max_nb_iter) {
329330 if (!detail::select_wss (q, y, gradient, vec_cond_greater, vec_cond_less, tol, eps, start_search_indices,
330331 start_search_rng, find_size_threshold_host, kernel_cache, i, j, diff)) {
@@ -360,19 +361,17 @@ smo_out<T> smo(queue& q, matrix_t<T>& x, vector_t<T>& y, T c, T tol, T eps, SYCL
360361 // Update gradient
361362 T delta_ai = yi * (ai - old_ai);
362363 T delta_aj = yj * (aj - old_aj);
364+
363365 // Shouldn't happen in theory but can because of precision issue
364- if (std::abs (delta_ai) < eps && std::abs (delta_aj) < eps) {
365- std::cerr << " SVM cannot converge, try setting a smaller eps or a bigger tol." << std::endl;
366- break ;
367- }
368- else {
369- detail::update_gradient (q, delta_ai, delta_aj, ker_i_t , ker_j_t , gradient);
370- vec_cond_greater.write_from_host (i, cond_greater (yi, ai));
371- vec_cond_greater.write_from_host (j, cond_greater (yj, aj));
372- vec_cond_less.write_from_host (i, cond_less (yi, ai));
373- vec_cond_less.write_from_host (j, cond_less (yj, aj));
374- ++nb_iter;
375- }
366+ assert (std::abs (delta_ai) >= eps);
367+ assert (std::abs (delta_aj) >= eps);
368+
369+ detail::update_gradient (q, delta_ai, delta_aj, ker_i_t , ker_j_t , gradient);
370+ vec_cond_greater.write_from_host (i, cond_greater (yi, ai));
371+ vec_cond_greater.write_from_host (j, cond_greater (yj, aj));
372+ vec_cond_less.write_from_host (i, cond_less (yi, ai));
373+ vec_cond_less.write_from_host (j, cond_less (yj, aj));
374+ ++nb_iter;
376375 }
377376
378377 if (nb_iter == max_nb_iter)
@@ -382,7 +381,7 @@ smo_out<T> smo(queue& q, matrix_t<T>& x, vector_t<T>& y, T c, T tol, T eps, SYCL
382381 auto host_alphas = alphas.template get_access <access::mode::read>();
383382 std::vector<uint32_t > host_sv_indices;
384383 for (unsigned k = 0 ; k < m; ++k) {
385- if (host_alphas[k] > eps )
384+ if (host_alphas[k] > alpha_eps )
386385 host_sv_indices.push_back (k);
387386 }
388387 auto nb_sv = host_sv_indices.size ();
0 commit comments