Leveraging Semi-Supervised Concept-based Models with CME
CME relies on a similar observation highlighted in , where it was observed that vanilla CNN models often retain a high amount of information pertaining to concepts in their hidden space, which may be used for concept information mining at no extra annotation cost. Importantly, this work considered the scenario where the underlying concepts are unknown, and had to be extracted from a model’s hidden space in an unsupervised fashion.
With CME, we make use of the above observation, and consider a scenario where we have knowledge of the underlying concepts, but we only have a small amount of sample annotations for each these concepts. Similarly to , CME relies on a given pre-trained vanilla CNN and the small amount of concept annotations in order to extract further concept annotations in a semi-supervised fashion, as shown below:
As shown above, CME extracts the concept representation using a pre-trained model’s hidden space in a post-hoc fashion. Further details are given below.
Concept Encoder Training: instead of training concept encoders from scratch on the raw data, as done in case of CBMs, we setup our concept encoder model training in a semi-supervised fashion, using the vanilla CNN’s hidden space:
- We begin by pre-specifying a set of layers L from the vanilla CNN to use for concept extraction. This can range from all layers, to just the last few, depending on available compute capacity.
- Next, for each concept, we train a separate model on top of the hidden space of each layer in L to predict that concept’s values from the layer’s hidden space
- We proceed to selecting the model and corresponding layer with the best model accuracy as the “best” model and layer for predicting that concept.
- Consequently, when making concept predictions for a concept i, we first retrieve the hidden space representation of the best layer for that concept, and then pass it through the corresponding predictive model for inference.
Overall, the concept encoder function can be summarised as follows (assuming there are k concepts in total):
- Here, p-hat on the LHS represents the concept encoder function
- The gᵢ terms represent the hidden-space-to-concept models trained on top of the different layer hidden spaces, with i representing the concept index, ranging from 1 to k. In practice, these models can be fairly simple, such as Linear Regressors, or Gradient Boosted Classifiers
- The f(x) terms represent the sub-models of the original vanilla CNN, extracting the input’s hidden representation at a particular layer
- In both cases above, lʲ superscripts specify the “best” layers these two models are operating on
Concept Processor Training: concept processor model training in CME is setup by training models using task labels as outputs, and concept encoder predictions as inputs. Importantly, these models are operating on a much more compact input representation, and can consequently be represented directly via interpretable models, such as Decision Trees (DTs), or Logistic Regression (LR) models.