Skip to content

Commit

Permalink
Merge pull request #64 from tylerjthomas9/class-weight-dict
Browse files Browse the repository at this point in the history
Allow Naive passing of Julia `class_weight` dictionary
  • Loading branch information
tylerjthomas9 committed Jan 3, 2024
2 parents 530ce27 + 75f718a commit cfff77e
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,29 @@ authors = ["Thibaut Lienart, Anthony Blaom"]
version = "0.6.0"

[deps]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[compat]
Aqua = "0.8"
MLJBase = "1"
MLJModelInterface = "1.4"
MLJTestInterface = "0.2"
PythonCall = "0.9"
StableRNGs = "1"
Statistics = "1"
Tables = "1.10"
Test = "1"
julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["MLJBase", "MLJTestInterface", "StableRNGs", "Test"]
test = ["Aqua", "MLJBase", "MLJTestInterface", "StableRNGs", "Test"]
30 changes: 23 additions & 7 deletions src/macros.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,17 @@ function _sk_finalize(m, clean_expr, fit_expr, expr)
esc(expr)
end

"""
_prepare_param(obj)
Prepare model parameters for passing Python constructor
"""
_prepare_param(obj) = obj

function _prepare_param(obj::Dict)
return SK.PythonCall.pydict(obj)
end

# =================================
# Specifics for SUPERVISED MODELS
# =================================
Expand Down Expand Up @@ -142,9 +153,10 @@ function _skmodel_fit_reg(modelname, params, save_std::Bool=false)
pyisnull(parent) && ski!(parent, skmod)
# retrieve the effective ScikitLearn constructor
skconstr = getproperty(parent, mdl)
# build the scikitlearn model passing all the parameters
skmodel = skconstr(
$((Expr(:kw, p, :(model.$p)) for p in params)...))
param_dict = Dict{Symbol, Any}(
[(p => _prepare_param(getfield(model, p))) for p in $params]
)
skmodel = skconstr(; param_dict...)
# --------------------------------------------------------------
# fit and organise results
X_py = ScikitLearnAPI.numpy.array(Xmatrix)
Expand Down Expand Up @@ -180,8 +192,10 @@ function _skmodel_fit_clf(modelname, params)
parent = eval(sksym)
pyisnull(parent) && ski!(parent, skmod)
skconstr = getproperty(parent, mdl)
skmodel = skconstr(
$((Expr(:kw, p, :(model.$p)) for p in params)...))
param_dict = Dict{Symbol, Any}(
[(p => _prepare_param(getfield(model, p))) for p in $params]
)
skmodel = skconstr(; param_dict...)
fitres = SK.fit!(skmodel, Xmatrix, yplain)
report = (; names)
if ScikitLearnAPI.pyhasattr(fitres, "coef_")
Expand Down Expand Up @@ -264,8 +278,10 @@ function _skmodel_fit_uns(modelname, params)
parent = eval(sksym)
pyisnull(parent) && ski!(parent, skmod)
skconstr = getproperty(parent, mdl)
skmodel = skconstr(
$((Expr(:kw, p, :(model.$p)) for p in params)...))
param_dict = Dict{Symbol, Any}(
[(p => _prepare_param(getfield(model, p))) for p in $params]
)
skmodel = skconstr(; param_dict...)
fitres = SK.fit!(skmodel, Xmatrix)
# TODO: we may want to use the report later on
report = NamedTuple()
Expand Down
8 changes: 8 additions & 0 deletions test/extras.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@
f, = MB.fit(m, 1, X)
@test f !== nothing
end

@testset "i63" begin
X, y = MB.make_blobs(500, 3, rng=555)
w = Dict(1=>0.2, 2=>0.7, 3=>0.1)
m = RandomForestClassifier(class_weight=w)
f, = MB.fit(m, 1, X, y)
@test f !== nothing
end
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Aqua
using StableRNGs
using MLJScikitLearnInterface
using Test
Expand Down Expand Up @@ -27,3 +28,5 @@ println("\nclustering"); include("models/clustering.jl")
println("\nfeature importance tests"); include("feature_importance_tests.jl")
println("\ngeneric interface tests"); include("generic_api_tests.jl")
println("\nExtra tests from bug reports"); include("extras.jl")

Aqua.test_all(MLJScikitLearnInterface, ambiguities=false)

0 comments on commit cfff77e

Please sign in to comment.