Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[WIP] Stateful Destructive Device Slicing#2726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Draft
hunhoffe wants to merge12 commits intomain
base:main
Choose a base branch
Loading
fromdevice-slicer
Draft
Show file tree
Hide file tree
Changes fromall commits
Commits
Show all changes
12 commits
Select commitHold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 120 additions & 74 deletionspython/iron/device/device.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -9,32 +9,24 @@

from ... import ir # type: ignore
from ...dialects._aie_enum_gen import WireBundle # type: ignore
from ...dialects.aie import AIEDevice, tile,TileOp,get_target_model # type: ignore
from ...dialects.aie import AIEDevice, tile, get_target_model # type: ignore
from ..resolvable import Resolvable
from .tile import Tile

import re


class Device(Resolvable):
"""
A base class for representations of a device of a specific type.

Note: this class is abstract because it does not implement Resolve
"""

class __DeviceTile(Resolvable):
class DeviceLike(Resolvable):
class __DeviceTile(Tile, Resolvable):
"""
Interior class for tiles objects owned by a particular device.
This is needed to ensure we don't generate more than one MLIR operation corresponding
to the same logical tile within a device.
"""

def __init__(self, col: int, row: int) -> None:
self._col: int = col
self._row: int = row
self._op: TileOp | None = None
super().__init__()
Tile.__init__(self, col, row)
Resolvable.__init__(self)

