Remove usingnamespace from MLIR.

This commit is contained in:
Tarry Singh 2025-01-28 09:35:58 +00:00
parent f8ab0d7b2a
commit 0a2ab7c8cb
10 changed files with 777 additions and 1031 deletions

View File

@ -73,7 +73,7 @@ pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: m
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) }, .{ "predicate", .int(ctx, .i64, @intFromEnum(predicate)) },
}, },
.location = location, .location = location,
}); });
@ -103,7 +103,7 @@ pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: m
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) }, .{ "predicate", .int(ctx, .i64, @intFromEnum(predicate)) },
}, },
.location = location, .location = location,
}); });

View File

@ -127,10 +127,10 @@ pub const DotPrecision = union(enum) {
// When we specify the dot algorithm, we should not specify the precision. // When we specify the dot algorithm, we should not specify the precision.
.algorithm => .DEFAULT, .algorithm => .DEFAULT,
}); });
return precision.as(mlir.Attribute); return precision.asAttr();
} }
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute { pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.RankedTensorType) ?mlir.Attribute {
return switch (self) { return switch (self) {
.algorithm => |algo| algo.asAttr(ctx, operand_type), .algorithm => |algo| algo.asAttr(ctx, operand_type),
else => null, else => null,
@ -156,15 +156,14 @@ pub const DotAlgorithm = struct {
.allow_imprecise_accumulation = false, .allow_imprecise_accumulation = false,
}; };
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute { pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, tensor_type: mlir.RankedTensorType) mlir.Attribute {
const tensor_type = operand_type.as(mlir.RankedTensorType);
const elem_type = tensor_type.getElementType(); const elem_type = tensor_type.getElementType();
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet( return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
ctx.inner(), ctx._inner,
elem_type.inner(), elem_type._inner,
elem_type.inner(), elem_type._inner,
self.accumulation.asType(ctx).inner(), self.accumulation.asType(ctx)._inner,
self.component_count, self.component_count,
self.component_count, self.component_count,
self.num_primitive_operations, self.num_primitive_operations,
@ -197,11 +196,11 @@ pub fn dot_general(
.rhs_batching_dimensions = opts.rhs_batching_dimensions, .rhs_batching_dimensions = opts.rhs_batching_dimensions,
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions, .lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions, .rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
}).as(mlir.Attribute), }).asAttr(),
}, },
.{ "precision_config", .array(ctx, &precisions) }, .{ "precision_config", .array(ctx, &precisions) },
// keep algorithm as the last attribute so we can omit it when it's not set. // keep algorithm as the last attribute so we can omit it when it's not set.
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined }, .{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType().as(mlir.RankedTensorType).?) orelse undefined },
}; };
const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1; const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1;
return mlir.Operation.make(ctx, "stablehlo.dot_general", .{ return mlir.Operation.make(ctx, "stablehlo.dot_general", .{
@ -214,19 +213,15 @@ pub fn dot_general(
pub fn constant( pub fn constant(
ctx: mlir.Context, ctx: mlir.Context,
result_type: mlir.RankedTensorType, dims: []const i64,
elem_type: mlir.DenseElementsAttributeTypes, elem_type: mlir.DenseElementsAttributeTypes,
raw_bytes: []const u8, raw_bytes: []const u8,
location: mlir.Location, location: mlir.Location,
) mlir.Operation { ) mlir.Operation {
const attribute = switch (elem_type) {
inline else => |dt| mlir.DenseElementsAttribute(dt).init(result_type.as(mlir.Type), raw_bytes).as(mlir.Attribute),
};
return mlir.Operation.make(ctx, "stablehlo.constant", .{ return mlir.Operation.make(ctx, "stablehlo.constant", .{
.operands = &.{}, .operands = &.{},
.results = &.{result_type.as(mlir.Type)}, .results = &.{.tensor(dims, elem_type.mlirType(ctx))},
.attributes = &.{.{ "value", attribute }}, .attributes = &.{.{ "value", .denseElementsFromBytes(ctx, dims, elem_type, raw_bytes) }},
.location = location, .location = location,
}); });
} }
@ -285,10 +280,10 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
}); });
} }
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation { pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.reshape", .{ return mlir.Operation.make(ctx, "stablehlo.reshape", .{
.operands = &.{value}, .operands = &.{value},
.results = &.{result_type.as(mlir.Type)}, .results = &.{result_type},
.location = location, .location = location,
}); });
} }
@ -332,7 +327,7 @@ pub fn gather(
args.start_indices_batching_dims, args.start_indices_batching_dims,
args.start_index_map, args.start_index_map,
args.index_vector_dim, args.index_vector_dim,
).as(mlir.Attribute) }, ).asAttr() },
.{ "slice_sizes", .dense(ctx, .i64, slice_sizes) }, .{ "slice_sizes", .dense(ctx, .i64, slice_sizes) },
.{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) }, .{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) },
}, },
@ -358,22 +353,20 @@ pub const ScatterArgs = struct {
unique_indices: bool = false, unique_indices: bool = false,
pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute { pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute {
return mlir.Attribute.wrap( return .{ ._inner = c.stablehloScatterDimensionNumbersGet(
c.stablehloScatterDimensionNumbersGet( ctx._inner,
ctx.inner(), @intCast(self.update_window_dims.len),
@intCast(self.update_window_dims.len), self.update_window_dims.ptr,
self.update_window_dims.ptr, @intCast(self.inserted_window_dims.len),
@intCast(self.inserted_window_dims.len), self.inserted_window_dims.ptr,
self.inserted_window_dims.ptr, @intCast(self.input_batching_dims.len),
@intCast(self.input_batching_dims.len), self.input_batching_dims.ptr,
self.input_batching_dims.ptr, @intCast(self.scatter_indices_batching_dims.len),
@intCast(self.scatter_indices_batching_dims.len), self.scatter_indices_batching_dims.ptr,
self.scatter_indices_batching_dims.ptr, @intCast(self.scatter_dims_to_operand_dims.len),
@intCast(self.scatter_dims_to_operand_dims.len), self.scatter_dims_to_operand_dims.ptr,
self.scatter_dims_to_operand_dims.ptr, self.index_vector_dim,
self.index_vector_dim, ) };
),
);
} }
}; };
@ -431,8 +424,8 @@ pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_d
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "comparison_direction", comparison_direction.as(mlir.Attribute) }, .{ "comparison_direction", comparison_direction.asAttr() },
.{ "compare_type", compare_type.as(mlir.Attribute) }, .{ "compare_type", compare_type.asAttr() },
}, },
.location = location, .location = location,
}); });
@ -580,7 +573,7 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value,
.{ "left_side", .i1FromBool(ctx, opts.left_side) }, .{ "left_side", .i1FromBool(ctx, opts.left_side) },
.{ "lower", .i1FromBool(ctx, opts.lower) }, .{ "lower", .i1FromBool(ctx, opts.lower) },
.{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) }, .{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) },
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) }, .{ "transpose_a", Transpose.init(ctx, opts.transpose_a).asAttr() },
}, },
.location = location, .location = location,
}); });
@ -596,7 +589,7 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts:
.operands = &.{value}, .operands = &.{value},
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) }, .{ "fft_type", FftType.init(ctx, opts.kind).asAttr() },
.{ "fft_length", .dense(ctx, .i64, opts.length) }, .{ "fft_length", .dense(ctx, .i64, opts.length) },
}, },
.location = location, .location = location,
@ -608,7 +601,7 @@ pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, r
.operands = &.{ a, b, shape }, .operands = &.{ a, b, shape },
.result_type_inference = true, .result_type_inference = true,
.attributes = &.{ .attributes = &.{
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute) }, .{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).asAttr() },
}, },
.location = location, .location = location,
}); });
@ -619,7 +612,7 @@ pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, in
.operands = &.{initial_state}, .operands = &.{initial_state},
.results = &.{ res_state_type, res_type }, .results = &.{ res_state_type, res_type },
.attributes = &.{ .attributes = &.{
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute) }, .{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).asAttr() },
}, },
.location = location, .location = location,
}); });
@ -695,7 +688,7 @@ pub fn convolution(
) mlir.Operation { ) mlir.Operation {
var max_precisions: [2]mlir.Attribute = undefined; var max_precisions: [2]mlir.Attribute = undefined;
for (opts.precision_config, 0..) |p, i| { for (opts.precision_config, 0..) |p, i| {
max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute); max_precisions[i] = PrecisionAttribute.init(ctx, p).asAttr();
} }
var window_reversal: [3]i32 = undefined; var window_reversal: [3]i32 = undefined;
for (opts.window_reversal, 0..) |w, i| { for (opts.window_reversal, 0..) |w, i| {
@ -721,7 +714,7 @@ pub fn convolution(
.output_batch_dimension = opts.output_batch_dimension, .output_batch_dimension = opts.output_batch_dimension,
.output_feature_dimension = opts.output_feature_dimension, .output_feature_dimension = opts.output_feature_dimension,
.output_spatial_dimensions = opts.output_spatial_dimensions, .output_spatial_dimensions = opts.output_spatial_dimensions,
}).as(mlir.Attribute), }).asAttr(),
}, },
.{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) }, .{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) },
.{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) }, .{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) },
@ -756,13 +749,13 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, ""); const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, "");
if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) { if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) {
stdx.debug.assert( stdx.debug.assert(
backend_config.is_a(mlir.StringAttribute), backend_config.isA(mlir.StringAttribute),
"API version < 4 requires a string as backend_config, got {}", "API version < 4 requires a string as backend_config, got {}",
.{backend_config}, .{backend_config},
); );
} else { } else {
stdx.debug.assert( stdx.debug.assert(
backend_config.is_a(mlir.DictionaryAttribute), backend_config.isA(mlir.DictionaryAttribute),
"API version >= 4 requires a dictionary as backend_config, got {}", "API version >= 4 requires a dictionary as backend_config, got {}",
.{backend_config}, .{backend_config},
); );
@ -780,7 +773,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (opts.output_operand_aliases) |alias| { for (opts.output_operand_aliases) |alias| {
output_operand_aliases.appendAssumeCapacity( output_operand_aliases.appendAssumeCapacity(
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute), OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(),
); );
} }
attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }); attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) });
@ -805,7 +798,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const operand_layouts = blk: { const operand_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (inputs) |input| { for (inputs) |input| {
const ranked_type = input.getType().as(mlir.RankedTensorType); const ranked_type = input.getType().as(mlir.RankedTensorType).?;
const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..]; const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..];
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
} }
@ -824,7 +817,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const result_layouts = blk: { const result_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (res_types) |t| { for (res_types) |t| {
const ranked_t = t.as(mlir.RankedTensorType); const ranked_t = t.as(mlir.RankedTensorType).?;
const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..]; const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..];
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
} }
@ -846,13 +839,10 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
pub const DotDimensionNumbersAttribute = struct { pub const DotDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(DotDimensionNumbersAttribute, .{ pub const is_a_fn = c.stablehloAttributeIsADotDimensionNumbers;
.is_a_fn = c.stablehloAttributeIsADotDimensionNumbers,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = DotDimensionNumbersAttribute; const Self = DotDimensionNumbersAttribute;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub fn init(ctx: mlir.Context, args: struct { pub fn init(ctx: mlir.Context, args: struct {
lhs_batching_dimensions: []const i64, lhs_batching_dimensions: []const i64,
@ -860,9 +850,9 @@ pub const DotDimensionNumbersAttribute = struct {
lhs_contracting_dimensions: []const i64, lhs_contracting_dimensions: []const i64,
rhs_contracting_dimensions: []const i64, rhs_contracting_dimensions: []const i64,
}) Self { }) Self {
return Self.wrap( return .{
c.stablehloDotDimensionNumbersGet( ._inner = c.stablehloDotDimensionNumbersGet(
ctx.inner(), ctx._inner,
@intCast(args.lhs_batching_dimensions.len), @intCast(args.lhs_batching_dimensions.len),
args.lhs_batching_dimensions.ptr, args.lhs_batching_dimensions.ptr,
@intCast(args.rhs_batching_dimensions.len), @intCast(args.rhs_batching_dimensions.len),
@ -872,52 +862,49 @@ pub const DotDimensionNumbersAttribute = struct {
@intCast(args.rhs_contracting_dimensions.len), @intCast(args.rhs_contracting_dimensions.len),
args.rhs_contracting_dimensions.ptr, args.rhs_contracting_dimensions.ptr,
), ),
); };
} }
pub fn getLhsBatchingDimensionsSize(self: Self) usize { pub fn getLhsBatchingDimensionsSize(self: Self) usize {
return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self.inner())); return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self._inner));
} }
pub fn getLhsBatchingDimensionsElem(self: Self, pos: usize) i64 { pub fn getLhsBatchingDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self.inner(), @intCast(pos)); return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self._inner, @intCast(pos));
} }
pub fn getRhsBatchingDimensionsSize(self: Self) usize { pub fn getRhsBatchingDimensionsSize(self: Self) usize {
return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self.inner())); return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self._inner));
} }
pub fn getRhsBatchingDimensionsElem(self: Self, pos: usize) i64 { pub fn getRhsBatchingDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self.inner(), @intCast(pos)); return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self._inner, @intCast(pos));
} }
pub fn getLhsContractingDimensionsSize(self: Self) usize { pub fn getLhsContractingDimensionsSize(self: Self) usize {
return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self.inner())); return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self._inner));
} }
pub fn getLhsContractingDimensionsElem(self: Self, pos: usize) i64 { pub fn getLhsContractingDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self.inner(), @intCast(pos)); return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self._inner, @intCast(pos));
} }
pub fn getRhsContractingDimensionsSize(self: Self) usize { pub fn getRhsContractingDimensionsSize(self: Self) usize {
return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self.inner())); return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self._inner));
} }
pub fn getRhsContractingDimensionsElem(self: Self, pos: usize) i64 { pub fn getRhsContractingDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self.inner(), @intCast(pos)); return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self._inner, @intCast(pos));
} }
}; };
pub const GatherDimensionNumbersAttribute = struct { pub const GatherDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(GatherDimensionNumbersAttribute, .{ pub const is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers;
.is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = GatherDimensionNumbersAttribute; const Self = GatherDimensionNumbersAttribute;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub fn init( pub fn init(
ctx: mlir.Context, ctx: mlir.Context,
@ -928,9 +915,9 @@ pub const GatherDimensionNumbersAttribute = struct {
start_index_map: []const i64, start_index_map: []const i64,
index_vector_dim: i64, index_vector_dim: i64,
) Self { ) Self {
return Self.wrap( return .{
c.stablehloGatherDimensionNumbersGet( ._inner = c.stablehloGatherDimensionNumbersGet(
ctx.inner(), ctx._inner,
@intCast(offset_dims.len), @intCast(offset_dims.len),
offset_dims.ptr, offset_dims.ptr,
@intCast(collapsed_slice_dims.len), @intCast(collapsed_slice_dims.len),
@ -943,64 +930,61 @@ pub const GatherDimensionNumbersAttribute = struct {
start_index_map.ptr, start_index_map.ptr,
index_vector_dim, index_vector_dim,
), ),
); };
} }
pub fn getOffsetDimsSize(self: Self) usize { pub fn getOffsetDimsSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self.inner())); return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self._inner));
} }
pub fn getOffsetDimsElem(self: Self, pos: usize) i64 { pub fn getOffsetDimsElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self.inner(), @intCast(pos)); return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self._inner, @intCast(pos));
} }
pub fn getCollapsedSliceDimsSize(self: Self) usize { pub fn getCollapsedSliceDimsSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self.inner())); return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self._inner));
} }
pub fn getCollapsedSliceDimsElem(self: Self, pos: usize) i64 { pub fn getCollapsedSliceDimsElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self.inner(), @intCast(pos)); return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self._inner, @intCast(pos));
} }
pub fn getStartIndexMapSize(self: Self) usize { pub fn getStartIndexMapSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self.inner())); return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self._inner));
} }
pub fn getOperandBatchingDimsSize(self: Self) usize { pub fn getOperandBatchingDimsSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self.inner())); return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self._inner));
} }
pub fn getOperandBatchingDimsElem(self: Self, pos: usize) i64 { pub fn getOperandBatchingDimsElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self.inner(), @intCast(pos)); return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self._inner, @intCast(pos));
} }
pub fn getStartIndicesBatchingDimsSize(self: Self) usize { pub fn getStartIndicesBatchingDimsSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self.inner())); return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self._inner));
} }
pub fn getStartIndicesBatchingDimsElem(self: Self, pos: usize) i64 { pub fn getStartIndicesBatchingDimsElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self.inner(), @intCast(pos)); return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self._inner, @intCast(pos));
} }
pub fn getStartIndexMapElem(self: Self, pos: usize) i64 { pub fn getStartIndexMapElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self.inner(), @intCast(pos)); return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self._inner, @intCast(pos));
} }
pub fn getIndexVectorDim(self: Self) usize { pub fn getIndexVectorDim(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self.inner())); return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self._inner));
} }
}; };
pub const ConvDimensionNumbersAttribute = struct { pub const ConvDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(ConvDimensionNumbersAttribute, .{ pub const is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers;
.is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = ConvDimensionNumbersAttribute; const Self = ConvDimensionNumbersAttribute;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub fn init(ctx: mlir.Context, args: struct { pub fn init(ctx: mlir.Context, args: struct {
input_batch_dimension: i64, input_batch_dimension: i64,
@ -1013,9 +997,9 @@ pub const ConvDimensionNumbersAttribute = struct {
output_feature_dimension: i64, output_feature_dimension: i64,
output_spatial_dimensions: []const i64, output_spatial_dimensions: []const i64,
}) Self { }) Self {
return Self.wrap( return .{
c.stablehloConvDimensionNumbersGet( ._inner = c.stablehloConvDimensionNumbersGet(
ctx.inner(), ctx._inner,
args.input_batch_dimension, args.input_batch_dimension,
args.input_feature_dimension, args.input_feature_dimension,
@intCast(args.input_spatial_dimensions.len), @intCast(args.input_spatial_dimensions.len),
@ -1029,67 +1013,64 @@ pub const ConvDimensionNumbersAttribute = struct {
@intCast(args.output_spatial_dimensions.len), @intCast(args.output_spatial_dimensions.len),
args.output_spatial_dimensions.ptr, args.output_spatial_dimensions.ptr,
), ),
); };
} }
pub fn getInputBatchDimension(self: Self) i64 { pub fn getInputBatchDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetInputBatchDimension(self.inner()); return c.stablehloConvDimensionNumbersGetInputBatchDimension(self._inner);
} }
pub fn getInputFeatureDimension(self: Self) i64 { pub fn getInputFeatureDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self.inner()); return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self._inner);
} }
pub fn getInputSpatialDimensionsSize(self: Self) usize { pub fn getInputSpatialDimensionsSize(self: Self) usize {
return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self.inner())); return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self._inner));
} }
pub fn getInputSpatialDimensionsElem(self: Self, pos: usize) i64 { pub fn getInputSpatialDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self.inner(), @intCast(pos)); return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self._inner, @intCast(pos));
} }
pub fn getKernelInputFeatureDimension(self: Self) i64 { pub fn getKernelInputFeatureDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self.inner()); return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self._inner);
} }
pub fn getKernelOutputFeatureDimension(self: Self) i64 { pub fn getKernelOutputFeatureDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self.inner()); return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self._inner);
} }
pub fn getKernelSpatialDimensionsSize(self: Self) usize { pub fn getKernelSpatialDimensionsSize(self: Self) usize {
return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self.inner())); return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self._inner));
} }
pub fn getKernelSpatialDimensionsElem(self: Self, pos: usize) i64 { pub fn getKernelSpatialDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self.inner(), @intCast(pos)); return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self._inner, @intCast(pos));
} }
pub fn getOutputBatchDimension(self: Self) i64 { pub fn getOutputBatchDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self.inner()); return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self._inner);
} }
pub fn getOutputFeatureDimension(self: Self) i64 { pub fn getOutputFeatureDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self.inner()); return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self._inner);
} }
pub fn getOutputSpatialDimensionsSize(self: Self) usize { pub fn getOutputSpatialDimensionsSize(self: Self) usize {
return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self.inner())); return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self._inner));
} }
pub fn getOutputSpatialDimensionsElem(self: Self, pos: usize) i64 { pub fn getOutputSpatialDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self.inner(), @intCast(pos)); return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self._inner, @intCast(pos));
} }
}; };
pub const OutputOperandAliasAttribute = struct { pub const OutputOperandAliasAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(OutputOperandAliasAttribute, .{ pub const is_a_fn = c.stablehloAttributeIsAOutputOperandAlias;
.is_a_fn = c.stablehloAttributeIsAOutputOperandAlias, pub const asAttr = mlir.Attribute.fromAny(OutputOperandAliasAttribute);
.is_null_fn = c.mlirAttributeIsNull, pub const eql = mlir.Attribute.eqlAny(OutputOperandAliasAttribute);
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub fn init( pub fn init(
ctx: mlir.Context, ctx: mlir.Context,
@ -1097,27 +1078,24 @@ pub const OutputOperandAliasAttribute = struct {
operand_index: i64, operand_index: i64,
operand_tuple_indices: []const i64, operand_tuple_indices: []const i64,
) OutputOperandAliasAttribute { ) OutputOperandAliasAttribute {
return OutputOperandAliasAttribute.wrap(c.stablehloOutputOperandAliasGet( return .{ ._inner = c.stablehloOutputOperandAliasGet(
ctx.inner(), ctx._inner,
@intCast(output_tuple_indices.len), @intCast(output_tuple_indices.len),
output_tuple_indices.ptr, output_tuple_indices.ptr,
@intCast(operand_index), @intCast(operand_index),
@intCast(operand_tuple_indices.len), @intCast(operand_tuple_indices.len),
operand_tuple_indices.ptr, operand_tuple_indices.ptr,
)); ) };
} }
}; };
pub const PrecisionAttribute = struct { pub const PrecisionAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(PrecisionAttribute, .{ pub const is_a_fn = c.stablehloAttributeIsAPrecisionAttr;
.is_a_fn = c.stablehloAttributeIsAPrecisionAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = PrecisionAttribute; const Self = PrecisionAttribute;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Precision = enum { pub const Precision = enum {
DEFAULT, DEFAULT,
@ -1126,11 +1104,11 @@ pub const PrecisionAttribute = struct {
}; };
pub fn init(ctx: mlir.Context, value: Precision) Self { pub fn init(ctx: mlir.Context, value: Precision) Self {
return Self.wrap(c.stablehloPrecisionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); return .{ ._inner = c.stablehloPrecisionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
} }
pub fn getValue(self: Self) Precision { pub fn getValue(self: Self) Precision {
const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self.inner())); const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self._inner));
return std.meta.stringToEnum(Precision, value) orelse unreachable; return std.meta.stringToEnum(Precision, value) orelse unreachable;
} }
}; };
@ -1138,13 +1116,10 @@ pub const PrecisionAttribute = struct {
pub const ComparisonDirection = struct { pub const ComparisonDirection = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(ComparisonDirection, .{ pub const is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr;
.is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = ComparisonDirection; const Self = ComparisonDirection;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Direction = enum { pub const Direction = enum {
EQ, EQ,
@ -1156,11 +1131,11 @@ pub const ComparisonDirection = struct {
}; };
pub fn init(ctx: mlir.Context, value: Direction) Self { pub fn init(ctx: mlir.Context, value: Direction) Self {
return Self.wrap(c.stablehloComparisonDirectionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); return .{ ._inner = c.stablehloComparisonDirectionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
} }
pub fn getValue(self: Self) Direction { pub fn getValue(self: Self) Direction {
const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self.inner())); const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self._inner));
return std.meta.stringToEnum(Direction, value) orelse unreachable; return std.meta.stringToEnum(Direction, value) orelse unreachable;
} }
}; };
@ -1168,13 +1143,10 @@ pub const ComparisonDirection = struct {
pub const CompareType = struct { pub const CompareType = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(CompareType, .{ pub const is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr;
.is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = CompareType; const Self = CompareType;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum { pub const Type = enum {
SIGNED, SIGNED,
@ -1184,11 +1156,11 @@ pub const CompareType = struct {
}; };
pub fn init(ctx: mlir.Context, value: Type) Self { pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloComparisonTypeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); return .{ ._inner = c.stablehloComparisonTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
} }
pub fn getValue(self: Self) Type { pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self.inner())); const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable; return std.meta.stringToEnum(Type, value) orelse unreachable;
} }
}; };
@ -1196,13 +1168,10 @@ pub const CompareType = struct {
pub const Transpose = struct { pub const Transpose = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(Transpose, .{ pub const is_a_fn = c.stablehloAttributeIsATransposeAttr;
.is_a_fn = c.stablehloAttributeIsATransposeAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = Transpose; const Self = Transpose;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum { pub const Type = enum {
NO_TRANSPOSE, NO_TRANSPOSE,
@ -1211,11 +1180,11 @@ pub const Transpose = struct {
}; };
pub fn init(ctx: mlir.Context, value: Type) Self { pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloTransposeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); return .{ ._inner = c.stablehloTransposeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
} }
pub fn getValue(self: Self) Type { pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self.inner())); const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable; return std.meta.stringToEnum(Type, value) orelse unreachable;
} }
}; };
@ -1223,13 +1192,10 @@ pub const Transpose = struct {
pub const FftType = struct { pub const FftType = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(FftType, .{ pub const is_a_fn = c.stablehloAttributeIsAFftTypeAttr;
.is_a_fn = c.stablehloAttributeIsAFftTypeAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = FftType; const Self = FftType;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum { pub const Type = enum {
FFT, FFT,
@ -1239,11 +1205,11 @@ pub const FftType = struct {
}; };
pub fn init(ctx: mlir.Context, value: Type) Self { pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloFftTypeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); return .{ ._inner = c.stablehloFftTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
} }
pub fn getValue(self: Self) Type { pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self.inner())); const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable; return std.meta.stringToEnum(Type, value) orelse unreachable;
} }
}; };
@ -1251,13 +1217,10 @@ pub const FftType = struct {
pub const RngDistribution = struct { pub const RngDistribution = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(RngDistribution, .{ pub const is_a_fn = c.stablehloAttributeIsARngDistributionAttr;
.is_a_fn = c.stablehloAttributeIsARngDistributionAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = RngDistribution; const Self = RngDistribution;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum { pub const Type = enum {
UNIFORM, UNIFORM,
@ -1265,11 +1228,11 @@ pub const RngDistribution = struct {
}; };
pub fn init(ctx: mlir.Context, value: Type) Self { pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloRngDistributionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); return .{ ._inner = c.stablehloRngDistributionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
} }
pub fn getValue(self: Self) Type { pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self.inner())); const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable; return std.meta.stringToEnum(Type, value) orelse unreachable;
} }
}; };
@ -1277,13 +1240,10 @@ pub const RngDistribution = struct {
pub const RngAlgorithm = struct { pub const RngAlgorithm = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(RngAlgorithm, .{ pub const is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr;
.is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
const Self = RngAlgorithm; const Self = RngAlgorithm;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum { pub const Type = enum {
DEFAULT, DEFAULT,
@ -1292,11 +1252,11 @@ pub const RngAlgorithm = struct {
}; };
pub fn init(ctx: mlir.Context, value: Type) Self { pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloRngAlgorithmAttrGet(ctx.inner(), mlir.stringRef(@tagName(value)))); return .{ ._inner = c.stablehloRngAlgorithmAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
} }
pub fn getValue(self: Self) Type { pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self.inner())); const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable; return std.meta.stringToEnum(Type, value) orelse unreachable;
} }
}; };

