Skip to content

Improve docs for EntityEmbedder wrapper #303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ablaom opened this issue Apr 7, 2025 · 4 comments · May be fixed by #304
Open

Improve docs for EntityEmbedder wrapper #303

ablaom opened this issue Apr 7, 2025 · 4 comments · May be fixed by #304
Labels
documentation Improvements or additions to documentation

Comments

@ablaom
Copy link
Collaborator

ablaom commented Apr 7, 2025

There is no reference to this wrapper at all in the docs.

Also, the doc string is confusing. What is the model that is bound to the machine? (It is the output of model = EntityEmbedder(model=atomic_model).)

@ablaom ablaom added the documentation Improvements or additions to documentation label Apr 7, 2025
@ablaom
Copy link
Collaborator Author

ablaom commented Apr 7, 2025

By the way, my undestanding is this wrapper turns the atomic supervised model into an unsupervised one (with training target), so that it can be used in a pipeline and be interpreted as a transformer and not a supervised model. Perhaps the example in the docs string could explicitly give a pipeline example.

@ablaom ablaom linked a pull request Apr 8, 2025 that will close this issue
@EssamWisam
Copy link
Collaborator

EssamWisam commented May 11, 2025

There is no reference to this wrapper at all in the docs.

Image

We do mention that built-in support exists for the neural models. We likely chose not to expose it here because it may be more reachable if it's exposed with its friends in MLJTransforms (already working on MLJTransforms docs). If you like, we can then add a link to the MLJTransforms documentation page of the EntityEmbedder wrapper here (or vice versa)?

I have also just found that I created https://fluxml.ai/MLJFlux.jl/dev/common_workflows/entity_embeddings/notebook/ which in the the common workflows section.

@EssamWisam
Copy link
Collaborator

EssamWisam commented May 11, 2025

Also, the doc string is confusing. What is the model that is bound to the machine? (It is the output of model = EntityEmbedder(model=atomic_model).)

@ablaom I thought I could check with you quickly here before rushing into a PR. I added or modified lines surrounded with ++++

    EntityEmbedder(; model=mljflux_neural_model)

`EntityEmbedder` implements entity embeddings as in the "Entity Embeddings of Categorical Variables" paper by Cheng Guo, Felix Berkhahn.

# Training data

In MLJ (or MLJBase) bind an instance unsupervised `model` to data with
++++
    mach = machine(embed_model, X, y)
++++

Here:
++++
- `embed_model` is an instance of `EntityEmbedder`, which wraps a supervised MLJFlux model. 
  The supervised model must be one of these: `MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier`,
  `MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`.
++++
- `X` is any table of input features supported by the model being wrapped. Features to be transformed must
   have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)` to 
   check scitypes. 

- `y` is the target, which can be any `AbstractVector` supported by the model being wrapped.

And here is an example I worked on to showcase the use of pipeline:

using MLJ
using CategoricalArrays
import Pkg; Pkg.add("MLJLIBSVMInterface")       # For SVC

# Setup some data
N = 200
X = (;
    Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
    Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))),
    Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true),
    Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
    Column5 = randn(Float32, N),
    Column6 = categorical(
        repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
    ),
)
y = categorical(repeat(["class1", "class2", "class3", "class4", "class5"], Int(N / 5)))

# Load the entity embedder, it's neural network backbone and the SVC which inherently supports
# only continuous features
EntityEmbedder = @load EntityEmbedder pkg=MLJFlux   
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
SVC = @load SVC pkg=LIBSVM              



emb = EntityEmbedder(NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2)))
clf = SVC(cost = 1.0)

pipeline = emb |> clf

# Construct machine
mach = machine(pipeline, X, y)

# Train model
fit!(mach)

# Predict
yhat = predict(mach, X)

# Transform data using model to encode categorical columns
machy = machine(emb, X, y)
fit!(machy)
Xnew = transform(machy, X)
Xnew

I wonder if it's optimal to have this in lieu of the previous example which I found to be better in terms of being more low-level and intuitive in not assuming any advanced knowledge of MLJ. What's your view? If we don't include the pipeline use here, I'd be happy to make this into a tutorial by using a real dataset and showing a plot (which would be in MLJTransforms docs).

@ablaom
Copy link
Collaborator Author

ablaom commented May 12, 2025

@EssamWisam This basically looks like the right path. I have a few suggestions, but they'll be easier to make in a PR.

Re my complaints about documentation, I only mention I could not find any reference to the wrapper, EntityEmbedder. I think all we need is to interpolate the (revised) docstring somewhere in the MLJFlux docs. I think it might be confusing to omit it here, as all the other models defined by MLJFlux also appear in the docs. Of course we can cross-reference from MLJTransforms and/or MLJ. Another argument for housing the docs here is that effective use of the wrapper really requires familiarity with the Flux API, which the other encoders don't.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants