Skip to content

Commit

Permalink
Updates for new Core API v0.13.0
Browse files Browse the repository at this point in the history
  • Loading branch information
mmottl committed Nov 22, 2019
1 parent ad9cb9c commit 80ea624
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 22 deletions.
4 changes: 3 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
### ?.?.? (????-??-??)
### 1.5.0 (2019-11-22)

* Switched to OPAM file generation via `dune-project`

* Updates for new Core API v0.13.0


### 1.4.1 (2018-10-24)

Expand Down
5 changes: 3 additions & 2 deletions app/ocaml_gpr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ module Args = struct
);
]

let usage_msg = sprintf "%s: -cmd [ train | test ] -model file" Sys.argv.(0)
let usage_msg =
sprintf "%s: -cmd [ train | test ] -model file" (Sys.get_argv ()).(0)

let anon_fun _ = failwith "no anonymous arguments allowed"

Expand Down Expand Up @@ -325,7 +326,7 @@ let train args =
let last_deriv_time = ref 0. in
let maybe_print last_time line =
let now = Unix.gettimeofday () in
if !last_time +. 1. < now then begin
if Float.(!last_time + 1. < now) then begin
last_time := now;
prerr_endline line;
end
Expand Down
2 changes: 1 addition & 1 deletion dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ and GPR implements some of the latest advances in this field.")
(ocaml (>= 4.08))
(dune (>= 1.10))
base-threads
(core (>= 0.9.1))
(core (>= v0.13))
(lacaml (>= 11.0.0))
(gsl (>= 1.24.0))
)
Expand Down
2 changes: 1 addition & 1 deletion gpr.opam
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ depends: [
"ocaml" {>= "4.08"}
"dune" {>= "1.10"}
"base-threads"
"core" {>= "0.9.1"}
"core" {>= "v0.13"}
"lacaml" {>= "11.0.0"}
"gsl" {>= "1.24.0"}
]
34 changes: 17 additions & 17 deletions src/fitc_gp.ml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ module Make_common (Spec : Specs.Eval) = struct
type co_variance_coeffs = mat * mat

let check_sigma2 sigma2 =
if sigma2 < 0. then failwith "Model.check_sigma2: sigma2 < 0"
if Float.(sigma2 < 0.) then failwith "Model.check_sigma2: sigma2 < 0"

let calc_internal inputs sigma2 ~kn_diag ~v_mat ~r_vec =
check_sigma2 sigma2;
Expand Down Expand Up @@ -189,7 +189,7 @@ module Make_common (Spec : Specs.Eval) = struct
else
let el =
let el = r_mat.{r, r} in
if el > 0. then el
if Float.(el > 0.) then el
else
(* Cannot happen with LAPACK version 3.2 and greater *)
let neg_el = -. el in
Expand Down Expand Up @@ -340,15 +340,15 @@ module Make_common (Spec : Specs.Eval) = struct
let means = Trained.calc_means trained in
let rec loop madsum i =
if i = 0 then madsum /. f_samples
else loop (madsum +. Float.abs (y.{i} -. means.{i})) (i - 1)
else loop Float.(madsum + abs (y.{i} - means.{i})) (i - 1)
in
loop 0. n_samples

let calc_maxad ({ Trained.y } as trained) =
let means = Trained.calc_means trained in
let rec loop maxad i =
if i = 0 then maxad
else loop (max maxad (Float.abs (y.{i} -. means.{i}))) (i - 1)
else loop Float.(max maxad (abs (y.{i} - means.{i}))) (i - 1)
in
loop 0. (Vec.dim y)

Expand All @@ -368,7 +368,7 @@ module Make_common (Spec : Specs.Eval) = struct
if i = 0 then madsum /. f_samples, maxad
else
let ad = Float.abs (y.{i} -. means.{i}) in
loop ~madsum:(madsum +. ad) ~maxad:(max maxad ad) (i - 1)
loop ~madsum:(madsum +. ad) ~maxad:(Float.max maxad ad) (i - 1)
in
loop ~madsum:0. ~maxad:0. n_samples
in
Expand Down Expand Up @@ -1209,7 +1209,7 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
let is_bad_deriv ~finite_el ~deriv ~tol =
Float.is_nan finite_el
|| Float.is_nan deriv
|| Float.abs (finite_el -. deriv) > tol
|| Float.(abs (finite_el - deriv) > tol)