File diff suppressed because it is too large Load Diff

View File

@ -1,176 +0,0 @@
const mlir = @This();
const builtin = @import("builtin");
const std = @import("std");
const stdx = @import("stdx");
const dtype = @import("dtype.zig");
const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor;
const log = std.log.scoped(.@"zml/mlir");
pub usingnamespace @import("mlir");
pub const ext = struct {
pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type {
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type);
}
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
return switch (dt) {
.bool => .bool,
.i8 => .i8,
.i16 => .i16,
.i32 => .i32,
.i64 => .i64,
.u8 => .u8,
.u16 => .u16,
.u32 => .u32,
.u64 => .u64,
.bf16 => .bf16,
.f16 => .f16,
.f32 => .f32,
.f64 => .f64,
else => null,
};
}
pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute {
const ranked_type_ = ranked_type.as(mlir.Type);
return switch (dt) {
.bool => mlir.DenseElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute),
.i8 => mlir.DenseElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute),
.i16 => mlir.DenseElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute),
.i32 => mlir.DenseElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute),
.i64 => mlir.DenseElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute),
.u8 => mlir.DenseElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute),
.u16 => mlir.DenseElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute),
.u32 => mlir.DenseElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute),
.u64 => mlir.DenseElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute),
.bf16 => mlir.DenseElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute),
.f16 => mlir.DenseElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute),
.f32 => mlir.DenseElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute),
.f64 => mlir.DenseElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute),
inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)),
};
}
pub const RankedTensorType = struct {
pub fn fromShape(ctx: mlir.Context, sh: Shape) mlir.RankedTensorType {
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype()));
}
};
pub const Type = struct {
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
return switch (dt) {
.bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type),
.f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type),
.f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type),
.f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type),
.f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type),
.f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type),
.bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type),
.f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type),
.f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type),
.f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type),
.i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type),
.i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type),
.i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type),
.i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type),
.i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type),
.u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type),
.u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type),
.u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type),
.u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type),
.u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type),
.c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type),
.c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type),
};
}
pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
const mapping = .{
.{ .bool, mlir.IntegerType(.i1) },
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
.{ .f8e5m2, mlir.FloatType(.f8e5m2) },
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
.{ .bf16, mlir.FloatType(.bf16) },
.{ .f16, mlir.FloatType(.f16) },
.{ .f32, mlir.FloatType(.f32) },
.{ .f64, mlir.FloatType(.f64) },
.{ .i4, mlir.IntegerType(.i4) },
.{ .i8, mlir.IntegerType(.i8) },
.{ .i16, mlir.IntegerType(.i16) },
.{ .i32, mlir.IntegerType(.i32) },
.{ .i64, mlir.IntegerType(.i64) },
.{ .u4, mlir.IntegerType(.u4) },
.{ .u8, mlir.IntegerType(.u8) },
.{ .u16, mlir.IntegerType(.u16) },
.{ .u32, mlir.IntegerType(.u32) },
.{ .u64, mlir.IntegerType(.u64) },
.{ .c64, mlir.ComplexType(.c64) },
.{ .c128, mlir.ComplexType(.c128) },
};
inline for (mapping) |entry| {
const dt, const mlirT = entry;
if (mlir_type.is_a(mlirT)) {
return dt;
}
}
stdx.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
}
};
pub const Attribute = struct {
pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute {
switch (data) {
.bool => |val| {
return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute);
},
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| {
const float_type = @field(mlir.FloatTypes, @tagName(tag));
const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32());
return float_attr.as(mlir.Attribute);
},
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| {
const int_type = @field(mlir.IntegerTypes, @tagName(tag));
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
return int_attr.as(mlir.Attribute);
},
inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
}
}
};
pub const DenseElementsAttribute = struct {
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
return switch (data.dtype()) {
.bool => mlir.DenseElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute),
.i8 => mlir.DenseElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute),
.i16 => mlir.DenseElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute),
.i32 => mlir.DenseElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute),
.i64 => mlir.DenseElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute),
.u8 => mlir.DenseElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute),
.u16 => mlir.DenseElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute),
.u32 => mlir.DenseElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute),
.u64 => mlir.DenseElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute),
.bf16 => mlir.DenseElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute),
.f16 => mlir.DenseElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute),
.f32 => mlir.DenseElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute),
.f64 => mlir.DenseElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute),
inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
};
}
};
};

