Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions .github/workflows/xpu-compile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name: xpu-compile

on:
workflow_dispatch:
schedule:
- cron: "0 0 * * *"
pull_request:
paths:
- ".github/workflows/xpu-compile.yml"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read
issues: write

jobs:
compile-tests:
runs-on: [self-hosted, intel, xpu]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YizhouZ - does it make more sense to add this to its own workflow, or to add it to the existing xpu workflow?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YizhouZ - does it make more sense to add this to its own workflow, or to add it to the existing xpu workflow?

you mean adding a new worker for this workflow?

container:
image: intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04
ports:
- 80
options: --privileged -it --rm --device /dev/dri:/dev/dri -v /dev/dri/by-path:/dev/dri/by-path --ipc=host --cap-add=ALL

steps:
- uses: actions/checkout@v4
- name: Install prerequisite
run: |
apt-get update
apt-get install clinfo libaio-dev python3-pip -y
pip install torch==2.3.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torch/
pip install intel-extension-for-pytorch==2.3.110+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/intel-extension-for-pytorch/
pip install oneccl_bind_pt==2.3.100+xpu -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/oneccl-bind-pt/
pip install torchvision==0.18.1 -f https://pytorch-extension.intel.com/release-whl/stable/xpu/us/torchvision/
pip install https://github.com/intel/intel-xpu-backend-for-triton/releases/download/v3.0.0b2/triton_xpu-3.0.0b2-cp310-cp310-linux_x86_64.whl
pip install py-cpuinfo numpy
pip install .[dev,autotuning]

- name: Check container state
run: |
ldd --version
ds_report
python3 -c "import torch; print('torch:', torch.__version__, torch)"
python3 -c "import torch; import intel_extension_for_pytorch; print('XPU available:', torch.xpu.is_available())"
python3 -c "from deepspeed.accelerator import get_accelerator; print('accelerator:', get_accelerator()._name)"
pip list

- name: Compile Status
shell: bash
run: |
export FI_HMEM=system
ulimit -n 1048575
cd tests/torch_compile
export ZE_AFFINITY_MASK=0,1
deepspeed test_compile.py --deepspeed_config ds_config.json 2>&1 | tee log.txt
cat log.txt | grep "'graph_breaks'" | sed 's/,/ /g' | awk '{print $2}' >> $GITHUB_STEP_SUMMARY
41 changes: 41 additions & 0 deletions tests/torch_compile/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
{
"train_batch_size": 8,
"steps_per_print": 2000,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.001,
"betas": [
0.8,
0.999
],
"eps": 1e-8,
"weight_decay": 3e-7
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": 0.001,
"warmup_num_steps": 1000
}
},
"gradient_clipping": 1.0,
"prescale_gradients": false,
"bf16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1,
"initial_scale_power": 15
},
"wall_clock_breakdown": false,
"zero_optimization": {
"stage": 3,
"reduce_scatter": true,
"overlap_comm": false,
"contiguous_gradients": false
}
}
99 changes: 99 additions & 0 deletions tests/torch_compile/test_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import argparse
import deepspeed
from deepspeed.accelerator import get_accelerator
from deepspeed import comm

import torch
import intel_extension_for_pytorch # noqa: F401 # type: ignore
from torch.utils.data import Dataset, DataLoader

torch._dynamo.config.cache_size_limit = 100

import collections


def get_dynamo_stats():
# TODO: consider deepcopy'ing the entire counters struct and
# adding a helper to do subtraction on it
return collections.Counter({
"calls_captured": torch._dynamo.utils.counters["stats"]["calls_captured"],
"unique_graphs": torch._dynamo.utils.counters["stats"]["unique_graphs"],
"graph_breaks": sum(torch._dynamo.utils.counters["graph_break"].values()),
# NB: The plus removes zero counts
"unique_graph_breaks": len(+torch._dynamo.utils.counters["graph_break"]),
"autograd_captures": torch._dynamo.utils.counters["compiled_autograd"]["captures"],
"autograd_compiles": torch._dynamo.utils.counters["compiled_autograd"]["compiles"],
"cudagraph_skips": torch._dynamo.utils.counters["inductor"]["cudagraph_skips"],
})


class RandomDataset(Dataset):

def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size).to(torch.bfloat16)

def __getitem__(self, index):
return self.data[index]

def __len__(self):
return self.len


data_size = 1024
data_length = 100
rand_loader = DataLoader(dataset=RandomDataset(data_size, data_length), batch_size=1, shuffle=False)


class MyModule(torch.nn.Module):

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.fc0 = torch.nn.Linear(1024, 256, bias=False)
self.fc1 = torch.nn.Linear(256, 256, bias=False)
self.dropout = torch.nn.Dropout(0.5)

def forward(self, data, residual):
output = residual + self.fc1(self.fc0(self.dropout(data))) * 0.5
return output


model = MyModule()
params = model.parameters()

parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
parser.add_argument('--deepspeed_config',
type=str,
default='ds_config.json',
help='path to DeepSpeed configuration file')
cmd_args = parser.parse_args()

# initialize the DeepSpeed engine
model_engine, optimizer, _, _ = deepspeed.initialize(args=cmd_args, model=model, model_parameters=params)
model_engine.compile()

residual = torch.rand(256, 256, dtype=torch.float).to(get_accelerator().current_device_name())

start_stats = get_dynamo_stats()

for step, batch in enumerate(rand_loader):
if step % 10 == 0 and comm.get_rank() == 0:
print(f'step={step}')
# forward() method
loss = model_engine(batch.to(get_accelerator().current_device_name()), residual).sum()
# runs backpropagation
model_engine.backward(loss)
# weight update
model_engine.step()

dynamo_stats = get_dynamo_stats()
dynamo_stats.subtract(start_stats)

if comm.get_rank() == 0:
print(dynamo_stats)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YizhouZ - could you run the pre-commit formatter on the PR, that will resolve the formatting/python issues.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YizhouZ - could you run the pre-commit formatter on the PR, that will resolve the formatting/python issues.

sure!