zml: clean up dead and commented code; note that copyslice is currently broken and pending reimplementation
This commit is contained in:
parent
058e1415fa
commit
be6328813d
@ -288,79 +288,6 @@ fn elementTypeOrSelf(typ: mlir.Type) mlir.Type {
|
|||||||
} else typ;
|
} else typ;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn scatter(
|
|
||||||
ctx: mlir.Context,
|
|
||||||
// inputs
|
|
||||||
inputs: []const mlir.Value,
|
|
||||||
scatter_indices: mlir.Value,
|
|
||||||
updates: []const mlir.Value,
|
|
||||||
// input functions
|
|
||||||
update_ctx: anytype, // for update_fn
|
|
||||||
update_fn: fn (anytype, mlir.Context, []const mlir.Value, []const mlir.Value) mlir.Operation,
|
|
||||||
// attributes
|
|
||||||
args: struct {
|
|
||||||
update_window_dims: []const i64,
|
|
||||||
inserted_window_dims: []const i64,
|
|
||||||
input_batching_dims: []const i64,
|
|
||||||
scatter_indices_batching_dims: []const i64,
|
|
||||||
scatter_dims_to_operand_dims: []const i64,
|
|
||||||
index_vector_dim: i64,
|
|
||||||
indices_are_sorted: bool = false,
|
|
||||||
unique_indices: bool = false,
|
|
||||||
},
|
|
||||||
// zml loc
|
|
||||||
location: mlir.Location,
|
|
||||||
) mlir.Operation {
|
|
||||||
// create block for update_fn
|
|
||||||
const MaxBlockArguments = 32; // TODO(rene): where does this 32 come from?
|
|
||||||
// taken from reduce
|
|
||||||
|
|
||||||
const block_n_args = inputs.len * 2; // TODO(rene): is this correct? yes, passes tests: block_inputs plus block_accumulators = inputs
|
|
||||||
const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
|
|
||||||
var scatter_elem_types: [MaxBlockArguments]mlir.Type = undefined;
|
|
||||||
for (inputs, 0..) |input, i| {
|
|
||||||
const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?;
|
|
||||||
scatter_elem_types[i] = arg_type;
|
|
||||||
scatter_elem_types[inputs.len + i] = arg_type;
|
|
||||||
}
|
|
||||||
|
|
||||||
var block = mlir.Block.open(scatter_elem_types[0..block_n_args], locations) catch unreachable;
|
|
||||||
{
|
|
||||||
defer block.close();
|
|
||||||
var block_inputs: [MaxBlockArguments / 2]mlir.Value = undefined;
|
|
||||||
var block_accs: [MaxBlockArguments / 2]mlir.Value = undefined;
|
|
||||||
for (0..inputs.len) |i| {
|
|
||||||
block_inputs[i] = block.argument(i);
|
|
||||||
block_accs[i] = block.argument(inputs.len + i);
|
|
||||||
}
|
|
||||||
_ = update_fn(update_ctx, ctx, block_inputs[0..inputs.len], block_accs[0..inputs.len]);
|
|
||||||
}
|
|
||||||
return mlir.Operation.make(
|
|
||||||
ctx,
|
|
||||||
"stablehlo.scatter",
|
|
||||||
.{
|
|
||||||
.variadic_operands = &.{ inputs, &.{scatter_indices}, updates },
|
|
||||||
// .blocks = &.{block},
|
|
||||||
.block = block,
|
|
||||||
.attributes = &.{
|
|
||||||
.{ "scatter_dimension_numbers", ScatterDimensionNumbersAttribute.init(
|
|
||||||
ctx,
|
|
||||||
args.update_window_dims,
|
|
||||||
args.inserted_window_dims,
|
|
||||||
args.input_batching_dims,
|
|
||||||
args.scatter_indices_batching_dims,
|
|
||||||
args.scatter_dims_to_operand_dims,
|
|
||||||
args.index_vector_dim,
|
|
||||||
).as(mlir.Attribute).? },
|
|
||||||
.{ "indices_are_sorted", mlir.BoolAttribute.init(ctx, args.indices_are_sorted).as(mlir.Attribute).? },
|
|
||||||
.{ "unique_indices", mlir.BoolAttribute.init(ctx, args.unique_indices).as(mlir.Attribute).? },
|
|
||||||
},
|
|
||||||
.result_type_inference = true,
|
|
||||||
.location = location,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation {
|
pub fn iota(ctx: mlir.Context, dimension: i64, result_type: mlir.Type, location: mlir.Location) mlir.Operation {
|
||||||
return mlir.Operation.make(ctx, "stablehlo.iota", .{
|
return mlir.Operation.make(ctx, "stablehlo.iota", .{
|
||||||
.operands = &.{},
|
.operands = &.{},
|
||||||
@ -439,66 +366,6 @@ pub fn reduce(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const ReduceWindowOpts = struct {
|
|
||||||
window_dimensions: []const i64,
|
|
||||||
window_strides: []const i64,
|
|
||||||
base_dilations: []const i64,
|
|
||||||
window_dilations: []const i64,
|
|
||||||
padding_values: []const i64,
|
|
||||||
padding_shape: []const i64,
|
|
||||||
};
|
|
||||||
|
|
||||||
// pub fn reduce_window(
|
|
||||||
// ctx: mlir.Context,
|
|
||||||
// inputs: []const mlir.Value,
|
|
||||||
// init_values: []const mlir.Value,
|
|
||||||
// opts: ReduceWindowOpts,
|
|
||||||
// blkctx: anytype,
|
|
||||||
// blkfn: fn (anytype, mlir.Context, []const mlir.Value, []const mlir.Value) mlir.Operation,
|
|
||||||
// location: mlir.Location,
|
|
||||||
// ) mlir.Operation {
|
|
||||||
// // TODO: move to ops.zig, and refactor similar to `reduce`
|
|
||||||
// const MaxBlockArguments = 32;
|
|
||||||
|
|
||||||
// const block_n_args = inputs.len + init_values.len;
|
|
||||||
// const locations = ([_]mlir.Location{mlir.Location.unknown(ctx)} ** MaxBlockArguments)[0..block_n_args];
|
|
||||||
// var reduce_elem_types: [MaxBlockArguments]mlir.Type = undefined;
|
|
||||||
// for (inputs, 0..) |input, i| {
|
|
||||||
// const arg_type = mlir.RankedTensorType.init(&.{}, elementTypeOrSelf(input.getType())).as(mlir.Type).?;
|
|
||||||
// reduce_elem_types[i] = arg_type;
|
|
||||||
// reduce_elem_types[inputs.len + i] = arg_type;
|
|
||||||
// }
|
|
||||||
// const module = @import("../module.zig");
|
|
||||||
// const comp = module.getCompilationContext();
|
|
||||||
// var block = comp.openBlock(reduce_elem_types[0..block_n_args], locations) catch unreachable;
|
|
||||||
// {
|
|
||||||
// defer comp.closeBlock(block);
|
|
||||||
|
|
||||||
// var block_inputs: [MaxBlockArguments / 2]mlir.Value = undefined;
|
|
||||||
// var block_accs: [MaxBlockArguments / 2]mlir.Value = undefined;
|
|
||||||
// for (0..inputs.len) |i| {
|
|
||||||
// block_inputs[i] = block.argument(i);
|
|
||||||
// block_accs[i] = block.argument(inputs.len + i);
|
|
||||||
// }
|
|
||||||
// _ = blkfn(blkctx, ctx, block_inputs[0..inputs.len], block_accs[0..init_values.len]);
|
|
||||||
// }
|
|
||||||
|
|
||||||
// const pad_shape = mlir.RankedTensorType.init(opts.padding_shape, DataType.i64.mlirType(ctx)).as(mlir.Type).?;
|
|
||||||
// return mlir.Operation.make(ctx, "stablehlo.reduce_window", .{
|
|
||||||
// .variadic_operands = &.{ inputs, init_values },
|
|
||||||
// .result_type_inference = true,
|
|
||||||
// .blocks = &.{block},
|
|
||||||
// .attributes = &.{
|
|
||||||
// .{ "window_dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_dimensions).as(mlir.Attribute).? },
|
|
||||||
// .{ "window_strides", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_strides).as(mlir.Attribute).? },
|
|
||||||
// .{ "base_dilations", mlir.DenseArrayAttribute(.i64).init(ctx, opts.base_dilations).as(mlir.Attribute).? },
|
|
||||||
// .{ "window_dilations", mlir.DenseArrayAttribute(.i64).init(ctx, opts.window_dilations).as(mlir.Attribute).? },
|
|
||||||
// .{ "padding", mlir.DenseIntOrFPElementsAttribute(.i64).init(pad_shape, std.mem.sliceAsBytes(opts.padding_values)).as(mlir.Attribute).? },
|
|
||||||
// },
|
|
||||||
// .location = location,
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn sort(
|
pub fn sort(
|
||||||
ctx: mlir.Context,
|
ctx: mlir.Context,
|
||||||
inputs: []const mlir.Value,
|
inputs: []const mlir.Value,
|
||||||
|
|||||||
280
mlir/mlir.zig
280
mlir/mlir.zig
@ -691,10 +691,6 @@ pub const OperationState = struct {
|
|||||||
c.mlirOperationStateAddOwnedRegions(self.innerPtr(), @intCast(regions.len), @ptrCast(regions.ptr));
|
c.mlirOperationStateAddOwnedRegions(self.innerPtr(), @intCast(regions.len), @ptrCast(regions.ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn addSuccessor(self: *Self, successor: Operation) void {
|
|
||||||
// c.mlirOperationStateAddSuccessors(self.innerPtr(), 1, &[_]c.MlirOperation{successor.inner()});
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn addAttribute(self: *Self, ctx: Context, name: [:0]const u8, attr: Attribute) void {
|
pub fn addAttribute(self: *Self, ctx: Context, name: [:0]const u8, attr: Attribute) void {
|
||||||
c.mlirOperationStateAddAttributes(self.innerPtr(), 1, @ptrCast(&.{
|
c.mlirOperationStateAddAttributes(self.innerPtr(), 1, @ptrCast(&.{
|
||||||
.{
|
.{
|
||||||
@ -745,9 +741,9 @@ pub const DictionaryAttribute = struct {
|
|||||||
return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos)));
|
return NamedAttribute.wrap(c.mlirDictionaryAttrGetElement(self.inner(), @intCast(pos)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute {
|
pub fn getByName(self: Self, name: [:0]const u8) ?NamedAttribute {
|
||||||
// return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name));
|
return NamedAttribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self.inner(), name));
|
||||||
// }
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Operation = struct {
|
pub const Operation = struct {
|
||||||
@ -1519,276 +1515,6 @@ pub const DialectHandle = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// pub const AnyQuantizedType = MlirWrapperType(c.MlirType, .{
|
|
||||||
// .is_a_fn = c.mlirTypeIsAAnyQuantizedType,
|
|
||||||
// .is_null_fn = c.mlirTypeIsNull,
|
|
||||||
// .dump_fn = c.mlirTypeDump,
|
|
||||||
// .equal_fn = c.mlirTypeEqual,
|
|
||||||
// }, struct {
|
|
||||||
// const Self = AnyQuantizedType;
|
|
||||||
|
|
||||||
// pub fn init(
|
|
||||||
// flags: quant.QuantizationFlags,
|
|
||||||
// storageType: Type,
|
|
||||||
// expressedType: Type,
|
|
||||||
// storageTypeMin: i64,
|
|
||||||
// storageTypeMax: i64,
|
|
||||||
// ) Self {
|
|
||||||
// return Self.wrap(c.mlirAnyQuantizedTypeGet(
|
|
||||||
// @intCast(@intFromEnum(flags)),
|
|
||||||
// storageType.inner(),
|
|
||||||
// expressedType.inner(),
|
|
||||||
// storageTypeMin,
|
|
||||||
// storageTypeMax,
|
|
||||||
// ));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getExpressedType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getFlags(self: Self) quant.QuantizationFlags {
|
|
||||||
// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn isSigned(self: Self) bool {
|
|
||||||
// return c.mlirQuantizedTypeIsSigned(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeMin(self: Self) i64 {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeMax(self: Self) i64 {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeIntegralWidth(self: Self) c_uint {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getQuantizedElementType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner()));
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
|
|
||||||
// pub const UniformQuantizedType = MlirWrapperType(c.MlirType, .{
|
|
||||||
// .is_a_fn = c.mlirTypeIsAUniformQuantizedType,
|
|
||||||
// .is_null_fn = c.mlirTypeIsNull,
|
|
||||||
// .dump_fn = c.mlirTypeDump,
|
|
||||||
// .equal_fn = c.mlirTypeEqual,
|
|
||||||
// }, struct {
|
|
||||||
// const Self = AnyQuantizedType;
|
|
||||||
|
|
||||||
// pub fn init(
|
|
||||||
// flags: quant.QuantizationFlags,
|
|
||||||
// storageType: Type,
|
|
||||||
// expressedType: Type,
|
|
||||||
// scale: f64,
|
|
||||||
// zeroPoint: i64,
|
|
||||||
// storageTypeMin: i64,
|
|
||||||
// storageTypeMax: i64,
|
|
||||||
// ) Self {
|
|
||||||
// return Self.wrap(c.mlirUniformQuantizedTypeGet(
|
|
||||||
// @intCast(@intFromEnum(flags)),
|
|
||||||
// storageType.inner(),
|
|
||||||
// expressedType.inner(),
|
|
||||||
// scale,
|
|
||||||
// zeroPoint,
|
|
||||||
// storageTypeMin,
|
|
||||||
// storageTypeMax,
|
|
||||||
// ));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getExpressedType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getFlags(self: Self) quant.QuantizationFlags {
|
|
||||||
// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn isSigned(self: Self) bool {
|
|
||||||
// return c.mlirQuantizedTypeIsSigned(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeMin(self: Self) i64 {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeMax(self: Self) i64 {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeIntegralWidth(self: Self) c_uint {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getQuantizedElementType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getScale(self: Self) f64 {
|
|
||||||
// return c.mlirUniformQuantizedTypeGetScale(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getZeroPoint(self: Self) i64 {
|
|
||||||
// return c.mlirUniformQuantizedTypeGetZeroPoint(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn isFixedPoint(self: Self) bool {
|
|
||||||
// return c.mlirUniformQuantizedTypeIsFixedPoint(self.inner());
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
|
|
||||||
// pub const QuantizedPerAxisType = MlirWrapperType(c.MlirType, .{
|
|
||||||
// .is_a_fn = c.mlirTypeIsAUniformQuantizedPerAxisType,
|
|
||||||
// .is_null_fn = c.mlirTypeIsNull,
|
|
||||||
// .dump_fn = c.mlirTypeDump,
|
|
||||||
// .equal_fn = c.mlirTypeEqual,
|
|
||||||
// }, struct {
|
|
||||||
// const Self = AnyQuantizedType;
|
|
||||||
|
|
||||||
// pub fn init(
|
|
||||||
// flags: quant.QuantizationFlags,
|
|
||||||
// storageType: Type,
|
|
||||||
// expressedType: Type,
|
|
||||||
// nDims: usize,
|
|
||||||
// scales: []f64,
|
|
||||||
// zeroPoints: []i64,
|
|
||||||
// quantizedDimension: i32,
|
|
||||||
// storageTypeMin: i64,
|
|
||||||
// storageTypeMax: i64,
|
|
||||||
// ) Self {
|
|
||||||
// std.debug.assert(scales.len == nDims);
|
|
||||||
// std.debug.assert(zeroPoints.len == nDims);
|
|
||||||
// return Self.wrap(c.mlirUniformQuantizedPerAxisTypeGet(
|
|
||||||
// @intCast(@intFromEnum(flags)),
|
|
||||||
// storageType.inner(),
|
|
||||||
// expressedType.inner(),
|
|
||||||
// @intCast(nDims),
|
|
||||||
// scales.ptr,
|
|
||||||
// zeroPoints.ptr,
|
|
||||||
// quantizedDimension,
|
|
||||||
// storageTypeMin,
|
|
||||||
// storageTypeMax,
|
|
||||||
// ));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getExpressedType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getFlags(self: Self) quant.QuantizationFlags {
|
|
||||||
// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn isSigned(self: Self) bool {
|
|
||||||
// return c.mlirQuantizedTypeIsSigned(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeMin(self: Self) i64 {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeMax(self: Self) i64 {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeIntegralWidth(self: Self) c_uint {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getQuantizedElementType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getNumDims(self: Self) usize {
|
|
||||||
// return @intCast(c.mlirUniformQuantizedPerAxisTypeGetNumDims(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getScale(self: Self) f64 {
|
|
||||||
// return @intCast(c.mlirUniformQuantizedPerAxisTypeGetScale(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getZeroPoint(self: Self, pos: usize) i64 {
|
|
||||||
// return c.mlirUniformQuantizedPerAxisTypeGetZeroPoint(self.inner(), @intCast(pos));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getQuantizedDimension(self: Self) i32 {
|
|
||||||
// return c.mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn isFixedPoint(self: Self) bool {
|
|
||||||
// return c.mlirUniformQuantizedPerAxisTypeIsFixedPoint(self.inner());
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
|
|
||||||
// pub const CalibratedQuantizedType = MlirWrapperType(c.MlirType, .{
|
|
||||||
// .is_a_fn = c.mlirTypeIsACalibratedQuantizedType,
|
|
||||||
// .is_null_fn = c.mlirTypeIsNull,
|
|
||||||
// .dump_fn = c.mlirTypeDump,
|
|
||||||
// .equal_fn = c.mlirTypeEqual,
|
|
||||||
// }, struct {
|
|
||||||
// const Self = AnyQuantizedType;
|
|
||||||
|
|
||||||
// pub fn init(expressedType: Type, min: f64, max: f64) Self {
|
|
||||||
// return Self.wrap(c.mlirCalibratedQuantizedTypeGet(expressedType.inner(), min, max));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getExpressedType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetExpressedType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getFlags(self: Self) quant.QuantizationFlags {
|
|
||||||
// return @enumFromInt(c.mlirQuantizedTypeGetFlags(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn isSigned(self: Self) bool {
|
|
||||||
// return c.mlirQuantizedTypeIsSigned(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetStorageType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeMin(self: Self) i64 {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeMin(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeMax(self: Self) i64 {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeMax(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getStorageTypeIntegralWidth(self: Self) c_uint {
|
|
||||||
// return c.mlirQuantizedTypeGetStorageTypeIntegralWidth(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getQuantizedElementType(self: Self) Type {
|
|
||||||
// return Type.wrap(c.mlirQuantizedTypeGetQuantizedElementType(self.inner()));
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getMin(self: Self) f64 {
|
|
||||||
// return c.mlirCalibratedQuantizedTypeGetMin(self.inner());
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getMax(self: Self) f64 {
|
|
||||||
// return c.mlirCalibratedQuantizedTypeGetMax(self.inner());
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
|
|
||||||
pub const ShapedType = struct {
|
pub const ShapedType = struct {
|
||||||
_inner: c.MlirType,
|
_inner: c.MlirType,
|
||||||
pub usingnamespace MlirHelpers(ShapedType, .{
|
pub usingnamespace MlirHelpers(ShapedType, .{
|
||||||
|
|||||||
@ -319,14 +319,6 @@ pub const Client = opaque {
|
|||||||
return Profiler.init(null, options);
|
return Profiler.init(null, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn getGpuCustomCallRegistry(self: *const Client, api: *const Api) ?*GpuCustomCallRegistry {
|
|
||||||
// if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
|
|
||||||
// return .{ .custom_call_register = ext.custom_call.? };
|
|
||||||
// }
|
|
||||||
// log.warn("No Gpu Custom Call registry found for platform: {}", .{self});
|
|
||||||
// return null;
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
||||||
const ret = try api.call(.PJRT_Executable_DeserializeAndLoad, .{
|
const ret = try api.call(.PJRT_Executable_DeserializeAndLoad, .{
|
||||||
.client = self.inner(),
|
.client = self.inner(),
|
||||||
@ -365,32 +357,6 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// // pub const CustomCallSignature = *const fn (*anyopaque, **anyopaque, [*c]const u8, usize) callconv(.C) void;
|
|
||||||
|
|
||||||
// // pub const GpuCustomCallRegistry = struct {
|
|
||||||
// // custom_call_register: *const c.PJRT_Gpu_Register_Custom_Call,
|
|
||||||
|
|
||||||
// // pub fn registerCustomCall(self: GpuCustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: CustomCallSignature) ApiError!void {
|
|
||||||
// // var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{
|
|
||||||
// // .function_name = name.ptr,
|
|
||||||
// // .function_name_size = name.len,
|
|
||||||
// // .api_version = @intCast(api_version),
|
|
||||||
// // .custom_call_function = @ptrCast(@constCast(func)),
|
|
||||||
// // });
|
|
||||||
// // const result = self.custom_call_register(&ret);
|
|
||||||
// // if (result) |pjrt_c_error| {
|
|
||||||
// // const pjrt_error = .{ .inner = pjrt_c_error };
|
|
||||||
// // log.err("{s}", .{pjrt_error.getMessage(api)});
|
|
||||||
// // return pjrt_error.getCode().toApiError();
|
|
||||||
// // }
|
|
||||||
// // }
|
|
||||||
// // };
|
|
||||||
|
|
||||||
// // const OldPjrtExtension = extern struct {
|
|
||||||
// // type: c.PJRT_Extension_Type,
|
|
||||||
// // next: [*]OldPjrtExtension,
|
|
||||||
// // };
|
|
||||||
|
|
||||||
pub const Device = opaque {
|
pub const Device = opaque {
|
||||||
const inner = InnerMixin(c.PJRT_Device).inner;
|
const inner = InnerMixin(c.PJRT_Device).inner;
|
||||||
|
|
||||||
|
|||||||
@ -128,64 +128,6 @@ pub const Profiler = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// If this was working it would be a good alternative to xspace_to_json.cc
|
|
||||||
// const xspace = @import("xspace.pb.zig");
|
|
||||||
// pub fn printDataAsXSpace(allocator: std.mem.Allocator, data: []const u8) void {
|
|
||||||
// var arena = std.heap.ArenaAllocator.init(allocator);
|
|
||||||
// defer arena.deinit();
|
|
||||||
//
|
|
||||||
// const space = xspace.XSpace.decode(data, arena.allocator()) catch |e| {
|
|
||||||
// std.log.err("Couldn't load profiling data: {}", .{e});
|
|
||||||
// return;
|
|
||||||
// };
|
|
||||||
//
|
|
||||||
// for (space.errors.items) |err| {
|
|
||||||
// std.log.err("{s}", .{err.getSlice()});
|
|
||||||
// }
|
|
||||||
// for (space.warnings.items) |warning| {
|
|
||||||
// std.log.warn("{s}", .{warning.getSlice()});
|
|
||||||
// }
|
|
||||||
// for (space.hostnames.items) |host| {
|
|
||||||
// std.log.info("Profiled host {s}", .{host.getSlice()});
|
|
||||||
// }
|
|
||||||
// for (space.planes.items) |plane| {
|
|
||||||
// var event_metadata = std.hash_map.AutoHashMap(i64, xspace.XEventMetadata).init(arena.allocator());
|
|
||||||
// event_metadata.ensureTotalCapacity(@intCast(plane.event_metadata.items.len)) catch return;
|
|
||||||
// defer event_metadata.deinit();
|
|
||||||
// for (plane.event_metadata.items) |event_meta_entry| {
|
|
||||||
// if (event_meta_entry.value) |event_meta| {
|
|
||||||
// event_metadata.putAssumeCapacity(event_meta.id, event_meta);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// std.log.info("Profiled device {s}", .{plane.name.getSlice()});
|
|
||||||
|
|
||||||
// for (plane.lines.items) |line| {
|
|
||||||
// std.log.info(
|
|
||||||
// "{d} -> {d} xline {s} ({d} events)",
|
|
||||||
// .{ line.timestamp_ns, line.duration_ps, line.name.getSlice(), line.events.items.len },
|
|
||||||
// );
|
|
||||||
// const ps_per_ns: i64 = 1000;
|
|
||||||
// var duration_ns: i64 = 0;
|
|
||||||
// var last_metadata_id: i64 = 0;
|
|
||||||
// for (line.events.items) |event| {
|
|
||||||
// if (event.metadata_id != last_metadata_id and duration_ns != 0) {
|
|
||||||
// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us;
|
|
||||||
// const meta = event_metadata.get(event.metadata_id).?;
|
|
||||||
// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us });
|
|
||||||
|
|
||||||
// last_metadata_id = event.metadata_id;
|
|
||||||
// duration_ns = 0;
|
|
||||||
// }
|
|
||||||
// duration_ns += @divFloor(event.duration_ps, ps_per_ns);
|
|
||||||
|
|
||||||
// const duration_us = @as(f32, @floatFromInt(duration_ns)) / std.time.ns_per_us;
|
|
||||||
// const meta = event_metadata.get(event.metadata_id).?;
|
|
||||||
// std.log.info("event {s}: {d:.1}μs", .{ meta.name.getSlice(), duration_us });
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
const ProfilingData = union(enum) {
|
const ProfilingData = union(enum) {
|
||||||
owned: []const u8,
|
owned: []const u8,
|
||||||
external: []const u8,
|
external: []const u8,
|
||||||
|
|||||||
82
zml/aio.zig
82
zml/aio.zig
@ -300,23 +300,12 @@ fn _populateStruct(
|
|||||||
log.warn("No layer found at {s}", .{prefix});
|
log.warn("No layer found at {s}", .{prefix});
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} else if (ptr_info.size == .One) {
|
|
||||||
//if (ptr_info.child != zml.Tensor and ptr_info.child != ?zml.Tensor) {
|
|
||||||
// // Note: should we recurse on all pointers ?
|
|
||||||
// log.warn("Not looking into: {any}", .{prefix});
|
|
||||||
// return false;
|
|
||||||
//}
|
|
||||||
//obj.* = try allocator.create(ptr_info.child);
|
|
||||||
//return try _populateStruct(allocator, buffer_store, unique_id, prefix, obj.*, required);
|
|
||||||
} else {
|
} else {
|
||||||
std.log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) });
|
std.log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) });
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.Struct => |struct_info| {
|
.Struct => |struct_info| {
|
||||||
// TODO(Corentin): See if we keep that
|
|
||||||
//if (@hasDecl(T, "_zml_reader_skip_me_")) return false;
|
|
||||||
|
|
||||||
var partial_struct = false;
|
var partial_struct = false;
|
||||||
inline for (struct_info.fields) |field| {
|
inline for (struct_info.fields) |field| {
|
||||||
try prefix_builder.push(allocator, field.name);
|
try prefix_builder.push(allocator, field.name);
|
||||||
@ -343,46 +332,12 @@ fn _populateStruct(
|
|||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
//.Array => |array_info| {
|
|
||||||
// var new_prefix = prefix;
|
|
||||||
// if (prefix.items.len > 0)
|
|
||||||
// new_prefix.appendAssumeCapacity('.');
|
|
||||||
// const len = new_prefix.items.len;
|
|
||||||
// for (obj, 0..) |*value, i| {
|
|
||||||
// new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
|
||||||
// const found = try _populateStruct(allocator, buffer_store, unique_id, new_prefix, value, required);
|
|
||||||
// if (!found) return false;
|
|
||||||
// new_prefix.shrinkRetainingCapacity(len);
|
|
||||||
// }
|
|
||||||
// const num_layers = buffer_store.numLayers(prefix.items);
|
|
||||||
// if (num_layers != array_info.len) {
|
|
||||||
// log.warn("Found {d} layers with prefix {s}, but only loaded {d}", .{ num_layers, prefix.items, array_info.len });
|
|
||||||
// }
|
|
||||||
// return true;
|
|
||||||
//},
|
|
||||||
.Optional => |opt_info| {
|
.Optional => |opt_info| {
|
||||||
obj.* = @as(opt_info.child, undefined);
|
obj.* = @as(opt_info.child, undefined);
|
||||||
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false);
|
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &(obj.*.?), false);
|
||||||
if (!found) obj.* = null;
|
if (!found) obj.* = null;
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
//.Union => |union_info| {
|
|
||||||
// // Note: the main issue here is that several fields could match but we only return the first one.
|
|
||||||
// inline for (union_info.fields) |field| {
|
|
||||||
// // interpret obj as a "field", and try to populate that.
|
|
||||||
// obj.* = @unionInit(T, field.name, undefined);
|
|
||||||
// const found = try _populateStruct(allocator, buffer_store, unique_id, prefix, &@field(obj.*, field.name), false);
|
|
||||||
// if (found) {
|
|
||||||
// std.log.info("Interpreted {s} as {s}", .{ prefix.items, @typeName(field.type) });
|
|
||||||
// return true;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// obj.* = undefined;
|
|
||||||
// if (required) {
|
|
||||||
// std.log.err("Not able to intepret {s} as any member of the union: {s}", .{ prefix.items, @typeName(T) });
|
|
||||||
// }
|
|
||||||
// return false;
|
|
||||||
//},
|
|
||||||
.Int => {
|
.Int => {
|
||||||
obj.* = undefined;
|
obj.* = undefined;
|
||||||
return true;
|
return true;
|
||||||
@ -540,9 +495,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
} else return error.TypeNotSupported;
|
} else return error.TypeNotSupported;
|
||||||
},
|
},
|
||||||
.Struct => |struct_info| {
|
.Struct => |struct_info| {
|
||||||
// TODO(Corentin): See if we keep that
|
|
||||||
//if (@hasDecl(T, "_zml_reader_skip_me_")) return false;
|
|
||||||
|
|
||||||
inline for (struct_info.fields) |field| {
|
inline for (struct_info.fields) |field| {
|
||||||
try prefix_builder.push(allocator, field.name);
|
try prefix_builder.push(allocator, field.name);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
@ -550,23 +502,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
|
try visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &@field(obj, field.name), platform);
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
//.Array => |array_info| {
|
|
||||||
// var new_prefix = prefix;
|
|
||||||
// if (prefix.items.len > 0)
|
|
||||||
// new_prefix.appendAssumeCapacity('.');
|
|
||||||
// const len = new_prefix.items.len;
|
|
||||||
// for (obj, 0..) |*value, i| {
|
|
||||||
// new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
|
||||||
// const found = try _populateStruct(allocator, buffer_store, unique_id, new_prefix, value, required);
|
|
||||||
// if (!found) return false;
|
|
||||||
// new_prefix.shrinkRetainingCapacity(len);
|
|
||||||
// }
|
|
||||||
// const num_layers = buffer_store.numLayers(prefix.items);
|
|
||||||
// if (num_layers != array_info.len) {
|
|
||||||
// log.warn("Found {d} layers with prefix {s}, but only loaded {d}", .{ num_layers, prefix.items, array_info.len });
|
|
||||||
// }
|
|
||||||
// return true;
|
|
||||||
//},
|
|
||||||
.Optional => |opt_info| {
|
.Optional => |opt_info| {
|
||||||
var child = @as(opt_info.child, undefined);
|
var child = @as(opt_info.child, undefined);
|
||||||
if (visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &child, platform)) {
|
if (visitStructAndLoadBuffer(allocator, prefix_builder, buffer_store, &child, platform)) {
|
||||||
@ -576,23 +511,6 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
else => return err,
|
else => return err,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
//.Union => |union_info| {
|
|
||||||
// // Note: the main issue here is that several fields could match but we only return the first one.
|
|
||||||
// inline for (union_info.fields) |field| {
|
|
||||||
// // interpret obj as a "field", and try to populate that.
|
|
||||||
// obj.* = @unionInit(T, field.name, undefined);
|
|
||||||
// const found = try _populateStruct(allocator, buffer_store, unique_id, prefix, &@field(obj.*, field.name), false);
|
|
||||||
// if (found) {
|
|
||||||
// std.log.info("Interpreted {s} as {s}", .{ prefix.items, @typeName(field.type) });
|
|
||||||
// return true;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// obj.* = undefined;
|
|
||||||
// if (required) {
|
|
||||||
// std.log.err("Not able to intepret {s} as any member of the union: {s}", .{ prefix.items, @typeName(T) });
|
|
||||||
// }
|
|
||||||
// return false;
|
|
||||||
//},
|
|
||||||
else => {},
|
else => {},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -95,10 +95,6 @@ pub const Decoder = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp {
|
fn parseOps(self: *Decoder, allocator: Allocator, seekable_stream: anytype) ![]PickleOp {
|
||||||
// TODO(SuperAuguste): deflate using `std.compress.flate`'s `decompressor`
|
|
||||||
// TODO(SuperAuguste): explore swapping in non-generic reader here instead of using switch(?)
|
|
||||||
// not sure if that'd actually be beneficial in any way
|
|
||||||
|
|
||||||
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
|
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
|
||||||
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
|
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||||
while (try iter.next()) |entry| {
|
while (try iter.next()) |entry| {
|
||||||
|
|||||||
@ -49,10 +49,6 @@ pub fn collectDims(
|
|||||||
expected_dim.* = DIM_MISMATCH;
|
expected_dim.* = DIM_MISMATCH;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// TODO: strict mode:
|
|
||||||
// else if (mode == .strict) {
|
|
||||||
// @compileError("Found unexpected axis " ++ @tagName(a) ++ " when collecting " ++ @typeName(ShapeStruct(dims)));
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}).cb, &context, v);
|
}).cb, &context, v);
|
||||||
|
|||||||
@ -190,126 +190,6 @@ pub const HostBuffer = struct {
|
|||||||
res._shape = self._shape.reshape(shape_);
|
res._shape = self._shape.reshape(shape_);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const Slice = struct {
|
|
||||||
single: ?i64 = null,
|
|
||||||
start: i64 = 0,
|
|
||||||
end: ?i64 = null,
|
|
||||||
step: i64 = 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub inline fn copySlice1d(self: HostBuffer, allocator: std.mem.Allocator, axis: i8, _args: Slice) !HostBuffer {
|
|
||||||
var slices = [_]Slice{.{}} ** 5;
|
|
||||||
slices[self._shape.axis(axis)] = _args;
|
|
||||||
return copySlice(self, allocator, slices[0..self._shape.rank()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn copySlice(self: HostBuffer, allocator: std.mem.Allocator, slices: []const Slice) !HostBuffer {
|
|
||||||
const byte_size = self.dtype().sizeOf();
|
|
||||||
var start_indices = [_]usize{0} ** 5;
|
|
||||||
var strides_ = [_]usize{1} ** 5;
|
|
||||||
const dims = self._shape.dims();
|
|
||||||
var sh = self._shape;
|
|
||||||
|
|
||||||
for (slices, 0..) |_args, a| {
|
|
||||||
const args: Slice = .{
|
|
||||||
.start = if (_args.start >= 0) _args.start else _args.start + dims[a],
|
|
||||||
.end = _args.end orelse dims[a],
|
|
||||||
.step = _args.step,
|
|
||||||
};
|
|
||||||
start_indices[a] = @intCast(args.start);
|
|
||||||
strides_[a] = @intCast(args.step);
|
|
||||||
sh._dims.set(a, b: {
|
|
||||||
const range = args.end.? - args.start;
|
|
||||||
const counts = @divFloor(range - 1, args.step) + 1;
|
|
||||||
break :b counts;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
const rk = self.rank();
|
|
||||||
meta.assert(rk <= 5, "copySlice only supports less than 5-D tensors. Received: {}", .{self});
|
|
||||||
const raw_strides: [Shape.MAX_RANK]usize = blk: {
|
|
||||||
var res: [Shape.MAX_RANK]usize = undefined;
|
|
||||||
const _strides = self._shape.computeStrides(self.dtype().sizeOf());
|
|
||||||
for (_strides.constSlice(), 0..rk) |stride, i| res[i] = @intCast(stride);
|
|
||||||
break :blk res;
|
|
||||||
};
|
|
||||||
|
|
||||||
const result_tensor = try HostBuffer.empty(allocator, sh);
|
|
||||||
|
|
||||||
const res_strides: [Shape.MAX_RANK]usize = blk: {
|
|
||||||
var res: [Shape.MAX_RANK]usize = undefined;
|
|
||||||
const _strides = self._shape.computeStrides(self.dtype().sizeOf());
|
|
||||||
for (_strides.constSlice(), 0..rk) |stride, i| res[i] = @intCast(stride);
|
|
||||||
break :blk res;
|
|
||||||
};
|
|
||||||
|
|
||||||
const src_data = self.data;
|
|
||||||
const data_ = @constCast(result_tensor.data);
|
|
||||||
for (0..@intCast(sh.dim(0))) |j0| {
|
|
||||||
const off0 = (j0 * strides_[0] + start_indices[0]) * raw_strides[0];
|
|
||||||
const res_off0 = j0 * res_strides[0];
|
|
||||||
if (rk == 1) {
|
|
||||||
@memcpy(data_[res_off0..][0..byte_size], src_data[off0..][0..byte_size]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (0..@intCast(sh.dim(1))) |j1| {
|
|
||||||
const off1 = off0 + (j1 * strides_[1] + start_indices[1]) * raw_strides[1];
|
|
||||||
const res_off1 = res_off0 + j1 * res_strides[1];
|
|
||||||
if (rk == 2) {
|
|
||||||
@memcpy(data_[res_off1..][0..byte_size], src_data[off1..][0..byte_size]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (0..@intCast(sh.dim(2))) |j2| {
|
|
||||||
const off2 = off1 + (j2 * strides_[2] + start_indices[2]) * raw_strides[2];
|
|
||||||
const res_off2 = res_off1 + j2 * res_strides[2];
|
|
||||||
if (rk == 3) {
|
|
||||||
@memcpy(data_[res_off2..][0..byte_size], src_data[off2..][0..byte_size]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (0..@intCast(sh.dim(3))) |j3| {
|
|
||||||
const off3 = off2 + (j3 * strides_[3] + start_indices[3]) * raw_strides[3];
|
|
||||||
const res_off3 = res_off2 + j3 * res_strides[3];
|
|
||||||
if (rk == 4) {
|
|
||||||
@memcpy(data_[res_off3..][0..byte_size], src_data[off3..][0..byte_size]);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (0..@intCast(sh.dim(4))) |j4| {
|
|
||||||
const off4 = off3 + (j4 * strides_[4] + start_indices[4]) * raw_strides[4];
|
|
||||||
const res_off4 = res_off3 + j4 * res_strides[4];
|
|
||||||
@memcpy(data_[res_off4..][0..byte_size], src_data[off4..][0..byte_size]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result_tensor;
|
|
||||||
}
|
|
||||||
|
|
||||||
test copySlice {
|
|
||||||
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
|
||||||
defer arena_state.deinit();
|
|
||||||
const allocator = arena_state.allocator();
|
|
||||||
|
|
||||||
const x = HostBuffer.fromSlice(.{ 2, 5 }, &[_]f32{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
|
||||||
{
|
|
||||||
const res = try copySlice1d(x, allocator, 0, .{ .end = 1 });
|
|
||||||
try std.testing.expectEqualSlices(f32, &.{ 0, 1, 2, 3, 4 }, res.items(f32));
|
|
||||||
}
|
|
||||||
// { // failing
|
|
||||||
// const res = try copySlice1d(x, allocator, -1, .{ .start = -2 });
|
|
||||||
// try testing.expectEqualSlices(f32, &.{ 3, 4, 8, 9 }, res.items(f32));
|
|
||||||
// }
|
|
||||||
// {// failing
|
|
||||||
// const res = try copySlice1d(x, allocator, 1, .{ .start = 1, .step = 2 });
|
|
||||||
// try testing.expectEqualSlices(f32, &.{ 1, 3, 6, 8 }, res.items(f32));
|
|
||||||
// }
|
|
||||||
{
|
|
||||||
const res = try copySlice(x, allocator, &.{ .{ .start = 1 }, .{ .start = 1, .step = 2 } });
|
|
||||||
try std.testing.expectEqualSlices(f32, &.{ 6, 8 }, res.items(f32));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
fn parseArrayInfo(T: type) Shape {
|
fn parseArrayInfo(T: type) Shape {
|
||||||
|
|||||||
@ -746,9 +746,7 @@ fn compileInternal(
|
|||||||
|
|
||||||
var timer = std.time.Timer.start() catch null;
|
var timer = std.time.Timer.start() catch null;
|
||||||
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
|
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
|
||||||
// TODO: this is fast, doesn't make system call, and use mutable state.
|
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
||||||
// does it need to be async ?
|
|
||||||
// const f = try CompilationContext.generateBytecode(context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true });
|
|
||||||
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } });
|
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } });
|
||||||
context._module.getBody().appendOperation(f.mlir_fn);
|
context._module.getBody().appendOperation(f.mlir_fn);
|
||||||
|
|
||||||
|
|||||||
@ -218,13 +218,6 @@ test "real/img" {
|
|||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
const Fns = struct {
|
const Fns = struct {
|
||||||
// fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor {
|
|
||||||
// const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 });
|
|
||||||
// const real, const imag = splitRealImg(x, impl);
|
|
||||||
// const y = mergeRealImg(real, imag, impl);
|
|
||||||
// return y.cmp(.EQ, x).flatten(0).convert(.i32).sum(-1);
|
|
||||||
// }
|
|
||||||
|
|
||||||
fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor {
|
fn testSplitMergeIsId(impl: RopeOpts.Implementation) Tensor {
|
||||||
const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 });
|
const x = Tensor.arange(.{ .end = 20 }, .f32).reshape(.{ 5, 4 });
|
||||||
const real, const imag = splitRealImg(x, impl);
|
const real, const imag = splitRealImg(x, impl);
|
||||||
|
|||||||
11
zml/ops.zig
11
zml/ops.zig
@ -547,17 +547,6 @@ fn _BlockSign(comptime func: anytype, blk_type: BlockType) BlockSignature {
|
|||||||
if (i >= arg_start) {
|
if (i >= arg_start) {
|
||||||
n_tensors += staticCountTensors(ArgType) orelse @compileError("Can't use " ++ @typeName(ArgType) ++ " in an MLIR function, because it has a variable number of tensors");
|
n_tensors += staticCountTensors(ArgType) orelse @compileError("Can't use " ++ @typeName(ArgType) ++ " in an MLIR function, because it has a variable number of tensors");
|
||||||
}
|
}
|
||||||
|
|
||||||
// if (arg.type) |ArgType| {
|
|
||||||
// full_args[i] = ArgType;
|
|
||||||
// if (i >= arg_start) {
|
|
||||||
// n_tensors += staticCountTensors(ArgType) orelse @compileError("Can't use " ++ @typeName(ArgType) ++ " in an MLIR function, because it has a variable number of tensors");
|
|
||||||
// }
|
|
||||||
// } else {
|
|
||||||
// // anytype are considered to not have tensors.
|
|
||||||
// // violation of this will be detected when calling `compile()` but not at Zig compile time.
|
|
||||||
// full_args[i] = void;
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
const FullArgs = std.meta.Tuple(&full_args);
|
const FullArgs = std.meta.Tuple(&full_args);
|
||||||
const BlkCtx = switch (blk_type) {
|
const BlkCtx = switch (blk_type) {
|
||||||
|
|||||||
@ -167,34 +167,8 @@ pub const Client = opaque {
|
|||||||
pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler {
|
pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler {
|
||||||
return self.inner().getProfiler(api, options);
|
return self.inner().getProfiler(api, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn getGpuCustomCallRegistry(self: Client) ?GpuCustomCallRegistry {
|
|
||||||
// return switch (self.inner) {
|
|
||||||
// inline else => |v, tag| if (v.getGpuCustomCallRegistry()) |registry| GpuCustomCallRegistry.wrap(tag, registry) else null,
|
|
||||||
// };
|
|
||||||
// }
|
|
||||||
|
|
||||||
// pub fn getGpuCustomCallRegistry(self: *const Client, api: *const Api) ?*GpuCustomCallRegistry {
|
|
||||||
// if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
|
|
||||||
// return .{ .custom_call_register = ext.custom_call.? };
|
|
||||||
// }
|
|
||||||
// log.warn("No Gpu Custom Call registry found for platform: {}", .{self});
|
|
||||||
// return null;
|
|
||||||
// }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// pub const GpuCustomCallRegistry = struct {
|
|
||||||
// pub usingnamespace WrapperMixin(GpuCustomCallRegistry, pjrt.GpuCustomCallRegistry);
|
|
||||||
|
|
||||||
// inner: GpuCustomCallRegistry.UnionType,
|
|
||||||
|
|
||||||
// pub fn registerCustomCall(self: GpuCustomCallRegistry, api_version: usize, name: []const u8, func: pjrt.CustomCallSignature) ApiError!void {
|
|
||||||
// return switch (self.inner) {
|
|
||||||
// inline else => |v| v.registerCustomCall(api_version, name, func),
|
|
||||||
// };
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
|
|
||||||
pub const Buffer = opaque {
|
pub const Buffer = opaque {
|
||||||
const inner = InnerMixin(pjrt.Buffer).inner;
|
const inner = InnerMixin(pjrt.Buffer).inner;
|
||||||
|
|
||||||
|
|||||||
@ -348,9 +348,6 @@ pub const Shape = struct {
|
|||||||
return self.dtype().sizeOf() * self.count();
|
return self.dtype().sizeOf() * self.count();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Aliases
|
|
||||||
pub const numel = count;
|
|
||||||
|
|
||||||
/// Compares the two shapes described, ignoring tagging.
|
/// Compares the two shapes described, ignoring tagging.
|
||||||
pub fn eql(self: Shape, other: Shape) bool {
|
pub fn eql(self: Shape, other: Shape) bool {
|
||||||
return std.mem.eql(i64, self.dims(), other.dims()) and self.dtype() == other.dtype();
|
return std.mem.eql(i64, self.dims(), other.dims()) and self.dtype() == other.dtype();
|
||||||
@ -883,78 +880,6 @@ pub const Shape = struct {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parses an anytype argument of the form `val` or `.{ .a = val }`.l
|
|
||||||
/// Helps offering consistent API through ZML.
|
|
||||||
// pub fn parseTaggedValue(
|
|
||||||
// T: type,
|
|
||||||
// default_tag: EnumLiteral,
|
|
||||||
// d: anytype,
|
|
||||||
// ) struct { Tag, T } {
|
|
||||||
// const err_msg = "Expected one tagged dimension, received a tuple: " ++ @typeName(@TypeOf(d));
|
|
||||||
// return switch (@typeInfo(@TypeOf(d))) {
|
|
||||||
// .Int, .ComptimeInt => .{ toTag(default_tag), @intCast(d) },
|
|
||||||
// .Struct => |struct_info| {
|
|
||||||
// if (struct_info.fields.len != 1) @compileError(err_msg);
|
|
||||||
// const name = struct_info.fields[0].name;
|
|
||||||
// return .{ name.ptr, @intCast(@field(d, name)) };
|
|
||||||
// },
|
|
||||||
// else => @compileError(err_msg),
|
|
||||||
// };
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// Parses a list of tags `.{ .a, .b, .c }` into a `[]Tag`
|
|
||||||
// pub inline fn parseTagList(comptime axes_: anytype) []Tag {
|
|
||||||
// switch (@typeInfo(@TypeOf(axes_))) {
|
|
||||||
// .Struct, .Array => {
|
|
||||||
// var _tags: [axes_.len]Tag = undefined;
|
|
||||||
// inline for (axes_, &_tags) |a, *t| t.* = toTag(a);
|
|
||||||
// return &_tags;
|
|
||||||
// },
|
|
||||||
// else => @compileError("Expected a tuple of enum literal, but found " ++ @tagName(@TypeOf(axes))),
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
/// Parses a comptime struct into a struct similarly to Shape.init,
|
|
||||||
/// but with a custom type in place of the `i64` dimensions.
|
|
||||||
/// Helps offering consistent API through ZML.
|
|
||||||
// pub fn parseShapedValue(T: type, value: anytype) struct {
|
|
||||||
// std.BoundedArray(Tag, MAX_RANK),
|
|
||||||
// std.BoundedArray(T, MAX_RANK),
|
|
||||||
// } {
|
|
||||||
// const too_long_err = std.fmt.comptimePrint("Received too many axes, maximum supported is {d}", .{MAX_RANK});
|
|
||||||
|
|
||||||
// var _tags: [MAX_RANK]Tag = [_]Tag{TagUnknown} ** MAX_RANK;
|
|
||||||
// const struct_info = switch (@typeInfo(@TypeOf(value))) {
|
|
||||||
// .Struct => |struct_info| struct_info,
|
|
||||||
// else => return .{
|
|
||||||
// .{ .len = 0, .buffer = _tags },
|
|
||||||
// std.BoundedArray(T, MAX_RANK).fromSlice(value) catch @panic(too_long_err),
|
|
||||||
// },
|
|
||||||
// };
|
|
||||||
|
|
||||||
// meta.assertComptime(struct_info.fields.len <= MAX_RANK, too_long_err, .{});
|
|
||||||
|
|
||||||
// var values: std.BoundedArray(T, MAX_RANK) = .{};
|
|
||||||
// inline for (struct_info.fields) |field| {
|
|
||||||
// if (T == Tag) {
|
|
||||||
// values.appendAssumeCapacity(toTag(@field(value, field.name)));
|
|
||||||
// } else {
|
|
||||||
// // If you have an error here it means Zig wasn't able to convert between the
|
|
||||||
// // value you passed and the expected `T`.
|
|
||||||
// values.appendAssumeCapacity(@field(value, field.name));
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// if (!struct_info.is_tuple) {
|
|
||||||
// inline for (struct_info.fields, 0..) |field, i| {
|
|
||||||
// _tags[i] = toTag(field);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// return .{
|
|
||||||
// .{ .len = struct_info.fields.len, .buffer = _tags },
|
|
||||||
// values,
|
|
||||||
// };
|
|
||||||
// }
|
|
||||||
|
|
||||||
fn intersectTags(a: []const Tag, b: []const Tag) TagsArray {
|
fn intersectTags(a: []const Tag, b: []const Tag) TagsArray {
|
||||||
var res = TagsArray.init(0) catch unreachable;
|
var res = TagsArray.init(0) catch unreachable;
|
||||||
for (a) |tag_| {
|
for (a) |tag_| {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user