101
zml/mlirx.zig Normal file
View File

@ -0,0 +1,101 @@
const std = @import("std");
const mlir = @import("mlir");
const dtype = @import("dtype.zig");
const Shape = @import("shape.zig").Shape;
const mlirx = @This();
/// Returns the mlir.Type corresponding to a given zml.Shape.
pub fn tensorType(ctx: mlir.Context, sh: Shape) mlir.Type {
return .tensor(sh.dims(), mlirx.Type.fromDType(ctx, sh.dtype()));
}
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
return switch (dt) {
.bool => .bool,
.i8 => .i8,
.i16 => .i16,
.i32 => .i32,
.i64 => .i64,
.u8 => .u8,
.u16 => .u16,
.u32 => .u32,
.u64 => .u64,
.bf16 => .bf16,
.f16 => .f16,
.f32 => .f32,
.f64 => .f64,
else => null,
};
}
pub const Type = struct {
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
return switch (dt) {
.bool => .int(ctx, .i1),
.f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz),
.f8e4m3fn => .float(ctx, .f8e4m3fn),
.f8e4m3fnuz => .float(ctx, .f8e4m3fnuz),
.f8e5m2 => .float(ctx, .f8e5m2),
.f8e5m2fnuz => .float(ctx, .f8e5m2fnuz),
.bf16 => .float(ctx, .bf16),
.f16 => .float(ctx, .f16),
.f32 => .float(ctx, .f32),
.f64 => .float(ctx, .f64),
.i4 => .int(ctx, .i4),
.i8 => .int(ctx, .i8),
.i16 => .int(ctx, .i16),
.i32 => .int(ctx, .i32),
.i64 => .int(ctx, .i64),
.u4 => .int(ctx, .u4),
.u8 => .int(ctx, .u8),
.u16 => .int(ctx, .u16),
.u32 => .int(ctx, .u32),
.u64 => .int(ctx, .u64),
.c64 => .complex(ctx, .c64),
.c128 => .complex(ctx, .c128),
};
}
pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
const mapping = .{
.{ .bool, mlir.IntegerType(.i1) },
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
.{ .f8e5m2, mlir.FloatType(.f8e5m2) },
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
.{ .bf16, mlir.FloatType(.bf16) },
.{ .f16, mlir.FloatType(.f16) },
.{ .f32, mlir.FloatType(.f32) },
.{ .f64, mlir.FloatType(.f64) },
.{ .i4, mlir.IntegerType(.i4) },
.{ .i8, mlir.IntegerType(.i8) },
.{ .i16, mlir.IntegerType(.i16) },
.{ .i32, mlir.IntegerType(.i32) },
.{ .i64, mlir.IntegerType(.i64) },
.{ .u4, mlir.IntegerType(.u4) },
.{ .u8, mlir.IntegerType(.u8) },
.{ .u16, mlir.IntegerType(.u16) },
.{ .u32, mlir.IntegerType(.u32) },
.{ .u64, mlir.IntegerType(.u64) },
.{ .c64, mlir.ComplexType(.c64) },
.{ .c128, mlir.ComplexType(.c128) },
};
inline for (mapping) |entry| {
const dt, const mlirT = entry;
if (mlirT.is_a_fn(mlir_type._inner)) {
return dt;
}
}
std.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
}
};

