zml: clean up dead and commented code; note that copyslice is currently broken and pending reimplementation

This commit is contained in:
Tarry Singh 2023-02-08 17:13:47 +00:00
parent 058e1415fa
commit be6328813d
13 changed files with 4 additions and 834 deletions

View File

@ -288,79 +288,6 @@ fn elementTypeOrSelf(typ: mlir.Type) mlir.Type {
} else typ;
}
pub fn scatter(
ctx: mlir.Context,
// inputs
inputs: []const mlir.Value,
scatter_indices: mlir.Value,
updates: []const mlir.Value,
// input functions
update_ctx: anytype, // for update_fn
update_fn: fn (anytype, mlir.Context, []const mlir.Value, []const mlir.Value) mlir.Operation,
// attributes
args: struct {
update_window_dims: []const i64,
inserted_window_dims: []const i64,
input_batching_dims: []const i64,
scatter_indices_batching_dims: []const i64,
scatter_dims_to_operand_dims: []const i64,
index_vector_dim: i64,
indices_are_sorted: bool = false,
unique_indices: bool = false,
},
// zml loc
location: mlir.Location,
) mlir.Operation {
// create block for update_fn
const MaxBlockArguments = 32; // TODO(rene): where does this 32 come from?
// taken from reduce
const block_n_args = inputs.len * 2; // TODO(rene): is this correct? yes, passes tests: block_inputs plus block_accumulators = inputs
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
var scatter_elem_types: [MaxBlockArguments]mlir.Type = undefined;
for (inputs, 0..) |input, i| {
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?;
scatter_elem_types[i] = arg_type;
scatter_elem_types[inputs.len + i] = arg_type;
}
var block = mlir.Block.open(scatter_elem_types[0..block_n_args], locations) catch unreachable;
{
defer block.close();
var block_inputs: [MaxBlockArguments / 2]mlir.Value = undefined;
var block_accs: [MaxBlockArguments / 2]mlir.Value = undefined;
for (0..inputs.len) |i| {
block_inputs[i] = block.argument(i);
block_accs[i] = block.argument(inputs.len + i);
}
_ = update_fn(update_ctx, ctx, block_inputs[0..inputs.len], block_accs[0..inputs.len]);
}
return mlir.Operation.make(
ctx,
"stablehlo.scatter",
.{
.variadic_operands = &.{ inputs, &.{scatter_indices}, updates },
// .blocks = &.{block},
.block = block,
.attributes = &.{
.{ "scatter_dimension_numbers", ScatterDimensionNumbersAttribute.init(
ctx,
args.update_window_dims,
args.inserted_window_dims,
args.input_batching_dims,
args.scatter_indices_batching_dims,
args.scatter_dims_to_operand_dims,
args.index_vector_dim,
).as(mlir.Attribute).? },
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? },
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? },
},
.result_type_inference = true,
.location = location,
},
);
}
pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.iota", .{
.operands = &.{},
@ -439,66 +366,6 @@ pub fn reduce(
});
}
pub const ReduceWindowOpts = struct {
window_dimensions: []const i64,
window_strides: []const i64,
base_dilations: []const i64,
window_dilations: []const i64,
padding_values: []const i64,
padding_shape: []const i64,
};
// pub fn reduce_window(
// ctx: mlir.Context,
// inputs: []const mlir.Value,
// init_values: []const mlir.Value,
// opts: ReduceWindowOpts,
// blkctx: anytype,
// blkfn: fn (anytype, mlir.Context, []const mlir.Value, []const mlir.Value) mlir.Operation,
// location: mlir.Location,
// ) mlir.Operation {
// // TODO: move to ops.zig, and refactor similar to `reduce`
// const MaxBlockArguments = 32;
// const block_n_args = inputs.len + init_values.len;
// const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
// var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined;
// for (inputs, 0..) |input, i| {
// const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?;
// reduce_elem_types[i] = arg_type;
// reduce_elem_types[inputs.len + i] = arg_type;
// }
// const module = @import("../module.zig");
// const comp = module.getCompilationContext();
// var block = comp.openBlock(reduce_elem_types[0..block_n_args], locations) catch unreachable;
// {
// defer comp.closeBlock(block);
// var block_inputs: [MaxBlockArguments / 2]mlir.Value = undefined;
// var block_accs: [MaxBlockArguments / 2]mlir.Value = undefined;
// for (0..inputs.len) |i| {
// block_inputs[i] = block.argument(i);
// block_accs[i] = block.argument(inputs.len + i);
// }
// _ = blkfn(blkctx, ctx, block_inputs[0..inputs.len], block_accs[0..init_values.len]);
// }
// const pad_shape = mlir.RankedTensorType.init(opts.padding_shape, DataType.i64.mlirType(ctx)).as(mlir.Type).?;
// return mlir.Operation.make(ctx, "stablehlo.reduce_window", .{
// .variadic_operands = &.{ inputs, init_values },
// .result_type_inference = true,
// .blocks = &.{block},
// .attributes = &.{
// .{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_dimensions).as(mlir.Attribute).? },
// .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute).? },
// .{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx, opts.base_dilations).as(mlir.Attribute).? },
// .{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_dilations).as(mlir.Attribute).? },
// .{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding_values)).as(mlir.Attribute).? },
// },
// .location = location,
// });
// }
pub fn sort(
ctx: mlir.Context,
inputs: []const mlir.Value,

View File

@ -691,10 +691,6 @@ pub const OperationState = struct {
c.mlirOperationStateAddOwnedRegions(self.innerPtr(), @intCast(regions.len), @ptrCast(regions.ptr));
}
// pub fn addSuccessor(self: *Self, successor: Operation) void {
// c.mlirOperationStateAddSuccessors(self.innerPtr(), 1, &[_]c.MlirOperation{successor.inner()});
// }
pub fn addAttribute(self: *Self, ctx: Context, name: [:0]const u8, attr: Attribute) void {
c.mlirOperationStateAddAttributes(self.innerPtr(), 1, @ptrCast(&.{
.{
@ -745,9 +741,9 @@ pub const DictionaryAttribute = struct {
return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos)));
}
// pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute {
// return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name));
// }
pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute {
return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name));
}
};
pub const Operation = struct {
@ -1519,276 +1515,6 @@ pub const DialectHandle = struct {
}
};
// pub const AnyQuantizedType = MlirWrapperType(c.MlirType, .{
// .is_a_fn = c.mlirTypeIsAAnyQuantizedType,
// .is_null_fn = c.mlirTypeIsNull,
// .dump_fn = c.mlirTypeDump,
// .equal_fn = c.mlirTypeEqual,
// }, struct {
// const Self = AnyQuantizedType;
// pub fn init(
// flags: quant.QuantizationFlags,
// storageType: Type,
// expressedType: Type,
// storageTypeMin: i64,
// storageTypeMax: i64,
// ) Self {
// return Self.wrap(c.mlirAnyQuantizedTypeGet(
// @intCast(@intFromEnum(flags)),
// storageType.inner(),
// expressedType.inner(),
// storageTypeMin,
// storageTypeMax,
// ));
// }
// pub fn getExpressedType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner()));
// }
// pub fn getFlags(self: Self) quant.QuantizationFlags {
// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner()));
// }
// pub fn isSigned(self: Self) bool {
// return c.mlirQuantizedTypeIsSigned(self.inner());
// }
// pub fn getStorageType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner()));
// }
// pub fn getStorageTypeMin(self: Self) i64 {
// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner());
// }
// pub fn getStorageTypeMax(self: Self) i64 {
// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner());
// }
// pub fn getStorageTypeIntegralWidth(self: Self) c_uint {
// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner());
// }
// pub fn getQuantizedElementType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner()));
// }
// });
// pub const UniformQuantizedType = MlirWrapperType(c.MlirType, .{
// .is_a_fn = c.mlirTypeIsAUniformQuantizedType,
// .is_null_fn = c.mlirTypeIsNull,
// .dump_fn = c.mlirTypeDump,
// .equal_fn = c.mlirTypeEqual,
// }, struct {
// const Self = AnyQuantizedType;
// pub fn init(
// flags: quant.QuantizationFlags,
// storageType: Type,
// expressedType: Type,
// scale: f64,
// zeroPoint: i64,
// storageTypeMin: i64,
// storageTypeMax: i64,
// ) Self {
// return Self.wrap(c.mlirUniformQuantizedTypeGet(
// @intCast(@intFromEnum(flags)),
// storageType.inner(),
// expressedType.inner(),
// scale,
// zeroPoint,
// storageTypeMin,
// storageTypeMax,
// ));
// }
// pub fn getExpressedType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner()));
// }
// pub fn getFlags(self: Self) quant.QuantizationFlags {
// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner()));
// }
// pub fn isSigned(self: Self) bool {
// return c.mlirQuantizedTypeIsSigned(self.inner());
// }
// pub fn getStorageType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner()));
// }
// pub fn getStorageTypeMin(self: Self) i64 {
// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner());
// }
// pub fn getStorageTypeMax(self: Self) i64 {
// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner());
// }
// pub fn getStorageTypeIntegralWidth(self: Self) c_uint {
// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner());
// }
// pub fn getQuantizedElementType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner()));
// }
// pub fn getScale(self: Self) f64 {
// return c.mlirUniformQuantizedTypeGetScale(self.inner());
// }
// pub fn getZeroPoint(self: Self) i64 {
// return c.mlirUniformQuantizedTypeGetZeroPoint(self.inner());
// }
// pub fn isFixedPoint(self: Self) bool {
// return c.mlirUniformQuantizedTypeIsFixedPoint(self.inner());
// }
// });
// pub const QuantizedPerAxisType = MlirWrapperType(c.MlirType, .{
// .is_a_fn = c.mlirTypeIsAUniformQuantizedPerAxisType,
// .is_null_fn = c.mlirTypeIsNull,
// .dump_fn = c.mlirTypeDump,
// .equal_fn = c.mlirTypeEqual,
// }, struct {
// const Self = AnyQuantizedType;
// pub fn init(
// flags: quant.QuantizationFlags,
// storageType: Type,
// expressedType: Type,
// nDims: usize,
// scales: []f64,
// zeroPoints: []i64,
// quantizedDimension: i32,
// storageTypeMin: i64,
// storageTypeMax: i64,
// ) Self {
// std.debug.assert(scales.len == nDims);
// std.debug.assert(zeroPoints.len == nDims);
// return Self.wrap(c.mlirUniformQuantizedPerAxisTypeGet(
// @intCast(@intFromEnum(flags)),
// storageType.inner(),
// expressedType.inner(),
// @intCast(nDims),
// scales.ptr,
// zeroPoints.ptr,
// quantizedDimension,
// storageTypeMin,
// storageTypeMax,
// ));
// }
// pub fn getExpressedType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner()));
// }
// pub fn getFlags(self: Self) quant.QuantizationFlags {
// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner()));
// }
// pub fn isSigned(self: Self) bool {
// return c.mlirQuantizedTypeIsSigned(self.inner());
// }
// pub fn getStorageType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner()));
// }
// pub fn getStorageTypeMin(self: Self) i64 {
// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner());
// }
// pub fn getStorageTypeMax(self: Self) i64 {
// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner());
// }
// pub fn getStorageTypeIntegralWidth(self: Self) c_uint {
// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner());
// }
// pub fn getQuantizedElementType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner()));
// }
// pub fn getNumDims(self: Self) usize {
// return @intCast(c.mlirUniformQuantizedPerAxisTypeGetNumDims(self.inner()));
// }
// pub fn getScale(self: Self) f64 {
// return @intCast(c.mlirUniformQuantizedPerAxisTypeGetScale(self.inner()));
// }
// pub fn getZeroPoint(self: Self, pos: usize) i64 {
// return c.mlirUniformQuantizedPerAxisTypeGetZeroPoint(self.inner(), @intCast(pos));
// }
// pub fn getQuantizedDimension(self: Self) i32 {
// return c.mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(self.inner());
// }
// pub fn isFixedPoint(self: Self) bool {
// return c.mlirUniformQuantizedPerAxisTypeIsFixedPoint(self.inner());
// }
// });
// pub const CalibratedQuantizedType = MlirWrapperType(c.MlirType, .{
// .is_a_fn = c.mlirTypeIsACalibratedQuantizedType,
// .is_null_fn = c.mlirTypeIsNull,
// .dump_fn = c.mlirTypeDump,
// .equal_fn = c.mlirTypeEqual,
// }, struct {
// const Self = AnyQuantizedType;
// pub fn init(expressedType: Type, min: f64, max: f64) Self {
// return Self.wrap(c.mlirCalibratedQuantizedTypeGet(expressedType.inner(), min, max));
// }
// pub fn getExpressedType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner()));
// }
// pub fn getFlags(self: Self) quant.QuantizationFlags {
// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner()));
// }
// pub fn isSigned(self: Self) bool {
// return c.mlirQuantizedTypeIsSigned(self.inner());
// }
// pub fn getStorageType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner()));
// }
// pub fn getStorageTypeMin(self: Self) i64 {
// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner());
// }
// pub fn getStorageTypeMax(self: Self) i64 {
// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner());
// }
// pub fn getStorageTypeIntegralWidth(self: Self) c_uint {
// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner());
// }
// pub fn getQuantizedElementType(self: Self) Type {
// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner()));
// }
// pub fn getMin(self: Self) f64 {
// return c.mlirCalibratedQuantizedTypeGetMin(self.inner());
// }
// pub fn getMax(self: Self) f64 {
// return c.mlirCalibratedQuantizedTypeGetMax(self.inner());
// }
// });
pub const ShapedType = struct {
_inner: c.MlirType,
pub usingnamespace MlirHelpers(ShapedType, .{

View File

@ -319,14 +319,6 @@ pub const Client = opaque {
return Profiler.init(null, options);
}
// pub fn getGpuCustomCallRegistry(self: *const Client, api: *const Api) ?*GpuCustomCallRegistry {
// if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
// return .{ .custom_call_register = ext.custom_call.? };
// }
// log.warn("No Gpu Custom Call registry found for platform: {}", .{self});
// return null;
// }
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
const ret = try api.call(.PJRT_Executable_DeserializeAndLoad, .{
.client = self.inner(),
@ -365,32 +357,6 @@ pub const Client = opaque {
}
};
// // pub const CustomCallSignature = *const fn (*anyopaque, **anyopaque, [*c]const u8, usize) callconv(.C) void;
// // pub const GpuCustomCallRegistry = struct {
// // custom_call_register: *const c.PJRT_Gpu_Register_Custom_Call,
// // pub fn registerCustomCall(self: GpuCustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: CustomCallSignature) ApiError!void {
// // var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{
// // .function_name = name.ptr,
// // .function_name_size = name.len,
// // .api_version = @intCast(api_version),
// // .custom_call_function = @ptrCast(@constCast(func)),
// // });
// // const result = self.custom_call_register(&ret);
// // if (result) |pjrt_c_error| {
// // const pjrt_error = .{ .inner = pjrt_c_error };
// // log.err("{s}", .{pjrt_error.getMessage(api)});
// // return pjrt_error.getCode().toApiError();
// // }
// // }
// // };
// // const OldPjrtExtension = extern struct {
// // type: c.PJRT_Extension_Type,
// // next: [*]OldPjrtExtension,
// // };
pub const Device = opaque {
const inner = InnerMixin(c.PJRT_Device).inner;

View File

@ -128,64 +128,6 @@ pub const Profiler = struct {
}
};
// If this was working it would be a good alternative to xspace_to_json.cc
// const xspace = @import("xspace.pb.zig");
// pub fn printDataAsXSpace(allocator: std.mem.Allocator, data: []const u8) void {
// var arena = std.heap.ArenaAllocator.init(allocator);
// defer arena.deinit();
//
// const space = xspace.XSpace.decode(data, arena.allocator()) catch |e| {
// std.log.err("Couldn't load profiling data: {}", .{e});
// return;
// };
//
// for (space.errors.items) |err| {
// std.log.err("{s}", .{err.getSlice()});
// }
// for (space.warnings.items) |warning| {
// std.log.warn("{s}", .{warning.getSlice()});
// }
// for (space.hostnames.items) |host| {
// std.log.info("Profiled host {s}", .{host.getSlice()});
// }
// for (space.planes.items) |plane| {
// var event_metadata = std.hash_map.AutoHashMap(i64, xspace.XEventMetadata).init(arena.allocator());
// event_metadata.ensureTotalCapacity(@intCast(plane.event_metadata.items.len)) catch return;
// defer event_metadata.deinit();
// for (plane.event_metadata.items) |event_meta_entry| {
// if (event_meta_entry.value) |event_meta| {
// event_metadata.putAssumeCapacity(event_meta.id, event_meta);
// }
// }
// std.log.info("Profiled device {s}", .{plane.name.getSlice()});
// for (plane.lines.items) |line| {
// std.log.info(
// "{d} -> {d} xline {s} ({d} events)",
// .{ line.timestamp_ns, line.duration_ps, line.name.getSlice(), line.events.items.len },
// );
// const ps_per_ns: i64 = 1000;
// var duration_ns: i64 = 0;
// var last_metadata_id: i64 = 0;
// for (line.events.items) |event| {
// if (event.metadata_id != last_metadata_id and duration_ns != 0) {
// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us;
// const meta = event_metadata.get(event.metadata_id).?;
// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us });
// last_metadata_id = event.metadata_id;
// duration_ns = 0;
// }
// duration_ns += @divFloor(event.duration_ps, ps_per_ns);
// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us;
// const meta = event_metadata.get(event.metadata_id).?;
// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us });
// }
// }
// }
// }
const ProfilingData = union(enum) {
owned: []const u8,
external: []const u8,

View File

@ -300,23 +300,12 @@ fn _populateStruct(
log.warn("No layer found at {s}", .{prefix});
}
return true;
} else if (ptr_info.size == .One) {
//if (ptr_info.child != zml.Tensor and ptr_info.child != ?zml.Tensor) {
// // Note: should we recurse on all pointers ?
// log.warn("Not looking into: {any}", .{prefix});
// return false;
//}
//obj.* = try allocator.create(ptr_info.child);
//return try _populateStruct(allocator, buffer_store, unique_id, prefix, obj.*, required);
} else {
std.log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) });
return false;
}
},
.Struct => |struct_info| {
// TODO(Corentin): See if we keep that
//if (@hasDecl(T, "_zml_reader_skip_me_")) return false;
var partial_struct = false;
inline for (struct_info.fields) |field| {
try prefix_builder.push(allocator, field.name);
@ -343,46 +332,12 @@ fn _populateStruct(
}
return true;
},
//.Array => |array_info| {
// var new_prefix = prefix;
// if (prefix.items.len > 0)
// new_prefix.appendAssumeCapacity('.');
// const len = new_prefix.items.len;
// for (obj, 0..) |*value, i| {
// new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
// const found = try _populateStruct(allocator, buffer_store, unique_id, new_prefix, value, required);
// if (!found) return false;
// new_prefix.shrinkRetainingCapacity(len);
// }
// const num_layers = buffer_store.numLayers(prefix.items);
// if (num_layers != array_info.len) {
// log.warn("Found {d} layers with prefix {s}, but only loaded {d}", .{ num_layers, prefix.items, array_info.len });
// }
// return true;
//},
.Optional => |opt_info| {
obj.* = @as(opt_info.child, undefined);
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false);
if (!found) obj.* = null;
return true;
},
//.Union => |union_info| {
// // Note: the main issue here is that several fields could match but we only return the first one.
// inline for (union_info.fields) |field| {
// // interpret obj as a "field", and try to populate that.
// obj.* = @unionInit(T, field.name, undefined);
// const found = try _populateStruct(allocator, buffer_store, unique_id, prefix, &@field(obj.*, field.name), false);
// if (found) {
// std.log.info("Interpreted {s} as {s}", .{ prefix.items, @typeName(field.type) });
// return true;
// }
// }
// obj.* = undefined;
// if (required) {
// std.log.err("Not able to intepret {s} as any member of the union: {s}", .{ prefix.items, @typeName(T) });
// }
// return false;
//},
.Int => {
obj.* = undefined;
return true;
@ -540,9 +495,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
} else return error.TypeNotSupported;
},
.Struct => |struct_info| {
// TODO(Corentin): See if we keep that
//if (@hasDecl(T, "_zml_reader_skip_me_")) return false;
inline for (struct_info.fields) |field| {
try prefix_builder.push(allocator, field.name);
defer prefix_builder.pop();
@ -550,23 +502,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
}
},
//.Array => |array_info| {
// var new_prefix = prefix;
// if (prefix.items.len > 0)
// new_prefix.appendAssumeCapacity('.');
// const len = new_prefix.items.len;
// for (obj, 0..) |*value, i| {
// new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
// const found = try _populateStruct(allocator, buffer_store, unique_id, new_prefix, value, required);
// if (!found) return false;
// new_prefix.shrinkRetainingCapacity(len);
// }
// const num_layers = buffer_store.numLayers(prefix.items);
// if (num_layers != array_info.len) {
// log.warn("Found {d} layers with prefix {s}, but only loaded {d}", .{ num_layers, prefix.items, array_info.len });
// }
// return true;
//},
.Optional => |opt_info| {
var child = @as(opt_info.child, undefined);
if (visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &child, platform)) {
@ -576,23 +511,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
else => return err,
}
},
//.Union => |union_info| {
// // Note: the main issue here is that several fields could match but we only return the first one.
// inline for (union_info.fields) |field| {
// // interpret obj as a "field", and try to populate that.
// obj.* = @unionInit(T, field.name, undefined);
// const found = try _populateStruct(allocator, buffer_store, unique_id, prefix, &@field(obj.*, field.name), false);
// if (found) {
// std.log.info("Interpreted {s} as {s}", .{ prefix.items, @typeName(field.type) });
// return true;
// }
// }
// obj.* = undefined;
// if (required) {
// std.log.err("Not able to intepret {s} as any member of the union: {s}", .{ prefix.items, @typeName(T) });
// }
// return false;
//},
else => {},
}
}

