5-2d-decision-boundary

Author

math4mads

决策边界图

probml page 84,figure 2.13

  • 概率值预测方法 参见: https://discourse.julialang.org/t/extracting-values-from-univariatefinite/62794/3

  • 平面上每个点作为预测数据集

  • 通过学习的模型对平面上每个点给出概率值

  • 利用 contour方法绘出概率值

Code
import MLJ:fit!,predict,predict_mode,predict_mean,machine
using MLJ,GLMakie,Random,DataFrames

iris = load_iris(); 
iris =DataFrame(iris);
nums=100

iris[!, :target] = [r.target == "virginica" ? 1.0 : 0.0 for r in eachrow(iris)]
iris=coerce(iris, :target=> Multiclass )
gdf=groupby(iris, :target)
X,y=iris[:,3:4],iris[:,:target]

cats=levels(y)

function boundary_data(df,;n=nums)
    n1=n2=n
    xlow,xhigh=extrema(df[:,1])
    ylow,yhigh=extrema(df[:,2])
    tx = LinRange(xlow,xhigh,n1)
    ty = LinRange(ylow,yhigh,n2)
    x_test = mapreduce(collect, hcat, Iterators.product(tx, ty));
    x_test=MLJ.table(x_test')
    return tx,ty,x_test
end
tx,ty,x_test=boundary_data(X)


using MLJ
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels
#X, y = make_blobs(centers = 2)
mach = fit!(machine(LogisticClassifier(), X, y))
predict(mach, X)
fitted_params(mach)
probs=predict(mach, x_test)|>Array #返回分类概率值
probs_res=broadcast(pdf, probs, 1.0).|>(d->round(d,digits=2))|>d->reshape(d,nums,nums) #返回概率为1.0("virginica")的概率值
[ Info: For silent loading, specify `verbosity=0`. 
[ Info: Training machine(LogisticClassifier(lambda = 2.220446049250313e-16, …), …).
┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, NamedTuple{(), Tuple{}}}
│   optim_options: Optim.Options{Float64, Nothing}
└   lbfgs_options: NamedTuple{(), Tuple{}} NamedTuple()
import MLJLinearModels ✔
100×100 Matrix{Float64}:
 0.0   0.0   0.0   0.0   0.0   0.0   …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0   …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0   …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0   0.0   0.0   0.0   0.0   0.0      0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                             ⋮     ⋱            ⋮                   
 0.0   0.0   0.0   0.0   0.0   0.0      1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0   0.0   0.0   0.0   0.0   0.0      1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0   0.0   0.0   0.0   0.0   0.0   …  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0   0.0   0.0   0.0   0.0   0.0      1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0   0.0   0.0   0.0   0.0   0.0      1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0   0.0   0.0   0.0   0.0   0.0      1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0   0.0   0.0   0.0   0.01  0.01     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0   0.0   0.0   0.01  0.01  0.01  …  1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.0   0.0   0.01  0.01  0.01  0.01     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.01  0.01  0.01  0.01  0.01  0.02     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.01  0.01  0.01  0.02  0.02  0.03     1.0  1.0  1.0  1.0  1.0  1.0  1.0
 0.01  0.01  0.02  0.02  0.03  0.04     1.0  1.0  1.0  1.0  1.0  1.0  1.0

plot results

Code
    fig=Figure(resolution=(1600,800))

    function plot_res_contour()

        ax = Axis(fig[1, 1], xlabel="petal-length", ylabel="petal-width", title="2d2class-contour")
        contour!(ax, tx, ty, probs_res; labels=true)
        colors = [:red, :blue]
        for i in 1:2
            scatter!(ax, gdf[i][:, 3], gdf[i][:, 4], color=(colors[i], 0.8), marker=:circle, markersize=10, strokewidth=1, strokecolor=:black, label=gdf[i][1, 5] == 1 ? "virginica" : "non-virginica")
        end
        axislegend(ax, position=:lt)
        #save("./imgs/iris-logreg-2d-2class-contourf.png",fig)
        fig
    end

    function plot_res_contourf()

        ax = Axis(fig[1, 2], xlabel="petal-length", ylabel="petal-width", title="2d2class-contourf")
        contourf!(ax, tx, ty, probs_res; levels=6, colormap=(:heat, 0.5))
        #contourf!(ax,tx,ty,yhat;levels=length(cats),colormap=(:heat,0.5))
        colors = [:red, :blue]
        for i in 1:2
            scatter!(ax, gdf[i][:, 3], gdf[i][:, 4], color=(colors[i], 0.8), marker=:circle, markersize=10, strokewidth=1, strokecolor=:black, label=gdf[i][1, 5] == 1.0 ? "virginica" : "non-virginica")
        end
        axislegend(ax, position=:lt)
        #save("./imgs/iris-logreg-2d-2class-contourf.png",fig)
        fig
    end

    plot_res_contourf()
    plot_res_contour()