View File

@ -2,21 +2,18 @@ const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const dialect = @import("mlir/dialects"); const dialect = @import("mlir/dialects");
const runfiles = @import("runfiles"); const mlir = @import("mlir");
const stdx = @import("stdx"); const stdx = @import("stdx");
const xla_pb = @import("//xla:xla_proto"); const xla_pb = @import("//xla:xla_proto");
const BaseExe = @import("exe.zig").BaseExe; const BaseExe = @import("exe.zig").BaseExe;
const Buffer = @import("buffer.zig").Buffer; const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("tensor.zig").Bufferized;
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const mlir = @import("mlir.zig"); const mlirx = @import("mlirx.zig");
const Location = mlir.Location;
const ops = @import("ops.zig"); const ops = @import("ops.zig");
const pjrt = @import("pjrtx.zig"); const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;
const Target = @import("platform.zig").Target; const Target = @import("platform.zig").Target;
const Tensor = @import("tensor.zig").Tensor; const Tensor = @import("tensor.zig").Tensor;
const Tracer = @import("tools/tracer.zig").Tracer; const Tracer = @import("tools/tracer.zig").Tracer;
@ -170,8 +167,8 @@ pub const CompilationContext = struct {
const sharding = self._platform.sharding(); const sharding = self._platform.sharding();
const mlir_ctx = self._mlir_ctx; const mlir_ctx = self._mlir_ctx;
module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr()); module.op().setAttributeByName("mhlo.num_replicas", .int(mlir_ctx, .i32, sharding.num_replicas));
module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr()); module.op().setAttributeByName("mhlo.num_partitions", .int(mlir_ctx, .i32, sharding.num_partitions));
const module_hash = computeModuleHash(self._platform, module); const module_hash = computeModuleHash(self._platform, module);
var module_dir: ?[]const u8 = null; var module_dir: ?[]const u8 = null;
@ -346,7 +343,7 @@ pub const CompilationContext = struct {
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
const input_types = try arena.alloc(mlir.Type, tensor_count); const input_types = try arena.alloc(mlir.Type, tensor_count);
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh); for (input_types, input_shapes.items) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh);
const og_block_args = self._block_args; const og_block_args = self._block_args;
defer { defer {
@ -947,7 +944,7 @@ pub fn fillMlirTypes(v: anytype, mlir_ctx: mlir.Context, types: []mlir.Type) voi
var context = LocalContext{ .mlir_ctx = mlir_ctx, .types = types }; var context = LocalContext{ .mlir_ctx = mlir_ctx, .types = types };
meta.visit((struct { meta.visit((struct {
fn cb(inner_context: *LocalContext, tensor: *const Tensor) void { fn cb(inner_context: *LocalContext, tensor: *const Tensor) void {
inner_context.types[inner_context.index] = mlir.ext.mlirType(inner_context.mlir_ctx, tensor.shape()); inner_context.types[inner_context.index] = mlirx.tensorType(inner_context.mlir_ctx, tensor.shape());
inner_context.index += 1; inner_context.index += 1;
} }
}).cb, &context, v); }).cb, &context, v);

View File

@ -5,7 +5,7 @@ const dialect = @import("mlir/dialects");
const Context = @import("../context.zig").Context; const Context = @import("../context.zig").Context;
const DataType = @import("../dtype.zig").DataType; const DataType = @import("../dtype.zig").DataType;
const Data = @import("../dtype.zig").Data; const Data = @import("../dtype.zig").Data;
const mlir = @import("../mlir.zig"); const mlirx = @import("../mlirx.zig");
const module = @import("../module.zig"); const module = @import("../module.zig");
const CompilationContext = module.CompilationContext; const CompilationContext = module.CompilationContext;
const SdpaOpts = @import("../nn.zig").SdpaOpts; const SdpaOpts = @import("../nn.zig").SdpaOpts;
@ -130,7 +130,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
.api_version = .original, .api_version = .original,
}, },
&.{ &.{
mlir.ext.mlirType(mlir_ctx, q.shape()), mlirx.tensorType(mlir_ctx, q.shape()),
.tensor(&.{0}, .int(mlir_ctx, .u8)), .tensor(&.{0}, .int(mlir_ctx, .u8)),
}, },
loc, loc,

