9-2d-descision-boundary

Author

math4mad

简介
  1. ref: probml page84 figure 2.13
  2. dataset:iris
  3. plots:使用 GLMakie:contourf 方法

1. load package

Code
    import MLJ:fit!,fitted_params
    using GLMakie,MLJ,CSV,DataFrames

2 process data

2.1 import iris datset

Code
iris = load_iris();

#selectrows(iris, 1:3)  |> pretty

iris = DataFrames.DataFrame(iris);
first(iris,5)|>display
y, X = unpack(iris, ==(:target); rng=123);

X=select!(X,3:4)

byCat = iris.target
categ = unique(byCat)
colors = [:orange,:lightgreen,:purple];
5×5 DataFrame
Row sepal_length sepal_width petal_length petal_width target
Float64 Float64 Float64 Float64 Cat…
1 5.1 3.5 1.4 0.2 setosa
2 4.9 3.0 1.4 0.2 setosa
3 4.7 3.2 1.3 0.2 setosa
4 4.6 3.1 1.5 0.2 setosa
5 5.0 3.6 1.4 0.2 setosa

2.2 make desc boundary data

生成决策边界实际是利用训练模型对区间内的每个点都做出预测,利用两个属性的最大值和最小值 生成 grid 数据,这是 test数据

Code
# grid data
   n1 = n2 = 20
   tx = LinRange(0, 8, 40)
   ty = LinRange(-1, 4, 40)
   X_test = mapreduce(collect, hcat, Iterators.product(tx, ty))
   X_test = MLJ.table(X_test')
Tables.MatrixTable{LinearAlgebra.Adjoint{Float64, Matrix{Float64}}} with 1600 rows, 2 columns, and schema:
 :x1  Float64
 :x2  Float64

3. Logisitcs model

3.1 training model

Code
   LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels
   #X, y = make_blobs(centers = 2)
   mach = fit!(machine(LogisticClassifier(), X, y))
   predict(mach, X)
   fitted_params(mach)
   probs=predict(mach, X_test)|>Array #返回分类概率值
   probres=[broadcast(pdf, probs,cat) for cat in categ]
import MLJLinearModels ✔
[ Info: For silent loading, specify `verbosity=0`. 
[ Info: Training machine(LogisticClassifier(lambda = 2.220446049250313e-16, …), …).
┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, NamedTuple{(), Tuple{}}}
│   optim_options: Optim.Options{Float64, Nothing}
└   lbfgs_options: NamedTuple{(), Tuple{}} NamedTuple()
3-element Vector{Vector{Float64}}:
 [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0  …  9.671834765600391e-132, 7.370874765820906e-135, 5.617320408186208e-138, 4.2809421636823794e-141, 3.2624925190464766e-144, 2.4863352574885955e-147, 1.8948282567827453e-150, 1.4440426374071983e-153, 1.1005003388488668e-156, 8.38687836794779e-160]
 [2.6707314921073497e-56, 1.0760270944086094e-53, 4.335270360659973e-51, 1.7466631832673993e-48, 7.037236485793986e-46, 2.835274587075858e-43, 1.142320852844153e-40, 4.6023652763324474e-38, 1.8542746623291115e-35, 7.470798854313656e-33  …  1.3627080472304016e-14, 4.184137656019423e-15, 1.2847218419308875e-15, 3.9446842977544977e-16, 1.2111986969540427e-16, 3.718934578207533e-17, 1.1418831965200673e-17, 3.5061042539858484e-18, 1.0765345419987361e-18, 3.305453963038536e-19]
 [1.6814079540479983e-80, 2.2062917118777524e-77, 2.8950280068446543e-74, 3.798766552624987e-71, 4.9846244275441556e-68, 6.540670593853871e-65, 8.582466430351137e-62, 1.1261648017760677e-58, 1.4777187549191659e-55, 1.9390170205959138e-52  …  0.9999999999999865, 0.9999999999999958, 0.9999999999999987, 0.9999999999999996, 0.9999999999999998, 1.0, 1.0, 1.0, 1.0, 1.0]
Code
function plot_res_contour()
    fig=Figure()
    ax = Axis(fig[1, 1], xlabel="petal-length", ylabel="petal-width", title="2d3class-contour")
    
    
    for (idx,cat) in enumerate(categ)
        indc = findall(x -> x == cat, byCat)
        scatter!(ax,iris[:,3][indc],iris[:,4][indc];color=(colors[idx], 0.8), markersize=probres[idx].*10, strokewidth=1, strokecolor=:black, label="$cat")
        
    end
   fig
end
plot_res_contour()