Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
"ts"
],
"rust-analyzer.cargo.features": ["connections", "vendored-openssl"],
"rust-analyzer.linkedProjects": [
"rs/Cargo.toml"
],
}
2 changes: 1 addition & 1 deletion cs/src/Contracts/DevTunnels.Contracts.csproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<RootNamespace>Microsoft.DevTunnels.Contracts</RootNamespace>
Expand Down
54 changes: 47 additions & 7 deletions cs/tools/TunnelsSDK.Generator/RustContractWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,33 @@ private void WriteInterfaceContract(

s.Append(FormatDocComment(type.GetDocumentationCommentXml(), ""));
s.Append("#[derive(Clone, Debug, Deserialize, Serialize");
if (DefaultDerivers.Contains(rsName))
// Add Default for types in the explicit list, and for types that will be
// embedded into their base type via #[serde(flatten)].
var willBeEmbedded = type.BaseType?.ToString() is string bt &&
bt.StartsWith(this.csNamespace) &&
allTypes.Any((t) => SymbolEqualityComparer.Default.Equals(t.BaseType, type.BaseType));
if (DefaultDerivers.Contains(rsName) || willBeEmbedded)
{
s.Append(", Default");
}
s.AppendLine(")]");
s.AppendLine("#[serde(rename_all(serialize = \"camelCase\", deserialize = \"camelCase\"))]");
s.Append($"pub struct {rsName} {{");

// Check if this type has derived types. If so, we embed the derived types
// into this base type (like Go's struct embedding) rather than having
// derived types embed the base.
var derivedTypes = allTypes.Where(
(t) => SymbolEqualityComparer.Default.Equals(t.BaseType, type)).ToArray();

var fullBaseType = type.BaseType?.ToString();
if (fullBaseType != null && fullBaseType.StartsWith(this.csNamespace))
// Only add #[serde(flatten)] pub base if the base type does NOT embed
// derived types. When a base type has derived types, it embeds them
// (like Go's struct embedding), so derived types must not embed back.
var baseEmbedsDerived = fullBaseType != null &&
fullBaseType.StartsWith(this.csNamespace) &&
allTypes.Any((t) => SymbolEqualityComparer.Default.Equals(t.BaseType, type.BaseType));
if (fullBaseType != null && fullBaseType.StartsWith(this.csNamespace) && !baseEmbedsDerived)
{
var rsBaseType = fullBaseType.Substring(this.csNamespace.Length + 1);
s.AppendLine();
Expand All @@ -206,6 +223,10 @@ private void WriteInterfaceContract(
imports.Add($"crate::contracts::{rsBaseType}");
}

// A type is "embedded" if its base type embeds it via #[serde(flatten)].
// In that case, all fields must tolerate missing values in JSON.
var isEmbeddedType = baseEmbedsDerived;

var properties = type.GetMembers()
.OfType<IPropertySymbol>()
.Where((p) => !p.IsStatic)
Expand All @@ -216,7 +237,19 @@ private void WriteInterfaceContract(
{
s.AppendLine();
s.Append(FormatDocComment(property.GetDocumentationCommentXml(), " "));
AppendStructProperty(type, property, imports, s);
AppendStructProperty(type, property, imports, s, isEmbeddedType);
}

// Embed derived types via #[serde(flatten)], similar to Go's struct
// embedding. This allows the base type to deserialize fields from all
// derived types.
foreach (var derivedType in derivedTypes.OrderBy((t) => t.Name))
{
var fieldName = ToSnakeCase(derivedType.Name);
s.AppendLine();
s.AppendLine(" #[serde(flatten)]");
s.AppendLine($" pub {fieldName}: {derivedType.Name},");
imports.Add($"crate::contracts::{derivedType.Name}");
}

s.AppendLine("}");
Expand Down Expand Up @@ -394,7 +427,7 @@ private string FormatDocComment(string? comment, string prefix)

return s.ToString();
}
private void AppendStructProperty(ITypeSymbol parentType, IPropertySymbol property, SortedSet<string> imports, StringBuilder s)
private void AppendStructProperty(ITypeSymbol parentType, IPropertySymbol property, SortedSet<string> imports, StringBuilder s, bool isEmbeddedType = false)
{
var csType = property.Type.ToString();
var isNullable = csType.EndsWith("?");
Expand All @@ -418,7 +451,9 @@ private void AppendStructProperty(ITypeSymbol parentType, IPropertySymbol proper
if (isArray)
{
csType = csType.Substring(0, csType.Length - 2);
if (isNullable || ignoreWhenDefault)
// When a type is embedded (flattened) into a base type, all array
// fields must default to empty since they may not be present in JSON.
if (isNullable || ignoreWhenDefault || isEmbeddedType)
{
serdeDeclarations.Add("skip_serializing_if = \"Vec::is_empty\"");
serdeDeclarations.Add("default");
Expand All @@ -432,6 +467,11 @@ private void AppendStructProperty(ITypeSymbol parentType, IPropertySymbol proper
serdeDeclarations.Add("default");
}

if (isNullable)
{
serdeDeclarations.Add("skip_serializing_if = \"Option::is_none\"");
}

if (serdeDeclarations.Count > 0)
{
s.AppendLine($" #[serde({string.Join(", ", serdeDeclarations)})]");
Expand Down Expand Up @@ -472,7 +512,7 @@ private void AppendStructProperty(ITypeSymbol parentType, IPropertySymbol proper
"long" => "i64",
"ulong" => "u64",
"string" => "String",
"System.DateTime" => "DateTime<Utc>",
"System.DateTime" => "Timestamp",
"System.Text.RegularExpressions.Regex" => "regexp.Regexp",
"System.Collections.Generic.IDictionary<string, string>"
=> "HashMap<String, String>",
Expand All @@ -494,7 +534,7 @@ private void AppendStructProperty(ITypeSymbol parentType, IPropertySymbol proper

if (csType == "System.DateTime")
{
imports.Add("chrono::{DateTime, Utc}");
imports.Add("jiff::Timestamp");
}
else if (csType.Contains("IDictionary<"))
{
Expand Down
Loading
Loading