Implement func.call emission and function caching across MLIR dialects and ZML module/ops, propagating tags and donations.

This commit is contained in:
Tarry Singh 2023-10-19 17:01:55 +00:00
parent 37de7b9613
commit 98b512c495
7 changed files with 162 additions and 135 deletions

View File

@ -13,8 +13,7 @@ pub fn func(
location: mlir.Location, location: mlir.Location,
}, },
) mlir.Operation { ) mlir.Operation {
const AttrTuple = struct { [:0]const u8, mlir.Attribute }; var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 4){};
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
if (args.arg_attrs.len > 0) { if (args.arg_attrs.len > 0) {

View File

@ -189,7 +189,7 @@ pub fn dot_general(
}, },
) mlir.Operation { ) mlir.Operation {
const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2; const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2;
const attributes = [3]mlir.Operation.AttrTuple{ const attributes = [3]mlir.AttrTuple{
.{ .{
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{ "dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
.lhs_batching_dimensions = opts.lhs_batching_dimensions, .lhs_batching_dimensions = opts.lhs_batching_dimensions,

View File

@ -333,6 +333,8 @@ pub const Identifier = struct {
} }
}; };
pub const AttrTuple = struct { [:0]const u8, Attribute };
pub const Attribute = struct { pub const Attribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace MlirHelpers(Attribute, .{ pub usingnamespace MlirHelpers(Attribute, .{
@ -791,8 +793,6 @@ pub const Operation = struct {
) orelse Error.InvalidMlir; ) orelse Error.InvalidMlir;
} }
pub const AttrTuple = struct { [:0]const u8, Attribute };
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct { pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
operands: ?[]const Value = null, operands: ?[]const Value = null,
variadic_operands: ?[]const []const Value = null, variadic_operands: ?[]const []const Value = null,

View File

@ -86,12 +86,14 @@ pub fn MapType(From: type, To: type) type {
/// For example it can convert from a comptime array to a runtime slice. /// For example it can convert from a comptime array to a runtime slice.
/// `mapAlloc` can allocate new slices to write the result if the result struct requires it. /// `mapAlloc` can allocate new slices to write the result if the result struct requires it.
/// The caller is owning said allocations, using an `ArenaAllocator` might help tracking them. /// The caller is owning said allocations, using an `ArenaAllocator` might help tracking them.
// TODO: handle tuple to slice conversion ///
/// Note: to avoid infinite loop, mapAlloc doesn't look for `From` fields inside `To` struct.
/// Any `To` struct inside `from` will be copied over to the target.
pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void { pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void {
// const Ctx = FnParam(cb, 0); // TODO: handle tuple to slice conversion
const From = FnParam(cb, 1); const From = FnParam(cb, 1);
const To = stdx.meta.FnResult(cb);
const FromStruct = @TypeOf(from); const FromStruct = @TypeOf(from);
const type_info_to_ptr = @typeInfo(@TypeOf(to)); const type_info_to_ptr = @typeInfo(@TypeOf(to));
if (type_info_to_ptr != .Pointer) { if (type_info_to_ptr != .Pointer) {
stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: {}", .{@TypeOf(to)}); stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: {}", .{@TypeOf(to)});
@ -100,11 +102,12 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
const type_info_to = @typeInfo(ToStruct); const type_info_to = @typeInfo(ToStruct);
if (FromStruct == From) { if (FromStruct == From) {
// Special case for converting from shape to tensor: // We have an issues with `Tensor` -> `Shape` -> `Tensor` conversion when compiling ZML functions where one argument is a Shape itself.
// If the target type is Shape, skip tensor conversion. // Normally we should call `cb` on all `Shape`.
// A general `to.* = from` assignment causes a Zig error in this scenario. // But the "ShapeOf" struct will have more Shape than need on the output.
// (see below) // So here we take a hint from the receiving object.
if (ToStruct == @import("shape.zig").Shape and FromStruct == ToStruct) { // FromStruct) { // If the target is indeed a Tensor, use the callback, but if the target is `Shape` just copy it over.
if (ToStruct != To and FromStruct == ToStruct) {
to.* = from; to.* = from;
} else { } else {
to.* = @call(.auto, cb, .{ ctx, from }); to.* = @call(.auto, cb, .{ ctx, from });
@ -112,19 +115,11 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
return; return;
} }
// This is generally due to a user error, but let this fn compile, if (FromStruct == To) {
// and the user will have a Zig error.
if (FromStruct == ToStruct) {
to.* = from; to.* = from;
return; return;
} }
// Don't go into Shape objects because of the weird tag.
// TODO: we could not error on pointers to basic types like u8
if (FromStruct == @import("shape.zig").Shape) {
to.* = from;
return;
}
switch (type_info_to) { switch (type_info_to) {
.Struct => |info| inline for (info.fields) |field| { .Struct => |info| inline for (info.fields) |field| {
// if (field.is_comptime) continue; // if (field.is_comptime) continue;
@ -155,10 +150,11 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
.One => switch (type_info_to_ptr.Pointer.size) { .One => switch (type_info_to_ptr.Pointer.size) {
// pointer to array -> slice promotion // pointer to array -> slice promotion
.Slice => { .Slice => {
to.* = try allocator.alloc(type_info_to_ptr.Pointer.child, from.len); const items = try allocator.alloc(type_info_to_ptr.Pointer.child, from.len);
for (from, to.*) |f, *t| { for (from, items) |f, *t| {
try mapAlloc(cb, allocator, ctx, f, t); try mapAlloc(cb, allocator, ctx, f, t);
} }
to.* = items;
}, },
else => try mapAlloc(cb, allocator, ctx, from.*, to.*), else => try mapAlloc(cb, allocator, ctx, from.*, to.*),
}, },
@ -177,7 +173,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
} else { } else {
to.* = null; to.* = null;
}, },
.Int, .Float => to.* = from, .Int, .Float, .Enum => to.* = from,
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}), else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
} }
} }

View File

@ -49,7 +49,7 @@ const Block = union(BlockKind) {
.op_result => |parent_op| self.appendOperationRecursive(parent_op), .op_result => |parent_op| self.appendOperationRecursive(parent_op),
.block_argument => |arg| { .block_argument => |arg| {
// Hermetic blocks are not allowed to use arguments from other blocks. // Hermetic blocks are not allowed to use arguments from other blocks.
std.debug.assert(self == .open or self.block().eql(arg.block())); stdx.debug.assert(self == .open or self.block().eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self.block()._inner.ptr });
}, },
.null => @panic("InvalidMlir"), .null => @panic("InvalidMlir"),
} }
@ -75,6 +75,11 @@ pub const MlirFn = struct {
res_shapes: []Shape, res_shapes: []Shape,
res_donations: []Tensor._Donation, res_donations: []Tensor._Donation,
mlir_fn: mlir.Operation, mlir_fn: mlir.Operation,
pub const Kind = enum {
main,
private,
};
}; };
pub const CompilationContext = struct { pub const CompilationContext = struct {
@ -151,7 +156,6 @@ pub const CompilationContext = struct {
pub fn deactivate(self: *CompilationContext) void { pub fn deactivate(self: *CompilationContext) void {
std.debug.assert(_current != null and _current.? == self); std.debug.assert(_current != null and _current.? == self);
_current = self._previous; _current = self._previous;
self._previous = null;
} }
pub fn current() *CompilationContext { pub fn current() *CompilationContext {
@ -182,9 +186,9 @@ pub const CompilationContext = struct {
const arena = arena_state.allocator(); const arena = arena_state.allocator();
var timer = std.time.Timer.start() catch null; var timer = std.time.Timer.start() catch null;
const tensor_args = self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args); const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args);
// Run in a dedicated thread because compilation relies on `threadlocal`. // Run in a dedicated thread because compilation relies on `threadlocal`.
const f = try asynk.callBlocking(CompilationContext.generateBytecode, .{ self, arena, "main", func, &tensor_args }); const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, arena, func, &tensor_args, .{ .name = "main", .kind = .main } });
const module = self._module; const module = self._module;
module.getBody().appendOperation(f.mlir_fn); module.getBody().appendOperation(f.mlir_fn);
@ -329,15 +333,18 @@ pub const CompilationContext = struct {
/// Generate an MLIR function from a ZML function. /// Generate an MLIR function from a ZML function.
/// The caller is responsible to have properly created the input /// The caller is responsible to have properly created the input
/// tensors with unique tensor ids. /// tensors with unique tensor ids.
pub fn generateBytecode( pub fn emitMlir(
self: *CompilationContext, self: *CompilationContext,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
fn_name: []const u8,
comptime func: anytype, comptime func: anytype,
args: *const stdx.meta.FnArgs(func), args: *const stdx.meta.FnArgs(func),
opts: struct {
name: []const u8,
kind: MlirFn.Kind = .private,
},
) error{OutOfMemory}!MlirFn { ) error{OutOfMemory}!MlirFn {
const frame = self._tracer.frameStart("generateBytecode.emit"); const frame = self._tracer.frameStart("emitMlir.emit");
errdefer self._tracer.frameEnd(frame, "generateBytecode.emit"); errdefer self._tracer.frameEnd(frame, "emitMlir.emit");
// Note: only temp allocations are done in the arena, // Note: only temp allocations are done in the arena,
// the other allocations are managed by the caller. // the other allocations are managed by the caller.
@ -371,11 +378,6 @@ pub const CompilationContext = struct {
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable; var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
{ {
defer self.closeBlock(fn_body); defer self.closeBlock(fn_body);
// Note: we could shrink self._buffer_to_arg once we called `func`.
// But for now we are only compiling one function per CompilationContext.
// So we don't need to do this since we won't reuse self._buffer_to_arg anyway.
// const n = self._buffer_to_arg.count();
// defer self._buffer_to_arg.shrinkRetainingCapacity(n);
try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count)); try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count));
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0); const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0);
@ -400,14 +402,15 @@ pub const CompilationContext = struct {
const res_attrs = try arena.alloc(AttributeList, out_tensor_count); const res_attrs = try arena.alloc(AttributeList, out_tensor_count);
@memset(res_attrs, .{}); @memset(res_attrs, .{});
// Donations attributes only make sense on the main function. if (opts.kind == .main) {
self.addDonationsAttributes(arg_attrs, fn_res_donations); self.addDonationsAttributes(arg_attrs, fn_res_donations);
if (self._platform.sharding().num_partitions > 1) { if (self._platform.sharding().num_partitions > 1) {
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes); self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
} }
}
const mlir_fn = dialect.func.func(self.mlirCtx(), .{ const mlir_fn = dialect.func.func(self.mlirCtx(), .{
.sym_name = fn_name, .sym_name = opts.name,
.args = input_types, .args = input_types,
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs), .arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
.results = fn_res_types, .results = fn_res_types,
@ -416,9 +419,9 @@ pub const CompilationContext = struct {
.location = loc, .location = loc,
}); });
self._tracer.frameEnd(frame, "generateBytecode.emit"); self._tracer.frameEnd(frame, "emitMlir.emit");
const canonicalize_frame = self._tracer.frameStart("generateBytecode.canonicalize"); const canonicalize_frame = self._tracer.frameStart("emitMlir.canonicalize");
defer self._tracer.frameEnd(canonicalize_frame, "generateBytecode.canonicalize"); defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize");
self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) { self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) {
error.InvalidMlir => { error.InvalidMlir => {
log.err("Failed to canonicalize invalid mlir: {}", .{mlir_fn.mlirFormatter(.{})}); log.err("Failed to canonicalize invalid mlir: {}", .{mlir_fn.mlirFormatter(.{})});
@ -429,7 +432,7 @@ pub const CompilationContext = struct {
return .{ return .{
.mlir_fn = mlir_fn, .mlir_fn = mlir_fn,
.name = fn_name, .name = opts.name,
.num_args = @intCast(tensor_count), .num_args = @intCast(tensor_count),
.res_types = fn_res_types, .res_types = fn_res_types,
.res_shapes = fn_res_shapes, .res_shapes = fn_res_shapes,
@ -478,7 +481,13 @@ pub const CompilationContext = struct {
const Local = struct { const Local = struct {
bias: Tensor, bias: Tensor,
pub fn forward(self: @This(), x: Tensor) Tensor { pub fn forward(self: @This(), x: Tensor, y: Tensor) [2]Tensor {
const x1 = zml.ops.call(self, .inner, .{x});
const x2 = zml.ops.call(self, .inner, .{x1});
return .{ x1.reuseBuffer(y), x2 };
}
pub fn inner(self: @This(), x: Tensor) Tensor {
const y = x.add(self.bias); const y = x.add(self.bias);
return y.reuseBuffer(x); return y.reuseBuffer(x);
} }
@ -490,20 +499,26 @@ pub const CompilationContext = struct {
var comp = try zml.module.CompilationContext.init(allocator, "test", platform); var comp = try zml.module.CompilationContext.init(allocator, "test", platform);
defer comp.deinit(); defer comp.deinit();
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } } }; var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &tensor_args); const f = try comp.emitMlir(allocator, Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{}; var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{};
try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})}); try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})});
// Check that the `x` input argument gives its buffer to the result tensor. // Check that the `x` input argument gives its buffer to the result tensor.
// `%arg0` is the bias of the model, `%arg1` is `x`. // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
try std.testing.expectEqual(2, f.num_args); try std.testing.expectEqual(3, f.num_args);
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "tf.aliasing_output = 0 : i32") != null) catch |err| { // We should have two buffers being donated.
const template = "tf.aliasing_output = {d} : i32";
var buf = template.*;
for (0..2) |i| {
const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable;
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, alias_attr) != null) catch |err| {
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items}); log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items});
return err; return err;
}; };
} }
}
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute { pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute {
const mlir_ctx = self.mlirCtx(); const mlir_ctx = self.mlirCtx();
@ -608,45 +623,77 @@ pub const CompilationContext = struct {
// first, do the "compile" and check the bytecode // first, do the "compile" and check the bytecode
// the result of this will also have the correct tags of the result shapes // the result of this will also have the correct tags of the result shapes
const dummy_result = self.generateMlirBytecodeForFunction( const args_hash = hashArgs(args);
arena, const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash };
func_name,
func,
args,
) catch unreachable; // TODO: do we like unreachable?
const bytecode_hash = hashArgs(dummy_result.bytecode_tmp);
const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = bytecode_hash };
const function = self._fn_cache.getEntry(key) orelse b: { const function = self._fn_cache.getEntry(key) orelse b: {
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name)) const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
arena.dupeZ(u8, func_name) catch unreachable arena.dupeZ(u8, func_name) catch unreachable
else else
std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable; std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable;
log.info("addFuncToModule {any} {s}", .{ key, full_name }); const og_buffer_to_arg = self._buffer_to_arg;
defer {
self._buffer_to_arg.deinit(self._allocator);
self._buffer_to_arg = og_buffer_to_arg;
}
const value = self.addFuncToModule( // Reset the buffer -> assignement
arena, self._buffer_to_arg = .{};
full_name,
func,
args,
) catch unreachable;
break :b self._fn_cache.addEntry(key, value) catch unreachable; var arg_id: u16 = 0;
var tensor_args: @TypeOf(args) = args;
meta.mapAlloc(struct {
fn cb(arg_id_: *u16, x: Tensor) Tensor {
const a = arg_id_.*;
arg_id_.* += 1;
return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } };
}
}.cb, self._allocator, &arg_id, args, &tensor_args) catch @panic("OutOfMemory");
const f = self.emitMlir(arena, func, &tensor_args, .{
.name = full_name,
}) catch @panic("OOM");
self._module.getBody().appendOperation(f.mlir_fn);
break :b self._fn_cache.addEntry(key, f) catch unreachable;
}; };
// Note: we won't increase the size of the cache until next `call` so
// we can use the memory there without worrying about fragmentation.
const loc = self.mlirCtx().location(@src()); const loc = self.mlirCtx().location(@src());
const values = arena.alloc(mlir.Value, function.n_model + function.n_args) catch unreachable; const values = arena.alloc(mlir.Value, function.num_args) catch unreachable;
self.extractValues(&args, values[function.n_model..]); self.extractValues(&args, values);
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc); const donations = arena.alloc(Tensor._Donation, function.num_args) catch unreachable;
// TODO: tags seem to be lost by `callFunc`. meta.collectBuf(struct {
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
return ctx.getValueAndDonation(x)[1];
}
}.cb, self, &args, donations);
const op = dialect.func.call(self.mlirCtx(), @ptrCast(function.name), values, function.res_types, loc);
// Create the result tensor object by combining the operand results,
// as well as the registered shapes and donations.
// Note: this assume res can be stack-allocated.
// Maybe it'd be simpler to just call the Zig function twice to do the shape/donation propagation for us.
// But this is blocked on https://github.com/zml/zml/issues/97
var res: stdx.meta.FnResult(func) = undefined; var res: stdx.meta.FnResult(func) = undefined;
assignResults(op, &res, function.res_shapes); const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
meta.visit((struct {
fn cb(ctx: *LocalContext, tensor: *Tensor) void {
const i = ctx.index;
ctx.index += 1;
var new = Tensor.fromMlirValue(ctx.op.result(i));
new._shape = ctx.function.res_shapes[i];
new._donation = switch (ctx.function.res_donations[i]) {
.no_buffer => .no_buffer,
.arg => |input_arg| ctx.donations[input_arg],
.input_buffer => .no_buffer, // user escaped the sandbox
};
tensor.* = new;
}
}).cb, &context, &res);
std.debug.assert(context.index == op.numResults());
return res; return res;
} }
@ -669,7 +716,7 @@ pub const CompilationContext = struct {
const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id); const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id);
if (res.found_existing) { if (res.found_existing) {
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id }); stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id });
} else { } else {
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } }; res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
} }
@ -681,7 +728,7 @@ pub const CompilationContext = struct {
/// Create tensor from the given shapes. /// Create tensor from the given shapes.
/// Each created tensor will receive a unique id, local to this CompilationContext. /// Each created tensor will receive a unique id, local to this CompilationContext.
pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator: std.mem.Allocator, args_shapes: anytype) ArgsT { pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator: std.mem.Allocator, args_shapes: anytype) !ArgsT {
const Local = struct { const Local = struct {
fn tensorFromShape(arg_id: *u64, shape: Shape) Tensor { fn tensorFromShape(arg_id: *u64, shape: Shape) Tensor {
defer arg_id.* += 1; defer arg_id.* += 1;
@ -951,29 +998,6 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize {
return context.index; return context.index;
} }
/// Visit the given struct and assign op results to each tensor found.
fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void {
const LocalContext = struct {
index: usize,
op: mlir.Operation,
shapes: ?[]Shape,
};
var context = LocalContext{ .index = 0, .op = op, .shapes = shapes };
meta.visit((struct {
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
if (inner_ctx.shapes) |sh| {
new._shape = sh[inner_ctx.index];
} else {
new._shape._tags = tensor._shape._tags;
}
tensor.* = new;
inner_ctx.index += 1;
}
}).cb, &context, v);
std.debug.assert(context.index == op.numResults());
}
pub const XxHash64Writer = struct { pub const XxHash64Writer = struct {
hasher: *std.hash.XxHash64, hasher: *std.hash.XxHash64,
@ -1039,8 +1063,7 @@ pub const FnCache = struct {
const owned_value: MlirFn = .{ const owned_value: MlirFn = .{
.name = name_copy, .name = name_copy,
.mlir_fn = value.mlir_fn, .mlir_fn = value.mlir_fn,
.n_model = value.n_model, .num_args = value.num_args,
.n_args = value.n_args,
.res_types = res_types_copy, .res_types = res_types_copy,
.res_shapes = res_shapes_copy, .res_shapes = res_shapes_copy,
.res_donations = res_donations_copy, .res_donations = res_donations_copy,
@ -1055,47 +1078,55 @@ test FnCache {
const zml = @import("zml.zig"); const zml = @import("zml.zig");
const platform = zml.testing.env(); const platform = zml.testing.env();
const Layer = struct {
const Layer_ = @This();
w: Tensor,
b: Tensor,
pub fn forward(self: Layer_, x: Tensor) Tensor {
const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
return wx.add(self.b.broad(wx.shape())).relu();
}
};
const NN = struct { const NN = struct {
const NN_ = @This(); const NN_ = @This();
layer_weights: [3]Tensor, layers: [3]Layer,
layer_biases: [3]Tensor,
pub fn forward(self: NN_, x0: Tensor) Tensor { pub fn forward(self: NN_, x0: Tensor) Tensor {
var x = x0; var x = x0;
for (self.layer_weights, self.layer_biases) |w, b| { for (self.layers) |layer| {
// TODO use the `call` magic helper x = ops.call(layer, .forward, .{x});
// x = ops.callFunc(ctx, NN_, "reluLayer", .{ w, b, x });
x = NN_.reluLayer(w, b, x);
} }
return x; return x;
} }
pub fn forwardRefImpl(self: NN_, x0: Tensor) Tensor { pub fn forwardRefImpl(self: NN_, x0: Tensor) Tensor {
var x = x0; var x = x0;
for (self.layer_weights, self.layer_biases) |w, b| { for (self.layers) |layer| {
x = NN_.reluLayer(w, b, x); x = layer.forward(x);
} }
return x; return x;
} }
pub fn reluLayer(w: Tensor, b: Tensor, x: Tensor) Tensor {
const wx = w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
return wx.add(b.broadcastLeft(wx.shape())).relu();
}
}; };
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 }); const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
const nn: zml.Bufferized(NN) = .{ const nn: zml.Bufferized(NN) = .{
.layer_weights = .{ .layers = .{
try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }), .{
try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }), .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }),
// third layer is different .b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }),
try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }), },
.{
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }),
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }),
},
// third layer is different
.{
.w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }),
.b = try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }),
}, },
.layer_biases = .{
try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }),
try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }),
try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }),
}, },
}; };
const res = try zml.testing.compileAndCall(platform, NN.forward, .{ nn, x }); const res = try zml.testing.compileAndCall(platform, NN.forward, .{ nn, x });

View File

@ -30,9 +30,10 @@ test {
/// Generate an MLIR call to the given member function with the given tensors. /// Generate an MLIR call to the given member function with the given tensors.
pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(stdx.meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) { pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(stdx.meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) {
// TODO: this should use `self.getContext().callFunc(self, args)` const ctx = CompilationContext.current();
const name = @typeName(@TypeOf(self)) ++ "." ++ @tagName(func);
return @call(.auto, @field(@TypeOf(self), @tagName(func)), .{self} ++ args); const actual_fn = @field(@TypeOf(self), @tagName(func));
return ctx.callFunc(name, actual_fn, .{self} ++ args);
} }
pub fn while_( pub fn while_(

View File

@ -138,7 +138,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
} }
}; };
var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined; var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined;
try meta.mapAlloc(Local.bufferToShape, allocator, {}, buffer_args, &shape_args); try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_args, &shape_args);
const mod = try zml.compileFn(allocator, func, shape_args, platform); const mod = try zml.compileFn(allocator, func, shape_args, platform);
defer mod.deinit(); defer mod.deinit();