Skip to content

Pruning: Keras subclassed model increased support  #155

@alanchiao

Description

@alanchiao

Currently the pruning API will throw an error when a subclassed model is passed to it. Users can get around this by diving into the subclassed models and applying pruning to individual Sequential/Functional models and tf.keras layers.

Better support is important for various cases (e.g. Object Detection, BERT examples) and issues such as this one.

We can provide better support for pruning an entire subclassed model.

  • Pruning some layers of the model would still require going into the model definition itself, though now you can prune a whole subclassed
    model inside a subclassed model.
  • This would only prune variables that live inside a tf.keras.Layer (whether a built-in layer or a custom layer using a the PrunableLayer interface).

Implementation-wise, we can iterate through the layers of a subclassed model (and nested models) and applying pruning to all of them. Replacing a layer in an already created model will be tricky and we'd have to do this without clone_model.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions