# BertGeneration

[BertGeneration](https://huggingface.co/papers/1907.12461) leverages pretrained BERT checkpoints for sequence-to-sequence tasks with the [EncoderDecoderModel](/docs/transformers/v4.57.0/en/model_doc/encoder-decoder#transformers.EncoderDecoderModel) architecture. BertGeneration adapts the `BERT` for generative tasks.

You can find all the original BERT checkpoints under the [BERT](https://huggingface.co/collections/google/bert-release-64ff5e7a4be99045d1896dbc) collection.

> [!TIP]
> This model was contributed by [patrickvonplaten](https://huggingface.co/patrickvonplaten).
>
> Click on the BertGeneration models in the right sidebar for more examples of how to apply BertGeneration to different sequence generation tasks.

The example below demonstrates how to use BertGeneration with [EncoderDecoderModel](/docs/transformers/v4.57.0/en/model_doc/encoder-decoder#transformers.EncoderDecoderModel) for sequence-to-sequence tasks.

<hfoptions id="usage">
<hfoption id="Pipeline">

```python
import torch
from transformers import pipeline

pipeline = pipeline(
    task="text2text-generation",
    model="google/roberta2roberta_L-24_discofuse",
    dtype=torch.float16,
    device=0
)
pipeline("Plants create energy through ")
```

</hfoption>
<hfoption id="AutoModel">

```python
import torch
from transformers import EncoderDecoderModel, AutoTokenizer

model = EncoderDecoderModel.from_pretrained("google/roberta2roberta_L-24_discofuse", dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_discofuse")

input_ids = tokenizer(
    "Plants create energy through ", add_special_tokens=False, return_tensors="pt"
).input_ids

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
```

</hfoption>
<hfoption id="transformers CLI">

```bash
echo -e "Plants create energy through " | transformers run --task text2text-generation --model "google/roberta2roberta_L-24_discofuse" --device 0
```

</hfoption>
</hfoptions>

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends.

The example below uses [BitsAndBytesConfig](../quantizationbitsandbytes) to quantize the weights to 4-bit.

```python
import torch
from transformers import EncoderDecoderModel, AutoTokenizer, BitsAndBytesConfig

# Configure 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

model = EncoderDecoderModel.from_pretrained(
    "google/roberta2roberta_L-24_discofuse",
    quantization_config=quantization_config,
    dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained("google/roberta2roberta_L-24_discofuse")

input_ids = tokenizer(
    "Plants create energy through ", add_special_tokens=False, return_tensors="pt"
).input_ids

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))
```

## Notes

- [BertGenerationEncoder](/docs/transformers/v4.57.0/en/model_doc/bert-generation#transformers.BertGenerationEncoder) and [BertGenerationDecoder](/docs/transformers/v4.57.0/en/model_doc/bert-generation#transformers.BertGenerationDecoder) should be used in combination with [EncoderDecoderModel](/docs/transformers/v4.57.0/en/model_doc/encoder-decoder#transformers.EncoderDecoderModel) for sequence-to-sequence tasks.

   ```python
   from transformers import BertGenerationEncoder, BertGenerationDecoder, BertTokenizer, EncoderDecoderModel
   
   # leverage checkpoints for Bert2Bert model
   # use BERT's cls token as BOS token and sep token as EOS token
   encoder = BertGenerationEncoder.from_pretrained("google-bert/bert-large-uncased", bos_token_id=101, eos_token_id=102)
   # add cross attention layers and use BERT's cls token as BOS token and sep token as EOS token
   decoder = BertGenerationDecoder.from_pretrained(
       "google-bert/bert-large-uncased", add_cross_attention=True, is_decoder=True, bos_token_id=101, eos_token_id=102
   )
   bert2bert = EncoderDecoderModel(encoder=encoder, decoder=decoder)

   # create tokenizer
   tokenizer = BertTokenizer.from_pretrained("google-bert/bert-large-uncased")

   input_ids = tokenizer(
       "This is a long article to summarize", add_special_tokens=False, return_tensors="pt"
   ).input_ids
   labels = tokenizer("This is a short summary", return_tensors="pt").input_ids

   # train
   loss = bert2bert(input_ids=input_ids, decoder_input_ids=labels, labels=labels).loss
   loss.backward()
   ```

- For summarization, sentence splitting, sentence fusion and translation, no special tokens are required for the input.
- No EOS token should be added to the end of the input for most generation tasks.

## BertGenerationConfig[[transformers.BertGenerationConfig]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class transformers.BertGenerationConfig</name><anchor>transformers.BertGenerationConfig</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/bert_generation/configuration_bert_generation.py#L20</source><parameters>[{"name": "vocab_size", "val": " = 50358"}, {"name": "hidden_size", "val": " = 1024"}, {"name": "num_hidden_layers", "val": " = 24"}, {"name": "num_attention_heads", "val": " = 16"}, {"name": "intermediate_size", "val": " = 4096"}, {"name": "hidden_act", "val": " = 'gelu'"}, {"name": "hidden_dropout_prob", "val": " = 0.1"}, {"name": "attention_probs_dropout_prob", "val": " = 0.1"}, {"name": "max_position_embeddings", "val": " = 512"}, {"name": "initializer_range", "val": " = 0.02"}, {"name": "layer_norm_eps", "val": " = 1e-12"}, {"name": "pad_token_id", "val": " = 0"}, {"name": "bos_token_id", "val": " = 2"}, {"name": "eos_token_id", "val": " = 1"}, {"name": "position_embedding_type", "val": " = 'absolute'"}, {"name": "use_cache", "val": " = True"}, {"name": "**kwargs", "val": ""}]</parameters><paramsdesc>- **vocab_size** (`int`, *optional*, defaults to 50358) --
  Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
  `inputs_ids` passed when calling `BertGeneration`.
- **hidden_size** (`int`, *optional*, defaults to 1024) --
  Dimensionality of the encoder layers and the pooler layer.
- **num_hidden_layers** (`int`, *optional*, defaults to 24) --
  Number of hidden layers in the Transformer encoder.
- **num_attention_heads** (`int`, *optional*, defaults to 16) --
  Number of attention heads for each attention layer in the Transformer encoder.
- **intermediate_size** (`int`, *optional*, defaults to 4096) --
  Dimensionality of the "intermediate" (often called feed-forward) layer in the Transformer encoder.
- **hidden_act** (`str` or `function`, *optional*, defaults to `"gelu"`) --
  The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  `"relu"`, `"silu"` and `"gelu_new"` are supported.
- **hidden_dropout_prob** (`float`, *optional*, defaults to 0.1) --
  The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
- **attention_probs_dropout_prob** (`float`, *optional*, defaults to 0.1) --
  The dropout ratio for the attention probabilities.
- **max_position_embeddings** (`int`, *optional*, defaults to 512) --
  The maximum sequence length that this model might ever be used with. Typically set this to something large
  just in case (e.g., 512 or 1024 or 2048).
- **initializer_range** (`float`, *optional*, defaults to 0.02) --
  The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
- **layer_norm_eps** (`float`, *optional*, defaults to 1e-12) --
  The epsilon used by the layer normalization layers.
- **pad_token_id** (`int`, *optional*, defaults to 0) --
  Padding token id.
- **bos_token_id** (`int`, *optional*, defaults to 2) --
  Beginning of stream token id.
- **eos_token_id** (`int`, *optional*, defaults to 1) --
  End of stream token id.
- **position_embedding_type** (`str`, *optional*, defaults to `"absolute"`) --
  Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
  positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
  [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
  For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
  with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
- **use_cache** (`bool`, *optional*, defaults to `True`) --
  Whether or not the model should return the last key/values attentions (not used by all models). Only
  relevant if `config.is_decoder=True`.</paramsdesc><paramgroups>0</paramgroups></docstring>

This is the configuration class to store the configuration of a `BertGenerationPreTrainedModel`. It is used to
instantiate a BertGeneration model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the BertGeneration
[google/bert_for_seq_generation_L-24_bbc_encoder](https://huggingface.co/google/bert_for_seq_generation_L-24_bbc_encoder)
architecture.

Configuration objects inherit from [PretrainedConfig](/docs/transformers/v4.57.0/en/main_classes/configuration#transformers.PretrainedConfig) and can be used to control the model outputs. Read the
documentation from [PretrainedConfig](/docs/transformers/v4.57.0/en/main_classes/configuration#transformers.PretrainedConfig) for more information.



<ExampleCodeBlock anchor="transformers.BertGenerationConfig.example">

Examples:

```python
>>> from transformers import BertGenerationConfig, BertGenerationEncoder

>>> # Initializing a BertGeneration config
>>> configuration = BertGenerationConfig()

>>> # Initializing a model (with random weights) from the config
>>> model = BertGenerationEncoder(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```

</ExampleCodeBlock>

</div>

## BertGenerationTokenizer[[transformers.BertGenerationTokenizer]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class transformers.BertGenerationTokenizer</name><anchor>transformers.BertGenerationTokenizer</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/bert_generation/tokenization_bert_generation.py#L34</source><parameters>[{"name": "vocab_file", "val": ""}, {"name": "bos_token", "val": " = '<s>'"}, {"name": "eos_token", "val": " = '</s>'"}, {"name": "unk_token", "val": " = '<unk>'"}, {"name": "pad_token", "val": " = '<pad>'"}, {"name": "sep_token", "val": " = '<::::>'"}, {"name": "sp_model_kwargs", "val": ": typing.Optional[dict[str, typing.Any]] = None"}, {"name": "**kwargs", "val": ""}]</parameters><paramsdesc>- **vocab_file** (`str`) --
  [SentencePiece](https://github.com/google/sentencepiece) file (generally has a *.spm* extension) that
  contains the vocabulary necessary to instantiate a tokenizer.
- **bos_token** (`str`, *optional*, defaults to `"<s>"`) --
  The begin of sequence token.
- **eos_token** (`str`, *optional*, defaults to `"</s>"`) --
  The end of sequence token.
- **unk_token** (`str`, *optional*, defaults to `"<unk>"`) --
  The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
  token instead.
- **pad_token** (`str`, *optional*, defaults to `"<pad>"`) --
  The token used for padding, for example when batching sequences of different lengths.
- **sep_token** (`str`, *optional*, defaults to `"< --:::>"`):
  The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
  sequence classification or for a text and a question for question answering. It is also used as the last
  token of a sequence built with special tokens.
- **sp_model_kwargs** (`dict`, *optional*) --
  Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
  SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
  to set:

  - `enable_sampling`: Enable subword regularization.
  - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.

    - `nbest_size = {0,1}`: No sampling is performed.
    - `nbest_size > 1`: samples from the nbest_size results.
    - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
      using forward-filtering-and-backward-sampling algorithm.

  - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
    BPE-dropout.</paramsdesc><paramgroups>0</paramgroups></docstring>

Construct a BertGeneration tokenizer. Based on [SentencePiece](https://github.com/google/sentencepiece).

This tokenizer inherits from [PreTrainedTokenizer](/docs/transformers/v4.57.0/en/main_classes/tokenizer#transformers.PreTrainedTokenizer) which contains most of the main methods. Users should refer to
this superclass for more information regarding those methods.





<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>save_vocabulary</name><anchor>transformers.BertGenerationTokenizer.save_vocabulary</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/bert_generation/tokenization_bert_generation.py#L159</source><parameters>[{"name": "save_directory", "val": ": str"}, {"name": "filename_prefix", "val": ": typing.Optional[str] = None"}]</parameters></docstring>


</div></div>

## BertGenerationEncoder[[transformers.BertGenerationEncoder]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class transformers.BertGenerationEncoder</name><anchor>transformers.BertGenerationEncoder</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/bert_generation/modeling_bert_generation.py#L592</source><parameters>[{"name": "config", "val": ""}]</parameters><paramsdesc>- **config** ([BertGenerationEncoder](/docs/transformers/v4.57.0/en/model_doc/bert-generation#transformers.BertGenerationEncoder)) --
  Model configuration class with all the parameters of the model. Initializing with a config file does not
  load the weights associated with the model, only the configuration. Check out the
  [from_pretrained()](/docs/transformers/v4.57.0/en/main_classes/model#transformers.PreTrainedModel.from_pretrained) method to load the model weights.</paramsdesc><paramgroups>0</paramgroups></docstring>

The bare BertGeneration model transformer outputting raw hidden-states without any specific head on top.

This model inherits from [PreTrainedModel](/docs/transformers/v4.57.0/en/main_classes/model#transformers.PreTrainedModel). Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.





<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>forward</name><anchor>transformers.BertGenerationEncoder.forward</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/bert_generation/modeling_bert_generation.py#L633</source><parameters>[{"name": "input_ids", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "attention_mask", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "position_ids", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "head_mask", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "inputs_embeds", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "encoder_hidden_states", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "encoder_attention_mask", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "past_key_values", "val": ": typing.Optional[transformers.cache_utils.Cache] = None"}, {"name": "use_cache", "val": ": typing.Optional[bool] = None"}, {"name": "output_attentions", "val": ": typing.Optional[bool] = None"}, {"name": "output_hidden_states", "val": ": typing.Optional[bool] = None"}, {"name": "return_dict", "val": ": typing.Optional[bool] = None"}, {"name": "**kwargs", "val": ""}]</parameters><paramsdesc>- **input_ids** (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.

  Indices can be obtained using [AutoTokenizer](/docs/transformers/v4.57.0/en/model_doc/auto#transformers.AutoTokenizer). See [PreTrainedTokenizer.encode()](/docs/transformers/v4.57.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.encode) and
  [PreTrainedTokenizer.__call__()](/docs/transformers/v4.57.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__) for details.

  [What are input IDs?](../glossary#input-ids)
- **attention_mask** (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

  - 1 for tokens that are **not masked**,
  - 0 for tokens that are **masked**.

  [What are attention masks?](../glossary#attention-mask)
- **position_ids** (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.

  [What are position IDs?](../glossary#position-ids)
- **head_mask** (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*) --
  Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

  - 1 indicates the head is **not masked**,
  - 0 indicates the head is **masked**.
- **inputs_embeds** (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) --
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  model's internal embedding lookup matrix.
- **encoder_hidden_states** (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) --
  Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  if the model is configured as a decoder.
- **encoder_attention_mask** (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

  - 1 for tokens that are **not masked**,
  - 0 for tokens that are **masked**.
- **past_key_values** (`~cache_utils.Cache`, *optional*) --
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

  Only [Cache](/docs/transformers/v4.57.0/en/internal/generation_utils#transformers.Cache) instance is allowed as input, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  If no `past_key_values` are passed, [DynamicCache](/docs/transformers/v4.57.0/en/internal/generation_utils#transformers.DynamicCache) will be initialized by default.

  The model will output the same cache format that is fed as input.

  If `past_key_values` are used, the user is expected to input only unprocessed `input_ids` (those that don't
  have their past key value states given to this model) of shape `(batch_size, unprocessed_length)` instead of all `input_ids`
  of shape `(batch_size, sequence_length)`.
- **use_cache** (`bool`, *optional*) --
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  `past_key_values`).
- **output_attentions** (`bool`, *optional*) --
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  tensors for more detail.
- **output_hidden_states** (`bool`, *optional*) --
  Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  more detail.
- **return_dict** (`bool`, *optional*) --
  Whether or not to return a [ModelOutput](/docs/transformers/v4.57.0/en/main_classes/output#transformers.utils.ModelOutput) instead of a plain tuple.</paramsdesc><paramgroups>0</paramgroups><rettype>[transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions](/docs/transformers/v4.57.0/en/main_classes/output#transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions) or `tuple(torch.FloatTensor)`</rettype><retdesc>A [transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions](/docs/transformers/v4.57.0/en/main_classes/output#transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions) or a tuple of
`torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
elements depending on the configuration ([BertGenerationConfig](/docs/transformers/v4.57.0/en/model_doc/bert-generation#transformers.BertGenerationConfig)) and inputs.

- **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) -- Sequence of hidden-states at the output of the last layer of the model.

  If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,
  hidden_size)` is output.
- **past_key_values** (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`) -- It is a [Cache](/docs/transformers/v4.57.0/en/internal/generation_utils#transformers.Cache) instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

  Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if
  `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`
  input) to speed up sequential decoding.
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

  Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  sequence_length)`.

  Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  heads.
- **cross_attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  sequence_length)`.

  Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
  weighted average in the cross-attention heads.</retdesc></docstring>
The [BertGenerationEncoder](/docs/transformers/v4.57.0/en/model_doc/bert-generation#transformers.BertGenerationEncoder) forward method, overrides the `__call__` special method.

<Tip>

Although the recipe for forward pass needs to be defined within this function, one should call the `Module`
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.

</Tip>








</div></div>

## BertGenerationDecoder[[transformers.BertGenerationDecoder]]

<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>class transformers.BertGenerationDecoder</name><anchor>transformers.BertGenerationDecoder</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/bert_generation/modeling_bert_generation.py#L765</source><parameters>[{"name": "config", "val": ""}]</parameters><paramsdesc>- **config** ([BertGenerationDecoder](/docs/transformers/v4.57.0/en/model_doc/bert-generation#transformers.BertGenerationDecoder)) --
  Model configuration class with all the parameters of the model. Initializing with a config file does not
  load the weights associated with the model, only the configuration. Check out the
  [from_pretrained()](/docs/transformers/v4.57.0/en/main_classes/model#transformers.PreTrainedModel.from_pretrained) method to load the model weights.</paramsdesc><paramgroups>0</paramgroups></docstring>

BertGeneration Model with a `language modeling` head on top for CLM fine-tuning.

This model inherits from [PreTrainedModel](/docs/transformers/v4.57.0/en/main_classes/model#transformers.PreTrainedModel). Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)

This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.





<div class="docstring border-l-2 border-t-2 pl-4 pt-3.5 border-gray-100 rounded-tl-xl mb-6 mt-8">


<docstring><name>forward</name><anchor>transformers.BertGenerationDecoder.forward</anchor><source>https://github.com/huggingface/transformers/blob/v4.57.0/src/transformers/models/bert_generation/modeling_bert_generation.py#L787</source><parameters>[{"name": "input_ids", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "attention_mask", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "position_ids", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "head_mask", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "inputs_embeds", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "encoder_hidden_states", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "encoder_attention_mask", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "labels", "val": ": typing.Optional[torch.Tensor] = None"}, {"name": "past_key_values", "val": ": typing.Optional[transformers.cache_utils.Cache] = None"}, {"name": "use_cache", "val": ": typing.Optional[bool] = None"}, {"name": "output_attentions", "val": ": typing.Optional[bool] = None"}, {"name": "output_hidden_states", "val": ": typing.Optional[bool] = None"}, {"name": "return_dict", "val": ": typing.Optional[bool] = None"}, {"name": "**kwargs", "val": ""}]</parameters><paramsdesc>- **input_ids** (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.

  Indices can be obtained using [AutoTokenizer](/docs/transformers/v4.57.0/en/model_doc/auto#transformers.AutoTokenizer). See [PreTrainedTokenizer.encode()](/docs/transformers/v4.57.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.encode) and
  [PreTrainedTokenizer.__call__()](/docs/transformers/v4.57.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.__call__) for details.

  [What are input IDs?](../glossary#input-ids)
- **attention_mask** (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

  - 1 for tokens that are **not masked**,
  - 0 for tokens that are **masked**.

  [What are attention masks?](../glossary#attention-mask)
- **position_ids** (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.

  [What are position IDs?](../glossary#position-ids)
- **head_mask** (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*) --
  Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

  - 1 indicates the head is **not masked**,
  - 0 indicates the head is **masked**.
- **inputs_embeds** (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) --
  Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  model's internal embedding lookup matrix.
- **encoder_hidden_states** (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) --
  Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
  if the model is configured as a decoder.
- **encoder_attention_mask** (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
  the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:

  - 1 for tokens that are **not masked**,
  - 0 for tokens that are **masked**.
- **labels** (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*) --
  Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
- **past_key_values** (`~cache_utils.Cache`, *optional*) --
  Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
  blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
  returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

  Only [Cache](/docs/transformers/v4.57.0/en/internal/generation_utils#transformers.Cache) instance is allowed as input, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  If no `past_key_values` are passed, [DynamicCache](/docs/transformers/v4.57.0/en/internal/generation_utils#transformers.DynamicCache) will be initialized by default.

  The model will output the same cache format that is fed as input.

  If `past_key_values` are used, the user is expected to input only unprocessed `input_ids` (those that don't
  have their past key value states given to this model) of shape `(batch_size, unprocessed_length)` instead of all `input_ids`
  of shape `(batch_size, sequence_length)`.
- **use_cache** (`bool`, *optional*) --
  If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
  `past_key_values`).
- **output_attentions** (`bool`, *optional*) --
  Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
  tensors for more detail.
- **output_hidden_states** (`bool`, *optional*) --
  Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
  more detail.
- **return_dict** (`bool`, *optional*) --
  Whether or not to return a [ModelOutput](/docs/transformers/v4.57.0/en/main_classes/output#transformers.utils.ModelOutput) instead of a plain tuple.</paramsdesc><paramgroups>0</paramgroups><rettype>[transformers.modeling_outputs.CausalLMOutputWithCrossAttentions](/docs/transformers/v4.57.0/en/main_classes/output#transformers.modeling_outputs.CausalLMOutputWithCrossAttentions) or `tuple(torch.FloatTensor)`</rettype><retdesc>A [transformers.modeling_outputs.CausalLMOutputWithCrossAttentions](/docs/transformers/v4.57.0/en/main_classes/output#transformers.modeling_outputs.CausalLMOutputWithCrossAttentions) or a tuple of
`torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
elements depending on the configuration ([BertGenerationConfig](/docs/transformers/v4.57.0/en/model_doc/bert-generation#transformers.BertGenerationConfig)) and inputs.

- **loss** (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) -- Language modeling loss (for next-token prediction).
- **logits** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) -- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
  one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

  Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
- **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  sequence_length)`.

  Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
  heads.
- **cross_attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  sequence_length)`.

  Cross attentions weights after the attention softmax, used to compute the weighted average in the
  cross-attention heads.
- **past_key_values** (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`) -- It is a [Cache](/docs/transformers/v4.57.0/en/internal/generation_utils#transformers.Cache) instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

  Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
  `past_key_values` input) to speed up sequential decoding.</retdesc></docstring>
The [BertGenerationDecoder](/docs/transformers/v4.57.0/en/model_doc/bert-generation#transformers.BertGenerationDecoder) forward method, overrides the `__call__` special method.

<Tip>

Although the recipe for forward pass needs to be defined within this function, one should call the `Module`
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.

</Tip>







<ExampleCodeBlock anchor="transformers.BertGenerationDecoder.forward.example">

Example:

```python
>>> from transformers import AutoTokenizer, BertGenerationDecoder, BertGenerationConfig
>>> import torch

>>> tokenizer = AutoTokenizer.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
>>> config = BertGenerationConfig.from_pretrained("google/bert_for_seq_generation_L-24_bbc_encoder")
>>> config.is_decoder = True
>>> model = BertGenerationDecoder.from_pretrained(
...     "google/bert_for_seq_generation_L-24_bbc_encoder", config=config
... )

>>> inputs = tokenizer("Hello, my dog is cute", return_token_type_ids=False, return_tensors="pt")
>>> outputs = model(**inputs)

>>> prediction_logits = outputs.logits
```

</ExampleCodeBlock>

</div></div>

<EditOnGithub source="https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/bert-generation.md" />