Remove usingnamespace from MLIR.

This commit is contained in:
Tarry Singh 2025-01-28 09:35:58 +00:00
parent f8ab0d7b2a
commit 0a2ab7c8cb
10 changed files with 777 additions and 1031 deletions

View File

@ -73,7 +73,7 @@ pub fn cmpi(ctx: mlir.Context, predicate: CmpIPredicate, lhs: mlir.Value, rhs: m
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
.{ "predicate", .int(ctx, .i64, @intFromEnum(predicate)) },
},
.location = location,
});
@ -103,7 +103,7 @@ pub fn cmpf(ctx: mlir.Context, predicate: CmpFPredicate, lhs: mlir.Value, rhs: m
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.attributes = &.{
.{ "predicate", mlir.IntegerAttribute(.i64).init(ctx, @intFromEnum(predicate)).as(mlir.Attribute) },
.{ "predicate", .int(ctx, .i64, @intFromEnum(predicate)) },
},
.location = location,
});

View File

@ -127,10 +127,10 @@ pub const DotPrecision = union(enum) {
// When we specify the dot algorithm, we should not specify the precision.
.algorithm => .DEFAULT,
});
return precision.as(mlir.Attribute);
return precision.asAttr();
}
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.Type) ?mlir.Attribute {
pub fn algorithmAttr(self: DotPrecision, ctx: mlir.Context, operand_type: mlir.RankedTensorType) ?mlir.Attribute {
return switch (self) {
.algorithm => |algo| algo.asAttr(ctx, operand_type),
else => null,
@ -156,15 +156,14 @@ pub const DotAlgorithm = struct {
.allow_imprecise_accumulation = false,
};
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, operand_type: mlir.Type) mlir.Attribute {
const tensor_type = operand_type.as(mlir.RankedTensorType);
pub fn asAttr(self: DotAlgorithm, ctx: mlir.Context, tensor_type: mlir.RankedTensorType) mlir.Attribute {
const elem_type = tensor_type.getElementType();
return mlir.Attribute.wrap(c.stablehloDotAlgorithmGet(
ctx.inner(),
elem_type.inner(),
elem_type.inner(),
self.accumulation.asType(ctx).inner(),
ctx._inner,
elem_type._inner,
elem_type._inner,
self.accumulation.asType(ctx)._inner,
self.component_count,
self.component_count,
self.num_primitive_operations,
@ -197,11 +196,11 @@ pub fn dot_general(
.rhs_batching_dimensions = opts.rhs_batching_dimensions,
.lhs_contracting_dimensions = opts.lhs_contracting_dimensions,
.rhs_contracting_dimensions = opts.rhs_contracting_dimensions,
}).as(mlir.Attribute),
}).asAttr(),
},
.{ "precision_config", .array(ctx, &precisions) },
// keep algorithm as the last attribute so we can omit it when it's not set.
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType()) orelse undefined },
.{ "algorithm", opts.precision.algorithmAttr(ctx, lhs.getType().as(mlir.RankedTensorType).?) orelse undefined },
};
const n_attributes = if (opts.precision == .algorithm) attributes.len else attributes.len - 1;
return mlir.Operation.make(ctx, "stablehlo.dot_general", .{
@ -214,19 +213,15 @@ pub fn dot_general(
pub fn constant(
ctx: mlir.Context,
result_type: mlir.RankedTensorType,
dims: []const i64,
elem_type: mlir.DenseElementsAttributeTypes,
raw_bytes: []const u8,
location: mlir.Location,
) mlir.Operation {
const attribute = switch (elem_type) {
inline else => |dt| mlir.DenseElementsAttribute(dt).init(result_type.as(mlir.Type), raw_bytes).as(mlir.Attribute),
};
return mlir.Operation.make(ctx, "stablehlo.constant", .{
.operands = &.{},
.results = &.{result_type.as(mlir.Type)},
.attributes = &.{.{ "value", attribute }},
.results = &.{.tensor(dims, elem_type.mlirType(ctx))},
.attributes = &.{.{ "value", .denseElementsFromBytes(ctx, dims, elem_type, raw_bytes) }},
.location = location,
});
}
@ -285,10 +280,10 @@ pub fn concatenate(ctx: mlir.Context, inputs: []const mlir.Value, dimension: i64
});
}
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.RankedTensorType, location: mlir.Location) mlir.Operation {
pub fn reshape(ctx: mlir.Context, value: mlir.Value, result_type: mlir.Type, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.reshape", .{
.operands = &.{value},
.results = &.{result_type.as(mlir.Type)},
.results = &.{result_type},
.location = location,
});
}
@ -332,7 +327,7 @@ pub fn gather(
args.start_indices_batching_dims,
args.start_index_map,
args.index_vector_dim,
).as(mlir.Attribute) },
).asAttr() },
.{ "slice_sizes", .dense(ctx, .i64, slice_sizes) },
.{ "indices_are_sorted", .boolean(ctx, args.indices_are_sorted) },
},
@ -358,22 +353,20 @@ pub const ScatterArgs = struct {
unique_indices: bool = false,
pub fn getScatterDimensionNumbers(self: ScatterArgs, ctx: mlir.Context) mlir.Attribute {
return mlir.Attribute.wrap(
c.stablehloScatterDimensionNumbersGet(
ctx.inner(),
@intCast(self.update_window_dims.len),
self.update_window_dims.ptr,
@intCast(self.inserted_window_dims.len),
self.inserted_window_dims.ptr,
@intCast(self.input_batching_dims.len),
self.input_batching_dims.ptr,
@intCast(self.scatter_indices_batching_dims.len),
self.scatter_indices_batching_dims.ptr,
@intCast(self.scatter_dims_to_operand_dims.len),
self.scatter_dims_to_operand_dims.ptr,
self.index_vector_dim,
),
);
return .{ ._inner = c.stablehloScatterDimensionNumbersGet(
ctx._inner,
@intCast(self.update_window_dims.len),
self.update_window_dims.ptr,
@intCast(self.inserted_window_dims.len),
self.inserted_window_dims.ptr,
@intCast(self.input_batching_dims.len),
self.input_batching_dims.ptr,
@intCast(self.scatter_indices_batching_dims.len),
self.scatter_indices_batching_dims.ptr,
@intCast(self.scatter_dims_to_operand_dims.len),
self.scatter_dims_to_operand_dims.ptr,
self.index_vector_dim,
) };
}
};
@ -431,8 +424,8 @@ pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_d
.operands = &.{ lhs, rhs },
.result_type_inference = true,
.attributes = &.{
.{ "comparison_direction", comparison_direction.as(mlir.Attribute) },
.{ "compare_type", compare_type.as(mlir.Attribute) },
.{ "comparison_direction", comparison_direction.asAttr() },
.{ "compare_type", compare_type.asAttr() },
},
.location = location,
});
@ -580,7 +573,7 @@ pub fn triangular_solve(ctx: mlir.Context, value: mlir.Value, other: mlir.Value,
.{ "left_side", .i1FromBool(ctx, opts.left_side) },
.{ "lower", .i1FromBool(ctx, opts.lower) },
.{ "unit_diagonal", .i1FromBool(ctx, opts.unit_diagonal) },
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).as(mlir.Attribute) },
.{ "transpose_a", Transpose.init(ctx, opts.transpose_a).asAttr() },
},
.location = location,
});
@ -596,7 +589,7 @@ pub fn fft(ctx: mlir.Context, value: mlir.Value, location: mlir.Location, opts:
.operands = &.{value},
.result_type_inference = true,
.attributes = &.{
.{ "fft_type", FftType.init(ctx, opts.kind).as(mlir.Attribute) },
.{ "fft_type", FftType.init(ctx, opts.kind).asAttr() },
.{ "fft_length", .dense(ctx, .i64, opts.length) },
},
.location = location,
@ -608,7 +601,7 @@ pub fn rng(ctx: mlir.Context, a: mlir.Value, b: mlir.Value, shape: mlir.Value, r
.operands = &.{ a, b, shape },
.result_type_inference = true,
.attributes = &.{
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).as(mlir.Attribute) },
.{ "rng_distribution", RngDistribution.init(ctx, rng_distribution).asAttr() },
},
.location = location,
});
@ -619,7 +612,7 @@ pub fn rng_bit_generator(ctx: mlir.Context, rng_algorithm: RngAlgorithm.Type, in
.operands = &.{initial_state},
.results = &.{ res_state_type, res_type },
.attributes = &.{
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).as(mlir.Attribute) },
.{ "rng_algorithm", RngAlgorithm.init(ctx, rng_algorithm).asAttr() },
},
.location = location,
});
@ -695,7 +688,7 @@ pub fn convolution(
) mlir.Operation {
var max_precisions: [2]mlir.Attribute = undefined;
for (opts.precision_config, 0..) |p, i| {
max_precisions[i] = PrecisionAttribute.init(ctx, p).as(mlir.Attribute);
max_precisions[i] = PrecisionAttribute.init(ctx, p).asAttr();
}
var window_reversal: [3]i32 = undefined;
for (opts.window_reversal, 0..) |w, i| {
@ -721,7 +714,7 @@ pub fn convolution(
.output_batch_dimension = opts.output_batch_dimension,
.output_feature_dimension = opts.output_feature_dimension,
.output_spatial_dimensions = opts.output_spatial_dimensions,
}).as(mlir.Attribute),
}).asAttr(),
},
.{ "feature_group_count", .int(ctx, .i64, opts.feature_group_count) },
.{ "batch_group_count", .int(ctx, .i64, opts.batch_group_count) },
@ -756,13 +749,13 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, "");
if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) {
stdx.debug.assert(
backend_config.is_a(mlir.StringAttribute),
backend_config.isA(mlir.StringAttribute),
"API version < 4 requires a string as backend_config, got {}",
.{backend_config},
);
} else {
stdx.debug.assert(
backend_config.is_a(mlir.DictionaryAttribute),
backend_config.isA(mlir.DictionaryAttribute),
"API version >= 4 requires a dictionary as backend_config, got {}",
.{backend_config},
);
@ -780,7 +773,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (opts.output_operand_aliases) |alias| {
output_operand_aliases.appendAssumeCapacity(
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute),
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).asAttr(),
);
}
attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) });
@ -805,7 +798,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const operand_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (inputs) |input| {
const ranked_type = input.getType().as(mlir.RankedTensorType);
const ranked_type = input.getType().as(mlir.RankedTensorType).?;
const ol = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_type.getRank() ..];
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
}
@ -824,7 +817,7 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const result_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (res_types) |t| {
const ranked_t = t.as(mlir.RankedTensorType);
const ranked_t = t.as(mlir.RankedTensorType).?;
const rl = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - ranked_t.getRank() ..];
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
}
@ -846,13 +839,10 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
pub const DotDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(DotDimensionNumbersAttribute, .{
.is_a_fn = c.stablehloAttributeIsADotDimensionNumbers,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsADotDimensionNumbers;
const Self = DotDimensionNumbersAttribute;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub fn init(ctx: mlir.Context, args: struct {
lhs_batching_dimensions: []const i64,
@ -860,9 +850,9 @@ pub const DotDimensionNumbersAttribute = struct {
lhs_contracting_dimensions: []const i64,
rhs_contracting_dimensions: []const i64,
}) Self {
return Self.wrap(
c.stablehloDotDimensionNumbersGet(
ctx.inner(),
return .{
._inner = c.stablehloDotDimensionNumbersGet(
ctx._inner,
@intCast(args.lhs_batching_dimensions.len),
args.lhs_batching_dimensions.ptr,
@intCast(args.rhs_batching_dimensions.len),
@ -872,52 +862,49 @@ pub const DotDimensionNumbersAttribute = struct {
@intCast(args.rhs_contracting_dimensions.len),
args.rhs_contracting_dimensions.ptr,
),
);
};
}
pub fn getLhsBatchingDimensionsSize(self: Self) usize {
return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self.inner()));
return @intCast(c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(self._inner));
}
pub fn getLhsBatchingDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self.inner(), @intCast(pos));
return c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(self._inner, @intCast(pos));
}
pub fn getRhsBatchingDimensionsSize(self: Self) usize {
return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self.inner()));
return @intCast(c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(self._inner));
}
pub fn getRhsBatchingDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self.inner(), @intCast(pos));
return c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(self._inner, @intCast(pos));
}
pub fn getLhsContractingDimensionsSize(self: Self) usize {
return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self.inner()));
return @intCast(c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(self._inner));
}
pub fn getLhsContractingDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self.inner(), @intCast(pos));
return c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(self._inner, @intCast(pos));
}
pub fn getRhsContractingDimensionsSize(self: Self) usize {
return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self.inner()));
return @intCast(c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(self._inner));
}
pub fn getRhsContractingDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self.inner(), @intCast(pos));
return c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(self._inner, @intCast(pos));
}
};
pub const GatherDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(GatherDimensionNumbersAttribute, .{
.is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsAGatherDimensionNumbers;
const Self = GatherDimensionNumbersAttribute;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub fn init(
ctx: mlir.Context,
@ -928,9 +915,9 @@ pub const GatherDimensionNumbersAttribute = struct {
start_index_map: []const i64,
index_vector_dim: i64,
) Self {
return Self.wrap(
c.stablehloGatherDimensionNumbersGet(
ctx.inner(),
return .{
._inner = c.stablehloGatherDimensionNumbersGet(
ctx._inner,
@intCast(offset_dims.len),
offset_dims.ptr,
@intCast(collapsed_slice_dims.len),
@ -943,64 +930,61 @@ pub const GatherDimensionNumbersAttribute = struct {
start_index_map.ptr,
index_vector_dim,
),
);
};
}
pub fn getOffsetDimsSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self.inner()));
return @intCast(c.stablehloGatherDimensionNumbersGetOffsetDimsSize(self._inner));
}
pub fn getOffsetDimsElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self.inner(), @intCast(pos));
return c.stablehloGatherDimensionNumbersGetOffsetDimsElem(self._inner, @intCast(pos));
}
pub fn getCollapsedSliceDimsSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self.inner()));
return @intCast(c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(self._inner));
}
pub fn getCollapsedSliceDimsElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self.inner(), @intCast(pos));
return c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(self._inner, @intCast(pos));
}
pub fn getStartIndexMapSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self.inner()));
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndexMapSize(self._inner));
}
pub fn getOperandBatchingDimsSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self.inner()));
return @intCast(c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(self._inner));
}
pub fn getOperandBatchingDimsElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self.inner(), @intCast(pos));
return c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(self._inner, @intCast(pos));
}
pub fn getStartIndicesBatchingDimsSize(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self.inner()));
return @intCast(c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(self._inner));
}
pub fn getStartIndicesBatchingDimsElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self.inner(), @intCast(pos));
return c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(self._inner, @intCast(pos));
}
pub fn getStartIndexMapElem(self: Self, pos: usize) i64 {
return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self.inner(), @intCast(pos));
return c.stablehloGatherDimensionNumbersGetStartIndexMapElem(self._inner, @intCast(pos));
}
pub fn getIndexVectorDim(self: Self) usize {
return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self.inner()));
return @intCast(c.stablehloGatherDimensionNumbersGetIndexVectorDim(self._inner));
}
};
pub const ConvDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(ConvDimensionNumbersAttribute, .{
.is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsAConvDimensionNumbers;
const Self = ConvDimensionNumbersAttribute;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub fn init(ctx: mlir.Context, args: struct {
input_batch_dimension: i64,
@ -1013,9 +997,9 @@ pub const ConvDimensionNumbersAttribute = struct {
output_feature_dimension: i64,
output_spatial_dimensions: []const i64,
}) Self {
return Self.wrap(
c.stablehloConvDimensionNumbersGet(
ctx.inner(),
return .{
._inner = c.stablehloConvDimensionNumbersGet(
ctx._inner,
args.input_batch_dimension,
args.input_feature_dimension,
@intCast(args.input_spatial_dimensions.len),
@ -1029,67 +1013,64 @@ pub const ConvDimensionNumbersAttribute = struct {
@intCast(args.output_spatial_dimensions.len),
args.output_spatial_dimensions.ptr,
),
);
};
}
pub fn getInputBatchDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetInputBatchDimension(self.inner());
return c.stablehloConvDimensionNumbersGetInputBatchDimension(self._inner);
}
pub fn getInputFeatureDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self.inner());
return c.stablehloConvDimensionNumbersGetInputFeatureDimension(self._inner);
}
pub fn getInputSpatialDimensionsSize(self: Self) usize {
return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self.inner()));
return @intCast(c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(self._inner));
}
pub fn getInputSpatialDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self.inner(), @intCast(pos));
return c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(self._inner, @intCast(pos));
}
pub fn getKernelInputFeatureDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self.inner());
return c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(self._inner);
}
pub fn getKernelOutputFeatureDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self.inner());
return c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(self._inner);
}
pub fn getKernelSpatialDimensionsSize(self: Self) usize {
return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self.inner()));
return @intCast(c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(self._inner));
}
pub fn getKernelSpatialDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self.inner(), @intCast(pos));
return c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(self._inner, @intCast(pos));
}
pub fn getOutputBatchDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self.inner());
return c.stablehloConvDimensionNumbersGetOutputBatchDimension(self._inner);
}
pub fn getOutputFeatureDimension(self: Self) i64 {
return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self.inner());
return c.stablehloConvDimensionNumbersGetOutputFeatureDimension(self._inner);
}
pub fn getOutputSpatialDimensionsSize(self: Self) usize {
return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self.inner()));
return @intCast(c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(self._inner));
}
pub fn getOutputSpatialDimensionsElem(self: Self, pos: usize) i64 {
return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self.inner(), @intCast(pos));
return c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(self._inner, @intCast(pos));
}
};
pub const OutputOperandAliasAttribute = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(OutputOperandAliasAttribute, .{
.is_a_fn = c.stablehloAttributeIsAOutputOperandAlias,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsAOutputOperandAlias;
pub const asAttr = mlir.Attribute.fromAny(OutputOperandAliasAttribute);
pub const eql = mlir.Attribute.eqlAny(OutputOperandAliasAttribute);
pub fn init(
ctx: mlir.Context,
@ -1097,27 +1078,24 @@ pub const OutputOperandAliasAttribute = struct {
operand_index: i64,
operand_tuple_indices: []const i64,
) OutputOperandAliasAttribute {
return OutputOperandAliasAttribute.wrap(c.stablehloOutputOperandAliasGet(
ctx.inner(),
return .{ ._inner = c.stablehloOutputOperandAliasGet(
ctx._inner,
@intCast(output_tuple_indices.len),
output_tuple_indices.ptr,
@intCast(operand_index),
@intCast(operand_tuple_indices.len),
operand_tuple_indices.ptr,
));
) };
}
};
pub const PrecisionAttribute = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(PrecisionAttribute, .{
.is_a_fn = c.stablehloAttributeIsAPrecisionAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsAPrecisionAttr;
const Self = PrecisionAttribute;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Precision = enum {
DEFAULT,
@ -1126,11 +1104,11 @@ pub const PrecisionAttribute = struct {
};
pub fn init(ctx: mlir.Context, value: Precision) Self {
return Self.wrap(c.stablehloPrecisionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
return .{ ._inner = c.stablehloPrecisionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
}
pub fn getValue(self: Self) Precision {
const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self.inner()));
const value = mlir.fromStringRef(c.stablehloPrecisionAttrGetValue(self._inner));
return std.meta.stringToEnum(Precision, value) orelse unreachable;
}
};
@ -1138,13 +1116,10 @@ pub const PrecisionAttribute = struct {
pub const ComparisonDirection = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(ComparisonDirection, .{
.is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsAComparisonDirectionAttr;
const Self = ComparisonDirection;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Direction = enum {
EQ,
@ -1156,11 +1131,11 @@ pub const ComparisonDirection = struct {
};
pub fn init(ctx: mlir.Context, value: Direction) Self {
return Self.wrap(c.stablehloComparisonDirectionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
return .{ ._inner = c.stablehloComparisonDirectionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
}
pub fn getValue(self: Self) Direction {
const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self.inner()));
const value = mlir.fromStringRef(c.stablehloComparisonDirectionAttrGetValue(self._inner));
return std.meta.stringToEnum(Direction, value) orelse unreachable;
}
};
@ -1168,13 +1143,10 @@ pub const ComparisonDirection = struct {
pub const CompareType = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(CompareType, .{
.is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsAComparisonTypeAttr;
const Self = CompareType;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum {
SIGNED,
@ -1184,11 +1156,11 @@ pub const CompareType = struct {
};
pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloComparisonTypeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
return .{ ._inner = c.stablehloComparisonTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
}
pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self.inner()));
const value = mlir.fromStringRef(c.stablehloComparisonTypeAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable;
}
};
@ -1196,13 +1168,10 @@ pub const CompareType = struct {
pub const Transpose = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(Transpose, .{
.is_a_fn = c.stablehloAttributeIsATransposeAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsATransposeAttr;
const Self = Transpose;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum {
NO_TRANSPOSE,
@ -1211,11 +1180,11 @@ pub const Transpose = struct {
};
pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloTransposeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
return .{ ._inner = c.stablehloTransposeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
}
pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self.inner()));
const value = mlir.fromStringRef(c.stablehloTransposeAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable;
}
};
@ -1223,13 +1192,10 @@ pub const Transpose = struct {
pub const FftType = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(FftType, .{
.is_a_fn = c.stablehloAttributeIsAFftTypeAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsAFftTypeAttr;
const Self = FftType;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum {
FFT,
@ -1239,11 +1205,11 @@ pub const FftType = struct {
};
pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloFftTypeAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
return .{ ._inner = c.stablehloFftTypeAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
}
pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self.inner()));
const value = mlir.fromStringRef(c.stablehloFftTypeAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable;
}
};
@ -1251,13 +1217,10 @@ pub const FftType = struct {
pub const RngDistribution = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(RngDistribution, .{
.is_a_fn = c.stablehloAttributeIsARngDistributionAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsARngDistributionAttr;
const Self = RngDistribution;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum {
UNIFORM,
@ -1265,11 +1228,11 @@ pub const RngDistribution = struct {
};
pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloRngDistributionAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
return .{ ._inner = c.stablehloRngDistributionAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
}
pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self.inner()));
const value = mlir.fromStringRef(c.stablehloRngDistributionAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable;
}
};
@ -1277,13 +1240,10 @@ pub const RngDistribution = struct {
pub const RngAlgorithm = struct {
_inner: c.MlirAttribute,
pub usingnamespace mlir.MlirHelpers(RngAlgorithm, .{
.is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr,
.is_null_fn = c.mlirAttributeIsNull,
.dump_fn = c.mlirAttributeDump,
.equal_fn = c.mlirAttributeEqual,
});
pub const is_a_fn = c.stablehloAttributeIsARngAlgorithmAttr;
const Self = RngAlgorithm;
pub const asAttr = mlir.Attribute.fromAny(Self);
pub const eql = mlir.Attribute.eqlAny(Self);
pub const Type = enum {
DEFAULT,
@ -1292,11 +1252,11 @@ pub const RngAlgorithm = struct {
};
pub fn init(ctx: mlir.Context, value: Type) Self {
return Self.wrap(c.stablehloRngAlgorithmAttrGet(ctx.inner(), mlir.stringRef(@tagName(value))));
return .{ ._inner = c.stablehloRngAlgorithmAttrGet(ctx._inner, mlir.stringRef(@tagName(value))) };
}
pub fn getValue(self: Self) Type {
const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self.inner()));
const value = mlir.fromStringRef(c.stablehloRngAlgorithmAttrGetValue(self._inner));
return std.meta.stringToEnum(Type, value) orelse unreachable;
}
};

File diff suppressed because it is too large Load Diff

View File

@ -1,176 +0,0 @@
const mlir = @This();
const builtin = @import("builtin");
const std = @import("std");
const stdx = @import("stdx");
const dtype = @import("dtype.zig");
const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor;
const log = std.log.scoped(.@"zml/mlir");
pub usingnamespace @import("mlir");
pub const ext = struct {
pub fn mlirType(ctx: mlir.Context, sh: Shape) mlir.Type {
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type);
}
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
return switch (dt) {
.bool => .bool,
.i8 => .i8,
.i16 => .i16,
.i32 => .i32,
.i64 => .i64,
.u8 => .u8,
.u16 => .u16,
.u32 => .u32,
.u64 => .u64,
.bf16 => .bf16,
.f16 => .f16,
.f32 => .f32,
.f64 => .f64,
else => null,
};
}
pub fn denseElementsAttr(dt: dtype.DataType, _: usize, bytes: []const u8, ranked_type: mlir.RankedTensorType) mlir.Attribute {
const ranked_type_ = ranked_type.as(mlir.Type);
return switch (dt) {
.bool => mlir.DenseElementsAttribute(.bool).init(ranked_type_, bytes).as(mlir.Attribute),
.i8 => mlir.DenseElementsAttribute(.i8).init(ranked_type_, bytes).as(mlir.Attribute),
.i16 => mlir.DenseElementsAttribute(.i16).init(ranked_type_, bytes).as(mlir.Attribute),
.i32 => mlir.DenseElementsAttribute(.i32).init(ranked_type_, bytes).as(mlir.Attribute),
.i64 => mlir.DenseElementsAttribute(.i64).init(ranked_type_, bytes).as(mlir.Attribute),
.u8 => mlir.DenseElementsAttribute(.u8).init(ranked_type_, bytes).as(mlir.Attribute),
.u16 => mlir.DenseElementsAttribute(.u16).init(ranked_type_, bytes).as(mlir.Attribute),
.u32 => mlir.DenseElementsAttribute(.u32).init(ranked_type_, bytes).as(mlir.Attribute),
.u64 => mlir.DenseElementsAttribute(.u64).init(ranked_type_, bytes).as(mlir.Attribute),
.bf16 => mlir.DenseElementsAttribute(.bf16).init(ranked_type_, bytes).as(mlir.Attribute),
.f16 => mlir.DenseElementsAttribute(.f16).init(ranked_type_, bytes).as(mlir.Attribute),
.f32 => mlir.DenseElementsAttribute(.f32).init(ranked_type_, bytes).as(mlir.Attribute),
.f64 => mlir.DenseElementsAttribute(.f64).init(ranked_type_, bytes).as(mlir.Attribute),
inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)),
};
}
pub const RankedTensorType = struct {
pub fn fromShape(ctx: mlir.Context, sh: Shape) mlir.RankedTensorType {
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype()));
}
};
pub const Type = struct {
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
return switch (dt) {
.bool => mlir.IntegerType(.i1).init(ctx).as(mlir.Type),
.f8e4m3b11fnuz => mlir.FloatType(.f8e4m3b11fnuz).init(ctx).as(mlir.Type),
.f8e4m3fn => mlir.FloatType(.f8e4m3fn).init(ctx).as(mlir.Type),
.f8e4m3fnuz => mlir.FloatType(.f8e4m3fnuz).init(ctx).as(mlir.Type),
.f8e5m2 => mlir.FloatType(.f8e5m2).init(ctx).as(mlir.Type),
.f8e5m2fnuz => mlir.FloatType(.f8e5m2fnuz).init(ctx).as(mlir.Type),
.bf16 => mlir.FloatType(.bf16).init(ctx).as(mlir.Type),
.f16 => mlir.FloatType(.f16).init(ctx).as(mlir.Type),
.f32 => mlir.FloatType(.f32).init(ctx).as(mlir.Type),
.f64 => mlir.FloatType(.f64).init(ctx).as(mlir.Type),
.i4 => mlir.IntegerType(.i4).init(ctx).as(mlir.Type),
.i8 => mlir.IntegerType(.i8).init(ctx).as(mlir.Type),
.i16 => mlir.IntegerType(.i16).init(ctx).as(mlir.Type),
.i32 => mlir.IntegerType(.i32).init(ctx).as(mlir.Type),
.i64 => mlir.IntegerType(.i64).init(ctx).as(mlir.Type),
.u4 => mlir.IntegerType(.u4).init(ctx).as(mlir.Type),
.u8 => mlir.IntegerType(.u8).init(ctx).as(mlir.Type),
.u16 => mlir.IntegerType(.u16).init(ctx).as(mlir.Type),
.u32 => mlir.IntegerType(.u32).init(ctx).as(mlir.Type),
.u64 => mlir.IntegerType(.u64).init(ctx).as(mlir.Type),
.c64 => mlir.ComplexType(.c64).init(ctx).as(mlir.Type),
.c128 => mlir.ComplexType(.c128).init(ctx).as(mlir.Type),
};
}
pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
const mapping = .{
.{ .bool, mlir.IntegerType(.i1) },
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
.{ .f8e5m2, mlir.FloatType(.f8e5m2) },
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
.{ .bf16, mlir.FloatType(.bf16) },
.{ .f16, mlir.FloatType(.f16) },
.{ .f32, mlir.FloatType(.f32) },
.{ .f64, mlir.FloatType(.f64) },
.{ .i4, mlir.IntegerType(.i4) },
.{ .i8, mlir.IntegerType(.i8) },
.{ .i16, mlir.IntegerType(.i16) },
.{ .i32, mlir.IntegerType(.i32) },
.{ .i64, mlir.IntegerType(.i64) },
.{ .u4, mlir.IntegerType(.u4) },
.{ .u8, mlir.IntegerType(.u8) },
.{ .u16, mlir.IntegerType(.u16) },
.{ .u32, mlir.IntegerType(.u32) },
.{ .u64, mlir.IntegerType(.u64) },
.{ .c64, mlir.ComplexType(.c64) },
.{ .c128, mlir.ComplexType(.c128) },
};
inline for (mapping) |entry| {
const dt, const mlirT = entry;
if (mlir_type.is_a(mlirT)) {
return dt;
}
}
stdx.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
}
};
pub const Attribute = struct {
pub fn fromData(data: dtype.Data, ctx: mlir.Context) mlir.Attribute {
switch (data) {
.bool => |val| {
return mlir.IntegerAttribute(.i1).init(ctx, @intFromBool(val)).as(mlir.Attribute);
},
inline .f8e4m3b11fnuz, .f8e4m3fn, .f8e4m3fnuz, .f8e5m2, .f8e5m2fnuz => |val, tag| {
const float_type = @field(mlir.FloatTypes, @tagName(tag));
const float_attr = mlir.FloatAttribute(float_type).init(ctx, val.toF32());
return float_attr.as(mlir.Attribute);
},
inline .i4, .i8, .i16, .i32, .i64, .u4, .u8, .u16, .u32, .u64 => |val, tag| {
const int_type = @field(mlir.IntegerTypes, @tagName(tag));
const int_attr = mlir.IntegerAttribute(int_type).init(ctx, @intCast(val));
return int_attr.as(mlir.Attribute);
},
inline else => |_, tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
}
}
};
pub const DenseElementsAttribute = struct {
pub fn fromData(data: dtype.Data, result_type: mlir.Type) mlir.Attribute {
return switch (data.dtype()) {
.bool => mlir.DenseElementsAttribute(.bool).init(result_type, data.constSlice()).as(mlir.Attribute),
.i8 => mlir.DenseElementsAttribute(.i8).init(result_type, data.constSlice()).as(mlir.Attribute),
.i16 => mlir.DenseElementsAttribute(.i16).init(result_type, data.constSlice()).as(mlir.Attribute),
.i32 => mlir.DenseElementsAttribute(.i32).init(result_type, data.constSlice()).as(mlir.Attribute),
.i64 => mlir.DenseElementsAttribute(.i64).init(result_type, data.constSlice()).as(mlir.Attribute),
.u8 => mlir.DenseElementsAttribute(.u8).init(result_type, data.constSlice()).as(mlir.Attribute),
.u16 => mlir.DenseElementsAttribute(.u16).init(result_type, data.constSlice()).as(mlir.Attribute),
.u32 => mlir.DenseElementsAttribute(.u32).init(result_type, data.constSlice()).as(mlir.Attribute),
.u64 => mlir.DenseElementsAttribute(.u64).init(result_type, data.constSlice()).as(mlir.Attribute),
.bf16 => mlir.DenseElementsAttribute(.bf16).init(result_type, data.constSlice()).as(mlir.Attribute),
.f16 => mlir.DenseElementsAttribute(.f16).init(result_type, data.constSlice()).as(mlir.Attribute),
.f32 => mlir.DenseElementsAttribute(.f32).init(result_type, data.constSlice()).as(mlir.Attribute),
.f64 => mlir.DenseElementsAttribute(.f64).init(result_type, data.constSlice()).as(mlir.Attribute),
inline else => |tag| stdx.debug.panic("Unsupported data type: {any}", .{tag}),
};
}
};
};

101
zml/mlirx.zig Normal file
View File

@ -0,0 +1,101 @@
const std = @import("std");
const mlir = @import("mlir");
const dtype = @import("dtype.zig");
const Shape = @import("shape.zig").Shape;
const mlirx = @This();
/// Returns the mlir.Type corresponding to a given zml.Shape.
pub fn tensorType(ctx: mlir.Context, sh: Shape) mlir.Type {
return .tensor(sh.dims(), mlirx.Type.fromDType(ctx, sh.dtype()));
}
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
return switch (dt) {
.bool => .bool,
.i8 => .i8,
.i16 => .i16,
.i32 => .i32,
.i64 => .i64,
.u8 => .u8,
.u16 => .u16,
.u32 => .u32,
.u64 => .u64,
.bf16 => .bf16,
.f16 => .f16,
.f32 => .f32,
.f64 => .f64,
else => null,
};
}
pub const Type = struct {
pub fn fromDType(ctx: mlir.Context, dt: dtype.DataType) mlir.Type {
return switch (dt) {
.bool => .int(ctx, .i1),
.f8e4m3b11fnuz => .float(ctx, .f8e4m3b11fnuz),
.f8e4m3fn => .float(ctx, .f8e4m3fn),
.f8e4m3fnuz => .float(ctx, .f8e4m3fnuz),
.f8e5m2 => .float(ctx, .f8e5m2),
.f8e5m2fnuz => .float(ctx, .f8e5m2fnuz),
.bf16 => .float(ctx, .bf16),
.f16 => .float(ctx, .f16),
.f32 => .float(ctx, .f32),
.f64 => .float(ctx, .f64),
.i4 => .int(ctx, .i4),
.i8 => .int(ctx, .i8),
.i16 => .int(ctx, .i16),
.i32 => .int(ctx, .i32),
.i64 => .int(ctx, .i64),
.u4 => .int(ctx, .u4),
.u8 => .int(ctx, .u8),
.u16 => .int(ctx, .u16),
.u32 => .int(ctx, .u32),
.u64 => .int(ctx, .u64),
.c64 => .complex(ctx, .c64),
.c128 => .complex(ctx, .c128),
};
}
pub fn toDType(mlir_type: mlir.Type) dtype.DataType {
const mapping = .{
.{ .bool, mlir.IntegerType(.i1) },
.{ .f8e4m3b11fnuz, mlir.FloatType(.f8e4m3b11fnuz) },
.{ .f8e4m3fn, mlir.FloatType(.f8e4m3fn) },
.{ .f8e4m3fnuz, mlir.FloatType(.f8e4m3fnuz) },
.{ .f8e5m2, mlir.FloatType(.f8e5m2) },
.{ .f8e5m2fnuz, mlir.FloatType(.f8e5m2fnuz) },
.{ .bf16, mlir.FloatType(.bf16) },
.{ .f16, mlir.FloatType(.f16) },
.{ .f32, mlir.FloatType(.f32) },
.{ .f64, mlir.FloatType(.f64) },
.{ .i4, mlir.IntegerType(.i4) },
.{ .i8, mlir.IntegerType(.i8) },
.{ .i16, mlir.IntegerType(.i16) },
.{ .i32, mlir.IntegerType(.i32) },
.{ .i64, mlir.IntegerType(.i64) },
.{ .u4, mlir.IntegerType(.u4) },
.{ .u8, mlir.IntegerType(.u8) },
.{ .u16, mlir.IntegerType(.u16) },
.{ .u32, mlir.IntegerType(.u32) },
.{ .u64, mlir.IntegerType(.u64) },
.{ .c64, mlir.ComplexType(.c64) },
.{ .c128, mlir.ComplexType(.c128) },
};
inline for (mapping) |entry| {
const dt, const mlirT = entry;
if (mlirT.is_a_fn(mlir_type._inner)) {
return dt;
}
}
std.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
}
};

View File

@ -2,21 +2,18 @@ const std = @import("std");
const asynk = @import("async");
const dialect = @import("mlir/dialects");
const runfiles = @import("runfiles");
const mlir = @import("mlir");
const stdx = @import("stdx");
const xla_pb = @import("//xla:xla_proto");
const BaseExe = @import("exe.zig").BaseExe;
const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("tensor.zig").Bufferized;
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const Location = mlir.Location;
const mlirx = @import("mlirx.zig");
const ops = @import("ops.zig");
const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;
const Target = @import("platform.zig").Target;
const Tensor = @import("tensor.zig").Tensor;
const Tracer = @import("tools/tracer.zig").Tracer;
@ -170,8 +167,8 @@ pub const CompilationContext = struct {
const sharding = self._platform.sharding();
const mlir_ctx = self._mlir_ctx;
module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr());
module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
module.op().setAttributeByName("mhlo.num_replicas", .int(mlir_ctx, .i32, sharding.num_replicas));
module.op().setAttributeByName("mhlo.num_partitions", .int(mlir_ctx, .i32, sharding.num_partitions));
const module_hash = computeModuleHash(self._platform, module);
var module_dir: ?[]const u8 = null;
@ -346,7 +343,7 @@ pub const CompilationContext = struct {
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
const input_types = try arena.alloc(mlir.Type, tensor_count);
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
for (input_types, input_shapes.items) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh);
const og_block_args = self._block_args;
defer {
@ -947,7 +944,7 @@ pub fn fillMlirTypes(v: anytype, mlir_ctx: mlir.Context, types: []mlir.Type) voi
var context = LocalContext{ .mlir_ctx = mlir_ctx, .types = types };
meta.visit((struct {
fn cb(inner_context: *LocalContext, tensor: *const Tensor) void {
inner_context.types[inner_context.index] = mlir.ext.mlirType(inner_context.mlir_ctx, tensor.shape());
inner_context.types[inner_context.index] = mlirx.tensorType(inner_context.mlir_ctx, tensor.shape());
inner_context.index += 1;
}
}).cb, &context, v);

View File

@ -5,7 +5,7 @@ const dialect = @import("mlir/dialects");
const Context = @import("../context.zig").Context;
const DataType = @import("../dtype.zig").DataType;
const Data = @import("../dtype.zig").Data;
const mlir = @import("../mlir.zig");
const mlirx = @import("../mlirx.zig");
const module = @import("../module.zig");
const CompilationContext = module.CompilationContext;
const SdpaOpts = @import("../nn.zig").SdpaOpts;
@ -130,7 +130,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
.api_version = .original,
},
&.{
mlir.ext.mlirType(mlir_ctx, q.shape()),
mlirx.tensorType(mlir_ctx, q.shape()),
.tensor(&.{0}, .int(mlir_ctx, .u8)),
},
loc,

View File

@ -1,24 +1,16 @@
const std = @import("std");
const assert = std.debug.assert;
const mlir = @import("mlir");
const stdx = @import("stdx");
const _collectAxes = @import("tensor.zig")._collectAxes;
const buffer = @import("buffer.zig");
const Buffer = buffer.Buffer;
const Bufferized = @import("tensor.zig").Bufferized;
const Buffer = @import("buffer.zig").Buffer;
const CompilationContext = @import("module.zig").CompilationContext;
const Context = @import("context.zig").Context;
const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType;
const helpers = @import("helpers.zig");
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const module = @import("module.zig");
const CompilationContext = module.CompilationContext;
const mlirx = @import("mlirx.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;
const Tensor = @import("tensor.zig").Tensor;
const EnumLiteral = @TypeOf(.enum_literal);
@ -200,14 +192,14 @@ pub fn reduce(
mlir_ctx,
val,
inner_ctx.broadcasting_axes[0 .. tensor.rank() - inner_ctx.n_reduced],
mlir.ext.RankedTensorType.fromShape(mlir_ctx, reduced_shape).as(mlir.Type),
mlirx.tensorType(mlir_ctx, reduced_shape),
inner_ctx.loc,
);
tensor.* = Tensor._result(reduced_shape, broad_val.result(0));
inner_ctx.index += 1;
}
}).cb, &local_context, &res);
assert(local_context.index == op.numResults());
std.debug.assert(local_context.index == op.numResults());
return res;
}
@ -248,7 +240,8 @@ pub fn reduceWindow(
.{ "window_strides", .dense(ctx.mlirCtx(), .i64, opts.window_strides) },
.{ "base_dilations", .dense(ctx.mlirCtx(), .i64, opts.base_dilations) },
.{ "window_dilations", .dense(ctx.mlirCtx(), .i64, opts.window_dilations) },
.{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, opts.padding) },
// Cast the [][2]i64 to []i64 (safe)
.{ "padding", .denseElements(ctx.mlirCtx(), &.{ @intCast(opts.padding.len), 2 }, .i64, @ptrCast(opts.padding)) },
},
.location = loc,
});
@ -609,8 +602,8 @@ pub fn sort(
.result_type_inference = true,
.blocks = &.{block},
.attributes = &.{
.{ "dimension", mlir.IntegerAttribute(.i64).init(ctx.mlirCtx(), dimension).as(mlir.Attribute) },
.{ "is_stable", mlir.BoolAttribute.init(ctx.mlirCtx(), is_stable).as(mlir.Attribute) },
.{ "dimension", .int(ctx.mlirCtx(), .i64, dimension) },
.{ "is_stable", .boolean(ctx.mlirCtx(), is_stable) },
},
.location = loc,
});
@ -767,7 +760,7 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
inner_ctx.index += 1;
}
}).cb, &context, &res);
assert(context.index == op.numResults());
std.debug.assert(context.index == op.numResults());
return res;
}
@ -817,7 +810,7 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T
var res_types: [outputs.len]mlir.Type = undefined;
inline for (outputs, 0..) |output, i| {
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
res_types[i] = mlirx.tensorType(ctx.mlirCtx(), output);
}
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &.{
@ -1031,7 +1024,7 @@ pub fn scatter(
inner_ctx.index += 1;
}
}).cb, &local_context, &res);
assert(local_context.index == op.numResults());
std.debug.assert(local_context.index == op.numResults());
return res;
}
@ -1327,30 +1320,30 @@ pub fn customCall(target_name: [:0]const u8, inputs: anytype, outputs: anytype,
}
fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor {
const ctx = module.CompilationContext.current();
const ctx = CompilationContext.current();
const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable;
ctx.extractValues(inputs, values);
const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable;
for (outputs, 0..) |output, i| {
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
res_types[i] = mlirx.tensorType(ctx.mlirCtx(), output);
}
const metadata_type_info = @typeInfo(@TypeOf(metadata));
var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined;
inline for (metadata_type_info.@"struct".fields, 0..) |field, i| {
const attribute: mlir.Attribute = switch (@typeInfo(field.type)) {
.int, .comptime_int => mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))),
.int, .comptime_int => .int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))),
else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)),
};
metadata_attributes_tuple[i] = .{ field.name, attribute };
}
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(.{
.{ "pjrt_api", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) },
.{ "pjrt_client", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) },
} ++ metadata_attributes_tuple));
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(metadata_attributes_tuple ++ [_]mlir.AttrTuple{
.{ "pjrt_api", .int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) },
.{ "pjrt_client", .int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) },
}));
const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable;
for (inputs, 0..) |input, i| {

View File

@ -1,20 +1,17 @@
const std = @import("std");
const assert = std.debug.assert;
const testing = std.testing;
const builtin = @import("builtin");
const mlir = @import("mlir");
const stdx = @import("stdx");
const Buffer = @import("buffer.zig").Buffer;
const CompilationContext = @import("module.zig").CompilationContext;
const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Memory = @import("buffer.zig").Buffer.Memory;
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const Location = mlir.Location;
const module = @import("module.zig");
const CompilationContext = module.CompilationContext;
const mlirx = @import("mlirx.zig");
const ops = @import("ops.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
@ -112,12 +109,12 @@ pub const Tensor = struct {
///
/// The shape is derived from the type of the mlir.Value.
pub fn fromMlirValue(val: mlir.Value) Tensor {
const ranked_tensor = val.getType().as(mlir.RankedTensorType);
const ranked_tensor = val.getType().as(mlir.RankedTensorType).?;
const n = ranked_tensor.getRank();
stdx.debug.assert(n <= MAX_RANK, "Can't represent MLIR tensor of rank {}, max supported rank is {}.", .{ n, MAX_RANK });
var sh: Shape = .{ ._dtype = mlir.ext.Type.toDType(ranked_tensor.getElementType()) };
var sh: Shape = .{ ._dtype = mlirx.Type.toDType(ranked_tensor.getElementType()) };
for (0..n) |i| {
sh._dims.appendAssumeCapacity(ranked_tensor.getDimension(i));
}
@ -322,7 +319,7 @@ pub const Tensor = struct {
const op = dialect.stablehlo.bitcast_convert(
self.getContext().mlirCtx(),
self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type),
mlirx.tensorType(self.getContext().mlirCtx(), res_shape),
loc,
);
@ -559,8 +556,8 @@ pub const Tensor = struct {
ctx.mlirCtx(),
self.algorithm,
self._state.value(),
mlir.ext.mlirType(ctx.mlirCtx(), self._state._shape),
mlir.ext.mlirType(ctx.mlirCtx(), sh),
mlirx.tensorType(ctx.mlirCtx(), self._state._shape),
mlirx.tensorType(ctx.mlirCtx(), sh),
loc,
);
return .{ self.update(op.result(0)), _result(sh, op.result(1)) };
@ -870,7 +867,7 @@ pub const Tensor = struct {
self.value(),
other.value(),
used_opts,
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape).as(mlir.Type),
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc,
);
@ -1052,7 +1049,7 @@ pub const Tensor = struct {
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
const mlir_ctx = self.getContext().mlirCtx();
const res_type = mlir.ext.mlirType(mlir_ctx, self.shape().withDtype(to));
const res_type = mlirx.tensorType(mlir_ctx, self.shape().withDtype(to));
const op = dialect.stablehlo.convert(mlir_ctx, self.value(), res_type, loc);
return _result(self._shape.withDtype(to), op.result(0));
}
@ -1217,7 +1214,7 @@ pub const Tensor = struct {
mlir_ctx,
lhs.value(),
rhs.value(),
mlir.ext.mlirType(mlir_ctx, res_shape),
mlirx.tensorType(mlir_ctx, res_shape),
loc,
.{
.lhs_batching_dimensions = lhs_batching_axes.constSlice(),
@ -1392,7 +1389,7 @@ pub const Tensor = struct {
[2][5]f32{ .{ 0, 1, 1, 0, 1 }, .{ 3, 1, 0, 2, 1 } },
);
const res = try zml.testing.compileAndCall(platform, Local._cumsum, .{x});
try testing.expectEqual(
try std.testing.expectEqual(
[2][5]f32{ .{ 0, 1, 2, 2, 3 }, .{ 3, 4, 4, 6, 7 } },
try res.getValue([2][5]f32),
);
@ -1424,7 +1421,7 @@ pub const Tensor = struct {
const op = dialect.stablehlo.transpose(
self.getContext().mlirCtx(),
self.value(),
mlir.ext.mlirType(self.getContext().mlirCtx(), res_shape),
mlirx.tensorType(self.getContext().mlirCtx(), res_shape),
loc,
.{ .permutation = toI64(permutation) },
);
@ -1457,7 +1454,7 @@ pub const Tensor = struct {
const reshaped_val = dialect.stablehlo.reshape(
self.getContext().mlirCtx(),
self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc,
);
return _result(new_shape, reshaped_val.result(0));
@ -1474,7 +1471,7 @@ pub const Tensor = struct {
const reshaped_val = dialect.stablehlo.reshape(
self.getContext().mlirCtx(),
self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc,
);
return _result(new_shape, reshaped_val.result(0));
@ -1512,7 +1509,7 @@ pub const Tensor = struct {
const reshaped_val = dialect.stablehlo.reshape(
self.getContext().mlirCtx(),
self.value(),
mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), new_shape),
mlirx.tensorType(self.getContext().mlirCtx(), new_shape),
loc,
);
// log.debug("flatten({d}, {d}) -> {d}", .{ self.dims(), axis_, new_shape[0 .. self.rank() - 1] });
@ -1586,7 +1583,7 @@ pub const Tensor = struct {
const mlir_ctx = self.getContext().mlirCtx();
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "slices={any}", .{slices});
const result_type = mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type);
const result_type = mlirx.tensorType(mlir_ctx, res_shape);
const slice_op = dialect.stablehlo.slice(
mlir_ctx,
self.value(),
@ -1620,15 +1617,15 @@ pub const Tensor = struct {
{
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } });
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
try std.testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
}
{
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } });
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
try std.testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
}
{
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } });
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
try std.testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
}
}
@ -1838,7 +1835,7 @@ pub const Tensor = struct {
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
const sh = Shape.init(.{n_steps}, dt);
var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
var op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc);
var res = _result(sh, op.result(0));
if (args.step != 1) {
@ -1868,7 +1865,7 @@ pub const Tensor = struct {
var op = dialect.stablehlo.iota(
mlir_ctx,
a,
mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).as(mlir.Type),
mlirx.tensorType(mlir_ctx, res_shape),
loc,
);
return _result(res_shape, op.result(0));
@ -1890,7 +1887,7 @@ pub const Tensor = struct {
const loc = ctx.location(@src(), "linspace({}, dtype={})", .{ args, dt });
const sh = Shape.init(.{args.steps}, dt);
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc);
var res = _result(sh, iota_op.result(0));
if (args.steps != 1) {
@ -1933,21 +1930,19 @@ pub const Tensor = struct {
/// Returns a constant Tensor with the given value.
pub fn constant(dimz: anytype, val: Data) Tensor {
const sh = Shape.init(dimz, val.dtype());
const singleton_sh = Shape.init(.{}, val.dtype());
const ctx = CompilationContext.current().mlirCtx();
const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val });
const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type|
dialect.stablehlo.constant(ctx, res_type, elem_type, val.constSlice(), loc)
var constant_op = if (mlirx.denseElementAttrType(val.dtype())) |elem_type|
dialect.stablehlo.constant(ctx, &.{}, elem_type, val.constSlice(), loc)
else blk: {
// Not all dtype can be serialized in the IR. If that's not possible, use f32.
const val_f32 = val.as(f32);
break :blk dialect.stablehlo.constant(ctx, res_type, .f32, std.mem.asBytes(&val_f32), loc);
break :blk dialect.stablehlo.constant(ctx, &.{}, .f32, std.mem.asBytes(&val_f32), loc);
};
if (sh.rank() > 0) {
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type), loc);
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlirx.tensorType(ctx, sh), loc);
}
return _result(sh, constant_op.result(0)).convert(val.dtype());
}
@ -1955,10 +1950,9 @@ pub const Tensor = struct {
/// Embeds a buffer with concrete values into an Mlir program.
pub fn constantTensor(val: HostBuffer) Tensor {
const ctx = CompilationContext.current().mlirCtx();
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
const loc = ctx.location(@src());
const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.bytes(), loc);
const elem_type = mlirx.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
const constant_op = dialect.stablehlo.constant(ctx, val.shape().dims(), elem_type, val.bytes(), loc);
return _result(val.shape(), constant_op.result(0));
}
@ -1994,7 +1988,7 @@ pub const Tensor = struct {
return _result(res_shape, self.value());
}
const ctx = self.getContext();
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type);
const result_type = mlirx.tensorType(ctx.mlirCtx(), res_shape);
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
@ -2052,7 +2046,7 @@ pub const Tensor = struct {
/// Reshapes the input Tensor with the given shape.
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
const output_shape = self._shape.reshape(output_shape_);
const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape);
const tensor_type = mlirx.tensorType(self.getContext().mlirCtx(), output_shape);
const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape});
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
return _result(output_shape, reshape_value.result(0));
@ -2846,9 +2840,9 @@ pub const Tensor = struct {
const res = argmax.call(.{x});
const max_ = res.values.getValue(f32);
const max_idx = res.indices.getValue(i32);
try testing.expectEqual(max_, 7.9);
try std.testing.expectEqual(max_, 7.9);
// We should always return the first max found.
try testing.expectEqual(max_idx, 2);
try std.testing.expectEqual(max_idx, 2);
}
// Test with Nan
@ -2857,8 +2851,8 @@ pub const Tensor = struct {
const res = argmax.call(.{x});
const max_ = try res.values.getValue(f32);
const max_idx = try res.indices.getValue(i32);
try testing.expect(std.math.isNan(max_));
try testing.expectEqual(max_idx, 1);
try std.testing.expect(std.math.isNan(max_));
try std.testing.expectEqual(max_idx, 1);
}
}
@ -2907,7 +2901,7 @@ pub const Tensor = struct {
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 });
const res = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{} });
const res_cpu = try res.toHostAlloc(allocator);
try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
try std.testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
}
// 3D Tensor, dim = 1, descending
{
@ -2920,7 +2914,7 @@ pub const Tensor = struct {
});
const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{ .descending = true } });
const res = try res_dev.toHostAlloc(allocator);
try testing.expectEqualSlices(i32, &.{
try std.testing.expectEqualSlices(i32, &.{
4, 1, 1, 2, 0, 2, 0, 0, 3, 4,
2, 0, 4, 4, 1, 3, 4, 4, 1, 0,
1, 4, 2, 0, 2, 4, 2, 2, 0, 3,
@ -2942,7 +2936,7 @@ pub const Tensor = struct {
});
const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 3, .{} });
const res = try res_dev.toHostAlloc(allocator);
try testing.expectEqualSlices(i32, &.{
try std.testing.expectEqualSlices(i32, &.{
2, 1, 3, 0,
2, 3, 1, 0,
3, 2, 0, 1,
@ -3262,7 +3256,7 @@ pub const Tensor = struct {
const z = try zml.Buffer.scalar(platform, 4, .i32);
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } });
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
try std.testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
}
{
@ -3271,7 +3265,7 @@ pub const Tensor = struct {
const z = try zml.Buffer.scalar(platform, 3, .i32);
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } });
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
try std.testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
}
}
@ -3389,7 +3383,7 @@ pub const Tensor = struct {
}._fwd,
.{ x.withTags(.{.a}), .{ .a = idx }, y.withTags(.{.a}) },
);
try testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32));
try std.testing.expectEqual([10]f32{ 0, 1, 2, 3, -1, -1, 6, 7, 8, 9 }, try res.getValue([10]f32));
}
{
@ -3407,7 +3401,7 @@ pub const Tensor = struct {
}._fwd,
.{ x.withTags(.{ .a, .b }), idx, y.withTags(.{.a}) },
);
try testing.expectEqualDeep(
try std.testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } },
try res.getValue([2][5]f32),
);
@ -3427,7 +3421,7 @@ pub const Tensor = struct {
}._fwd,
.{ x, idx, y },
);
try testing.expectEqualDeep(
try std.testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, -1, 4 }, .{ 5, 6, 7, -1, 9 } },
res.getValue([2][5]f32),
);
@ -3448,7 +3442,7 @@ pub const Tensor = struct {
}._fwd,
.{ x.withTags(.{ .a, .b }), .{ .a = idx_a, .b = idx_b }, y.withTags(.{.a}) },
);
try testing.expectEqualDeep(
try std.testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
res.getValue([2][5]f32),
);
@ -3466,7 +3460,7 @@ pub const Tensor = struct {
}
};
const res = try zml.testing.compileAndCall(platform, A._fwd, .{ x, .{ idx_a, idx_b }, y });
try testing.expectEqualDeep(
try std.testing.expectEqualDeep(
[2][5]f32{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, -1, 9 } },
res.getValue([2][5]f32),
);
@ -3531,7 +3525,7 @@ pub const Tensor = struct {
const x = try zml.Buffer.fromArray(platform, [2][2]u8{ .{ 1, 2 }, .{ 3, 4 } });
{
const res = try zml.testing.compileAndCall(platform, Local._toDiag, .{x});
try testing.expectEqual(
try std.testing.expectEqual(
[2][2][2]u8{ .{
.{ 1, 0 },
.{ 0, 2 },
@ -3582,7 +3576,7 @@ pub const Tensor = struct {
});
{
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 0 });
try testing.expectEqual(
try std.testing.expectEqual(
[3][3]u8{
.{ 1, 0, 0 },
.{ 1, 1, 0 },
@ -3593,7 +3587,7 @@ pub const Tensor = struct {
}
{
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, 1 });
try testing.expectEqual(
try std.testing.expectEqual(
[3][3]u8{
.{ 1, 1, 0 },
.{ 1, 1, 1 },
@ -3604,7 +3598,7 @@ pub const Tensor = struct {
}
{
const res = try zml.testing.compileAndCall(platform, Local._tri, .{ x, -1 });
try testing.expectEqual(
try std.testing.expectEqual(
[3][3]u8{
.{ 0, 0, 0 },
.{ 1, 0, 0 },

View File

@ -25,7 +25,7 @@ pub const nn = @import("nn.zig");
pub const module = @import("module.zig");
pub const meta = @import("meta.zig");
pub const platform = @import("platform.zig");
pub const mlir = @import("mlir.zig");
pub const mlir = @import("mlirx.zig");
pub const pjrt = @import("pjrtx.zig");
pub const testing = @import("testing.zig");
pub const torch = @import("torch.zig");