def resolve(
self,
Expand All@@ -44,55 +36,28 @@ def resolve(
) -> None:
if self._op == None:
self._op = tile(
self._col,
self._row,
self.col,
self.row,
loc=loc,
ip=ip,
allocation_scheme=allocation_scheme,
)

@property
def op(self) -> TileOp:
if not self._op:
raise ValueError("Cannot get operation before it is set.")
return self._op

@op.setter
def op(self, op: TileOp):
if self._op:
raise ValueError("Cannot set operation more than once.")
self._op = op

def __init__(self, device: AIEDevice) -> None:
"""Initialize a representation of a device.

Args:
device (AIEDevice): aie device
"""
def __init__(
self, device: AIEDevice, tiles: list[list[Tile]] | None = None
) -> None:
self._device = device
self._tiles: list[list[Device.__DeviceTile]] = []
self._tm = get_target_model(device)
for c in range(self._tm.columns()):
self._tiles.append([])
for r in range(self._tm.rows()):
self._tiles[c].append(Device.__DeviceTile(c, r))

def tile_iterator(self) -> Generator[Tile, None, None]:
"""
Iterates over the device tiles deterministically
"""
for c in range(self._tm.columns()):
for r in range(self._tm.rows()):
yield self._tiles[c][r]
return None

@property
def rows(self) -> int:
return self._tm.rows()

@property
def cols(self) -> int:
return self._tm.columns()
if tiles is None:
self._tiles: list[list[DeviceLike.__DeviceTile]] = []
for c in range(self._tm.columns()):
self._tiles.append([])
for r in range(self._tm.rows()):
self._tiles[c].append(DeviceLike.__DeviceTile(c, r))
else:
self._tiles = tiles
self.ncols = len(self._tiles)
self.nrows = len(self._tiles[0]) if self.ncols else 0

def get_shim_tiles(self) -> list[Tile]:
"""Returns a list of all shim tiles on the device.
Expand All@@ -101,9 +66,9 @@ def get_shim_tiles(self) -> list[Tile]:
list[Tile]: A list of shim tiles.
"""
return [
Tile(t._col, t._row)
t
for t in self.tile_iterator()
if self._tm.is_shim_noc_or_pl_tile(t._col, t._row)
if self._tm.is_shim_noc_or_pl_tile(t.col, t.row)
]

def get_mem_tiles(self) -> list[Tile]:
Expand All@@ -112,11 +77,7 @@ def get_mem_tiles(self) -> list[Tile]:
Returns:
list[Tile]: A list of mem tiles.
"""
return [
Tile(t._col, t._row)
for t in self.tile_iterator()
if self._tm.is_mem_tile(t._col, t._row)
]
return [t for t in self.tile_iterator() if self._tm.is_mem_tile(t.col, t.row)]

def get_compute_tiles(self) -> list[Tile]:
"""Returns a list of all compute tiles on the device.
Expand All@@ -125,9 +86,9 @@ def get_compute_tiles(self) -> list[Tile]:
list[Tile]: A list of compute tiles.
"""
return [
Tile(t._col, t._row)
Tile(t.col, t.row)
for t in self.tile_iterator()
if self._tm.is_core_tile(t._col, t._row)
if self._tm.is_core_tile(t.col, t.row)
]

def get_num_source_switchbox_connections(self, t: Tile) -> int:
Expand DownExpand Up@@ -237,29 +198,114 @@ def resolve_tile(
tile: Tile,
loc: ir.Location | None = None,
ip: ir.InsertionPoint | None = None,
) -> None:
self._tiles[tile.col][tile.row].resolve(loc, ip, tile.allocation_scheme)
):
self._tiles[tile.col][tile.row].resolve(
loc, ip, getattr(tile, "allocation_scheme", None)
)
tile.op = self._tiles[tile.col][tile.row].op

def tile_iterator(self) -> Generator[Tile, None, None]:
"""
Iterates over the available device tiles deterministically
"""
for c in range(self._tm.columns()):
for r in range(self._tm.rows()):
yield self._tiles[c][r]


class DeviceView(DeviceLike):
def __init__(self, device: "Device", tiles: list[Tile]):
super().__init__(device=device._device, tiles=tiles)
self._device_instance = device
self._coords = set()
for col_tiles in tiles:
for t in col_tiles:
self._coords.add((t.col, t.row))

def tile_iterator(self) -> Generator[Tile, None, None]:
# Keep ordering consistent from the device we sliced from.
for c in range(len(self._tiles)):
for r in range(len(self._tiles[0])):
yield self._tiles[c][r]


class Device(DeviceLike):
"""
A base class for representations of a device of a specific type.
"""

def __init__(self, device: AIEDevice) -> None:
"""Initialize a representation of a device.

Args:
device (AIEDevice): aie device
"""
super().__init__(device=device)
self._claimed_tiles = set()

def __getitem__(self, key):
if isinstance(key, tuple):
if len(key) > 2:
raise IndexError("Only 2D slicing is supported for devices.")
if len(key) == 2:
col_slice, row_slice = key
else:
col_slice, row_slice = key[0], slice(None, None, None)
elif isinstance(key, (int, slice)):
col_slice, row_slice = key, slice(None, None, None)
else:
raise IndexError(
"Device indices must be integers, slices, or a 2-tuple of those."
)

if isinstance(col_slice, int) and isinstance(row_slice, int):
if col_slice >= self._tm.columns() or row_slice >= self._tm.rows():
raise IndexError("Tile index out of range.")

# Handle slices and integers for cols
if isinstance(col_slice, int):
cols = [col_slice]
else:
cols = range(self._tm.columns())[col_slice]

# Handle slices and integers for rows
if isinstance(row_slice, int):
rows = [row_slice]
else:
rows = range(self._tm.rows())[row_slice]

if not cols or not rows:
return DeviceView(self, [])

tiles_to_claim = []
coords_to_claim = set()
for c in cols:
tiles_to_claim.append([])
for r in rows:
if (c, r) in self._claimed_tiles:
raise ValueError(f"Tile ({c}, {r}) has already been claimed.")
coords_to_claim.add((c, r))
tiles_to_claim[-1].append(self._tiles[c][r])

self._claimed_tiles.update(coords_to_claim)
return DeviceView(self, tiles_to_claim)

def tile_iterator(self) -> Generator[Tile, None, None]:
for t in super().tile_iterator():
if (t.col, t.row) not in self._claimed_tiles:
yield self._tiles[t.col][t.row]


def create_class(class_name, device):

def _device__init__(self) -> None:
super(globals()[class_name], self).__init__(device=device)

def _device_resolve(
self,
loc: ir.Location | None = None,
ip: ir.InsertionPoint | None = None,
) -> None:
return device

globals()[class_name] = type(
class_name,
(Device,),
{
"__init__": _device__init__,
"resolve": _device_resolve,
"__doc__": f"A representation of a device that resolves to {device}",
},
)
Expand Down
7 changes: 3 additions & 4 deletionspython/iron/hostruntime/config.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -43,7 +43,7 @@ def detect_npu_device():
"RyzenAI-npu6",
]
):
return NPU2()
return NPU2
elif any(
keyword.lower() in output.lower()
for keyword in [
Expand All@@ -52,7 +52,7 @@ def detect_npu_device():
"RyzenAI-npu1",
]
):
return NPU1()
return NPU1
else:
raise RuntimeError("No supported NPU device found.")

Expand DownExpand Up@@ -89,5 +89,4 @@ def get_current_device():
global config
if "device" not in config:
config["device"] = detect_npu_device()

return config["device"]
return config["device"]() if config["device"] else None
19 changes: 2 additions & 17 deletionspython/iron/placers.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -7,7 +7,6 @@
# (c) Copyright 2024 Advanced Micro Devices, Inc.

from abc import ABCMeta, abstractmethod
from typing import Optional
import statistics

from .device import Device
Expand DownExpand Up@@ -55,9 +54,8 @@ class SequentialPlacer(Placer):
tiles in a row-wise direction up to the defined limit then move to the next column for subsequent placement.
"""

def __init__(self, cores_per_col: Optional[int] = None):
def __init__(self):
super().__init__()
self.cores_per_col = cores_per_col

def make_placement(
self,
Expand DownExpand Up@@ -97,19 +95,6 @@ def make_placement(
)
computes.remove(worker.tile)

# Shorten the list of compute tiles available if the cores per column value is set
if self.cores_per_col is not None:
unused_computes_at_col = {
column: [tile for tile in computes if tile.col == column]
for column in range(device.cols)
}
computes = []
for col, tiles in unused_computes_at_col.items():
if len(tiles) < self.cores_per_col:
raise ValueError(f"Not enough compute tiles at column {col}!")
else:
computes.extend(tiles[: self.cores_per_col])

for worker in workers:
if worker.tile == AnyComputeTile:
if compute_idx >= len(computes):
Expand DownExpand Up@@ -261,7 +246,7 @@ def _find_col_match(self, col: int, tiles: list[Tile], device: Device) -> Tile:
The column is increased until a tile is found in the device, or an error is signaled.
"""
new_col = col
while new_col < device.cols:
while new_col < device.ncols:
for t in tiles:
if t.col == new_col:
return t
Expand Down
5 changes: 0 additions & 5 deletionspython/iron/program.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -46,11 +46,6 @@ def resolve_program(self, placer: Placer | None = None, device_name="main"):
module (Module): The module containing the MLIR context information.
"""
with mlir_mod_ctx() as ctx:
# Create a fresh device instance of the same type to avoid stale MLIR operations
# This preserves the device configuration while ensuring clean state
device_type = type(self._device)
# For dynamically created device classes, the constructor takes no arguments
self._device = device_type()

@device(self._device.resolve(), sym_name=device_name)
def device_body():
Expand Down
Loading
Loading

[8]ページ先頭

©2009-2025 Movatter.jp