View File

@ -95,10 +95,6 @@ pub const Decoder = struct {
}
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp {
// TODO(SuperAuguste): deflate using `std.compress.flate`'s `decompressor`
// TODO(SuperAuguste): explore swapping in non-generic reader here instead of using switch(?)
// not sure if that'd actually be beneficial in any way
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {

View File

@ -49,10 +49,6 @@ pub fn collectDims(
expected_dim.* = DIM_MISMATCH;
}
}
// TODO: strict mode:
// else if (mode == .strict) {
// @compileError("Found unexpected axis " ++ @tagName(a) ++ " when collecting " ++ @typeName(ShapeStruct(dims)));
// }
}
}
}).cb, &context, v);

View File

@ -190,126 +190,6 @@ pub const HostBuffer = struct {
res._shape = self._shape.reshape(shape_);
return res;
}
pub const Slice = struct {
single: ?i64 = null,
start: i64 = 0,
end: ?i64 = null,
step: i64 = 1,
};
pub inline fn copySlice1d(self: HostBuffer, allocator: std.mem.Allocator, axis: i8, _args: Slice) !HostBuffer {
var slices = [_]Slice{.{}} ** 5;
slices[self._shape.axis(axis)] = _args;
return copySlice(self, allocator, slices[0..self._shape.rank()]);
}
pub fn copySlice(self: HostBuffer, allocator: std.mem.Allocator, slices: []const Slice) !HostBuffer {
const byte_size = self.dtype().sizeOf();
var start_indices = [_]usize{0} ** 5;
var strides_ = [_]usize{1} ** 5;
const dims = self._shape.dims();
var sh = self._shape;
for (slices, 0..) |_args, a| {
const args: Slice = .{
.start = if (_args.start >= 0) _args.start else _args.start + dims[a],
.end = _args.end orelse dims[a],
.step = _args.step,
};
start_indices[a] = @intCast(args.start);
strides_[a] = @intCast(args.step);
sh._dims.set(a, b: {
const range = args.end.? - args.start;
const counts = @divFloor(range - 1, args.step) + 1;
break :b counts;
});
}
const rk = self.rank();
meta.assert(rk <= 5, "copySlice only supports less than 5-D tensors. Received: {}", .{self});
const raw_strides: [Shape.MAX_RANK]usize = blk: {
var res: [Shape.MAX_RANK]usize = undefined;
const _strides = self._shape.computeStrides(self.dtype().sizeOf());
for (_strides.constSlice(), 0..rk) |stride, i| res[i] = @intCast(stride);
break :blk res;
};
const result_tensor = try HostBuffer.empty(allocator, sh);
const res_strides: [Shape.MAX_RANK]usize = blk: {
var res: [Shape.MAX_RANK]usize = undefined;
const _strides = self._shape.computeStrides(self.dtype().sizeOf());
for (_strides.constSlice(), 0..rk) |stride, i| res[i] = @intCast(stride);
break :blk res;
};
const src_data = self.data;
const data_ = @constCast(result_tensor.data);
for (0..@intCast(sh.dim(0))) |j0| {
const off0 = (j0 * strides_[0] + start_indices[0]) * raw_strides[0];
const res_off0 = j0 * res_strides[0];
if (rk == 1) {
@memcpy(data_[res_off0..][0..byte_size], src_data[off0..][0..byte_size]);
continue;
}
for (0..@intCast(sh.dim(1))) |j1| {
const off1 = off0 + (j1 * strides_[1] + start_indices[1]) * raw_strides[1];
const res_off1 = res_off0 + j1 * res_strides[1];
if (rk == 2) {
@memcpy(data_[res_off1..][0..byte_size], src_data[off1..][0..byte_size]);
continue;
}
for (0..@intCast(sh.dim(2))) |j2| {
const off2 = off1 + (j2 * strides_[2] + start_indices[2]) * raw_strides[2];
const res_off2 = res_off1 + j2 * res_strides[2];
if (rk == 3) {
@memcpy(data_[res_off2..][0..byte_size], src_data[off2..][0..byte_size]);
continue;
}
for (0..@intCast(sh.dim(3))) |j3| {
const off3 = off2 + (j3 * strides_[3] + start_indices[3]) * raw_strides[3];
const res_off3 = res_off2 + j3 * res_strides[3];
if (rk == 4) {
@memcpy(data_[res_off3..][0..byte_size], src_data[off3..][0..byte_size]);
continue;
}
for (0..@intCast(sh.dim(4))) |j4| {
const off4 = off3 + (j4 * strides_[4] + start_indices[4]) * raw_strides[4];
const res_off4 = res_off3 + j4 * res_strides[4];
@memcpy(data_[res_off4..][0..byte_size], src_data[off4..][0..byte_size]);
}
}
}
}
}
return result_tensor;
}
test copySlice {
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena_state.deinit();
const allocator = arena_state.allocator();
const x = HostBuffer.fromSlice(.{ 2, 5 }, &[_]f32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
{
const res = try copySlice1d(x, allocator, 0, .{ .end = 1 });
try std.testing.expectEqualSlices(f32, &.{ 0, 1, 2, 3, 4 }, res.items(f32));
}
// { // failing
// const res = try copySlice1d(x, allocator, -1, .{ .start = -2 });
// try testing.expectEqualSlices(f32, &.{ 3, 4, 8, 9 }, res.items(f32));
// }
// {// failing
// const res = try copySlice1d(x, allocator, 1, .{ .start = 1, .step = 2 });
// try testing.expectEqualSlices(f32, &.{ 1, 3, 6, 8 }, res.items(f32));
// }
{
const res = try copySlice(x, allocator, &.{ .{ .start = 1 }, .{ .start = 1, .step = 2 } });
try std.testing.expectEqualSlices(f32, &.{ 6, 8 }, res.items(f32));
}
}
};
fn parseArrayInfo(T: type) Shape {

View File

@ -746,9 +746,7 @@ fn compileInternal(
var timer = std.time.Timer.start() catch null;
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
// TODO: this is fast, doesn't make system call, and use mutable state.
// does it need to be async ?
// const f = try CompilationContext.generateBytecode(context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true });
// Run in a dedicated thread because compilation relies on `threadlocal`.
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } });
context._module.getBody().appendOperation(f.mlir_fn);

