diff --git a/Project.toml b/Project.toml index 7cc1bed..8ff4982 100644 --- a/Project.toml +++ b/Project.toml @@ -4,24 +4,29 @@ authors = ["Thibaut Lienart, Anthony Blaom"] version = "0.6.0" [deps] -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [compat] +Aqua = "0.8" MLJBase = "1" MLJModelInterface = "1.4" +MLJTestInterface = "0.2" PythonCall = "0.9" +StableRNGs = "1" +Statistics = "1" Tables = "1.10" +Test = "1" julia = "1.6" [extras] +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["MLJBase", "MLJTestInterface", "StableRNGs", "Test"] +test = ["Aqua", "MLJBase", "MLJTestInterface", "StableRNGs", "Test"] diff --git a/src/macros.jl b/src/macros.jl index 8f45c3a..f9cda44 100644 --- a/src/macros.jl +++ b/src/macros.jl @@ -103,6 +103,17 @@ function _sk_finalize(m, clean_expr, fit_expr, expr) esc(expr) end +""" + _prepare_param(obj) + +Prepare model parameters for passing Python constructor +""" +_prepare_param(obj) = obj + +function _prepare_param(obj::Dict) + return SK.PythonCall.pydict(obj) +end + # ================================= # Specifics for SUPERVISED MODELS # ================================= @@ -142,9 +153,10 @@ function _skmodel_fit_reg(modelname, params, save_std::Bool=false) pyisnull(parent) && ski!(parent, skmod) # retrieve the effective ScikitLearn constructor skconstr = getproperty(parent, mdl) - # build the scikitlearn model passing all the parameters - skmodel = skconstr( - $((Expr(:kw, p, :(model.$p)) for p in params)...)) + param_dict = Dict{Symbol, Any}( + [(p => _prepare_param(getfield(model, p))) for p in $params] + ) + skmodel = skconstr(; param_dict...) # -------------------------------------------------------------- # fit and organise results X_py = ScikitLearnAPI.numpy.array(Xmatrix) @@ -180,8 +192,10 @@ function _skmodel_fit_clf(modelname, params) parent = eval(sksym) pyisnull(parent) && ski!(parent, skmod) skconstr = getproperty(parent, mdl) - skmodel = skconstr( - $((Expr(:kw, p, :(model.$p)) for p in params)...)) + param_dict = Dict{Symbol, Any}( + [(p => _prepare_param(getfield(model, p))) for p in $params] + ) + skmodel = skconstr(; param_dict...) fitres = SK.fit!(skmodel, Xmatrix, yplain) report = (; names) if ScikitLearnAPI.pyhasattr(fitres, "coef_") @@ -264,8 +278,10 @@ function _skmodel_fit_uns(modelname, params) parent = eval(sksym) pyisnull(parent) && ski!(parent, skmod) skconstr = getproperty(parent, mdl) - skmodel = skconstr( - $((Expr(:kw, p, :(model.$p)) for p in params)...)) + param_dict = Dict{Symbol, Any}( + [(p => _prepare_param(getfield(model, p))) for p in $params] + ) + skmodel = skconstr(; param_dict...) fitres = SK.fit!(skmodel, Xmatrix) # TODO: we may want to use the report later on report = NamedTuple() diff --git a/test/extras.jl b/test/extras.jl index 93c6845..e4475a2 100644 --- a/test/extras.jl +++ b/test/extras.jl @@ -6,3 +6,11 @@ f, = MB.fit(m, 1, X) @test f !== nothing end + +@testset "i63" begin + X, y = MB.make_blobs(500, 3, rng=555) + w = Dict(1=>0.2, 2=>0.7, 3=>0.1) + m = RandomForestClassifier(class_weight=w) + f, = MB.fit(m, 1, X, y) + @test f !== nothing +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d7c7b69..12d7da3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,4 @@ +using Aqua using StableRNGs using MLJScikitLearnInterface using Test @@ -27,3 +28,5 @@ println("\nclustering"); include("models/clustering.jl") println("\nfeature importance tests"); include("feature_importance_tests.jl") println("\ngeneric interface tests"); include("generic_api_tests.jl") println("\nExtra tests from bug reports"); include("extras.jl") + +Aqua.test_all(MLJScikitLearnInterface, ambiguities=false)