Exploring the Deep Fusion of Large Language Models and Diffusion Transformers for Text-to-Image Synthesis
Create a virtual environment (Python~=3.10):
conda create -n fuse-dit python=3.10
conda activate fuse-ditClone the repository:
git clone https://github.com/tang-bd/fuse-dit.git
cd fuse-ditInstall the dependencies:
pip install -r requirements.txtFor TPU devices, additionally install PyTorch/XLA (~=2.5.0):
pip install torch~=2.5.0 'torch_xla[tpu]~=2.5.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.htmlDownload the dataset and synthetic captions, then merge them into WebDataset .tar format:
python utils/process_cc12m.py --dataset_path /path/to/dataset/ \
--captions_path /path/to/synthetic/captions/ --output_dir /output/dir/Download the text file containing the dataset links, then use the provided script to download the full dataset:
python utils/download_sa1b.py --input_file /path/to/sa1b/links \
--raw_dir /downloaded/files/dir/ --images_dir /extracted/files/dir/Download the synthetic captions, then merge them with the full dataset into WebDataset .tar format:
python utils/process_sa1b.py --dataset_path /extracted/files/dir/ \
--captions_path /path/to/synthetic/captions/ --output_dir /output/dir/Download the dataset and convert it into WebDataset .tar format:
python utils/clean_journeydb.py \
--input_path /path/to/dataset/data/train/train_anno_realease_repath.jsonl \
--output_dir /output/dir
python utils/process_journeydb.py \
--dataset_path /path/to/dataset/data/train/imgs \
--output_dir /output/dirTo ensure reproducibility, we recommend training on TPU devices. The research experiments were conducted using TPU v4-256 pods.
gcloud alpha compute tpus queued-resources create <your-tpu-pod> --node-id=<your-tpu-pod> \
--zone=<your-zone> \
--project=<your-project> \
--accelerator-type=v4-256 \
--runtime-version=tpu-ubuntu2204-base \
--best-effortTo launch training on TPU devices:
gcloud alpha compute tpus tpu-vm ssh <your-tpu-pod> --worker=all --command=" \
export XLA_USE_SPMD=1; export XLA_DISABLE_FUNCTIONALIZATION=1; \
export PT_XLA_DEBUG_LEVEL=1; cd fuse-dit; \
python train.py -c /training/config.yaml"Configuration files used in the research experiments are provided in configs. For local training, adapt these files to suit your environment. Note that for TPU devices, the batch_size parameter is specified per node, not per chip.
To launch training on GPU devices:
deepspeed train.py -c /training/config.yamlAdditional DeepSpeed configuration is required for multi-node training.
Before running inference, convert the trained model checkpoint into diffusers pipeline format:
python utils/save_pipeline.py --checkpoint /path/to/checkpoint/ \
--trainer <model-trainer> --type <model-type> --compression<model-trainer>:spmd(TPU) ordeepspeed(GPU), based on the training setup.<model-type>:baseline-ditorfuse-dit, depending on the model architecture.
For convenience, our pre-trained model can be downloaded directly here as well.
Example inference code is provided in inference.py:
python inference.py --checkpoint_path /path/to/pipeline/ --prompt "your prompt" \
--resolution 512 \
--num_inference_steps 25 \
--guidance_scale 6.0 \
--save_path /save/path.jpgFollow the instructions in the official GenEval repository to set up the benchmark, and use the provided script with appropriate configuration to sample images.
accelerate launch evaluation/sample_geneval.py evaluation/geneval.yamlThen follow the official instructions to evaluate the sampled images.
Follow the instructions in the official DPG-Bench repository to set up the benchmark, and use the provided script with appropriate configuration to sample images.
accelerate launch evaluation/sample_dpgbench.py evaluation/dpgbench.yamlThen follow the official instructions to evaluate the sampled images.
Download the MJHQ-30K dataset, and use the provided script with appropriate configuration to sample images.
accelerate launch evaluation/sample_mjhq.py \
--checkpoint /path/to/pipeline/ \
--model_type <model-type> \
--prompts /path/to/mjhq/meta_data.json \
--resolution 512 \
--num_inference_steps 25 \
--guidance_scale 6.0 \
--batch_size <batch-size> \
--save_dir /save/dir/Then compute the FID score:
python evaluation/fid.py --real_images /path/to/mjhq/imgs/ --fake_images /save/dir/If you find our work useful for your your research and applications, please cite using this BibTeX:
@article{tang2025exploringdeepfusion,
title={Exploring the Deep Fusion of Large Language Models and Diffusion Transformers for Text-to-Image Synthesis},
author={Bingda Tang and Boyang Zheng and Xichen Pan and Sayak Paul and Saining Xie},
year={2025},
journal={arXiv preprint arXiv:2505.10046},
}If you have any questions or suggestions, please feel free to contact: tangbd2003@gmail.com.
We also thank Shusheng Yang, Shengbang Tong, Wenhao Chai, Nanye Ma, Sihan Xu, and Chenyu Li for insightful discussions. This work was mainly supported by the Google TPU Research Cloud (TRC) program, the Google Cloud Research Credits program (GCP19980904), Open Path AI Foundation, and Lambda Labs.