View File

@ -218,13 +218,6 @@ test "real/img" {
const platform = zml.testing.env();
const Fns = struct {
// fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor {
// const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 });
// const real, const imag = splitRealImg(x, impl);
// const y = mergeRealImg(real, imag, impl);
// return y.cmp(.EQ, x).flatten(0).convert(.i32).sum(-1);
// }
fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor {
const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 });
const real, const imag = splitRealImg(x, impl);

View File

@ -547,17 +547,6 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature {
if (i >= arg_start) {
n_tensors += staticCountTensors(ArgType) orelse @compileError("Can't use " ++ @typeName(ArgType) ++ " in an MLIR function, because it has a variable number of tensors");
}
// if (arg.type) |ArgType| {
// full_args[i] = ArgType;
// if (i >= arg_start) {
// n_tensors += staticCountTensors(ArgType) orelse @compileError("Can't use " ++ @typeName(ArgType) ++ " in an MLIR function, because it has a variable number of tensors");
// }
// } else {
// // anytype are considered to not have tensors.
// // violation of this will be detected when calling `compile()` but not at Zig compile time.
// full_args[i] = void;
// }
}
const FullArgs = std.meta.Tuple(&full_args);
const BlkCtx = switch (blk_type) {

View File

@ -167,34 +167,8 @@ pub const Client = opaque {
pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler {
return self.inner().getProfiler(api, options);
}
// pub fn getGpuCustomCallRegistry(self: Client) ?GpuCustomCallRegistry {
// return switch (self.inner) {
// inline else => |v, tag| if (v.getGpuCustomCallRegistry()) |registry| GpuCustomCallRegistry.wrap(tag, registry) else null,
// };
// }
// pub fn getGpuCustomCallRegistry(self: *const Client, api: *const Api) ?*GpuCustomCallRegistry {
// if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
// return .{ .custom_call_register = ext.custom_call.? };
// }
// log.warn("No Gpu Custom Call registry found for platform: {}", .{self});
// return null;
// }
};
// pub const GpuCustomCallRegistry = struct {
// pub usingnamespace WrapperMixin(GpuCustomCallRegistry, pjrt.GpuCustomCallRegistry);
// inner: GpuCustomCallRegistry.UnionType,
// pub fn registerCustomCall(self: GpuCustomCallRegistry, api_version: usize, name: []const u8, func: pjrt.CustomCallSignature) ApiError!void {
// return switch (self.inner) {
// inline else => |v| v.registerCustomCall(api_version, name, func),
// };
// }
// };
pub const Buffer = opaque {
const inner = InnerMixin(pjrt.Buffer).inner;

View File

@ -348,9 +348,6 @@ pub const Shape = struct {
return self.dtype().sizeOf() * self.count();
}
// Aliases
pub const numel = count;
/// Compares the two shapes described, ignoring tagging.
pub fn eql(self: Shape, other: Shape) bool {
return std.mem.eql(i64, self.dims(), other.dims()) and self.dtype() == other.dtype();
@ -883,78 +880,6 @@ pub const Shape = struct {
);
}
/// Parses an anytype argument of the form `val` or `.{ .a = val }`.l
/// Helps offering consistent API through ZML.
// pub fn parseTaggedValue(
// T: type,
// default_tag: EnumLiteral,
// d: anytype,
// ) struct { Tag, T } {
// const err_msg = "Expected one tagged dimension, received a tuple: " ++ @typeName(@TypeOf(d));
// return switch (@typeInfo(@TypeOf(d))) {
// .Int, .ComptimeInt => .{ toTag(default_tag), @intCast(d) },
// .Struct => |struct_info| {
// if (struct_info.fields.len != 1) @compileError(err_msg);
// const name = struct_info.fields[0].name;
// return .{ name.ptr, @intCast(@field(d, name)) };
// },
// else => @compileError(err_msg),
// };
// }
/// Parses a list of tags `.{ .a, .b, .c }` into a `[]Tag`
// pub inline fn parseTagList(comptime axes_: anytype) []Tag {
// switch (@typeInfo(@TypeOf(axes_))) {
// .Struct, .Array => {
// var _tags: [axes_.len]Tag = undefined;
// inline for (axes_, &_tags) |a, *t| t.* = toTag(a);
// return &_tags;
// },
// else => @compileError("Expected a tuple of enum literal, but found " ++ @tagName(@TypeOf(axes))),
// }
// }
/// Parses a comptime struct into a struct similarly to Shape.init,
/// but with a custom type in place of the `i64` dimensions.
/// Helps offering consistent API through ZML.
// pub fn parseShapedValue(T: type, value: anytype) struct {
// std.BoundedArray(Tag, MAX_RANK),
// std.BoundedArray(T, MAX_RANK),
// } {
// const too_long_err = std.fmt.comptimePrint("Received too many axes, maximum supported is {d}", .{MAX_RANK});
// var _tags: [MAX_RANK]Tag = [_]Tag{TagUnknown} ** MAX_RANK;
// const struct_info = switch (@typeInfo(@TypeOf(value))) {
// .Struct => |struct_info| struct_info,
// else => return .{
// .{ .len = 0, .buffer = _tags },
// std.BoundedArray(T, MAX_RANK).fromSlice(value) catch @panic(too_long_err),
// },
// };
// meta.assertComptime(struct_info.fields.len <= MAX_RANK, too_long_err, .{});
// var values: std.BoundedArray(T, MAX_RANK) = .{};
// inline for (struct_info.fields) |field| {
// if (T == Tag) {
// values.appendAssumeCapacity(toTag(@field(value, field.name)));
// } else {
// // If you have an error here it means Zig wasn't able to convert between the
// // value you passed and the expected `T`.
// values.appendAssumeCapacity(@field(value, field.name));
// }
// }
// if (!struct_info.is_tuple) {
// inline for (struct_info.fields, 0..) |field, i| {
// _tags[i] = toTag(field);
// }
// }
// return .{
// .{ .len = struct_info.fields.len, .buffer = _tags },
// values,
// };
// }
fn intersectTags(a: []const Tag, b: []const Tag) TagsArray {
var res = TagsArray.init(0) catch unreachable;
for (a) |tag_| {