Implement func.call emission and function caching across MLIR dialects and ZML module/ops, propagating tags and donations.
This commit is contained in:
parent
37de7b9613
commit
98b512c495
@ -13,8 +13,7 @@ pub fn func(
|
||||
location: mlir.Location,
|
||||
},
|
||||
) mlir.Operation {
|
||||
const AttrTuple = struct { [:0]const u8, mlir.Attribute };
|
||||
var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 4){};
|
||||
var attrs_tuple_buffer = std.BoundedArray(mlir.AttrTuple, 4){};
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
|
||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
|
||||
if (args.arg_attrs.len > 0) {
|
||||
|
||||
@ -189,7 +189,7 @@ pub fn dot_general(
|
||||
},
|
||||
) mlir.Operation {
|
||||
const precisions = [1]mlir.Attribute{opts.precision.precisionAttr(ctx)} ** 2;
|
||||
const attributes = [3]mlir.Operation.AttrTuple{
|
||||
const attributes = [3]mlir.AttrTuple{
|
||||
.{
|
||||
"dot_dimension_numbers", DotDimensionNumbersAttribute.init(ctx, .{
|
||||
.lhs_batching_dimensions = opts.lhs_batching_dimensions,
|
||||
|
||||
@ -333,6 +333,8 @@ pub const Identifier = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const AttrTuple = struct { [:0]const u8, Attribute };
|
||||
|
||||
pub const Attribute = struct {
|
||||
_inner: c.MlirAttribute,
|
||||
pub usingnamespace MlirHelpers(Attribute, .{
|
||||
@ -791,8 +793,6 @@ pub const Operation = struct {
|
||||
) orelse Error.InvalidMlir;
|
||||
}
|
||||
|
||||
pub const AttrTuple = struct { [:0]const u8, Attribute };
|
||||
|
||||
pub fn make(ctx: Context, op_name: [:0]const u8, args: struct {
|
||||
operands: ?[]const Value = null,
|
||||
variadic_operands: ?[]const []const Value = null,
|
||||
|
||||
36
zml/meta.zig
36
zml/meta.zig
@ -86,12 +86,14 @@ pub fn MapType(From: type, To: type) type {
|
||||
/// For example it can convert from a comptime array to a runtime slice.
|
||||
/// `mapAlloc` can allocate new slices to write the result if the result struct requires it.
|
||||
/// The caller is owning said allocations, using an `ArenaAllocator` might help tracking them.
|
||||
// TODO: handle tuple to slice conversion
|
||||
///
|
||||
/// Note: to avoid infinite loop, mapAlloc doesn't look for `From` fields inside `To` struct.
|
||||
/// Any `To` struct inside `from` will be copied over to the target.
|
||||
pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void {
|
||||
// const Ctx = FnParam(cb, 0);
|
||||
// TODO: handle tuple to slice conversion
|
||||
const From = FnParam(cb, 1);
|
||||
const To = stdx.meta.FnResult(cb);
|
||||
const FromStruct = @TypeOf(from);
|
||||
|
||||
const type_info_to_ptr = @typeInfo(@TypeOf(to));
|
||||
if (type_info_to_ptr != .Pointer) {
|
||||
stdx.debug.compileError("convertType is expecting a mutable `to` argument but received: {}", .{@TypeOf(to)});
|
||||
@ -100,11 +102,12 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
||||
const type_info_to = @typeInfo(ToStruct);
|
||||
|
||||
if (FromStruct == From) {
|
||||
// Special case for converting from shape to tensor:
|
||||
// If the target type is Shape, skip tensor conversion.
|
||||
// A general `to.* = from` assignment causes a Zig error in this scenario.
|
||||
// (see below)
|
||||
if (ToStruct == @import("shape.zig").Shape and FromStruct == ToStruct) { // FromStruct) {
|
||||
// We have an issues with `Tensor` -> `Shape` -> `Tensor` conversion when compiling ZML functions where one argument is a Shape itself.
|
||||
// Normally we should call `cb` on all `Shape`.
|
||||
// But the "ShapeOf" struct will have more Shape than need on the output.
|
||||
// So here we take a hint from the receiving object.
|
||||
// If the target is indeed a Tensor, use the callback, but if the target is `Shape` just copy it over.
|
||||
if (ToStruct != To and FromStruct == ToStruct) {
|
||||
to.* = from;
|
||||
} else {
|
||||
to.* = @call(.auto, cb, .{ ctx, from });
|
||||
@ -112,19 +115,11 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
||||
return;
|
||||
}
|
||||
|
||||
// This is generally due to a user error, but let this fn compile,
|
||||
// and the user will have a Zig error.
|
||||
if (FromStruct == ToStruct) {
|
||||
if (FromStruct == To) {
|
||||
to.* = from;
|
||||
return;
|
||||
}
|
||||
|
||||
// Don't go into Shape objects because of the weird tag.
|
||||
// TODO: we could not error on pointers to basic types like u8
|
||||
if (FromStruct == @import("shape.zig").Shape) {
|
||||
to.* = from;
|
||||
return;
|
||||
}
|
||||
switch (type_info_to) {
|
||||
.Struct => |info| inline for (info.fields) |field| {
|
||||
// if (field.is_comptime) continue;
|
||||
@ -155,10 +150,11 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
||||
.One => switch (type_info_to_ptr.Pointer.size) {
|
||||
// pointer to array -> slice promotion
|
||||
.Slice => {
|
||||
to.* = try allocator.alloc(type_info_to_ptr.Pointer.child, from.len);
|
||||
for (from, to.*) |f, *t| {
|
||||
const items = try allocator.alloc(type_info_to_ptr.Pointer.child, from.len);
|
||||
for (from, items) |f, *t| {
|
||||
try mapAlloc(cb, allocator, ctx, f, t);
|
||||
}
|
||||
to.* = items;
|
||||
},
|
||||
else => try mapAlloc(cb, allocator, ctx, from.*, to.*),
|
||||
},
|
||||
@ -177,7 +173,7 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
||||
} else {
|
||||
to.* = null;
|
||||
},
|
||||
.Int, .Float => to.* = from,
|
||||
.Int, .Float, .Enum => to.* = from,
|
||||
else => stdx.debug.compileError("zml.meta.mapAlloc doesn't support: {}", .{FromStruct}),
|
||||
}
|
||||
}
|
||||
|
||||
231
zml/module.zig
231
zml/module.zig
@ -49,7 +49,7 @@ const Block = union(BlockKind) {
|
||||
.op_result => |parent_op| self.appendOperationRecursive(parent_op),
|
||||
.block_argument => |arg| {
|
||||
// Hermetic blocks are not allowed to use arguments from other blocks.
|
||||
std.debug.assert(self == .open or self.block().eql(arg.block()));
|
||||
stdx.debug.assert(self == .open or self.block().eql(arg.block()), "Can't add {} from {?x} block to {?x} block", .{ arg, arg.block()._inner.ptr, self.block()._inner.ptr });
|
||||
},
|
||||
.null => @panic("InvalidMlir"),
|
||||
}
|
||||
@ -75,6 +75,11 @@ pub const MlirFn = struct {
|
||||
res_shapes: []Shape,
|
||||
res_donations: []Tensor._Donation,
|
||||
mlir_fn: mlir.Operation,
|
||||
|
||||
pub const Kind = enum {
|
||||
main,
|
||||
private,
|
||||
};
|
||||
};
|
||||
|
||||
pub const CompilationContext = struct {
|
||||
@ -151,7 +156,6 @@ pub const CompilationContext = struct {
|
||||
pub fn deactivate(self: *CompilationContext) void {
|
||||
std.debug.assert(_current != null and _current.? == self);
|
||||
_current = self._previous;
|
||||
self._previous = null;
|
||||
}
|
||||
|
||||
pub fn current() *CompilationContext {
|
||||
@ -182,9 +186,9 @@ pub const CompilationContext = struct {
|
||||
const arena = arena_state.allocator();
|
||||
|
||||
var timer = std.time.Timer.start() catch null;
|
||||
const tensor_args = self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args);
|
||||
const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args);
|
||||
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
||||
const f = try asynk.callBlocking(CompilationContext.generateBytecode, .{ self, arena, "main", func, &tensor_args });
|
||||
const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, arena, func, &tensor_args, .{ .name = "main", .kind = .main } });
|
||||
const module = self._module;
|
||||
module.getBody().appendOperation(f.mlir_fn);
|
||||
|
||||
@ -329,15 +333,18 @@ pub const CompilationContext = struct {
|
||||
/// Generate an MLIR function from a ZML function.
|
||||
/// The caller is responsible to have properly created the input
|
||||
/// tensors with unique tensor ids.
|
||||
pub fn generateBytecode(
|
||||
pub fn emitMlir(
|
||||
self: *CompilationContext,
|
||||
allocator: std.mem.Allocator,
|
||||
fn_name: []const u8,
|
||||
comptime func: anytype,
|
||||
args: *const stdx.meta.FnArgs(func),
|
||||
opts: struct {
|
||||
name: []const u8,
|
||||
kind: MlirFn.Kind = .private,
|
||||
},
|
||||
) error{OutOfMemory}!MlirFn {
|
||||
const frame = self._tracer.frameStart("generateBytecode.emit");
|
||||
errdefer self._tracer.frameEnd(frame, "generateBytecode.emit");
|
||||
const frame = self._tracer.frameStart("emitMlir.emit");
|
||||
errdefer self._tracer.frameEnd(frame, "emitMlir.emit");
|
||||
|
||||
// Note: only temp allocations are done in the arena,
|
||||
// the other allocations are managed by the caller.
|
||||
@ -371,11 +378,6 @@ pub const CompilationContext = struct {
|
||||
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
|
||||
{
|
||||
defer self.closeBlock(fn_body);
|
||||
// Note: we could shrink self._buffer_to_arg once we called `func`.
|
||||
// But for now we are only compiling one function per CompilationContext.
|
||||
// So we don't need to do this since we won't reuse self._buffer_to_arg anyway.
|
||||
// const n = self._buffer_to_arg.count();
|
||||
// defer self._buffer_to_arg.shrinkRetainingCapacity(n);
|
||||
|
||||
try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count));
|
||||
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0);
|
||||
@ -400,14 +402,15 @@ pub const CompilationContext = struct {
|
||||
const res_attrs = try arena.alloc(AttributeList, out_tensor_count);
|
||||
@memset(res_attrs, .{});
|
||||
|
||||
// Donations attributes only make sense on the main function.
|
||||
if (opts.kind == .main) {
|
||||
self.addDonationsAttributes(arg_attrs, fn_res_donations);
|
||||
|
||||
if (self._platform.sharding().num_partitions > 1) {
|
||||
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
|
||||
}
|
||||
}
|
||||
|
||||
const mlir_fn = dialect.func.func(self.mlirCtx(), .{
|
||||
.sym_name = fn_name,
|
||||
.sym_name = opts.name,
|
||||
.args = input_types,
|
||||
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
|
||||
.results = fn_res_types,
|
||||
@ -416,9 +419,9 @@ pub const CompilationContext = struct {
|
||||
.location = loc,
|
||||
});
|
||||
|
||||
self._tracer.frameEnd(frame, "generateBytecode.emit");
|
||||
const canonicalize_frame = self._tracer.frameStart("generateBytecode.canonicalize");
|
||||
defer self._tracer.frameEnd(canonicalize_frame, "generateBytecode.canonicalize");
|
||||
self._tracer.frameEnd(frame, "emitMlir.emit");
|
||||
const canonicalize_frame = self._tracer.frameStart("emitMlir.canonicalize");
|
||||
defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize");
|
||||
self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) {
|
||||
error.InvalidMlir => {
|
||||
log.err("Failed to canonicalize invalid mlir: {}", .{mlir_fn.mlirFormatter(.{})});
|
||||
@ -429,7 +432,7 @@ pub const CompilationContext = struct {
|
||||
|
||||
return .{
|
||||
.mlir_fn = mlir_fn,
|
||||
.name = fn_name,
|
||||
.name = opts.name,
|
||||
.num_args = @intCast(tensor_count),
|
||||
.res_types = fn_res_types,
|
||||
.res_shapes = fn_res_shapes,
|
||||
@ -478,7 +481,13 @@ pub const CompilationContext = struct {
|
||||
const Local = struct {
|
||||
bias: Tensor,
|
||||
|
||||
pub fn forward(self: @This(), x: Tensor) Tensor {
|
||||
pub fn forward(self: @This(), x: Tensor, y: Tensor) [2]Tensor {
|
||||
const x1 = zml.ops.call(self, .inner, .{x});
|
||||
const x2 = zml.ops.call(self, .inner, .{x1});
|
||||
return .{ x1.reuseBuffer(y), x2 };
|
||||
}
|
||||
|
||||
pub fn inner(self: @This(), x: Tensor) Tensor {
|
||||
const y = x.add(self.bias);
|
||||
return y.reuseBuffer(x);
|
||||
}
|
||||
@ -490,20 +499,26 @@ pub const CompilationContext = struct {
|
||||
|
||||
var comp = try zml.module.CompilationContext.init(allocator, "test", platform);
|
||||
defer comp.deinit();
|
||||
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } } };
|
||||
const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &tensor_args);
|
||||
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
|
||||
const f = try comp.emitMlir(allocator, Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
|
||||
|
||||
var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{};
|
||||
try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})});
|
||||
|
||||
// Check that the `x` input argument gives its buffer to the result tensor.
|
||||
// `%arg0` is the bias of the model, `%arg1` is `x`.
|
||||
try std.testing.expectEqual(2, f.num_args);
|
||||
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, "tf.aliasing_output = 0 : i32") != null) catch |err| {
|
||||
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
||||
try std.testing.expectEqual(3, f.num_args);
|
||||
// We should have two buffers being donated.
|
||||
const template = "tf.aliasing_output = {d} : i32";
|
||||
var buf = template.*;
|
||||
for (0..2) |i| {
|
||||
const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable;
|
||||
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, alias_attr) != null) catch |err| {
|
||||
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items});
|
||||
return err;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute {
|
||||
const mlir_ctx = self.mlirCtx();
|
||||
@ -608,45 +623,77 @@ pub const CompilationContext = struct {
|
||||
|
||||
// first, do the "compile" and check the bytecode
|
||||
// the result of this will also have the correct tags of the result shapes
|
||||
const dummy_result = self.generateMlirBytecodeForFunction(
|
||||
arena,
|
||||
func_name,
|
||||
func,
|
||||
args,
|
||||
) catch unreachable; // TODO: do we like unreachable?
|
||||
const bytecode_hash = hashArgs(dummy_result.bytecode_tmp);
|
||||
|
||||
const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = bytecode_hash };
|
||||
const args_hash = hashArgs(args);
|
||||
const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash };
|
||||
const function = self._fn_cache.getEntry(key) orelse b: {
|
||||
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
|
||||
arena.dupeZ(u8, func_name) catch unreachable
|
||||
else
|
||||
std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable;
|
||||
|
||||
log.info("addFuncToModule {any} {s}", .{ key, full_name });
|
||||
const og_buffer_to_arg = self._buffer_to_arg;
|
||||
defer {
|
||||
self._buffer_to_arg.deinit(self._allocator);
|
||||
self._buffer_to_arg = og_buffer_to_arg;
|
||||
}
|
||||
|
||||
const value = self.addFuncToModule(
|
||||
arena,
|
||||
full_name,
|
||||
func,
|
||||
args,
|
||||
) catch unreachable;
|
||||
// Reset the buffer -> assignement
|
||||
self._buffer_to_arg = .{};
|
||||
|
||||
break :b self._fn_cache.addEntry(key, value) catch unreachable;
|
||||
var arg_id: u16 = 0;
|
||||
var tensor_args: @TypeOf(args) = args;
|
||||
meta.mapAlloc(struct {
|
||||
fn cb(arg_id_: *u16, x: Tensor) Tensor {
|
||||
const a = arg_id_.*;
|
||||
arg_id_.* += 1;
|
||||
return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } };
|
||||
}
|
||||
}.cb, self._allocator, &arg_id, args, &tensor_args) catch @panic("OutOfMemory");
|
||||
|
||||
const f = self.emitMlir(arena, func, &tensor_args, .{
|
||||
.name = full_name,
|
||||
}) catch @panic("OOM");
|
||||
self._module.getBody().appendOperation(f.mlir_fn);
|
||||
|
||||
break :b self._fn_cache.addEntry(key, f) catch unreachable;
|
||||
};
|
||||
|
||||
// Note: we won't increase the size of the cache until next `call` so
|
||||
// we can use the memory there without worrying about fragmentation.
|
||||
|
||||
const loc = self.mlirCtx().location(@src());
|
||||
|
||||
const values = arena.alloc(mlir.Value, function.n_model + function.n_args) catch unreachable;
|
||||
self.extractValues(&args, values[function.n_model..]);
|
||||
const values = arena.alloc(mlir.Value, function.num_args) catch unreachable;
|
||||
self.extractValues(&args, values);
|
||||
|
||||
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc);
|
||||
// TODO: tags seem to be lost by `callFunc`.
|
||||
const donations = arena.alloc(Tensor._Donation, function.num_args) catch unreachable;
|
||||
meta.collectBuf(struct {
|
||||
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
|
||||
return ctx.getValueAndDonation(x)[1];
|
||||
}
|
||||
}.cb, self, &args, donations);
|
||||
|
||||
const op = dialect.func.call(self.mlirCtx(), @ptrCast(function.name), values, function.res_types, loc);
|
||||
// Create the result tensor object by combining the operand results,
|
||||
// as well as the registered shapes and donations.
|
||||
// Note: this assume res can be stack-allocated.
|
||||
// Maybe it'd be simpler to just call the Zig function twice to do the shape/donation propagation for us.
|
||||
// But this is blocked on https://github.com/zml/zml/issues/97
|
||||
var res: stdx.meta.FnResult(func) = undefined;
|
||||
assignResults(op, &res, function.res_shapes);
|
||||
const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
|
||||
var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
|
||||
meta.visit((struct {
|
||||
fn cb(ctx: *LocalContext, tensor: *Tensor) void {
|
||||
const i = ctx.index;
|
||||
ctx.index += 1;
|
||||
var new = Tensor.fromMlirValue(ctx.op.result(i));
|
||||
new._shape = ctx.function.res_shapes[i];
|
||||
new._donation = switch (ctx.function.res_donations[i]) {
|
||||
.no_buffer => .no_buffer,
|
||||
.arg => |input_arg| ctx.donations[input_arg],
|
||||
.input_buffer => .no_buffer, // user escaped the sandbox
|
||||
};
|
||||
tensor.* = new;
|
||||
}
|
||||
}).cb, &context, &res);
|
||||
std.debug.assert(context.index == op.numResults());
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -669,7 +716,7 @@ pub const CompilationContext = struct {
|
||||
|
||||
const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id);
|
||||
if (res.found_existing) {
|
||||
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {}({}).", .{ res.key_ptr.*, tensor, tensor._id });
|
||||
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id });
|
||||
} else {
|
||||
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
|
||||
}
|
||||
@ -681,7 +728,7 @@ pub const CompilationContext = struct {
|
||||
|
||||
/// Create tensor from the given shapes.
|
||||
/// Each created tensor will receive a unique id, local to this CompilationContext.
|
||||
pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator: std.mem.Allocator, args_shapes: anytype) ArgsT {
|
||||
pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator: std.mem.Allocator, args_shapes: anytype) !ArgsT {
|
||||
const Local = struct {
|
||||
fn tensorFromShape(arg_id: *u64, shape: Shape) Tensor {
|
||||
defer arg_id.* += 1;
|
||||
@ -951,29 +998,6 @@ fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize {
|
||||
return context.index;
|
||||
}
|
||||
|
||||
/// Visit the given struct and assign op results to each tensor found.
|
||||
fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void {
|
||||
const LocalContext = struct {
|
||||
index: usize,
|
||||
op: mlir.Operation,
|
||||
shapes: ?[]Shape,
|
||||
};
|
||||
var context = LocalContext{ .index = 0, .op = op, .shapes = shapes };
|
||||
meta.visit((struct {
|
||||
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
|
||||
var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
|
||||
if (inner_ctx.shapes) |sh| {
|
||||
new._shape = sh[inner_ctx.index];
|
||||
} else {
|
||||
new._shape._tags = tensor._shape._tags;
|
||||
}
|
||||
tensor.* = new;
|
||||
inner_ctx.index += 1;
|
||||
}
|
||||
}).cb, &context, v);
|
||||
std.debug.assert(context.index == op.numResults());
|
||||
}
|
||||
|
||||
pub const XxHash64Writer = struct {
|
||||
hasher: *std.hash.XxHash64,
|
||||
|
||||
@ -1039,8 +1063,7 @@ pub const FnCache = struct {
|
||||
const owned_value: MlirFn = .{
|
||||
.name = name_copy,
|
||||
.mlir_fn = value.mlir_fn,
|
||||
.n_model = value.n_model,
|
||||
.n_args = value.n_args,
|
||||
.num_args = value.num_args,
|
||||
.res_types = res_types_copy,
|
||||
.res_shapes = res_shapes_copy,
|
||||
.res_donations = res_donations_copy,
|
||||
@ -1055,47 +1078,55 @@ test FnCache {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const Layer = struct {
|
||||
const Layer_ = @This();
|
||||
|
||||
w: Tensor,
|
||||
b: Tensor,
|
||||
|
||||
pub fn forward(self: Layer_, x: Tensor) Tensor {
|
||||
const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
|
||||
return wx.add(self.b.broad(wx.shape())).relu();
|
||||
}
|
||||
};
|
||||
|
||||
const NN = struct {
|
||||
const NN_ = @This();
|
||||
layer_weights: [3]Tensor,
|
||||
layer_biases: [3]Tensor,
|
||||
layers: [3]Layer,
|
||||
|
||||
pub fn forward(self: NN_, x0: Tensor) Tensor {
|
||||
var x = x0;
|
||||
for (self.layer_weights, self.layer_biases) |w, b| {
|
||||
// TODO use the `call` magic helper
|
||||
// x = ops.callFunc(ctx, NN_, "reluLayer", .{ w, b, x });
|
||||
x = NN_.reluLayer(w, b, x);
|
||||
for (self.layers) |layer| {
|
||||
x = ops.call(layer, .forward, .{x});
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
pub fn forwardRefImpl(self: NN_, x0: Tensor) Tensor {
|
||||
var x = x0;
|
||||
for (self.layer_weights, self.layer_biases) |w, b| {
|
||||
x = NN_.reluLayer(w, b, x);
|
||||
for (self.layers) |layer| {
|
||||
x = layer.forward(x);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
|
||||
pub fn reluLayer(w: Tensor, b: Tensor, x: Tensor) Tensor {
|
||||
const wx = w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
|
||||
return wx.add(b.broadcastLeft(wx.shape())).relu();
|
||||
}
|
||||
};
|
||||
|
||||
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
|
||||
const nn: zml.Bufferized(NN) = .{
|
||||
.layer_weights = .{
|
||||
try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }),
|
||||
try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }),
|
||||
// third layer is different
|
||||
try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }),
|
||||
.layers = .{
|
||||
.{
|
||||
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }),
|
||||
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }),
|
||||
},
|
||||
.{
|
||||
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }),
|
||||
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }),
|
||||
},
|
||||
// third layer is different
|
||||
.{
|
||||
.w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }),
|
||||
.b = try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }),
|
||||
},
|
||||
.layer_biases = .{
|
||||
try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }),
|
||||
try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }),
|
||||
try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }),
|
||||
},
|
||||
};
|
||||
const res = try zml.testing.compileAndCall(platform, NN.forward, .{ nn, x });
|
||||
|
||||
@ -30,9 +30,10 @@ test {
|
||||
|
||||
/// Generate an MLIR call to the given member function with the given tensors.
|
||||
pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(stdx.meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) {
|
||||
// TODO: this should use `self.getContext().callFunc(self, args)`
|
||||
|
||||
return @call(.auto, @field(@TypeOf(self), @tagName(func)), .{self} ++ args);
|
||||
const ctx = CompilationContext.current();
|
||||
const name = @typeName(@TypeOf(self)) ++ "." ++ @tagName(func);
|
||||
const actual_fn = @field(@TypeOf(self), @tagName(func));
|
||||
return ctx.callFunc(name, actual_fn, .{self} ++ args);
|
||||
}
|
||||
|
||||
pub fn while_(
|
||||
|
||||
@ -138,7 +138,7 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
|
||||
}
|
||||
};
|
||||
var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined;
|
||||
try meta.mapAlloc(Local.bufferToShape, allocator, {}, buffer_args, &shape_args);
|
||||
try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_args, &shape_args);
|
||||
|
||||
const mod = try zml.compileFn(allocator, func, shape_args, platform);
|
||||
defer mod.deinit();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user