diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index c20dd85..af95f10 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -288,79 +288,6 @@ fn elementTypeOrSelf(typ: mlir.Type) mlir.Type { } else typ; } -pub fn scatter( - ctx: mlir.Context, - // inputs - inputs: []const mlir.Value, - scatter_indices: mlir.Value, - updates: []const mlir.Value, - // input functions - update_ctx: anytype, // for update_fn - update_fn: fn (anytype, mlir.Context, []const mlir.Value, []const mlir.Value) mlir.Operation, - // attributes - args: struct { - update_window_dims: []const i64, - inserted_window_dims: []const i64, - input_batching_dims: []const i64, - scatter_indices_batching_dims: []const i64, - scatter_dims_to_operand_dims: []const i64, - index_vector_dim: i64, - indices_are_sorted: bool = false, - unique_indices: bool = false, - }, - // zml loc - location: mlir.Location, -) mlir.Operation { - // create block for update_fn - const MaxBlockArguments = 32; // TODO(rene): where does this 32 come from? - // taken from reduce - - const block_n_args = inputs.len * 2; // TODO(rene): is this correct? yes, passes tests: block_inputs plus block_accumulators = inputs - const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args]; - var scatter_elem_types: [MaxBlockArguments]mlir.Type = undefined; - for (inputs, 0..) |input, i| { - const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?; - scatter_elem_types[i] = arg_type; - scatter_elem_types[inputs.len + i] = arg_type; - } - - var block = mlir.Block.open(scatter_elem_types[0..block_n_args], locations) catch unreachable; - { - defer block.close(); - var block_inputs: [MaxBlockArguments / 2]mlir.Value = undefined; - var block_accs: [MaxBlockArguments / 2]mlir.Value = undefined; - for (0..inputs.len) |i| { - block_inputs[i] = block.argument(i); - block_accs[i] = block.argument(inputs.len + i); - } - _ = update_fn(update_ctx, ctx, block_inputs[0..inputs.len], block_accs[0..inputs.len]); - } - return mlir.Operation.make( - ctx, - "stablehlo.scatter", - .{ - .variadic_operands = &.{ inputs, &.{scatter_indices}, updates }, - // .blocks = &.{block}, - .block = block, - .attributes = &.{ - .{ "scatter_dimension_numbers", ScatterDimensionNumbersAttribute.init( - ctx, - args.update_window_dims, - args.inserted_window_dims, - args.input_batching_dims, - args.scatter_indices_batching_dims, - args.scatter_dims_to_operand_dims, - args.index_vector_dim, - ).as(mlir.Attribute).? }, - .{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? }, - .{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? }, - }, - .result_type_inference = true, - .location = location, - }, - ); -} - pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation { return mlir.Operation.make(ctx, "stablehlo.iota", .{ .operands = &.{}, @@ -439,66 +366,6 @@ pub fn reduce( }); } -pub const ReduceWindowOpts = struct { - window_dimensions: []const i64, - window_strides: []const i64, - base_dilations: []const i64, - window_dilations: []const i64, - padding_values: []const i64, - padding_shape: []const i64, -}; - -// pub fn reduce_window( -// ctx: mlir.Context, -// inputs: []const mlir.Value, -// init_values: []const mlir.Value, -// opts: ReduceWindowOpts, -// blkctx: anytype, -// blkfn: fn (anytype, mlir.Context, []const mlir.Value, []const mlir.Value) mlir.Operation, -// location: mlir.Location, -// ) mlir.Operation { -// // TODO: move to ops.zig, and refactor similar to `reduce` -// const MaxBlockArguments = 32; - -// const block_n_args = inputs.len + init_values.len; -// const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args]; -// var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined; -// for (inputs, 0..) |input, i| { -// const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?; -// reduce_elem_types[i] = arg_type; -// reduce_elem_types[inputs.len + i] = arg_type; -// } -// const module = @import("../module.zig"); -// const comp = module.getCompilationContext(); -// var block = comp.openBlock(reduce_elem_types[0..block_n_args], locations) catch unreachable; -// { -// defer comp.closeBlock(block); - -// var block_inputs: [MaxBlockArguments / 2]mlir.Value = undefined; -// var block_accs: [MaxBlockArguments / 2]mlir.Value = undefined; -// for (0..inputs.len) |i| { -// block_inputs[i] = block.argument(i); -// block_accs[i] = block.argument(inputs.len + i); -// } -// _ = blkfn(blkctx, ctx, block_inputs[0..inputs.len], block_accs[0..init_values.len]); -// } - -// const pad_shape = mlir.RankedTensorType.init(opts.padding_shape, DataType.i64.mlirType(ctx)).as(mlir.Type).?; -// return mlir.Operation.make(ctx, "stablehlo.reduce_window", .{ -// .variadic_operands = &.{ inputs, init_values }, -// .result_type_inference = true, -// .blocks = &.{block}, -// .attributes = &.{ -// .{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_dimensions).as(mlir.Attribute).? }, -// .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute).? }, -// .{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx, opts.base_dilations).as(mlir.Attribute).? }, -// .{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_dilations).as(mlir.Attribute).? }, -// .{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding_values)).as(mlir.Attribute).? }, -// }, -// .location = location, -// }); -// } - pub fn sort( ctx: mlir.Context, inputs: []const mlir.Value, diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 6eda578..f3ccc70 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -691,10 +691,6 @@ pub const OperationState = struct { c.mlirOperationStateAddOwnedRegions(self.innerPtr(), @intCast(regions.len), @ptrCast(regions.ptr)); } - // pub fn addSuccessor(self: *Self, successor: Operation) void { - // c.mlirOperationStateAddSuccessors(self.innerPtr(), 1, &[_]c.MlirOperation{successor.inner()}); - // } - pub fn addAttribute(self: *Self, ctx: Context, name: [:0]const u8, attr: Attribute) void { c.mlirOperationStateAddAttributes(self.innerPtr(), 1, @ptrCast(&.{ .{ @@ -745,9 +741,9 @@ pub const DictionaryAttribute = struct { return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos))); } - // pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute { - // return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name)); - // } + pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute { + return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name)); + } }; pub const Operation = struct { @@ -1519,276 +1515,6 @@ pub const DialectHandle = struct { } }; -// pub const AnyQuantizedType = MlirWrapperType(c.MlirType, .{ -// .is_a_fn = c.mlirTypeIsAAnyQuantizedType, -// .is_null_fn = c.mlirTypeIsNull, -// .dump_fn = c.mlirTypeDump, -// .equal_fn = c.mlirTypeEqual, -// }, struct { -// const Self = AnyQuantizedType; - -// pub fn init( -// flags: quant.QuantizationFlags, -// storageType: Type, -// expressedType: Type, -// storageTypeMin: i64, -// storageTypeMax: i64, -// ) Self { -// return Self.wrap(c.mlirAnyQuantizedTypeGet( -// @intCast(@intFromEnum(flags)), -// storageType.inner(), -// expressedType.inner(), -// storageTypeMin, -// storageTypeMax, -// )); -// } - -// pub fn getExpressedType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner())); -// } - -// pub fn getFlags(self: Self) quant.QuantizationFlags { -// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner())); -// } - -// pub fn isSigned(self: Self) bool { -// return c.mlirQuantizedTypeIsSigned(self.inner()); -// } - -// pub fn getStorageType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner())); -// } - -// pub fn getStorageTypeMin(self: Self) i64 { -// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner()); -// } - -// pub fn getStorageTypeMax(self: Self) i64 { -// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner()); -// } - -// pub fn getStorageTypeIntegralWidth(self: Self) c_uint { -// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner()); -// } - -// pub fn getQuantizedElementType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner())); -// } -// }); - -// pub const UniformQuantizedType = MlirWrapperType(c.MlirType, .{ -// .is_a_fn = c.mlirTypeIsAUniformQuantizedType, -// .is_null_fn = c.mlirTypeIsNull, -// .dump_fn = c.mlirTypeDump, -// .equal_fn = c.mlirTypeEqual, -// }, struct { -// const Self = AnyQuantizedType; - -// pub fn init( -// flags: quant.QuantizationFlags, -// storageType: Type, -// expressedType: Type, -// scale: f64, -// zeroPoint: i64, -// storageTypeMin: i64, -// storageTypeMax: i64, -// ) Self { -// return Self.wrap(c.mlirUniformQuantizedTypeGet( -// @intCast(@intFromEnum(flags)), -// storageType.inner(), -// expressedType.inner(), -// scale, -// zeroPoint, -// storageTypeMin, -// storageTypeMax, -// )); -// } - -// pub fn getExpressedType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner())); -// } - -// pub fn getFlags(self: Self) quant.QuantizationFlags { -// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner())); -// } - -// pub fn isSigned(self: Self) bool { -// return c.mlirQuantizedTypeIsSigned(self.inner()); -// } - -// pub fn getStorageType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner())); -// } - -// pub fn getStorageTypeMin(self: Self) i64 { -// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner()); -// } - -// pub fn getStorageTypeMax(self: Self) i64 { -// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner()); -// } - -// pub fn getStorageTypeIntegralWidth(self: Self) c_uint { -// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner()); -// } - -// pub fn getQuantizedElementType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner())); -// } - -// pub fn getScale(self: Self) f64 { -// return c.mlirUniformQuantizedTypeGetScale(self.inner()); -// } - -// pub fn getZeroPoint(self: Self) i64 { -// return c.mlirUniformQuantizedTypeGetZeroPoint(self.inner()); -// } - -// pub fn isFixedPoint(self: Self) bool { -// return c.mlirUniformQuantizedTypeIsFixedPoint(self.inner()); -// } -// }); - -// pub const QuantizedPerAxisType = MlirWrapperType(c.MlirType, .{ -// .is_a_fn = c.mlirTypeIsAUniformQuantizedPerAxisType, -// .is_null_fn = c.mlirTypeIsNull, -// .dump_fn = c.mlirTypeDump, -// .equal_fn = c.mlirTypeEqual, -// }, struct { -// const Self = AnyQuantizedType; - -// pub fn init( -// flags: quant.QuantizationFlags, -// storageType: Type, -// expressedType: Type, -// nDims: usize, -// scales: []f64, -// zeroPoints: []i64, -// quantizedDimension: i32, -// storageTypeMin: i64, -// storageTypeMax: i64, -// ) Self { -// std.debug.assert(scales.len == nDims); -// std.debug.assert(zeroPoints.len == nDims); -// return Self.wrap(c.mlirUniformQuantizedPerAxisTypeGet( -// @intCast(@intFromEnum(flags)), -// storageType.inner(), -// expressedType.inner(), -// @intCast(nDims), -// scales.ptr, -// zeroPoints.ptr, -// quantizedDimension, -// storageTypeMin, -// storageTypeMax, -// )); -// } - -// pub fn getExpressedType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner())); -// } - -// pub fn getFlags(self: Self) quant.QuantizationFlags { -// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner())); -// } - -// pub fn isSigned(self: Self) bool { -// return c.mlirQuantizedTypeIsSigned(self.inner()); -// } - -// pub fn getStorageType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner())); -// } - -// pub fn getStorageTypeMin(self: Self) i64 { -// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner()); -// } - -// pub fn getStorageTypeMax(self: Self) i64 { -// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner()); -// } - -// pub fn getStorageTypeIntegralWidth(self: Self) c_uint { -// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner()); -// } - -// pub fn getQuantizedElementType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner())); -// } - -// pub fn getNumDims(self: Self) usize { -// return @intCast(c.mlirUniformQuantizedPerAxisTypeGetNumDims(self.inner())); -// } - -// pub fn getScale(self: Self) f64 { -// return @intCast(c.mlirUniformQuantizedPerAxisTypeGetScale(self.inner())); -// } - -// pub fn getZeroPoint(self: Self, pos: usize) i64 { -// return c.mlirUniformQuantizedPerAxisTypeGetZeroPoint(self.inner(), @intCast(pos)); -// } - -// pub fn getQuantizedDimension(self: Self) i32 { -// return c.mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(self.inner()); -// } - -// pub fn isFixedPoint(self: Self) bool { -// return c.mlirUniformQuantizedPerAxisTypeIsFixedPoint(self.inner()); -// } -// }); - -// pub const CalibratedQuantizedType = MlirWrapperType(c.MlirType, .{ -// .is_a_fn = c.mlirTypeIsACalibratedQuantizedType, -// .is_null_fn = c.mlirTypeIsNull, -// .dump_fn = c.mlirTypeDump, -// .equal_fn = c.mlirTypeEqual, -// }, struct { -// const Self = AnyQuantizedType; - -// pub fn init(expressedType: Type, min: f64, max: f64) Self { -// return Self.wrap(c.mlirCalibratedQuantizedTypeGet(expressedType.inner(), min, max)); -// } - -// pub fn getExpressedType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner())); -// } - -// pub fn getFlags(self: Self) quant.QuantizationFlags { -// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner())); -// } - -// pub fn isSigned(self: Self) bool { -// return c.mlirQuantizedTypeIsSigned(self.inner()); -// } - -// pub fn getStorageType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner())); -// } - -// pub fn getStorageTypeMin(self: Self) i64 { -// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner()); -// } - -// pub fn getStorageTypeMax(self: Self) i64 { -// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner()); -// } - -// pub fn getStorageTypeIntegralWidth(self: Self) c_uint { -// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner()); -// } - -// pub fn getQuantizedElementType(self: Self) Type { -// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner())); -// } - -// pub fn getMin(self: Self) f64 { -// return c.mlirCalibratedQuantizedTypeGetMin(self.inner()); -// } - -// pub fn getMax(self: Self) f64 { -// return c.mlirCalibratedQuantizedTypeGetMax(self.inner()); -// } -// }); - pub const ShapedType = struct { _inner: c.MlirType, pub usingnamespace MlirHelpers(ShapedType, .{ diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 14d040d..642d97a 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -319,14 +319,6 @@ pub const Client = opaque { return Profiler.init(null, options); } - // pub fn getGpuCustomCallRegistry(self: *const Client, api: *const Api) ?*GpuCustomCallRegistry { - // if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| { - // return .{ .custom_call_register = ext.custom_call.? }; - // } - // log.warn("No Gpu Custom Call registry found for platform: {}", .{self}); - // return null; - // } - pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { const ret = try api.call(.PJRT_Executable_DeserializeAndLoad, .{ .client = self.inner(), @@ -365,32 +357,6 @@ pub const Client = opaque { } }; -// // pub const CustomCallSignature = *const fn (*anyopaque, **anyopaque, [*c]const u8, usize) callconv(.C) void; - -// // pub const GpuCustomCallRegistry = struct { -// // custom_call_register: *const c.PJRT_Gpu_Register_Custom_Call, - -// // pub fn registerCustomCall(self: GpuCustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: CustomCallSignature) ApiError!void { -// // var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{ -// // .function_name = name.ptr, -// // .function_name_size = name.len, -// // .api_version = @intCast(api_version), -// // .custom_call_function = @ptrCast(@constCast(func)), -// // }); -// // const result = self.custom_call_register(&ret); -// // if (result) |pjrt_c_error| { -// // const pjrt_error = .{ .inner = pjrt_c_error }; -// // log.err("{s}", .{pjrt_error.getMessage(api)}); -// // return pjrt_error.getCode().toApiError(); -// // } -// // } -// // }; - -// // const OldPjrtExtension = extern struct { -// // type: c.PJRT_Extension_Type, -// // next: [*]OldPjrtExtension, -// // }; - pub const Device = opaque { const inner = InnerMixin(c.PJRT_Device).inner; diff --git a/pjrt/profiler.zig b/pjrt/profiler.zig index 1cef566..e148bb5 100644 --- a/pjrt/profiler.zig +++ b/pjrt/profiler.zig @@ -128,64 +128,6 @@ pub const Profiler = struct { } }; -// If this was working it would be a good alternative to xspace_to_json.cc -// const xspace = @import("xspace.pb.zig"); -// pub fn printDataAsXSpace(allocator: std.mem.Allocator, data: []const u8) void { -// var arena = std.heap.ArenaAllocator.init(allocator); -// defer arena.deinit(); -// -// const space = xspace.XSpace.decode(data, arena.allocator()) catch |e| { -// std.log.err("Couldn't load profiling data: {}", .{e}); -// return; -// }; -// -// for (space.errors.items) |err| { -// std.log.err("{s}", .{err.getSlice()}); -// } -// for (space.warnings.items) |warning| { -// std.log.warn("{s}", .{warning.getSlice()}); -// } -// for (space.hostnames.items) |host| { -// std.log.info("Profiled host {s}", .{host.getSlice()}); -// } -// for (space.planes.items) |plane| { -// var event_metadata = std.hash_map.AutoHashMap(i64, xspace.XEventMetadata).init(arena.allocator()); -// event_metadata.ensureTotalCapacity(@intCast(plane.event_metadata.items.len)) catch return; -// defer event_metadata.deinit(); -// for (plane.event_metadata.items) |event_meta_entry| { -// if (event_meta_entry.value) |event_meta| { -// event_metadata.putAssumeCapacity(event_meta.id, event_meta); -// } -// } -// std.log.info("Profiled device {s}", .{plane.name.getSlice()}); - -// for (plane.lines.items) |line| { -// std.log.info( -// "{d} -> {d} xline {s} ({d} events)", -// .{ line.timestamp_ns, line.duration_ps, line.name.getSlice(), line.events.items.len }, -// ); -// const ps_per_ns: i64 = 1000; -// var duration_ns: i64 = 0; -// var last_metadata_id: i64 = 0; -// for (line.events.items) |event| { -// if (event.metadata_id != last_metadata_id and duration_ns != 0) { -// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us; -// const meta = event_metadata.get(event.metadata_id).?; -// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us }); - -// last_metadata_id = event.metadata_id; -// duration_ns = 0; -// } -// duration_ns += @divFloor(event.duration_ps, ps_per_ns); - -// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us; -// const meta = event_metadata.get(event.metadata_id).?; -// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us }); -// } -// } -// } -// } - const ProfilingData = union(enum) { owned: []const u8, external: []const u8, diff --git a/zml/aio.zig b/zml/aio.zig index 1d185f9..6997db2 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -300,23 +300,12 @@ fn _populateStruct( log.warn("No layer found at {s}", .{prefix}); } return true; - } else if (ptr_info.size == .One) { - //if (ptr_info.child != zml.Tensor and ptr_info.child != ?zml.Tensor) { - // // Note: should we recurse on all pointers ? - // log.warn("Not looking into: {any}", .{prefix}); - // return false; - //} - //obj.* = try allocator.create(ptr_info.child); - //return try _populateStruct(allocator, buffer_store, unique_id, prefix, obj.*, required); } else { std.log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) }); return false; } }, .Struct => |struct_info| { - // TODO(Corentin): See if we keep that - //if (@hasDecl(T, "_zml_reader_skip_me_")) return false; - var partial_struct = false; inline for (struct_info.fields) |field| { try prefix_builder.push(allocator, field.name); @@ -343,46 +332,12 @@ fn _populateStruct( } return true; }, - //.Array => |array_info| { - // var new_prefix = prefix; - // if (prefix.items.len > 0) - // new_prefix.appendAssumeCapacity('.'); - // const len = new_prefix.items.len; - // for (obj, 0..) |*value, i| { - // new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); - // const found = try _populateStruct(allocator, buffer_store, unique_id, new_prefix, value, required); - // if (!found) return false; - // new_prefix.shrinkRetainingCapacity(len); - // } - // const num_layers = buffer_store.numLayers(prefix.items); - // if (num_layers != array_info.len) { - // log.warn("Found {d} layers with prefix {s}, but only loaded {d}", .{ num_layers, prefix.items, array_info.len }); - // } - // return true; - //}, .Optional => |opt_info| { obj.* = @as(opt_info.child, undefined); const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false); if (!found) obj.* = null; return true; }, - //.Union => |union_info| { - // // Note: the main issue here is that several fields could match but we only return the first one. - // inline for (union_info.fields) |field| { - // // interpret obj as a "field", and try to populate that. - // obj.* = @unionInit(T, field.name, undefined); - // const found = try _populateStruct(allocator, buffer_store, unique_id, prefix, &@field(obj.*, field.name), false); - // if (found) { - // std.log.info("Interpreted {s} as {s}", .{ prefix.items, @typeName(field.type) }); - // return true; - // } - // } - // obj.* = undefined; - // if (required) { - // std.log.err("Not able to intepret {s} as any member of the union: {s}", .{ prefix.items, @typeName(T) }); - // } - // return false; - //}, .Int => { obj.* = undefined; return true; @@ -540,9 +495,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi } else return error.TypeNotSupported; }, .Struct => |struct_info| { - // TODO(Corentin): See if we keep that - //if (@hasDecl(T, "_zml_reader_skip_me_")) return false; - inline for (struct_info.fields) |field| { try prefix_builder.push(allocator, field.name); defer prefix_builder.pop(); @@ -550,23 +502,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform); } }, - //.Array => |array_info| { - // var new_prefix = prefix; - // if (prefix.items.len > 0) - // new_prefix.appendAssumeCapacity('.'); - // const len = new_prefix.items.len; - // for (obj, 0..) |*value, i| { - // new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); - // const found = try _populateStruct(allocator, buffer_store, unique_id, new_prefix, value, required); - // if (!found) return false; - // new_prefix.shrinkRetainingCapacity(len); - // } - // const num_layers = buffer_store.numLayers(prefix.items); - // if (num_layers != array_info.len) { - // log.warn("Found {d} layers with prefix {s}, but only loaded {d}", .{ num_layers, prefix.items, array_info.len }); - // } - // return true; - //}, .Optional => |opt_info| { var child = @as(opt_info.child, undefined); if (visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &child, platform)) { @@ -576,23 +511,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi else => return err, } }, - //.Union => |union_info| { - // // Note: the main issue here is that several fields could match but we only return the first one. - // inline for (union_info.fields) |field| { - // // interpret obj as a "field", and try to populate that. - // obj.* = @unionInit(T, field.name, undefined); - // const found = try _populateStruct(allocator, buffer_store, unique_id, prefix, &@field(obj.*, field.name), false); - // if (found) { - // std.log.info("Interpreted {s} as {s}", .{ prefix.items, @typeName(field.type) }); - // return true; - // } - // } - // obj.* = undefined; - // if (required) { - // std.log.err("Not able to intepret {s} as any member of the union: {s}", .{ prefix.items, @typeName(T) }); - // } - // return false; - //}, else => {}, } } diff --git a/zml/aio/torch/parser.zig b/zml/aio/torch/parser.zig index 9d3feec..a17ac85 100644 --- a/zml/aio/torch/parser.zig +++ b/zml/aio/torch/parser.zig @@ -95,10 +95,6 @@ pub const Decoder = struct { } fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp { - // TODO(SuperAuguste): deflate using `std.compress.flate`'s `decompressor` - // TODO(SuperAuguste): explore swapping in non-generic reader here instead of using switch(?) - // not sure if that'd actually be beneficial in any way - var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream); var filename_buf: [std.fs.max_path_bytes]u8 = undefined; while (try iter.next()) |entry| { diff --git a/zml/helpers.zig b/zml/helpers.zig index a042e87..1c01bd9 100644 --- a/zml/helpers.zig +++ b/zml/helpers.zig @@ -49,10 +49,6 @@ pub fn collectDims( expected_dim.* = DIM_MISMATCH; } } - // TODO: strict mode: - // else if (mode == .strict) { - // @compileError("Found unexpected axis " ++ @tagName(a) ++ " when collecting " ++ @typeName(ShapeStruct(dims))); - // } } } }).cb, &context, v); diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index fa16886..f4b92ec 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -190,126 +190,6 @@ pub const HostBuffer = struct { res._shape = self._shape.reshape(shape_); return res; } - - pub const Slice = struct { - single: ?i64 = null, - start: i64 = 0, - end: ?i64 = null, - step: i64 = 1, - }; - - pub inline fn copySlice1d(self: HostBuffer, allocator: std.mem.Allocator, axis: i8, _args: Slice) !HostBuffer { - var slices = [_]Slice{.{}} ** 5; - slices[self._shape.axis(axis)] = _args; - return copySlice(self, allocator, slices[0..self._shape.rank()]); - } - - pub fn copySlice(self: HostBuffer, allocator: std.mem.Allocator, slices: []const Slice) !HostBuffer { - const byte_size = self.dtype().sizeOf(); - var start_indices = [_]usize{0} ** 5; - var strides_ = [_]usize{1} ** 5; - const dims = self._shape.dims(); - var sh = self._shape; - - for (slices, 0..) |_args, a| { - const args: Slice = .{ - .start = if (_args.start >= 0) _args.start else _args.start + dims[a], - .end = _args.end orelse dims[a], - .step = _args.step, - }; - start_indices[a] = @intCast(args.start); - strides_[a] = @intCast(args.step); - sh._dims.set(a, b: { - const range = args.end.? - args.start; - const counts = @divFloor(range - 1, args.step) + 1; - break :b counts; - }); - } - - const rk = self.rank(); - meta.assert(rk <= 5, "copySlice only supports less than 5-D tensors. Received: {}", .{self}); - const raw_strides: [Shape.MAX_RANK]usize = blk: { - var res: [Shape.MAX_RANK]usize = undefined; - const _strides = self._shape.computeStrides(self.dtype().sizeOf()); - for (_strides.constSlice(), 0..rk) |stride, i| res[i] = @intCast(stride); - break :blk res; - }; - - const result_tensor = try HostBuffer.empty(allocator, sh); - - const res_strides: [Shape.MAX_RANK]usize = blk: { - var res: [Shape.MAX_RANK]usize = undefined; - const _strides = self._shape.computeStrides(self.dtype().sizeOf()); - for (_strides.constSlice(), 0..rk) |stride, i| res[i] = @intCast(stride); - break :blk res; - }; - - const src_data = self.data; - const data_ = @constCast(result_tensor.data); - for (0..@intCast(sh.dim(0))) |j0| { - const off0 = (j0 * strides_[0] + start_indices[0]) * raw_strides[0]; - const res_off0 = j0 * res_strides[0]; - if (rk == 1) { - @memcpy(data_[res_off0..][0..byte_size], src_data[off0..][0..byte_size]); - continue; - } - for (0..@intCast(sh.dim(1))) |j1| { - const off1 = off0 + (j1 * strides_[1] + start_indices[1]) * raw_strides[1]; - const res_off1 = res_off0 + j1 * res_strides[1]; - if (rk == 2) { - @memcpy(data_[res_off1..][0..byte_size], src_data[off1..][0..byte_size]); - continue; - } - for (0..@intCast(sh.dim(2))) |j2| { - const off2 = off1 + (j2 * strides_[2] + start_indices[2]) * raw_strides[2]; - const res_off2 = res_off1 + j2 * res_strides[2]; - if (rk == 3) { - @memcpy(data_[res_off2..][0..byte_size], src_data[off2..][0..byte_size]); - continue; - } - for (0..@intCast(sh.dim(3))) |j3| { - const off3 = off2 + (j3 * strides_[3] + start_indices[3]) * raw_strides[3]; - const res_off3 = res_off2 + j3 * res_strides[3]; - if (rk == 4) { - @memcpy(data_[res_off3..][0..byte_size], src_data[off3..][0..byte_size]); - continue; - } - for (0..@intCast(sh.dim(4))) |j4| { - const off4 = off3 + (j4 * strides_[4] + start_indices[4]) * raw_strides[4]; - const res_off4 = res_off3 + j4 * res_strides[4]; - @memcpy(data_[res_off4..][0..byte_size], src_data[off4..][0..byte_size]); - } - } - } - } - } - - return result_tensor; - } - - test copySlice { - var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator); - defer arena_state.deinit(); - const allocator = arena_state.allocator(); - - const x = HostBuffer.fromSlice(.{ 2, 5 }, &[_]f32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }); - { - const res = try copySlice1d(x, allocator, 0, .{ .end = 1 }); - try std.testing.expectEqualSlices(f32, &.{ 0, 1, 2, 3, 4 }, res.items(f32)); - } - // { // failing - // const res = try copySlice1d(x, allocator, -1, .{ .start = -2 }); - // try testing.expectEqualSlices(f32, &.{ 3, 4, 8, 9 }, res.items(f32)); - // } - // {// failing - // const res = try copySlice1d(x, allocator, 1, .{ .start = 1, .step = 2 }); - // try testing.expectEqualSlices(f32, &.{ 1, 3, 6, 8 }, res.items(f32)); - // } - { - const res = try copySlice(x, allocator, &.{ .{ .start = 1 }, .{ .start = 1, .step = 2 } }); - try std.testing.expectEqualSlices(f32, &.{ 6, 8 }, res.items(f32)); - } - } }; fn parseArrayInfo(T: type) Shape { diff --git a/zml/module.zig b/zml/module.zig index b227ecd..63c80f9 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -746,9 +746,7 @@ fn compileInternal( var timer = std.time.Timer.start() catch null; const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args); - // TODO: this is fast, doesn't make system call, and use mutable state. - // does it need to be async ? - // const f = try CompilationContext.generateBytecode(context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true }); + // Run in a dedicated thread because compilation relies on `threadlocal`. const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } }); context._module.getBody().appendOperation(f.mlir_fn); diff --git a/zml/nn.zig b/zml/nn.zig index a7c0cc3..eff5803 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -218,13 +218,6 @@ test "real/img" { const platform = zml.testing.env(); const Fns = struct { - // fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor { - // const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 }); - // const real, const imag = splitRealImg(x, impl); - // const y = mergeRealImg(real, imag, impl); - // return y.cmp(.EQ, x).flatten(0).convert(.i32).sum(-1); - // } - fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor { const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 }); const real, const imag = splitRealImg(x, impl); diff --git a/zml/ops.zig b/zml/ops.zig index 4d7e758..9335ff7 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -547,17 +547,6 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature { if (i >= arg_start) { n_tensors += staticCountTensors(ArgType) orelse @compileError("Can't use " ++ @typeName(ArgType) ++ " in an MLIR function, because it has a variable number of tensors"); } - - // if (arg.type) |ArgType| { - // full_args[i] = ArgType; - // if (i >= arg_start) { - // n_tensors += staticCountTensors(ArgType) orelse @compileError("Can't use " ++ @typeName(ArgType) ++ " in an MLIR function, because it has a variable number of tensors"); - // } - // } else { - // // anytype are considered to not have tensors. - // // violation of this will be detected when calling `compile()` but not at Zig compile time. - // full_args[i] = void; - // } } const FullArgs = std.meta.Tuple(&full_args); const BlkCtx = switch (blk_type) { diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 1e8fd96..96457e7 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -167,34 +167,8 @@ pub const Client = opaque { pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler { return self.inner().getProfiler(api, options); } - - // pub fn getGpuCustomCallRegistry(self: Client) ?GpuCustomCallRegistry { - // return switch (self.inner) { - // inline else => |v, tag| if (v.getGpuCustomCallRegistry()) |registry| GpuCustomCallRegistry.wrap(tag, registry) else null, - // }; - // } - - // pub fn getGpuCustomCallRegistry(self: *const Client, api: *const Api) ?*GpuCustomCallRegistry { - // if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| { - // return .{ .custom_call_register = ext.custom_call.? }; - // } - // log.warn("No Gpu Custom Call registry found for platform: {}", .{self}); - // return null; - // } }; -// pub const GpuCustomCallRegistry = struct { -// pub usingnamespace WrapperMixin(GpuCustomCallRegistry, pjrt.GpuCustomCallRegistry); - -// inner: GpuCustomCallRegistry.UnionType, - -// pub fn registerCustomCall(self: GpuCustomCallRegistry, api_version: usize, name: []const u8, func: pjrt.CustomCallSignature) ApiError!void { -// return switch (self.inner) { -// inline else => |v| v.registerCustomCall(api_version, name, func), -// }; -// } -// }; - pub const Buffer = opaque { const inner = InnerMixin(pjrt.Buffer).inner; diff --git a/zml/shape.zig b/zml/shape.zig index 172327e..1c6fce8 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -348,9 +348,6 @@ pub const Shape = struct { return self.dtype().sizeOf() * self.count(); } - // Aliases - pub const numel = count; - /// Compares the two shapes described, ignoring tagging. pub fn eql(self: Shape, other: Shape) bool { return std.mem.eql(i64, self.dims(), other.dims()) and self.dtype() == other.dtype(); @@ -883,78 +880,6 @@ pub const Shape = struct { ); } - /// Parses an anytype argument of the form `val` or `.{ .a = val }`.l - /// Helps offering consistent API through ZML. - // pub fn parseTaggedValue( - // T: type, - // default_tag: EnumLiteral, - // d: anytype, - // ) struct { Tag, T } { - // const err_msg = "Expected one tagged dimension, received a tuple: " ++ @typeName(@TypeOf(d)); - // return switch (@typeInfo(@TypeOf(d))) { - // .Int, .ComptimeInt => .{ toTag(default_tag), @intCast(d) }, - // .Struct => |struct_info| { - // if (struct_info.fields.len != 1) @compileError(err_msg); - // const name = struct_info.fields[0].name; - // return .{ name.ptr, @intCast(@field(d, name)) }; - // }, - // else => @compileError(err_msg), - // }; - // } - - /// Parses a list of tags `.{ .a, .b, .c }` into a `[]Tag` - // pub inline fn parseTagList(comptime axes_: anytype) []Tag { - // switch (@typeInfo(@TypeOf(axes_))) { - // .Struct, .Array => { - // var _tags: [axes_.len]Tag = undefined; - // inline for (axes_, &_tags) |a, *t| t.* = toTag(a); - // return &_tags; - // }, - // else => @compileError("Expected a tuple of enum literal, but found " ++ @tagName(@TypeOf(axes))), - // } - // } - - /// Parses a comptime struct into a struct similarly to Shape.init, - /// but with a custom type in place of the `i64` dimensions. - /// Helps offering consistent API through ZML. - // pub fn parseShapedValue(T: type, value: anytype) struct { - // std.BoundedArray(Tag, MAX_RANK), - // std.BoundedArray(T, MAX_RANK), - // } { - // const too_long_err = std.fmt.comptimePrint("Received too many axes, maximum supported is {d}", .{MAX_RANK}); - - // var _tags: [MAX_RANK]Tag = [_]Tag{TagUnknown} ** MAX_RANK; - // const struct_info = switch (@typeInfo(@TypeOf(value))) { - // .Struct => |struct_info| struct_info, - // else => return .{ - // .{ .len = 0, .buffer = _tags }, - // std.BoundedArray(T, MAX_RANK).fromSlice(value) catch @panic(too_long_err), - // }, - // }; - - // meta.assertComptime(struct_info.fields.len <= MAX_RANK, too_long_err, .{}); - - // var values: std.BoundedArray(T, MAX_RANK) = .{}; - // inline for (struct_info.fields) |field| { - // if (T == Tag) { - // values.appendAssumeCapacity(toTag(@field(value, field.name))); - // } else { - // // If you have an error here it means Zig wasn't able to convert between the - // // value you passed and the expected `T`. - // values.appendAssumeCapacity(@field(value, field.name)); - // } - // } - // if (!struct_info.is_tuple) { - // inline for (struct_info.fields, 0..) |field, i| { - // _tags[i] = toTag(field); - // } - // } - // return .{ - // .{ .len = struct_info.fields.len, .buffer = _tags }, - // values, - // }; - // } - fn intersectTags(a: []const Tag, b: []const Tag) TagsArray { var res = TagsArray.init(0) catch unreachable; for (a) |tag_| {