let check_deriv_hyper ?(eps = 1e-8) ?(tol = 1e-2)
kernel1 inducing_points1 points1 hyper =
Expand Down Expand Up @@ -1436,7 +1436,7 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
module Optim = struct
let get_sigma2 targets = function
| None -> Vec.sqr_nrm2 targets /. float (Vec.dim targets)
| Some sigma2 when sigma2 < 0. ->
| Some sigma2 when Float.(sigma2 < 0.) ->
failwithf "Optim.get_sigma2: sigma2 < 0: %f" sigma2 ()
| Some sigma2 -> sigma2

Expand Down Expand Up @@ -1494,7 +1494,7 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
exception Optim_exception of exn

let check_exception seen_exception_ref res =
if Float.classify res = Float.Class.Nan then
if Poly.(Float.classify res = Float.Class.Nan) then
match !seen_exception_ref with
| None ->
failwith "Gpr.Optim.Gsl: optimization function returned nan"
Expand Down Expand Up @@ -1561,7 +1561,7 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
let update_best_model trained log_evidence =
match !best_model_ref with
| Some (_, old_log_evidence)
when old_log_evidence >= log_evidence -> ()
when Float.(old_log_evidence >= log_evidence) -> ()
| _ ->
report_trained_model ~iter:!iter_count trained;
best_model_ref := Some (trained, log_evidence)
Expand Down Expand Up @@ -1633,7 +1633,7 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
try report_gradient_norm ~iter:!iter_count gnorm
with exc -> raise (Optim_exception exc)
end;
if gnorm < epsabs then get_best_model ()
if Float.(gnorm < epsabs) then get_best_model ()
else begin
incr iter_count;
Gd.iterate mumin;
Expand Down Expand Up @@ -1676,15 +1676,15 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
| Some max_iter -> max_iter
in
let rec loop n ~best_le ~best ~t =
if n = 0 || gradient_norm t < epsabs then best
if n = 0 || Float.(gradient_norm t < epsabs) then best
else
let new_t = step t in
let best_le, best =
let new_trained = get_trained new_t in
let new_log_evidence =
Eval_trained.calc_log_evidence new_trained
in
if new_log_evidence <= best_le then best_le, best
if Float.(new_log_evidence <= best_le) then best_le, best
else begin
report new_t;
new_log_evidence, new_t
Expand Down Expand Up @@ -1714,7 +1714,7 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
~inputs ~targets () =
let loc = "Gpr.Fitc_gp.Optim.SGD.create" in
let fail_neg0 what v =
if v <= 0. then failwithf "%s: %s (%f) <= 0" loc what v ()
if Float.(v <= 0.) then failwithf "%s: %s (%f) <= 0" loc what v ()
in
fail_neg0 "tau" tau;
fail_neg0 "eta0" eta;
Expand Down Expand Up @@ -1817,14 +1817,14 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
let lambda =
match lambda with
| None -> 0.1
| Some lambda when lambda < 0. || lambda > 1. ->
| Some lambda when Float.(lambda < 0. || lambda > 1.) ->
failwithf "%s: violating 0 <= lambda(%f) <= 1" loc lambda ()
| Some lambda -> lambda
in
let mu =
match mu with
| None -> 1e-3
| Some mu when mu < 0. ->
| Some mu when Float.(mu < 0.) ->
failwithf "%s: violating 0 <= mu(%f)" loc mu ()
| Some mu -> mu
in
Expand All @@ -1850,7 +1850,7 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
else begin
for i = 1 to n_all_hypers do
let eta0_i = eta0.{i} in
if eta0_i <= 0. then
if Float.(eta0_i <= 0.) then
failwithf "%s: eta0.{%d} < 0: %f" loc i eta0_i ()
done;
eta0
Expand Down Expand Up @@ -1933,7 +1933,7 @@ module Make_common_deriv (Spec : Specs.Deriv) = struct
for i = 1 to n_all_hypers do
eta.{i} <-
old_eta.{i} *.
max 0.5 (1. +. mu *. old_gradient.{i} *. old_nu.{i})
Float.max 0.5 (1. +. mu *. old_gradient.{i} *. old_nu.{i})
done;
let sigma2, hyper_ix =
if learn_sigma2 then
Expand Down

0 comments on commit 80ea624

Please sign in to comment.