View File

@ -1,24 +1,16 @@
const std = @import("std"); const std = @import("std");
const assert = std.debug.assert;
const mlir = @import("mlir");
const stdx = @import("stdx"); const stdx = @import("stdx");
const _collectAxes = @import("tensor.zig")._collectAxes; const _collectAxes = @import("tensor.zig")._collectAxes;
const buffer = @import("buffer.zig"); const Buffer = @import("buffer.zig").Buffer;
const Buffer = buffer.Buffer; const CompilationContext = @import("module.zig").CompilationContext;
const Bufferized = @import("tensor.zig").Bufferized;
const Context = @import("context.zig").Context; const Context = @import("context.zig").Context;
const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType;
const helpers = @import("helpers.zig");
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const mlir = @import("mlir.zig"); const mlirx = @import("mlirx.zig");
const module = @import("module.zig");
const CompilationContext = module.CompilationContext;
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;
const Tensor = @import("tensor.zig").Tensor; const Tensor = @import("tensor.zig").Tensor;
const EnumLiteral = @TypeOf(.enum_literal); const EnumLiteral = @TypeOf(.enum_literal);
@ -200,14 +192,14 @@ pub fn reduce(
mlir_ctx, mlir_ctx,
val, val,
inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced], inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced],
mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type), mlirx.tensorType(mlir_ctx, reduced_shape),
inner_ctx.loc, inner_ctx.loc,
); );
tensor.* = Tensor._result(reduced_shape, broad_val.result(0)); tensor.* = Tensor._result(reduced_shape, broad_val.result(0));
inner_ctx.index += 1; inner_ctx.index += 1;
} }
}).cb, &local_context, &res); }).cb, &local_context, &res);
assert(local_context.index == op.numResults()); std.debug.assert(local_context.index == op.numResults());
return res; return res;
} }
@ -248,7 +240,8 @@ pub fn reduceWindow(
.{ "window_strides", .dense(ctx.mlirCtx(), .i64, opts.window_strides) }, .{ "window_strides", .dense(ctx.mlirCtx(), .i64, opts.window_strides) },
.{ "base_dilations", .dense(ctx.mlirCtx(), .i64, opts.base_dilations) }, .{ "base_dilations", .dense(ctx.mlirCtx(), .i64, opts.base_dilations) },
.{ "window_dilations", .dense(ctx.mlirCtx(), .i64, opts.window_dilations) }, .{ "window_dilations", .dense(ctx.mlirCtx(), .i64, opts.window_dilations) },
.{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, opts.padding) }, // Cast the [][2]i64 to []i64 (safe)
.{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, @ptrCast(opts.padding)) },
}, },
.location = loc, .location = loc,
}); });
@ -609,8 +602,8 @@ pub fn sort(
.result_type_inference = true, .result_type_inference = true,
.blocks = &.{block}, .blocks = &.{block},
.attributes = &.{ .attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute) }, .{ "dimension", .int(ctx.mlirCtx(), .i64, dimension) },
.{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute) }, .{ "is_stable", .boolean(ctx.mlirCtx(), is_stable) },
}, },
.location = loc, .location = loc,
}); });
@ -767,7 +760,7 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
inner_ctx.index += 1; inner_ctx.index += 1;
} }
}).cb, &context, &res); }).cb, &context, &res);
assert(context.index == op.numResults()); std.debug.assert(context.index == op.numResults());
return res; return res;
} }
@ -817,7 +810,7 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T
var res_types: [outputs.len]mlir.Type = undefined; var res_types: [outputs.len]mlir.Type = undefined;
inline for (outputs, 0..) |output, i| { inline for (outputs, 0..) |output, i| {
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output); res_types[i] = mlirx.tensorType(ctx.mlirCtx(), output);
} }
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &.{ const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &.{
@ -1031,7 +1024,7 @@ pub fn scatter(
inner_ctx.index += 1; inner_ctx.index += 1;
} }
}).cb, &local_context, &res); }).cb, &local_context, &res);
assert(local_context.index == op.numResults()); std.debug.assert(local_context.index == op.numResults());
return res; return res;
} }
@ -1327,30 +1320,30 @@ pub fn customCall(target_name: [:0]const u8, inputs: anytype, outputs: anytype,
} }
fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor { fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor {
const ctx = module.CompilationContext.current(); const ctx = CompilationContext.current();
const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable; const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable;
ctx.extractValues(inputs, values); ctx.extractValues(inputs, values);
const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable; const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable;
for (outputs, 0..) |output, i| { for (outputs, 0..) |output, i| {
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output); res_types[i] = mlirx.tensorType(ctx.mlirCtx(), output);
} }
const metadata_type_info = @typeInfo(@TypeOf(metadata)); const metadata_type_info = @typeInfo(@TypeOf(metadata));
var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined; var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined;
inline for (metadata_type_info.@"struct".fields, 0..) |field, i| { inline for (metadata_type_info.@"struct".fields, 0..) |field, i| {
const attribute: mlir.Attribute = switch (@typeInfo(field.type)) { const attribute: mlir.Attribute = switch (@typeInfo(field.type)) {
.int, .comptime_int => mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))), .int, .comptime_int => .int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))),
else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)), else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)),
}; };
metadata_attributes_tuple[i] = .{ field.name, attribute }; metadata_attributes_tuple[i] = .{ field.name, attribute };
} }
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(.{ const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(metadata_attributes_tuple ++ [_]mlir.AttrTuple{
.{ "pjrt_api", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) }, .{ "pjrt_api", .int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) },
.{ "pjrt_client", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) }, .{ "pjrt_client", .int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) },
} ++ metadata_attributes_tuple)); }));
const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable; const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable;
for (inputs, 0..) |input, i| { for (inputs, 0..) |input, i| {

View File

@ -1,20 +1,17 @@
const std = @import("std"); const std = @import("std");
const assert = std.debug.assert;
const testing = std.testing;
const builtin = @import("builtin"); const builtin = @import("builtin");
const mlir = @import("mlir");
const stdx = @import("stdx"); const stdx = @import("stdx");
const Buffer = @import("buffer.zig").Buffer; const Buffer = @import("buffer.zig").Buffer;
const CompilationContext = @import("module.zig").CompilationContext;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Memory = @import("buffer.zig").Buffer.Memory; const Memory = @import("buffer.zig").Buffer.Memory;
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const mlir = @import("mlir.zig"); const mlirx = @import("mlirx.zig");
const Location = mlir.Location;
const module = @import("module.zig");
const CompilationContext = module.CompilationContext;
const ops = @import("ops.zig"); const ops = @import("ops.zig");
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
@ -112,12 +109,12 @@ pub const Tensor = struct {
/// ///
/// The shape is derived from the type of the mlir.Value. /// The shape is derived from the type of the mlir.Value.
pub fn fromMlirValue(val: mlir.Value) Tensor { pub fn fromMlirValue(val: mlir.Value) Tensor {
const ranked_tensor = val.getType().as(mlir.RankedTensorType); const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
const n = ranked_tensor.getRank(); const n = ranked_tensor.getRank();
stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK }); stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) }; var sh: Shape = .{ ._dtype = mlirx.Type.toDType(ranked_tensor.getElementType()) };
for (0..n) |i| { for (0..n) |i| {
sh._dims.appendAssumeCapacity(ranked_tensor.getDimension(i)); sh._dims.appendAssumeCapacity(ranked_tensor.getDimension(i));
} }
@ -322,7 +319,7 @@ pub const Tensor = struct {
const op = dialect.stablehlo.bitcast_convert( const op = dialect.stablehlo.bitcast_convert(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type), mlirx.tensorType(self.getContext().mlirCtx(), res_shape),
loc, loc,
); );
@ -559,8 +556,8 @@ pub const Tensor = struct {
ctx.mlirCtx(), ctx.mlirCtx(),
self.algorithm, self.algorithm,
self._state.value(), self._state.value(),
mlir.ext.mlirType(ctx.mlirCtx(), self._state._shape), mlirx.tensorType(ctx.mlirCtx(), self._state._shape),
mlir.ext.mlirType(ctx.mlirCtx(), sh), mlirx.tensorType(ctx.mlirCtx(), sh),
loc, loc,
); );
return .{ self.update(op.result(0)), _result(sh, op.result(1)) }; return .{ self.update(op.result(0)), _result(sh, op.result(1)) };
@ -870,7 +867,7 @@ pub const Tensor = struct {
self.value(), self.value(),
other.value(), other.value(),
used_opts, used_opts,
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type), mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc, loc,
); );
@ -1052,7 +1049,7 @@ pub const Tensor = struct {
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) }); const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
const mlir_ctx = self.getContext().mlirCtx(); const mlir_ctx = self.getContext().mlirCtx();
const res_type = mlir.ext.mlirType(mlir_ctx, self.shape().withDtype(to)); const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to));
const op = dialect.stablehlo.convert(mlir_ctx, self.value(), res_type, loc); const op = dialect.stablehlo.convert(mlir_ctx, self.value(), res_type, loc);
return _result(self._shape.withDtype(to), op.result(0)); return _result(self._shape.withDtype(to), op.result(0));
} }
@ -1217,7 +1214,7 @@ pub const Tensor = struct {
mlir_ctx, mlir_ctx,
lhs.value(), lhs.value(),
rhs.value(), rhs.value(),
mlir.ext.mlirType(mlir_ctx, res_shape), mlirx.tensorType(mlir_ctx, res_shape),
loc, loc,
.{ .{
.lhs_batching_dimensions = lhs_batching_axes.constSlice(), .lhs_batching_dimensions = lhs_batching_axes.constSlice(),
@ -1392,7 +1389,7 @@ pub const Tensor = struct {
[2][5]f32{ .{ 0, 1, 1, 0, 1 }, .{ 3, 1, 0, 2, 1 } }, [2][5]f32{ .{ 0, 1, 1, 0, 1 }, .{ 3, 1, 0, 2, 1 } },
); );
const res = try zml.testing.compileAndCall(platform, Local._cumsum, .{x}); const res = try zml.testing.compileAndCall(platform, Local._cumsum, .{x});
try testing.expectEqual( try std.testing.expectEqual(
[2][5]f32{ .{ 0, 1, 2, 2, 3 }, .{ 3, 4, 4, 6, 7 } }, [2][5]f32{ .{ 0, 1, 2, 2, 3 }, .{ 3, 4, 4, 6, 7 } },
try res.getValue([2][5]f32), try res.getValue([2][5]f32),
); );
@ -1424,7 +1421,7 @@ pub const Tensor = struct {
const op = dialect.stablehlo.transpose( const op = dialect.stablehlo.transpose(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
mlir.ext.mlirType(self.getContext().mlirCtx(), res_shape), mlirx.tensorType(self.getContext().mlirCtx(), res_shape),
loc, loc,
.{ .permutation = toI64(permutation) }, .{ .permutation = toI64(permutation) },
); );
@ -1457,7 +1454,7 @@ pub const Tensor = struct {
const reshaped_val = dialect.stablehlo.reshape( const reshaped_val = dialect.stablehlo.reshape(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape), mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc, loc,
); );
return _result(new_shape, reshaped_val.result(0)); return _result(new_shape, reshaped_val.result(0));
@ -1474,7 +1471,7 @@ pub const Tensor = struct {
const reshaped_val = dialect.stablehlo.reshape( const reshaped_val = dialect.stablehlo.reshape(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape), mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc, loc,
); );
return _result(new_shape, reshaped_val.result(0)); return _result(new_shape, reshaped_val.result(0));
@ -1512,7 +1509,7 @@ pub const Tensor = struct {
const reshaped_val = dialect.stablehlo.reshape( const reshaped_val = dialect.stablehlo.reshape(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape), mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc, loc,
); );
// log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] }); // log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] });
@ -1586,7 +1583,7 @@ pub const Tensor = struct {
const mlir_ctx = self.getContext().mlirCtx(); const mlir_ctx = self.getContext().mlirCtx();
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices}); const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type); const result_type = mlirx.tensorType(mlir_ctx, res_shape);
const slice_op = dialect.stablehlo.slice( const slice_op = dialect.stablehlo.slice(
mlir_ctx, mlir_ctx,
self.value(), self.value(),
@ -1620,15 +1617,15 @@ pub const Tensor = struct {
{ {
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } }); const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } });
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32)); try std.testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
} }
{ {
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } }); const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } });
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32)); try std.testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
} }
{ {
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } }); const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } });
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32)); try std.testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
} }
} }
@ -1838,7 +1835,7 @@ pub const Tensor = struct {
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable; const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
const sh = Shape.init(.{n_steps}, dt); const sh = Shape.init(.{n_steps}, dt);
var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc); var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc);
var res = _result(sh, op.result(0)); var res = _result(sh, op.result(0));
if (args.step != 1) { if (args.step != 1) {
@ -1868,7 +1865,7 @@ pub const Tensor = struct {
var op = dialect.stablehlo.iota( var op = dialect.stablehlo.iota(
mlir_ctx, mlir_ctx,
a, a,
mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type), mlirx.tensorType(mlir_ctx, res_shape),
loc, loc,
); );
return _result(res_shape, op.result(0)); return _result(res_shape, op.result(0));
@ -1890,7 +1887,7 @@ pub const Tensor = struct {
const loc = ctx.location(@src(), "linspace({}, dtype={})", .{ args, dt }); const loc = ctx.location(@src(), "linspace({}, dtype={})", .{ args, dt });
const sh = Shape.init(.{args.steps}, dt); const sh = Shape.init(.{args.steps}, dt);
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc); var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc);
var res = _result(sh, iota_op.result(0)); var res = _result(sh, iota_op.result(0));
if (args.steps != 1) { if (args.steps != 1) {
@ -1933,21 +1930,19 @@ pub const Tensor = struct {
/// Returns a constant Tensor with the given value. /// Returns a constant Tensor with the given value.
pub fn constant(dimz: anytype, val: Data) Tensor { pub fn constant(dimz: anytype, val: Data) Tensor {
const sh = Shape.init(dimz, val.dtype()); const sh = Shape.init(dimz, val.dtype());
const singleton_sh = Shape.init(.{}, val.dtype());
const ctx = CompilationContext.current().mlirCtx(); const ctx = CompilationContext.current().mlirCtx();
const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val }); const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val });
const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type| var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type|
dialect.stablehlo.constant(ctx, res_type, elem_type, val.constSlice(), loc) dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc)
else blk: { else blk: {
// Not all dtype can be serialized in the IR. If that's not possible, use f32. // Not all dtype can be serialized in the IR. If that's not possible, use f32.
const val_f32 = val.as(f32); const val_f32 = val.as(f32);
break :blk dialect.stablehlo.constant(ctx, res_type, .f32, std.mem.asBytes(&val_f32), loc); break :blk dialect.stablehlo.constant(ctx, &.{}, .f32, std.mem.asBytes(&val_f32), loc);
}; };
if (sh.rank() > 0) { if (sh.rank() > 0) {
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type), loc); constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlirx.tensorType(ctx, sh), loc);
} }
return _result(sh, constant_op.result(0)).convert(val.dtype()); return _result(sh, constant_op.result(0)).convert(val.dtype());
} }
@ -1955,10 +1950,9 @@ pub const Tensor = struct {
/// Embeds a buffer with concrete values into an Mlir program. /// Embeds a buffer with concrete values into an Mlir program.
pub fn constantTensor(val: HostBuffer) Tensor { pub fn constantTensor(val: HostBuffer) Tensor {
const ctx = CompilationContext.current().mlirCtx(); const ctx = CompilationContext.current().mlirCtx();
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
const loc = ctx.location(@src()); const loc = ctx.location(@src());
const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()}); const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.bytes(), loc); const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc);
return _result(val.shape(), constant_op.result(0)); return _result(val.shape(), constant_op.result(0));
} }
@ -1994,7 +1988,7 @@ pub const Tensor = struct {
return _result(res_shape, self.value()); return _result(res_shape, self.value());
} }
const ctx = self.getContext(); const ctx = self.getContext();
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type); const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape);
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc); const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
@ -2052,7 +2046,7 @@ pub const Tensor = struct {
/// Reshapes the input Tensor with the given shape. /// Reshapes the input Tensor with the given shape.
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor { pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
const output_shape = self._shape.reshape(output_shape_); const output_shape = self._shape.reshape(output_shape_);
const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape); const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape);
const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape}); const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape});
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc); const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
return _result(output_shape, reshape_value.result(0)); return _result(output_shape, reshape_value.result(0));
@ -2846,9 +2840,9 @@ pub const Tensor = struct {
const res = argmax.call(.{x}); const res = argmax.call(.{x});
const max_ = res.values.getValue(f32); const max_ = res.values.getValue(f32);
const max_idx = res.indices.getValue(i32); const max_idx = res.indices.getValue(i32);
try testing.expectEqual(max_, 7.9); try std.testing.expectEqual(max_, 7.9);
// We should always return the first max found. // We should always return the first max found.
try testing.expectEqual(max_idx, 2); try std.testing.expectEqual(max_idx, 2);
} }
// Test with Nan // Test with Nan
@ -2857,8 +2851,8 @@ pub const Tensor = struct {
const res = argmax.call(.{x}); const res = argmax.call(.{x});
const max_ = try res.values.getValue(f32); const max_ = try res.values.getValue(f32);
const max_idx = try res.indices.getValue(i32); const max_idx = try res.indices.getValue(i32);
try testing.expect(std.math.isNan(max_)); try std.testing.expect(std.math.isNan(max_));
try testing.expectEqual(max_idx, 1); try std.testing.expectEqual(max_idx, 1);
} }
} }
@ -2907,7 +2901,7 @@ pub const Tensor = struct {
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 }); const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 });
const res = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{} }); const res = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{} });
const res_cpu = try res.toHostAlloc(allocator); const res_cpu = try res.toHostAlloc(allocator);
try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32)); try std.testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
} }
// 3D Tensor, dim = 1, descending // 3D Tensor, dim = 1, descending
{ {
@ -2920,7 +2914,7 @@ pub const Tensor = struct {
}); });
const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{ .descending = true } }); const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{ .descending = true } });
const res = try res_dev.toHostAlloc(allocator); const res = try res_dev.toHostAlloc(allocator);
try testing.expectEqualSlices(i32, &.{ try std.testing.expectEqualSlices(i32, &.{
4, 1, 1, 2, 0, 2, 0, 0, 3, 4, 4, 1, 1, 2, 0, 2, 0, 0, 3, 4,
2, 0, 4, 4, 1, 3, 4, 4, 1, 0, 2, 0, 4, 4, 1, 3, 4, 4, 1, 0,
1, 4, 2, 0, 2, 4, 2, 2, 0, 3, 1, 4, 2, 0, 2, 4, 2, 2, 0, 3,
@ -2942,7 +2936,7 @@ pub const Tensor = struct {
}); });
const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 3, .{} }); const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 3, .{} });
const res = try res_dev.toHostAlloc(allocator); const res = try res_dev.toHostAlloc(allocator);
try testing.expectEqualSlices(i32, &.{ try std.testing.expectEqualSlices(i32, &.{
2, 1, 3, 0, 2, 1, 3, 0,
2, 3, 1, 0, 2, 3, 1, 0,
3, 2, 0, 1, 3, 2, 0, 1,
@ -3262,7 +3256,7 @@ pub const Tensor = struct {
const z = try zml.Buffer.scalar(platform, 4, .i32); const z = try zml.Buffer.scalar(platform, 4, .i32);
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } }); const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } });
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T)); try std.testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
} }
{ {
@ -3271,7 +3265,7 @@ pub const Tensor = struct {
const z = try zml.Buffer.scalar(platform, 3, .i32); const z = try zml.Buffer.scalar(platform, 3, .i32);
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } }); const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } });
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T)); try std.testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
} }
} }
@ -3389,7 +3383,7 @@ pub const Tensor = struct {
}._fwd, }._fwd,
.{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) }, .{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) },
); );
try testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32)); try std.testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32));
} }
{ {
@ -3407,7 +3401,7 @@ pub const Tensor = struct {
}._fwd, }._fwd,
.{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) }, .{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) },
); );
try testing.expectEqualDeep( try std.testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } }, [2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } },
try res.getValue([2][5]f32), try res.getValue([2][5]f32),
); );
@ -3427,7 +3421,7 @@ pub const Tensor = struct {
}._fwd, }._fwd,
.{ x, idx, y }, .{ x, idx, y },
); );
try testing.expectEqualDeep( try std.testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } }, [2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } },
res.getValue([2][5]f32), res.getValue([2][5]f32),
); );
@ -3448,7 +3442,7 @@ pub const Tensor = struct {
}._fwd, }._fwd,
.{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) }, .{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) },
); );
try testing.expectEqualDeep( try std.testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } }, [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
res.getValue([2][5]f32), res.getValue([2][5]f32),
); );
@ -3466,7 +3460,7 @@ pub const Tensor = struct {
} }
}; };
const res = try zml.testing.compileAndCall(platform, A._fwd, .{ x, .{ idx_a, idx_b }, y }); const res = try zml.testing.compileAndCall(platform, A._fwd, .{ x, .{ idx_a, idx_b }, y });
try testing.expectEqualDeep( try std.testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } }, [2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
res.getValue([2][5]f32), res.getValue([2][5]f32),
); );
@ -3531,7 +3525,7 @@ pub const Tensor = struct {
const x = try zml.Buffer.fromArray(platform, [2][2]u8{ .{ 1, 2 }, .{ 3, 4 } }); const x = try zml.Buffer.fromArray(platform, [2][2]u8{ .{ 1, 2 }, .{ 3, 4 } });
{ {
const res = try zml.testing.compileAndCall(platform, Local._toDiag, .{x}); const res = try zml.testing.compileAndCall(platform, Local._toDiag, .{x});
try testing.expectEqual( try std.testing.expectEqual(
[2][2][2]u8{ .{ [2][2][2]u8{ .{
.{ 1, 0 }, .{ 1, 0 },
.{ 0, 2 }, .{ 0, 2 },
@ -3582,7 +3576,7 @@ pub const Tensor = struct {
}); });
{ {
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 0 }); const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 0 });
try testing.expectEqual( try std.testing.expectEqual(
[3][3]u8{ [3][3]u8{
.{ 1, 0, 0 }, .{ 1, 0, 0 },
.{ 1, 1, 0 }, .{ 1, 1, 0 },
@ -3593,7 +3587,7 @@ pub const Tensor = struct {
} }
{ {
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 1 }); const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 1 });
try testing.expectEqual( try std.testing.expectEqual(
[3][3]u8{ [3][3]u8{
.{ 1, 1, 0 }, .{ 1, 1, 0 },
.{ 1, 1, 1 }, .{ 1, 1, 1 },
@ -3604,7 +3598,7 @@ pub const Tensor = struct {
} }
{ {
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, -1 }); const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, -1 });
try testing.expectEqual( try std.testing.expectEqual(
[3][3]u8{ [3][3]u8{
.{ 0, 0, 0 }, .{ 0, 0, 0 },
.{ 1, 0, 0 }, .{ 1, 0, 0 },

View File

@ -25,7 +25,7 @@ pub const nn = @import("nn.zig");
pub const module = @import("module.zig"); pub const module = @import("module.zig");
pub const meta = @import("meta.zig"); pub const meta = @import("meta.zig");
pub const platform = @import("platform.zig"); pub const platform = @import("platform.zig");
pub const mlir = @import("mlir.zig"); pub const mlir = @import("mlirx.zig");
pub const pjrt = @import("pjrtx.zig"); pub const pjrt = @import("pjrtx.zig");
pub const testing = @import("testing.zig"); pub const testing = @import("testing.zig");
pub const torch = @import("torch.zig"); pub const torch = @import("torch.zig");